Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions src/agents/extensions/sandbox/e2b/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,8 @@ class _E2BPtyProcessEntry:
output_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
output_notify: asyncio.Event = field(default_factory=asyncio.Event)
last_used: float = field(default_factory=time.monotonic)
done: bool = False
exit_code: int | None = None


@dataclass(frozen=True)
Expand Down Expand Up @@ -982,6 +984,7 @@ async def _append_output(payload: bytes | bytearray | str | object) -> None:
on_data=_append_output,
)
entry.handle = handle
asyncio.create_task(self._run_pty_waiter(entry))
await self._sandbox.pty.send_stdin(
cast(Any, handle).pid,
f"{command_text}\n".encode(),
Expand All @@ -999,6 +1002,7 @@ async def _append_output(payload: bytes | bytearray | str | object) -> None:
on_stderr=_append_output,
)
entry.handle = handle
asyncio.create_task(self._run_pty_waiter(entry))
async with self._pty_lock:
process_id = allocate_pty_process_id(self._reserved_pty_process_ids)
self._reserved_pty_process_ids.add(process_id)
Expand Down Expand Up @@ -1044,6 +1048,24 @@ async def _append_output(payload: bytes | bytearray | str | object) -> None:
original_token_count=original_token_count,
)

async def _run_pty_waiter(self, entry: _E2BPtyProcessEntry) -> None:
wait = getattr(entry.handle, "wait", None)
if not callable(wait):
return

try:
result = wait()
if inspect.isawaitable(result):
await result
exit_code = getattr(entry.handle, "exit_code", None)
if exit_code is not None:
entry.exit_code = int(exit_code)
except Exception:
pass
finally:
entry.done = True
Comment on lines +1063 to +1066

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid marking PTY done when wait fails

When E2B's wait() raises for a stream/RPC error rather than because the process exited, this catch-all still falls through to finally and sets entry.done. The next _finalize_pty_update treats that as an exited process, removes the session, and calls kill, so a long-running PTY/background command can be reported as process_id=None/exit_code=None and be terminated just because the waiter connection failed. Only mark the entry done after a successful wait or a CommandExitException/result that proves the process actually ended.

Useful? React with 👍 / 👎.

entry.output_notify.set()

