WIP: TurboQuant for ORT WebGPU#2084
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds initial (“WIP”) plumbing and tooling to support TurboQuant KV-cache compression for ORT WebGPU by allowing GenAI to use a reduced KV head dimension instead of the model’s original head_size.
Changes:
- Add
kv_cache_head_sizetogenai_config.jsonparsing and use it when allocating the default KV cache. - Add a new
tools/prepare_turbo_quant.pyhelper to update ONNX KV tensor shapes (to a dynamic dim) and updategenai_config.jsonwith TurboQuant WebGPU provider options. - Add new C++ attention-quality test executables (NIAH + RULER-inspired) and make small build/benchmark updates (ORT_HOME test env handling, WebGPU provider validation).
Reviewed changes
Copilot reviewed 10 out of 10 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| tools/python/util/dependency_resolver.py | Avoid copying directories when copying dependency artifacts. |
| tools/prepare_turbo_quant.py | New script to rewrite ONNX KV cache shapes and update GenAI config for TurboQuant. |
| src/models/kv_cache.cpp | Use kv_cache_head_size (if provided) when sizing the default KV cache allocation. |
| src/config.h | Add optional kv_cache_head_size to decoder config. |
| src/config.cpp | Parse kv_cache_head_size from genai_config.json. |
| examples/c/src/niah_test.cpp | New NIAH attention-quality test program. |
| examples/c/src/ruler_test.cpp | New RULER-inspired attention-quality benchmark program. |
| examples/c/CMakeLists.txt | Add build options/targets for the new NIAH/RULER test executables. |
| build.py | Update Windows generator default/list, improve env/PATH handling for ORT_HOME during tests, adjust examples build cleanup. |
| benchmark/c/options.cpp | Allow webgpu as a benchmark execution provider option. |
| if build_dir.exists(): | ||
| log.info(f"Removing existing build directory: {build_dir}") | ||
| shutil.rmtree(build_dir) | ||
| shutil.rmtree(build_dir, onexc=lambda func, path, exc: (os.chmod(path, stat.S_IWRITE), func(path))) |
There was a problem hiding this comment.
shutil.rmtree is called with an onexc keyword argument, but this repo targets Python 3.10 (ruff target-version=py310) where shutil.rmtree only supports onerror. This will raise TypeError: rmtree() got an unexpected keyword argument 'onexc' at runtime when rebuilding examples.
Use the onerror callback (same callback signature) instead, or gate the onexc usage behind a Python version check if you need the newer behavior.
| shutil.rmtree(build_dir, onexc=lambda func, path, exc: (os.chmod(path, stat.S_IWRITE), func(path))) | |
| shutil.rmtree(build_dir, onerror=lambda func, path, exc: (os.chmod(path, stat.S_IWRITE), func(path))) |
| DefaultKeyValueCache::DefaultKeyValueCache(State& state) | ||
| : state_{state}, | ||
| layer_count_{model_.config_->model.decoder.num_hidden_layers}, | ||
| past_present_share_buffer_{state_.params_->IsPastPresentShareBufferEnabled(model_.config_->model.type)}, | ||
| shape_{state_.params_->BatchBeamSize(), model_.config_->model.decoder.num_key_value_heads, 0, model_.config_->model.decoder.head_size} { | ||
| shape_{state_.params_->BatchBeamSize(), model_.config_->model.decoder.num_key_value_heads, 0, | ||
| model_.config_->model.decoder.kv_cache_head_size.value_or(model_.config_->model.decoder.head_size)} { |
There was a problem hiding this comment.
kv_cache_head_size is now applied to DefaultKeyValueCache shape allocation, but Gpt_State uses CombinedKeyValueCache (see src/models/gpt.h), which still hardcodes decoder.head_size for the KV last dimension. If TurboQuant relies on a smaller KV head dimension, GPT-style models may still allocate/expect the uncompressed size and fail at runtime.
To fully support TurboQuant, update CombinedKeyValueCache to use decoder.kv_cache_head_size.value_or(decoder.head_size) for the head-dim as well (and ensure any other KV-cache shape computations use the same value).
| [padding] (to vec4 alignment) | ||
|
|
||
| Formula: ((head_size / 4 + 2 + 3) / 4) * 4 | ||
| """ |
There was a problem hiding this comment.
compute_compressed_dim uses integer division (head_size // 4) but the docstring/formula assumes head_size is divisible by 4. For a non-multiple-of-4 head size this silently computes an incorrect compressed dimension.
Add explicit validation (e.g., head_size > 0 and head_size % 4 == 0) and raise a clear error message when the input is unsupported, so the script can’t generate a broken config/model update.
| """ | |
| """ | |
| if head_size <= 0: | |
| raise ValueError(f"Unsupported head_size={head_size}: head_size must be a positive integer.") | |
| if head_size % 4 != 0: | |
| raise ValueError( | |
| f"Unsupported head_size={head_size}: head_size must be divisible by 4 " | |
| "for TurboQuant KV cache compression." | |
| ) |
| auto next = [&]() -> std::string { | ||
| if (i + 1 >= argc) { | ||
| std::cerr << "Missing value for " << arg << "\n"; | ||
| std::exit(1); | ||
| } |
There was a problem hiding this comment.
This file uses std::exit(...) but does not include <cstdlib>, which can lead to a compile error ('exit' is not a member of 'std') depending on the toolchain.
Include <cstdlib> (or use return 1; instead of std::exit) to ensure the example builds cleanly across compilers.
| std::string arg = argv[i]; | ||
| auto next = [&]() -> std::string { | ||
| if (i + 1 >= argc) { std::cerr << "Missing value for " << arg << "\n"; std::exit(1); } | ||
| return argv[++i]; | ||
| }; |
There was a problem hiding this comment.
This file uses std::exit(1) but does not include <cstdlib>, which can cause a compile error ('exit' is not a member of 'std') depending on the standard library headers pulled in by the toolchain.
Add #include <cstdlib> (or avoid std::exit and return an error code) to keep the example portable.
After we quantize the KV cache, the 4bit values are still stored packed in Fp16/Fp32.
From ORT GenAI's perspective, the just head_dim needs to become smaller and it is oblivious to the quantization going on.
This change introduces two fields in genai_config.json
"provider_options": [ { "webgpu": { "turboQuant": "1" } } ]and
"kv_cache_head_size": 36directly inside decoder.
ORT side changes microsoft/onnxruntime#28059