Skip to content

Commit d0ea9e7

Browse files
committed
fix: issue with streaming partial nested objects
1 parent 55e9881 commit d0ea9e7

File tree

7 files changed

+52
-34
lines changed

7 files changed

+52
-34
lines changed

workflowai/core/_common_types.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,18 @@
1919

2020

2121
class OutputValidator(Protocol, Generic[AgentOutputCov]):
22-
def __call__(self, data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutputCov: ...
22+
def __call__(self, data: dict[str, Any], partial: bool) -> AgentOutputCov:
23+
"""A way to convert a json object into an AgentOutput
24+
25+
Args:
26+
data (dict[str, Any]): The json object to convert
27+
partial (bool): Whether the json is partial, meaning that
28+
it may not contain all the fields required by the AgentOutput model.
29+
30+
Returns:
31+
AgentOutputCov: The converted AgentOutput
32+
"""
33+
...
2334

2435

2536
class VersionRunParams(TypedDict):

workflowai/core/client/_fn_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
RunParams,
2424
RunTemplate,
2525
)
26-
from workflowai.core.client._utils import intolerant_validator
26+
from workflowai.core.client._utils import default_validator
2727
from workflowai.core.client.agent import Agent
2828
from workflowai.core.domain.errors import InvalidGenerationError
2929
from workflowai.core.domain.model import ModelOrStr
@@ -144,14 +144,15 @@ async def __call__(self, input: AgentInput, **kwargs: Unpack[RunParams[AgentOutp
144144
except InvalidGenerationError as e:
145145
if e.partial_output and e.run_id:
146146
try:
147-
validator, _ = self._sanitize_validator(kwargs, intolerant_validator(self.output_cls))
147+
validator, _ = self._sanitize_validator(kwargs, default_validator(self.output_cls))
148148
run = self._build_run_no_tools(
149149
chunk=RunResponse(
150150
id=e.run_id,
151151
task_output=e.partial_output,
152152
),
153153
schema_id=self.schema_id or 0,
154154
validator=validator,
155+
partial=False,
155156
)
156157
run.error = e.error
157158
return run

workflowai/core/client/_models.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,18 @@ def to_domain(
134134
task_id: str,
135135
task_schema_id: int,
136136
validator: OutputValidator[AgentOutput],
137+
partial: Optional[bool] = None,
137138
) -> Run[AgentOutput]:
139+
# We do partial validation if either:
140+
# - there are tool call requests, which means that the output can be empty
141+
# - the run has not yet finished, for exmaple when streaming, in which case the duration_seconds is None
142+
if partial is None:
143+
partial = bool(self.tool_call_requests) or self.duration_seconds is None
138144
return Run(
139145
id=self.id,
140146
agent_id=task_id,
141147
schema_id=task_schema_id,
142-
output=validator(self.task_output, self.tool_call_requests is not None),
148+
output=validator(self.task_output, partial),
143149
version=self.version and self.version.to_domain(),
144150
duration_seconds=self.duration_seconds,
145151
cost_usd=self.cost_usd,

workflowai/core/client/_models_test.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from tests.utils import fixture_text
77
from workflowai.core.client._models import RunResponse
8-
from workflowai.core.client._utils import intolerant_validator, tolerant_validator
8+
from workflowai.core.client._utils import default_validator
99
from workflowai.core.domain.run import Run
1010
from workflowai.core.domain.tool_call import ToolCallRequest
1111

@@ -41,7 +41,7 @@ def test_no_version_not_optional(self):
4141
with pytest.raises(ValidationError): # sanity
4242
_TaskOutput.model_validate({"a": 1})
4343

44-
parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutput))
44+
parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutput))
4545
assert isinstance(parsed, Run)
4646
assert parsed.output.a == 1
4747
# b is not defined
@@ -52,18 +52,19 @@ def test_no_version_optional(self):
5252
chunk = RunResponse.model_validate_json('{"id": "1", "task_output": {"a": 1}}')
5353
assert chunk
5454

55-
parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutputOpt))
55+
parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutputOpt))
5656
assert isinstance(parsed, Run)
5757
assert parsed.output.a == 1
5858
assert parsed.output.b is None
5959

6060
def test_with_version(self):
61+
"""Full output is validated since the duration is passed and there are no tool calls"""
6162
chunk = RunResponse.model_validate_json(
6263
'{"id": "1", "task_output": {"a": 1, "b": "test"}, "cost_usd": 0.1, "duration_seconds": 1, "version": {"properties": {"a": 1, "b": "test"}}}', # noqa: E501
6364
)
6465
assert chunk
6566

66-
parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutput))
67+
parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutput))
6768
assert isinstance(parsed, Run)
6869
assert parsed.output.a == 1
6970
assert parsed.output.b == "test"
@@ -73,17 +74,19 @@ def test_with_version(self):
7374

7475
def test_with_version_validation_fails(self):
7576
chunk = RunResponse.model_validate_json(
76-
'{"id": "1", "task_output": {"a": 1}, "version": {"properties": {"a": 1, "b": "test"}}}',
77+
"""{"id": "1", "task_output": {"a": 1},
78+
"version": {"properties": {"a": 1, "b": "test"}}, "duration_seconds": 1}""",
7779
)
7880
with pytest.raises(ValidationError):
79-
chunk.to_domain(task_id="1", task_schema_id=1, validator=intolerant_validator(_TaskOutput))
81+
chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutput))
8082