async def pty_write_stdin(
self,
*,
Expand Down Expand Up @@ -1195,7 +1217,7 @@ async def _collect_pty_output(
if time.monotonic() >= deadline:
break

if self._entry_exit_code(entry) is not None:
if self._entry_done(entry):
async with entry.output_lock:
while entry.output_chunks:
output.extend(entry.output_chunks.popleft())
Expand Down Expand Up @@ -1226,7 +1248,7 @@ async def _finalize_pty_update(
exit_code = self._entry_exit_code(entry)
live_process_id: int | None = process_id

if exit_code is not None:
if self._entry_done(entry):
async with self._pty_lock:
removed = self._pty_processes.pop(process_id, None)
self._reserved_pty_process_ids.discard(process_id)
Expand All @@ -1246,7 +1268,7 @@ def _prune_pty_processes_if_needed(self) -> _E2BPtyProcessEntry | None:
return None

meta: list[tuple[int, float, bool]] = [
(process_id, entry.last_used, self._entry_exit_code(entry) is not None)
(process_id, entry.last_used, self._entry_done(entry))
for process_id, entry in self._pty_processes.items()
]
process_id = process_id_to_prune_from_meta(meta)
Expand All @@ -1257,6 +1279,8 @@ def _prune_pty_processes_if_needed(self) -> _E2BPtyProcessEntry | None:
return self._pty_processes.pop(process_id, None)

def _entry_exit_code(self, entry: _E2BPtyProcessEntry) -> int | None:
if entry.exit_code is not None:
return entry.exit_code
value = getattr(entry.handle, "exit_code", None)
if value is None:
return None
Expand All @@ -1265,6 +1289,9 @@ def _entry_exit_code(self, entry: _E2BPtyProcessEntry) -> int | None:
except (TypeError, ValueError):
return None

def _entry_done(self, entry: _E2BPtyProcessEntry) -> bool:
return entry.done or self._entry_exit_code(entry) is not None

async def _terminate_pty_entry(self, entry: _E2BPtyProcessEntry) -> None:
kill = getattr(entry.handle, "kill", None)
if callable(kill):
Expand Down
124 changes: 112 additions & 12 deletions tests/extensions/sandbox/test_e2b.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ def __init__(self) -> None:
self.next_result = _FakeE2BResult()
self.background_calls: list[dict[str, object]] = []
self.background_error: BaseException | None = None
self.background_stdout: bytes | str | None = "started\n"
self.background_exit_code = 0
self.background_wait_event: asyncio.Event | None = None

async def run(
self,
Expand Down Expand Up @@ -273,18 +276,18 @@ async def run(
"background": background,
}
)
if callable(on_stdout):
result = on_stdout("started\n")
if callable(on_stdout) and self.background_stdout is not None:
result = on_stdout(self.background_stdout)
if inspect.isawaitable(result):
await result

class _Handle:
exit_code = 0

async def kill(self) -> None:
return None

return cast(_FakeE2BResult, _Handle())
return cast(
_FakeE2BResult,
_FakeE2BBackgroundHandle(
exit_code=self.background_exit_code,
wait_event=self.background_wait_event,
),
)

self.calls.append(
{
Expand Down Expand Up @@ -322,14 +325,47 @@ async def kill(self) -> None:
return result


class _FakeE2BBackgroundHandle:
def __init__(
self,
*,
exit_code: int,
wait_event: asyncio.Event | None = None,
) -> None:
self.exit_code: int | None = exit_code if wait_event is None else None
self._final_exit_code = exit_code
self._wait_event = wait_event

async def wait(self) -> None:
if self._wait_event is not None:
await self._wait_event.wait()
self.exit_code = self._final_exit_code

async def kill(self) -> None:
self.exit_code = 0
if self._wait_event is not None:
self._wait_event.set()


class _FakeE2BPtyHandle:
def __init__(self) -> None:
self.pid = "pty-123"
self.exit_code: int | None = None
self.stdin_payloads: list[bytes] = []
self._final_exit_code = 0
self._done = asyncio.Event()

async def wait(self) -> None:
await self._done.wait()
self.exit_code = self._final_exit_code

async def kill(self) -> None:
self.exit_code = 0
self.finish(0)

def finish(self, exit_code: int) -> None:
self._final_exit_code = exit_code
self.exit_code = exit_code
self._done.set()


class _FakeE2BPty:
Expand All @@ -338,6 +374,7 @@ def __init__(self) -> None:
self.on_data: object | None = None
self.create_error: BaseException | None = None
self.send_stdin_error: BaseException | None = None
self.stdin_outputs: list[bytes] = [b">>> ", b"10\n"]

async def create(
self,
Expand All @@ -364,8 +401,8 @@ async def send_stdin(
if self.send_stdin_error is not None:
raise self.send_stdin_error
self.handle.stdin_payloads.append(data)
if callable(self.on_data):
payload = b">>> " if len(self.handle.stdin_payloads) == 1 else b"10\n"
if callable(self.on_data) and self.stdin_outputs:
payload = self.stdin_outputs.pop(0)
result = self.on_data(payload)
if inspect.isawaitable(result):
await result
Expand Down Expand Up @@ -1851,6 +1888,69 @@ async def test_e2b_pty_start_non_tty_uses_commands_run_in_background() -> None:
]


@pytest.mark.asyncio
async def test_e2b_pty_start_non_tty_wakes_when_process_exits_without_output() -> None:
sandbox = _FakeE2BSandbox()
sandbox.commands.background_stdout = None
sandbox.commands.background_exit_code = 7
sandbox.commands.background_wait_event = asyncio.Event()
state = E2BSandboxSessionState(
session_id=uuid.uuid4(),
manifest=Manifest(root="/workspace"),
snapshot=NoopSnapshot(id="snapshot"),
sandbox_id=sandbox.sandbox_id,
workspace_root_ready=True,
)
session = E2BSandboxSession.from_state(state, sandbox=sandbox)

task = asyncio.create_task(
session.pty_exec_start("python3", shell=False, tty=False, yield_time_s=5.0)
)
await asyncio.sleep(0.01)
assert not task.done()

sandbox.commands.background_wait_event.set()
update = await asyncio.wait_for(task, timeout=0.5)

assert update.process_id is None
assert update.output == b""
assert update.exit_code == 7


@pytest.mark.asyncio
async def test_e2b_pty_write_stdin_wakes_when_process_exits_without_output() -> None:
sandbox = _FakeE2BSandbox()
sandbox.pty.stdin_outputs = [b">>> "]
state = E2BSandboxSessionState(
session_id=uuid.uuid4(),
manifest=Manifest(root="/workspace"),
snapshot=NoopSnapshot(id="snapshot"),
sandbox_id=sandbox.sandbox_id,
workspace_root_ready=True,
)
session = E2BSandboxSession.from_state(state, sandbox=sandbox)

started = await session.pty_exec_start("python3", shell=False, tty=True, yield_time_s=0.01)
assert started.process_id is not None

task = asyncio.create_task(
session.pty_write_stdin(
session_id=started.process_id,
chars="raise SystemExit(9)\n",
yield_time_s=5.0,
)
)
await asyncio.sleep(0.01)
assert not task.done()

sandbox.pty.handle.finish(9)
update = await asyncio.wait_for(task, timeout=0.5)

assert update.process_id is None
assert update.output == b""
assert update.exit_code == 9


@pytest.mark.asyncio
async def test_e2b_pty_start_non_tty_wraps_background_run_failures() -> None:
sandbox = _FakeE2BSandbox()
Expand Down