Skip to content

Commit 0103ca2

Browse files
Implement kernel evaluation harness including JAX evaluator, code adapter, and dataset definitions.
1 parent 39a246d commit 0103ca2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+3495
-0
lines changed

MaxKernel/evaluation/.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
.pytest_cache/
2+
.env
3+
*functional_test.py
4+
test_data/*

MaxKernel/evaluation/README.md

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# MaxKernel Evaluation Module
2+
3+
The evaluation module in MaxKernel provides tools to benchmark and evaluate JAX kernels on remote TPU VMs (or locally). It allows you to compare a reference implementation with an optimized implementation, verify correctness, and measure performance (speedup).
4+
5+
## Overview
6+
7+
The module consists of two main components:
8+
1. **`jax_kernel_evaluator.py`**: A Python module containing the `JAXKernelEvaluator` class, which handles the orchestration of running evaluations on a remote TPU VM or locally.
9+
2. **`benchmark.py`**: A Command Line Interface (CLI) tool that automates the evaluation process across a dataset of problems.
10+
11+
---
12+
13+
## JAXKernelEvaluator
14+
15+
The `JAXKernelEvaluator` class is responsible for setting up the environment, uploading necessary files to a remote TPU VM (if applicable), executing the evaluation harness, and parsing the results.
16+
17+
### Key Features
18+
- Supports both **remote** evaluation (on a TPU VM via SSH) and **local** evaluation.
19+
- Automates file transfer and command execution on remote TPUs.
20+
- Includes an optional **adaptation** step using the Gemini API to refactor code or generate task configurations if they are missing.
21+
- Handles timeouts and attempts to clean up runaway processes on remote VMs.
22+
23+
### Programmatic Usage
24+
25+
Here is an example of how to use `JAXKernelEvaluator` in a Python script:
26+
27+
```python
28+
from evaluation.jax_kernel_evaluator import JAXKernelEvaluator
29+
30+
# Initialize the evaluator for remote execution
31+
evaluator = JAXKernelEvaluator(
32+
local=False,
33+
tpu_name="your-tpu-name",
34+
project="your-gcp-project",
35+
zone="your-tpu-zone",
36+
venv_path="/path/to/venv/on/tpu",
37+
)
38+
39+
# Run evaluation
40+
result = evaluator.evaluate(
41+
reference_code_path="path/to/reference.py",
42+
optimized_code_path="path/to/optimized.py",
43+
task_yaml_path="path/to/kernel_task.yaml",
44+
atol=1e-3,
45+
rtol=1e-3,
46+
)
47+
48+
print(f"Task ID: {result.task_id}")
49+
print(f"Correctness: {result.correctness}")
50+
print(f"Speedup: {result.speedup}x")
51+
if result.error_trace:
52+
print(f"Error: {result.error_trace}")
53+
```
54+
55+
### Initialization Arguments
56+
57+
| Argument | Type | Description |
58+
| :--- | :--- | :--- |
59+
| `local` | `bool` | If `True`, runs evaluation locally. Default is `False`. |
60+
| `tpu_name` | `str` | Name of the TPU VM to connect to for remote evaluation. |
61+
| `project` | `str` | Google Cloud project ID. |
62+
| `zone` | `str` | Google Cloud zone. |
63+
| `venv_path` | `str` | Path to the Python virtual environment locally or on the remote TPU. |
64+
| `remote_base_dir` | `str` | Base directory on remote TPU for storing temporary files. |
65+
66+
### CLI Usage
67+
68+
> [!NOTE]
69+
> All CLI commands should be run from the project root directory (`MaxKernel`).
70+
71+
You can also use `jax_kernel_evaluator.py` directly from the command line to evaluate a single kernel without setting up a dataset.
72+
73+
**Example:**
74+
75+
```bash
76+
python3 -m evaluation.jax_kernel_evaluator \
77+
--reference_code_path path/to/reference.py \
78+
--optimized_code_path path/to/optimized.py \
79+
--task_yaml_path path/to/kernel_task.yaml \
80+
--local
81+
```
82+
83+
**Arguments:**
84+
85+
| Argument | Type | Description |
86+
| :--- | :--- | :--- |
87+
| `--reference_code_path` | `str` | (Required) Local path to the reference JAX script. |
88+
| `--optimized_code_path` | `str` | (Required) Local path to the optimized JAX script. |
89+
| `--task_yaml_path` | `str` | Path to the kernel_task.yaml file. |
90+
| `--local` | Flag | Run evaluation locally instead of on a remote TPU. |
91+
| `--tpu_name` | `str` | Name of the TPU VM (required if not local). |
92+
| `--project` | `str` | Google Cloud project ID. |
93+
| `--zone` | `str` | Google Cloud zone. |
94+
| `--venv_path` | `str` | Path to the Python virtual environment. |
95+
| `--remote_base_dir` | `str` | Base directory on remote TPU for temp files. |
96+
| `--adapt` | `str` list | Components to adapt via LLM (`reference_code`, `optimized_code`, `kernel_task`). |
97+
| `--timeout_seconds` | `int` | Maximum time in seconds to wait for execution (default: 300). |
98+
| `--no_cleanup` | Flag | Skip cleanup of temporary files on the remote VM. |
99+
| `--atol` | `float` | Absolute tolerance for correctness check (default: 1e-3). |
100+
| `--rtol` | `float` | Relative tolerance for correctness check (default: 1e-3). |
101+
102+
---
103+
104+
## Benchmark CLI (`benchmark.py`)
105+
106+
The `benchmark.py` script is a command-line tool that iterates through a dataset of problems and evaluates them using `JAXKernelEvaluator`.
107+
108+
### Dataset Structure
109+
110+
To use the benchmark CLI, your dataset directory should have the following structure:
111+
112+
```text
113+
dataset_dir/
114+
├── problem_1/
115+
│ ├── reference.py
116+
│ ├── optimized.py
117+
│ └── kernel_task.yaml
118+
├── problem_2/
119+
│ ├── reference.py
120+
│ ├── optimized.py
121+
│ └── kernel_task.yaml
122+
...
123+
```
124+
125+
### CLI Usage Examples
126+
127+
> [!NOTE]
128+
> All CLI commands should be run from the project root directory (`MaxKernel`).
129+
130+
**Run benchmark on a remote TPU:**
131+
132+
```bash
133+
python3 -m evaluation.benchmark \
134+
--tpu_name my-tpu \
135+
--zone us-central1-a \
136+
--project my-project \
137+
--dataset_dir /path/to/dataset \
138+
--venv_path /path/to/venv/on/tpu
139+
```
140+
141+
**Run benchmark locally:**
142+
143+
```bash
144+
python3 -m evaluation.benchmark \
145+
--local \
146+
--dataset_dir /path/to/dataset \
147+
--venv_path /path/to/local/venv
148+
```
149+
150+
**Run with code adaptation (using Gemini API):**
151+
152+
If your dataset is missing `kernel_task.yaml` or requires code adaptation to fit the harness, you can use the `--adapt` flag. Note that this requires setting the `GEMINI_API_KEY` environment variable.
153+
154+
```bash
155+
export GEMINI_API_KEY="your-api-key"
156+
157+
python3 -m evaluation.benchmark \
158+
--tpu_name my-tpu \
159+
... \
160+
--adapt reference_code optimized_code kernel_task
161+
```
162+
163+
### CLI Arguments
164+
165+
| Argument | Type | Description |
166+
| :--- | :--- | :--- |
167+
| `--local` | Flag | Run evaluation locally instead of on a remote TPU. |
168+
| `--tpu_name` | `str` | The name of the target TPU VM. |
169+
| `--zone` | `str` | The Google Cloud zone where the TPU is located. |
170+
| `--project` | `str` | The Google Cloud project ID. |
171+
| `--venv_path` | `str` | Absolute path to Python venv to use. |
172+
| `--dataset_dir` | `str` | Local directory containing the benchmark dataset. |
173+
| `--adapt` | `str` list | Components to adapt via LLM (`reference_code`, `optimized_code`, `kernel_task`). |
174+
| `--reference_file_name`| `str` | Filename for reference code (default: `reference.py`). |
175+
| `--optimized_file_name`| `str` | Filename for optimized code (default: `optimized.py`). |
176+
| `--task_file_name` | `str` | Filename for kernel task yaml (default: `kernel_task.yaml`). |
177+
178+
### Results
179+
180+
The CLI will save execution results to a JSON file in the `dataset_dir` named `evaluation_results_<tpu_name>.json`. It will also automatically resume from previous runs if this file exists. At the end of the run, a summary of results will be printed.

MaxKernel/evaluation/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)