Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/agents/run_internal/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"normalize_resumed_input",
"fingerprint_input_item",
"deduplicate_input_items",
"deduplicate_input_items_preferring_latest",
"function_rejection_item",
"shell_rejection_item",
"apply_patch_rejection_item",
Expand Down Expand Up @@ -176,6 +177,15 @@ def deduplicate_input_items(items: Sequence[TResponseInputItem]) -> list[TRespon
return deduplicated


def deduplicate_input_items_preferring_latest(
items: Sequence[TResponseInputItem],
) -> list[TResponseInputItem]:
"""Deduplicate by stable identifiers while keeping the latest occurrence."""
# deduplicate_input_items keeps the first item per dedupe key. Reverse twice so that
# the latest item in the original order wins for duplicate IDs/call_ids.
return list(reversed(deduplicate_input_items(list(reversed(items)))))


def function_rejection_item(
agent: Any,
tool_call: Any,
Expand Down
6 changes: 3 additions & 3 deletions src/agents/run_internal/run_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
from .items import (
REJECTION_MESSAGE,
copy_input_items,
deduplicate_input_items,
deduplicate_input_items_preferring_latest,
ensure_input_item_format,
normalize_input_items_for_api,
normalize_resumed_input,
Expand Down Expand Up @@ -1109,7 +1109,7 @@ async def run_single_turn_streamed(
system_instructions=system_prompt,
)
if isinstance(filtered.input, list):
filtered.input = deduplicate_input_items(filtered.input)
filtered.input = deduplicate_input_items_preferring_latest(filtered.input)
if server_conversation_tracker is not None:
logger.debug(
"filtered.input has %s items; ids=%s",
Expand Down Expand Up @@ -1418,7 +1418,7 @@ async def get_new_response(
system_instructions=system_prompt,
)
if isinstance(filtered.input, list):
filtered.input = deduplicate_input_items(filtered.input)
filtered.input = deduplicate_input_items_preferring_latest(filtered.input)

if server_conversation_tracker is not None:
server_conversation_tracker.mark_input_as_sent(filtered.input)
Expand Down
6 changes: 3 additions & 3 deletions src/agents/run_internal/session_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..run_state import RunState
from .items import (
copy_input_items,
deduplicate_input_items,
deduplicate_input_items_preferring_latest,
drop_orphan_function_calls,
ensure_input_item_format,
fingerprint_input_item,
Expand Down Expand Up @@ -136,7 +136,7 @@ async def prepare_input_with_session(
prepared_as_inputs = [ensure_input_item_format(item) for item in prepared_items_raw]
filtered = drop_orphan_function_calls(prepared_as_inputs)
normalized = normalize_input_items_for_api(filtered)
deduplicated = deduplicate_input_items(normalized)
deduplicated = deduplicate_input_items_preferring_latest(normalized)

return deduplicated, [ensure_input_item_format(item) for item in appended_items]

Expand Down Expand Up @@ -259,7 +259,7 @@ async def save_result_to_session(
for item in new_items_for_fingerprint
]

items_to_save = deduplicate_input_items(input_list + new_items_as_input)
items_to_save = deduplicate_input_items_preferring_latest(input_list + new_items_as_input)

if is_openai_conversation_session and items_to_save:
items_to_save = [_sanitize_openai_conversation_item(item) for item in items_to_save]
Expand Down
139 changes: 139 additions & 0 deletions tests/test_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
drop_orphan_function_calls,
ensure_input_item_format,
normalize_input_items_for_api,
normalize_resumed_input,
)
from agents.run_internal.oai_conversation import OpenAIServerConversationTracker
from agents.run_internal.run_loop import get_new_response
Expand Down Expand Up @@ -256,6 +257,43 @@ def _has_call(call_type: str, call_id: str) -> bool:
assert _has_call("local_shell_call", "local_shell_keep")


def test_normalize_resumed_input_drops_orphan_function_calls():
raw_input: list[TResponseInputItem] = [
cast(
TResponseInputItem,
{
"type": "function_call",
"call_id": "orphan_call",
"name": "tool_orphan",
"arguments": "{}",
},
),
cast(
TResponseInputItem,
{
"type": "function_call",
"call_id": "paired_call",
"name": "tool_paired",
"arguments": "{}",
},
),
cast(
TResponseInputItem,
{"type": "function_call_output", "call_id": "paired_call", "output": "ok"},
),
]

normalized = normalize_resumed_input(raw_input)
assert isinstance(normalized, list)
call_ids = [
cast(dict[str, Any], item).get("call_id")
for item in normalized
if isinstance(item, dict) and item.get("type") == "function_call"
]
assert "orphan_call" not in call_ids
assert "paired_call" in call_ids


def testnormalize_input_items_for_api_preserves_provider_data():
items: list[TResponseInputItem] = [
cast(
Expand Down Expand Up @@ -1161,6 +1199,71 @@ async def test_prepare_input_with_session_keeps_function_call_outputs():
assert last_item["content"] == "hello"


@pytest.mark.asyncio
async def test_prepare_input_with_session_prefers_latest_function_call_output():
history_output = cast(
TResponseInputItem,
{
"type": "function_call_output",
"call_id": "call_latest",
"output": "history-output",
},
)
session = SimpleListSession(history=[history_output])
latest_output = cast(
TResponseInputItem,
{
"type": "function_call_output",
"call_id": "call_latest",
"output": "new-output",
},
)

prepared_input, session_items = await prepare_input_with_session([latest_output], session, None)

assert isinstance(prepared_input, list)
prepared_outputs = [
cast(dict[str, Any], item)
for item in prepared_input
if isinstance(item, dict)
and item.get("type") == "function_call_output"
and item.get("call_id") == "call_latest"
]
assert len(prepared_outputs) == 1
assert prepared_outputs[0]["output"] == "new-output"
assert len(session_items) == 1
assert cast(dict[str, Any], session_items[0])["output"] == "new-output"


@pytest.mark.asyncio
async def test_prepare_input_with_session_drops_orphan_function_calls():
orphan_call = cast(
TResponseInputItem,
{
"type": "function_call",
"call_id": "orphan_call",
"name": "tool_orphan",
"arguments": "{}",
},
)
session = SimpleListSession(history=[orphan_call])

prepared_input, session_items = await prepare_input_with_session("hello", session, None)

assert isinstance(prepared_input, list)
assert len(session_items) == 1
assert not any(
isinstance(item, dict)
and item.get("type") == "function_call"
and item.get("call_id") == "orphan_call"
for item in prepared_input
)
assert any(
isinstance(item, dict) and item.get("role") == "user" and item.get("content") == "hello"
for item in prepared_input
)


def test_ensure_api_input_item_handles_model_dump_objects():
class _ModelDumpItem:
def model_dump(self, exclude_unset: bool = True) -> dict[str, Any]:
Expand Down Expand Up @@ -1417,6 +1520,42 @@ async def test_save_result_to_session_preserves_function_outputs():
assert "output" in saved_dict


@pytest.mark.asyncio
async def test_save_result_to_session_prefers_latest_duplicate_function_outputs():
session = SimpleListSession()
original_item = cast(
TResponseInputItem,
{
"type": "function_call_output",
"call_id": "call_duplicate",
"output": "old-output",
},
)
new_item_payload = {
"type": "function_call_output",
"call_id": "call_duplicate",
"output": "new-output",
}
new_item = _DummyRunItem(new_item_payload)

await save_result_to_session(
session,
[original_item],
[cast(RunItem, new_item)],
None,
)

duplicates = [
cast(dict[str, Any], item)
for item in session.saved_items
if isinstance(item, dict)
and item.get("type") == "function_call_output"
and item.get("call_id") == "call_duplicate"
]
assert len(duplicates) == 1
assert duplicates[0]["output"] == "new-output"


@pytest.mark.asyncio
async def test_rewind_handles_id_stripped_sessions() -> None:
session = IdStrippingSession()
Expand Down
94 changes: 92 additions & 2 deletions tests/test_call_model_input_filter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

from typing import Any
from typing import Any, cast

import pytest

from agents import Agent, RunConfig, Runner, UserError
from agents import Agent, RunConfig, Runner, TResponseInputItem, UserError
from agents.run import CallModelData, ModelInputData

from .fake_model import FakeModel
Expand Down Expand Up @@ -77,3 +77,93 @@ def invalid_filter(_data: CallModelData[Any]):
input="start",
run_config=RunConfig(call_model_input_filter=invalid_filter),
)


@pytest.mark.asyncio
async def test_call_model_input_filter_prefers_latest_duplicate_outputs_non_streamed() -> None:
model = FakeModel()
agent = Agent(name="test", model=model)
model.set_next_output([get_text_message("ok")])

duplicate_old = cast(
TResponseInputItem,
{
"type": "function_call_output",
"call_id": "dup-call",
"output": "old-value",
},
)
duplicate_new = cast(
TResponseInputItem,
{
"type": "function_call_output",
"call_id": "dup-call",
"output": "new-value",
},
)

def filter_fn(data: CallModelData[Any]) -> ModelInputData:
return ModelInputData(
input=[duplicate_old, duplicate_new] + list(data.model_data.input),
instructions=data.model_data.instructions,
)

await Runner.run(
agent,
input="start",
run_config=RunConfig(call_model_input_filter=filter_fn),
)

outputs = [
item
for item in model.last_turn_args["input"]
if item.get("type") == "function_call_output" and item.get("call_id") == "dup-call"
]
assert len(outputs) == 1
assert outputs[0]["output"] == "new-value"


@pytest.mark.asyncio
async def test_call_model_input_filter_prefers_latest_duplicate_outputs_streamed() -> None:
model = FakeModel()
agent = Agent(name="test", model=model)
model.set_next_output([get_text_message("ok")])

duplicate_old = cast(
TResponseInputItem,
{
"type": "function_call_output",
"call_id": "dup-call-stream",
"output": "old-value",
},
)
duplicate_new = cast(
TResponseInputItem,
{
"type": "function_call_output",
"call_id": "dup-call-stream",
"output": "new-value",
},
)

async def filter_fn(data: CallModelData[Any]) -> ModelInputData:
return ModelInputData(
input=[duplicate_old, duplicate_new] + list(data.model_data.input),
instructions=data.model_data.instructions,
)

result = Runner.run_streamed(
agent,
input="start",
run_config=RunConfig(call_model_input_filter=filter_fn),
)
async for _ in result.stream_events():
pass

outputs = [
item
for item in model.last_turn_args["input"]
if item.get("type") == "function_call_output" and item.get("call_id") == "dup-call-stream"
]
assert len(outputs) == 1
assert outputs[0]["output"] == "new-value"