Skip to content

Commit 2f61a54

Browse files
Merge pull request #9 from brycewang-stanford/feat/arima-se-accessor
feat(arima): expose standard errors on ARIMAResult (#7)
2 parents d092bf6 + b5c8eea commit 2f61a54

7 files changed

Lines changed: 560 additions & 19 deletions

File tree

.github/workflows/citation-audit.yml

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,27 @@ jobs:
138138
# Live verification against arXiv / NBER / Crossref. --strict:
139139
# unresolved IDs fail alongside mismatches, so a typo that
140140
# breaks primary-source lookup is caught early.
141-
run: python tools/audit_citations.py --strict --out audit_report.md
141+
#
142+
# Exit-code contract (tools/audit_citations.py main()):
143+
# 0 — clean.
144+
# 1 — mismatch, or a GENUINE unresolved id (source reachable
145+
# but the id is absent) → a real §10 zero-hallucination
146+
# failure. Blocks the merge.
147+
# 2 — soft failure: the ONLY unresolved ids were transient
148+
# upstream errors (arXiv / Crossref 429 rate-limit on the
149+
# shared runner IP, or a network blip). NOT a bad citation,
150+
# so it must not block a merge — we surface it as a warning
151+
# and pass. The auditor already retries 429/5xx with
152+
# back-off before giving up.
153+
run: |
154+
set +e
155+
python tools/audit_citations.py --strict --out audit_report.md
156+
code=$?
157+
if [ "$code" -eq 2 ]; then
158+
echo "::warning title=Citation audit soft failure::Auditor could not reach arXiv/Crossref (rate limit / network); no mismatch detected — treating as a soft pass (exit 2)."
159+
exit 0
160+
fi
161+
exit "$code"
142162
143163
- name: Upload citation audit report
144164
if: always()

src/statspai/timeseries/arima.py

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class ARIMAResult:
2525
order: Tuple[int, int, int]
2626
seasonal_order: Optional[Tuple[int, int, int, int]]
2727
params: pd.Series
28+
se: pd.Series # asymptotic standard errors (param index)
2829
aic: float
2930
bic: float
3031
aicc: float
@@ -34,6 +35,44 @@ class ARIMAResult:
3435
n: int
3536
_model: object # statsmodels result (opaque)
3637

38+
# --- inference accessors -------------------------------------------------
39+
@property
40+
def std_errors(self) -> pd.Series:
41+
"""Alias for :attr:`se` (regression-style naming)."""
42+
return self.se
43+
44+
@property
45+
def tvalues(self) -> pd.Series:
46+
"""z-statistics ``params / se`` (SARIMAX uses a normal reference)."""
47+
return self.params / self.se
48+
49+
@property
50+
def pvalues(self) -> pd.Series:
51+
"""Two-sided p-values from the normal reference distribution."""
52+
from scipy import stats
53+
z = (self.params / self.se).to_numpy()
54+
return pd.Series(2.0 * stats.norm.sf(np.abs(z)), index=self.params.index)
55+
56+
def conf_int(self, alpha: float = 0.05) -> pd.DataFrame:
57+
"""Confidence intervals for each parameter.
58+
59+
Parameters
60+
----------
61+
alpha : float, default 0.05
62+
``1 - alpha`` is the coverage (0.05 → 95% CI).
63+
64+
Returns
65+
-------
66+
pd.DataFrame
67+
Indexed by parameter name with ``lower`` / ``upper`` columns.
68+
"""
69+
from scipy import stats
70+
z = stats.norm.ppf(1.0 - alpha / 2.0)
71+
lower = self.params - z * self.se
72+
upper = self.params + z * self.se
73+
return pd.DataFrame({"lower": lower, "upper": upper},
74+
index=self.params.index)
75+
3776
def forecast(self, horizon: int = 10, alpha: float = 0.05) -> pd.DataFrame:
3877
fc = self._model.get_forecast(steps=horizon)
3978
pred = np.asarray(fc.predicted_mean).ravel()
@@ -71,10 +110,16 @@ def summary(self) -> str:
71110
f"AICc : {self.aicc:.2f}",
72111
f"Log-Lik : {self.log_likelihood:.2f}",
73112
"",
74-
"Parameters:",
113+
f" {'':<15s} {'coef':>10s} {'std err':>10s} {'z':>8s} {'P>|z|':>8s}",
75114
]
115+
pvals = self.pvalues
76116
for nm, val in self.params.items():
77-
lines.append(f" {nm:<15s} {val: .4f}")
117+
s = float(self.se.get(nm, np.nan))
118+
z = val / s if s and np.isfinite(s) else np.nan
119+
p = float(pvals.get(nm, np.nan))
120+
lines.append(
121+
f" {nm:<15s} {val:>10.4f} {s:>10.4f} {z:>8.3f} {p:>8.3f}"
122+
)
78123
return "\n".join(lines)
79124

