Skip to content

Commit c6f60a2

Browse files
authored
Merge pull request #8 from workflowai/guillaume/fix-stream-and-timeout
Fix stream errors and increase httpx timeout
2 parents c9c1d6f + 17ba601 commit c6f60a2

File tree

8 files changed

+194
-28
lines changed

8 files changed

+194
-28
lines changed

tests/e2e/deploy_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,18 @@ async def test_deploy_task(wai: workflowai.Client):
2828

2929
# Run using the environment and the same input
3030
task_run2 = await wai.run(
31-
task, task_input=CityToCapitalTaskInput(city="Osaka"), environment="dev",
31+
task,
32+
task_input=CityToCapitalTaskInput(city="Osaka"),
33+
environment="dev",
3234
)
3335
# IDs will match since we are using cache
3436
assert task_run.id == task_run2.id
3537

3638
# Run using the environment and a different input
3739
task_run3 = await wai.run(
38-
task, task_input=CityToCapitalTaskInput(city="Toulouse"), environment="dev",
40+
task,
41+
task_input=CityToCapitalTaskInput(city="Toulouse"),
42+
environment="dev",
3943
)
4044
assert task_run3.task_output.capital == "Paris"
4145
assert task_run3.id != task_run2.id

tests/e2e/stream_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import Optional
2+
3+
import pytest
4+
from pydantic import BaseModel
5+
6+
import workflowai
7+
from workflowai.core.domain.task import Task
8+
9+
10+
class ImprovePromptTaskInput(BaseModel):
11+
original_prompt: Optional[str] = None
12+
prompt_input: Optional[str] = None
13+
prompt_output: Optional[str] = None
14+
user_evaluation: Optional[str] = None
15+
16+
17+
class ImprovePromptTaskOutput(BaseModel):
18+
improved_prompt: Optional[str] = None
19+
changelog: Optional[str] = None
20+
21+
22+
class ImprovePromptTask(Task[ImprovePromptTaskInput, ImprovePromptTaskOutput]):
23+
id: str = "improve-prompt"
24+
schema_id: int = 3
25+
input_class: type[ImprovePromptTaskInput] = ImprovePromptTaskInput
26+
output_class: type[ImprovePromptTaskOutput] = ImprovePromptTaskOutput
27+
28+
29+
@pytest.mark.skip("This hits the API")
30+
async def test_stream_task(wai: workflowai.Client):
31+
task = ImprovePromptTask()
32+
33+
task_input = ImprovePromptTaskInput(
34+
original_prompt="Say hello to the guest",
35+
prompt_input='{"guest": "John", "language": "French"}',
36+
prompt_output='{"greeting": "Hello John"}',
37+
user_evaluation="Not in the right language",
38+
)
39+
40+
streamed = await wai.run(task, task_input=task_input, stream=True, use_cache="never")
41+
chunks = [chunk async for chunk in streamed]
42+
43+
assert len(chunks) > 1

workflowai/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from workflowai.core.client import Client as Client
44
from workflowai.core.domain.cache_usage import CacheUsage as CacheUsage
5-
from workflowai.core.domain.errors import NotFoundError as NotFoundError
5+
from workflowai.core.domain.errors import WorkflowAIError as WorkflowAIError
66
from workflowai.core.domain.llm_completion import LLMCompletion as LLMCompletion
77
from workflowai.core.domain.task import Task as Task
88
from workflowai.core.domain.task_evaluation import TaskEvaluation as TaskEvaluation

workflowai/core/client/api.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from typing import Any, AsyncIterator, Literal, Optional, TypeVar, Union, overload
22

33
import httpx
4-
from pydantic import BaseModel, TypeAdapter
4+
from pydantic import BaseModel, TypeAdapter, ValidationError
5+
6+
from workflowai.core.client.utils import split_chunks
7+
from workflowai.core.domain.errors import BaseError, ErrorResponse, WorkflowAIError
58

69
# A type for return values
710
_R = TypeVar("_R")
@@ -24,6 +27,7 @@ def _client(self) -> httpx.AsyncClient:
2427
"Authorization": f"Bearer {self.api_key}",
2528
**(self.source_headers or {}),
2629
},
30+
timeout=120.0,
2731
)
2832
return client
2933

