Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions codeflash/languages/java/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,19 @@
GRACEFUL_SHUTDOWN_WAIT = 5 # seconds to wait after SIGTERM before SIGKILL


def _run_java_with_graceful_timeout(
java_command: list[str], env: dict[str, str], timeout: int, stage_name: str
) -> None:
def _run_java_with_graceful_timeout(java_command: list[str], env: dict[str, str], timeout: int, stage_name: str) -> int:
"""Run a Java command with graceful timeout handling.

Sends SIGTERM first (allowing JFR dump and shutdown hooks to run),
then SIGKILL if the process doesn't exit within GRACEFUL_SHUTDOWN_WAIT seconds.

Returns the process exit code, or -1 if the process was killed due to timeout.
"""
if not timeout:
subprocess.run(java_command, env=env, check=False)
return
result = subprocess.run(java_command, env=env, check=False)
if result.returncode != 0:
logger.warning("%s exited with code %d", stage_name, result.returncode)
return result.returncode

import signal

Expand All @@ -45,6 +47,11 @@ def _run_java_with_graceful_timeout(
logger.warning("%s stage did not exit after SIGTERM, sending SIGKILL", stage_name)
proc.kill()
proc.wait()
return -1

if proc.returncode != 0:
logger.warning("%s exited with code %d", stage_name, proc.returncode)
return proc.returncode


# --add-opens flags needed for Kryo serialization on Java 16+
Expand Down Expand Up @@ -78,21 +85,27 @@ def trace(
jfr_file = trace_db_path.with_suffix(".jfr")
trace_db_path.parent.mkdir(parents=True, exist_ok=True)

# Stage 1: JFR Profiling
# Stage 1: JFR Profiling (non-fatal — JFR data is supplementary)
logger.info("Stage 1: Running JFR profiling...")
jfr_env = self.build_jfr_env(jfr_file)
_run_java_with_graceful_timeout(java_command, jfr_env, timeout, "JFR profiling")
jfr_exit = _run_java_with_graceful_timeout(java_command, jfr_env, timeout, "JFR profiling")

if not jfr_file.exists():
if jfr_exit != 0:
logger.warning("JFR profiling failed (exit code %d), continuing without profiling data", jfr_exit)
elif not jfr_file.exists():
logger.warning("JFR file was not created at %s", jfr_file)

# Stage 2: Argument Capture via Tracing Agent
# Stage 2: Argument Capture via Tracing Agent (fatal — trace data is essential)
logger.info("Stage 2: Running argument capture...")
config_path = self.create_tracer_config(
trace_db_path, packages, project_root=project_root, max_function_count=max_function_count, timeout=timeout
)
agent_env = self.build_agent_env(config_path)
_run_java_with_graceful_timeout(java_command, agent_env, timeout, "Argument capture")
capture_exit = _run_java_with_graceful_timeout(java_command, agent_env, timeout, "Argument capture")

if capture_exit != 0:
msg = f"Argument capture failed with exit code {capture_exit} — cannot proceed without trace data"
raise RuntimeError(msg)

if not trace_db_path.exists():
logger.error("Trace database was not created at %s", trace_db_path)
Expand Down
147 changes: 147 additions & 0 deletions tests/test_languages/test_java/test_tracer_exit_codes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch

if TYPE_CHECKING:
from pathlib import Path

import pytest

from codeflash.languages.java.tracer import JavaTracer, _run_java_with_graceful_timeout


class TestRunJavaWithGracefulTimeout:
def test_returns_zero_on_success(self) -> None:
mock_result = MagicMock()
mock_result.returncode = 0
with patch("codeflash.languages.java.tracer.subprocess.run", return_value=mock_result):
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 0, "test")
assert rc == 0

def test_returns_nonzero_on_failure(self) -> None:
mock_result = MagicMock()
mock_result.returncode = 1
with patch("codeflash.languages.java.tracer.subprocess.run", return_value=mock_result):
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 0, "test")
assert rc == 1

def test_returns_exit_code_137_oom_kill(self) -> None:
mock_result = MagicMock()
mock_result.returncode = 137
with patch("codeflash.languages.java.tracer.subprocess.run", return_value=mock_result):
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 0, "test")
assert rc == 137

def test_timeout_path_returns_zero_on_success(self) -> None:
mock_proc = MagicMock()
mock_proc.returncode = 0
with patch("codeflash.languages.java.tracer.subprocess.Popen", return_value=mock_proc):
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 60, "test")
assert rc == 0

