Skip to content

and-per-i/simplex-filters

Repository files navigation

simplex-filters

2-Simplicial Attention: Geometry on the Grassmannian and KV Cache Filtering

Thesis project: implementation of 2-simplicial attention for Transformers (trilinear and Gram determinant), geometric analysis of key vector planes on the Grassmannian of LLaMA, and KV cache filtering via a Grassmannian Q-filter score.

License


Why this project

Standard Transformer attention is a dot product between query Q and key K (bilinear form). This project explores two trilinear generalizations:

Type Formula What it measures
Dot-product (standard) (\langle q, k \rangle) Similarity between 2 vectors
Trilinear (\langle q, k_1, k_2 \rangle) 3-body interaction
Gram determinant det(Gram(q, k₁, k₂)) Volume of the parallelepiped span(q, k₁, k₂)

The GramDet variant has a direct geometric interpretation: the higher the determinant, the more linearly independent the three vectors — and thus the more complementary the information they carry.

Core contribution: LLaMA possesses latent Grassmannian structure on Gr(2,d) — the planes generated by pairs of key vectors have geodesic variance 30-38% below a random Haar sample, proving that key space is not isotropic but has intrinsic geometry. This structure is robust cross-dataset (Wikitext, C4) and cross-scale (1B, 8B), and predicts compatibility with Q-filters-style eviction.


Installation

git clone https://github.com/and-per-i/simplex-filters.git
cd simplex-filters
pip install -r requirements.txt
pip install -e simplicial_attention/   # optional Triton/TLX kernels

Requirements: Python 3.10+, CUDA GPU, HF_TOKEN to download weights from HuggingFace.

export HF_TOKEN=hf_your_token

What the project does

1. Two 2-simplicial attention mechanisms

Replaces standard dot-product attention in selected LLaMA layers with two variants:

from src.modeling.gram_det_attention import GramDetAttention

# GramDet: score = det(Gram(q, k1, k2)) with configurable window
# Window of 17 tokens (W=8) → 153 pairs, 65 tokens (W=32) → 2145 pairs

Conversion handled by convert_llama_to_hybrid():

  • Trilinear (--attention-type simplicial): 5 projections (q, k1, k2, v1, v2, o)
  • GramDet (--attention-type gram_det): 3 projections (q, k, v, o)

2. Training on C4 (LLaMA 3.2 1B)

Lightweight finetuning on 4 GramDet layers [8, 10, 12, 14]:

python finetuning/train_hybrid.py --config finetuning/config_gramdet.yaml

Config finetuning/config_gramdet.yaml:

  • gram_window: 8 (default) or 32 for wider window
  • Only 67M trainable parameters out of 1.2B total
  • Learning rate 5e-4, cosine annealing, effective batch 16
  • Checkpoints every 1000 steps in ./gramdet_1b/

3. Grassmannian geometric analysis

On pure LLaMA (no conversion, no training):

python main.py --analyze-llama

Or cross-model:

python scripts/analyze_qwen.py
python scripts/analyze_olmo.py

Each key kⱼ is a vector in ℝᵈ. Two keys (kⱼ₁, kⱼ₂) define a point on the Grassmannian Gr(2,d) — the space of 2D planes in ℝᵈ. The analysis computes:

Metric What it says
Geodesic variance How spread out the planes are (low = structured)
Query-plane angle How orthogonal the query is to the mean plane
Anisotropy σ₁/σ₂ Distribution concentration along one axis

Key results on LLaMA 3.2 1B:

Dataset Geodesic variance Reduction vs random
Random Haar Gr(2,64) 3.85
Wikitext 2.14-2.55 38-48%
C4 2.14-2.50 38-49%

The structure is marginal (not relational): each individual key has a preferred spatial position, independent of others. The shuffle test (pair swapping) yields variance identical to the original (ratio ≈ 1.0×).

Cross-model robustness: Qwen 2.5-0.5B shows a similar pattern (variance 2.24-2.58), confirming this is not an artifact of a single model. The structure is scale-independent: LLaMA 3.1 8B (Gr(2,128), baseline 4.09) shows variance 2.34-2.52, with LLaMA 3.2 1B showing proportional values.

4. KV Cache Eviction Benchmark (kvpress)

Real benchmark on pure LLaMA with no architectural conversion. Uses kvpress (NVIDIA) to apply eviction via hooks:

pip install kvpress
python scripts/benchmark_kvpress.py                    # Pure LLaMA
python scripts/benchmark_kvpress.py --gramdet           # GramDet step 0
python scripts/benchmark_kvpress.py --gramdet --gram-window 32  # 65-token window

Strategies compared:

  • QFilterPress (NVIDIA): standard score for dot-product attention
  • GrassmannianPress (ours): ‖k − P̄k‖, orthogonal component to the Fréchet mean plane
  • RandomPress: random baseline

Results on pure LLaMA:

Budget Grassmann QFilter Random
100% 11.42 11.42 11.42
50% 697.06 16.20 157.00
30% 703.63 15.21 217.29
10% 445.86 14.38 363.90

QFilter dominates on standard LLaMA because the ‖k − P̄k‖ score was validated for GramDet (ρ=+0.61), while standard LLaMA uses softmax(q·k/√d) — token relevance depends on alignment with the current query, not distance from the mean plane.

Results on GramDet step 0 (W=32):

