@@ -18,12 +18,32 @@ def __init__(self, name, input_names=None, params=None):
1818ATTENTION_QUANT_FP8 = "fp8"
1919
2020
21+ def _parse_int_quant (quant ):
22+ text = str (quant ).lower ()
23+ if text .startswith ("int" ) and text [3 :].isdigit ():
24+ bits = int (text [3 :])
25+ elif text .startswith ("i" ) and text [1 :].isdigit ():
26+ bits = int (text [1 :])
27+ elif text .isdigit ():
28+ bits = int (text )
29+ else :
30+ return None
31+ if 1 <= bits <= 31 :
32+ return bits
33+ return None
34+
35+
2136def _normalize_attention_quant (quant ):
2237 if quant is None or quant is False :
2338 return ATTENTION_QUANT_NONE
2439 if quant is True :
2540 return ATTENTION_QUANT_INT8
2641 quant = str (quant ).lower ()
42+ int_bits = _parse_int_quant (quant )
43+ if int_bits is not None :
44+ if int_bits >= 32 :
45+ return ATTENTION_QUANT_NONE
46+ return f"int{ int_bits } "
2747 aliases = {
2848 "none" : ATTENTION_QUANT_NONE ,
2949 "fp32" : ATTENTION_QUANT_NONE ,
@@ -36,10 +56,6 @@ def _normalize_attention_quant(quant):
3656 "1.5" : ATTENTION_QUANT_BITNET ,
3757 "1.58bit" : ATTENTION_QUANT_BITNET ,
3858 "ternary" : ATTENTION_QUANT_BITNET ,
39- "int8" : ATTENTION_QUANT_INT8 ,
40- "i8" : ATTENTION_QUANT_INT8 ,
41- "8" : ATTENTION_QUANT_INT8 ,
42- "8.0" : ATTENTION_QUANT_INT8 ,
4359 "fp8" : ATTENTION_QUANT_FP8 ,
4460 "float8" : ATTENTION_QUANT_FP8 ,
4561 "e4m3" : ATTENTION_QUANT_FP8 ,
@@ -58,23 +74,24 @@ def attention_qtype_to_quant(qtype):
5874 return ATTENTION_QUANT_NONE
5975 if qtype >= 32 :
6076 return ATTENTION_QUANT_NONE
61- if qtype == 8 :
62- return ATTENTION_QUANT_INT8
6377 if 0 < qtype < 2 :
6478 return ATTENTION_QUANT_BITNET
65- if qtype == 2 :
66- return ATTENTION_QUANT_BITNET
79+ if float ( qtype ). is_integer () :
80+ return f"int { int ( qtype ) } "
6781 raise ValueError (f"Unsupported attention qtype: { qtype } " )
6882
6983
7084def attention_quant_to_bits (quant ):
7185 quant = _normalize_attention_quant (quant )
7286 if quant == ATTENTION_QUANT_NONE :
7387 return 32
74- if quant == ATTENTION_QUANT_INT8 or quant == ATTENTION_QUANT_FP8 :
88+ if quant == ATTENTION_QUANT_FP8 :
7589 return 8
7690 if quant == ATTENTION_QUANT_BITNET :
7791 return 1.58
92+ int_bits = _parse_int_quant (quant )
93+ if int_bits is not None :
94+ return int_bits
7895 raise ValueError (f"Unsupported attention quantization mode: { quant } " )
7996
8097
@@ -94,6 +111,10 @@ def _resolve_fp8_dtype(fp8_dtype):
94111
95112
96113def fake_quant_int8 (x , dim = None , eps = 1e-8 ):
114+ return fake_quant_int (x , qbit = 8 , dim = dim , eps = eps )
115+
116+
117+ def fake_quant_int (x , qbit = 8 , dim = None , eps = 1e-8 ):
97118 reduce_dims = dim
98119 if dim is None :
99120 max_abs = x .detach ().abs ().amax ()
@@ -102,25 +123,53 @@ def fake_quant_int8(x, dim=None, eps=1e-8):
102123 reduce_dims = (dim ,)
103124 max_abs = x .detach ().abs ().amax (dim = reduce_dims , keepdim = True )
104125
105- scale = max_abs .clamp (min = eps ) / 127.0
106- q = torch .round (x / scale ).clamp (- 128 , 127 )
126+ qbit = int (qbit )
127+ if qbit <= 0 :
128+ raise ValueError (f"qbit must be positive, got { qbit } " )
129+ if qbit == 1 :
130+ if dim is None :
131+ scale = x .detach ().abs ().mean ().clamp (min = eps )
132+ else :
133+ scale = x .detach ().abs ().mean (dim = reduce_dims , keepdim = True ).clamp (min = eps )
134+ q = torch .sign (x )
135+ q = torch .where (q == 0.0 , torch .ones_like (q ), q )
136+ return q * scale
137+
138+ qmax = 2 ** (qbit - 1 ) - 1
139+ qmin = - (2 ** (qbit - 1 ))
140+ scale = max_abs .clamp (min = eps ) / float (qmax )
141+ q = torch .round (x / scale ).clamp (qmin , qmax )
107142 return q * scale
108143
109144
110- def fake_quant_bitnet (x , dim = None , eps = 1e-8 ):
145+ def _normalize_bitnet_scale (bitnet_scale ):
146+ bitnet_scale = str (bitnet_scale ).lower ()
147+ if bitnet_scale not in ["max" , "mean" ]:
148+ raise ValueError (f"Unsupported bitnet scale mode: { bitnet_scale } " )
149+ return bitnet_scale
150+
151+
152+ def fake_quant_bitnet (x , dim = None , eps = 1e-8 , mode = "max" ):
111153 reduce_dims = dim
112- if dim is None :
113- max_abs = x .detach ().abs ().amax ()
154+ mode = _normalize_bitnet_scale (mode )
155+ if dim is not None and isinstance (dim , int ):
156+ reduce_dims = (dim ,)
157+
158+ if mode == "max" :
159+ if dim is None :
160+ denom = x .detach ().abs ().amax ()
161+ else :
162+ denom = x .detach ().abs ().amax (dim = reduce_dims , keepdim = True )
114163 else :
115- if isinstance (dim , int ):
116- reduce_dims = (dim ,)
117- max_abs = x .detach ().abs ().amax (dim = reduce_dims , keepdim = True )
164+ if dim is None :
165+ denom = x .detach ().abs ().mean ()
166+ else :
167+ denom = x .detach ().abs ().mean (dim = reduce_dims , keepdim = True )
118168
119- scale = max_abs .clamp (min = eps )
169+ scale = denom .clamp (min = eps )
120170 q = torch .round (x / scale ).clamp (- 1 , 1 )
121171 return q * scale
122172
123-
124173def fake_quant_fp8 (x , fp8_dtype = "e4m3fn" ):
125174 dtype = _resolve_fp8_dtype (fp8_dtype )
126175 return x .to (dtype ).to (x .dtype )
@@ -135,6 +184,7 @@ def _init_attention_quant(
135184 k_quant = None ,
136185 v_quant = None ,
137186 score_quant = None ,
187+ bitnet_scale = "max" ,
138188 fp8_dtype = "e4m3fn" ,
139189 int_dim = None ,
140190 int_dim_q = None ,
@@ -157,6 +207,7 @@ def _init_attention_quant(
157207 self .score_attention_quant = _normalize_attention_quant (
158208 score_quant if score_quant is not None else quant
159209 )
210+ self .bitnet_scale = _normalize_bitnet_scale (bitnet_scale )
160211 self .fp8_dtype = fp8_dtype
161212 self .int_dim = int_dim
162213 self .int_dim_q = int_dim_q if int_dim_q is not None else int_dim
@@ -178,6 +229,7 @@ def set_quantization(
178229 k_quant = None ,
179230 v_quant = None ,
180231 score_quant = None ,
232+ bitnet_scale = None ,
181233 fp8_dtype = None ,
182234 int_dim = None ,
183235 int_dim_q = None ,
@@ -200,6 +252,8 @@ def set_quantization(
200252 self .score_attention_quant = _normalize_attention_quant (
201253 score_quant if score_quant is not None else quant
202254 )
255+ if bitnet_scale is not None :
256+ self .bitnet_scale = _normalize_bitnet_scale (bitnet_scale )
203257 if fp8_dtype is not None :
204258 self .fp8_dtype = fp8_dtype
205259 if int_dim is not None :
@@ -233,10 +287,11 @@ def set_quantization(
233287
234288 def _quantize_attention_tensor (self , x , int_dim = None , quant = None ):
235289 quant = self .attention_quant if quant is None else _normalize_attention_quant (quant )
236- if quant == ATTENTION_QUANT_INT8 :
237- return fake_quant_int8 (x , dim = int_dim )
290+ int_bits = _parse_int_quant (quant )
291+ if int_bits is not None :
292+ return fake_quant_int (x , qbit = int_bits , dim = int_dim )
238293 if quant == ATTENTION_QUANT_BITNET :
239- return fake_quant_bitnet (x , dim = int_dim )
294+ return fake_quant_bitnet (x , dim = int_dim , mode = self . bitnet_scale )
240295 if quant == ATTENTION_QUANT_FP8 :
241296 return fake_quant_fp8 (x , self .fp8_dtype )
242297 return x
@@ -300,6 +355,37 @@ def forward(self, q, k, v):
300355 return num / (den + self .eps )
301356
302357
358+ class LLaMaAttention (nn .Module ):
359+ def __init__ (self , head_dim : int , dropout : float = 0.0 ,
360+ max_seq_len : int = 256 , use_flash : bool = True ):
361+ super ().__init__ ()
362+ self .head_dim = head_dim
363+ self .dropout = dropout
364+ self .flash = use_flash and hasattr (torch .nn .functional , "scaled_dot_product_attention" )
365+ mask = torch .full ((1 , 1 , max_seq_len , max_seq_len ), float ("-inf" ))
366+ mask = torch .triu (mask , diagonal = 1 )
367+ self .register_buffer ("mask" , mask , persistent = False )
368+
369+ def forward (self , q , k , v ):
370+ if self .flash :
371+ return torch .nn .functional .scaled_dot_product_attention (
372+ q ,
373+ k ,
374+ v ,
375+ attn_mask = None ,
376+ dropout_p = self .dropout if self .training else 0.0 ,
377+ is_causal = True ,
378+ )
379+
380+ seqlen = q .shape [2 ]
381+ scores = torch .matmul (q , k .transpose (2 , 3 )) / (self .head_dim ** 0.5 )
382+ mask = self .mask [:, :, :seqlen , :seqlen ].to (device = scores .device )
383+ scores = scores + mask
384+ scores = F .softmax (scores .float (), dim = - 1 ).type_as (q )
385+ scores = F .dropout (scores , p = self .dropout , training = self .training )
386+ return torch .matmul (scores , v )
387+
388+
303389class LinearAttention (nn .Module ):
304390 def __init__ (self , dim , num_heads = 8 , attention_dropout = 0.1 ,
305391 projection_dropout = 0.1 , eps = 1e-6 , ** kwargs ):
@@ -332,7 +418,7 @@ def forward(self, x):
332418
333419def set_attention_quantization (model , quant = ATTENTION_QUANT_INT8 ,
334420 q_quant = None , kv_quant = None , k_quant = None , v_quant = None ,
335- score_quant = None , fp8_dtype = "e4m3fn" ,
421+ score_quant = None , bitnet_scale = "max" , fp8_dtype = "e4m3fn" ,
336422 int_dim = None , int_dim_q = None , int_dim_k = None ,
337423 int_dim_v = None , int_dim_score = None ,
338424 quantize_q = True , quantize_kv = True ,
@@ -347,6 +433,7 @@ def set_attention_quantization(model, quant=ATTENTION_QUANT_INT8,
347433 k_quant = k_quant ,
348434 v_quant = v_quant ,
349435 score_quant = score_quant ,
436+ bitnet_scale = bitnet_scale ,
350437 fp8_dtype = fp8_dtype ,
351438 int_dim = int_dim ,
352439 int_dim_q = int_dim_q ,
0 commit comments