From a5d65a73f875ed05180db326db2d58ff111b430d Mon Sep 17 00:00:00 2001 From: Rohit Rastogi Date: Tue, 9 Jun 2026 18:35:39 -0700 Subject: [PATCH] Fix E2B PTY exit wakeup --- src/agents/extensions/sandbox/e2b/sandbox.py | 43 +++- tests/extensions/sandbox/test_e2b.py | 237 +++++++++++++++++-- 2 files changed, 253 insertions(+), 27 deletions(-) diff --git a/src/agents/extensions/sandbox/e2b/sandbox.py b/src/agents/extensions/sandbox/e2b/sandbox.py index 425436f3e0..514c497b90 100644 --- a/src/agents/extensions/sandbox/e2b/sandbox.py +++ b/src/agents/extensions/sandbox/e2b/sandbox.py @@ -686,6 +686,8 @@ class _E2BPtyProcessEntry: output_lock: asyncio.Lock = field(default_factory=asyncio.Lock) output_notify: asyncio.Event = field(default_factory=asyncio.Event) last_used: float = field(default_factory=time.monotonic) + exit_code: int | None = None + wait_task: asyncio.Task[None] | None = None @dataclass(frozen=True) @@ -999,6 +1001,7 @@ async def _append_output(payload: bytes | bytearray | str | object) -> None: on_stderr=_append_output, ) entry.handle = handle + entry.wait_task = asyncio.create_task(self._run_pty_waiter(entry)) async with self._pty_lock: process_id = allocate_pty_process_id(self._reserved_pty_process_ids) self._reserved_pty_process_ids.add(process_id) @@ -1215,6 +1218,24 @@ async def _collect_pty_output( truncated_text, original_token_count = truncate_text_by_tokens(text, max_output_tokens) return truncated_text.encode("utf-8", errors="replace"), original_token_count + async def _run_pty_waiter(self, entry: _E2BPtyProcessEntry) -> None: + try: + result = await cast(Any, entry.handle).wait() + entry.exit_code = int(result.exit_code) + except asyncio.CancelledError: + raise + except Exception as e: + # E2B raises CommandExitException, which carries the exit code, when a + # command exits nonzero. + value = getattr(e, "exit_code", None) + if value is not None: + try: + entry.exit_code = int(value) + except (TypeError, ValueError): + pass + finally: + entry.output_notify.set() + async def _finalize_pty_update( self, *, @@ -1258,6 +1279,8 @@ def _prune_pty_processes_if_needed(self) -> _E2BPtyProcessEntry | None: def _entry_exit_code(self, entry: _E2BPtyProcessEntry) -> int | None: value = getattr(entry.handle, "exit_code", None) + if value is None: + value = entry.exit_code if value is None: return None try: @@ -1266,12 +1289,20 @@ def _entry_exit_code(self, entry: _E2BPtyProcessEntry) -> int | None: return None async def _terminate_pty_entry(self, entry: _E2BPtyProcessEntry) -> None: - kill = getattr(entry.handle, "kill", None) - if callable(kill): - try: - await kill() - except Exception: - pass + wait_task = entry.wait_task + + if self._entry_exit_code(entry) is None: + kill = getattr(entry.handle, "kill", None) + if callable(kill): + try: + await kill() + except Exception: + pass + + if wait_task is not None: + if not wait_task.done(): + wait_task.cancel() + await asyncio.gather(wait_task, return_exceptions=True) def _tar_exclude_args(self) -> list[str]: return shell_tar_exclude_args(self._persist_workspace_skip_relpaths()) diff --git a/tests/extensions/sandbox/test_e2b.py b/tests/extensions/sandbox/test_e2b.py index 6a123770c3..f68a5f92d1 100644 --- a/tests/extensions/sandbox/test_e2b.py +++ b/tests/extensions/sandbox/test_e2b.py @@ -212,6 +212,51 @@ def __init__(self, *, stdout: str = "", stderr: str = "", exit_code: int = 0) -> self.exit_code = exit_code +class _FakeE2BCommandExitException(Exception): + def __init__(self, *, exit_code: int) -> None: + super().__init__(f"command exited with {exit_code}") + self.exit_code = exit_code + + +class _FakeE2BAsyncCommandHandle: + def __init__( + self, + *, + result_exit_code: int = 0, + wait_delay_s: float = 0, + wait_error: BaseException | None = None, + wait_never: bool = False, + ) -> None: + self.exit_code: int | None = None + self.result_exit_code = result_exit_code + self.wait_delay_s = wait_delay_s + self.wait_error = wait_error + self.wait_never = wait_never + self.wait_calls = 0 + self.wait_cancelled = False + self.kill_calls = 0 + + async def wait(self) -> _FakeE2BResult: + self.wait_calls += 1 + try: + if self.wait_never: + await asyncio.Event().wait() + if self.wait_delay_s: + await asyncio.sleep(self.wait_delay_s) + if self.wait_error is not None: + raise self.wait_error + self.exit_code = self.result_exit_code + return _FakeE2BResult(exit_code=self.result_exit_code) + except asyncio.CancelledError: + self.wait_cancelled = True + raise + + async def kill(self) -> bool: + self.kill_calls += 1 + self.exit_code = 0 + return True + + class _FakeE2BFiles: def __init__(self) -> None: self.make_dir_calls: list[tuple[str, float | None]] = [] @@ -244,6 +289,8 @@ def __init__(self) -> None: self.next_result = _FakeE2BResult() self.background_calls: list[dict[str, object]] = [] self.background_error: BaseException | None = None + self.next_async_command_handle: _FakeE2BAsyncCommandHandle | None = None + self.async_command_stdout_chunks: list[bytes | str] = [] async def run( self, @@ -257,7 +304,7 @@ async def run( stdin: bool | None = None, timeout: float | None = None, request_timeout: float | None = None, - ) -> _FakeE2BResult: + ) -> object: _ = request_timeout if background: if self.background_error is not None: @@ -274,17 +321,12 @@ async def run( } ) if callable(on_stdout): - result = on_stdout("started\n") - if inspect.isawaitable(result): - await result - - class _Handle: - exit_code = 0 - - async def kill(self) -> None: - return None + for chunk in self.async_command_stdout_chunks: + result = on_stdout(chunk) + if inspect.isawaitable(result): + await result - return cast(_FakeE2BResult, _Handle()) + return self.next_async_command_handle or _FakeE2BAsyncCommandHandle() self.calls.append( { @@ -322,20 +364,30 @@ async def kill(self) -> None: return result -class _FakeE2BPtyHandle: - def __init__(self) -> None: +class _FakeE2BPtyHandle(_FakeE2BAsyncCommandHandle): + def __init__( + self, + *, + result_exit_code: int = 0, + wait_delay_s: float = 0, + wait_error: BaseException | None = None, + wait_never: bool = True, + ) -> None: + super().__init__( + result_exit_code=result_exit_code, + wait_delay_s=wait_delay_s, + wait_error=wait_error, + wait_never=wait_never, + ) self.pid = "pty-123" - self.exit_code: int | None = None self.stdin_payloads: list[bytes] = [] - async def kill(self) -> None: - self.exit_code = 0 - class _FakeE2BPty: def __init__(self) -> None: self.handle = _FakeE2BPtyHandle() self.on_data: object | None = None + self.stdin_output_chunks: list[bytes | str] = [] self.create_error: BaseException | None = None self.send_stdin_error: BaseException | None = None @@ -365,10 +417,11 @@ async def send_stdin( raise self.send_stdin_error self.handle.stdin_payloads.append(data) if callable(self.on_data): - payload = b">>> " if len(self.handle.stdin_payloads) == 1 else b"10\n" - result = self.on_data(payload) - if inspect.isawaitable(result): - await result + for chunk in self.stdin_output_chunks: + result = self.on_data(chunk) + if inspect.isawaitable(result): + await result + self.stdin_output_chunks.clear() class _FakeE2BSandbox: @@ -1798,6 +1851,7 @@ async def _fake_rm(path: Path | str, *, recursive: bool = False) -> None: @pytest.mark.asyncio async def test_e2b_pty_start_and_write_stdin() -> None: sandbox = _FakeE2BSandbox() + sandbox.pty.stdin_output_chunks = [b">>> "] state = E2BSandboxSessionState( session_id=uuid.uuid4(), manifest=Manifest(root="/workspace"), @@ -1812,6 +1866,7 @@ async def test_e2b_pty_start_and_write_stdin() -> None: assert started.process_id is not None assert b">>>" in started.output + sandbox.pty.stdin_output_chunks = [b"10\n"] updated = await session.pty_write_stdin( session_id=started.process_id, chars="5 + 5\n", @@ -1822,10 +1877,13 @@ async def test_e2b_pty_start_and_write_stdin() -> None: assert b"10" in updated.output assert sandbox.pty.handle.stdin_payloads == [b"python3\n", b"5 + 5\n"] + await session.pty_terminate_all() + @pytest.mark.asyncio async def test_e2b_pty_start_non_tty_uses_commands_run_in_background() -> None: sandbox = _FakeE2BSandbox() + sandbox.commands.async_command_stdout_chunks = ["started\n"] state = E2BSandboxSessionState( session_id=uuid.uuid4(), manifest=Manifest(root="/workspace"), @@ -1851,6 +1909,143 @@ async def test_e2b_pty_start_non_tty_uses_commands_run_in_background() -> None: ] +@pytest.mark.asyncio +async def test_e2b_pty_start_non_tty_wakes_when_exit_follows_last_output() -> None: + sandbox = _FakeE2BSandbox() + handle = _FakeE2BAsyncCommandHandle(wait_delay_s=0.01) + sandbox.commands.next_async_command_handle = handle + sandbox.commands.async_command_stdout_chunks = ["started\n"] + state = E2BSandboxSessionState( + session_id=uuid.uuid4(), + manifest=Manifest(root="/workspace"), + snapshot=NoopSnapshot(id="snapshot"), + sandbox_id=sandbox.sandbox_id, + workspace_root_ready=True, + ) + session = E2BSandboxSession.from_state(state, sandbox=sandbox) + + started = await asyncio.wait_for( + session.pty_exec_start("python3", shell=False, tty=False, yield_time_s=10), + timeout=1, + ) + + assert started.process_id is None + assert started.exit_code == 0 + assert started.output == b"started\n" + assert handle.wait_calls == 1 + assert handle.kill_calls == 0 + + +@pytest.mark.asyncio +async def test_e2b_pty_start_tty_wakes_when_session_exits_after_output() -> None: + sandbox = _FakeE2BSandbox() + handle = _FakeE2BPtyHandle(wait_never=False, wait_delay_s=0.01) + sandbox.pty.handle = handle + sandbox.pty.stdin_output_chunks = [b"bye\n"] + state = E2BSandboxSessionState( + session_id=uuid.uuid4(), + manifest=Manifest(root="/workspace"), + snapshot=NoopSnapshot(id="snapshot"), + sandbox_id=sandbox.sandbox_id, + workspace_root_ready=True, + ) + session = E2BSandboxSession.from_state(state, sandbox=sandbox) + + started = await asyncio.wait_for( + session.pty_exec_start("exit", shell=False, tty=True, yield_time_s=10), + timeout=1, + ) + + assert started.process_id is None + assert started.exit_code == 0 + assert started.output == b"bye\n" + assert handle.stdin_payloads == [b"exit\n"] + assert handle.wait_calls == 1 + assert handle.kill_calls == 0 + + +@pytest.mark.asyncio +async def test_e2b_pty_start_non_tty_wakes_on_quiet_exit() -> None: + sandbox = _FakeE2BSandbox() + handle = _FakeE2BAsyncCommandHandle(wait_delay_s=0.01) + sandbox.commands.next_async_command_handle = handle + state = E2BSandboxSessionState( + session_id=uuid.uuid4(), + manifest=Manifest(root="/workspace"), + snapshot=NoopSnapshot(id="snapshot"), + sandbox_id=sandbox.sandbox_id, + workspace_root_ready=True, + ) + session = E2BSandboxSession.from_state(state, sandbox=sandbox) + + started = await asyncio.wait_for( + session.pty_exec_start("true", shell=False, tty=False, yield_time_s=10), + timeout=1, + ) + + assert started.process_id is None + assert started.exit_code == 0 + assert started.output == b"" + assert handle.wait_calls == 1 + assert handle.kill_calls == 0 + + +@pytest.mark.asyncio +async def test_e2b_pty_start_non_tty_wakes_on_nonzero_wait_exit() -> None: + sandbox = _FakeE2BSandbox() + handle = _FakeE2BAsyncCommandHandle( + wait_delay_s=0.01, + wait_error=_FakeE2BCommandExitException(exit_code=2), + ) + sandbox.commands.next_async_command_handle = handle + state = E2BSandboxSessionState( + session_id=uuid.uuid4(), + manifest=Manifest(root="/workspace"), + snapshot=NoopSnapshot(id="snapshot"), + sandbox_id=sandbox.sandbox_id, + workspace_root_ready=True, + ) + session = E2BSandboxSession.from_state(state, sandbox=sandbox) + + started = await asyncio.wait_for( + session.pty_exec_start("false", shell=False, tty=False, yield_time_s=10), + timeout=1, + ) + + assert started.process_id is None + assert started.exit_code == 2 + assert started.output == b"" + assert handle.wait_calls == 1 + assert handle.kill_calls == 0 + + +@pytest.mark.asyncio +async def test_e2b_pty_start_non_tty_running_command_cleans_up_waiter() -> None: + sandbox = _FakeE2BSandbox() + handle = _FakeE2BAsyncCommandHandle(wait_never=True) + sandbox.commands.next_async_command_handle = handle + state = E2BSandboxSessionState( + session_id=uuid.uuid4(), + manifest=Manifest(root="/workspace"), + snapshot=NoopSnapshot(id="snapshot"), + sandbox_id=sandbox.sandbox_id, + workspace_root_ready=True, + ) + session = E2BSandboxSession.from_state(state, sandbox=sandbox) + + started = await session.pty_exec_start("sleep", "60", shell=False, tty=False, yield_time_s=0.01) + + assert started.process_id is not None + assert started.exit_code is None + assert handle.wait_calls == 1 + assert handle.kill_calls == 0 + + await session.pty_terminate_all() + + assert handle.wait_cancelled + assert handle.kill_calls == 1 + + @pytest.mark.asyncio async def test_e2b_pty_start_non_tty_wraps_background_run_failures() -> None: sandbox = _FakeE2BSandbox()