@@ -84,21 +88,37 @@ async def delete(self, path: str) -> None:
8488
response = await client.delete(path)
8589
response.raise_for_status()
8690

91+
def _extract_error(self, data: Union[bytes, str], exception: Optional[Exception] = None) -> WorkflowAIError:
92+
try:
93+
res = ErrorResponse.model_validate_json(data)
94+
return WorkflowAIError(res.error, task_run_id=res.task_run_id)
95+
except ValidationError:
96+
raise WorkflowAIError(
97+
error=BaseError(
98+
message="Unknown error" if exception is None else str(exception),
99+
details={
100+
"raw": str(data),
101+
},
102+
),
103+
) from exception
104+
87105
async def stream(
88106
self,
89107
method: Literal["GET", "POST"],
90108
path: str,
91109
data: BaseModel,
92110
returns: type[_M],
93111
) -> AsyncIterator[_M]:
94-
# TODO: error handling
95-
async with self._client() as client: # noqa: SIM117
96-
async with client.stream(
97-
method,
98-
path,
99-
content=data.model_dump_json(exclude_none=True),
100-
headers={"Content-Type": "application/json"},
101-
) as response:
102-
async for chunk in response.aiter_bytes():
103-
stripped = chunk.removeprefix(b"data: ").removesuffix(b"\n\n")
104-
yield returns.model_validate_json(stripped)
112+
async with self._client() as client, client.stream(
113+
method,
114+
path,
115+
content=data.model_dump_json(exclude_none=True),
116+
headers={"Content-Type": "application/json"},
117+
) as response:
118+
async for chunk in response.aiter_bytes():
119+
payload = ""
120+
try:
121+
for payload in split_chunks(chunk):
122+
yield returns.model_validate_json(payload)
123+
except ValidationError as e:
124+
raise self._extract_error(payload, e) from None

