-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
108 lines (89 loc) · 3.88 KB
/
Copy pathmodel.py
File metadata and controls
108 lines (89 loc) · 3.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""Flax transformer with classification and GPD heads."""
from __future__ import annotations
from typing import Any, Tuple
import jax
import jax.numpy as jnp
from flax import linen as nn
class MultiHeadSelfAttention(nn.Module):
"""Self-attention; returns (output, attn_probs) for optional plotting."""
num_heads: int
d_model: int
dropout_rate: float = 0.1
dtype: Any = jnp.float32
@nn.compact
def __call__(self, x: jnp.ndarray, *, deterministic: bool) -> Tuple[jnp.ndarray, jnp.ndarray]:
b, t, _ = x.shape
d_head = self.d_model // self.num_heads
q = nn.Dense(self.d_model, dtype=self.dtype, name="query")(x)
k = nn.Dense(self.d_model, dtype=self.dtype, name="key")(x)
v = nn.Dense(self.d_model, dtype=self.dtype, name="value")(x)
q = q.reshape(b, t, self.num_heads, d_head).transpose(0, 2, 1, 3)
k = k.reshape(b, t, self.num_heads, d_head).transpose(0, 2, 1, 3)
v = v.reshape(b, t, self.num_heads, d_head).transpose(0, 2, 1, 3)
scale = jnp.asarray(d_head**-0.5, dtype=x.dtype)
logits = jnp.einsum("bhqd,bhkd->bhqk", q, k) * scale
attn = jax.nn.softmax(logits, axis=-1)
attn = nn.Dropout(rate=self.dropout_rate)(attn, deterministic=deterministic)
out = jnp.einsum("bhqk,bhkd->bhqd", attn, v)
b, h, t, dh = out.shape
out = out.transpose(0, 2, 1, 3).reshape(b, t, self.d_model)
out = nn.Dense(self.d_model, dtype=self.dtype, name="out")(out)
return out, attn
class TailAwareTransformer(nn.Module):
"""Transformer over (batch, time, features) with pooled dual heads."""
d_model: int = 128
num_heads: int = 4
num_layers: int = 2
mlp_dim: int = 256
num_classes: int = 2
dropout_rate: float = 0.1
dtype: Any = jnp.float32
@nn.compact
def __call__(
self,
x: jnp.ndarray,
*,
train: bool = False,
return_attention: bool = False,
):
"""
x: (B, T, F) raw features (scaled upstream).
Returns:
logits: (B, num_classes)
xi, sigma: (B,) GPD parameters (sigma > 0 via softplus).
attn_weights: optional (B, H, T, T) from last self-attention layer.
"""
b, t, f = x.shape
h = nn.Dense(self.d_model, dtype=self.dtype, name="in_proj")(x)
positions = self.param("pos_embed", nn.initializers.normal(0.02), (1, t, self.d_model))
h = h + positions
last_attn = None
for i in range(self.num_layers):
ln1 = nn.LayerNorm(dtype=self.dtype, name=f"ln1_{i}")
attn_layer = MultiHeadSelfAttention(
num_heads=self.num_heads,
d_model=self.d_model,
dropout_rate=self.dropout_rate,
dtype=self.dtype,
name=f"attn_{i}",
)
y = ln1(h)
attn_out, attn_probs = attn_layer(y, deterministic=not train)
h = h + attn_out
if i == self.num_layers - 1:
last_attn = attn_probs
ln2_in = nn.LayerNorm(dtype=self.dtype, name=f"ln2_{i}")(h)
z = nn.Dense(self.mlp_dim, dtype=self.dtype, name=f"ff1_{i}")(ln2_in)
z = nn.gelu(z)
z = nn.Dropout(rate=self.dropout_rate, name=f"dr1_{i}")(z, deterministic=not train)
z = nn.Dense(self.d_model, dtype=self.dtype, name=f"ff2_{i}")(z)
h = h + z
pooled = jnp.mean(h, axis=1)
logits = nn.Dense(self.num_classes, dtype=self.dtype, name="cls_head")(pooled)
xi_raw = nn.Dense(1, dtype=self.dtype, name="xi_head")(pooled).squeeze(-1)
sigma_raw = nn.Dense(1, dtype=self.dtype, name="sigma_head")(pooled).squeeze(-1)
xi = jnp.tanh(xi_raw)
sigma = jax.nn.softplus(sigma_raw) + jnp.asarray(1e-4, dtype=sigma_raw.dtype)
if return_attention:
return logits, xi, sigma, last_attn
return logits, xi, sigma