def test_timeout_path_returns_nonzero_on_failure(self) -> None:
mock_proc = MagicMock()
mock_proc.returncode = 1
with patch("codeflash.languages.java.tracer.subprocess.Popen", return_value=mock_proc):
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 60, "test")
assert rc == 1

def test_timeout_returns_negative_one(self) -> None:
import subprocess

mock_proc = MagicMock()
# First wait() times out, SIGTERM wait succeeds
mock_proc.wait.side_effect = [
subprocess.TimeoutExpired(cmd="java", timeout=60),
None, # SIGTERM wait succeeds
]
with patch("codeflash.languages.java.tracer.subprocess.Popen", return_value=mock_proc):
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 60, "test")
assert rc == -1

def test_timeout_sends_sigterm_then_sigkill(self) -> None:
import signal
import subprocess

mock_proc = MagicMock()
# First wait() times out, SIGTERM wait also times out
mock_proc.wait.side_effect = [
subprocess.TimeoutExpired(cmd="java", timeout=60),
subprocess.TimeoutExpired(cmd="java", timeout=5),
None,
]
with patch("codeflash.languages.java.tracer.subprocess.Popen", return_value=mock_proc):
rc = _run_java_with_graceful_timeout(["java", "-version"], {}, 60, "test")

assert rc == -1
mock_proc.send_signal.assert_called_once_with(signal.SIGTERM)
mock_proc.kill.assert_called_once()


class TestJavaTracerExitCodeHandling:
def test_stage1_failure_continues(self, tmp_path: Path) -> None:
trace_db_path = (tmp_path / "trace.db").resolve()
tracer = JavaTracer()

# Stage 1 fails (exit code 1), Stage 2 succeeds (exit code 0)
exit_codes = iter([1, 0])

def mock_run_timeout(java_command: list[str], env: dict, timeout: int, stage_name: str) -> int:
rc = next(exit_codes)
if stage_name == "Argument capture":
trace_db_path.write_bytes(b"fake-db")
return rc

with (
patch("codeflash.languages.java.tracer._run_java_with_graceful_timeout", side_effect=mock_run_timeout),
patch.object(tracer, "build_jfr_env", return_value={}),
patch.object(tracer, "build_agent_env", return_value={}),
patch.object(tracer, "create_tracer_config", return_value=tmp_path / "config.json"),
):
trace_db, _jfr_file = tracer.trace(
java_command=["java", "-cp", ".", "Main"], trace_db_path=trace_db_path, packages=["com.example"]
)
# Should complete despite Stage 1 failure
assert trace_db == trace_db_path

def test_stage2_failure_raises(self, tmp_path: Path) -> None:
trace_db_path = (tmp_path / "trace.db").resolve()
tracer = JavaTracer()

# Stage 1 succeeds (exit code 0), Stage 2 fails (exit code 1)
exit_codes = iter([0, 1])

def mock_run_timeout(java_command: list[str], env: dict, timeout: int, stage_name: str) -> int:
return next(exit_codes)

with (
patch("codeflash.languages.java.tracer._run_java_with_graceful_timeout", side_effect=mock_run_timeout),
patch.object(tracer, "build_jfr_env", return_value={}),
patch.object(tracer, "build_agent_env", return_value={}),
patch.object(tracer, "create_tracer_config", return_value=tmp_path / "config.json"),
pytest.raises(RuntimeError, match="Argument capture failed with exit code 1"),
):
tracer.trace(
java_command=["java", "-cp", ".", "Main"], trace_db_path=trace_db_path, packages=["com.example"]
)

def test_both_stages_succeed(self, tmp_path: Path) -> None:
trace_db_path = (tmp_path / "trace.db").resolve()
tracer = JavaTracer()

def mock_run_timeout(java_command: list[str], env: dict, timeout: int, stage_name: str) -> int:
if stage_name == "Argument capture":
trace_db_path.write_bytes(b"fake-db")
return 0

with (
patch("codeflash.languages.java.tracer._run_java_with_graceful_timeout", side_effect=mock_run_timeout),
patch.object(tracer, "build_jfr_env", return_value={}),
patch.object(tracer, "build_agent_env", return_value={}),
patch.object(tracer, "create_tracer_config", return_value=tmp_path / "config.json"),
):
trace_db, _jfr_file = tracer.trace(
java_command=["java", "-cp", ".", "Main"], trace_db_path=trace_db_path, packages=["com.example"]
)
assert trace_db == trace_db_path
Loading