@@ -41,6 +41,7 @@ def activation_nquant_2d(x: torch.Tensor, qbit = 8):
4141 if qbit == 1 :
4242 x_absmean = torch .mean (x .abs (), dim = (- 2 ,- 1 ), keepdim = True )
4343 y = x .sign () * x_absmean
44+ y = torch .where (y == 0.0 , - x_absmean , y )
4445 elif qbit < 2 : # Ternary quantization
4546 x_absmean = torch .mean (x .abs (), dim = (- 2 ,- 1 ), keepdim = True )
4647 scale = 1.0 / x_absmean .clamp_ (min = 1e-5 )
@@ -140,7 +141,7 @@ def weight_quantnb(w: torch.Tensor, qbit = 8, mode = "max"):
140141def weight_quantnb_group (w : torch .Tensor , qbit : int = 8 , mode : str = "max" ,
141142 dim : int = - 1 , group_size : int = 32 , return_expanded : bool = True ):
142143 """
143- Group-wise symmetric quantization for qbit >= 2 .
144+ Group-wise weight quantization for qbit >= 1 .
144145 - dim: dimension to group over
145146 - group_size: number of contiguous elements per group along 'dim'
146147 - mode: "max" (per-group max) for qbit > 2, otherwise "mean"
@@ -149,7 +150,7 @@ def weight_quantnb_group(w: torch.Tensor, qbit: int = 8, mode: str = "max",
149150 u: integer-quantized weights (same shape as w)
150151 inv_scale: inverse scale tensor (broadcastable to w if return_expanded=True)
151152 """
152- assert qbit > 1 , "qbit should be larger than 1"
153+ assert qbit >= 1 , "qbit should be larger than or equal to 1"
153154 assert isinstance (group_size , int ) and group_size > 0 , "group_size must be a positive int"
154155
155156 # Normalize dim to positive index
@@ -166,16 +167,25 @@ def weight_quantnb_group(w: torch.Tensor, qbit: int = 8, mode: str = "max",
166167 x = w .reshape (new_shape )
167168
168169 reduce_dim = dim + 1 # the 'group_size' axis
169- if (mode == "max" ) and (qbit > 2 ):
170+ if qbit == 1 :
171+ denom = x .abs ().mean (dim = reduce_dim , keepdim = True )
172+ scale = 1.0 / denom .clamp (min = 1e-5 )
173+ u_group = x .sign ()
174+ elif 1 < qbit < 2 :
175+ denom = x .abs ().mean (dim = reduce_dim , keepdim = True )
176+ scale = 1.0 / denom .clamp (min = 1e-5 )
177+ u_group = (x * scale ).round ().clamp_ (- 1 , 1 )
178+ elif (mode == "max" ) and (qbit > 2 ):
170179 denom = x .abs ().amax (dim = reduce_dim , keepdim = True )
180+ scale = (2 ** (qbit - 1 ) - 1 ) / denom .clamp (min = 1e-5 )
181+ u_group = (x * scale ).round ().clamp_ (- (2 ** (qbit - 1 )), 2 ** (qbit - 1 ) - 1 )
171182 elif (mode == "mean" ) or (qbit <= 2 ):
172183 denom = x .abs ().mean (dim = reduce_dim , keepdim = True )
184+ scale = (2 ** (qbit - 1 ) - 1 ) / denom .clamp (min = 1e-5 )
185+ u_group = (x * scale ).round ().clamp_ (- (2 ** (qbit - 1 )), 2 ** (qbit - 1 ) - 1 )
173186 else :
174187 raise ValueError ("Invalid mode" )
175188
176- scale = (2 ** (qbit - 1 ) - 1 ) / denom .clamp (min = 1e-5 )
177-
178- u_group = (x * scale ).round ().clamp_ (- (2 ** (qbit - 1 )), 2 ** (qbit - 1 ) - 1 )
179189 u = u_group .reshape_as (w )
180190
181191 inv_scale_group = 1.0 / scale
@@ -264,6 +274,19 @@ def weight_quant(self, w: torch.Tensor):
264274
265275 def save_qweight (self ):
266276 self .qw , self .qw_scale = self ._weight_quant_impl (self .weight .data )
277+
278+ def qweight (self ):
279+ if self .qw is None or self .qw_scale is None :
280+ self .save_qweight ()
281+ if self .qw .device != self .weight .device :
282+ self .qw = self .qw .to (self .weight .device )
283+ if self .qw_scale .device != self .weight .device :
284+ self .qw_scale = self .qw_scale .to (self .weight .device )
285+ return self .qw * self .qw_scale
286+
287+ def ste_weight_quant (self ):
288+ w = self .weight
289+ return w + (self .weight_quant (w ) - w ).detach ()
267290
268291 def export_qweight (self ):
269292 return {
@@ -315,14 +338,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
315338 x_norm = SimpleRMSNorm (self .in_features )(x ) if self .use_norm else x
316339 x_norm = HadamardTransform ()(x ) if self .haramard else x_norm
317340 x_quant = x_norm + (self .act_quant (x_norm ) - x_norm ).detach ()
318- w_quant = w + ( self .weight_quant ( w ) - w ). detach ()
341+ w_quant = self .ste_weight_quant ()
319342 y = F .linear (x_quant , w_quant , bias = self .bias )
320343 return y
321344 elif self .qforward is True :
322345 # Forward with Post Training Quantization (PTQ)
323346 # Only for inference
324- qx = self .act_quant (x )
325- y = F .linear (qx , self .qw * self .qw_scale , bias = self .bias )
347+ x_norm = SimpleRMSNorm (self .in_features )(x ) if self .use_norm else x
348+ x_norm = HadamardTransform ()(x ) if self .haramard else x_norm
349+ qx = self .act_quant (x_norm )
350+ y = F .linear (qx , self .qweight (), bias = self .bias )
326351 return y
327352 else :
328353 return F .linear (x , w , bias = self .bias )
@@ -387,15 +412,16 @@ def forward(self, x: torch.Tensor):
387412 # Using Straight-Through-Estimator (STE)
388413 x_norm = self .rmsnorm (x ) if self .use_norm else x
389414 x_quant = x_norm + (activation_nquant_2d (x_norm , self .act_q ) - x_norm ).detach ()
390- w_quant = w + ( self .weight_quant ( w ) - w ). detach ()
415+ w_quant = self .ste_weight_quant ()
391416 y = F .conv2d (x_quant , w_quant , self .bias , self .stride , self .padding ,
392417 self .dilation , self .groups )
393418 return y
394419 elif self .qforward :
395420 # Forward with Post Training Quantization (PTQ)
396421 # Only for inference
397- qx = activation_nquant_2d (x , self .act_q )
398- y = F .conv2d (qx , self .qw * self .qw_scale , self .bias , self .stride , self .padding ,
422+ x_norm = self .rmsnorm (x ) if self .use_norm else x
423+ qx = activation_nquant_2d (x_norm , self .act_q )
424+ y = F .conv2d (qx , self .qweight (), self .bias , self .stride , self .padding ,
399425 self .dilation , self .groups )
400426 return y
401427 else :
@@ -449,12 +475,13 @@ def forward(self, x: torch.Tensor):
449475 if self .qat :
450476 x_norm = self .rmsnorm (x ) if self .use_norm else x
451477 x_quant = x_norm + (activation_nquant (x_norm , self .act_q ) - x_norm ).detach ()
452- w_quant = w + ( self .weight_quant ( w ) - w ). detach ()
478+ w_quant = self .ste_weight_quant ()
453479 return F .conv1d (x_quant , w_quant , self .bias , self .stride ,
454480 self .padding , self .dilation , self .groups )
455481 elif self .qforward is True :
456- qx = activation_nquant (x , self .act_q )
457- return F .conv1d (qx , self .qw * self .qw_scale , self .bias ,
482+ x_norm = self .rmsnorm (x ) if self .use_norm else x
483+ qx = activation_nquant (x_norm , self .act_q )
484+ return F .conv1d (qx , self .qweight (), self .bias ,
458485 self .stride , self .padding , self .dilation , self .groups )
459486 else :
460487 return F .conv1d (x , w , self .bias , self .stride ,
0 commit comments