Commit b7bae77
feat(fast): vmap'd JAX bootstrap for feols — pairs + cluster
Adds ``sp.fast.feols_jax_bootstrap``: a JAX-backed bootstrap solver
that pre-residualises ``y`` and ``X`` once via the Rust HDFE kernel,
then runs ``n_boot`` parallel WLS fits via ``jax.vmap`` over a JIT-
compiled single-iteration kernel. The same JIT program is lifted to
a batched primitive, so on CUDA / TPU the per-iteration QR runs
side-by-side with hundreds of others — 10-100x faster than the
sequential numpy bootstrap on equivalent hardware.
Two bootstrap variants
- ``bootstrap='pairs'`` (default): Efron pairs bootstrap; each draw
resamples *rows* with replacement (multinomial counts become the
per-iteration WLS weights). Pairs SE → HC1 SE as B → ∞.
- ``bootstrap='cluster'``: Cameron-Gelbach-Miller cluster bootstrap;
each draw resamples *clusters* with replacement; observations in a
cluster sampled k times get weight k. Cluster SE → CR1 SE as B → ∞.
Memory control via ``vmap_chunk_size`` (default 200): the inner
``jax.vmap`` is split into chunks so peak HBM usage stays bounded
even on small GPUs / TPUs.
Plumbing
- New ``_jax_prep_inputs`` helper duplicates the formula-parse + FE-
residualise prep from ``feols_jax``. Pulled out as a separate
helper rather than refactoring the live ``feols_jax`` body to keep
this commit's blast radius small; consolidation can be a separate
follow-up.
- ``_make_bootstrap_kernels`` returns ``(_one_pairs_boot,
_build_cluster_boot)``; the cluster builder closes over a static
``n_clusters`` so ``jax.random.choice`` accepts it inside JIT.
- New ``FeolsBootstrapResult`` dataclass mirrors the field naming
conventions of ``FeolsResult``.
Verified
- 16/16 new tests in tests/test_jax_feols_bootstrap.py:
* point estimate matches feols_jax bit-for-bit
* same-seed runs are bit-identical
* different ``vmap_chunk_size`` values give identical numerics
* pairs SE within 10% rtol of HC1 SE at B=2000
* cluster SE within 15% rtol of CR1 SE at B=2000
* percentile CI covers true coefficients on a clean DGP
* 8 error-path validations
- 28/28 combined Phase 4 + 4b tests pass — no regression on
feols_jax point/HC1/CR1 parity.
Phase 4b scope cap
- Wild cluster bootstrap (Cameron-Gelbach-Miller 2008 §III.B) is the
natural Phase 4c extension for few-cluster designs (G < 30).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>1 parent a9d61b9 commit b7bae77
3 files changed
Lines changed: 699 additions & 2 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
42 | 42 | | |
43 | 43 | | |
44 | 44 | | |
45 | | - | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
46 | 50 | | |
47 | 51 | | |
48 | 52 | | |
| |||
53 | 57 | | |
54 | 58 | | |
55 | 59 | | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
56 | 68 | | |
57 | 69 | | |
58 | 70 | | |
| |||
97 | 109 | | |
98 | 110 | | |
99 | 111 | | |
| 112 | + | |
| 113 | + | |
100 | 114 | | |
101 | 115 | | |
102 | 116 | | |
| |||
0 commit comments