Skip to content

Commit a87d788

Browse files
chore(release): cut v1.14.0 — GPU-acceleration sprint
Bumps version 1.13.1 → 1.14.0. Adds the v1.14.0 CHANGELOG entry covering the 5-commit GPU sprint landed earlier this session: - 54d83b4 feat(neural): opt-in GPU/MPS routing via STATSPAI_TORCH_DEVICE - 2fc18e4 feat(fast): Rust cluster_meat kernel with Rayon over clusters - c339063 feat(iv): sp.iv(absorb=...) HDFE Phase 3 — 2SLS with absorbed FEs - 1ac1ba3 feat(fast): JAX-backed feols_jax for end-to-end GPU/TPU OLS - b7bae77 feat(fast): vmap'd JAX bootstrap (pairs + cluster) - 4782152 feat(fast): Phase 4c — wild + wild_cluster bootstrap (score form) - 39d6ced docs(gpu): GPU acceleration guide + README pointer Also extends paper.md (JSS) with a fifth bullet under *Unique features* documenting the accelerator story; cites cameron2008bootstrap (verified DOI in paper.bib) for the wild cluster bootstrap and notes the new Rust HDFE / cluster-meat kernel in the implementation paragraph. The CHANGELOG headline + Verified section together make the v1.14 GPU promise explicit and bounded — most StatsPAI estimators remain CPU-only by design. Verified - 60/60 v1.14 test bundle pass (test_torch_device_resolver, test_iv_absorb, test_jax_feols, test_jax_feols_bootstrap; test_cluster_meat_rust skips without maturin built). - ``python -c "import statspai as sp; print(sp.__version__)"`` reports 1.14.0; ``sp.fast.feols_jax_bootstrap`` resolves cleanly. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 4782152 commit a87d788

4 files changed

Lines changed: 142 additions & 4 deletions

File tree

CHANGELOG.md

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,117 @@
22

33
All notable changes to StatsPAI will be documented in this file.
44