Budget Grassmann Random
100% 28.95 28.95
50% 28.97 29.01
30% 29.17 29.24
10% 29.60 30.11

Grassmann beats random at all budgets. The signal is small (max Δ=0.51) and not robust to prefix/num_sequences variation — within statistical noise for 10 sequences.

Depth-selective eviction: validate_proxy reveals a clear shallow→deep gradient in Grassmannian proxy correlation on GramDet 1B step 0:

Layer Spearman ρ p-value
8 -0.14 10⁻⁶
10 -0.20 10⁻¹²
12 +0.10 10⁻⁴
14 +0.31 10⁻³⁰

Lower layers (8-10) are anti-correlated, deeper layers (12-14) are positively correlated (ρ up to +0.31). The score only works where the query-plane angle is high. Selective eviction is enabled with:

python scripts/benchmark_kvpress.py --gramdet --gram-window 32 --eviction-layers 12,14

Differences are in the 0.02-0.05 PPL range — entirely within statistical noise. The end-to-end benchmark on step 0 did not produce robust results. This must be honestly documented: without a genuinely trained GramDet model (from-scratch pretraining), the eviction benchmark is not informative.

5. Cross-dataset validation (Wikitext vs C4)

python main.py --analyze-llama --dataset-name wikitext
python main.py --analyze-llama --dataset-name c4

The geometric structure is robust cross-dataset: geodesic variance ±0.1 between Wikitext and C4 across all layers. The shuffle test confirms marginality (~1.0×) on both datasets.

6. Cross-model analysis

python scripts/analyze_qwen.py  # Qwen2.5-0.5B (QKV bias, altered structure)
python scripts/analyze_olmo.py  # OLMo 2 7B (QK-norm, no structure)

Automatic support for diverse architectures (GQA, MHA, RoPE, etc.). Predictive criterion: geodesic variance ≥30% below random → compatible with Q-filters-style eviction. OLMo (QK-norm) is not compatible; Qwen (QKV bias) has altered structure.

7. Monte Carlo baseline for Grassmannian

python scripts/grassmann_baseline.py --dim 64 --n-planes 320 --runs 5

8. RULER Benchmark (Needle-in-a-Haystack)

python src/kv_cache/ruler/niah_benchmark.py --model meta-llama/Llama-3.2-1B

Project structure

simplex-filters/
├── main.py                              # Entry point (analyze-llama, analyze, finetune, ...)
├── finetuning/
│   ├── config_gramdet.yaml               # GramDet 1B config
│   ├── config.yaml                       # Trilinear 8B config (legacy)
│   ├── train_hybrid.py                   # Training loop
│   └── utils/                            # Data loader, optimizer, metrics, wandb
├── src/
│   ├── modeling/
│   │   ├── gram_det_attention.py          # GramDet vectorized (pure PyTorch)
│   │   ├── simplicial_attention.py        # Trilinear (Triton kernel)
│   │   └── convert_to_hybrid.py           # LLaMA → hybrid (GQA 4:1 expansion)
│   ├── geometry/
│   │   ├── plane.py                       # SVD, projector, geodesic distance
│   │   ├── grassmann.py                   # Fréchet mean, geodesic variance
│   │   ├── hooks.py                       # Forward hooks for activations
│   │   └── analyzer.py                    # Cross-model analysis pipeline
│   └── kv_cache/
│       ├── grassmann_press.py             # GrassmannianPress for kvpress
│       ├── qfilter_score.py               # Orthogonal score ‖k − P̄k‖
│       ├── eviction.py                    # Eviction via eviction_params
│       ├── benchmark.py                   # Benchmark (legacy)
│       └── ruler/                         # Needle-in-a-Haystack
├── scripts/
│   ├── benchmark_kvpress.py               # Eviction benchmark with kvpress
│   ├── benchmark_gramdet_step0.py         # GramDet step 0 benchmark (legacy)
│   ├── analyze_qwen.py                    # Qwen geometric analysis
│   ├── analyze_olmo.py                    # OLMo geometric analysis
│   ├── baseline_ppl.py                    # Baseline PPL computation
│   ├── download_c4.py                     # Local C4 download
│   ├── validate_proxy.py                  # Proxy score validation
│   └── grassmann_baseline.py              # Monte Carlo baseline
├── tests/                                 # Structural, forward/backward, numerical tests
└── simplicial_attention/                  # Triton/TLX kernels (optional)

Experimental findings (summary)

Finding Detail
LLaMA has latent Grassmannian structure Variance 38-48% below random, robust cross-dataset and cross-model
Structure is marginal, not relational Shuffle test ratio ≈ 1.0×
GramDet creates more structure than trilinear Variance 2.34 vs 3.1
Orthogonal Q-filter is the correct proxy Spearman ρ=+0.61 for GramDet
QFilterPress dominates on standard LLaMA PPL 14.38 vs 363.90 at 10% budget
GrassmannianPress on GramDet step 0 Non-robust signal — differences within statistical noise (Δ PPL 0.02-0.51)

Citing

Based on:

  • Clift et al., "Logic and the 2-Simplicial Transformer", 2019
  • Roy et al., "Fast and Simplex: 2-Simplicial Attention in Triton", 2025
  • Godey et al. (Q-filters), 2024
  • Nvidia kvpress library (2025)

About

Grassmannian geometry meets LLM inference: latent structure in LLaMA key planes and Q-filters-style KV cache eviction

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors