Conversation
16f12fa to
88b6869
Compare
|
Could you also add the corresponding unit tests for impacted functions in quant_utils.py here? Thanks! |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #736 +/- ##
==========================================
- Coverage 74.02% 73.82% -0.21%
==========================================
Files 192 193 +1
Lines 19664 19745 +81
==========================================
+ Hits 14557 14577 +20
- Misses 5107 5168 +61 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| assert dequant_tensor.shape == input_shape, ( | ||
| f"Expected dequantized shape {input_shape}, got {dequant_tensor.shape}" | ||
| ) | ||
| assert torch.allclose(dequant_tensor, test_tensor, rtol=5e-2, atol=5e-2), ( |
There was a problem hiding this comment.
We can also compare with the fake quant here.
There was a problem hiding this comment.
As far as I understand, fake quant is tested by test_qtensor_accuracy (part of class TestQTensor).
In the code below the comment # compare with fake quant as well.
I added a test case for MXFP8 in test test_qtensor_accuracy.
And checked that it works using this command:
pytest --maxfail 1 tests/gpu/torch/quantization/test_qtensor_cuda.py -k "test_qtensor_accuracy"
All the new MXFP8 tests also worked, using this command:
pytest tests/gpu/torch/quantization/test_qtensor_cuda.py -k "test_mxfp8"
9b0c088 to
a764b32
Compare
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the
📝 WalkthroughWalkthroughThis change introduces support for MXFP8 quantization format across the codebase. A new MXFP8QTensor class implements block-based FP8 E4M3 quantization with E8M0 shared scales. MXFP8 support is integrated into configuration, quantization utilities, export workflows, and test coverage. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@modelopt/torch/quantization/qtensor/mxfp8_tensor.py`:
- Around line 93-138: get_weights_scaling_factor_from_quantizer currently
assumes 2D weights and computes expected_shape = (out_dim, in_dim //
BLOCK_SIZE), which breaks for 3D MoE weights (num_experts, out_dim, in_dim)
because reduce_block_amax yields a 3D scale; update the method to detect MoE by
checking weight.dim() == 3 and set expected_shape = (num_experts, out_dim,
in_dim // cls.BLOCK_SIZE) in that case (or mirror the NVFP4 transpose guard
behavior before calling this method), then after pulling weight_quantizer._scale
ensure scale.shape exactly equals expected_shape (after an allowed reshape only
when numel matches) and raise/assert with a clear message if it does not;
reference symbols: get_weights_scaling_factor_from_quantizer,
get_weights_scaling_factor, cls.BLOCK_SIZE, cls.SCALE_DTYPE, and
weight_quantizer._scale.
🧹 Nitpick comments (1)
modelopt/torch/export/unified_export_hf.py (1)
301-309: Minor redundancy:weightis already available from line 250.The MXFP8 export logic is correct. However,
weightis already fetched at line 250 viagetattr(sub_module, weight_name), so line 303 re-fetches the same value unnecessarily.♻️ Optional: reuse existing weight variable
elif quantization_format == QUANTIZATION_MXFP8: # MXFP8 uses dynamic block quantization with E8M0 scales (uint8) - weight = getattr(sub_module, weight_name) e8m0_scale = MXFP8QTensor.get_weights_scaling_factor_from_quantizer( weight, weight_quantizer )
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
examples/llm_ptq/hf_ptq.pyexamples/llm_ptq/scripts/huggingface_example.shmodelopt/torch/export/model_config.pymodelopt/torch/export/quant_utils.pymodelopt/torch/export/unified_export_hf.pymodelopt/torch/quantization/nn/modules/tensor_quantizer.pymodelopt/torch/quantization/qtensor/__init__.pymodelopt/torch/quantization/qtensor/mxfp8_tensor.pytests/examples/llm_ptq/test_llm_ptq.pytests/gpu/torch/quantization/test_qtensor_cuda.py
🧰 Additional context used
🧬 Code graph analysis (4)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
modelopt/torch/quantization/qtensor/mxfp8_tensor.py (2)
MXFP8QTensor(26-269)quantize(195-222)
modelopt/torch/export/unified_export_hf.py (1)
modelopt/torch/quantization/qtensor/mxfp8_tensor.py (2)
MXFP8QTensor(26-269)get_weights_scaling_factor_from_quantizer(94-138)
tests/examples/llm_ptq/test_llm_ptq.py (2)
tests/_test_utils/examples/llm_ptq_utils.py (1)
PTQCommand(28-87)tests/_test_utils/torch/quantization/quant_utils.py (1)
quant(19-30)
modelopt/torch/export/quant_utils.py (2)
modelopt/torch/quantization/qtensor/mxfp8_tensor.py (3)
MXFP8QTensor(26-269)get_weights_scaling_factor_from_quantizer(94-138)quantize_with_scale(141-192)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
block_sizes(360-362)block_sizes(365-367)
🔇 Additional comments (32)
examples/llm_ptq/scripts/huggingface_example.sh (2)
56-58: LGTM!The
mxfp8format is correctly added to the valid quantization formats list and the error message is updated consistently.
207-210: Verify ifmxfp8should be added to the supported formats for TRT-LLM torch runtime.The
mxfp8format is not included in the check on line 207, meaning MXFP8-quantized models will exit early with a message to use TensorRT-LLM for deployment. If MXFP8 should support the same workflow asfp8andnvfp4(continuing torun_tensorrt_llm.py), consider adding it:- if [[ ! " fp8 nvfp4 bf16 fp16 " =~ " ${QFORMAT} " ]]; then + if [[ ! " fp8 nvfp4 bf16 fp16 mxfp8 " =~ " ${QFORMAT} " ]]; thenmodelopt/torch/export/model_config.py (1)
38-38: LGTM!The new
QUANTIZATION_MXFP8constant follows the established naming convention and is correctly placed among related quantization format identifiers.tests/examples/llm_ptq/test_llm_ptq.py (1)
117-117: LGTM!The MXFP8 test case is correctly configured with
min_sm=100to ensure it only runs on Blackwell GPUs which have hardware acceleration for MXFP8.modelopt/torch/quantization/qtensor/__init__.py (1)
23-23: LGTM!The
mxfp8_tensormodule export follows the established pattern and is correctly positioned alphabetically among the other tensor module imports.examples/llm_ptq/hf_ptq.py (3)
175-191: LGTM!The
mxfp8format is correctly added to the auto-quantize validation list, enabling MXFP8 as a valid format option for automatic per-layer quantization search.
759-774: LGTM!The
mxfp8format is correctly added to the mono-quantize validation list for the HF export path.
86-86: LGTM!The
mxfp8format is correctly mapped tomtq.MXFP8_DEFAULT_CFGin the quantization configuration choices dictionary. The constant is properly defined and exported in the mtq module.modelopt/torch/quantization/nn/modules/tensor_quantizer.py (2)
52-52: LGTM!The import addition for
MXFP8QTensoris correctly placed alongside other quantized tensor imports.
693-703: LGTM!The MXFP8 branch correctly:
- Validates block size matches the MXFP8 spec (32)
- Uses
MXFP8QTensor.quantize()which handles block-based quantization internally- Stores scales in the same manner as MXFP4
The distinction between MXFP4
(2, 1)and MXFP8(4, 3)num_bits is clear and properly separated.tests/gpu/torch/quantization/test_qtensor_cuda.py (9)
18-18: LGTM!The imports for
mathandMXFP8QTensorare correctly added to support the new MXFP8 tests.Also applies to: 27-27
253-260: LGTM!The MXFP8 test case is correctly added to
test_qtensor_accuracywith appropriate configuration matching the MXFP8 spec (block size 32, dynamic type, scale_bits (8,0)).
616-676: LGTM!Comprehensive test for MXFP8 quantize/dequantize covering:
- Multiple devices (cuda, cpu)
- Multiple dtypes (float32, float16, bfloat16)
- Various shapes including 3D MoE-like tensors
- Padding scenarios (dimensions not divisible by 32)
- Proper assertions for scale dtype, shapes, and quantized data format
The tolerance of
rtol=5e-2, atol=5e-2is reasonable for FP8 quantization precision.
678-716: LGTM!Excellent test for verifying E8M0 scale computation with known input values. The test validates that per-block max values are preserved through the quantize-dequantize cycle.
718-751: LGTM!Good boundary value testing for FP8 E4M3 limits (max 448, powers of 2, positive/negative values). The
# fmt: off/onmarkers appropriately preserve the readable tensor formatting.
753-782: LGTM!Memory usage test follows the same pattern as the existing NVFP4 test. The 3x threshold is reasonable given MXFP8 stores FP8 data plus uint8 scales.
784-806: LGTM!Tests for
get_weights_scaling_factorwith proper shape and dtype validation. The check for E8M0 values ≤ 254 correctly excludes NaN representation (255).
808-824: LGTM!Good coverage of edge cases for
_compute_e8m0_exponent:
- Zero amax → minimum exponent (-127)
- E4M3_MAX (448) → exponent 0
- Normal value (1.0) → computed exponent
- Very large/small values → clamped to valid range
826-889: LGTM!Comprehensive error handling tests covering:
- 1D tensor assertions
- Non-divisible dimensions
- Wrong scale dtype
- Empty tensor
- 0D tensor (scalar)
- Non-floating point input
- Missing scale in dequantize
This ensures robust input validation.
modelopt/torch/export/unified_export_hf.py (1)
35-35: LGTM!The imports for
MXFP8QTensorandQUANTIZATION_MXFP8are correctly added to support MXFP8 export handling.Also applies to: 54-54
modelopt/torch/export/quant_utils.py (5)
33-33: LGTM!The imports for
MXFP8QTensorandQUANTIZATION_MXFP8are correctly added.Also applies to: 58-58
296-297: LGTM!The MXFP8 weight scaling factor retrieval correctly delegates to
MXFP8QTensor.get_weights_scaling_factor_from_quantizer, which handles both extracting existing scales and computing new ones.
482-489: LGTM!The MXFP8 detection logic correctly identifies the format by checking:
block_sizesis a dicttypeis"dynamic"scale_bitsis(8, 0)(E8M0 format)This is properly positioned before the FP8_PB_WO/FP8_PB_REAL checks at lines 490-493, ensuring MXFP8 is correctly distinguished from other FP8 block quantization formats.
685-689: LGTM!The MXFP8 layer config processing correctly maps the
"mxfp8"format to"MXFP8"quant_algo with the appropriate group_size, following the same pattern as other quantization formats.
794-795: LGTM!The
to_quantized_weightfunction correctly usesMXFP8QTensor.quantize_with_scaleto apply the pre-computed E8M0 scale to the weight tensor.modelopt/torch/quantization/qtensor/mxfp8_tensor.py (7)
1-23: LGTM!Clean module structure with proper license, docstring, imports from existing utilities (
reduce_block_amax,reduce_block_padding), and explicit__all__export.
26-40: LGTM!Class constants are correctly defined:
E4M3_MAX = 448.0matches FP8 E4M3 max valueBLOCK_SIZE = 32per MXFP8 specificationSCALE_DTYPE = torch.uint8for E8M0 biased exponent storage
42-66: LGTM!The
_compute_e8m0_exponentimplementation:
- Converts to float32 for numerical stability
- Handles zero values by using
torch.wherewith min_value fallback- Correctly computes
ceil(log2(amax / E4M3_MAX))- Clamps to valid E8M0 range [-127, 127]
68-91: LGTM!The
get_weights_scaling_factorimplementation correctly:
- Validates 2D minimum dimension
- Validates divisibility by BLOCK_SIZE
- Uses existing
reduce_block_amaxutility- Converts to biased uint8 format (exponent + 127)
140-192: LGTM!The
quantize_with_scaleimplementation is well-structured:
- Proper input validation for dimensions and dtype
- Flexible scale reshaping to handle different input shapes
- Correct scale factor computation:
2^(127 - exponent)- Proper clamping to E4M3 range before FP8 conversion
- The NOTE comment documents potential vLLM/flashinfer compatibility consideration
194-222: LGTM!The
quantizemethod correctly implements the full quantization flow:
- Input validation for empty, dimension, and dtype
- Padding alignment via
reduce_block_padding- Per-block amax computation
- E8M0 exponent computation and biasing
- Shape restoration via cropping
224-269: LGTM!The
dequantizemethod correctly reverses the quantization:
- Requires scale in kwargs (enforced by assertion)
- Converts quantized data to float for computation
- Applies padding for block alignment
- Computes descale as
2^(exponent - 127)- Handles scale shape broadcasting
- Restores original shape via cropping
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
| return cls.get_weights_scaling_factor(weight) | ||
|
|
||
| @classmethod | ||
| def quantize_with_scale( |
There was a problem hiding this comment.
Is it possible to add this logic inside the quantize() function? We have a similar use case in NVFP4QTensor where we use precomputed weight_scaling_factor2 to quantize the weights
There was a problem hiding this comment.
Added weights_scaling_factor as optional input to quantize of MXFP8QTensor.
mxinO
left a comment
There was a problem hiding this comment.
LGTM! Did you test some benchmarks to make sure the accuracy is as expected.
| {-1: 32, "type": "dynamic", "scale_bits": (8, 0)}, | ||
| None, | ||
| torch.randn([512, 512], dtype=torch.float32), | ||
| None, |
There was a problem hiding this comment.
If the test_output is None, seems this test only tests quantize and dequantize, which is already covered in test_mxfp8_quantize_dequantize.
There was a problem hiding this comment.
Removed and added a new fake quant test instead test_mxfp8_fake_quant
|
@mxinO Once this PR is merged, I can continue to work on support in vLLM. See this draft PR with my vLLM branch with support for ModelOpt MXFP8: Using my vLLM branch, I was able to run gsm8k: Results for Results for You can see an accuracy drop, but that's expected (BF16 has higher accuracy than MXFP8). Note: |
b7ed5ce to
23570a1
Compare
2bc8bca to
4a2a15e
Compare
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
Head branch was pushed to by a user without write access
4a2a15e to
08a0524
Compare
|
/ok to test 08a0524 |
What does this PR do?
Type of change: new feature
Overview: Add support for MXFP8 PTQ, enabling MXFP8 hardware acceleration during inference on Blackwell GPUs.
Usage
The
hf_quant_config.jsonof the output checkpoint:{ "producer": { "name": "modelopt", "version": "0.41.0.dev50+g7a796a875" }, "quantization": { "quant_algo": "MXFP8", "kv_cache_quant_algo": "FP8", "group_size": 32, "exclude_modules": [ "lm_head" ] } }And
config.json(only thequantization_config):Testing
Used
hf_ptq.pyto quantize the modelnvidia/OpenMath2-Llama3.1-8B(available in hugging-face), see the example command above.Checked that the generated MXFP8 checkpoint can be loaded with vLLM (required changes in vLLM, not merged to main).
Added tests for
MXFP8QTensorintests/gpu/torch/quantization/test_qtensor_cuda.py.Added "mxfp8" in
tests/examples/llm_ptq/test_llm_ptq.pySupport for Nemotron Models
Verify that Nemotron Nano V3 BF16 can be converted to MXFP8 using
hf_ptq.py:https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Tests
✏️ Tip: You can customize this high-level summary in your review settings.