Skip to content

Commit e980b45

Browse files
authored
Merge pull request #12 from IST-DASLab/backwards
MXFP4 and MXFP8 backward passes
2 parents 617317d + d9a1e4b commit e980b45

File tree

8 files changed

+965
-173
lines changed

8 files changed

+965
-173
lines changed

inference_lib/setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name="fp_quant",
5-
version="0.2.0",
5+
version="0.3.0",
66
packages=find_packages(where="src"),
77
package_dir={"": "src"},
88
author="Andrei Panferov",
@@ -18,7 +18,7 @@
1818
],
1919
python_requires=">=3.9",
2020
install_requires=[
21-
"torch>=2.7.0",
21+
"torch>=2.8.0",
2222
"scipy>=1.13.0",
2323
"triton>=3.3.0",
2424
],

inference_lib/src/fp_quant/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,10 @@
33
"""
44

55
from .module import FPQuantLinear
6-
from .utils import FPQuantConfig, FPQuantDtype, replace_with_fp_quant_linear
6+
from .utils import (
7+
FPQuantConfig,
8+
FPQuantDtype,
9+
replace_with_fp_quant_linear,
10+
replace_quantize_with_fp_quant_linear,
11+
finalize_master_weights,
12+
)

inference_lib/src/fp_quant/module/linear.py

Lines changed: 94 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@
66

77
from ..utils import FPQuantConfig, FPQuantDtype, validate_config
88
from .linear_fns import (
9-
HAS_QUTLASS,
109
FPQuant4x16MasterFn,
10+
FPQuant4x4MasterFn,
11+
FPQuant4x8MasterFn,
12+
FPQuant4x8NoMasterFn,
1113
FPQuant4x16NoMasterFn,
1214
forward_quantize,
1315
)
16+
from .qutlass_ops import HAS_QUTLASS
1417
from .pseudoquant_linear_fns import (
1518
PseudoQuant4x16MasterFn,
1619
PseudoQuant4x16NoMasterFn,
@@ -20,17 +23,20 @@
2023

2124
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
2225
return torch.tensor(
23-
hadamard(group_size) * group_size**-0.5, dtype=dtype, device=device
24-
)
26+
hadamard(group_size) * group_size**-0.5,
27+
dtype=dtype,
28+
device=device,
29+
requires_grad=False,
30+
) * (torch.randint(0, 2, (group_size, 1), device=device, dtype=dtype) * 2 - 1)
2531

2632

2733
def get_identity_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
28-
return torch.eye(group_size, dtype=dtype, device=device)
34+
return torch.eye(group_size, dtype=dtype, device=device, requires_grad=False)
2935

3036

3137
def get_gsr_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
3238
hadamard_matrix = get_hadamard_matrix(group_size, dtype, device)
33-
sign_changes = torch.diff(hadamard_matrix, dim=0).ne(0).sum(dim=0)
39+
sign_changes = torch.diff(hadamard_matrix, dim=0).ne(0).sum(dim=0)
3440
sorted_indices = torch.argsort(sign_changes)
3541
return hadamard_matrix[:, sorted_indices].contiguous()
3642

@@ -148,15 +154,15 @@ def __init__(
148154
@torch.no_grad()
149155
def pre_forward(self):
150156
# Generate rotation matrices
151-
assert self.weight.shape[1] % self.config.hadamard_group_size == 0, (
152-
f"Weight shape must be divisible by hadamard group size: {self.weight.shape[1]} % {self.config.hadamard_group_size} = {self.weight.shape[1] % self.config.hadamard_group_size}"
153-
)
157+
assert (
158+
self.weight.shape[1] % self.config.hadamard_group_size == 0
159+
), f"Weight shape must be divisible by hadamard group size: {self.weight.shape[1]} % {self.config.hadamard_group_size} = {self.weight.shape[1] % self.config.hadamard_group_size}"
154160

155-
weight_in_device = (self.weight.data.device.type in ["cuda", "xpu"])
161+
weight_in_device = self.weight.data.device.type in ["cuda", "xpu"]
156162
if not self.config.pseudoquantization:
157-
assert weight_in_device, (
158-
f"Weight must be on CUDA or XPU, but is on {self.weight.device}"
159-
)
163+
assert (
164+
weight_in_device
165+
), f"Weight must be on CUDA or XPU, but is on {self.weight.device}"
160166
if self.config.transform_init == "hadamard":
161167
transform_init_fn = get_hadamard_matrix
162168
elif self.config.transform_init == "identity":
@@ -166,47 +172,49 @@ def pre_forward(self):
166172
else:
167173
raise ValueError(f"Invalid transform init: {self.config.transform_init}")
168174

169-
self.forward_hadamard_matrix = nn.Parameter(
175+
self.forward_hadamard_matrix = nn.Buffer(
170176
transform_init_fn(
171177
self.config.hadamard_group_size,
172178
self.weight.dtype,
173179
self.weight.device,
174180
),
175-
requires_grad=False,
176181
)
177-
self.backward_hadamard_matrix = nn.Parameter(
182+
self.backward_hadamard_matrix = nn.Buffer(
178183
transform_init_fn(
179184
self.config.hadamard_group_size,
180185
self.weight.dtype,
181186
self.weight.device,
182187
),
183-
requires_grad=False,
184188
)
185189

186-
if self.config.forward_dtype == FPQuantDtype.MXFP4:
187-
# MXFP4 quantization implicitly multiplies by 3.0
188-
self.weight_global_scale = nn.Parameter(
189-
torch.tensor([3.0], dtype=self.weight.dtype, device=self.weight.device),
190-
requires_grad=False,
191-
)
192-
self.act_global_scale = nn.Parameter(
193-
torch.tensor([3.0], dtype=self.weight.dtype, device=self.weight.device),
194-
requires_grad=False,
195-
)
190+
if (
191+
self.config.forward_dtype == FPQuantDtype.MXFP4
192+
and self.config.forward_method == "quest"
193+
):
194+
global_scale_val = 1.0
195+
elif self.config.forward_method == "abs_max":
196+
# MXFP4 abs_max quantization implicitly multiplies by 3.0
197+
global_scale_val = 3.0
196198
elif self.config.forward_dtype == FPQuantDtype.NVFP4:
197-
# MXFP4 quantization implicitly multiplies by 6.0
198-
self.weight_global_scale = nn.Parameter(
199-
torch.tensor(
200-
[10.0], dtype=self.weight.dtype, device=self.weight.device
201-
),
199+
# 10.0 ensures no underflows/overflows in most models
200+
global_scale_val = 10.0
201+
202+
self.weight_global_scale = nn.Buffer(
203+
torch.tensor(
204+
[global_scale_val],
205+
dtype=self.weight.dtype,
206+
device=self.weight.device,
202207
requires_grad=False,
203-
)
204-
self.act_global_scale = nn.Parameter(
205-
torch.tensor(
206-
[10.0], dtype=self.weight.dtype, device=self.weight.device
207-
),
208+
),
209+
)
210+
self.act_global_scale = nn.Buffer(
211+
torch.tensor(
212+
[global_scale_val],
213+
dtype=self.weight.dtype,
214+
device=self.weight.device,
208215
requires_grad=False,
209-
)
216+
),
217+
)
210218

211219
if self.config.store_master_weights:
212220
self.qweight = None
@@ -241,6 +249,55 @@ def pre_forward(self):
241249

242250
def forward(self, x) -> torch.Tensor:
243251
if (
252+
self.config.forward_dtype == FPQuantDtype.MXFP4
253+
and self.config.backward_dtype == FPQuantDtype.MXFP4
254+
and self.config.store_master_weights == True
255+
and self.config.pseudoquantization == False
256+
):
257+
return FPQuant4x4MasterFn.apply(
258+
x,
259+
self.weight,
260+
self.weight_global_scale,
261+
self.act_global_scale,
262+
self.bias,
263+
self.forward_hadamard_matrix,
264+
self.config.forward_dtype,
265+
self.config.forward_method,
266+
)
267+
elif (
268+
self.config.forward_dtype == FPQuantDtype.MXFP4
269+
and self.config.backward_dtype == FPQuantDtype.MXFP8
270+
and self.config.store_master_weights == True
271+
and self.config.pseudoquantization == False
272+
):
273+
return FPQuant4x8MasterFn.apply(
274+
x,
275+
self.weight,
276+
self.weight_global_scale,
277+
self.act_global_scale,
278+
self.bias,
279+
self.forward_hadamard_matrix,
280+
self.config.forward_dtype,
281+
self.config.forward_method,
282+
)
283+
elif (
284+
self.config.forward_dtype == FPQuantDtype.MXFP4
285+
and self.config.backward_dtype == FPQuantDtype.MXFP8
286+
and self.config.store_master_weights == False
287+
and self.config.pseudoquantization == False
288+
):
289+
return FPQuant4x8NoMasterFn.apply(
290+
x,
291+
self.qweight,
292+
self.scales,
293+
self.weight_global_scale,
294+
self.act_global_scale,
295+
self.bias,
296+
self.forward_hadamard_matrix,
297+
self.config.forward_dtype,
298+
self.config.forward_method,
299+
)
300+
elif (
244301
self.config.forward_dtype in (FPQuantDtype.MXFP4, FPQuantDtype.NVFP4)
245302
and self.config.backward_dtype == FPQuantDtype.BF16
246303
and self.config.store_master_weights == True

0 commit comments

Comments
 (0)