66
77from ..utils import FPQuantConfig , FPQuantDtype , validate_config
88from .linear_fns import (
9- HAS_QUTLASS ,
109 FPQuant4x16MasterFn ,
10+ FPQuant4x4MasterFn ,
11+ FPQuant4x8MasterFn ,
12+ FPQuant4x8NoMasterFn ,
1113 FPQuant4x16NoMasterFn ,
1214 forward_quantize ,
1315)
16+ from .qutlass_ops import HAS_QUTLASS
1417from .pseudoquant_linear_fns import (
1518 PseudoQuant4x16MasterFn ,
1619 PseudoQuant4x16NoMasterFn ,
2023
2124def get_hadamard_matrix (group_size : int , dtype : torch .dtype , device : torch .device ):
2225 return torch .tensor (
23- hadamard (group_size ) * group_size ** - 0.5 , dtype = dtype , device = device
24- )
26+ hadamard (group_size ) * group_size ** - 0.5 ,
27+ dtype = dtype ,
28+ device = device ,
29+ requires_grad = False ,
30+ ) * (torch .randint (0 , 2 , (group_size , 1 ), device = device , dtype = dtype ) * 2 - 1 )
2531
2632
2733def get_identity_matrix (group_size : int , dtype : torch .dtype , device : torch .device ):
28- return torch .eye (group_size , dtype = dtype , device = device )
34+ return torch .eye (group_size , dtype = dtype , device = device , requires_grad = False )
2935
3036
3137def get_gsr_matrix (group_size : int , dtype : torch .dtype , device : torch .device ):
3238 hadamard_matrix = get_hadamard_matrix (group_size , dtype , device )
33- sign_changes = torch .diff (hadamard_matrix , dim = 0 ).ne (0 ).sum (dim = 0 )
39+ sign_changes = torch .diff (hadamard_matrix , dim = 0 ).ne (0 ).sum (dim = 0 )
3440 sorted_indices = torch .argsort (sign_changes )
3541 return hadamard_matrix [:, sorted_indices ].contiguous ()
3642
@@ -148,15 +154,15 @@ def __init__(
148154 @torch .no_grad ()
149155 def pre_forward (self ):
150156 # Generate rotation matrices
151- assert self . weight . shape [ 1 ] % self . config . hadamard_group_size == 0 , (
152- f"Weight shape must be divisible by hadamard group size: { self .weight .shape [1 ]} % { self .config .hadamard_group_size } = { self . weight . shape [ 1 ] % self . config . hadamard_group_size } "
153- )
157+ assert (
158+ self .weight .shape [1 ] % self .config .hadamard_group_size == 0
159+ ), f"Weight shape must be divisible by hadamard group size: { self . weight . shape [ 1 ] } % { self . config . hadamard_group_size } = { self . weight . shape [ 1 ] % self . config . hadamard_group_size } "
154160
155- weight_in_device = ( self .weight .data .device .type in ["cuda" , "xpu" ])
161+ weight_in_device = self .weight .data .device .type in ["cuda" , "xpu" ]
156162 if not self .config .pseudoquantization :
157- assert weight_in_device , (
158- f"Weight must be on CUDA or XPU, but is on { self . weight . device } "
159- )
163+ assert (
164+ weight_in_device
165+ ), f"Weight must be on CUDA or XPU, but is on { self . weight . device } "
160166 if self .config .transform_init == "hadamard" :
161167 transform_init_fn = get_hadamard_matrix
162168 elif self .config .transform_init == "identity" :
@@ -166,47 +172,49 @@ def pre_forward(self):
166172 else :
167173 raise ValueError (f"Invalid transform init: { self .config .transform_init } " )
168174
169- self .forward_hadamard_matrix = nn .Parameter (
175+ self .forward_hadamard_matrix = nn .Buffer (
170176 transform_init_fn (
171177 self .config .hadamard_group_size ,
172178 self .weight .dtype ,
173179 self .weight .device ,
174180 ),
175- requires_grad = False ,
176181 )
177- self .backward_hadamard_matrix = nn .Parameter (
182+ self .backward_hadamard_matrix = nn .Buffer (
178183 transform_init_fn (
179184 self .config .hadamard_group_size ,
180185 self .weight .dtype ,
181186 self .weight .device ,
182187 ),
183- requires_grad = False ,
184188 )
185189
186- if self .config .forward_dtype == FPQuantDtype .MXFP4 :
187- # MXFP4 quantization implicitly multiplies by 3.0
188- self .weight_global_scale = nn .Parameter (
189- torch .tensor ([3.0 ], dtype = self .weight .dtype , device = self .weight .device ),
190- requires_grad = False ,
191- )
192- self .act_global_scale = nn .Parameter (
193- torch .tensor ([3.0 ], dtype = self .weight .dtype , device = self .weight .device ),
194- requires_grad = False ,
195- )
190+ if (
191+ self .config .forward_dtype == FPQuantDtype .MXFP4
192+ and self .config .forward_method == "quest"
193+ ):
194+ global_scale_val = 1.0
195+ elif self .config .forward_method == "abs_max" :
196+ # MXFP4 abs_max quantization implicitly multiplies by 3.0
197+ global_scale_val = 3.0
196198 elif self .config .forward_dtype == FPQuantDtype .NVFP4 :
197- # MXFP4 quantization implicitly multiplies by 6.0
198- self .weight_global_scale = nn .Parameter (
199- torch .tensor (
200- [10.0 ], dtype = self .weight .dtype , device = self .weight .device
201- ),
199+ # 10.0 ensures no underflows/overflows in most models
200+ global_scale_val = 10.0
201+
202+ self .weight_global_scale = nn .Buffer (
203+ torch .tensor (
204+ [global_scale_val ],
205+ dtype = self .weight .dtype ,
206+ device = self .weight .device ,
202207 requires_grad = False ,
203- )
204- self .act_global_scale = nn .Parameter (
205- torch .tensor (
206- [10.0 ], dtype = self .weight .dtype , device = self .weight .device
207- ),
208+ ),
209+ )
210+ self .act_global_scale = nn .Buffer (
211+ torch .tensor (
212+ [global_scale_val ],
213+ dtype = self .weight .dtype ,
214+ device = self .weight .device ,
208215 requires_grad = False ,
209- )
216+ ),
217+ )
210218
211219 if self .config .store_master_weights :
212220 self .qweight = None
@@ -241,6 +249,55 @@ def pre_forward(self):
241249
242250 def forward (self , x ) -> torch .Tensor :
243251 if (
252+ self .config .forward_dtype == FPQuantDtype .MXFP4
253+ and self .config .backward_dtype == FPQuantDtype .MXFP4
254+ and self .config .store_master_weights == True
255+ and self .config .pseudoquantization == False
256+ ):
257+ return FPQuant4x4MasterFn .apply (
258+ x ,
259+ self .weight ,
260+ self .weight_global_scale ,
261+ self .act_global_scale ,
262+ self .bias ,
263+ self .forward_hadamard_matrix ,
264+ self .config .forward_dtype ,
265+ self .config .forward_method ,
266+ )
267+ elif (
268+ self .config .forward_dtype == FPQuantDtype .MXFP4
269+ and self .config .backward_dtype == FPQuantDtype .MXFP8
270+ and self .config .store_master_weights == True
271+ and self .config .pseudoquantization == False
272+ ):
273+ return FPQuant4x8MasterFn .apply (
274+ x ,
275+ self .weight ,
276+ self .weight_global_scale ,
277+ self .act_global_scale ,
278+ self .bias ,
279+ self .forward_hadamard_matrix ,
280+ self .config .forward_dtype ,
281+ self .config .forward_method ,
282+ )
283+ elif (
284+ self .config .forward_dtype == FPQuantDtype .MXFP4
285+ and self .config .backward_dtype == FPQuantDtype .MXFP8
286+ and self .config .store_master_weights == False
287+ and self .config .pseudoquantization == False
288+ ):
289+ return FPQuant4x8NoMasterFn .apply (
290+ x ,
291+ self .qweight ,
292+ self .scales ,
293+ self .weight_global_scale ,
294+ self .act_global_scale ,
295+ self .bias ,
296+ self .forward_hadamard_matrix ,
297+ self .config .forward_dtype ,
298+ self .config .forward_method ,
299+ )
300+ elif (
244301 self .config .forward_dtype in (FPQuantDtype .MXFP4 , FPQuantDtype .NVFP4 )
245302 and self .config .backward_dtype == FPQuantDtype .BF16
246303 and self .config .store_master_weights == True
0 commit comments