Skip to content

Commit 7c7791c

Browse files
committed
add LLaMa attention as a module
1 parent 58123ef commit 7c7791c

4 files changed

Lines changed: 244 additions & 67 deletions

File tree

MiCoMisc.py

Lines changed: 110 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,32 @@ def __init__(self, name, input_names=None, params=None):
1818
ATTENTION_QUANT_FP8 = "fp8"
1919

2020

21+
def _parse_int_quant(quant):
22+
text = str(quant).lower()
23+
if text.startswith("int") and text[3:].isdigit():
24+
bits = int(text[3:])
25+
elif text.startswith("i") and text[1:].isdigit():
26+
bits = int(text[1:])
27+
elif text.isdigit():
28+
bits = int(text)
29+
else:
30+
return None
31+
if 1 <= bits <= 31:
32+
return bits
33+
return None
34+
35+
2136
def _normalize_attention_quant(quant):
2237
if quant is None or quant is False:
2338
return ATTENTION_QUANT_NONE
2439
if quant is True:
2540
return ATTENTION_QUANT_INT8
2641
quant = str(quant).lower()
42+
int_bits = _parse_int_quant(quant)
43+
if int_bits is not None:
44+
if int_bits >= 32:
45+
return ATTENTION_QUANT_NONE
46+
return f"int{int_bits}"
2747
aliases = {
2848
"none": ATTENTION_QUANT_NONE,
2949
"fp32": ATTENTION_QUANT_NONE,
@@ -36,10 +56,6 @@ def _normalize_attention_quant(quant):
3656
"1.5": ATTENTION_QUANT_BITNET,
3757
"1.58bit": ATTENTION_QUANT_BITNET,
3858
"ternary": ATTENTION_QUANT_BITNET,
39-
"int8": ATTENTION_QUANT_INT8,
40-
"i8": ATTENTION_QUANT_INT8,
41-
"8": ATTENTION_QUANT_INT8,
42-
"8.0": ATTENTION_QUANT_INT8,
4359
"fp8": ATTENTION_QUANT_FP8,
4460
"float8": ATTENTION_QUANT_FP8,
4561
"e4m3": ATTENTION_QUANT_FP8,
@@ -58,23 +74,24 @@ def attention_qtype_to_quant(qtype):
5874
return ATTENTION_QUANT_NONE
5975
if qtype >= 32:
6076
return ATTENTION_QUANT_NONE
61-
if qtype == 8:
62-
return ATTENTION_QUANT_INT8
6377
if 0 < qtype < 2:
6478
return ATTENTION_QUANT_BITNET
65-
if qtype == 2:
66-
return ATTENTION_QUANT_BITNET
79+
if float(qtype).is_integer():
80+
return f"int{int(qtype)}"
6781
raise ValueError(f"Unsupported attention qtype: {qtype}")
6882

6983

7084
def attention_quant_to_bits(quant):
7185
quant = _normalize_attention_quant(quant)
7286
if quant == ATTENTION_QUANT_NONE:
7387
return 32
74-
if quant == ATTENTION_QUANT_INT8 or quant == ATTENTION_QUANT_FP8:
88+
if quant == ATTENTION_QUANT_FP8:
7589
return 8
7690
if quant == ATTENTION_QUANT_BITNET:
7791
return 1.58
92+
int_bits = _parse_int_quant(quant)
93+
if int_bits is not None:
94+
return int_bits
7895
raise ValueError(f"Unsupported attention quantization mode: {quant}")
7996

8097

@@ -94,6 +111,10 @@ def _resolve_fp8_dtype(fp8_dtype):
94111

95112

96113
def fake_quant_int8(x, dim=None, eps=1e-8):
114+
return fake_quant_int(x, qbit=8, dim=dim, eps=eps)
115+
116+
117+
def fake_quant_int(x, qbit=8, dim=None, eps=1e-8):
97118
reduce_dims = dim
98119
if dim is None:
99120
max_abs = x.detach().abs().amax()
@@ -102,25 +123,53 @@ def fake_quant_int8(x, dim=None, eps=1e-8):
102123
reduce_dims = (dim,)
103124
max_abs = x.detach().abs().amax(dim=reduce_dims, keepdim=True)
104125

105-
scale = max_abs.clamp(min=eps) / 127.0
106-
q = torch.round(x / scale).clamp(-128, 127)
126+
qbit = int(qbit)
127+
if qbit <= 0:
128+
raise ValueError(f"qbit must be positive, got {qbit}")
129+
if qbit == 1:
130+
if dim is None:
131+
scale = x.detach().abs().mean().clamp(min=eps)
132+
else:
133+
scale = x.detach().abs().mean(dim=reduce_dims, keepdim=True).clamp(min=eps)
134+
q = torch.sign(x)
135+
q = torch.where(q == 0.0, torch.ones_like(q), q)
136+
return q * scale
137+
138+
qmax = 2 ** (qbit - 1) - 1
139+
qmin = -(2 ** (qbit - 1))
140+
scale = max_abs.clamp(min=eps) / float(qmax)
141+
q = torch.round(x / scale).clamp(qmin, qmax)
107142
return q * scale
108143

109144

110-
def fake_quant_bitnet(x, dim=None, eps=1e-8):
145+
def _normalize_bitnet_scale(bitnet_scale):
146+
bitnet_scale = str(bitnet_scale).lower()
147+
if bitnet_scale not in ["max", "mean"]:
148+
raise ValueError(f"Unsupported bitnet scale mode: {bitnet_scale}")
149+
return bitnet_scale
150+
151+
152+
def fake_quant_bitnet(x, dim=None, eps=1e-8, mode="max"):
111153
reduce_dims = dim
112-
if dim is None:
113-
max_abs = x.detach().abs().amax()
154+
mode = _normalize_bitnet_scale(mode)
155+
if dim is not None and isinstance(dim, int):
156+
reduce_dims = (dim,)
157+
158+
if mode == "max":
159+
if dim is None:
160+
denom = x.detach().abs().amax()
161+
else:
162+
denom = x.detach().abs().amax(dim=reduce_dims, keepdim=True)
114163
else:
115-
if isinstance(dim, int):
116-
reduce_dims = (dim,)
117-
max_abs = x.detach().abs().amax(dim=reduce_dims, keepdim=True)
164+
if dim is None:
165+
denom = x.detach().abs().mean()
166+
else:
167+
denom = x.detach().abs().mean(dim=reduce_dims, keepdim=True)
118168

119-
scale = max_abs.clamp(min=eps)
169+
scale = denom.clamp(min=eps)
120170
q = torch.round(x / scale).clamp(-1, 1)
121171
return q * scale
122172

123-
124173
def fake_quant_fp8(x, fp8_dtype="e4m3fn"):
125174
dtype = _resolve_fp8_dtype(fp8_dtype)
126175
return x.to(dtype).to(x.dtype)
@@ -135,6 +184,7 @@ def _init_attention_quant(
135184
k_quant=None,
136185
v_quant=None,
137186
score_quant=None,
187+
bitnet_scale="max",
138188
fp8_dtype="e4m3fn",
139189
int_dim=None,
140190
int_dim_q=None,
@@ -157,6 +207,7 @@ def _init_attention_quant(
157207
self.score_attention_quant = _normalize_attention_quant(
158208
score_quant if score_quant is not None else quant
159209
)
210+
self.bitnet_scale = _normalize_bitnet_scale(bitnet_scale)
160211
self.fp8_dtype = fp8_dtype
161212
self.int_dim = int_dim
162213
self.int_dim_q = int_dim_q if int_dim_q is not None else int_dim
@@ -178,6 +229,7 @@ def set_quantization(
178229
k_quant=None,
179230
v_quant=None,
180231
score_quant=None,
232+
bitnet_scale=None,
181233
fp8_dtype=None,
182234
int_dim=None,
183235
int_dim_q=None,
@@ -200,6 +252,8 @@ def set_quantization(
200252
self.score_attention_quant = _normalize_attention_quant(
201253
score_quant if score_quant is not None else quant
202254
)
255+
if bitnet_scale is not None:
256+
self.bitnet_scale = _normalize_bitnet_scale(bitnet_scale)
203257
if fp8_dtype is not None:
204258
self.fp8_dtype = fp8_dtype
205259
if int_dim is not None:
@@ -233,10 +287,11 @@ def set_quantization(
233287

234288
def _quantize_attention_tensor(self, x, int_dim=None, quant=None):
235289
quant = self.attention_quant if quant is None else _normalize_attention_quant(quant)
236-
if quant == ATTENTION_QUANT_INT8:
237-
return fake_quant_int8(x, dim=int_dim)
290+
int_bits = _parse_int_quant(quant)
291+
if int_bits is not None:
292+
return fake_quant_int(x, qbit=int_bits, dim=int_dim)
238293
if quant == ATTENTION_QUANT_BITNET:
239-
return fake_quant_bitnet(x, dim=int_dim)
294+
return fake_quant_bitnet(x, dim=int_dim, mode=self.bitnet_scale)
240295
if quant == ATTENTION_QUANT_FP8:
241296
return fake_quant_fp8(x, self.fp8_dtype)
242297
return x
@@ -300,6 +355,37 @@ def forward(self, q, k, v):
300355
return num / (den + self.eps)
301356

302357

358+
class LLaMaAttention(nn.Module):
359+
def __init__(self, head_dim: int, dropout: float = 0.0,
360+
max_seq_len: int = 256, use_flash: bool = True):
361+
super().__init__()
362+
self.head_dim = head_dim
363+
self.dropout = dropout
364+
self.flash = use_flash and hasattr(torch.nn.functional, "scaled_dot_product_attention")
365+
mask = torch.full((1, 1, max_seq_len, max_seq_len), float("-inf"))
366+
mask = torch.triu(mask, diagonal=1)
367+
self.register_buffer("mask", mask, persistent=False)
368+
369+
def forward(self, q, k, v):
370+
if self.flash:
371+
return torch.nn.functional.scaled_dot_product_attention(
372+
q,
373+
k,
374+
v,
375+
attn_mask=None,
376+
dropout_p=self.dropout if self.training else 0.0,
377+
is_causal=True,
378+
)
379+
380+
seqlen = q.shape[2]
381+
scores = torch.matmul(q, k.transpose(2, 3)) / (self.head_dim ** 0.5)
382+
mask = self.mask[:, :, :seqlen, :seqlen].to(device=scores.device)
383+
scores = scores + mask
384+
scores = F.softmax(scores.float(), dim=-1).type_as(q)
385+
scores = F.dropout(scores, p=self.dropout, training=self.training)
386+
return torch.matmul(scores, v)
387+
388+
303389
class LinearAttention(nn.Module):
304390
def __init__(self, dim, num_heads=8, attention_dropout=0.1,
305391
projection_dropout=0.1, eps=1e-6, **kwargs):
@@ -332,7 +418,7 @@ def forward(self, x):
332418

333419
def set_attention_quantization(model, quant=ATTENTION_QUANT_INT8,
334420
q_quant=None, kv_quant=None, k_quant=None, v_quant=None,
335-
score_quant=None, fp8_dtype="e4m3fn",
421+
score_quant=None, bitnet_scale="max", fp8_dtype="e4m3fn",
336422
int_dim=None, int_dim_q=None, int_dim_k=None,
337423
int_dim_v=None, int_dim_score=None,
338424
quantize_q=True, quantize_kv=True,
@@ -347,6 +433,7 @@ def set_attention_quantization(model, quant=ATTENTION_QUANT_INT8,
347433
k_quant=k_quant,
348434
v_quant=v_quant,
349435
score_quant=score_quant,
436+
bitnet_scale=bitnet_scale,
350437
fp8_dtype=fp8_dtype,
351438
int_dim=int_dim,
352439
int_dim_q=int_dim_q,

MiCoQLayers.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from MiCoMisc import (
66
AttentionQuantMixin,
77
AttentionScore,
8+
LLaMaAttention,
89
LinearAttentionScore,
910
ATTENTION_QUANT_NONE,
1011
attention_qtype_to_quant,
@@ -470,6 +471,7 @@ def _init_bit_attention(
470471
v_qtype=DEFAULT_W_Q,
471472
score_qtype=DEFAULT_ACT_Q,
472473
qat=False,
474+
bitnet_scale="max",
473475
fp8_dtype="e4m3fn",
474476
int_dim=None,
475477
int_dim_q=None,
@@ -496,6 +498,7 @@ def _init_bit_attention(
496498
k_quant=attention_qtype_to_quant(k_qtype),
497499
v_quant=attention_qtype_to_quant(v_qtype),
498500
score_quant=attention_qtype_to_quant(score_qtype),
501+
bitnet_scale=bitnet_scale,
499502
fp8_dtype=fp8_dtype,
500503
int_dim=int_dim,
501504
int_dim_q=int_dim_q,
@@ -562,6 +565,7 @@ def __init__(self, scale: float,
562565
v_qtype=DEFAULT_W_Q,
563566
score_qtype=DEFAULT_ACT_Q,
564567
qat=False,
568+
bitnet_scale="max",
565569
fp8_dtype="e4m3fn",
566570
int_dim=None,
567571
int_dim_q=None,
@@ -580,6 +584,7 @@ def __init__(self, scale: float,
580584
v_qtype=v_qtype,
581585
score_qtype=score_qtype,
582586
qat=qat,
587+
bitnet_scale=bitnet_scale,
583588
fp8_dtype=fp8_dtype,
584589
int_dim=int_dim,
585590
int_dim_q=int_dim_q,
@@ -616,6 +621,7 @@ def __init__(self, eps=1e-6,
616621
v_qtype=DEFAULT_W_Q,
617622
score_qtype=DEFAULT_ACT_Q,
618623
qat=False,
624+
bitnet_scale="max",
619625
fp8_dtype="e4m3fn",
620626
int_dim=None,
621627
int_dim_q=None,
@@ -634,6 +640,7 @@ def __init__(self, eps=1e-6,
634640
v_qtype=v_qtype,
635641
score_qtype=score_qtype,
636642
qat=qat,
643+
bitnet_scale=bitnet_scale,
637644
fp8_dtype=fp8_dtype,
638645
int_dim=int_dim,
639646
int_dim_q=int_dim_q,
@@ -666,3 +673,69 @@ def forward(self, q, k, v):
666673
num = torch.einsum("bhnd,bhdm->bnhm", q, context)
667674
den = torch.einsum("bhnd,bhd->bnh", q, k_sum).unsqueeze(-1)
668675
return num / (den + self.eps)
676+
677+
678+
class BitLLaMaAttention(LLaMaAttention, _BitAttentionBase):
679+
def __init__(self, head_dim: int, dropout: float = 0.0,
680+
max_seq_len: int = 256,
681+
q_qtype=DEFAULT_ACT_Q,
682+
k_qtype=DEFAULT_W_Q,
683+
v_qtype=DEFAULT_W_Q,
684+
score_qtype=DEFAULT_ACT_Q,
685+
qat=False,
686+
bitnet_scale="max",
687+
fp8_dtype="e4m3fn",
688+
int_dim=None,
689+
int_dim_q=None,
690+
int_dim_k=None,
691+
int_dim_v=None,
692+
int_dim_score=None,
693+
quantize_q=True,
694+
quantize_k=True,
695+
quantize_v=True,
696+
quantize_score=True):
697+
LLaMaAttention.__init__(
698+
self,
699+
head_dim=head_dim,
700+
dropout=dropout,
701+
max_seq_len=max_seq_len,
702+
use_flash=False,
703+
)
704+
self.layer_type = "LLaMaAttention"
705+
self._init_bit_attention(
706+
q_qtype=q_qtype,
707+
k_qtype=k_qtype,
708+
v_qtype=v_qtype,
709+
score_qtype=score_qtype,
710+
qat=qat,
711+
bitnet_scale=bitnet_scale,
712+
fp8_dtype=fp8_dtype,
713+
int_dim=int_dim,
714+
int_dim_q=int_dim_q,
715+
int_dim_k=int_dim_k,
716+
int_dim_v=int_dim_v,
717+
int_dim_score=int_dim_score,
718+
quantize_q=quantize_q,
719+
quantize_k=quantize_k,
720+
quantize_v=quantize_v,
721+
quantize_score=quantize_score,
722+
)
723+
724+
def forward(self, q, k, v):
725+
B, H, I, Fdim = q.shape
726+
J = k.shape[2]
727+
self.score_macs = B * H * I * J * Fdim
728+
self.context_macs = B * H * I * J * Fdim
729+
self.macs = self.score_macs + self.context_macs
730+
self.layer_features = [B, H, I, J, Fdim]
731+
732+
q = self._quantize_q(q)
733+
k = self._quantize_k(k)
734+
v = self._quantize_v(v)
735+
scores = torch.matmul(q, k.transpose(2, 3)) / (self.head_dim ** 0.5)
736+
mask = self.mask[:, :, :I, :J].to(device=scores.device)
737+
scores = scores + mask
738+
scores = F.softmax(scores.float(), dim=-1).type_as(q)
739+
scores = self._quantize_score(scores)
740+
scores = F.dropout(scores, p=self.dropout, training=self.training)
741+
return torch.matmul(scores, v)

0 commit comments

Comments
 (0)