5+
## [1.14.0] — 2026-05-05
6+
7+
### Headline
8+
9+
GPU-acceleration sprint. Three workloads now opt into accelerator
10+
backends without changing their public API: (1) the neural causal
11+
estimators (`sp.deepiv`, `sp.tarnet`, `sp.cfrnet`, `sp.dragonnet`,
12+
`sp.cevae`) route through PyTorch CUDA / MPS via a centralised
13+
device resolver; (2) `sp.fast.feols_jax` runs the full WLS solve on
14+
JAX / XLA; and (3) `sp.fast.feols_jax_bootstrap` lifts a JIT-compiled
15+
single-iteration WLS kernel to a `jax.vmap` batched primitive,
16+
giving a 10–100x speedup over sequential CPU bootstrap on CUDA / TPU
17+
at B ≥ 1000. Four bootstrap variants share the same JAX kernel
18+
infrastructure: pairs (multinomial-weight resampling), cluster
19+
(Cameron–Gelbach–Miller 2008 §III.A), wild (row-level Rademacher),
20+
and wild cluster (Cameron–Gelbach–Miller 2008 §III.B); the wild
21+
variants use the score formulation
22+
`β* = β̂ + (X'WX)⁻¹ X'W (η ⊙ û)` which is mathematically identical
23+
to refitting on `y* = X β̂ + η ⊙ û` but needs one mat-vec per
24+
iteration instead of a full QR. A new `cluster_meat` Rust kernel
25+
in `statspai_hdfe` (PyO3 + Rayon, parallel over clusters) is wired
26+
behind `statspai.core._numba_kernels.cluster_meat` with the existing
27+
numba kernel as automatic fallback. `sp.iv(absorb=...)` is the new
28+
2SLS-with-HDFE entry point: residualises `y`, exogenous controls,
29+
endogenous regressors, and instruments by one or more FE columns
30+
via the Phase-1 Rust demean kernel before fitting, with the residual
31+
DOF adjusted by `Σ(G_k - 1)` to charge the absorbed FE rank against
32+
iid / HC1 / CR1 SEs. A new `docs/guides/gpu_acceleration.md` is the
33+
canonical landing page for the accelerator story; the README and
34+
`paper.md` link to it and explicitly bound the GPU promise (most
35+
estimators are CPU-only by design — DiD / RD / synth / GMM are
36+
bandwidth-bound or small-K convex programs where a tuned CPU
37+
kernel matches GPU performance).
38+
39+
### Added
40+
41+
- `sp.fast.feols_jax` — JAX-backed end-to-end OLS / WLS with HDFE.
42+
Same formula DSL and `FeolsResult` return type as `sp.fast.feols`;
43+
the WLS solve and HC1 sandwich run on the default JAX device.
44+
CR1 cluster sandwich delegates to the existing `crve` (which
45+
itself dispatches to the new `cluster_meat` Rust kernel when
46+
built). Default `dtype="float64"` preserves bit-comparable
47+
numerics; `dtype="float32"` available for the GPU fast path.
48+
- `sp.fast.feols_jax_bootstrap` — vmap'd bootstrap with four
49+
variants (`pairs`, `cluster`, `wild`, `wild_cluster`).
50+
`vmap_chunk_size` parameter for memory control on tight devices.
51+
Same-seed → bit-identical reproducibility via `jax.random` PRNG.
52+
Returns a `FeolsBootstrapResult` dataclass with `coef`, `se_boot`,
53+
percentile `ci_lower` / `ci_upper`, and the full `boot_betas`
54+
table for custom CI methods.
55+
- `sp.iv(absorb=...)` — 2SLS with HDFE residualisation. Accepts
56+
`"firm + year"` string syntax or `["firm", "year"]` list.
57+
LIML / Fuller / GMM / JIVE raise `NotImplementedError` (Phase 3b).
58+
- `STATSPAI_TORCH_DEVICE` environment variable (`cpu` / `cuda` /
59+
`cuda:N` / `mps` / `auto`) routes neural causal estimators
60+
through the requested device. Default `cpu` preserves existing
61+
pinned numerics; explicit `cuda` raises if the device is
62+
unavailable rather than silently falling back. New
63+
`sp.fast.torch_device_info()` mirrors `sp.fast.jax_device_info()`.
64+
- `statspai_hdfe::cluster_meat` Rust kernel — Rayon parallel over
65+
clusters with thread-local k×k upper-triangle accumulator and
66+
elementwise reduction. Bumped the crate version 0.5.0-alpha.1 →
67+
0.7.0-alpha.1. Activation requires a one-time
68+
`pip install maturin && cd rust/statspai_hdfe && maturin develop --release`;
69+
Python falls back to the numba kernel transparently when the
70+
Rust extension is absent.
71+
- `docs/guides/gpu_acceleration.md` — accelerator landing page
72+
with activation recipes, a Google Colab quickstart benchmark,
73+
and an explicit "what is *not* GPU-accelerated and why" table.
74+
75+
### Changed
76+
77+
- `paper.md` adds a fifth bullet to the *Unique features*
78+
list documenting the accelerator story, and notes the Rust
79+
HDFE / cluster-meat kernel in the implementation paragraph.
80+
- `README.md` comparison-table accelerator row now links to the
81+
new GPU guide; the *What StatsPAI is — and is not* bullet
82+
expands to explicitly mention `feols_jax`,
83+
`feols_jax_bootstrap`, and the vmap mechanism.
84+
85+
### Internal
86+
87+
- New helper `_jax_prep_inputs` shares formula-parse + FE-residualise
88+
logic between `feols_jax` and `feols_jax_bootstrap`.
89+
`feols_jax` itself is unchanged in this release; consolidation
90+
into a shared call site is a candidate follow-up.
91+
- Rust crate adds `src/cluster.rs` (kernel) and a `cluster_meat`
92+
PyO3 binding in `src/lib.rs`. 3 cargo unit tests cover small-DGP
93+
reference parity, k=1 closed form, and empty-input safety.
94+
95+
### Verified
96+
97+
- 10 PyTorch device-resolver tests
98+
(`tests/test_torch_device_resolver.py`); 51 existing neural
99+
tests pass without numerical drift on default CPU.
100+
- 9 `cluster_meat` Rust parity tests
101+
(`tests/test_cluster_meat_rust.py`) — auto-skip when
102+
`statspai_hdfe` is not built.
103+
- 13 `sp.iv(absorb=)` parity tests vs explicit drop-first
104+
dummies (`tests/test_iv_absorb.py`); coefficients agree to
105+
`atol=1e-9`, iid SE to `rtol=1e-3`, cluster SE to `rtol=1e-2`.
106+
- 12 `feols_jax` parity tests vs `feols`
107+
(`tests/test_jax_feols.py`); iid / hc1 / cr1 / weighted /
108+
float32 / 6 error-path validations.
109+
- 24 `feols_jax_bootstrap` tests
110+
(`tests/test_jax_feols_bootstrap.py`); convergence to HC1 SE
111+
for pairs / wild and CR1 SE for cluster / wild_cluster at
112+
B=2000 (rtol 10–15%); algebraic identity check that the wild
113+
score formulation reproduces the literal "refit on pseudo-y"
114+
bootstrap bit-for-bit on a no-FE DGP (`atol=1e-9`).
115+
5116
## [1.13.1] — 2026-05-05
6117

