Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions src/agents/extensions/sandbox/e2b/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,8 @@ class _E2BPtyProcessEntry:
output_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
output_notify: asyncio.Event = field(default_factory=asyncio.Event)
last_used: float = field(default_factory=time.monotonic)
exit_code: int | None = None
wait_task: asyncio.Task[None] | None = None


@dataclass(frozen=True)
Expand Down Expand Up @@ -999,6 +1001,7 @@ async def _append_output(payload: bytes | bytearray | str | object) -> None:
on_stderr=_append_output,
)
entry.handle = handle
entry.wait_task = asyncio.create_task(self._run_pty_waiter(entry))
async with self._pty_lock:
process_id = allocate_pty_process_id(self._reserved_pty_process_ids)
self._reserved_pty_process_ids.add(process_id)
Expand Down Expand Up @@ -1215,6 +1218,24 @@ async def _collect_pty_output(
truncated_text, original_token_count = truncate_text_by_tokens(text, max_output_tokens)
return truncated_text.encode("utf-8", errors="replace"), original_token_count

async def _run_pty_waiter(self, entry: _E2BPtyProcessEntry) -> None:
try:
result = await cast(Any, entry.handle).wait()
entry.exit_code = int(result.exit_code)
except asyncio.CancelledError:
raise
except Exception as e:
# E2B raises CommandExitException, which carries the exit code, when a
# command exits nonzero.
value = getattr(e, "exit_code", None)
if value is not None:
try:
entry.exit_code = int(value)
except (TypeError, ValueError):
pass
finally:
entry.output_notify.set()

async def _finalize_pty_update(
self,
*,
Expand Down Expand Up @@ -1258,6 +1279,8 @@ def _prune_pty_processes_if_needed(self) -> _E2BPtyProcessEntry | None:

def _entry_exit_code(self, entry: _E2BPtyProcessEntry) -> int | None:
value = getattr(entry.handle, "exit_code", None)
if value is None:
value = entry.exit_code
if value is None:
return None
try:
Expand All @@ -1266,12 +1289,20 @@ def _entry_exit_code(self, entry: _E2BPtyProcessEntry) -> int | None:
return None

async def _terminate_pty_entry(self, entry: _E2BPtyProcessEntry) -> None:
kill = getattr(entry.handle, "kill", None)
if callable(kill):
try:
await kill()
except Exception:
pass
wait_task = entry.wait_task

if self._entry_exit_code(entry) is None:
kill = getattr(entry.handle, "kill", None)
if callable(kill):
try:
await kill()
except Exception:
pass

if wait_task is not None:
if not wait_task.done():
wait_task.cancel()
await asyncio.gather(wait_task, return_exceptions=True)

def _tar_exclude_args(self) -> list[str]:
return shell_tar_exclude_args(self._persist_workspace_skip_relpaths())
Expand Down
237 changes: 216 additions & 21 deletions tests/extensions/sandbox/test_e2b.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,51 @@ def __init__(self, *, stdout: str = "", stderr: str = "", exit_code: int = 0) ->
self.exit_code = exit_code


class _FakeE2BCommandExitException(Exception):
def __init__(self, *, exit_code: int) -> None:
super().__init__(f"command exited with {exit_code}")
self.exit_code = exit_code


class _FakeE2BAsyncCommandHandle:
def __init__(
self,
*,
result_exit_code: int = 0,
wait_delay_s: float = 0,
wait_error: BaseException | None = None,
wait_never: bool = False,
) -> None:
self.exit_code: int | None = None
self.result_exit_code = result_exit_code
self.wait_delay_s = wait_delay_s
self.wait_error = wait_error
self.wait_never = wait_never
self.wait_calls = 0
self.wait_cancelled = False
self.kill_calls = 0

async def wait(self) -> _FakeE2BResult:
self.wait_calls += 1
try:
if self.wait_never:
await asyncio.Event().wait()
if self.wait_delay_s:
await asyncio.sleep(self.wait_delay_s)
if self.wait_error is not None:
raise self.wait_error
self.exit_code = self.result_exit_code
return _FakeE2BResult(exit_code=self.result_exit_code)
except asyncio.CancelledError:
self.wait_cancelled = True
raise

async def kill(self) -> bool:
self.kill_calls += 1
self.exit_code = 0
return True


class _FakeE2BFiles:
def __init__(self) -> None:
self.make_dir_calls: list[tuple[str, float | None]] = []
Expand Down Expand Up @@ -244,6 +289,8 @@ def __init__(self) -> None:
self.next_result = _FakeE2BResult()
self.background_calls: list[dict[str, object]] = []
self.background_error: BaseException | None = None
self.next_async_command_handle: _FakeE2BAsyncCommandHandle | None = None
self.async_command_stdout_chunks: list[bytes | str] = []

async def run(
self,
Expand All @@ -257,7 +304,7 @@ async def run(
stdin: bool | None = None,
timeout: float | None = None,
request_timeout: float | None = None,
) -> _FakeE2BResult:
) -> object:
_ = request_timeout
if background:
if self.background_error is not None:
Expand All @@ -274,17 +321,12 @@ async def run(
}
)
if callable(on_stdout):
result = on_stdout("started\n")
if inspect.isawaitable(result):
await result

class _Handle:
exit_code = 0

async def kill(self) -> None:
return None
for chunk in self.async_command_stdout_chunks:
result = on_stdout(chunk)
if inspect.isawaitable(result):
await result

return cast(_FakeE2BResult, _Handle())
return self.next_async_command_handle or _FakeE2BAsyncCommandHandle()

self.calls.append(
{
Expand Down Expand Up @@ -322,20 +364,30 @@ async def kill(self) -> None:
return result


class _FakeE2BPtyHandle:
def __init__(self) -> None:
class _FakeE2BPtyHandle(_FakeE2BAsyncCommandHandle):
def __init__(
self,
*,
result_exit_code: int = 0,
wait_delay_s: float = 0,
wait_error: BaseException | None = None,
wait_never: bool = True,
) -> None:
super().__init__(
result_exit_code=result_exit_code,
wait_delay_s=wait_delay_s,
wait_error=wait_error,
wait_never=wait_never,
)
self.pid = "pty-123"
self.exit_code: int | None = None
self.stdin_payloads: list[bytes] = []

async def kill(self) -> None:
self.exit_code = 0


class _FakeE2BPty:
def __init__(self) -> None:
self.handle = _FakeE2BPtyHandle()
self.on_data: object | None = None
self.stdin_output_chunks: list[bytes | str] = []
self.create_error: BaseException | None = None
self.send_stdin_error: BaseException | None = None

Expand Down Expand Up @@ -365,10 +417,11 @@ async def send_stdin(
raise self.send_stdin_error
self.handle.stdin_payloads.append(data)
if callable(self.on_data):
payload = b">>> " if len(self.handle.stdin_payloads) == 1 else b"10\n"
result = self.on_data(payload)
if inspect.isawaitable(result):
await result
for chunk in self.stdin_output_chunks:
result = self.on_data(chunk)
if inspect.isawaitable(result):
await result
self.stdin_output_chunks.clear()


class _FakeE2BSandbox:
Expand Down Expand Up @@ -1798,6 +1851,7 @@ async def _fake_rm(path: Path | str, *, recursive: bool = False) -> None:
@pytest.mark.asyncio
async def test_e2b_pty_start_and_write_stdin() -> None:
sandbox = _FakeE2BSandbox()
sandbox.pty.stdin_output_chunks = [b">>> "]
state = E2BSandboxSessionState(
session_id=uuid.uuid4(),
manifest=Manifest(root="/workspace"),
Expand All @@ -1812,6 +1866,7 @@ async def test_e2b_pty_start_and_write_stdin() -> None:
assert started.process_id is not None
assert b">>>" in started.output

sandbox.pty.stdin_output_chunks = [b"10\n"]
updated = await session.pty_write_stdin(
session_id=started.process_id,
chars="5 + 5\n",
Expand All @@ -1822,10 +1877,13 @@ async def test_e2b_pty_start_and_write_stdin() -> None:
assert b"10" in updated.output
assert sandbox.pty.handle.stdin_payloads == [b"python3\n", b"5 + 5\n"]

await session.pty_terminate_all()


@pytest.mark.asyncio
async def test_e2b_pty_start_non_tty_uses_commands_run_in_background() -> None:
sandbox = _FakeE2BSandbox()
sandbox.commands.async_command_stdout_chunks = ["started\n"]
state = E2BSandboxSessionState(
session_id=uuid.uuid4(),
manifest=Manifest(root="/workspace"),
Expand All @@ -1851,6 +1909,143 @@ async def test_e2b_pty_start_non_tty_uses_commands_run_in_background() -> None:
]


@pytest.mark.asyncio
async def test_e2b_pty_start_non_tty_wakes_when_exit_follows_last_output() -> None:
sandbox = _FakeE2BSandbox()
handle = _FakeE2BAsyncCommandHandle(wait_delay_s=0.01)
sandbox.commands.next_async_command_handle = handle
sandbox.commands.async_command_stdout_chunks = ["started\n"]
state = E2BSandboxSessionState(
session_id=uuid.uuid4(),
manifest=Manifest(root="/workspace"),
snapshot=NoopSnapshot(id="snapshot"),
sandbox_id=sandbox.sandbox_id,
workspace_root_ready=True,
)
session = E2BSandboxSession.from_state(state, sandbox=sandbox)

started = await asyncio.wait_for(
session.pty_exec_start("python3", shell=False, tty=False, yield_time_s=10),
timeout=1,
)

assert started.process_id is None
assert started.exit_code == 0
assert started.output == b"started\n"
assert handle.wait_calls == 1
assert handle.kill_calls == 0


@pytest.mark.asyncio
async def test_e2b_pty_start_tty_wakes_when_session_exits_after_output() -> None:
sandbox = _FakeE2BSandbox()
handle = _FakeE2BPtyHandle(wait_never=False, wait_delay_s=0.01)
sandbox.pty.handle = handle
sandbox.pty.stdin_output_chunks = [b"bye\n"]
state = E2BSandboxSessionState(
session_id=uuid.uuid4(),
manifest=Manifest(root="/workspace"),
snapshot=NoopSnapshot(id="snapshot"),
sandbox_id=sandbox.sandbox_id,
workspace_root_ready=True,
)
session = E2BSandboxSession.from_state(state, sandbox=sandbox)

started = await asyncio.wait_for(
session.pty_exec_start("exit", shell=False, tty=True, yield_time_s=10),
timeout=1,
)

assert started.process_id is None
assert started.exit_code == 0
assert started.output == b"bye\n"
assert handle.stdin_payloads == [b"exit\n"]
assert handle.wait_calls == 1
assert handle.kill_calls == 0


@pytest.mark.asyncio
async def test_e2b_pty_start_non_tty_wakes_on_quiet_exit() -> None:
sandbox = _FakeE2BSandbox()
handle = _FakeE2BAsyncCommandHandle(wait_delay_s=0.01)
sandbox.commands.next_async_command_handle = handle
state = E2BSandboxSessionState(
session_id=uuid.uuid4(),
manifest=Manifest(root="/workspace"),
snapshot=NoopSnapshot(id="snapshot"),
sandbox_id=sandbox.sandbox_id,
workspace_root_ready=True,
)
session = E2BSandboxSession.from_state(state, sandbox=sandbox)

started = await asyncio.wait_for(
session.pty_exec_start("true", shell=False, tty=False, yield_time_s=10),
timeout=1,
)

assert started.process_id is None
assert started.exit_code == 0
assert started.output == b""
assert handle.wait_calls == 1
assert handle.kill_calls == 0


@pytest.mark.asyncio
async def test_e2b_pty_start_non_tty_wakes_on_nonzero_wait_exit() -> None:
sandbox = _FakeE2BSandbox()
handle = _FakeE2BAsyncCommandHandle(
wait_delay_s=0.01,
wait_error=_FakeE2BCommandExitException(exit_code=2),
)
sandbox.commands.next_async_command_handle = handle
state = E2BSandboxSessionState(
session_id=uuid.uuid4(),
manifest=Manifest(root="/workspace"),
snapshot=NoopSnapshot(id="snapshot"),
sandbox_id=sandbox.sandbox_id,
workspace_root_ready=True,
)
session = E2BSandboxSession.from_state(state, sandbox=sandbox)

started = await asyncio.wait_for(
session.pty_exec_start("false", shell=False, tty=False, yield_time_s=10),
timeout=1,
)

assert started.process_id is None
assert started.exit_code == 2
assert started.output == b""
assert handle.wait_calls == 1
assert handle.kill_calls == 0


@pytest.mark.asyncio
async def test_e2b_pty_start_non_tty_running_command_cleans_up_waiter() -> None:
sandbox = _FakeE2BSandbox()
handle = _FakeE2BAsyncCommandHandle(wait_never=True)
sandbox.commands.next_async_command_handle = handle
state = E2BSandboxSessionState(
session_id=uuid.uuid4(),
manifest=Manifest(root="/workspace"),
snapshot=NoopSnapshot(id="snapshot"),
sandbox_id=sandbox.sandbox_id,
workspace_root_ready=True,
)
session = E2BSandboxSession.from_state(state, sandbox=sandbox)

started = await session.pty_exec_start("sleep", "60", shell=False, tty=False, yield_time_s=0.01)

assert started.process_id is not None
assert started.exit_code is None
assert handle.wait_calls == 1
assert handle.kill_calls == 0

await session.pty_terminate_all()

assert handle.wait_cancelled
assert handle.kill_calls == 1


@pytest.mark.asyncio
async def test_e2b_pty_start_non_tty_wraps_background_run_failures() -> None:
sandbox = _FakeE2BSandbox()
Expand Down
Loading