|
| 1 | +# MaxKernel Evaluation Module |
| 2 | + |
| 3 | +The evaluation module in MaxKernel provides tools to benchmark and evaluate JAX kernels on remote TPU VMs (or locally). It allows you to compare a reference implementation with an optimized implementation, verify correctness, and measure performance (speedup). |
| 4 | + |
| 5 | +## Overview |
| 6 | + |
| 7 | +The module consists of two main components: |
| 8 | +1. **`jax_kernel_evaluator.py`**: A Python module containing the `JAXKernelEvaluator` class, which handles the orchestration of running evaluations on a remote TPU VM or locally. |
| 9 | +2. **`benchmark.py`**: A Command Line Interface (CLI) tool that automates the evaluation process across a dataset of problems. |
| 10 | + |
| 11 | +--- |
| 12 | + |
| 13 | +## JAXKernelEvaluator |
| 14 | + |
| 15 | +The `JAXKernelEvaluator` class is responsible for setting up the environment, uploading necessary files to a remote TPU VM (if applicable), executing the evaluation harness, and parsing the results. |
| 16 | + |
| 17 | +### Key Features |
| 18 | +- Supports both **remote** evaluation (on a TPU VM via SSH) and **local** evaluation. |
| 19 | +- Automates file transfer and command execution on remote TPUs. |
| 20 | +- Includes an optional **adaptation** step using the Gemini API to refactor code or generate task configurations if they are missing. |
| 21 | +- Handles timeouts and attempts to clean up runaway processes on remote VMs. |
| 22 | + |
| 23 | +### Programmatic Usage |
| 24 | + |
| 25 | +Here is an example of how to use `JAXKernelEvaluator` in a Python script: |
| 26 | + |
| 27 | +```python |
| 28 | +from evaluation.jax_kernel_evaluator import JAXKernelEvaluator |
| 29 | + |
| 30 | +# Initialize the evaluator for remote execution |
| 31 | +evaluator = JAXKernelEvaluator( |
| 32 | + local=False, |
| 33 | + tpu_name="your-tpu-name", |
| 34 | + project="your-gcp-project", |
| 35 | + zone="your-tpu-zone", |
| 36 | + venv_path="/path/to/venv/on/tpu", |
| 37 | +) |
| 38 | + |
| 39 | +# Run evaluation |
| 40 | +result = evaluator.evaluate( |
| 41 | + reference_code_path="path/to/reference.py", |
| 42 | + optimized_code_path="path/to/optimized.py", |
| 43 | + task_yaml_path="path/to/kernel_task.yaml", |
| 44 | + atol=1e-3, |
| 45 | + rtol=1e-3, |
| 46 | +) |
| 47 | + |
| 48 | +print(f"Task ID: {result.task_id}") |
| 49 | +print(f"Correctness: {result.correctness}") |
| 50 | +print(f"Speedup: {result.speedup}x") |
| 51 | +if result.error_trace: |
| 52 | + print(f"Error: {result.error_trace}") |
| 53 | +``` |
| 54 | + |
| 55 | +### Initialization Arguments |
| 56 | + |
| 57 | +| Argument | Type | Description | |
| 58 | +| :--- | :--- | :--- | |
| 59 | +| `local` | `bool` | If `True`, runs evaluation locally. Default is `False`. | |
| 60 | +| `tpu_name` | `str` | Name of the TPU VM to connect to for remote evaluation. | |
| 61 | +| `project` | `str` | Google Cloud project ID. | |
| 62 | +| `zone` | `str` | Google Cloud zone. | |
| 63 | +| `venv_path` | `str` | Path to the Python virtual environment locally or on the remote TPU. | |
| 64 | +| `remote_base_dir` | `str` | Base directory on remote TPU for storing temporary files. | |
| 65 | + |
| 66 | +### CLI Usage |
| 67 | + |
| 68 | +> [!NOTE] |
| 69 | +> All CLI commands should be run from the project root directory (`MaxKernel`). |
| 70 | +
|
| 71 | +You can also use `jax_kernel_evaluator.py` directly from the command line to evaluate a single kernel without setting up a dataset. |
| 72 | + |
| 73 | +**Example:** |
| 74 | + |
| 75 | +```bash |
| 76 | +python3 -m evaluation.jax_kernel_evaluator \ |
| 77 | + --reference_code_path path/to/reference.py \ |
| 78 | + --optimized_code_path path/to/optimized.py \ |
| 79 | + --task_yaml_path path/to/kernel_task.yaml \ |
| 80 | + --local |
| 81 | +``` |
| 82 | + |
| 83 | +**Arguments:** |
| 84 | + |
| 85 | +| Argument | Type | Description | |
| 86 | +| :--- | :--- | :--- | |
| 87 | +| `--reference_code_path` | `str` | (Required) Local path to the reference JAX script. | |
| 88 | +| `--optimized_code_path` | `str` | (Required) Local path to the optimized JAX script. | |
| 89 | +| `--task_yaml_path` | `str` | Path to the kernel_task.yaml file. | |
| 90 | +| `--local` | Flag | Run evaluation locally instead of on a remote TPU. | |
| 91 | +| `--tpu_name` | `str` | Name of the TPU VM (required if not local). | |
| 92 | +| `--project` | `str` | Google Cloud project ID. | |
| 93 | +| `--zone` | `str` | Google Cloud zone. | |
| 94 | +| `--venv_path` | `str` | Path to the Python virtual environment. | |
| 95 | +| `--remote_base_dir` | `str` | Base directory on remote TPU for temp files. | |
| 96 | +| `--adapt` | `str` list | Components to adapt via LLM (`reference_code`, `optimized_code`, `kernel_task`). | |
| 97 | +| `--timeout_seconds` | `int` | Maximum time in seconds to wait for execution (default: 300). | |
| 98 | +| `--no_cleanup` | Flag | Skip cleanup of temporary files on the remote VM. | |
| 99 | +| `--atol` | `float` | Absolute tolerance for correctness check (default: 1e-3). | |
| 100 | +| `--rtol` | `float` | Relative tolerance for correctness check (default: 1e-3). | |
| 101 | + |
| 102 | +--- |
| 103 | + |
| 104 | +## Benchmark CLI (`benchmark.py`) |
| 105 | + |
| 106 | +The `benchmark.py` script is a command-line tool that iterates through a dataset of problems and evaluates them using `JAXKernelEvaluator`. |
| 107 | + |
| 108 | +### Dataset Structure |
| 109 | + |
| 110 | +To use the benchmark CLI, your dataset directory should have the following structure: |
| 111 | + |
| 112 | +```text |
| 113 | +dataset_dir/ |
| 114 | +├── problem_1/ |
| 115 | +│ ├── reference.py |
| 116 | +│ ├── optimized.py |
| 117 | +│ └── kernel_task.yaml |
| 118 | +├── problem_2/ |
| 119 | +│ ├── reference.py |
| 120 | +│ ├── optimized.py |
| 121 | +│ └── kernel_task.yaml |
| 122 | +... |
| 123 | +``` |
| 124 | + |
| 125 | +### CLI Usage Examples |
| 126 | + |
| 127 | +> [!NOTE] |
| 128 | +> All CLI commands should be run from the project root directory (`MaxKernel`). |
| 129 | +
|
| 130 | +**Run benchmark on a remote TPU:** |
| 131 | + |
| 132 | +```bash |
| 133 | +python3 -m evaluation.benchmark \ |
| 134 | + --tpu_name my-tpu \ |
| 135 | + --zone us-central1-a \ |
| 136 | + --project my-project \ |
| 137 | + --dataset_dir /path/to/dataset \ |
| 138 | + --venv_path /path/to/venv/on/tpu |
| 139 | +``` |
| 140 | + |
| 141 | +**Run benchmark locally:** |
| 142 | + |
| 143 | +```bash |
| 144 | +python3 -m evaluation.benchmark \ |
| 145 | + --local \ |
| 146 | + --dataset_dir /path/to/dataset \ |
| 147 | + --venv_path /path/to/local/venv |
| 148 | +``` |
| 149 | + |
| 150 | +**Run with code adaptation (using Gemini API):** |
| 151 | + |
| 152 | +If your dataset is missing `kernel_task.yaml` or requires code adaptation to fit the harness, you can use the `--adapt` flag. Note that this requires setting the `GEMINI_API_KEY` environment variable. |
| 153 | + |
| 154 | +```bash |
| 155 | +export GEMINI_API_KEY="your-api-key" |
| 156 | + |
| 157 | +python3 -m evaluation.benchmark \ |
| 158 | + --tpu_name my-tpu \ |
| 159 | + ... \ |
| 160 | + --adapt reference_code optimized_code kernel_task |
| 161 | +``` |
| 162 | + |
| 163 | +### CLI Arguments |
| 164 | + |
| 165 | +| Argument | Type | Description | |
| 166 | +| :--- | :--- | :--- | |
| 167 | +| `--local` | Flag | Run evaluation locally instead of on a remote TPU. | |
| 168 | +| `--tpu_name` | `str` | The name of the target TPU VM. | |
| 169 | +| `--zone` | `str` | The Google Cloud zone where the TPU is located. | |
| 170 | +| `--project` | `str` | The Google Cloud project ID. | |
| 171 | +| `--venv_path` | `str` | Absolute path to Python venv to use. | |
| 172 | +| `--dataset_dir` | `str` | Local directory containing the benchmark dataset. | |
| 173 | +| `--adapt` | `str` list | Components to adapt via LLM (`reference_code`, `optimized_code`, `kernel_task`). | |
| 174 | +| `--reference_file_name`| `str` | Filename for reference code (default: `reference.py`). | |
| 175 | +| `--optimized_file_name`| `str` | Filename for optimized code (default: `optimized.py`). | |
| 176 | +| `--task_file_name` | `str` | Filename for kernel task yaml (default: `kernel_task.yaml`). | |
| 177 | + |
| 178 | +### Results |
| 179 | + |
| 180 | +The CLI will save execution results to a JSON file in the `dataset_dir` named `evaluation_results_<tpu_name>.json`. It will also automatically resume from previous runs if this file exists. At the end of the run, a summary of results will be printed. |
0 commit comments