diff --git a/src/agents/extensions/sandbox/e2b/sandbox.py b/src/agents/extensions/sandbox/e2b/sandbox.py index 425436f3e0..1ccaaa639f 100644 --- a/src/agents/extensions/sandbox/e2b/sandbox.py +++ b/src/agents/extensions/sandbox/e2b/sandbox.py @@ -686,6 +686,8 @@ class _E2BPtyProcessEntry: output_lock: asyncio.Lock = field(default_factory=asyncio.Lock) output_notify: asyncio.Event = field(default_factory=asyncio.Event) last_used: float = field(default_factory=time.monotonic) + done: bool = False + exit_code: int | None = None @dataclass(frozen=True) @@ -982,6 +984,7 @@ async def _append_output(payload: bytes | bytearray | str | object) -> None: on_data=_append_output, ) entry.handle = handle + asyncio.create_task(self._run_pty_waiter(entry)) await self._sandbox.pty.send_stdin( cast(Any, handle).pid, f"{command_text}\n".encode(), @@ -999,6 +1002,7 @@ async def _append_output(payload: bytes | bytearray | str | object) -> None: on_stderr=_append_output, ) entry.handle = handle + asyncio.create_task(self._run_pty_waiter(entry)) async with self._pty_lock: process_id = allocate_pty_process_id(self._reserved_pty_process_ids) self._reserved_pty_process_ids.add(process_id) @@ -1044,6 +1048,24 @@ async def _append_output(payload: bytes | bytearray | str | object) -> None: original_token_count=original_token_count, ) + async def _run_pty_waiter(self, entry: _E2BPtyProcessEntry) -> None: + wait = getattr(entry.handle, "wait", None) + if not callable(wait): + return + + try: + result = wait() + if inspect.isawaitable(result): + await result + exit_code = getattr(entry.handle, "exit_code", None) + if exit_code is not None: + entry.exit_code = int(exit_code) + except Exception: + pass + finally: + entry.done = True + entry.output_notify.set() + async def pty_write_stdin( self, *, @@ -1195,7 +1217,7 @@ async def _collect_pty_output( if time.monotonic() >= deadline: break - if self._entry_exit_code(entry) is not None: + if self._entry_done(entry): async with entry.output_lock: while entry.output_chunks: output.extend(entry.output_chunks.popleft()) @@ -1226,7 +1248,7 @@ async def _finalize_pty_update( exit_code = self._entry_exit_code(entry) live_process_id: int | None = process_id - if exit_code is not None: + if self._entry_done(entry): async with self._pty_lock: removed = self._pty_processes.pop(process_id, None) self._reserved_pty_process_ids.discard(process_id) @@ -1246,7 +1268,7 @@ def _prune_pty_processes_if_needed(self) -> _E2BPtyProcessEntry | None: return None meta: list[tuple[int, float, bool]] = [ - (process_id, entry.last_used, self._entry_exit_code(entry) is not None) + (process_id, entry.last_used, self._entry_done(entry)) for process_id, entry in self._pty_processes.items() ] process_id = process_id_to_prune_from_meta(meta) @@ -1257,6 +1279,8 @@ def _prune_pty_processes_if_needed(self) -> _E2BPtyProcessEntry | None: return self._pty_processes.pop(process_id, None) def _entry_exit_code(self, entry: _E2BPtyProcessEntry) -> int | None: + if entry.exit_code is not None: + return entry.exit_code value = getattr(entry.handle, "exit_code", None) if value is None: return None @@ -1265,6 +1289,9 @@ def _entry_exit_code(self, entry: _E2BPtyProcessEntry) -> int | None: except (TypeError, ValueError): return None + def _entry_done(self, entry: _E2BPtyProcessEntry) -> bool: + return entry.done or self._entry_exit_code(entry) is not None + async def _terminate_pty_entry(self, entry: _E2BPtyProcessEntry) -> None: kill = getattr(entry.handle, "kill", None) if callable(kill): diff --git a/tests/extensions/sandbox/test_e2b.py b/tests/extensions/sandbox/test_e2b.py index 6a123770c3..60819edefb 100644 --- a/tests/extensions/sandbox/test_e2b.py +++ b/tests/extensions/sandbox/test_e2b.py @@ -244,6 +244,9 @@ def __init__(self) -> None: self.next_result = _FakeE2BResult() self.background_calls: list[dict[str, object]] = [] self.background_error: BaseException | None = None + self.background_stdout: bytes | str | None = "started\n" + self.background_exit_code = 0 + self.background_wait_event: asyncio.Event | None = None async def run( self, @@ -273,18 +276,18 @@ async def run( "background": background, } ) - if callable(on_stdout): - result = on_stdout("started\n") + if callable(on_stdout) and self.background_stdout is not None: + result = on_stdout(self.background_stdout) if inspect.isawaitable(result): await result - class _Handle: - exit_code = 0 - - async def kill(self) -> None: - return None - - return cast(_FakeE2BResult, _Handle()) + return cast( + _FakeE2BResult, + _FakeE2BBackgroundHandle( + exit_code=self.background_exit_code, + wait_event=self.background_wait_event, + ), + ) self.calls.append( { @@ -322,14 +325,47 @@ async def kill(self) -> None: return result +class _FakeE2BBackgroundHandle: + def __init__( + self, + *, + exit_code: int, + wait_event: asyncio.Event | None = None, + ) -> None: + self.exit_code: int | None = exit_code if wait_event is None else None + self._final_exit_code = exit_code + self._wait_event = wait_event + + async def wait(self) -> None: + if self._wait_event is not None: + await self._wait_event.wait() + self.exit_code = self._final_exit_code + + async def kill(self) -> None: + self.exit_code = 0 + if self._wait_event is not None: + self._wait_event.set() + + class _FakeE2BPtyHandle: def __init__(self) -> None: self.pid = "pty-123" self.exit_code: int | None = None self.stdin_payloads: list[bytes] = [] + self._final_exit_code = 0 + self._done = asyncio.Event() + + async def wait(self) -> None: + await self._done.wait() + self.exit_code = self._final_exit_code async def kill(self) -> None: - self.exit_code = 0 + self.finish(0) + + def finish(self, exit_code: int) -> None: + self._final_exit_code = exit_code + self.exit_code = exit_code + self._done.set() class _FakeE2BPty: @@ -338,6 +374,7 @@ def __init__(self) -> None: self.on_data: object | None = None self.create_error: BaseException | None = None self.send_stdin_error: BaseException | None = None + self.stdin_outputs: list[bytes] = [b">>> ", b"10\n"] async def create( self, @@ -364,8 +401,8 @@ async def send_stdin( if self.send_stdin_error is not None: raise self.send_stdin_error self.handle.stdin_payloads.append(data) - if callable(self.on_data): - payload = b">>> " if len(self.handle.stdin_payloads) == 1 else b"10\n" + if callable(self.on_data) and self.stdin_outputs: + payload = self.stdin_outputs.pop(0) result = self.on_data(payload) if inspect.isawaitable(result): await result @@ -1851,6 +1888,69 @@ async def test_e2b_pty_start_non_tty_uses_commands_run_in_background() -> None: ] +@pytest.mark.asyncio +async def test_e2b_pty_start_non_tty_wakes_when_process_exits_without_output() -> None: + sandbox = _FakeE2BSandbox() + sandbox.commands.background_stdout = None + sandbox.commands.background_exit_code = 7 + sandbox.commands.background_wait_event = asyncio.Event() + state = E2BSandboxSessionState( + session_id=uuid.uuid4(), + manifest=Manifest(root="/workspace"), + snapshot=NoopSnapshot(id="snapshot"), + sandbox_id=sandbox.sandbox_id, + workspace_root_ready=True, + ) + session = E2BSandboxSession.from_state(state, sandbox=sandbox) + + task = asyncio.create_task( + session.pty_exec_start("python3", shell=False, tty=False, yield_time_s=5.0) + ) + await asyncio.sleep(0.01) + assert not task.done() + + sandbox.commands.background_wait_event.set() + update = await asyncio.wait_for(task, timeout=0.5) + + assert update.process_id is None + assert update.output == b"" + assert update.exit_code == 7 + + +@pytest.mark.asyncio +async def test_e2b_pty_write_stdin_wakes_when_process_exits_without_output() -> None: + sandbox = _FakeE2BSandbox() + sandbox.pty.stdin_outputs = [b">>> "] + state = E2BSandboxSessionState( + session_id=uuid.uuid4(), + manifest=Manifest(root="/workspace"), + snapshot=NoopSnapshot(id="snapshot"), + sandbox_id=sandbox.sandbox_id, + workspace_root_ready=True, + ) + session = E2BSandboxSession.from_state(state, sandbox=sandbox) + + started = await session.pty_exec_start("python3", shell=False, tty=True, yield_time_s=0.01) + assert started.process_id is not None + + task = asyncio.create_task( + session.pty_write_stdin( + session_id=started.process_id, + chars="raise SystemExit(9)\n", + yield_time_s=5.0, + ) + ) + await asyncio.sleep(0.01) + assert not task.done() + + sandbox.pty.handle.finish(9) + update = await asyncio.wait_for(task, timeout=0.5) + + assert update.process_id is None + assert update.output == b"" + assert update.exit_code == 9 + + @pytest.mark.asyncio async def test_e2b_pty_start_non_tty_wraps_background_run_failures() -> None: sandbox = _FakeE2BSandbox()