8183
def test_with_tool_calls(self):
8284
chunk = RunResponse.model_validate_json(
83-
'{"id": "1", "task_output": {}, "tool_call_requests": [{"id": "1", "name": "test", "input": {"a": 1}}]}',
85+
"""{"id": "1", "task_output": {},
86+
"tool_call_requests": [{"id": "1", "name": "test", "input": {"a": 1}}], "duration_seconds": 1}""",
8487
)
8588
assert chunk
8689

87-
parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=tolerant_validator(_TaskOutput))
90+
parsed = chunk.to_domain(task_id="1", task_schema_id=1, validator=default_validator(_TaskOutput))
8891
assert isinstance(parsed, Run)
8992
assert parsed.tool_call_requests == [ToolCallRequest(id="1", name="test", input={"a": 1})]

workflowai/core/client/_utils.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,18 +87,15 @@ async def _wait_for_exception(e: WorkflowAIError):
8787
return _should_retry, _wait_for_exception
8888

8989

90-
def tolerant_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]:
91-
def _validator(data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutput: # noqa: ARG001
92-
return construct_model_recursive(m, data)
93-
94-
return _validator
95-
96-
97-
def intolerant_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]:
98-
def _validator(data: dict[str, Any], has_tool_call_requests: bool) -> AgentOutput:
90+
def default_validator(m: type[AgentOutput]) -> OutputValidator[AgentOutput]:
91+
def _validator(data: dict[str, Any], partial: bool) -> AgentOutput:
9992
# When we have tool call requests, the output can be empty
100-
if has_tool_call_requests:
101-
return tolerant_validator(m)(data, has_tool_call_requests)
93+
if partial:
94+
try:
95+
return construct_model_recursive(m, data)
96+
except Exception: # noqa: BLE001
97+
logger.warning("Failed to validate partial data: %s", data)
98+
return m.model_construct(None, **data)
10299

103100
return m.model_validate(data)
104101

workflowai/core/client/_utils_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
from workflowai.core.client._utils import (
88
build_retryable_wait,
9+
default_validator,
910
global_default_version_reference,
1011
split_chunks,
11-
tolerant_validator,
1212
)
1313
from workflowai.core.domain.errors import BaseError, WorkflowAIError
1414

@@ -59,13 +59,13 @@ class Ingredient(BaseModel):
5959
ingredients: list[Ingredient]
6060

6161

62-
class TestTolerantValidator:
62+
class TestValidator:
6363
def test_tolerant_validator_nested_object(self):
64-
validated = tolerant_validator(Recipe)(
64+
validated = default_validator(Recipe)(
6565
{
6666
"ingredients": [{"name": "salt"}],
6767
},
68-
has_tool_call_requests=False,
68+
partial=True,
6969
)
7070
for ingredient in validated.ingredients:
7171
assert isinstance(ingredient, Recipe.Ingredient)

workflowai/core/client/agent.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@
2222
from workflowai.core.client._types import RunParams
2323
from workflowai.core.client._utils import (
2424
build_retryable_wait,
25+
default_validator,
2526
global_default_version_reference,
26-
intolerant_validator,
27-
tolerant_validator,
2827
)
2928
from workflowai.core.domain.completion import Completion
3029
from workflowai.core.domain.errors import BaseError, WorkflowAIError
@@ -295,8 +294,9 @@ def _build_run_no_tools(
295294
chunk: RunResponse,
296295
schema_id: int,
297296
validator: OutputValidator[AgentOutput],
297+
partial: Optional[bool] = None,
298298
) -> Run[AgentOutput]:
299-
run = chunk.to_domain(self.agent_id, schema_id, validator)
299+
run = chunk.to_domain(self.agent_id, schema_id, validator, partial)
300300
run._agent = self # pyright: ignore [reportPrivateUsage]
301301
return run
302302

@@ -362,7 +362,7 @@ async def run(
362362
Run[AgentOutput]: The task run object.
363363
"""
364364
prepared_run = await self._prepare_run(agent_input, stream=False, **kwargs)
365-
validator, new_kwargs = self._sanitize_validator(kwargs, intolerant_validator(self.output_cls))
365+
validator, new_kwargs = self._sanitize_validator(kwargs, default_validator(self.output_cls))
366366

367367
last_error = None
368368
while prepared_run.should_retry():
@@ -374,7 +374,6 @@ async def run(
374374
validator,
375375
current_iteration=0,
376376
# TODO[test]: add test with custom validator
377-
# We popped validator above
378377
**new_kwargs,
379378
)
380379
except WorkflowAIError as e: # noqa: PERF203
@@ -419,10 +418,11 @@ async def stream(
419418
AsyncIterator[Run[AgentOutput]]: An async iterator yielding task run objects.
420419
"""
421420
prepared_run = await self._prepare_run(agent_input, stream=True, **kwargs)
422-
validator, new_kwargs = self._sanitize_validator(kwargs, tolerant_validator(self.output_cls))
421+
validator, new_kwargs = self._sanitize_validator(kwargs, default_validator(self.output_cls))
423422

424423
while prepared_run.should_retry():
425424
try:
425+
chunk: Optional[RunResponse] = None
426426
async for chunk in self.api.stream(
427427
method="POST",
428428
path=prepared_run.route,
@@ -462,7 +462,7 @@ async def reply(
462462
"""
463463

464464
prepared_run = await self._prepare_reply(run_id, user_message, tool_results, stream=False, **kwargs)
465-
validator, new_kwargs = self._sanitize_validator(kwargs, intolerant_validator(self.output_cls))
465+
validator, new_kwargs = self._sanitize_validator(kwargs, default_validator(self.output_cls))
466466

467467
res = await self.api.post(prepared_run.route, prepared_run.request, returns=RunResponse, run=True)
468468
return await self._build_run(

0 commit comments

Comments
 (0)