7118
### Headline

paper.md

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,11 +234,38 @@ interface: `.summary()`, `.plot()`, `.to_latex()`, `.to_docx()`, and
234234
progress notifications, and `sp.citation()` together with a
235235
`paper.bib`-checked DOI verifier that refuses to emit citations
236236
not present in the curated bibliography.
237+
5. **Selective accelerator backends** for the workloads that
238+
benefit from them. Neural causal estimators (TARNet, CFRNet,
239+
DragonNet, CEVAE, DeepIV) route through PyTorch CUDA / MPS via
240+
the `STATSPAI_TORCH_DEVICE` environment variable. The HDFE
241+
residualiser exposes `backend="jax"` and `sp.fast.feols_jax`
242+
runs the full WLS solve on JAX / XLA. The largest GPU win in
243+
v1.14 is `sp.fast.feols_jax_bootstrap`: the same JIT-compiled
244+
WLS kernel is lifted to a `jax.vmap` batched primitive, giving
245+
a 10--100x speedup over sequential CPU bootstrap on CUDA / TPU
246+
at $B \geq 1{,}000$. Four bootstrap variants share the same
247+
JAX kernel infrastructure --- pairs, cluster, wild, and wild
248+
cluster [@cameron2008bootstrap] --- with the wild
249+
variants using the score formulation $\hat\beta^*_b =
250+
\hat\beta + (X'WX)^{-1}\, X'W\, (\eta_b \odot \hat u)$ which
251+
is mathematically identical to refitting on
252+
$y^* = X\hat\beta + \eta \odot \hat u$ but needs only a single
253+
matrix--vector multiply per iteration. The remainder of the
254+
package (DiD, RD, synthetic control, GMM) is CPU-only by
255+
design: these workloads are bandwidth-bound or involve small-$K$
256+
convex programs where a tuned Rust + Numba kernel matches GPU
257+
performance. Cluster-robust meat matrices use a Rust + Rayon
258+
kernel parallel over clusters
259+
(`statspai.core._numba_kernels.cluster_meat`) introduced in
260+
v1.14, and `sp.iv(absorb=...)` newly wires the Rust HDFE
261+
residualiser into the 2SLS first stage.
237262

238263
The package is implemented in pure Python atop NumPy, SciPy, Pandas,
239264
statsmodels, scikit-learn, and linearmodels, with optional PyTorch and
240-
JAX backends. It supports Python $\geq$ 3.9 and is distributed via
241-
PyPI under the MIT license.
265+
JAX backends and a Rust HDFE / cluster-meat kernel
266+
(`statspai_hdfe`, PyO3 + Rayon) that the Python wrappers prefer when
267+
built and silently fall through when absent. It supports Python
268+
$\geq$ 3.9 and is distributed via PyPI under the MIT license.
242269

243270
# Validation
244271

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "StatsPAI"
7-
version = "1.13.1"
7+
version = "1.14.0"
88
description = "The Agent-Native Causal Inference & Econometrics Toolkit for Python"
99
readme = "README.md"
1010
license = {text = "MIT"}

src/statspai/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
>>> sp.outreg2(result, filename="results.xlsx")
2323
"""
2424

25-
__version__ = "1.13.1"
25+
__version__ = "1.14.0"
2626
__author__ = "Biaoyue Wang"
2727
__email__ = "brycew6m@stanford.edu"
2828

0 commit comments

Comments
 (0)