workflowai/core/client/client.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
TaskRunResponse,
2727
)
2828
from workflowai.core.domain.cache_usage import CacheUsage
29-
from workflowai.core.domain.errors import NotFoundError
29+
from workflowai.core.domain.errors import BaseError, WorkflowAIError
3030
from workflowai.core.domain.task import Task, TaskInput, TaskOutput
3131
from workflowai.core.domain.task_example import TaskExample
3232
from workflowai.core.domain.task_run import TaskRun
@@ -146,7 +146,13 @@ async def run( # noqa: C901
146146
return res.to_domain(task)
147147
except HTTPStatusError as e:
148148
if e.response.status_code == 404:
149-
raise NotFoundError("Task not found") from e
149+
raise WorkflowAIError(
150+
error=BaseError(
151+
status_code=404,
152+
code="not_found",
153+
message="Task not found",
154+
),
155+
) from e
150156
retry_after = e.response.headers.get("Retry-After")
151157
if retry_after:
152158
try:
@@ -180,7 +186,7 @@ async def _stream():
180186
yield task.output_class.model_construct(None, **chunk.task_output)
181187
except HTTPStatusError as e:
182188
if e.response.status_code == 404:
183-
raise NotFoundError("Task not found") from e
189+
raise WorkflowAIError(error=BaseError(message="Task not found")) from e
184190
retry_after = e.response.headers.get("Retry-After")
185191

186192
if retry_after:
@@ -194,7 +200,7 @@ async def _stream():
194200
except (TypeError, ValueError, OverflowError):
195201
delay = min(delay * 2, max_retry_delay / 1000)
196202
elif e.response.status_code == 429 and delay < max_retry_delay / 1000:
197-
delay = min(delay * 2, max_retry_delay / 1000)
203+
delay = min(delay * 2, max_retry_delay / 1000)
198204
await asyncio.sleep(delay)
199205
retry_count += 1
200206

workflowai/core/client/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Sometimes, 2 payloads are sent in a single message.
2+
# By adding the " at the end we more or less guarantee that
3+
# the delimiter is not withing a quoted string
4+
import re
5+
6+
delimiter = re.compile(r'\}\n\ndata: \{"')
7+
8+
9+
def split_chunks(chunk: bytes):
10+
start = 0
11+
chunk_str = chunk.removeprefix(b"data: ").removesuffix(b"\n\n").decode()
12+
for match in delimiter.finditer(chunk_str):
13+
yield chunk_str[start : match.start() + 1]
14+
start = match.end() - 2
15+
yield chunk_str[start:]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import pytest
2+
3+
from workflowai.core.client.utils import split_chunks
4+
5+
6+
@pytest.mark.parametrize(
7+
("chunk", "expected"),
8+
[
9+
(b'data: {"foo": "bar"}\n\ndata: {"foo": "baz"}', ['{"foo": "bar"}', '{"foo": "baz"}']),
10+
(
11+
b'data: {"foo": "bar"}\n\ndata: {"foo": "baz"}\n\ndata: {"foo": "qux"}',
12+
['{"foo": "bar"}', '{"foo": "baz"}', '{"foo": "qux"}'],
13+
),
14+
],
15+
)
16+
def test_split_chunks(chunk: bytes, expected: list[bytes]):
17+
assert list(split_chunks(chunk)) == expected

workflowai/core/domain/errors.py

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,73 @@
1-
class WorkflowAIError(Exception):
2-
def __init__(self, message: str):
3-
self.message = message
4-
super().__init__(message)
1+
from typing import Any, Literal, Optional, Union
2+
3+
from pydantic import BaseModel
4+
5+
ProviderErrorCode = Literal[
6+
# Max number of tokens were exceeded in the prompt
7+
"max_tokens_exceeded",
8+
# The model failed to generate a response
9+
"failed_generation",
10+
# The model generated a response but it was not valid
11+
"invalid_generation",
12+
# The model returned an error that we currently do not handle
13+
# The returned status code will match the provider status code and the entire
14+
# provider response will be provided the error details.
15+
#
16+
# This error is intended as a fallback since we do not control what the providers
17+
# return. We track this error on our end and the error should eventually
18+
# be assigned a different status code
19+
"unknown_provider_error",
20+
# The provider returned a rate limit error
21+
"rate_limit",
22+
# The provider returned a server overloaded error
23+
"server_overloaded",
24+
# The requested provider does not support the model
25+
"invalid_provider_config",
26+
# The provider returned a 500
27+
"provider_internal_error",
28+
# The provider returned a 502 or 503
29+
"provider_unavailable",
30+
# The request timed out
31+
"read_timeout",
32+
]
33+
34+
ErrorCode = Union[
35+
ProviderErrorCode,
36+
Literal[
37+
# The object was not found
38+
"object_not_found",
39+
# There are no configured providers supporting the requested model
40+
# This error will never happen when using WorkflowAI keys
41+
"no_provider_supporting_model",
42+
# The requested provider does not support the model
43+
"provider_does_not_support_model",
44+
# The requested model does not support the requested generation mode
45+
# (e-g a model that does not support images generation was sent an image)
46+
"model_does_not_support_mode",
47+
# Run properties are invalid, for example the model does not exist
48+
"invalid_run_properties",
49+
# An internal error occurred
50+
"internal_error",
51+
# The request was invalid
52+
"bad_request",
53+
],
54+
str, # Using as a fallback to avoid validation error if an error code is added to the API
55+
]
556

6-
def __str__(self):
7-
return self.message
857

58+
class BaseError(BaseModel):
59+
details: Optional[dict[str, Any]] = None
60+
message: str
61+
status_code: Optional[int] = None
62+
code: Optional[ErrorCode] = None
963

10-
class NotFoundError(WorkflowAIError):
11-
def __init__(self, message: str):
12-
super().__init__(message)
64+
65+
class ErrorResponse(BaseModel):
66+
error: BaseError
67+
task_run_id: Optional[str] = None
68+
69+
70+
class WorkflowAIError(Exception):
71+
def __init__(self, error: BaseError, task_run_id: Optional[str] = None):
72+
self.error = error
73+
self.task_run_id = task_run_id

0 commit comments

Comments
 (0)