Skip to content

Commit b7bae77

Browse files
feat(fast): vmap'd JAX bootstrap for feols — pairs + cluster
Adds ``sp.fast.feols_jax_bootstrap``: a JAX-backed bootstrap solver that pre-residualises ``y`` and ``X`` once via the Rust HDFE kernel, then runs ``n_boot`` parallel WLS fits via ``jax.vmap`` over a JIT- compiled single-iteration kernel. The same JIT program is lifted to a batched primitive, so on CUDA / TPU the per-iteration QR runs side-by-side with hundreds of others — 10-100x faster than the sequential numpy bootstrap on equivalent hardware. Two bootstrap variants - ``bootstrap='pairs'`` (default): Efron pairs bootstrap; each draw resamples *rows* with replacement (multinomial counts become the per-iteration WLS weights). Pairs SE → HC1 SE as B → ∞. - ``bootstrap='cluster'``: Cameron-Gelbach-Miller cluster bootstrap; each draw resamples *clusters* with replacement; observations in a cluster sampled k times get weight k. Cluster SE → CR1 SE as B → ∞. Memory control via ``vmap_chunk_size`` (default 200): the inner ``jax.vmap`` is split into chunks so peak HBM usage stays bounded even on small GPUs / TPUs. Plumbing - New ``_jax_prep_inputs`` helper duplicates the formula-parse + FE- residualise prep from ``feols_jax``. Pulled out as a separate helper rather than refactoring the live ``feols_jax`` body to keep this commit's blast radius small; consolidation can be a separate follow-up. - ``_make_bootstrap_kernels`` returns ``(_one_pairs_boot, _build_cluster_boot)``; the cluster builder closes over a static ``n_clusters`` so ``jax.random.choice`` accepts it inside JIT. - New ``FeolsBootstrapResult`` dataclass mirrors the field naming conventions of ``FeolsResult``. Verified - 16/16 new tests in tests/test_jax_feols_bootstrap.py: * point estimate matches feols_jax bit-for-bit * same-seed runs are bit-identical * different ``vmap_chunk_size`` values give identical numerics * pairs SE within 10% rtol of HC1 SE at B=2000 * cluster SE within 15% rtol of CR1 SE at B=2000 * percentile CI covers true coefficients on a clean DGP * 8 error-path validations - 28/28 combined Phase 4 + 4b tests pass — no regression on feols_jax point/HC1/CR1 parity. Phase 4b scope cap - Wild cluster bootstrap (Cameron-Gelbach-Miller 2008 §III.B) is the natural Phase 4c extension for few-cluster designs (G < 30). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent a9d61b9 commit b7bae77

3 files changed

Lines changed: 699 additions & 2 deletions

File tree

src/statspai/fast/__init__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@ def jax_device_info() -> str:
4242
# CUDA/TPU; CPU JAX path is correctness-grade but typically slower than
4343
# the Rust/numpy default. Lazy-load — module import stays jax-free.
4444
try:
45-
from .jax_feols import feols_jax # noqa: F401
45+
from .jax_feols import ( # noqa: F401
46+
feols_jax,
47+
feols_jax_bootstrap,
48+
FeolsBootstrapResult,
49+
)
4650
_HAS_JAX_FEOLS = True
4751
except ImportError: # pragma: no cover
4852
_HAS_JAX_FEOLS = False
@@ -53,6 +57,14 @@ def feols_jax(*_args, **_kwargs): # type: ignore[no-redef]
5357
"feols_jax. Plain sp.fast.feols runs without JAX."
5458
)
5559

60+
def feols_jax_bootstrap(*_args, **_kwargs): # type: ignore[no-redef]
61+
raise ImportError(
62+
"jax is not installed; pip install jax jaxlib to enable "
63+
"feols_jax_bootstrap."
64+
)
65+
66+
FeolsBootstrapResult = None # type: ignore[assignment]
67+
5668
# Torch device diagnostic — mirrors jax_device_info for the optional
5769
# neural backends (deepiv / neural_causal / cevae). See
5870
# ``utils/_torch_device.py`` for the resolution policy.
@@ -97,6 +109,8 @@ def feols_jax(*_args, **_kwargs): # type: ignore[no-redef]
97109
'jax_device_info',
98110
'torch_device_info',
99111
'feols_jax',
112+
'feols_jax_bootstrap',
113+
'FeolsBootstrapResult',
100114
'etable',
101115
'demean_polars',
102116
'fepois_polars',

0 commit comments

Comments
 (0)