Paper ID: MIRAS Full Title: It's All Connected: A Journey Through Test-Time Memorization, Attentional Bias, Retention, and Online Optimization Authors: Ali Behrouz, Meisam Razaviyayn, Peilin Zhong, and Vahab Mirrokni (Google Research) arXiv: arXiv:2504.13173v1 [cs.LG] Year: 2025
Key Contribution (1 sentence): MIRAS is a general framework to design deep learning architectures based on four choices: (i) associative memory architecture, (ii) attentional bias objective, (iii) retention gate, and (iv) memory learning algorithm, unifying Transformers, Titans, and modern linear RNNs while introducing three novel sequence models (Moneta, Yaad, Memora).
| Aspect | Count | Key Items |
|---|---|---|
| Equations | 32+ | Attention (Eq 1-2), Memory Updates (Eq 3-32), Novel Variants |
| Algorithms | 3 | Moneta, Yaad, Memora with parallelization |
| Architectures | 3 | 2-layer MLP memory with different objectives/gates |
| Hyperparameters | 15+ | Learning rates, retention gates, chunk sizes, p/q values |
Equation 1-2: Transformer Attention
LaTeX:
Q = xW_Q, \quad K = xW_K, \quad V = xW_VPlain Text: Query, Key, and Value matrices are computed as linear projections of input x.
Variables:
| Variable | Shape | Description |
|---|---|---|
| x | (N, d_in) | Input sequence |
| W_Q, W_K, W_V | (d_in, d_in) | Learnable projection matrices |
| Q, K, V | (N, d_in) | Query, Key, Value matrices |
| N | scalar | Sequence length |
| d_in | scalar | Input dimension |
Equation 2: Causal Softmax Attention
LaTeX:
y_i = \sum_{j=1}^{i} \frac{\exp(q_i^\top k_j / \sqrt{d_{in}}) v_j}{\sum_{\ell=1}^{i} \exp(q_i^\top k_\ell / \sqrt{d_{in}})}Plain Text: Output at position i is weighted sum of values, where weights are softmax of scaled dot-product similarities.
Variables:
| Variable | Shape | Description |
|---|---|---|
| y_i | (d_in,) | Output at position i |
| q_i | (d_in,) | Query vector at position i |
| k_j | (d_in,) | Key vector at position j |
| v_j | (d_in,) | Value vector at position j |
Implementation Notes:
- Temperature scaling: 1/√d_in
- Causal masking: only attend to positions j ≤ i
- This is non-parametric solution to ℓ2-MSE loss
Dependencies: None (baseline)
Test Criteria:
- Attention weights sum to 1 per position
- Output shape matches input shape
- Gradient flows through all components
Equation 3: General Linear RNN Memory
LaTeX:
M_t = A_t * M_{t-1} + v_t k_t^\topPlain Text: Memory at time t equals previous memory scaled by A_t plus outer product of current value and key.
Variables:
| Variable | Shape | Description |
|---|---|---|
| M_t | (d, n) | Memory state at time t |
| A_t | (d, d) or scalar | Decay/retention parameter (data-dependent or fixed) |
| v_t | (d,) | Value vector at time t |
| k_t | (d,) | Key vector at time t |
| * | operator | Arbitrary associative operator |
| n | scalar | Memory width (1=vector, d=matrix) |
Implementation Notes:
- When n=1: vector-valued memory (RetNet)
- When n=d: matrix-valued memory (Linear Attention)
- A_t can be scalar (RetNet), diagonal matrix (GLA), or data-dependent (Mamba)
Special Cases:
- α=1: Linear Attention
- α learnable scalar: RetNet, Lightning Attention
- α_t data-dependent: Mamba2, GLA
Dependencies: Key/value projections from input
Test Criteria:
- Memory maintains shape (d, n)
- Causality: M_t only depends on inputs up to t
- Numerical stability for long sequences
Equation 9: Delta Rule with Retention
LaTeX:
M_t = \alpha (I - \eta_t k_t k_t^\top) M_{t-1} + v_t k_t^\topPlain Text: Memory updated by delta rule: decay previous memory, subtract key-projection, add new key-value pair.
Variables:
| Variable | Shape | Description |
|---|---|---|
| M_t | (d, n) | Memory at time t |
| α | scalar or (d,) | Retention coefficient |
| η_t | scalar | Learning rate at time t |
| I | (d, d) | Identity matrix |
| k_t | (d,) | Key vector |
| v_t | (d,) | Value vector |
Implementation Notes:
- Delta rule removes previous associations proportionally to k_t
- When α=1: DeltaNet
- When α_t data-dependent (scalar): Gated DeltaNet
- When α_t data-dependent (vector): RWKV-7
Special Cases (Table 1):
DeltaNet: α=1, η_t constant
Gated DeltaNet: α_t ∈ ℝ (scalar), data-dependent
RWKV-7: α_t ∈ ℝ^d (vector), data-dependent
Dependencies: Gradient of ℓ2 loss
Test Criteria:
- Memory rank can decrease (value replacement)
- Bounded memory norm for stable α < 1
- Orthogonal key updates don't interfere
Equation 4: Associative Memory with Attentional Bias
LaTeX:
M^* = \arg\min_M \mathcal{L}(M(\mathcal{K}); \mathcal{V})Plain Text: Optimal memory minimizes attentional bias objective L over mapping from keys K to values V.
Variables:
| Variable | Shape | Description |
|---|---|---|
| M | function | Memory operator: K → V |
| K | set in ℝ^{d_k} | Set of keys |
| V | set in ℝ^{d_v} | Set of values |
| L | function | Attentional bias objective |
Implementation Notes:
- This is meta-learning: inner loop optimizes L, outer loop optimizes other parameters
- L determines similarity metric and prioritization
- Can use any optimizer (GD, momentum, Newton, non-parametric)
Dependencies: None (core definition)
Equation 5: Memory Update via Gradient Descent
LaTeX:
W_t = W_{t-1} - \eta_t \nabla \ell(W_{t-1}; k_t, v_t)Plain Text: Memory parameters updated by gradient descent on loss for current key-value pair.
Variables:
| Variable | Shape | Description |
|---|---|---|
| W_t | (varies) | Memory parameters at time t |
| η_t | scalar | Learning rate at time t |
| ℓ(W; k, v) | scalar | Loss for single key-value pair |
| ∇ℓ | (same as W) | Gradient of loss |
Implementation Notes:
- W can parameterize linear map, MLP, or other architecture
- ℓ is attentional bias for single example: ℓ(W; k, v) = L(M(W, k), v)
- η_t can be data-dependent (learnable)
Equation 7: Follow-The-Regularized-Leader
LaTeX:
W_t = \arg\min_W \sum_{i=1}^{t} \langle W - W_{i-1}, \nabla\ell(W_{i-1}; k_i, v_i) \rangle + \frac{1}{2\eta} \|W\|_2^2Plain Text: Memory minimizes sum of linearized losses plus quadratic regularization.
Variables:
| Variable | Shape | Description |
|---|---|---|
| W_t | (varies) | Memory at time t |
| η | scalar | Regularization strength (inverse learning rate) |
| ⟨·,·⟩ | scalar | Inner product |
Implementation Notes:
- First term: learn from all past tokens (linearized)
- Second term: regularize memory size
- Equivalent to online GD with specific initialization
Dependencies: Sequence of gradients
Test Criteria:
- Convex optimization per step
- Equivalent to Eq 5 when W_0 = 0
LaTeX:
W_t = \arg\min_{W \in \mathcal{W}} \tilde{\ell}_t(W; k_t, v_t) + \text{Ret}_t(W, W_{t-1})Plain Text: Memory minimizes loss on current token plus retention term that keeps memory close to previous state.
Components:
\text{Ret}_t(W, W_{t-1}) = \frac{1}{\eta_t} D_t(W, W_{t-1}) + \frac{1}{\alpha_t} G_t(W)Variables:
| Variable | Shape | Description |
|---|---|---|
| W_t | (varies) | Memory at time t |
| ℓ̃_t | function | Approximation of loss (e.g., linearization) |
| Ret_t | function | Retention function (forget gate) |
| D_t | function | Local retention (premetric) |
| G_t | function | Global retention (regularization) |
| η_t | scalar | Learning rate (controls new learning) |
| α_t | scalar | Global retention rate |
| W | set | Constraint set for memory |
Implementation Notes:
- ℓ̃_t can be exact loss or linearization
- D_t measures deviation from W_{t-1} (retention of past)
- G_t controls memory size/norm
- η_t: larger = more learning, more forgetting
- α_t: controls global memory properties
Relation to FTRL (Proposition 3.2): When h_t(W) = ∑_{i=1}^{t-1} ℓ̃_i(W; k_i, v_i) + (1/η)R(W) is strictly convex, setting Ret_t(W, W') = D_h(W, W') (Bregman divergence) makes Learning-Retaining equivalent to FTRL.
Equation 10-11: ℓ_p Loss and Gradient
LaTeX:
\mathcal{L}(M(W, k_t); v_t) = \|M(k_t) - v_t\|_p^pGradient for matrix memory M(W, k) = Wk:
W_t = W_t - \eta_t \cdot p \cdot (\text{Sign}(Wk_t - v_t) \odot |Wk_t - v_t|^{p-1}) k_t^\topPlain Text: ℓ_p norm objective measures distance; gradient involves element-wise sign and power p-1.
Variables:
| Variable | Shape | Description |
|---|---|---|
| p | scalar ≥ 1 | Norm parameter |
| ∥·∥_p | scalar | p-norm |
| Sign(·) | element-wise | Sign function (-1, 0, +1) |
| ⊙ | operator | Hadamard (element-wise) product |
| · |
Special Case p=1 (Equation 12):
W_t = W_t - \eta_t \text{Sign}(W_t k_t - v_t) k_t^\top"Value-less" memory: only stores signs (-1, +1), robust to value magnitudes.
Smooth Approximations (Remark 5):
\text{Sign}(x) \approx \tanh(\alpha x)
|x| = \sqrt{x^2 + \epsilon}, \quad \epsilon = 10^{-6}Implementation Notes:
- p=2: Standard MSE (existing models)
- p=1: Robust to outliers, maps to {-1, +1}
- p=3: Paper uses for Moneta
- p=4: Worst performance in experiments
- Requires smooth approximations for backprop
Dependencies: Memory architecture
Test Criteria:
- Gradient finite and bounded
- p=2 recovers standard delta rule
- Sign approximation smooth
Equation 13-16: Huber Loss Variants
Huber function:
H(a) = \begin{cases}
\frac{1}{2} a^2 & \text{if } |a| \leq \delta \\
\delta (|a| - \frac{1}{2}\delta) & \text{if } |a| > \delta
\end{cases}Variant 1: Coordinate-wise (Eq 14)
\ell(W; k_t, v_t) = \sum_j H(M(W, k_t)_j - v_{t,j})Gradient (matrix memory):
W_t = W_{t-1} - \eta_t [(Wk_t - v_t)k_t^T \odot I(|Wk_t - v_t| \leq \delta_t)1^\top
+ (\delta_t \text{Sign}(Wk_t - v_t)k_t^\top) \odot I(|Wk_t - v_t| > \delta_t)1^\top]Variant 2: ℓ2 norm-based (Eq 15)
\ell(W; k_t, v_t) = H(\|M(W, k_t) - v_t\|_2)Gradient:
W_t = W_{t-1} - \eta_t \begin{cases}
(M(W_{t-1}, k_t) - v_t) k_t^T & \text{if } \|M(W_{t-1}, k_t) - v_t\|_2 \leq \delta_t \\
\delta_t \frac{M(W_{t-1}, k_t) - v_t}{\|M(W_{t-1}, k_t) - v_t\|_2} k_t^T & \text{otherwise}
\end{cases}Variant 3: Smooth mixture (Eq 16, used in Yaad)
W_t = W_{t-1} - \begin{cases}
\eta_t \nabla\ell_2(W_{t-1}; k_t, v_t) & \text{if } \|M(k_t) - v_t\| \leq \delta_t \\
\eta_t \delta_t \nabla\ell_1(W_{t-1}; k_t, v_t) & \text{otherwise}
\end{cases}Variables:
| Variable | Shape | Description |
|---|---|---|
| δ_t | scalar or (d,) | Huber threshold (can be data-dependent) |
| I(·) | element-wise | Indicator function (0 or 1) |
| H(·) | scalar | Huber function |
| ∇ℓ_1, ∇ℓ_2 | (same as W) | Gradients of ℓ1 and ℓ2 losses |
Implementation Notes:
- δ_t controls transition between quadratic and linear
- Small errors: use ℓ2 (smooth, efficient)
- Large errors: use ℓ1 or normalized (robust)
- δ_t can be learnable, data-dependent
- Variant 3 is smoothest, used in Yaad
Dependencies: Memory architecture
Test Criteria:
- Continuous gradient at threshold
- Reduces to ℓ2 when all errors small
- Robust to outliers (verified on S-NIAH tasks)
LaTeX:
\mathcal{L}(M(W, k_t); v_t) = \max_{\|\delta v_t\|_2 \leq \Delta} \frac{1}{2} \|M(W, k_t) - (v_t + \delta v_t)\|_2^2Optimal perturbation:
\delta v_t^* = \Delta \frac{-M(W, k_t) + v_t}{\|M(W, k_t) - v_t\|_2}Resulting loss:
\mathcal{L}(M(W, k_t); v_t) = \frac{1}{2}\|M(W, k_t) - v_t\|_2^2 + \Delta\|M(W, k_t) - v_t\|_2 + \frac{1}{2}\Delta^2Gradient (linear memory):
W_t = W_{t-1} - \eta [(M(W_{t-1}, k_t) - v_t)k_t^\top + \Delta \frac{M(W_{t-1}, k_t) - v_t}{\|M(W_{t-1}, k_t) - v_t\|_2} k_t^\top]Variables:
| Variable | Shape | Description |
|---|---|---|
| Δ | scalar | Robustness radius |
| δv_t | (d,) | Perturbation to value |
Implementation Notes:
- Δ controls trade-off: fit nominal data vs. robustness
- Can make Δ learnable parameter
- Combines ℓ2 loss with ℓ2 norm of error (similar to Huber)
Dependencies: Memory architecture
Equation 18: General f-divergence update
LaTeX:
W_t = W_{t-1} \odot g(-\zeta_t - \eta_t \nabla\ell(W_{t-1}; k_t, v_t))where g is inverse of f', and ζ_t chosen such that ∥W_t∥_1 = c.
Constraint set:
\mathcal{W} = \{W \mid \|W\|_1 = c, W_{jl} \geq 0, \forall j,l\}KL Divergence Specialization (Equation 19-21):
Retention function:
\text{Ret}_t(W, W_{t-1}) = \frac{1}{\eta_t} \sum_{jl} W_{jl} \log\frac{W_{jl}}{(W_t)_{jl}} + \frac{1}{\alpha_t} \sum_{jl} W_{jl} \log(W_{jl})Update rule (Equation 21):
W_t \leftarrow c \cdot \text{Softmax}((1-\lambda_t)\log(W_{t-1}) - \eta'_t \nabla\ell(W_{t-1}; k_t, v_t))where:
\lambda_t = \frac{1/\alpha_t}{1/\alpha_t + 1/\eta_t}, \quad \eta'_t = \frac{1}{1/\alpha_t + 1/\eta_t}Variables:
| Variable | Shape | Description |
|---|---|---|
| f(·) | ℝ+ → ℝ | Strictly convex function, f(1)=0 |
| g(·) | function | Inverse of f' |
| ζ_t | scalar | Lagrange multiplier for constraint |
| c | scalar | Total "probability mass" (scaling) |
| λ_t | scalar ∈ (0,1) | Retention coefficient |
| η'_t | scalar > 0 | Effective learning rate |
Implementation Notes:
- Softmax ensures W_t stays in probability simplex
- log-Softmax structure prevents numerical overflow
- λ_t interpolates between previous state and gradient update
- For f(τ) = τ log(τ): KL divergence
- Normalizes per-slice for neural networks
Dependencies: Non-negative weights
Test Criteria:
- W_t ≥ 0 element-wise
- ∥W_t∥_1 = c (probability simplex)
- Stable for large sequences
Learning-Retaining formulation (Eq 22):
Global retention:
G_t(W) = \frac{1}{2\beta}\|W\|_2^2 + \frac{1}{\alpha}\|W\|_1Local retention:
D_t(W, W_{t-1}) = \frac{1}{2}\|W - W_{t-1}\|_2^2Update with soft-thresholding:
W_t = S_\gamma(\lambda W_{t-1} - \zeta \nabla\ell(W_{t-1}; k_t, v_t))where:
\gamma = \frac{\eta\beta}{\alpha(\eta+\beta)}, \quad \lambda = \frac{\beta}{\beta+\eta}, \quad \zeta = \eta\lambdaSoft-thresholding operator:
S_\gamma(z) = \text{sign}(z) \max\{0, |z| - \gamma\}Smooth approximation:
S_\gamma(z) \approx \frac{|z| \cdot \arctan(z/\gamma)}{\pi/2}FTRL formulation (Eq 23):
A_t = A_{t-1} - \eta\nabla\ell(W_{t-1}; k_t, v_t)
W_t = S_{\eta/\alpha}(A_t)Variables:
| Variable | Shape | Description |
|---|---|---|
| α, β | scalars | Regularization strengths (ℓ1, ℓ2) |
| γ | scalar | Soft-threshold amount |
| λ | scalar ∈ (0,1) | Soft forgetting |
| ζ | scalar | Effective learning rate |
| A_t | (same as W) | Accumulated gradients (FTRL) |
Implementation Notes:
- Soft forgetting: multiply by λ < 1
- Hard forgetting: S_γ sets small values to 0
- FTRL version cleaner for implementation
- Use smooth approximation for differentiability
Dependencies: None
Test Criteria:
- Sparsity increases with α
- Bounded norm with β
- S_γ(0) = 0
FTRL regularization:
\frac{1}{\eta_t}R(W) = \frac{1}{2\eta(q-1)}\|W\|_q^2Update rule:
A_t = A_{t-1} - \eta\nabla\ell(W_{i-1}; k_t, v_t)
W_t = \frac{A_t}{\|A_t\|_p^{p-2}}where p = q/(q-1) (conjugate), 1 < q ≤ 2.
Variables:
| Variable | Shape | Description |
|---|---|---|
| q | scalar ∈ (1,2] | Regularization norm |
| p | scalar ≥ 2 | Conjugate exponent |
| A_t | (same as W) | Accumulated gradients |
Implementation Notes:
- Normalization by ℓ_p norm to power p-2
- q=2: reduces to standard ℓ2 regularization
- Different q values change memory dynamics
Dependencies: Accumulated gradients
Test Criteria:
- Normalization prevents explosion
- q=2 recovers standard GD
Retention (premetric):
D_t(W, W') = F(W) - F(W') - \langle\nabla F(W'), W - W'\ranglewhere F(W) = ∑_{jl} f(W_{jl}) for strictly convex f.
Update:
W_t = g(-\eta\nabla\ell(W_{t-1}; k_t, v_t) + F'(W_{t-1}))where g is inverse of F'.
Sigmoid example: Choose f'(τ) = log(τ/(1-τ)) (logit function), then g(τ) = σ(τ) (sigmoid).
Update:
W_t = \sigma(\log(W_t/(1-W_t)) - \eta\nabla\ell(W_{t-1}; k_t, v_t))Variables:
| Variable | Shape | Description |
|---|---|---|
| f(·) | ℝ+ → ℝ | Strictly convex function |
| F(W) | scalar | Sum of f applied element-wise |
| D_t | function | Bregman divergence |
| g(·) | function | Inverse of F' |
| σ(·) | sigmoid | 1/(1+exp(-·)) |
Implementation Notes:
- f(τ)=τ²/2: reduces to standard GD (Euclidean)
- Sigmoid: keeps W_t ∈ (0,1) element-wise
- Element-wise operations, division 𝑊_t/(1-𝑊_t)
- Adds non-linearity to memory dynamics
Dependencies: Choice of f
Test Criteria:
- W_t ∈ (0,1) for sigmoid
- Gradient finite
- f convex ensures Bregman divergence ≥ 0
Architecture:
- Memory: 2-layer MLP with expansion factor 4, GELU activation
- Formula: M(x) = x + LN(W₁σ(W₂x))
Attentional Bias: ℓ_p loss (Equation 10)
\mathcal{L} = \|M(k_t) - v_t\|_p^pRetention Gate: Hybrid ℓ_q + ℓ_2
\text{Global: } G_t(W) = \frac{1}{2(q-1)}\|W\|_q^2 + \frac{1}{\beta}\|W\|_2^2Memory Algorithm: Gradient Descent
Update Rule (Equation 24):
A_t = \alpha_t A_{t-1} - \eta_t \nabla\ell_p(W_{i-1}; k_t, v_t)
W_t = \frac{A_t}{\|A_t\|_q^{q-2}}Gradient (Equation 25):
\nabla\ell(W_{t-1}; k_t, v_t) = p\eta_t (\text{Sign}(Wk_t - v_t) \odot |Wk_t - v_t|^{p-1}) k_t^\topHyperparameters (from paper):
- p = 3 (ℓ_p norm)
- q = 4 (ℓ_q retention)
- MLP expansion: 4
- Activation: GELU
- LayerNorm: applied after MLP
Variables:
| Variable | Shape | Description |
|---|---|---|
| W₁ | (d, 4d) | MLP first layer (expansion) |
| W₂ | (4d, d) | MLP second layer (projection) |
| A_t | (same as W) | Accumulated gradients |
| α_t | scalar or (d,) | Retention coefficient (data-dependent) |
| η_t | scalar or (d,) | Learning rate (data-dependent) |
| σ | function | GELU activation |
| LN | function | Layer normalization |
Implementation Notes:
- Residual connection: x + LN(...)
- Smooth approximations for Sign and |·| (see Eq 11 notes)
- α_t and η_t are channel-wise, use low-rank projections (rank 32 or 64)
- Normalization by ℓ_q^{q-2} norm applied at end of each chunk (training)
Dependencies: None
Test Criteria:
- p=3 performs best (ablation study)
- Memory stable with α < 1
- Better on noisy synthetic tasks (S-NIAH-PK)
Architecture:
- Memory: 2-layer MLP (same as Moneta)
- Formula: M(x) = x + LN(W₁σ(W₂x))
Attentional Bias: Huber loss (Equation 16, Variant 3)
\ell(W; k_t, v_t) = \begin{cases}
\text{ℓ2 loss} & \text{if } \|M(k_t) - v_t\| \leq \delta_t \\
\text{ℓ1 loss (scaled)} & \text{otherwise}
\end{cases}Retention Gate: ℓ_2 local + ℓ_2 global (Titans-style)
\text{Ret}_t(W, W_{t-1}) = \frac{1}{2\theta_t}\|W - W_{t-1}\|_F^2 + \frac{1}{\beta_t}\|W\|_2^2Equivalent to "forget gate" mechanism from Titans.
Memory Algorithm: Gradient Descent
Update Rule (Equation 26):
W_t = \alpha_t W_{t-1} - \begin{cases}
\eta_t \nabla\ell_2(W_{t-1}; k_t, v_t) & \text{if } \|M(k_t) - v_t\| \leq \delta_t \\
\eta_t \delta_t \nabla\ell_1(W_{t-1}; k_t, v_t) & \text{otherwise}
\end{cases}where:
\alpha_t = \frac{\beta_t}{\beta_t + \theta_t}Hyperparameters:
- δ_t: data-dependent threshold (learnable)
- α_t ∈ [0,1]^d: channel-wise retention
- η_t: channel-wise learning rate
- β_t: decoupled from η_t (independent parameter)
Variables:
| Variable | Shape | Description |
|---|---|---|
| δ_t | scalar or (d,) | Huber threshold (data-dependent) |
| α_t | (d,) | Retention gate (channel-wise) |
| η_t | (d,) | Learning rate (channel-wise) |
| β_t | (d,) | Global retention parameter |
| θ_t | (d,) | Local retention parameter |
Implementation Notes:
- Adaptive: switches between ℓ_2 and ℓ_1 based on error magnitude
- More robust to outliers than pure ℓ_2
- "Coping mechanism": protects memory from extreme events
- Decoupling η and α improves expressivity
Dependencies: Error magnitude ∥M(k_t) - v_t∥
Test Criteria:
- Smooth transition at threshold
- Robust to outliers (verified experimentally)
- Comparable to Moneta and Memora on most tasks
Architecture:
- Memory: 2-layer MLP (same as Moneta, Yaad)
- Formula: M(x) = x + LN(W₁σ(W₂x))
Attentional Bias: ℓ_2 regression
\mathcal{L} = \|M(k_t) - v_t\|_2^2Retention Gate: KL divergence (Equation 21)
\text{Ret}_t(W, W_{t-1}) = \frac{1}{\eta_t}\sum_{jl} W_{jl}\log\frac{W_{jl}}{(W_{t-1})_{jl}} + \frac{1}{\alpha_t}\sum_{jl} W_{jl}\log(W_{jl})Memory Algorithm: Gradient Descent
Update Rule (Equation 27):
W_t = \text{Softmax}(\alpha_t \log(W_{t-1}) - \eta_t \nabla\ell_2(W_{t-1}; k_t, v_t))Hyperparameters:
- α_t: retention coefficient (data-dependent)
- η_t: learning rate (data-dependent)
- Softmax scaling: c (constant, usually 1)
Variables:
| Variable | Shape | Description |
|---|---|---|
| α_t | scalar or (d,) | Retention coefficient |
| η_t | scalar or (d,) | Learning rate |
| log(W) | element-wise | Natural logarithm |
| Softmax | per-slice | Applied per slice for neural networks |
Implementation Notes:
- Softmax ensures non-negativity and normalization
- Hard + soft forgetting: combines KL divergence (soft) with thresholding behavior
- log-domain updates prevent numerical issues
- Apply Softmax per-slice (e.g., per column or specific dimension)
- Lag tokens: apply full nonlinearity (Softmax + log) at chunk boundaries (index kb+1)
Parallelization (Section 5.4): Inside chunk (tokens kb+2 to (k+1)b):
- Use linear approximation (skip log and Softmax)
- Calculate: W_t ≈ W_t-1 - η_t∇ℓ_2(W_1; k_t, v_t) where W_1 is from lag token
At chunk boundary (token kb+1):
M_1 = \text{Softmax}(\alpha_1 \log(M_0) - \eta_1 \nabla\ell_2(M_0; k_1, v_1))Dependencies: Non-negative weights
Test Criteria:
- W_t ≥ 0 (enforced by Softmax)
- Sum constraints maintained
- Stable for long sequences
Core Idea:
- Divide sequence into chunks of size b (typically 16 or 64)
- Calculate gradients for chunk w.r.t. last state of previous chunk
- Use linearization + matrix operations within chunk
Equation 28: Expanded Recurrence (Moneta example, q=2)
LaTeX:
M_t = \alpha_t M_{t-1} - \eta_t \nabla\ell(M_{t-1}; k_t, v_t)
= \beta_t M_0 - \sum_{i=1}^{t} \eta_i \frac{\beta_t}{\beta_i} \nabla\ell(M_{t'}; k_i, v_i)where t' = t - mod(t, b) (last token of previous chunk), β_i = ∏_{j=1}^i α_j.
For linear memory M(W, k) = Wk with ℓ_p loss:
Equation 29: Batch Gradient Computation
\sum_{i=1}^{b} \eta_i \frac{\beta_b}{\beta_i} \nabla\ell(W_0; k_i, v_i) = p \mathbf{E}_b \odot \mathbf{B}_b \odot \text{Sign}(W_0K - V) \odot (|W_0K - V|^{p-1}) K^\topwhere:
- K = [k₁, k₂, ..., k_b] ∈ ℝ^{d×b} (stacked keys)
- V = [v₁, v₂, ..., v_b] ∈ ℝ^{d×b} (stacked values)
- E_b = [η₁, η₂, ..., η_b] ∈ ℝ^{1×b}
- B_b = [β_b/β₁, β_b/β₂, ..., β_b/β_b] ∈ ℝ^{1×b}
Variables:
| Variable | Shape | Description |
|---|---|---|
| b | scalar | Chunk size (16 or 64) |
| t' | scalar | Last token of previous chunk |
| M_0 | (varies) | Memory at start of chunk |
| K, V | (d, b) | Batched keys/values |
| E_b, B_b | (1, b) | Learning rate and decay masks |
| β_i | scalar | Cumulative product of α |
Smooth Approximations:
\text{Sign}(x) \approx \tanh(\alpha x), \quad \alpha \text{ large (e.g., 10)}
|x| = \sqrt{x^2 + \epsilon}, \quad \epsilon = 10^{-6}Implementation Notes:
- All operations in Eq 29 are matrix multiplications or element-wise
- Can run on GPU/TPU efficiently
- Non-linearity (normalization for q≠2) applied only at chunk boundaries
- For 2-layer MLP: apply same process to each layer's parameters
Test Criteria:
- Chunk processing much faster than sequential
- Gradient approximation error small (controlled by chunk size)
- Memory overflow prevented by periodic normalization
Process:
- Calculate gradients for both ℓ₁ and ℓ₂ losses
- Use masking based on error magnitude
Pseudo-algorithm:
# Inside chunk
grad_l2 = compute_l2_gradient(W0, K, V) # (W0*K - V) * K^T
grad_l1 = compute_l1_gradient(W0, K, V) # sign(W0*K - V) * K^T
errors = ||W0*K - V|| # shape (b,)
mask = (errors <= delta_t) # boolean mask
# Combine gradients
grad = mask * grad_l2 + (~mask) * (delta_t * grad_l1)
# Apply with learning rates
grad_weighted = E_b ⊙ B_b ⊙ grad
W_new = alpha_t * W0 - grad_weightedVariables:
| Variable | Shape | Description |
|---|---|---|
| mask | (b,) | Boolean: True if error ≤ δ_t |
| grad_l1 | (d, d) or MLP | Gradient for ℓ₁ loss |
| grad_l2 | (d, d) or MLP | Gradient for ℓ₂ loss |
Implementation Notes:
- Two gradient computations per chunk
- Masking is element-wise per token
- Can still use batched matrix operations
Lag Token (Equation 30): At first token of chunk (i = kb + 1):
M_1 = \text{Softmax}(\alpha_1 \log(M_0) - \eta_1 \nabla\ell_2(M_0; k_1, v_1))Linear Approximation Inside Chunk: For i = kb+2 to (k+1)b, use M_1 instead of M_{i-1}:
Equation 31-32: Batch Computation
\nabla\ell(W_1; k_t, v_t) = (W_1 k_t - v_t) k_t^\top
\sum_{i=1}^{b} \eta_i \frac{\beta_b}{\beta_i} \nabla\ell(W_1; k_i, v_i) = \mathbf{E}_b \odot \mathbf{B}_b \odot (W_1 K - V) K^\topImplementation:
# Step 1: Lag token (exact update)
M_1 = softmax(alpha_1 * log(M_0) - eta_1 * grad_l2(M_0, k_1, v_1))
# Step 2: Rest of chunk (linear approximation)
errors = M_1 @ K - V # shape (d, b)
grad_batch = (E_b ⊙ B_b ⊙ errors) @ K.T # shape (d, d)
M_next = beta_b * M_1 - grad_batch
# Step 3: Apply Softmax at next boundary
M_new = softmax(log(M_next)) # or keep for next lagVariables:
| Variable | Shape | Description |
|---|---|---|
| M_0 | (varies) | Memory from previous chunk |
| M_1 | (varies) | Memory after lag token |
| log(M_0) | element-wise | Natural logarithm |
Implementation Notes:
- Log and Softmax only at chunk boundaries
- Prevents compounding of non-linearity
- Linear approximation sufficient inside chunk
- Softmax can be per-slice for neural network parameters
Components (from Figure 2 and Section 5.4):
-
Input Projections:
q = x @ W_q k = x @ W_k v = x @ W_v -
1D Depthwise-Separable Convolution:
q = conv1d(q, kernel_size=4) k = conv1d(k, kernel_size=4) v = conv1d(v, kernel_size=4) -
ℓ_2 Normalization:
q = q / ||q||_2 k = k / ||k||_2 -
Memory Module:
# One of: Moneta, Yaad, Memora output = MemoryModule(q, k, v) -
Output Processing:
output = LayerNorm(output) output = output * sigmoid(output @ W_gate)
Channel-wise Parameters (Section 5.4):
- η_t, δ_t, α_t ∈ ℝ^d (channel-wise)
- Use low-rank projection to reduce parameters:
param_t = Linear_1(x) # x → ℝ^k (k=32 or 64) param_t = Linear_2(param_t) # ℝ^k → ℝ^d
Full Architecture:
Input (N, d)
↓
[Q, K, V Projections] (d → d each)
↓
[Conv1D (kernel=4)] for each
↓
[ℓ2 Norm] for Q, K
↓
[MIRAS Memory Module]
↓
[LayerNorm + Gating]
↓
Output (N, d)
Variables:
| Variable | Shape | Description |
|---|---|---|
| W_q, W_k, W_v | (d, d) | Projection matrices |
| conv1d | kernel=4 | Depthwise-separable convolution |
| W_gate | (d, d) | Gating linear layer |
Implementation Notes:
- Follow modern linear RNN design (GLA, Gated DeltaNet)
- Convolution adds local context
- ℓ2 norm improves training stability
- Gating from Mehta et al. 2023
Macro Architecture (Section 5.4):
- Based on Llama architecture
- Replace attention with MIRAS layer
Components:
Embedding
↓
[MIRAS Layer
↓
MLP (SwiGLU)
↓
RMSNorm
] × L layers
↓
RMSNorm
↓
LM Head
MLP:
- SwiGLU activation (Gated Linear Unit with Swish)
- Formula: MLP(x) = (W_1 x ⊙ swish(W_2 x)) @ W_3
Positional Encoding:
- RoPE (Rotary Position Embedding)
Normalization:
- RMSNorm (Root Mean Square Layer Normalization)
Model Sizes (Table 5):
| Size | Layers | Dim | Heads | Tokens | Peak LR |
|---|---|---|---|---|---|
| 170M | 12 | 768 | 16 | 15B | 3e-3 |
| 340M | 24 | 1024 | 16 | 15B | 1.5e-3 |
| 780M | 24 | 1536 | 16 | 30B | 1.25e-3 |
| 1.3B | - | - | - | 100B | - |
ASCII Diagram:
Input Tokens
↓
[Token Embedding]
↓
┌────────────────────┐
│ MIRAS Block 1 │
│ ┌──────────────┐ │
│ │ Q/K/V Proj │ │
│ │ Conv1D │ │
│ │ Memory(q,k,v)│ │
│ │ Gate │ │
│ └──────────────┘ │
│ RMSNorm │
│ MLP (SwiGLU) │
│ RMSNorm │
└────────────────────┘
↓
[... L blocks ...]
↓
[Final RMSNorm]
↓
[LM Head]
↓
Output Logits
Architecture (Figure 2, Samba-style):
[MIRAS Layer
↓
Sliding Window Attention (SWA)
↓
MLP
↓
RMSNorm
] × L layers
Sliding Window Attention:
- Window size: typically 128 or 256
- Local attention pattern
- Complements recurrent MIRAS memory
Implementation Notes:
- Alternate MIRAS and SWA
- Or stack: MIRAS → SWA in each block
- SWA provides local inductive bias
- MIRAS provides long-range memory
| Parameter | 340M | 760M | 1.3B | Description | Paper Reference |
|---|---|---|---|---|---|
| Layers | 24 | 24 | - | Number of blocks | Table 5 |
| Hidden dim | 1024 | 1536 | - | Model dimension | Table 5 |
| Heads | 16 | 16 | - | Number of attention heads (for hybrid) | Table 5 |
| MLP expansion | 4 | 4 | 4 | Memory MLP expansion factor | Section 5.3 |
| Conv kernel | 4 | 4 | 4 | 1D conv kernel size | Section 5.4 |
| Vocab size | - | - | - | Tokenizer dependent (not specified) | - |
| Parameter | Value | Description | Paper Reference |
|---|---|---|---|
| Training context | 4096 | Context window during training | Section 6 |
| Batch size | - | Not specified | - |
| Tokens (340M) | 15B | Total training tokens | Table 5, Section 6 |
| Tokens (760M) | 30B | Total training tokens | Table 5 |
| Tokens (1.3B) | 100B | Total training tokens | Section 6 |
| Peak LR (340M) | 1.5e-3 | Maximum learning rate | Table 5 |
| Peak LR (760M) | 1.25e-3 | Maximum learning rate | Table 5 |
| LR schedule | - | Not specified (likely cosine) | - |
| Warmup | - | Not specified | - |
| Weight decay | - | Not specified | - |
| Optimizer | AdamW | Assumed (standard) | - |
Moneta:
| Parameter | Value | Description | Paper Reference |
|---|---|---|---|
| p (attentional bias) | 3 | ℓ_p norm for loss | Section 5.3, Figure 4 |
| q (retention) | 4 | ℓ_q norm for retention | Section 5.3, Figure 4 |
| α smooth param | ~10 | For tanh approximation of Sign | Remark 5 |
| ε (abs approx) | 1e-6 | For smooth |x| approximation | Section 5.4 |
Yaad:
| Parameter | Value | Description | Paper Reference |
|---|---|---|---|
| δ_t | data-dependent | Huber threshold (learnable) | Section 5.3 |
| α_t | [0,1]^d | Retention gate (channel-wise) | Section 5.3 |
| β_t | learnable | Global retention (decoupled from η) | Section 5.3 |
Memora:
| Parameter | Value | Description | Paper Reference |
|---|---|---|---|
| α_t | data-dependent | Retention coefficient | Section 5.3 |
| λ_t | (0,1) | Derived from α_t and η_t | Equation 21 |
| c | 1 (assumed) | Softmax scaling constant | Section 5.2 |
Channel-wise Parametrization:
| Parameter | Value | Description | Paper Reference |
|---|---|---|---|
| Low-rank k | 32 or 64 | Rank for channel-wise parameter projection | Section 5.4 |
| Parameter | Value | Description | Paper Reference |
|---|---|---|---|
| Chunk size (b) | 16 or 64 | Tokens per parallel chunk | Section 5.4 |
| Lag tokens (Memora) | 1 | Tokens with full nonlinearity at boundary | Section 5.4 |
| Variant | Memory Arch | Attentional Bias | Retention Gate | Memory Algorithm |
|---|---|---|---|---|
| Moneta | 2-layer MLP | ℓ_p (p=3) | ℓ_q (q=4) + ℓ_2 | GD |
| Yaad | 2-layer MLP | Huber | ℓ_2 local + ℓ_2 global | GD |
| Memora | 2-layer MLP | ℓ_2 | KL divergence | GD |
From Table 1 in paper.
Language Modeling:
-
FineWeb-Edu (Penedo et al. 2024)
- Used for LM and common-sense reasoning
- Training: 15B tokens (small), 30B tokens (medium), 100B tokens (large)
- Evaluation metrics: Perplexity on validation set
-
C4 (Raffel et al. 2020)
- Used for scaling pattern experiments
- Training context: 4096 tokens
Specific LM Tasks:
- WikiText (Merity et al. 2017): Perplexity
- LAMBADA (Paperno et al. 2016): Perplexity and Accuracy
Common-sense Reasoning (zero-shot):
- PIQA (Bisk et al. 2020): Physical commonsense, Accuracy
- HellaSwag (Zellers et al. 2019): Sentence completion, Accuracy (acc_n)
- WinoGrande (Sakaguchi et al. 2021): Pronoun resolution, Accuracy
- ARC-easy, ARC-challenge (Clark et al. 2018): Science questions, Accuracy (acc, acc_n)
- SIQA (Social IQA): Social reasoning, Accuracy
- BoolQ (Clark et al. 2019): Yes/No questions, Accuracy
Long Context:
- RULER S-NIAH (Hsieh et al. 2024): Needle-in-haystack variants
- S-NIAH-PK: Passkey retrieval (synthetic noise)
- S-NIAH-N: Number retrieval
- S-NIAH-W: Word retrieval
- Context lengths: 1K, 2K, 4K, 8K tokens
- Metric: Accuracy (%)
340M Models (Table 2):
Best pure recurrent (Moneta):
- WikiText perplexity: 26.19
- LAMBADA perplexity: 29.31
- PIQA: 63.99%
- HellaSwag: 39.23%
- ARC-c: 27.15%
760M Models (Table 2):
Best pure recurrent (Yaad):
- WikiText perplexity: 20.99
- LAMBADA perplexity: 21.57
- PIQA: 69.14%
- HellaSwag: 50.02%
- ARC-c: 36.27%
Best hybrid (Moneta-H):
- WikiText perplexity: 18.72
- LAMBADA perplexity: 20.13
- PIQA: 70.84%
1.3B Models (Table 2):
Best pure recurrent (Yaad):
- WikiText perplexity: 15.18
- LAMBADA perplexity: 11.89
- PIQA: 72.81%
- HellaSwag: 56.46%
- ARC-c: 40.05%
Needle-in-Haystack (Table 3, 760M models):
Moneta (best on synthetic):
- S-NIAH-PK: 99.4% (2K), 98.8% (4K), 98.8% (8K)
- Average: 93.5%
Yaad (best overall):
- Average: 92.9%
Scaling (Figure 3):
- FLOPs vs perplexity: All variants outperform baselines
- Context scaling: Superior to Transformers, Mamba2, GSA when training context 2K→32K
Training:
- 340M model: ~15B tokens, context 4K
- 760M model: ~30B tokens, context 4K
- 1.3B model: ~100B tokens, context 4K
- Hardware: Not specified (likely TPU/GPU cluster)
Inference:
- Linear time complexity: O(Nd) vs O(N²d) for Transformers
- Fixed memory footprint (no KV cache growth)
- Parallel training: O(N/b) sequential steps with chunk size b
Comparison to Baselines (from context):
- Competitive or better than Transformer++, RetNet, Mamba2, Gated DeltaNet
- Outperforms Titans (predecessor with momentum-based updates)
- Pure recurrent models competitive with hybrid (attention+recurrent)
TITANS (Behrouz et al. 2024c) - CRITICAL DEPENDENCY:
MIRAS generalizes and extends TITANS:
-
Memory Architecture:
- TITANS: k-layer MLP (deep memory)
- MIRAS: General (vector, matrix, MLP, etc.)
- Moneta/Yaad/Memora: 2-layer MLP (simpler than TITANS)
-
Attentional Bias:
- TITANS: ℓ_2 MSE loss only
- MIRAS: General L (ℓ_p, Huber, robust, etc.)
- Extends beyond MSE to handle different data distributions
-
Retention Gate (Forget Gate):
- TITANS: ℓ_2 local + ℓ_2 global (specific form)
- MIRAS: General Ret (ℓ_q, KL, Elastic Net, Bregman, etc.)
- Reinterprets TITANS forget gate as ℓ_2 retention regularization
-
Memory Algorithm:
- TITANS: GD with momentum
- MIRAS: General optimizer (GD, momentum, Newton, non-parametric)
- Moneta/Yaad/Memora: GD (simpler, no momentum)
-
Specific Relation (from Table 1):
Titans-LMM: - Memory: k-layer MLP - Attentional Bias: ℓ_2 - Retention: ℓ_2 local + ℓ_2 global - Algorithm: GD + Momentum - Update: M_t = α_t M_{t-1} - S_t where S_t = η_t S_{t-1} - θ_t ∇L(M_{t-1}; k_t, v_t) -
Cold Start Strategy (Footnote 2):
- TITANS uses "cold start" for full memory erase
- Different from Mamba2/Gated DeltaNet which treat next token as first-ever
- TITANS uses previous state to measure surprise before erasing
Implementation Order:
- Implement basic memory module (matrix-valued, linear)
- Implement ℓ_2 attentional bias and retention (baseline)
- Test against TITANS-LMM (should be special case)
- Add novel attentional biases (ℓ_p, Huber)
- Add novel retention gates (KL, Elastic Net)
- Implement parallelization
TTT (Test-Time Training) (Sun et al. 2024):
- TTT-Linear, TTT-MLP: Use ℓ_2 loss, no retention
- MIRAS unifies TTT as non-parametric/parametric solution to MSE
- TTT-Linear: M_t = M_{t-1} - η∇L(M_{t-1}, x_t)
- TTT-MLP: Same but with MLP memory
Linear Attention (Katharopoulos et al. 2020):
- Hebbian rule: M_t = M_{t-1} + v_t k_t^⊤
- MIRAS special case: α=1, dot-product bias, no retention
RetNet (Sun et al. 2023):
- M_t = αM_{t-1} + v_t k_t^⊤
- MIRAS special case: scalar α, dot-product bias, ℓ_2 retention
Mamba2 (Dao et al. 2024):
- Data-dependent α_t: M_t = α_t M_{t-1} + v_t k_t^⊤
- MIRAS special case: dot-product bias, ℓ_2 retention
GLA (Yang et al. 2024b):
- Diagonal α_t: M_t = Diag(α_t)M_{t-1} + v_t k_t^⊤
- MIRAS special case: dot-product bias, ℓ_2 retention
DeltaNet (Schlag et al. 2021):
- M_t = (I - βk_t k_t^⊤)M_{t-1} + βv_t k_t^⊤
- MIRAS special case: ℓ_2 bias, no retention (α=1)
Gated DeltaNet (Yang et al. 2024a):
- M_t = α_t(I - βk_t k_t^⊤)M_{t-1} + βv_t k_t^⊤
- MIRAS special case: ℓ_2 bias, ℓ_2 retention
RWKV-7 (Peng et al. 2025b):
- Channel-wise gated delta rule
- MIRAS special case: ℓ_2 bias, ℓ_2 retention, channel-wise α
HGRN2 (Qin et al. 2024):
- Uses ℓ_1 loss (outside standard MIRAS instantiations in Table 1)
- Shows MIRAS can express non-ℓ_2 objectives
Transformer (Vaswani et al. 2017):
- Non-parametric solution to ℓ_2 MSE with Nadaraya-Watson estimator
- No retention (keeps all key-value pairs)
- Equation 1-2: Transformer attention (baseline reference)
- Equation 3: Hebbian memory update (baseline)
- Equation 4: Associative memory definition (core framework)
- Equation 5: Gradient descent update (general)
- Equation 7: FTRL viewpoint (alternative formulation)
- Learning-Retaining viewpoint (core implementation approach)
- Equation 9: Delta rule memory update (baseline)
- Equation 10: ℓ_p attentional bias definition
- Equation 11: ℓ_p gradient (general p)
- Equation 12: ℓ_1 special case (value-less memory)
- Remark 5: Smooth approximations (Sign, |·|)
- Equation 24: Moneta update rule (main)
- Equation 25: Moneta gradient (explicit)
- Equation 28-29: Moneta parallelization
- ℓ_q retention normalization (W_t = A_t / ∥A_t∥_q^{q-2})
- Test p=3, q=4 configuration
- Equation 13: Huber function H(a)
- Equation 14: Coordinate-wise Huber (variant 1)
- Equation 15: ℓ_2 norm-based Huber (variant 2)
- Equation 16: Smooth mixture Huber (variant 3, main)
- Equation 26: Yaad update rule (main)
- ℓ_2 retention (local + global)
- Data-dependent δ_t threshold
- Decoupled η_t and α_t (via β_t)
- Parallelization with dual gradients (ℓ_1 and ℓ_2)
- Equation 18: f-divergence update (general)
- Equation 19-20: KL divergence retention formulation
- Equation 21: KL update with Softmax (main)
- Equation 27: Memora update rule (simplified form)
- Equation 30: Lag token (chunk boundary)
- Equation 31-32: Memora parallelization (linear approx)
- Softmax per-slice for neural networks
- Log-domain stability
- Equation 22: Elastic Net retention (Learning-Retaining)
- Equation 23: Elastic Net retention (FTRL)
- Soft-thresholding operator S_γ
- ℓ_q memory stability (Section 5.2, Variant 4)
- Bregman divergence retention (Section 5.2, Variant 5)
- Robust to value shifts (Equation 17)
- 2-layer MLP memory (M(x) = x + LN(W₁σ(W₂x)))
- GELU activation
- Layer normalization
- Residual connections
- Q/K/V projections
- 1D depthwise-separable convolution (kernel=4)
- ℓ_2 normalization for q, k
- Output gating (sigmoid)
- RMSNorm (full model)
- RoPE positional encoding
- SwiGLU MLP (full model)
- Low-rank projection for η_t (d → k → d, k=32 or 64)
- Low-rank projection for α_t
- Low-rank projection for δ_t (Yaad)
- Data-dependent parameter computation
- Chunk-based processing (chunk size b=16 or 64)
- Batch gradient computation (Equation 29)
- E_b and B_b construction (learning rate and decay masks)
- Matrix operations (K, V stacking)
- Smooth approximations (tanh, sqrt)
- Chunk boundary handling (normalization for Moneta)
- Lag token processing (Memora)
- Dual gradient computation (Yaad)
- AdamW optimizer
- Learning rate schedule (cosine decay)
- Gradient clipping (if needed)
- Mixed precision training (FP16/BF16)
- Data loading (FineWeb-Edu, C4)
- Tokenization (Llama tokenizer or similar)
- Evaluation harness (lm-eval or similar)
- Unit tests: each equation
- Integration tests: full forward/backward pass
- Numerical gradient checks
- Causality verification (M_t only depends on inputs ≤ t)
- Memory stability tests (no overflow/underflow)
- Chunk parallelization equivalence (vs. sequential)
- Reproduction of baseline results (RetNet, DeltaNet, etc.)
- Ablation studies (p, q, δ values)
-
Smooth Approximations:
- Sign(x) and |x| are non-differentiable
- MUST use smooth approximations for backprop
- Sign(x) ≈ tanh(αx), α large (10)
- |x| = sqrt(x² + ε), ε = 1e-6
- Test gradients numerically
-
Normalization Timing (Moneta):
- ℓ_q normalization only at chunk boundaries during training
- Inside chunk: linear accumulation
- Prevents compounding non-linearity
- At inference: can normalize every step
-
Log/Softmax Stability (Memora):
- Never compute exp(large number)
- Use log-sum-exp trick in Softmax
- Softmax(x) = exp(x - max(x)) / sum(exp(x - max(x)))
- Log-domain updates: log(W_t) = ...
- Clamp log(W) to avoid -∞
-
Chunk Boundary Handling:
- Memory state from previous chunk is frozen (M_0)
- Gradients computed w.r.t. M_0, not M_{t-1}
- Approximation error controlled by chunk size
- Smaller chunk = better approximation, slower training
-
Channel-wise Parameters:
- η_t, α_t, δ_t all shape (d,) not scalar
- Use broadcasting carefully
- Low-rank projection reduces parameters
- Make sure gradients flow through projection
-
Retention Gate vs. Learning Rate:
- In Yaad/Memora: η_t and α_t are DECOUPLED
- Don't assume α_t = f(η_t)
- β_t is independent parameter
- More expressive but more parameters
-
Huber Threshold (Yaad):
- δ_t can be data-dependent (learnable)
- Shape: scalar or (d,)
- Needs careful initialization (e.g., 1.0)
- Too small: behaves like ℓ_2
- Too large: behaves like ℓ_1
-
MLP Memory Architecture:
- Expansion factor 4: (d → 4d → d)
- GELU activation between layers
- Layer norm AFTER MLP
- Residual connection: x + LN(MLP(x))
- Gradient flow through residual critical
-
KV Cache Equivalent:
- Unlike Transformers, no growing KV cache
- Memory size fixed: O(d²) for matrix, O(d × 4d) for 2-layer MLP
- Enables true O(1) memory for inference
- BUT: can't perfectly recall like Transformer
-
Parallelization Trade-off:
- Chunk size b: larger = faster, less accurate
- b=64 good for training, b=1 for inference
- Parallelization only for training
- Inference is sequential (but fast)
-
What is "Memory"?
- For linear: M(W, k) = Wk, W is (d, d) or (d, n)
- For MLP: M(W, k) = W₁σ(W₂k), W₁ is (d, 4d), W₂ is (4d, d)
- For full: M(W, k) = k + LN(W₁σ(W₂k)) with residual
- Memory "state" = parameters W (not hidden state like RNN)
-
Attentional Bias vs. Outer Loss:
- Attentional bias L: memory's internal objective
- Outer loss: language modeling cross-entropy
- Meta-learning: inner loop (L), outer loop (LM loss)
- Gradients from outer loss flow through memory update
-
What is "Test-Time" Learning?
- During forward pass (inference or training), memory updates
- Memory learns to map (k_t, v_t) as tokens arrive
- Parameters W updated based on current input
- Different from weight updates via backprop (those happen in outer loop)
-
Retention ≠ Forget:
- Paper: "Retention gate" not "forget gate"
- No actual erasing, just weighted retention
- α_t close to 1: retain strongly
- α_t close to 0: don't retain (effective forgetting)
- Matches human memory: retrieval failure, not erasure
-
FTRL vs. Learning-Retaining:
- Mathematically equivalent under conditions (Prop 3.2)
- Learning-Retaining more general (doesn't need strict convexity)
- FTRL useful for analysis
- Implement Learning-Retaining (cleaner)
-
Gradient Computation:
- For linear memory: gradient is rank-1 matrix (v_t k_t^⊤ for Hebbian)
- For MLP: use autodiff, but can derive analytically
- Parallelization: batch all gradients in chunk
- Use einsum for efficient computation
-
Data-Dependent Parameters:
- η_t = MLP_η(x_t): function of input
- α_t = sigmoid(MLP_α(x_t)): ensures [0,1]
- δ_t = softplus(MLP_δ(x_t)): ensures > 0
- Low-rank: x_t → Linear₁ → ReLU → Linear₂ → param
-
Hybrid Models:
- Sequential: [MIRAS → SWA → MLP] per block
- SWA: sliding window attention (window=128 or 256)
- Use flash-attention or similar for SWA
- MIRAS provides global, SWA provides local
-
Kernel Fusion:
- Fuse Sign/abs/power operations
- Fuse Softmax + log for Memora
- Fuse element-wise ⊙ operations
- Custom CUDA kernels for Equation 29
-
Memory Layout:
- Contiguous tensors for K, V batches
- Transpose once, use multiple times
- Avoid strided operations in inner loops
-
Numerical Precision:
- BF16 for most operations
- FP32 for normalization (ℓ_q, Softmax)
- Mixed precision: store W in FP32, compute in BF16
- Gradient accumulation in FP32
-
Checkpoint/Recomputation:
- Save M_0 at chunk boundaries
- Recompute chunk internals during backward
- Trade memory for computation
- Critical for long sequences
-
Distributed Training:
- Tensor parallelism: split d dimension
- Sequence parallelism: split N dimension
- Pipeline parallelism: split layers
- Chunk parallelism: process chunks on different devices
-
Initialization:
- Memory W: Xavier/He initialization
- α_t: initialize close to 1 (high retention initially)
- η_t: initialize small (1e-2 to 1e-3)
- δ_t: initialize to 1.0
-
Gradient Clipping:
- Memory gradients can be large (especially ℓ_p with p>2)
- Clip by norm: max_norm = 1.0
- Or use adaptive clipping
-
Learning Rate Schedule:
- Cosine decay from peak LR
- Warmup: 2000-5000 steps
- Peak LR: see Table 5 (1.5e-3 for 340M)
- Min LR: 10% of peak
-
Gradient Checks:
- Numerical gradient vs. autodiff
- Check each component separately (ℓ_p, ℓ_q, Huber, KL)
- Use small models (d=32) for debugging
-
Monitor These:
- Memory norm: should be bounded
- α_t distribution: should be in [0,1]
- η_t distribution: should be positive, reasonable scale
- Gradient norms: should not explode
- Loss: should decrease (eventually)
-
Common Errors:
- NaN: check Softmax (log of negative), sqrt of negative
- Inf: check exp(large), division by zero
- Exploding: check α_t > 1, η_t too large
- No learning: check α_t ≈ 0, η_t ≈ 0
-
Unit Tests:
- RetNet: set p=2, q=2, α=const, no MLP → should match RetNet exactly
- DeltaNet: set p=2, α=1, no retention → should match DeltaNet
- Chunk=1: should match sequential (up to numerical precision)
-
Ablations to Run:
- Vary p: {1, 1.5, 2, 2.8, 3, 3.2, 4}
- Vary q: {2, 3, 4, 5}
- Vary chunk size: {1, 4, 16, 64}
- Vary δ: {0.1, 0.5, 1.0, 2.0}
-
Relationship to TITANS:
- MIRAS is generalization, not replacement
- Titans-LMM: ℓ_2 bias, ℓ_2 retention, GD+momentum, k-layer MLP
- Moneta/Yaad/Memora: simpler (2-layer, GD only) but different objectives/gates
- TITANS likely better on some tasks (momentum helps)
- MIRAS provides framework to understand why
-
Table 1 Unification:
- All existing models fit in MIRAS framework
- Most use ℓ_2 or dot-product bias
- Most use ℓ_2 retention
- Novel contribution: go beyond these defaults
- Opportunity: explore vast design space
-
Performance Claims:
- Moneta/Yaad/Memora beat Mamba2, Gated DeltaNet, TTT
- Competitive with or beat Transformers (small scale)
- Better scaling with context length
- Hybrid versions best overall
-
When to Use Which Variant:
- Moneta: noisy data, need robustness (best on S-NIAH-PK)
- Yaad: outliers, need adaptive robustness (best overall balance)
- Memora: need hard forgetting, sparsity (KL divergence)
- Hybrid: when maximum performance needed (adds SWA)
-
Limitations (not in paper, inferred):
- Fixed memory size: can't grow like Transformer KV cache
- Approximation in chunk parallelization
- Many hyperparameters (p, q, δ, α, η)
- Unclear how to choose attentional bias for new task
- No theoretical guarantees on memorization capacity
-
Future Work Suggestions (from paper):
- Explore other attentional biases (KL, Wasserstein, etc.)
- Other retention gates (Section 5.2 has many)
- Other memory architectures (beyond MLP)
- Other optimizers (Newton, momentum, adaptive)
- Theoretical analysis (capacity, expressivity)
- Application to other domains (vision, biology)
paper_id: "MIRAS"
paper_title: "It's All Connected: A Journey Through Test-Time Memorization, Attentional Bias, Retention, and Online Optimization"
authors: "Ali Behrouz, Meisam Razaviyayn, Peilin Zhong, Vahab Mirrokni"
affiliation: "Google Research"
arxiv: "arXiv:2504.13173v1"
year: 2025
date_published: "2025-04-17"
equations_extracted: 32
algorithms_extracted: 3
architectures_extracted: 3
hyperparameters_extracted: 20+
key_concepts:
- "Associative Memory"
- "Attentional Bias"
- "Retention Gate (Forget Gate)"
- "Test-Time Memorization"
- "Meta-Learning"
- "Online Optimization"
- "Linear RNNs"
- "Efficient Transformers"
novel_contributions:
- "MIRAS framework (4 design choices)"
- "Moneta (ℓ_p bias + ℓ_q retention)"
- "Yaad (Huber bias + ℓ_2 retention)"
- "Memora (ℓ_2 bias + KL retention)"
- "Unification of existing models (Table 1)"
- "Novel retention gates (f-divergence, Elastic Net, etc.)"
- "Parallelization for non-linear RNNs"
dependencies:
critical:
- "TITANS (Behrouz et al. 2024c)"
important:
- "TTT (Sun et al. 2024)"
- "RetNet (Sun et al. 2023)"
- "Mamba2 (Dao et al. 2024)"
- "GLA (Yang et al. 2024b)"
- "Gated DeltaNet (Yang et al. 2024a)"
implementation_priority:
high:
- "Core framework (Equation 4, Learning-Retaining)"
- "Moneta (Equations 24-25, 28-29)"
- "Yaad (Equations 16, 26)"
- "Memora (Equations 21, 27, 30-32)"
- "Parallelization (Section 5.4)"
- "Architecture (Section 5.4, Figure 2)"
medium:
- "Alternative retention gates (Section 5.2)"
- "Hybrid models (with SWA)"
- "Channel-wise parametrization"
low:
- "Other attentional biases (Equation 17)"
- "Theoretical analysis"
code_modules:
- "memory.py: Memory architectures (linear, MLP)"
- "attentional_bias.py: Loss functions (ℓ_p, Huber, robust, KL)"
- "retention_gate.py: Retention functions (ℓ_q, KL, Elastic Net, Bregman)"
- "optimizer.py: Memory learning algorithms (GD, momentum)"
- "moneta.py: Moneta model"
- "yaad.py: Yaad model"
- "memora.py: Memora model"
- "miras_layer.py: MIRAS layer block"
- "miras_model.py: Full model (with embedding, LM head)"
- "parallelization.py: Chunk-based training"
- "utils.py: Smooth approximations, helper functions"
testing_requirements:
- "Unit tests for each equation"
- "Gradient checks (numerical vs. autodiff)"
- "Baseline reproductions (RetNet, DeltaNet)"
- "Chunk parallelization equivalence"
- "Causality verification"
- "Memory stability tests"
- "Ablation studies (p, q, δ, chunk size)"
estimated_complexity:
implementation: "High (many components, careful numerical handling)"
debugging: "High (meta-learning, multiple loss terms, parallelization)"
tuning: "High (many hyperparameters, task-dependent optimal choices)"
recommended_approach:
1. "Implement basic framework (linear memory, ℓ_2 bias, ℓ_2 retention, GD)"
2. "Test against RetNet/DeltaNet (should match exactly)"
3. "Add MLP memory architecture"
4. "Add Moneta components (ℓ_p, ℓ_q)"
5. "Add Yaad components (Huber)"
6. "Add Memora components (KL, Softmax)"
7. "Implement parallelization (critical for scale)"
8. "Full architecture (MIRAS layer, full model)"
9. "Training infrastructure"
10. "Evaluation harness"
11. "Ablation studies"
12. "Reproduce paper results"
notes:
- "MIRAS is framework, not single model"
- "3 novel models: Moneta, Yaad, Memora"
- "Unifies 15+ existing models (Table 1)"
- "Generalizes TITANS (critical dependency)"
- "Extensive design space (4 choices × many options each)"
- "Parallelization essential for practical training"
- "Paper lacks some implementation details (will need experimentation)"
- "Many hyperparameters (p, q, δ, α, η, chunk size, etc.)"
- "Task-dependent optimal configuration"Total Equations Extracted: 32+ Total Algorithms Extracted: 3 (Moneta, Yaad, Memora) + parallelization variants Total Architectures: 3 variants + hybrid + full model architecture Total Hyperparameters: 20+
This context file provides complete mathematical formulations, implementation guidance, and architectural details for reproducing the MIRAS paper. All equations include LaTeX, plain English descriptions, variable tables with shapes, implementation notes, dependencies, and test criteria as requested.