|
| 1 | +# GPU acceleration in StatsPAI |
| 2 | + |
| 3 | +> **TL;DR.** As of v1.14, three workloads in StatsPAI route to CUDA / TPU |
| 4 | +> when an accelerator is available: (1) the neural causal estimators |
| 5 | +> (`sp.deepiv`, `sp.tarnet`, `sp.cfrnet`, `sp.dragonnet`, `sp.cevae`) |
| 6 | +> via PyTorch; (2) end-to-end OLS with HDFE via `sp.fast.feols_jax`; |
| 7 | +> and (3) **vmap'd bootstrap** via `sp.fast.feols_jax_bootstrap` — |
| 8 | +> the largest GPU win per line of user code. |
| 9 | +> |
| 10 | +> Everything else in StatsPAI is CPU-only and intentionally so: most |
| 11 | +> econometric estimators (DiD, IV, RD, synthetic control, fixest-style |
| 12 | +> HDFE OLS, GMM) are dominated by combinatorial / memory-bound work |
| 13 | +> where GPUs offer no speedup over a tuned Rust + Numba kernel. |
| 14 | +
|
| 15 | +--- |
| 16 | + |
| 17 | +## What is GPU-accelerated today? |
| 18 | + |
| 19 | +| Workload | Function | Backend | Activation | |
| 20 | +| --- | --- | --- | --- | |
| 21 | +| Neural causal: representation networks (TARNet / CFRNet / DragonNet / CEVAE) | `sp.tarnet` / `sp.cfrnet` / `sp.dragonnet` / `sp.cevae` | PyTorch | `STATSPAI_TORCH_DEVICE={cuda,mps,auto}` env var | |
| 22 | +| Neural IV: Deep IV (Hartford et al. 2017) | `sp.deepiv` | PyTorch | same env var | |
| 23 | +| HDFE demean (alternating projection) | `sp.fast.demean(backend="jax")` | JAX / XLA | install `jax[cuda]` | |
| 24 | +| OLS / WLS with HDFE | `sp.fast.feols_jax` | JAX / XLA | install `jax[cuda]` | |
| 25 | +| **Bootstrap (pairs / cluster)** | `sp.fast.feols_jax_bootstrap` | JAX / XLA `vmap` | install `jax[cuda]` | |
| 26 | + |
| 27 | +The CPU paths (`sp.fast.demean`, `sp.fast.feols`, `sp.fast.fepois`, |
| 28 | +`sp.fast.boottest`, `sp.iv`, `sp.did`, `sp.rd`, `sp.synth`, …) all |
| 29 | +remain the production defaults and ship without any accelerator |
| 30 | +dependency. |
| 31 | + |
| 32 | +--- |
| 33 | + |
| 34 | +## Why bootstrap is the headline GPU win |
| 35 | + |
| 36 | +Single-shot OLS / WLS is **dominated by host↔device transfer overhead** |
| 37 | +on small-to-medium datasets — the actual QR factorisation is too cheap |
| 38 | +for GPU speedup to recover the wire cost. |
| 39 | + |
| 40 | +Bootstrap inverts this: the *same* JIT-compiled WLS program is lifted |
| 41 | +to a `jax.vmap` batch primitive and runs B times in lock-step on the |
| 42 | +device. On CUDA / TPU this approaches `B / utilisation × per-iteration |
| 43 | +time`; on CPU JAX it's still ~equal to a numpy sequential bootstrap |
| 44 | +(JIT overhead amortises around B ≈ 100). The speedup curve crosses |
| 45 | +favourably very quickly. |
| 46 | + |
| 47 | +**Pairs bootstrap** (Efron 1979): each draw resamples *rows* with |
| 48 | +replacement; multinomial counts become per-row WLS weights. Asymptotic |
| 49 | +target: HC1 standard errors. |
| 50 | + |
| 51 | +**Cluster bootstrap** (Cameron, Gelbach & Miller 2008 §III.A): each |
| 52 | +draw resamples *clusters* with replacement; observations in a cluster |
| 53 | +sampled k times get weight k. Asymptotic target: CR1 standard errors. |
| 54 | + |
| 55 | +```python |
| 56 | +import statspai as sp |
| 57 | + |
| 58 | +boot = sp.fast.feols_jax_bootstrap( |
| 59 | + "log_wage ~ schooling + experience | firm + year", |
| 60 | + data=df, |
| 61 | + n_boot=2_000, |
| 62 | + bootstrap="cluster", # or "pairs" |
| 63 | + cluster="firm", |
| 64 | + ci_alpha=0.05, |
| 65 | +) |
| 66 | +print(boot.summary()) |
| 67 | +print(boot.se_boot) |
| 68 | +print(boot.boot_betas) # full B × p draws for custom CI methods |
| 69 | +``` |
| 70 | + |
| 71 | +--- |
| 72 | + |
| 73 | +## Quickstart on Google Colab |
| 74 | + |
| 75 | +The fastest way to verify GPU acceleration without buying hardware is |
| 76 | +[Google Colab](https://colab.research.google.com/) Pro (≈ USD 10/month |
| 77 | +for T4 / V100, USD 50/month for A100). The free tier is also enough |
| 78 | +for proof-of-concept benchmarks. |
| 79 | + |
| 80 | +```python |
| 81 | +# In a Colab notebook with a GPU runtime selected |
| 82 | +!pip install -q statspai jax[cuda12] jaxlib |
| 83 | + |
| 84 | +import statspai as sp |
| 85 | +print(sp.fast.jax_device_info()) |
| 86 | +# Expect: jax: <version>, default device: cuda |
| 87 | + |
| 88 | +# Build a benchmark dataset |
| 89 | +import numpy as np, pandas as pd |
| 90 | +rng = np.random.default_rng(0) |
| 91 | +n, n_firm = 1_000_000, 5_000 |
| 92 | +firm = rng.integers(0, n_firm, size=n) |
| 93 | +fe = rng.normal(size=n_firm)[firm] |
| 94 | +df = pd.DataFrame({ |
| 95 | + "y": 0.5 * rng.normal(size=n) + fe, |
| 96 | + "x1": rng.normal(size=n), |
| 97 | + "x2": rng.normal(size=n), |
| 98 | + "firm": firm, |
| 99 | +}) |
| 100 | + |
| 101 | +# Time CPU vs GPU |
| 102 | +import time |
| 103 | + |
| 104 | +t0 = time.perf_counter() |
| 105 | +for _ in range(2_000): |
| 106 | + _ = sp.fast.feols("y ~ x1 + x2 | firm", df, vcov="hc1") |
| 107 | +print(f"CPU sequential bootstrap (B=2000): {time.perf_counter() - t0:.1f}s") |
| 108 | + |
| 109 | +t0 = time.perf_counter() |
| 110 | +boot = sp.fast.feols_jax_bootstrap( |
| 111 | + "y ~ x1 + x2 | firm", df, n_boot=2_000, bootstrap="pairs", |
| 112 | + vmap_chunk_size=500, # tune up for big-HBM GPUs |
| 113 | +) |
| 114 | +print(f"GPU vmap'd bootstrap (B=2000): {time.perf_counter() - t0:.1f}s") |
| 115 | +``` |
| 116 | + |
| 117 | +**Expected result on a T4 / V100 / A100:** the JAX path beats the |
| 118 | +sequential CPU loop by 10–100x once `n` × `B` is large enough to |
| 119 | +saturate the device. |
| 120 | + |
| 121 | +--- |
| 122 | + |
| 123 | +## PyTorch GPU for neural causal |
| 124 | + |
| 125 | +Setting the `STATSPAI_TORCH_DEVICE` environment variable (or having |
| 126 | +`torch.cuda.is_available()` true with `auto`) routes neural backends |
| 127 | +through CUDA / MPS: |
| 128 | + |
| 129 | +```bash |
| 130 | +export STATSPAI_TORCH_DEVICE=cuda # or 'auto', 'mps', 'cpu' |
| 131 | +``` |
| 132 | + |
| 133 | +```python |
| 134 | +import statspai as sp |
| 135 | +print(sp.fast.torch_device_info()) |
| 136 | +# Expect: torch <version> | cuda available (1 device(s)) | resolved=cuda |
| 137 | + |
| 138 | +# All of these will train on GPU when the env var resolves to cuda/mps |
| 139 | +sp.tarnet(df, y="y", treat="d", covariates=["x1", "x2"]) |
| 140 | +sp.cfrnet(df, y="y", treat="d", covariates=["x1", "x2"]) |
| 141 | +sp.dragonnet(df, y="y", treat="d", covariates=["x1", "x2"]) |
| 142 | +sp.cevae(df, y="y", treat="d", covariates=["x1", "x2"]) |
| 143 | +sp.deepiv(df, y="y", treat="d", instruments=["z"], covariates=["x1"]) |
| 144 | +``` |
| 145 | + |
| 146 | +The default is `cpu` to preserve bit-for-bit numerics on existing |
| 147 | +pinned tests; `auto` on Apple Silicon falls through to MPS (Metal |
| 148 | +Performance Shaders) when CUDA is unavailable. |
| 149 | + |
| 150 | +--- |
| 151 | + |
| 152 | +## What is *not* GPU-accelerated, and why |
| 153 | + |
| 154 | +| Family | Status | Why no GPU | |
| 155 | +| --- | --- | --- | |
| 156 | +| HDFE alternating-projection demean (CPU default) | Rust + Rayon | Bincount-style memory pattern is bandwidth-bound; tuned Rust matches GPU at typical FE counts. | |
| 157 | +| Cluster-robust sandwich `crve()` | Rust + Rayon (Phase 2) | Same — the per-cluster reduce is bandwidth-bound. | |
| 158 | +| Synthetic control (Abadie 2003 family, GSC, Augmented SC) | NumPy + scipy | Optimisation is small-K convex programs; no batch dimension to vmap over. | |
| 159 | +| DiD estimators (Callaway-Sant'Anna, Sun-Abraham, BJS) | NumPy + pandas | Group-by-time accumulation is sequential; per-cohort fits are tiny. | |
| 160 | +| Regression discontinuity | NumPy + scipy | Local-poly bandwidth choice is sequential. | |
| 161 | +| GMM / IV / 2SLS | NumPy + scipy | Single-shot dense linalg; same constant-cost story as `feols_jax`. | |
| 162 | +| Bayesian causal (PyMC) | NumPyro / JAX backend optional | Routing to GPU works *via PyMC*; we don't reimplement. | |
| 163 | + |
| 164 | +Future GPU candidates (open issues welcome): |
| 165 | +- **Permutation tests / placebo studies** — `vmap` over permutations is |
| 166 | + the obvious follow-up to bootstrap. |
| 167 | +- **DML cross-fitting** — k-fold parallel nuisance fits. |
| 168 | +- **Synthetic control matrix completion** — large-K SVD on GPU. |
| 169 | +- **Wild cluster bootstrap (Cameron-Gelbach-Miller §III.B)** — |
| 170 | + Phase 4c; closely related to the existing pairs / cluster bootstrap. |
| 171 | +- **Causal forest training** — wire `xgboost` / `cuml` for tree fits. |
| 172 | + |
| 173 | +--- |
| 174 | + |
| 175 | +## Reproducibility |
| 176 | + |
| 177 | +JAX uses an explicit PRNG. `seed=` is honoured; same seed → bit- |
| 178 | +identical bootstrap draws on the same device: |
| 179 | + |
| 180 | +```python |
| 181 | +b1 = sp.fast.feols_jax_bootstrap("y ~ x1 | firm", df, n_boot=500, seed=42) |
| 182 | +b2 = sp.fast.feols_jax_bootstrap("y ~ x1 | firm", df, n_boot=500, seed=42) |
| 183 | +assert (b1.boot_betas.values == b2.boot_betas.values).all() |
| 184 | +``` |
| 185 | + |
| 186 | +Numerics across devices (CPU JAX vs CUDA vs TPU) can differ by ~1–2 ulp |
| 187 | +because XLA reduction order is not guaranteed identical across |
| 188 | +hardware. For coefficient-level reporting this is well below |
| 189 | +econometric noise; for SE-level reporting see the convergence-rate |
| 190 | +notes in the docstrings. |
| 191 | + |
| 192 | +--- |
| 193 | + |
| 194 | +## Honesty check |
| 195 | + |
| 196 | +The GPU story in v1.14 is **opt-in and selective**. We deliberately |
| 197 | +don't claim "StatsPAI is GPU-accelerated" — most of the package is |
| 198 | +CPU-only and that's the right design for the workloads we cover. The |
| 199 | +GPU path matters for two specific cases: |
| 200 | + |
| 201 | +1. **Neural causal training** — already a torch-native workload; the |
| 202 | + only thing we contributed was the unified device routing. |
| 203 | +2. **Bootstrap-heavy inference** — where the speedup is real and |
| 204 | + measurable, especially at B ≥ 1000 on n ≥ 100k. |
| 205 | + |
| 206 | +If your workflow is "fit one DiD / IV / RD on a 10k-row sample," a |
| 207 | +laptop CPU is probably already as fast as a cloud GPU once you account |
| 208 | +for package import + JIT compile time. **Buy a GPU if you're either |
| 209 | +training neural causal models in volume, or doing high-B cluster |
| 210 | +bootstrap on large panels.** |
0 commit comments