80125
def __repr__(self) -> str:
@@ -104,6 +149,21 @@ def arima(
104149
If True, select (p, d, q) by AICc grid search (ignores ``order``).
105150
max_p, max_q, max_d : int
106151
Bounds for the auto search.
152+
153+
Returns
154+
-------
155+
ARIMAResult
156+
Exposes ``params`` and the matching standard errors ``se`` (alias
157+
``std_errors``), plus ``tvalues``, ``pvalues``, and
158+
``conf_int(alpha)`` for inference, alongside ``aic`` / ``bic`` /
159+
``aicc`` / ``log_likelihood`` and ``forecast`` / ``plot``.
160+
161+
Examples
162+
--------
163+
>>> import statspai as sp
164+
>>> res = sp.arima(df["gdp"], order=(2, 0, 0))
165+
>>> res.se # standard errors, indexed by parameter name
166+
>>> res.conf_int() # 95% confidence intervals
107167
"""
108168
try:
109169
from statsmodels.tsa.statespace.sarimax import SARIMAX
@@ -148,10 +208,21 @@ def arima(
148208
k = sum(order) + 1
149209
aicc = res.aic + 2 * k * (k + 1) / max(n - k - 1, 1)
150210

211+
_param_index = res.param_names if hasattr(res, "param_names") else None
212+
_params = pd.Series(res.params, index=_param_index)
213+
# statsmodels computes the asymptotic SEs (sqrt of the diagonal of the
214+
# covariance of the MLE) but we never surfaced them before; expose them.
215+
_bse = getattr(res, "bse", None)
216+
if _bse is not None:
217+
_se = pd.Series(np.asarray(_bse, dtype=float), index=_param_index)
218+
else: # pragma: no cover - defensive; SARIMAX always populates bse
219+
_se = pd.Series(np.full(len(_params), np.nan), index=_param_index)
220+
151221
_result = ARIMAResult(
152222
order=order,
153223
seasonal_order=seasonal_order,
154-
params=pd.Series(res.params, index=res.param_names) if hasattr(res, "param_names") else pd.Series(res.params),
224+
params=_params,
225+
se=_se,
155226
aic=float(res.aic),
156227
bic=float(res.bic),
157228
aicc=float(aicc),

tests/r_parity/39_arima.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""StatsPAI ARIMA parity (Python side) -- Module 39.
2+
3+
DGP: AR(2) with phi1=0.6, phi2=-0.2. Fits ARIMA(2,0,0). The
4+
companion R/Stata sides fit the same model.
5+
6+
sp.arima now exposes ``ARIMAResult.se`` (statsmodels' asymptotic SEs),
7+
so we compare standard errors alongside the point estimates and logLik.
8+
9+
Tolerance: rel < 1e-3 on AR coefficients.
10+
"""
11+
from __future__ import annotations
12+
13+
import numpy as np
14+
import pandas as pd
15+
import statspai as sp
16+
17+
from _common import PARITY_SEED, ParityRecord, dump_csv, write_results
18+
19+
20+
MODULE = "39_arima"
21+
22+
23+
def make_data(T: int = 300, seed: int = PARITY_SEED) -> pd.DataFrame:
24+
rng = np.random.default_rng(seed)
25+
y = np.zeros(T)
26+
eps = rng.normal(0, 0.7, T)
27+
for t in range(2, T):
28+
y[t] = 0.6 * y[t - 1] - 0.2 * y[t - 2] + eps[t]
29+
return pd.DataFrame({"y": y})
30+
31+
32+
def main() -> None:
33+
df = make_data()
34+
dump_csv(df, MODULE)
35+
36+
res = sp.arima(df["y"].values, order=(2, 0, 0))
37+
38+
rows: list[ParityRecord] = [
39+
ParityRecord(MODULE, "py", "ar1",
40+
estimate=float(res.params["ar.L1"]),
41+
se=float(res.se["ar.L1"]),
42+
n=int(len(df))),
43+
ParityRecord(MODULE, "py", "ar2",
44+
estimate=float(res.params["ar.L2"]),
45+
se=float(res.se["ar.L2"]),
46+
n=int(len(df))),
47+
ParityRecord(MODULE, "py", "sigma2",
48+
estimate=float(res.params["sigma2"]),
49+
se=float(res.se["sigma2"]),
50+
n=int(len(df))),
51+
ParityRecord(MODULE, "py", "logLik",
52+
estimate=float(res.log_likelihood),
53+
n=int(len(df))),
54+
]
55+
56+
write_results(MODULE, "py", rows,
57+
extra={"order": "(2,0,0)", "engine": "statsmodels"})
58+
59+
60+
if __name__ == "__main__":
61+
main()
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
{
2+
"module": "39_arima",
3+
"side": "py",
4+
"rows": [
5+
{
6+
"module": "39_arima",
7+
"side": "py",
8+
"statistic": "ar1",
9+
"estimate": 0.7018913469353353,
10+
"se": 0.04784543441281086,
11+
"ci_lo": null,
12+
"ci_hi": null,
13+
"n": 300,
14+
"extra": {}
15+
},
16+
{
17+
"module": "39_arima",
18+
"side": "py",
19+
"statistic": "ar2",
20+
"estimate": -0.34289324080723854,
21+
"se": 0.05372735783018966,
22+
"ci_lo": null,
23+
"ci_hi": null,
24+
"n": 300,
25+
"extra": {}
26+
},
27+
{
28+
"module": "39_arima",
29+
"side": "py",
30+
"statistic": "sigma2",
31+
"estimate": 0.4143918597001537,
32+
"se": 0.03353450058875001,
33+
"ci_lo": null,
34+
"ci_hi": null,
35+
"n": 300,
36+
"extra": {}
37+
},
38+
{
39+
"module": "39_arima",
40+
"side": "py",
41+
"statistic": "logLik",
42+
"estimate": -291.58314136058783,
43+
"se": null,
44+
"ci_lo": null,
45+
"ci_hi": null,
46+
"n": 300,
47+
"extra": {}
48+
}
49+
],
50+
"extra": {
51+
"order": "(2,0,0)",
52+
"engine": "statsmodels"
53+
}
54+
}

tests/test_arima.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,41 @@ def test_arima_auto_selects(rw):
2727
assert res.aicc < 560 # should beat a bad model
2828

2929

30+
def test_arima_standard_errors(rw):
31+
res = arima(rw, order=(1, 1, 0))
32+
# se is exposed, aligned with params, positive and finite
33+
assert res.se is not None
34+
assert list(res.se.index) == list(res.params.index)
35+
assert np.all(np.isfinite(res.se.to_numpy()))
36+
assert np.all(res.se.to_numpy() > 0)
37+
# std_errors is an alias for se
38+
assert res.std_errors.equals(res.se)
39+
40+
41+
def test_arima_conf_int_and_pvalues(rw):
42+
res = arima(rw, order=(2, 0, 0))
43+
ci = res.conf_int(alpha=0.05)
44+
assert list(ci.columns) == ["lower", "upper"]
45+
assert list(ci.index) == list(res.params.index)
46+
# params lie inside their own CI; bounds ordered
47+
assert np.all(ci["lower"].to_numpy() <= res.params.to_numpy())
48+
assert np.all(res.params.to_numpy() <= ci["upper"].to_numpy())
49+
assert np.all(ci["lower"].to_numpy() < ci["upper"].to_numpy())
50+
# pvalues in [0, 1], z = params / se
51+
pv = res.pvalues
52+
assert np.all((pv.to_numpy() >= 0) & (pv.to_numpy() <= 1))
53+
np.testing.assert_allclose(res.tvalues.to_numpy(),
54+
(res.params / res.se).to_numpy())
55+
56+
57+
def test_arima_se_matches_statsmodels(rw):
58+
# the exposed se must equal statsmodels' bse on the underlying fit
59+
res = arima(rw, order=(1, 1, 1))
60+
np.testing.assert_allclose(res.se.to_numpy(),
61+
np.asarray(res._model.bse, dtype=float),
62+
rtol=1e-12, atol=0)
63+
64+
3065
def test_exported():
3166
import statspai as sp
3267
assert callable(sp.arima)

0 commit comments

Comments
 (0)