2-Simplicial Attention: Geometry on the Grassmannian and KV Cache Filtering
Thesis project: implementation of 2-simplicial attention for Transformers (trilinear and Gram determinant), geometric analysis of key vector planes on the Grassmannian of LLaMA, and KV cache filtering via a Grassmannian Q-filter score.
Standard Transformer attention is a dot product between query Q and key K (bilinear form). This project explores two trilinear generalizations:
| Type | Formula | What it measures |
|---|---|---|
| Dot-product (standard) | (\langle q, k \rangle) | Similarity between 2 vectors |
| Trilinear | (\langle q, k_1, k_2 \rangle) | 3-body interaction |
| Gram determinant | det(Gram(q, k₁, k₂)) | Volume of the parallelepiped span(q, k₁, k₂) |
The GramDet variant has a direct geometric interpretation: the higher the determinant, the more linearly independent the three vectors — and thus the more complementary the information they carry.
Core contribution: LLaMA possesses latent Grassmannian structure on Gr(2,d) — the planes generated by pairs of key vectors have geodesic variance 30-38% below a random Haar sample, proving that key space is not isotropic but has intrinsic geometry. This structure is robust cross-dataset (Wikitext, C4) and cross-scale (1B, 8B), and predicts compatibility with Q-filters-style eviction.
git clone https://github.com/and-per-i/simplex-filters.git
cd simplex-filters
pip install -r requirements.txt
pip install -e simplicial_attention/ # optional Triton/TLX kernelsRequirements: Python 3.10+, CUDA GPU, HF_TOKEN to download weights from HuggingFace.
export HF_TOKEN=hf_your_tokenReplaces standard dot-product attention in selected LLaMA layers with two variants:
from src.modeling.gram_det_attention import GramDetAttention
# GramDet: score = det(Gram(q, k1, k2)) with configurable window
# Window of 17 tokens (W=8) → 153 pairs, 65 tokens (W=32) → 2145 pairsConversion handled by convert_llama_to_hybrid():
- Trilinear (
--attention-type simplicial): 5 projections (q, k1, k2, v1, v2, o) - GramDet (
--attention-type gram_det): 3 projections (q, k, v, o)
Lightweight finetuning on 4 GramDet layers [8, 10, 12, 14]:
python finetuning/train_hybrid.py --config finetuning/config_gramdet.yamlConfig finetuning/config_gramdet.yaml:
gram_window: 8(default) or32for wider window- Only 67M trainable parameters out of 1.2B total
- Learning rate 5e-4, cosine annealing, effective batch 16
- Checkpoints every 1000 steps in
./gramdet_1b/
On pure LLaMA (no conversion, no training):
python main.py --analyze-llamaOr cross-model:
python scripts/analyze_qwen.py
python scripts/analyze_olmo.pyEach key kⱼ is a vector in ℝᵈ. Two keys (kⱼ₁, kⱼ₂) define a point on the Grassmannian Gr(2,d) — the space of 2D planes in ℝᵈ. The analysis computes:
| Metric | What it says |
|---|---|
| Geodesic variance | How spread out the planes are (low = structured) |
| Query-plane angle | How orthogonal the query is to the mean plane |
| Anisotropy σ₁/σ₂ | Distribution concentration along one axis |
Key results on LLaMA 3.2 1B:
| Dataset | Geodesic variance | Reduction vs random |
|---|---|---|
| Random Haar Gr(2,64) | 3.85 | — |
| Wikitext | 2.14-2.55 | 38-48% |
| C4 | 2.14-2.50 | 38-49% |
The structure is marginal (not relational): each individual key has a preferred spatial position, independent of others. The shuffle test (pair swapping) yields variance identical to the original (ratio ≈ 1.0×).
Cross-model robustness: Qwen 2.5-0.5B shows a similar pattern (variance 2.24-2.58), confirming this is not an artifact of a single model. The structure is scale-independent: LLaMA 3.1 8B (Gr(2,128), baseline 4.09) shows variance 2.34-2.52, with LLaMA 3.2 1B showing proportional values.
Real benchmark on pure LLaMA with no architectural conversion. Uses kvpress (NVIDIA) to apply eviction via hooks:
pip install kvpress
python scripts/benchmark_kvpress.py # Pure LLaMA
python scripts/benchmark_kvpress.py --gramdet # GramDet step 0
python scripts/benchmark_kvpress.py --gramdet --gram-window 32 # 65-token windowStrategies compared:
- QFilterPress (NVIDIA): standard score for dot-product attention
- GrassmannianPress (ours): ‖k − P̄k‖, orthogonal component to the Fréchet mean plane
- RandomPress: random baseline
Results on pure LLaMA:
| Budget | Grassmann | QFilter | Random |
|---|---|---|---|
| 100% | 11.42 | 11.42 | 11.42 |
| 50% | 697.06 | 16.20 | 157.00 |
| 30% | 703.63 | 15.21 | 217.29 |
| 10% | 445.86 | 14.38 | 363.90 |
QFilter dominates on standard LLaMA because the ‖k − P̄k‖ score was validated for GramDet (ρ=+0.61), while standard LLaMA uses softmax(q·k/√d) — token relevance depends on alignment with the current query, not distance from the mean plane.
Results on GramDet step 0 (W=32):
| Budget | Grassmann | Random |
|---|---|---|
| 100% | 28.95 | 28.95 |
| 50% | 28.97 | 29.01 |
| 30% | 29.17 | 29.24 |
| 10% | 29.60 | 30.11 |
Grassmann beats random at all budgets. The signal is small (max Δ=0.51) and not robust to prefix/num_sequences variation — within statistical noise for 10 sequences.
Depth-selective eviction: validate_proxy reveals a clear shallow→deep gradient in Grassmannian proxy correlation on GramDet 1B step 0:
| Layer | Spearman ρ | p-value |
|---|---|---|
| 8 | -0.14 | 10⁻⁶ |
| 10 | -0.20 | 10⁻¹² |
| 12 | +0.10 | 10⁻⁴ |
| 14 | +0.31 | 10⁻³⁰ |
Lower layers (8-10) are anti-correlated, deeper layers (12-14) are positively correlated (ρ up to +0.31). The score only works where the query-plane angle is high. Selective eviction is enabled with:
python scripts/benchmark_kvpress.py --gramdet --gram-window 32 --eviction-layers 12,14Differences are in the 0.02-0.05 PPL range — entirely within statistical noise. The end-to-end benchmark on step 0 did not produce robust results. This must be honestly documented: without a genuinely trained GramDet model (from-scratch pretraining), the eviction benchmark is not informative.
python main.py --analyze-llama --dataset-name wikitext
python main.py --analyze-llama --dataset-name c4The geometric structure is robust cross-dataset: geodesic variance ±0.1 between Wikitext and C4 across all layers. The shuffle test confirms marginality (~1.0×) on both datasets.
python scripts/analyze_qwen.py # Qwen2.5-0.5B (QKV bias, altered structure)
python scripts/analyze_olmo.py # OLMo 2 7B (QK-norm, no structure)Automatic support for diverse architectures (GQA, MHA, RoPE, etc.). Predictive criterion: geodesic variance ≥30% below random → compatible with Q-filters-style eviction. OLMo (QK-norm) is not compatible; Qwen (QKV bias) has altered structure.
python scripts/grassmann_baseline.py --dim 64 --n-planes 320 --runs 5python src/kv_cache/ruler/niah_benchmark.py --model meta-llama/Llama-3.2-1Bsimplex-filters/
├── main.py # Entry point (analyze-llama, analyze, finetune, ...)
├── finetuning/
│ ├── config_gramdet.yaml # GramDet 1B config
│ ├── config.yaml # Trilinear 8B config (legacy)
│ ├── train_hybrid.py # Training loop
│ └── utils/ # Data loader, optimizer, metrics, wandb
├── src/
│ ├── modeling/
│ │ ├── gram_det_attention.py # GramDet vectorized (pure PyTorch)
│ │ ├── simplicial_attention.py # Trilinear (Triton kernel)
│ │ └── convert_to_hybrid.py # LLaMA → hybrid (GQA 4:1 expansion)
│ ├── geometry/
│ │ ├── plane.py # SVD, projector, geodesic distance
│ │ ├── grassmann.py # Fréchet mean, geodesic variance
│ │ ├── hooks.py # Forward hooks for activations
│ │ └── analyzer.py # Cross-model analysis pipeline
│ └── kv_cache/
│ ├── grassmann_press.py # GrassmannianPress for kvpress
│ ├── qfilter_score.py # Orthogonal score ‖k − P̄k‖
│ ├── eviction.py # Eviction via eviction_params
│ ├── benchmark.py # Benchmark (legacy)
│ └── ruler/ # Needle-in-a-Haystack
├── scripts/
│ ├── benchmark_kvpress.py # Eviction benchmark with kvpress
│ ├── benchmark_gramdet_step0.py # GramDet step 0 benchmark (legacy)
│ ├── analyze_qwen.py # Qwen geometric analysis
│ ├── analyze_olmo.py # OLMo geometric analysis
│ ├── baseline_ppl.py # Baseline PPL computation
│ ├── download_c4.py # Local C4 download
│ ├── validate_proxy.py # Proxy score validation
│ └── grassmann_baseline.py # Monte Carlo baseline
├── tests/ # Structural, forward/backward, numerical tests
└── simplicial_attention/ # Triton/TLX kernels (optional)
| Finding | Detail |
|---|---|
| LLaMA has latent Grassmannian structure | Variance 38-48% below random, robust cross-dataset and cross-model |
| Structure is marginal, not relational | Shuffle test ratio ≈ 1.0× |
| GramDet creates more structure than trilinear | Variance 2.34 vs 3.1 |
| Orthogonal Q-filter is the correct proxy | Spearman ρ=+0.61 for GramDet |
| QFilterPress dominates on standard LLaMA | PPL 14.38 vs 363.90 at 10% budget |
| GrassmannianPress on GramDet step 0 | Non-robust signal — differences within statistical noise (Δ PPL 0.02-0.51) |
Based on:
- Clift et al., "Logic and the 2-Simplicial Transformer", 2019
- Roy et al., "Fast and Simplex: 2-Simplicial Attention in Triton", 2025
- Godey et al. (Q-filters), 2024
- Nvidia kvpress library (2025)