diff --git a/tools/expand_graph_paths.sh b/tools/expand_graph_paths.sh new file mode 100644 index 0000000000..78f464d055 --- /dev/null +++ b/tools/expand_graph_paths.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +PASSNET_DIR="/path/to/passnet/repo/" + +INPUT_LISTS=( + "${PASSNET_DIR}/sample_lists/hf_typical_samples_v2.txt" + "${PASSNET_DIR}/sample_lists/hf_sole_op_samples_v2.txt" + "${PASSNET_DIR}/sample_lists/hf_fusible_samples_v2.txt" +) + +if [ ! -d "$PASSNET_DIR" ]; then + echo "Error: base directory $PASSNET_DIR not found" + exit 1 +fi + +for INPUT_LIST in "${INPUT_LISTS[@]}"; do + BASENAME=$(basename "$INPUT_LIST" .txt) + OUTPUT_FILE="${BASENAME}_all_expanded.txt" + > "$OUTPUT_FILE" + + echo "Processing $INPUT_LIST ..." + + count=0 + while IFS= read -r rel_path || [ -n "$rel_path" ]; do + clean_rel_path=$(echo "$rel_path" | tr -d '\r' | xargs) + [ -z "$clean_rel_path" ] && continue + + TARGET_FILE="${PASSNET_DIR}/${clean_rel_path}/graph_list.txt" + + if [ -f "$TARGET_FILE" ]; then + cat "$TARGET_FILE" >> "$OUTPUT_FILE" + ((count++)) + else + echo "Skipped: $TARGET_FILE not found" + fi + done < "$INPUT_LIST" + + echo "Done: $count directories processed -> $(pwd)/$OUTPUT_FILE" +done + +echo "All tasks completed." diff --git a/tools/extract_triton_kernels.sh b/tools/extract_triton_kernels.sh new file mode 100755 index 0000000000..2f86c1a8b9 --- /dev/null +++ b/tools/extract_triton_kernels.sh @@ -0,0 +1,68 @@ +#!/bin/bash +set -euo pipefail + +# Thin launcher for the triton kernel extraction pipeline. +# +# This script sets machine-specific paths and delegates all logic to the +# Python module at tools/triton_kernel_extractor. +# +# Usage: +# bash extract_triton_kernels.sh [gpu_ids] +# +# Args: +# source (required): "list" or "hf" +# gpu_ids (optional): comma-separated GPU IDs, e.g. "0,2,5,7" +# +# Examples: +# bash extract_triton_kernels.sh list # list source, auto-detect GPUs +# bash extract_triton_kernels.sh hf 0,2,5,7 # hf source, specified GPUs + +# ============================================================ +# Arguments +# ============================================================ + +SOURCE="${1:?Usage: bash extract_triton_kernels.sh [gpu_ids] (source: list | hf)}" +GPU_ARG="${2:-}" + +# ============================================================ +# Machine-specific path configuration +# +# Edit the variables below to match your local environment. +# ============================================================ + +DATASET_BASE_DIR="/path/to/dataset_output" +GRAPHNET_DIR="/path/to/GraphNet/GitHub/repo/" + +# ============================================================ +# Environment setup +# ============================================================ + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PASSNET_DIR="$(cd "$SCRIPT_DIR/.." && pwd)" + +export PYTHONPATH="$GRAPHNET_DIR:${PASSNET_DIR}:${PYTHONPATH:-}" + +# ============================================================ +# Build Python CLI arguments +# ============================================================ + +PYTHON_ARGS=( + --source "$SOURCE" + --dataset-base-dir "$DATASET_BASE_DIR" + --graphnet-dir "$GRAPHNET_DIR" + --passnet-dir "$PASSNET_DIR" + --max-autotune + --enable-cache-analysis +) + +if [ -n "$GPU_ARG" ]; then + # Convert comma-separated "0,2,5,7" to space-separated args. + IFS=',' read -ra GPU_IDS <<< "$GPU_ARG" + PYTHON_ARGS+=(--gpu-ids "${GPU_IDS[@]}") +fi + +# ============================================================ +# Run +# ============================================================ + +exec python3 -m tools.triton_kernel_extractor "${PYTHON_ARGS[@]}" diff --git a/tools/triton_kernel_extractor/README.md b/tools/triton_kernel_extractor/README.md new file mode 100644 index 0000000000..78809ad8da --- /dev/null +++ b/tools/triton_kernel_extractor/README.md @@ -0,0 +1,182 @@ +# Triton Kernel Extractor + +A pipeline that compiles computational subgraphs through TorchInductor, filters +the results by kernel-level speedup, and extracts the autotuning-selected Triton +kernel source together with the corresponding PTX assembly from the inductor +compilation cache. + +## Background + +When `torch.compile` processes a model via the TorchInductor backend with +`TORCH_COMPILE_DEBUG=1`, the compiler produces a per-graph cache directory +containing: + +- **`output_code.py`** — the generated Python wrapper that calls into Triton + kernels via `async_compile.triton('kernel_name', '''...''')`. The kernels + appearing here are the final, autotuning-selected implementations adopted by + the inductor scheduler. +- **`triton/0/{HASH}/`** — one directory per autotuning candidate + configuration (varying `XBLOCK`, `YBLOCK`, `num_warps`, etc.), each holding + the compiled artifacts (`.ptx`, `.cubin`, `.ttir`, `.llir`, `.source`, + `.json`). When autotuning explores N configurations for a kernel, N + directories are created. +- **`*.best_config`** — a JSON file written by the Triton autotuner recording + the winning configuration. Its `triton_cache_hash` field maps back to one of + the `triton/0/{HASH}/` directories. + +This pipeline automates the full workflow: compile → filter → clean → extract → +pair, producing clean `(subgraph, triton_kernel, ptx)` triples ready for +downstream analysis. + +## Pipeline Steps + +The pipeline processes three dataset categories — `sole_op_subgraphs`, +`fusible_subgraphs`, and `typical_subgraphs` — executing five steps for each: + +### Step 1: Multi-GPU Parallel Compilation + +Compiles each subgraph sample using `graph_net_bench.torch.test_compiler +--kernel-time` in an isolated subprocess. Samples are distributed across +available GPUs in round-robin fashion, with one `ProcessPoolExecutor` worker per +GPU. Each subprocess receives a dedicated `CUDA_VISIBLE_DEVICES` and an +isolated `TORCHINDUCTOR_CACHE_DIR`. Pass `--max-autotune` to enable Inductor's +`max_autotune` mode (via `torch.compile(mode="max-autotune")`), which activates +comprehensive autotuning including `max_autotune_gemm`, +`coordinate_descent_tuning`, and `epilogue_fusion`. + +### Step 2: Speedup Filtering + +Parses the `[Speedup][kernel]:` metric from each sample's compilation log (the +last occurrence is used). Samples achieving a speedup ≥ 1.0 are moved to +`kept/`; the rest are moved to `discarded/`. + +### Step 3: Temporary File Cleanup + +Recursively removes `__pycache__/` directories, `*.pyc`, and `*.pyo` files from +the `kept/` tree to reduce storage footprint before extraction. + +### Step 4: Kernel and PTX Extraction + +For each kept sample that contains `original_graph/graph_hash.txt`: + +1. Copies `original_graph/model.py` (the source subgraph) into the output. +2. Parses `output_code.py` to extract all Triton kernel definitions using a + regex equivalent of the original Perl one-liner. +3. Writes each kernel source to `triton_kernel/{kernel_name}.py`. +4. Locates the corresponding PTX for each kernel by scanning `triton/0/` and + disambiguating via `.best_config` when multiple autotuning candidates exist, + then writes it to `ptx/{kernel_name}.ptx`. + +Output is written atomically (`.tmp` directory + `rename`) so that an +interrupted run never leaves half-written data. + +### Step 5: Empty Sample Cleanup + +Removes output samples that contain `original_graph/` but no `triton_kernel/` +directory (i.e., samples where no Triton kernels were extracted). + +## PTX Resolution Algorithm + +Each Triton kernel may have been compiled under multiple autotuning +configurations. The algorithm to locate the winning PTX is: + +1. Scan `triton/0/*/` for directories containing `{kernel_name}.ptx`. +2. If exactly one candidate exists, use it directly (no autotuning was needed). +3. If multiple candidates exist, collect `triton_cache_hash` values from all + `*.best_config` files in the sample, and select the candidate whose directory + name matches one of these hashes. + +This approach was validated on 125 kernels across 98 samples with a 100% match +rate. + +## Output Structure + +``` +{output_dir}/{sample_name}/ + original_graph/ + model.py # source subgraph + triton_kernel/ + triton_poi_fused_xxx_0.py # Triton kernel source + triton_poi_fused_yyy_1.py + ptx/ + triton_poi_fused_xxx_0.ptx # corresponding PTX assembly + triton_poi_fused_yyy_1.ptx +``` + +## Cache Analysis + +Analyzes an inductor cache directory post-hoc, available as the `analyze` +subcommand or triggered automatically by passing `--enable-cache-analysis` to +the `extract` subcommand. Concatenates `test_compiler_log.log` files across +all sample states (root, kept, discarded), computes kernel and end-to-end +speedup distributions (mean, median, percentiles, threshold breakdowns), and +generates histogram, CDF, and optionally violin/ES(t) plots. Output defaults +to `_analysis/`. + +## Usage + +### Via the Bash Launcher + +```bash +# Edit machine-specific paths in extract_triton_kernels.sh first, then: +bash tools/extract_triton_kernels.sh list # auto-detect GPUs +bash tools/extract_triton_kernels.sh hf 0,2,5,7 # specify GPUs +``` + +### Via Python Directly + +```bash +python3 -m tools.triton_kernel_extractor \ + --source list \ + --dataset-base-dir /data/passnet_dataset \ + --graphnet-dir /opt/GraphNet \ + --passnet-dir /opt/passnet \ + --passnet-hf-dir /opt/passnet/graphs/hf_subgraphs_v2 \ + --gpu-ids 0 2 5 7 \ + --max-autotune \ + --enable-cache-analysis + +# Cache analysis can also be run standalone: +python3 -m tools.triton_kernel_extractor analyze [--output-dir DIR] +``` + +### CLI Arguments + +| Argument | Required | Description | +|----------------------------|----------|-------------------------------------------------------| +| `--source` | Yes | `list` (sample paths from text files) or `hf` (scan HuggingFace directories) | +| `--dataset-base-dir` | Yes | Root directory for cache and extraction output | +| `--graphnet-dir` | Yes | Path to the GraphNet repository (for `graph_net_bench` on PYTHONPATH) | +| `--passnet-dir` | Yes | Root of the PassNet repository (model path prefix) | +| `--passnet-hf-dir` | No | HuggingFace graph data directory; defaults to `{passnet-dir}/graphs/hf_subgraphs_v2` | +| `--gpu-ids` | No | GPU IDs for compilation; auto-detected when omitted | +| `--max-autotune` | No | Enable Inductor max_autotune mode (`torch.compile(mode="max-autotune")`) | +| `--enable-cache-analysis` | No | Run cache analysis on each dataset after extraction | + +## Module Structure + +``` +triton_kernel_extractor/ + __init__.py # package marker + __main__.py # CLI entry point (subcommands: extract, analyze) + config.py # PipelineConfig, DatasetDescriptor, constants + sample_enumerator.py # enumerate samples from "list" or "hf" sources + compiler.py # Step 1: multi-GPU parallel compilation + speedup_filter.py # Step 2: filter by kernel speedup + temp_cleaner.py # Step 3: remove __pycache__ / *.pyc / *.pyo + kernel_extractor.py # Step 4: extract Triton kernels and PTX + empty_sample_cleaner.py # Step 5: remove samples without Triton kernels + pipeline.py # orchestrate Steps 1–5 for all datasets + cache_analyzer.py # analyze cache: logs, statistics, plots +``` + +## Idempotency and Resume + +Every step implements skip logic to support safe re-execution: + +- **Compilation** skips samples whose log already contains `[Speedup][kernel]:` + or that already exist under `kept/` or `discarded/`. +- **Filtering** skips samples already classified into `kept/` or `discarded/`. +- **Extraction** skips output samples that already exist in the output directory. + Stale `.tmp` directories from prior interrupted runs are cleaned up + automatically on startup. diff --git a/tools/triton_kernel_extractor/__init__.py b/tools/triton_kernel_extractor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tools/triton_kernel_extractor/__main__.py b/tools/triton_kernel_extractor/__main__.py new file mode 100644 index 0000000000..0500aff536 --- /dev/null +++ b/tools/triton_kernel_extractor/__main__.py @@ -0,0 +1,298 @@ +"""CLI entry point for the triton kernel extraction pipeline. + +Subcommands +----------- + +**extract** (default when no subcommand is given):: + + python3 -m tools.triton_kernel_extractor [extract] \\ + --source list \\ + --dataset-base-dir /data/passnet_dataset \\ + --graphnet-dir /opt/GraphNet \\ + --passnet-dir /opt/passnet \\ + [--passnet-hf-dir /opt/passnet/graphs/hf_subgraphs_v2] \\ + [--gpu-ids 0 2 5 7] + +**analyze**:: + + python3 -m tools.triton_kernel_extractor analyze \\ + [--output-dir DIR] + +When ``--gpu-ids`` is omitted the script auto-detects all available GPUs +by parsing the output of ``nvidia-smi -L``. +""" + +from __future__ import annotations + +import argparse +import logging +import os +import re +import subprocess +import sys +from pathlib import Path + +from .config import PipelineConfig + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# GPU detection (shared by the extract subcommand) +# --------------------------------------------------------------------------- + +def _detect_gpu_ids() -> list[int]: + """Auto-detect available GPU IDs. + + Priority order (matching the original bash script): + 1. ``CUDA_VISIBLE_DEVICES`` environment variable + 2. ``nvidia-smi -L`` output + + Returns a list of integer GPU indices. Raises ``RuntimeError`` + when no GPUs are found. + """ + # Priority 1: honour CUDA_VISIBLE_DEVICES if set. + cuda_env = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip() + if cuda_env: + try: + return [int(x) for x in cuda_env.split(",") if x.strip()] + except ValueError: + pass # Fall through to nvidia-smi. + + # Priority 2: auto-detect from nvidia-smi. + try: + result = subprocess.run( + ["nvidia-smi", "-L"], + capture_output=True, + text=True, + timeout=10, + ) + ids = [int(m) for m in re.findall(r"GPU (\d+):", result.stdout)] + except (FileNotFoundError, subprocess.TimeoutExpired): + ids = [] + + if not ids: + raise RuntimeError( + "No GPUs detected. Pass --gpu-ids explicitly or check nvidia-smi." + ) + return ids + + +# --------------------------------------------------------------------------- +# Subcommand: extract +# --------------------------------------------------------------------------- + +def _add_extract_parser(subparsers: argparse._SubParsersAction) -> None: + parser = subparsers.add_parser( + "extract", + help="Run the full compilation and extraction pipeline.", + description=( + "Compile graph datasets and extract " + "(subgraph, triton_kernel, ptx) triples." + ), + ) + parser.add_argument( + "--source", + required=True, + choices=("list", "hf"), + help="Data source type: 'list' (txt file paths) or 'hf' (scan HF dirs).", + ) + parser.add_argument( + "--gpu-ids", + type=int, + nargs="*", + default=None, + help=( + "GPU IDs to use for parallel compilation. " + "Auto-detected via nvidia-smi when omitted." + ), + ) + parser.add_argument( + "--dataset-base-dir", + type=Path, + required=True, + help="Root directory of the dataset collection.", + ) + parser.add_argument( + "--graphnet-dir", + type=Path, + required=True, + help="Path to the GraphNet repository (added to PYTHONPATH for graph_net_bench).", + ) + parser.add_argument( + "--passnet-dir", + type=Path, + required=True, + help="Root of the PassNet repository (prefix for model paths in 'list' mode).", + ) + parser.add_argument( + "--passnet-hf-dir", + type=Path, + default=None, + help=( + "Root of the HuggingFace graph data directory. " + "Defaults to {passnet-dir}/graphs/hf_subgraphs_v2 when not specified." + ), + ) + parser.add_argument( + "--enable-cache-analysis", + action="store_true", + default=False, + help=( + "Run cache analysis (log concatenation, speedup statistics, plots) " + "on each dataset's cache directory after the extraction pipeline." + ), + ) + parser.add_argument( + "--max-autotune", + action="store_true", + default=False, + help=( + "Enable Inductor max_autotune mode during compilation. " + "Passes mode='max-autotune' to the GraphNet InductorBackend, " + "which activates comprehensive autotuning via torch.compile." + ), + ) + parser.set_defaults(func=_run_extract) + + +def _run_extract(args: argparse.Namespace) -> None: + from .pipeline import run_pipeline + + gpu_ids = args.gpu_ids if args.gpu_ids else _detect_gpu_ids() + + passnet_hf_dir = args.passnet_hf_dir + if passnet_hf_dir is None: + passnet_hf_dir = args.passnet_dir / "graphs" / "hf_subgraphs_v2" + + config = PipelineConfig( + source=args.source, + gpu_ids=gpu_ids, + dataset_base_dir=args.dataset_base_dir, + graphnet_dir=args.graphnet_dir, + passnet_dir=args.passnet_dir, + passnet_hf_dir=passnet_hf_dir, + max_autotune=args.max_autotune, + ) + + logger.info("Source: %s", config.source) + logger.info( + "Using %d GPU(s): %s", + len(config.gpu_ids), + " ".join(str(g) for g in config.gpu_ids), + ) + + # Unset CUDA_VISIBLE_DEVICES in the parent process so that worker + # subprocesses start with a clean slate and receive only the per-GPU + # value assigned by compiler.py. Matches the bash: `unset CUDA_VISIBLE_DEVICES`. + os.environ.pop("CUDA_VISIBLE_DEVICES", None) + + run_pipeline(config, enable_cache_analysis=args.enable_cache_analysis) + + +# --------------------------------------------------------------------------- +# Subcommand: analyze +# --------------------------------------------------------------------------- + +def _add_analyze_parser(subparsers: argparse._SubParsersAction) -> None: + parser = subparsers.add_parser( + "analyze", + help="Analyze an inductor cache directory (logs, statistics, plots).", + description=( + "Concatenate compiler logs, compute speedup statistics, and " + "generate distribution plots for an inductor cache directory." + ), + ) + parser.add_argument( + "cache_dir", + type=Path, + help="Inductor cache directory to analyze.", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=None, + help=( + "Directory for analysis output. " + "Defaults to _analysis." + ), + ) + parser.set_defaults(func=_run_analyze) + + +def _run_analyze(args: argparse.Namespace) -> None: + from .cache_analyzer import analyze_cache + + cache_dir: Path = args.cache_dir + output_dir: Path = args.output_dir or Path(f"{cache_dir}_analysis") + analyze_cache(cache_dir, output_dir) + + +# --------------------------------------------------------------------------- +# Backward-compatible argument detection +# --------------------------------------------------------------------------- + +def _needs_implicit_extract(argv: list[str]) -> bool: + """Return True if *argv* looks like the old extract-only CLI. + + The old CLI had no subcommand — it started directly with ``--source``. + If the first argument is not a known subcommand and not a help flag, + we prepend ``extract`` for backward compatibility. + """ + if not argv: + return False + known_subcommands = {"extract", "analyze"} + first = argv[0] + if first in known_subcommands: + return False + # Do not intercept top-level --help / -h. + if first in ("-h", "--help"): + return False + return True + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + +def main(argv: list[str] | None = None) -> None: + logging.basicConfig( + format="%(message)s", + level=logging.INFO, + stream=sys.stderr, + ) + + if argv is None: + argv = sys.argv[1:] + + # Backward compatibility: insert "extract" when no subcommand is given. + if _needs_implicit_extract(argv): + argv = ["extract"] + argv + + parser = argparse.ArgumentParser( + prog="python3 -m tools.triton_kernel_extractor", + description=( + "Triton kernel extraction toolkit: compile, filter, extract, " + "and analyze TorchInductor compilation caches." + ), + ) + subparsers = parser.add_subparsers(dest="command") + _add_extract_parser(subparsers) + _add_analyze_parser(subparsers) + + args = parser.parse_args(argv) + + if not hasattr(args, "func"): + parser.print_help() + sys.exit(1) + + try: + args.func(args) + except KeyboardInterrupt: + logger.info("") + logger.info("Interrupted.") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tools/triton_kernel_extractor/cache_analyzer.py b/tools/triton_kernel_extractor/cache_analyzer.py new file mode 100644 index 0000000000..71ccd9ff5d --- /dev/null +++ b/tools/triton_kernel_extractor/cache_analyzer.py @@ -0,0 +1,441 @@ +"""Analyze an inductor cache directory and produce summary statistics and plots. + +This module replaces the standalone ``analyze_inductor_cache.sh`` script. It +concatenates compiler logs from all sample states (root, kept, discarded), +computes speedup statistics, and generates distribution plots. + +The analysis can be invoked via the CLI:: + + python3 -m tools.triton_kernel_extractor analyze [--output-dir DIR] +""" + +from __future__ import annotations + +import logging +import re +import subprocess +import sys +from datetime import datetime +from pathlib import Path + +from .config import SPEEDUP_KERNEL_PATTERN, is_sample_dir + +logger = logging.getLogger(__name__) + +# Pattern for end-to-end speedup (secondary metric, includes framework overhead). +_SPEEDUP_E2E_PATTERN = re.compile(r"\[Speedup\]\[e2e\]:\s*([\d.]+)") + +# Reuse the kernel speedup pattern from config (compiled for findall). +_SPEEDUP_KERNEL_RE = re.compile(SPEEDUP_KERNEL_PATTERN) + + +# --------------------------------------------------------------------------- +# Step 1: Log concatenation +# --------------------------------------------------------------------------- + +def _concat_logs(search_dir: Path) -> tuple[str, int]: + """Concatenate ``test_compiler_log.log`` from all samples under *search_dir*. + + Returns the combined text and the number of log files found. + """ + if not search_dir.is_dir(): + return "", 0 + + parts: list[str] = [] + count = 0 + for sample_dir in sorted(search_dir.iterdir()): + if not sample_dir.is_dir(): + continue + if not is_sample_dir(sample_dir.name): + continue + log_file = sample_dir / "test_compiler_log.log" + if log_file.is_file(): + parts.append(log_file.read_text(encoding="utf-8", errors="replace")) + count += 1 + return "\n".join(parts), count + + +def concatenate_logs(cache_dir: Path, output_dir: Path) -> tuple[Path, Path, Path]: + """Concatenate logs from root, kept, and discarded sample directories. + + Writes three files to *output_dir* and returns their paths: + ``(all_log, kept_log, discarded_log)``. + """ + root_text, root_count = _concat_logs(cache_dir) + kept_text, kept_count = _concat_logs(cache_dir / "kept") + discarded_text, discarded_count = _concat_logs(cache_dir / "discarded") + + all_text = "\n".join(filter(None, [root_text, kept_text, discarded_text])) + + all_log = output_dir / "all_samples.log" + kept_log = output_dir / "kept_samples.log" + discarded_log = output_dir / "discarded_samples.log" + + all_log.write_text(all_text, encoding="utf-8") + kept_log.write_text(kept_text, encoding="utf-8") + discarded_log.write_text(discarded_text, encoding="utf-8") + + total = root_count + kept_count + discarded_count + logger.info( + " Logs concatenated: %d total (%d root, %d kept, %d discarded)", + total, root_count, kept_count, discarded_count, + ) + logger.info(" All: %s", all_log) + logger.info(" Kept: %s", kept_log) + logger.info(" Discarded: %s", discarded_log) + + return all_log, kept_log, discarded_log + + +# --------------------------------------------------------------------------- +# Step 2: Summary statistics +# --------------------------------------------------------------------------- + +def _parse_speedups(text: str, pattern: re.Pattern[str]) -> list[float]: + """Extract all speedup values matching *pattern* from log text.""" + return [float(m) for m in pattern.findall(text)] + + +def _percentile(values: list[float], p: float) -> float: + """Compute the *p*-th percentile (0–100) of a sorted list.""" + if not values: + return 0.0 + k = (len(values) - 1) * p / 100.0 + f = int(k) + c = f + 1 if f + 1 < len(values) else f + return values[f] + (k - f) * (values[c] - values[f]) + + +def _format_speedup_stats(values: list[float], label: str) -> str: + """Format a block of descriptive statistics for a speedup distribution.""" + lines: list[str] = [] + n = len(values) + lines.append(f" Samples with {label} speedup: {n}") + + if n == 0: + return "\n".join(lines) + + values_sorted = sorted(values) + mean = sum(values_sorted) / n + median = _percentile(values_sorted, 50) + + lines.append("") + lines.append(f" Mean: {mean:.4f}") + lines.append(f" Median: {median:.4f}") + lines.append(f" Min: {values_sorted[0]:.4f}") + lines.append(f" Max: {values_sorted[-1]:.4f}") + lines.append(f" P5: {_percentile(values_sorted, 5):.4f}") + lines.append(f" P25: {_percentile(values_sorted, 25):.4f}") + lines.append(f" P75: {_percentile(values_sorted, 75):.4f}") + lines.append(f" P95: {_percentile(values_sorted, 95):.4f}") + + ge2 = sum(1 for v in values_sorted if v >= 2.0) + ge1_5 = sum(1 for v in values_sorted if v >= 1.5) + ge1 = sum(1 for v in values_sorted if v >= 1.0) + lt1 = sum(1 for v in values_sorted if v < 1.0) + lt0_5 = sum(1 for v in values_sorted if v < 0.5) + + lines.append("") + lines.append(f" Speedup >= 2.0: {ge2} ({ge2/n*100:.1f}%)") + lines.append(f" Speedup >= 1.5: {ge1_5} ({ge1_5/n*100:.1f}%)") + lines.append(f" Speedup >= 1.0: {ge1} ({ge1/n*100:.1f}%)") + lines.append(f" Speedup < 1.0: {lt1} ({lt1/n*100:.1f}%) [negative optimization]") + lines.append(f" Speedup < 0.5: {lt0_5} ({lt0_5/n*100:.1f}%) [severe regression]") + + return "\n".join(lines) + + +def _count_subdirs(directory: Path) -> int: + """Count immediate subdirectories of *directory*.""" + if not directory.is_dir(): + return 0 + return sum(1 for d in directory.iterdir() if d.is_dir()) + + +def generate_summary( + cache_dir: Path, + all_log_text: str, + discarded_log_text: str, + output_dir: Path, +) -> Path: + """Generate a text summary report and return its path.""" + kernel_speedups = _parse_speedups(all_log_text, _SPEEDUP_KERNEL_RE) + e2e_speedups = _parse_speedups(all_log_text, _SPEEDUP_E2E_PATTERN) + + # Count samples in each state. + root_samples = sum( + 1 for d in cache_dir.iterdir() + if d.is_dir() and is_sample_dir(d.name) + ) if cache_dir.is_dir() else 0 + kept_samples = _count_subdirs(cache_dir / "kept") + discarded_samples = _count_subdirs(cache_dir / "discarded") + total_samples = root_samples + kept_samples + discarded_samples + + def pct(n: int) -> str: + return f"{n/total_samples*100:.1f}" if total_samples > 0 else "0.0" + + lines: list[str] = [ + "Inductor Cache Analysis Report", + "==============================", + f"Cache dir: {cache_dir}", + f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", + "", + "Sample Counts", + "-------------", + f" Total: {total_samples}", + f" Kept: {kept_samples} ({pct(kept_samples)}%)", + f" Discarded: {discarded_samples} ({pct(discarded_samples)}%)", + f" Unclassified (root): {root_samples}", + "", + "Kernel Speedup Distribution (primary — used for kept/discarded filtering)", + "-------------------------------------------------------------------------", + _format_speedup_stats(kernel_speedups, "Kernel"), + "", + "E2E Speedup Distribution (secondary — includes framework overhead)", + "-------------------------------------------------------------------", + _format_speedup_stats(e2e_speedups, "E2E"), + ] + + # Failure analysis from discarded logs. + if discarded_log_text: + neg_opt = sum(1 for v in kernel_speedups if 0 < v < 1.0) + error_lines = sum( + 1 for line in discarded_log_text.splitlines() + if re.search(r"ERROR|Exception|Traceback", line) + ) + lines.extend([ + "", + "Failure/Discard Analysis", + "------------------------", + f" Negative optimization (0 < kernel speedup < 1): {neg_opt}", + f" Logs with errors/exceptions: {error_lines} lines", + ]) + + report = "\n".join(lines) + "\n" + summary_file = output_dir / "summary.txt" + summary_file.write_text(report, encoding="utf-8") + + # Also print to console. + logger.info("%s", report) + logger.info("Summary saved to: %s", summary_file) + + return summary_file + + +# --------------------------------------------------------------------------- +# Step 3: Plots +# --------------------------------------------------------------------------- + +def _check_plotting_deps() -> bool: + """Return True if matplotlib and numpy are importable.""" + try: + import matplotlib # noqa: F401 + import numpy # noqa: F401 + return True + except ImportError: + return False + + +def generate_plots( + all_log_text: str, + all_log_path: Path, + output_dir: Path, +) -> None: + """Generate speedup distribution plots. + + Produces: + - ``speedup_histogram.png`` — raw and log2 histograms of kernel speedup. + - ``speedup_cdf.png`` — cumulative distribution function. + - ``violin.png`` — via external ``graph_net_visual.plot_violin`` (best effort). + - ``ESt.png`` — via external ``graph_net_visual.plot_ESt`` (best effort). + """ + if not _check_plotting_deps(): + logger.warning( + " Skipping plots: matplotlib or numpy not installed. " + "Run: pip install matplotlib numpy" + ) + return + + kernel_speedups = _parse_speedups(all_log_text, _SPEEDUP_KERNEL_RE) + if not kernel_speedups: + logger.info(" No kernel speedup data for plotting.") + return + + _generate_builtin_plots(kernel_speedups, output_dir) + _run_external_plot("graph_net_visual.plot_violin", all_log_path, output_dir) + _run_external_plot( + "graph_net_visual.plot_ESt", all_log_path, output_dir, + extra_args=["--disable-aggregation-mode"], + ) + + +def _generate_builtin_plots( + kernel_speedups: list[float], + output_dir: Path, +) -> None: + """Generate histogram and CDF plots using matplotlib.""" + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + import numpy as np + + speedups = np.array(kernel_speedups) + + # --- Histogram: raw + log2 side by side --- + fig, axes = plt.subplots(1, 2, figsize=(18, 7)) + + ax1 = axes[0] + bins = np.concatenate([ + np.arange(0, 1.0, 0.1), + np.arange(1.0, 2.0, 0.1), + np.arange(2.0, max(5.0, float(np.percentile(speedups, 99))) + 0.5, 0.5), + ]) + ax1.hist(speedups, bins=bins, color="steelblue", edgecolor="white", alpha=0.85) + ax1.axvline(x=1.0, color="red", linestyle="--", linewidth=1.5, label="speedup = 1.0") + ax1.axvline( + x=float(np.median(speedups)), color="orange", linestyle="-", linewidth=1.5, + label=f"median = {np.median(speedups):.3f}", + ) + ax1.set_xlabel("Kernel Speedup", fontsize=14) + ax1.set_ylabel("Count", fontsize=14) + ax1.set_title("Kernel Speedup Distribution", fontsize=16) + ax1.legend(fontsize=12) + + n_total = len(speedups) + n_pos = int(np.sum(speedups >= 1.0)) + n_neg = int(np.sum(speedups < 1.0)) + stats_text = ( + f"Total: {n_total}\n" + f"Speedup >= 1: {n_pos} ({n_pos/n_total*100:.1f}%)\n" + f"Speedup < 1: {n_neg} ({n_neg/n_total*100:.1f}%)\n" + f"Mean: {np.mean(speedups):.3f}\n" + f"Median: {np.median(speedups):.3f}" + ) + ax1.text( + 0.97, 0.97, stats_text, transform=ax1.transAxes, fontsize=10, + verticalalignment="top", horizontalalignment="right", + bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.8), + ) + + # Log2 histogram. + positive = speedups[speedups > 0] + log2_sp = np.log2(positive) + ax2 = axes[1] + ax2.hist(log2_sp, bins=60, color="darkorange", edgecolor="white", alpha=0.85) + ax2.axvline(x=0, color="red", linestyle="--", linewidth=1.5, label="log2(speedup) = 0") + ax2.axvline( + x=float(np.median(log2_sp)), color="blue", linestyle="-", linewidth=1.5, + label=f"median = {np.median(log2_sp):.3f}", + ) + ax2.set_xlabel("log2(Kernel Speedup)", fontsize=14) + ax2.set_ylabel("Count", fontsize=14) + ax2.set_title("log2(Kernel Speedup) Distribution", fontsize=16) + ax2.legend(fontsize=12) + + plt.tight_layout() + hist_path = output_dir / "speedup_histogram.png" + plt.savefig(str(hist_path), dpi=200, bbox_inches="tight") + plt.close(fig) + logger.info(" Saved: %s", hist_path) + + # --- CDF --- + fig2, ax3 = plt.subplots(figsize=(10, 6)) + sorted_sp = np.sort(speedups) + cdf = np.arange(1, len(sorted_sp) + 1) / len(sorted_sp) + ax3.plot(sorted_sp, cdf, color="steelblue", linewidth=2) + ax3.axvline(x=1.0, color="red", linestyle="--", linewidth=1.2, label="speedup = 1.0") + + cdf_at_1 = float(np.searchsorted(sorted_sp, 1.0) / len(sorted_sp)) + ax3.axhline(y=cdf_at_1, color="gray", linestyle=":", linewidth=1, alpha=0.7) + ax3.plot(1.0, cdf_at_1, "ro", markersize=8) + ax3.text(1.05, cdf_at_1, f"{cdf_at_1*100:.1f}% below 1.0", fontsize=12, color="red") + ax3.set_xlabel("Kernel Speedup", fontsize=14) + ax3.set_ylabel("Cumulative Fraction", fontsize=14) + ax3.set_title("Kernel Speedup CDF", fontsize=16) + ax3.set_xlim(0, min(5.0, float(np.percentile(speedups, 99.5)))) + ax3.legend(fontsize=12) + ax3.grid(True, alpha=0.3) + + cdf_path = output_dir / "speedup_cdf.png" + plt.savefig(str(cdf_path), dpi=200, bbox_inches="tight") + plt.close(fig2) + logger.info(" Saved: %s", cdf_path) + + +def _run_external_plot( + module: str, + log_path: Path, + output_dir: Path, + *, + extra_args: list[str] | None = None, +) -> None: + """Run an external GraphNet plotting module (best effort, never fatal).""" + cmd = [ + sys.executable, "-m", module, + "--benchmark-path", str(log_path), + "--output-dir", str(output_dir), + ] + if extra_args: + cmd.extend(extra_args) + + try: + result = subprocess.run( + cmd, capture_output=True, text=True, timeout=120, + ) + # Log any saved/error lines from stdout+stderr. + for line in (result.stdout + result.stderr).splitlines(): + if re.search(r"saved|Saved|Error|Warning", line, re.IGNORECASE): + logger.info(" %s: %s", module.rsplit(".", 1)[-1], line.strip()) + except (FileNotFoundError, subprocess.TimeoutExpired, OSError) as exc: + logger.debug(" %s unavailable: %s", module, exc) + + +# --------------------------------------------------------------------------- +# Public entry point +# --------------------------------------------------------------------------- + +def analyze_cache(cache_dir: Path, output_dir: Path) -> None: + """Run the full analysis pipeline on an inductor cache directory.""" + if not cache_dir.is_dir(): + logger.error("Cache directory does not exist: %s", cache_dir) + return + + output_dir.mkdir(parents=True, exist_ok=True) + + logger.info("======================================================") + logger.info(" Inductor Cache Analysis") + logger.info(" Input: %s", cache_dir) + logger.info(" Output: %s", output_dir) + logger.info("======================================================") + + # Step 1: Concatenate logs. + logger.info("") + logger.info("=== Step 1: Concatenating log files ===") + all_log, _kept_log, discarded_log = concatenate_logs(cache_dir, output_dir) + + all_log_text = all_log.read_text(encoding="utf-8", errors="replace") + discarded_log_text = discarded_log.read_text(encoding="utf-8", errors="replace") + + # Step 2: Summary statistics. + logger.info("") + logger.info("=== Step 2: Summary statistics ===") + generate_summary(cache_dir, all_log_text, discarded_log_text, output_dir) + + # Step 3: Plots. + logger.info("") + logger.info("=== Step 3: Generating plots ===") + generate_plots(all_log_text, all_log, output_dir) + + # Report output files. + logger.info("") + logger.info("======================================================") + logger.info(" Analysis complete!") + logger.info(" Output directory: %s", output_dir) + logger.info("======================================================") + output_files = sorted( + f for f in output_dir.iterdir() + if f.is_file() and f.suffix in {".txt", ".log", ".png"} + ) + for f in output_files: + size_kb = f.stat().st_size / 1024 + logger.info(" %s (%.1f KB)", f.name, size_kb) diff --git a/tools/triton_kernel_extractor/compiler.py b/tools/triton_kernel_extractor/compiler.py new file mode 100644 index 0000000000..d6f6ac4104 --- /dev/null +++ b/tools/triton_kernel_extractor/compiler.py @@ -0,0 +1,249 @@ +"""Step 1: Multi-GPU parallel compilation of graph samples.""" + +from __future__ import annotations + +import base64 +import json +import logging +import os +import shutil +import subprocess +import sys +from concurrent.futures import Future, ProcessPoolExecutor, as_completed +from pathlib import Path + +from .config import PipelineConfig, DatasetDescriptor +from .sample_enumerator import compute_unique_dir, resolve_model_path + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Single-sample compilation +# --------------------------------------------------------------------------- + +def _is_already_compiled( + log_file: Path, + cache_dir: Path, + unique_dir: str, +) -> bool: + """Check whether a sample has already been compiled in a prior run. + + Mirrors the bash resume logic:: + + { [ -f "$log_file" ] && grep -q '[Speedup][kernel]:' "$log_file"; } \ + || [ -d "$cache_dir/kept/$unique_dir" ] \ + || [ -d "$cache_dir/discarded/$unique_dir" ] + """ + if log_file.is_file(): + try: + content = log_file.read_text(encoding="utf-8", errors="replace") + if "[Speedup][kernel]:" in content: + return True + except OSError: + pass + + if (cache_dir / "kept" / unique_dir).is_dir(): + return True + if (cache_dir / "discarded" / unique_dir).is_dir(): + return True + + return False + + +def _compile_one_sample( + sample_path: str, + source: str, + passnet_dir: str, + passnet_hf_dir: str, + cache_dir: Path, + gpu_id: int, + progress_label: str, + compiler_config: str | None = None, +) -> str: + """Compile a single graph sample on a specific GPU. + + Returns one of ``"compiled"``, ``"skipped"``, or ``"failed"``. + """ + unique_dir = compute_unique_dir(source, sample_path, passnet_hf_dir) + full_model_path = resolve_model_path(source, sample_path, passnet_dir) + + graph_cache_dir = cache_dir / unique_dir + log_file = graph_cache_dir / "test_compiler_log.log" + + if _is_already_compiled(log_file, cache_dir, unique_dir): + logger.info("%s SKIP: %s", progress_label, sample_path) + return "skipped" + + # Remove incomplete cache from a prior interrupted attempt. + if graph_cache_dir.exists(): + shutil.rmtree(graph_cache_dir) + graph_cache_dir.mkdir(parents=True) + + logger.info("%s Compiling: %s", progress_label, full_model_path) + + # Build a clean environment for the compiler subprocess. + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + env["TORCH_COMPILE_DEBUG"] = "1" + env["TORCHINDUCTOR_CACHE_DIR"] = str(graph_cache_dir) + + result = subprocess.run( + [ + sys.executable, + "-m", + "graph_net_bench.torch.test_compiler", + "--model-path", + full_model_path, + "--kernel-time", + *(["--config", compiler_config] if compiler_config else []), + ], + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + + # Write combined stdout+stderr to the log file (matches bash "> log 2>&1"). + log_file.write_text(result.stdout or "", encoding="utf-8") + + # Copy the original graph source into the cache. + original_graph_dir = graph_cache_dir / "original_graph" + original_graph_dir.mkdir(exist_ok=True) + model_src = Path(full_model_path) + if model_src.is_dir(): + for item in model_src.iterdir(): + dest = original_graph_dir / item.name + if item.is_dir(): + shutil.copytree(str(item), str(dest), dirs_exist_ok=True) + else: + shutil.copy2(str(item), str(dest)) + + if result.returncode != 0: + return "failed" + return "compiled" + + +# --------------------------------------------------------------------------- +# Per-GPU sequential chunk worker +# --------------------------------------------------------------------------- + +def _compile_chunk( + samples: list[str], + source: str, + passnet_dir: str, + passnet_hf_dir: str, + cache_dir: Path, + gpu_id: int, + compiler_config: str | None = None, +) -> dict[str, int]: + """Process a chunk of samples sequentially on one GPU. + + This function is the top-level callable submitted to the process pool. + Each invocation runs in its own process, isolating CUDA state. + """ + total = len(samples) + stats = {"compiled": 0, "skipped": 0, "failed": 0} + + for idx, sample_path in enumerate(samples, 1): + label = f"[GPU{gpu_id} {idx}/{total}]" + status = _compile_one_sample( + sample_path=sample_path, + source=source, + passnet_dir=passnet_dir, + passnet_hf_dir=passnet_hf_dir, + cache_dir=cache_dir, + gpu_id=gpu_id, + progress_label=label, + compiler_config=compiler_config, + ) + stats[status] += 1 + + logger.info( + "[GPU%d] Done: %d compiled, %d skipped, %d failed (total: %d)", + gpu_id, + stats["compiled"], + stats["skipped"], + stats["failed"], + total, + ) + return stats + + +# --------------------------------------------------------------------------- +# Multi-GPU orchestrator +# --------------------------------------------------------------------------- + +def compile_all_samples( + samples: list[str], + config: PipelineConfig, + dataset: DatasetDescriptor, +) -> dict[str, int]: + """Split samples across GPUs round-robin and compile in parallel. + + Each GPU gets its own worker process that processes its chunk + sequentially, matching the original bash behaviour of one + ``compile_worker`` per GPU. + + Returns aggregated ``{"compiled": N, "skipped": N, "failed": N}``. + """ + gpu_ids = config.gpu_ids + num_gpus = len(gpu_ids) + + # Build base64-encoded config for test_compiler --config, if needed. + compiler_config: str | None = None + if config.max_autotune: + config_dict = {"mode": "max-autotune"} + compiler_config = base64.b64encode( + json.dumps(config_dict).encode() + ).decode() + + # Round-robin assignment (mirrors bash: gpu_id = GPU_IDS[local_idx % NUM_GPUS]). + chunks: dict[int, list[str]] = {gid: [] for gid in gpu_ids} + for idx, sample in enumerate(samples): + gid = gpu_ids[idx % num_gpus] + chunks[gid].append(sample) + + aggregated: dict[str, int] = {"compiled": 0, "skipped": 0, "failed": 0} + + # Use one process per GPU. max_workers == num_gpus ensures no GPU + # contention. + with ProcessPoolExecutor(max_workers=num_gpus) as executor: + future_to_gpu: dict[Future[dict[str, int]], int] = {} + + for gid in gpu_ids: + chunk = chunks[gid] + if not chunk: + continue + future = executor.submit( + _compile_chunk, + samples=chunk, + source=config.source, + passnet_dir=str(config.passnet_dir), + passnet_hf_dir=str(config.passnet_hf_dir), + cache_dir=dataset.cache_dir, + gpu_id=gid, + compiler_config=compiler_config, + ) + future_to_gpu[future] = gid + logger.info(" Launched worker GPU %d (%d samples)", gid, len(chunk)) + + logger.info(" Waiting for %d workers...", len(future_to_gpu)) + + has_errors = False + for future in as_completed(future_to_gpu): + gid = future_to_gpu[future] + try: + stats = future.result() + for key in aggregated: + aggregated[key] += stats[key] + except Exception: + has_errors = True + logger.exception("Worker GPU %d raised an exception", gid) + + if has_errors: + logger.warning( + "WARNING: Some workers had errors. Check logs for details." + ) + + return aggregated diff --git a/tools/triton_kernel_extractor/config.py b/tools/triton_kernel_extractor/config.py new file mode 100644 index 0000000000..ad4ede16a3 --- /dev/null +++ b/tools/triton_kernel_extractor/config.py @@ -0,0 +1,115 @@ +"""Pipeline configuration types and dataset descriptor construction.""" + +from __future__ import annotations + +import dataclasses +from pathlib import Path + +# The three dataset categories processed by the pipeline. +DATASET_NAMES: tuple[str, ...] = ( + "sole_op_subgraphs", + "fusible_subgraphs", + "typical_subgraphs", +) + +# Log pattern emitted by the external GraphNet test_compiler with --kernel-time. +# Used in speedup_filter to decide whether to keep or discard a compiled sample. +SPEEDUP_KERNEL_PATTERN = r"\[Speedup\]\[kernel\]:\s*([\d.]+)" + +# Subdirectory names reserved for internal bookkeeping inside the cache directory. +# These are skipped when iterating over sample directories. +RESERVED_DIR_NAMES = frozenset({"kept", "discarded"}) + +# Prefix used by temporary pipeline artifacts (chunk files, worker logs, sample +# lists). Directories whose name starts with this prefix are skipped during +# sample iteration. +RESERVED_DIR_PREFIX = "_" + +# Minimum kernel speedup required to keep a compiled sample. +SPEEDUP_THRESHOLD = 1.0 + + +def is_sample_dir(name: str) -> bool: + """Return True if *name* is a real sample directory, not a reserved one. + + Filters out ``kept``, ``discarded``, and directories starting with ``_`` + (temporary pipeline artifacts such as chunk files and worker logs). + """ + if name in RESERVED_DIR_NAMES: + return False + if name.startswith(RESERVED_DIR_PREFIX): + return False + return True + + +@dataclasses.dataclass(frozen=True) +class PipelineConfig: + """Immutable top-level configuration for the entire pipeline run.""" + + source: str # "list" or "hf" + gpu_ids: list[int] + dataset_base_dir: Path + graphnet_dir: Path + passnet_dir: Path + passnet_hf_dir: Path + max_autotune: bool = False + + +@dataclasses.dataclass(frozen=True) +class DatasetDescriptor: + """Describes one of the three dataset categories to be processed.""" + + name: str + cache_dir: Path + output_dir: Path + # Only populated when source == "list". + graph_list_file: Path | None + # Only populated when source == "hf". + hf_subdir: str | None + + +def build_dataset_descriptors( + config: PipelineConfig, +) -> list[DatasetDescriptor]: + """Build the list of dataset descriptors from the pipeline configuration. + + The mapping mirrors the original bash arrays ``DATASET_NAMES``, + ``CACHE_DIRS``, ``GRAPH_LIST_FILES``, and ``HF_SUBDIRS``. + """ + if config.source == "list": + dataset_dir = config.dataset_base_dir / "GitHubV2" + graph_list_files = [ + config.dataset_base_dir / "hf_sole_op_samples_v2_all_expanded.txt", + config.dataset_base_dir / "hf_fusible_samples_v2_all_expanded.txt", + config.dataset_base_dir / "hf_typical_samples_v2_all_expanded.txt", + ] + elif config.source == "hf": + dataset_dir = config.dataset_base_dir / "Huggingface" + graph_list_files = [None, None, None] + else: + raise ValueError( + f"Invalid source {config.source!r}. Must be 'list' or 'hf'." + ) + + if config.source == "hf": + hf_subdirs = list(DATASET_NAMES) + else: + hf_subdirs = [None, None, None] + + descriptors: list[DatasetDescriptor] = [] + for name, graph_list_file, hf_subdir in zip( + DATASET_NAMES, graph_list_files, hf_subdirs, strict=True + ): + cache_dir = dataset_dir / f"{name}_inductor_dump" + output_dir = dataset_dir / f"{name}_inductor_dump_subgraph_triton_kernel_pair" + descriptors.append( + DatasetDescriptor( + name=name, + cache_dir=cache_dir, + output_dir=output_dir, + graph_list_file=graph_list_file, + hf_subdir=hf_subdir, + ) + ) + + return descriptors diff --git a/tools/triton_kernel_extractor/empty_sample_cleaner.py b/tools/triton_kernel_extractor/empty_sample_cleaner.py new file mode 100644 index 0000000000..32a7f5be1b --- /dev/null +++ b/tools/triton_kernel_extractor/empty_sample_cleaner.py @@ -0,0 +1,45 @@ +"""Step 5: Remove output samples that have no extracted triton kernels.""" + +from __future__ import annotations + +import logging +import shutil +from pathlib import Path + +logger = logging.getLogger(__name__) + + +def clean_empty_kernel_samples(output_dir: Path) -> tuple[int, int]: + """Delete samples that contain ``original_graph/`` but no ``triton_kernel/``. + + Returns: + A tuple of ``(removed_count, kept_count)``. + """ + if not output_dir.is_dir(): + logger.warning("Output directory does not exist: %s", output_dir) + return 0, 0 + + total = 0 + removed = 0 + + for sample_dir in sorted(output_dir.iterdir()): + if not sample_dir.is_dir(): + continue + total += 1 + + has_graph = (sample_dir / "original_graph").is_dir() + has_kernel = (sample_dir / "triton_kernel").is_dir() + + if has_graph and not has_kernel: + logger.info(" Removing (no triton_kernel): %s", sample_dir.name) + shutil.rmtree(sample_dir) + removed += 1 + + kept = total - removed + logger.info( + "Cleanup: %d removed (no triton_kernel), %d kept (total: %d)", + removed, + kept, + total, + ) + return removed, kept diff --git a/tools/triton_kernel_extractor/kernel_extractor.py b/tools/triton_kernel_extractor/kernel_extractor.py new file mode 100644 index 0000000000..06fc9e76d8 --- /dev/null +++ b/tools/triton_kernel_extractor/kernel_extractor.py @@ -0,0 +1,254 @@ +"""Step 4: Extract autotuning-selected triton kernels and corresponding PTX.""" + +from __future__ import annotations + +import json +import logging +import re +import shutil +from pathlib import Path + +logger = logging.getLogger(__name__) + +# Compiled regex that replaces the original perl one-liner: +# +# perl -0777 -ne ' +# while (/async_compile\.triton\(\x27([^\x27]+)\x27,\s*\x27\x27\x27(.*?)\x27\x27\x27/gs) { +# print "===KERNEL_NAME===$1\n$2\n===KERNEL_END===\n"; +# }' +# +# Captures: group(1) = kernel name, group(2) = kernel source code. +_TRITON_KERNEL_PATTERN = re.compile( + r"async_compile\.triton\('([^']+)',\s*'''(.*?)'''", + re.DOTALL, +) + + +def _collect_best_config_hashes(graph_dir: Path) -> set[str]: + """Gather all autotuning-selected cache hashes from a sample directory. + + TorchInductor writes ``.best_config`` JSON files (one per autotuned kernel) + in 2-char prefix subdirectories of the sample cache. Each file contains a + ``triton_cache_hash`` field identifying the winning configuration among + multiple compiled candidates in ``triton/0/``. + + This function is called once per sample and the result is reused for every + kernel in that sample. + """ + hashes: set[str] = set() + for bc_path in graph_dir.rglob("*.best_config"): + try: + data = json.loads(bc_path.read_text(encoding="utf-8")) + cache_hash = data.get("triton_cache_hash") + if cache_hash: + hashes.add(cache_hash) + except (OSError, json.JSONDecodeError): + logger.debug("Skipping malformed .best_config: %s", bc_path) + return hashes + + +def _find_best_ptx( + graph_dir: Path, + kernel_name: str, + best_hashes: set[str], +) -> str | None: + """Locate the corresponding PTX for a given kernel via autotuning results. + + The inductor cache compiles each triton kernel into one or more candidate + configurations under ``triton/0/{HASH}/``. When autotuning runs, multiple + candidate directories exist for the same kernel and a ``.best_config`` file + records the winning ``triton_cache_hash``. + + Resolution strategy: + - 0 candidates → return ``None`` (no PTX compiled for this kernel). + - 1 candidate → return its PTX (no disambiguation needed). + - N candidates → intersect directory names with *best_hashes*; the match + identifies the autotuning winner. + """ + triton_base = graph_dir / "triton" / "0" + if not triton_base.is_dir(): + return None + + # Collect all triton/0/{hash}/ dirs that contain this kernel's PTX. + ptx_filename = f"{kernel_name}.ptx" + candidates: list[Path] = [ + ptx_file + for hash_dir in triton_base.iterdir() + if hash_dir.is_dir() + for ptx_file in [hash_dir / ptx_filename] + if ptx_file.is_file() + ] + + if not candidates: + logger.debug("No PTX found for kernel %s in %s", kernel_name, graph_dir.name) + return None + + if len(candidates) == 1: + try: + return candidates[0].read_text(encoding="utf-8", errors="replace") + except OSError: + logger.warning("Cannot read PTX file: %s", candidates[0]) + return None + + # Multiple candidates: pick the one whose parent dir matches a best_config hash. + for ptx_path in candidates: + if ptx_path.parent.name in best_hashes: + try: + return ptx_path.read_text(encoding="utf-8", errors="replace") + except OSError: + logger.warning("Cannot read PTX file: %s", ptx_path) + return None + + # Fallback: no .best_config match (should not happen based on validation). + logger.warning( + "Multiple PTX candidates for %s but no .best_config match in %s", + kernel_name, + graph_dir.name, + ) + return None + + +def extract_kernels_from_file( + output_code_path: Path, +) -> list[tuple[str, str]]: + """Parse an ``output_code.py`` and return ``(name, source)`` pairs. + + The file is read entirely into memory (``output_code.py`` files produced by + TorchInductor are typically well under 1 MB). Returns an empty list if the + file cannot be read. + """ + try: + content = output_code_path.read_text(encoding="utf-8", errors="replace") + except OSError: + logger.warning("Cannot read output_code.py: %s", output_code_path) + return [] + return _TRITON_KERNEL_PATTERN.findall(content) + + +def extract_triton_kernels( + cache_dir: Path, + output_dir: Path, +) -> tuple[int, int, int, int, int]: + """Walk kept samples, extract autotuning-selected triton kernels and corresponding PTX. + + For every kept sample that contains ``original_graph/graph_hash.txt``: + + 1. Copy ``original_graph/model.py`` into the output. + 2. Parse every ``output_code.py`` found in the sample tree. + 3. Write each extracted kernel to ``triton_kernel/{name}.py``. + 4. Locate the corresponding PTX for each kernel and write it to + ``ptx/{name}.ptx``. + + The output uses an atomic ``.tmp`` + ``rename`` pattern so that an + interrupted run never leaves a half-written sample directory. + + Returns: + ``(processed_files, total_kernels, total_ptx, copied_graphs, skip_count)`` + """ + kept_dir = cache_dir / "kept" + if not kept_dir.is_dir(): + logger.error("Kept directory does not exist: %s", kept_dir) + return 0, 0, 0, 0, 0 + + output_dir.mkdir(parents=True, exist_ok=True) + + # Clean up stale .tmp directories from a previous interrupted run. + for stale in output_dir.iterdir(): + if stale.is_dir() and stale.name.endswith(".tmp"): + shutil.rmtree(stale, ignore_errors=True) + + # Collect eligible samples (must contain original_graph/graph_hash.txt). + eligible: list[Path] = [ + d + for d in sorted(kept_dir.iterdir()) + if d.is_dir() and (d / "original_graph" / "graph_hash.txt").is_file() + ] + total = len(eligible) + + processed_files = 0 + total_kernels = 0 + total_ptx = 0 + copied_graphs = 0 + skip_count = 0 + + for idx, graph_dir in enumerate(eligible, 1): + graph_name = graph_dir.name + dest_graph_dir = output_dir / graph_name + + # Resume: skip if the final output already exists. + if dest_graph_dir.exists(): + skip_count += 1 + continue + + logger.info("[%d/%d] Extracting: %s", idx, total, graph_name) + + # Write to a temporary directory; rename atomically on success. + tmp_dir = dest_graph_dir.with_name(f"{graph_name}.tmp") + if tmp_dir.exists(): + shutil.rmtree(tmp_dir) + tmp_dir.mkdir(parents=True) + + # Copy original model source when available. + model_src = graph_dir / "original_graph" / "model.py" + if model_src.is_file(): + og_dir = tmp_dir / "original_graph" + og_dir.mkdir() + shutil.copy2(str(model_src), str(og_dir / "model.py")) + + # Pre-collect autotuning best-config hashes once per sample. + best_hashes = _collect_best_config_hashes(graph_dir) + + # Track kernel names already written for this sample to detect + # duplicates across multiple output_code.py files. + seen_kernels: set[str] = set() + + # Find and process all output_code.py files within the sample. + for output_code_path in sorted(graph_dir.rglob("output_code.py")): + processed_files += 1 + kernels = extract_kernels_from_file(output_code_path) + if not kernels: + continue + + triton_dir = tmp_dir / "triton_kernel" + triton_dir.mkdir(exist_ok=True) + + for name, source in kernels: + if name in seen_kernels: + logger.debug( + "Duplicate kernel %s in %s, skipping", name, graph_name + ) + continue + seen_kernels.add(name) + # Strip trailing whitespace then add exactly one newline, + # matching the bash `printf '%s\n'` semantics. + (triton_dir / f"{name}.py").write_text( + source.rstrip() + "\n", encoding="utf-8" + ) + total_kernels += 1 + + # Locate and write the corresponding PTX for this kernel. + ptx_content = _find_best_ptx(graph_dir, name, best_hashes) + if ptx_content is not None: + ptx_dir = tmp_dir / "ptx" + ptx_dir.mkdir(exist_ok=True) + (ptx_dir / f"{name}.ptx").write_text( + ptx_content, encoding="utf-8" + ) + total_ptx += 1 + + # Atomic completion: rename .tmp → final (same filesystem guarantees + # a single rename(2) syscall). + tmp_dir.rename(dest_graph_dir) + copied_graphs += 1 + + logger.info( + "Extraction: %d files, %d kernels, %d ptx, %d graphs, %d skipped (total: %d)", + processed_files, + total_kernels, + total_ptx, + copied_graphs, + skip_count, + total, + ) + logger.info("Output: %s", output_dir) + return processed_files, total_kernels, total_ptx, copied_graphs, skip_count diff --git a/tools/triton_kernel_extractor/pipeline.py b/tools/triton_kernel_extractor/pipeline.py new file mode 100644 index 0000000000..595449337c --- /dev/null +++ b/tools/triton_kernel_extractor/pipeline.py @@ -0,0 +1,131 @@ +"""Orchestrate the five-step extraction pipeline for all datasets.""" + +from __future__ import annotations + +import logging +from pathlib import Path + +from .compiler import compile_all_samples +from .config import ( + DatasetDescriptor, + PipelineConfig, + build_dataset_descriptors, +) +from .empty_sample_cleaner import clean_empty_kernel_samples +from .kernel_extractor import extract_triton_kernels +from .sample_enumerator import enumerate_hf_samples, enumerate_list_samples +from .speedup_filter import filter_samples_by_speedup +from .temp_cleaner import clean_temp_files + +logger = logging.getLogger(__name__) + + +def _load_samples( + config: PipelineConfig, + dataset: DatasetDescriptor, +) -> list[str]: + """Load the sample list for a single dataset from the appropriate source.""" + if config.source == "list": + if dataset.graph_list_file is None: + raise ValueError("graph_list_file must be set for 'list' source") + if not dataset.graph_list_file.is_file(): + logger.error( + "Graph list file not found: %s", dataset.graph_list_file + ) + return [] + return enumerate_list_samples(dataset.graph_list_file) + + # source == "hf" + if dataset.hf_subdir is None: + raise ValueError("hf_subdir must be set for 'hf' source") + return enumerate_hf_samples(config.passnet_hf_dir, dataset.hf_subdir) + + +def run_dataset_pipeline( + config: PipelineConfig, + dataset: DatasetDescriptor, + dataset_idx: int, + total_datasets: int, +) -> None: + """Execute all five pipeline steps for a single dataset.""" + logger.info("") + logger.info("======================================================") + logger.info( + " Dataset [%d/%d]: %s", dataset_idx, total_datasets, dataset.name + ) + + samples = _load_samples(config, dataset) + + if config.source == "list": + logger.info(" Graph list: %s", dataset.graph_list_file) + else: + logger.info( + " HF dir: %s/%s", config.passnet_hf_dir, dataset.hf_subdir + ) + + logger.info(" Samples: %d", len(samples)) + logger.info(" Cache dir: %s", dataset.cache_dir) + logger.info(" Output dir: %s", dataset.output_dir) + logger.info(" GPUs: %s", " ".join(str(g) for g in config.gpu_ids)) + if config.max_autotune: + logger.info(" Autotune: max_autotune (via GraphNet inductor config template)") + logger.info("======================================================") + + if not samples: + logger.error("No samples found for dataset: %s", dataset.name) + return + + dataset.cache_dir.mkdir(parents=True, exist_ok=True) + + # Step 1: Parallel compilation. + num_gpus = len(config.gpu_ids) + logger.info("") + logger.info("=== Step 1: Parallel compilation (%d GPUs) ===", num_gpus) + compile_all_samples(samples, config, dataset) + + # Step 2: Filter by speedup. + logger.info("") + logger.info("=== Step 2: Filter by speedup ===") + filter_samples_by_speedup(dataset.cache_dir) + + # Step 3: Clean temp files. + logger.info("") + logger.info("=== Step 3: Clean temp files ===") + clean_temp_files(dataset.cache_dir / "kept") + + # Step 4: Extract triton kernels. + logger.info("") + logger.info("=== Step 4: Extract autotuning-selected triton kernels and corresponding PTX ===") + extract_triton_kernels(dataset.cache_dir, dataset.output_dir) + + # Step 5: Clean samples without triton kernels. + logger.info("") + logger.info("=== Step 5: Clean samples without triton kernels ===") + clean_empty_kernel_samples(dataset.output_dir) + + +def run_pipeline( + config: PipelineConfig, + *, + enable_cache_analysis: bool = False, +) -> None: + """Run the full pipeline across all three dataset categories.""" + descriptors = build_dataset_descriptors(config) + total = len(descriptors) + + for idx, dataset in enumerate(descriptors, 1): + run_dataset_pipeline(config, dataset, idx, total) + + if enable_cache_analysis: + from .cache_analyzer import analyze_cache + + for idx, dataset in enumerate(descriptors, 1): + logger.info("") + logger.info( + "=== Cache analysis [%d/%d]: %s ===", idx, total, dataset.name + ) + analysis_dir = dataset.cache_dir.parent / f"{dataset.cache_dir.name}_analysis" + analyze_cache(dataset.cache_dir, analysis_dir) + + logger.info("") + logger.info("All datasets processed.") diff --git a/tools/triton_kernel_extractor/sample_enumerator.py b/tools/triton_kernel_extractor/sample_enumerator.py new file mode 100644 index 0000000000..7e8523d082 --- /dev/null +++ b/tools/triton_kernel_extractor/sample_enumerator.py @@ -0,0 +1,85 @@ +"""Enumerate graph samples from 'list' or 'hf' data sources.""" + +from __future__ import annotations + +import logging +from pathlib import Path + +logger = logging.getLogger(__name__) + + +def enumerate_list_samples(graph_list_file: Path) -> list[str]: + """Read sample paths from a text file, one per line. + + Blank lines are silently skipped. + + Raises: + FileNotFoundError: If *graph_list_file* does not exist. + """ + lines: list[str] = [] + with open(graph_list_file, encoding="utf-8") as fh: + for raw in fh: + stripped = raw.strip() + if stripped: + lines.append(stripped) + return lines + + +def enumerate_hf_samples(passnet_hf_dir: Path, hf_subdir: str) -> list[str]: + """Discover samples by scanning for ``model.py`` under a HuggingFace dir. + + Returns the sorted list of parent directories that contain a ``model.py``. + + Raises: + FileNotFoundError: If the base directory does not exist. + """ + base_dir = passnet_hf_dir / hf_subdir + if not base_dir.is_dir(): + raise FileNotFoundError(f"HF dataset directory not found: {base_dir}") + parents = sorted( + {str(p.parent) for p in base_dir.rglob("model.py") if p.is_file()} + ) + return parents + + +def compute_unique_dir( + source: str, + sample_path: str, + passnet_hf_dir: str, +) -> str: + """Derive a flat directory name from a sample path. + + For *list* sources the entire ``sample_path`` has ``/`` replaced by ``_``. + For *hf* sources only the relative portion below ``passnet_hf_dir`` is used. + + This mirrors the bash logic:: + + list: unique_dir="${sample_path//\\//_}" + hf: rel_path="${sample_path#$PASSNET_HF_DIR/}" + unique_dir="${rel_path//\\//_}" + """ + if source == "list": + return sample_path.replace("/", "_") + + # source == "hf" + hf_prefix = passnet_hf_dir.rstrip("/") + "/" + if sample_path.startswith(hf_prefix): + rel = sample_path[len(hf_prefix):] + else: + rel = sample_path + return rel.replace("/", "_") + + +def resolve_model_path( + source: str, + sample_path: str, + passnet_dir: str, +) -> str: + """Return the absolute path to the model directory. + + For *list* sources the model path is ``passnet_dir / sample_path``. + For *hf* sources ``sample_path`` is already absolute. + """ + if source == "list": + return f"{passnet_dir}/{sample_path}" + return sample_path diff --git a/tools/triton_kernel_extractor/speedup_filter.py b/tools/triton_kernel_extractor/speedup_filter.py new file mode 100644 index 0000000000..f9ceb94af3 --- /dev/null +++ b/tools/triton_kernel_extractor/speedup_filter.py @@ -0,0 +1,105 @@ +"""Step 2: Partition compiled samples into *kept* and *discarded* by speedup.""" + +from __future__ import annotations + +import logging +import re +import shutil +from pathlib import Path + +from .config import ( + SPEEDUP_KERNEL_PATTERN, + SPEEDUP_THRESHOLD, + is_sample_dir, +) + +logger = logging.getLogger(__name__) + +_SPEEDUP_RE = re.compile(SPEEDUP_KERNEL_PATTERN) + + +def _parse_kernel_speedup(log_file: Path) -> float | None: + """Extract the last ``[Speedup][kernel]:`` value from a compiler log. + + Returns ``None`` when the log does not exist or contains no speedup line. + """ + if not log_file.is_file(): + return None + try: + content = log_file.read_text(encoding="utf-8", errors="replace") + except OSError: + return None + + last_match: re.Match[str] | None = None + for m in _SPEEDUP_RE.finditer(content): + last_match = m + + if last_match is None: + return None + + try: + return float(last_match.group(1)) + except ValueError: + return None + + +def filter_samples_by_speedup(cache_dir: Path) -> tuple[int, int, int]: + """Move compiled samples into ``kept/`` or ``discarded/`` sub-directories. + + A sample is *kept* when the last ``[Speedup][kernel]:`` value in its + ``test_compiler_log.log`` is >= ``SPEEDUP_THRESHOLD``. Samples that + have already been classified are silently skipped. + + Returns: + A tuple of ``(kept_count, discarded_count, skip_count)``. + """ + kept_dir = cache_dir / "kept" + discarded_dir = cache_dir / "discarded" + kept_dir.mkdir(parents=True, exist_ok=True) + discarded_dir.mkdir(parents=True, exist_ok=True) + + # Collect candidate directories (snapshot the listing to avoid mutating + # the iterator while moving entries). + candidates: list[Path] = [ + d + for d in sorted(cache_dir.iterdir()) + if d.is_dir() and is_sample_dir(d.name) + ] + total = len(candidates) + + kept_count = 0 + discarded_count = 0 + skip_count = 0 + + for idx, graph_dir in enumerate(candidates, 1): + graph_name = graph_dir.name + + # Skip if already classified in a previous (possibly interrupted) run. + if (kept_dir / graph_name).exists() or ( + discarded_dir / graph_name + ).exists(): + skip_count += 1 + continue + + log_file = graph_dir / "test_compiler_log.log" + speedup = _parse_kernel_speedup(log_file) + should_keep = speedup is not None and speedup >= SPEEDUP_THRESHOLD + + if should_keep: + shutil.move(str(graph_dir), str(kept_dir / graph_name)) + kept_count += 1 + else: + shutil.move(str(graph_dir), str(discarded_dir / graph_name)) + discarded_count += 1 + + label = "KEPT" if should_keep else "DISCARDED" + logger.info("[%d/%d] %s: %s", idx, total, label, graph_name) + + logger.info( + "Filter: %d kept, %d discarded, %d skipped (total: %d)", + kept_count, + discarded_count, + skip_count, + total, + ) + return kept_count, discarded_count, skip_count diff --git a/tools/triton_kernel_extractor/temp_cleaner.py b/tools/triton_kernel_extractor/temp_cleaner.py new file mode 100644 index 0000000000..b4e5f3e80b --- /dev/null +++ b/tools/triton_kernel_extractor/temp_cleaner.py @@ -0,0 +1,41 @@ +"""Step 3: Remove Python bytecode caches from a directory tree.""" + +from __future__ import annotations + +import logging +import shutil +from pathlib import Path + +logger = logging.getLogger(__name__) + + +def clean_temp_files(directory: Path) -> None: + """Recursively delete ``__pycache__`` dirs, ``*.pyc``, and ``*.pyo`` files. + + Mirrors the original bash implementation:: + + find "$dir" -type d -name "__pycache__" -exec rm -rf {} + + find "$dir" -type f -name "*.pyc" -delete + find "$dir" -type f -name "*.pyo" -delete + + Silently skips entries that vanish between discovery and deletion (e.g. + a ``.pyc`` inside a ``__pycache__`` that was already removed). + """ + if not directory.is_dir(): + logger.warning("Directory does not exist, skipping: %s", directory) + return + + logger.info("Cleaning temp files from %s ...", directory) + + # Collect first, then delete — avoids mutating the tree during iteration. + pycache_dirs = sorted(directory.rglob("__pycache__"), reverse=True) + for d in pycache_dirs: + if d.is_dir(): + shutil.rmtree(d, ignore_errors=True) + + for pattern in ("*.pyc", "*.pyo"): + for f in directory.rglob(pattern): + try: + f.unlink() + except FileNotFoundError: + pass