Skip to content

Commit 7aba2b6

Browse files
Harden schema and MCP serialization
Improve agent schema metadata extraction, strict JSON serialization for MCP payloads, optional text dependencies, stability audit evidence accounting, and related regression coverage. Include the current full-suite validation record and performance benchmark harness updates present in the working tree. Co-authored-by: Codex <noreply@openai.com>
1 parent 49d5c26 commit 7aba2b6

17 files changed

Lines changed: 931 additions & 191 deletions

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,7 @@ pip install statspai[plotting] # matplotlib, seaborn
830830
pip install statspai[fixest] # pyfixest for high-dimensional FE
831831
pip install statspai[deepiv] # PyTorch for DeepIV
832832
pip install statspai[neural] # PyTorch for TARNet/CFRNet/DragonNet
833+
pip install statspai[text] # sentence-transformers for sbert text embeddings
833834
pip install statspai[performance] # JAX CPU backend for sp.fast.demean
834835
```
835836

README_CN.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,7 @@ pip install statspai[plotting] # matplotlib, seaborn
423423
pip install statspai[fixest] # pyfixest 高维固定效应
424424
pip install statspai[deepiv] # PyTorch (Deep IV)
425425
pip install statspai[neural] # PyTorch (TARNet / CFRNet / DragonNet)
426+
pip install statspai[text] # sentence-transformers,用于 sbert 文本嵌入
426427
pip install statspai[performance] # JAX CPU 后端,用于 sp.fast.demean
427428
```
428429

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ pip install statspai # core
156156
pip install 'statspai[plotting]' # matplotlib + seaborn
157157
pip install 'statspai[fixest]' # pyfixest HDFE
158158
pip install 'statspai[deepiv]' # PyTorch (Deep IV, TARNet)
159+
pip install 'statspai[text]' # sentence-transformers for sbert
159160
```
160161

161162
## Citation

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ bayes = [
9494
"pymc>=5.0",
9595
"arviz>=0.15",
9696
]
97+
text = [
98+
"sentence-transformers>=2.2.0",
99+
]
97100
tune = [
98101
"optuna>=3.0",
99102
]

scripts/stability_audit.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,17 @@
33
StatsPAI now separates API lifecycle from numerical validation evidence:
44
``stability='stable'`` means the public signature is locked, while
55
``validation_status='certified'`` / ``'validated'`` carries the
6-
parity-evidence signal. This script keeps the old risk visible by
7-
counting stable API entries that still lack a parity-test reference in
6+
validation-evidence signal. This script keeps the old risk visible by
7+
counting stable API entries that still lack either registry-attached
8+
validation evidence or a parity-test reference in
89
``tests/reference_parity/`` + ``tests/external_parity/``.
910
1011
The catch: until v1.13 every newly-registered function was *implicitly*
1112
``stable`` (the field's default), so the catalogue's ~970 stable
1213
entries currently mix two populations:
1314
14-
* **Parity-test backed** — at least one test in
15+
* **Validation-backed** — the registry marks the function
16+
``certified`` / ``validated`` or at least one test in
1517
``tests/reference_parity/`` or ``tests/external_parity/`` exercises
1618
the function with R / Stata / paper-replication numbers.
1719
* **API-stable but unbacked** — the public API is stable, but no
@@ -60,7 +62,7 @@
6062
#: a parity test before --check fails. Bumped when we deliberately add
6163
#: hand-written entries faster than parity tests. Decrease over time as
6264
#: the audit gets cleaned up.
63-
UNBACKED_HANDWRITTEN_FLOOR = 220
65+
UNBACKED_HANDWRITTEN_FLOOR = 190
6466

6567
#: Regex matching ``sp.<name>(`` references in test source. Used to
6668
#: attribute parity coverage to public ``sp.*`` symbols.
@@ -125,6 +127,7 @@ def _registry_specs():
125127
def collect() -> dict:
126128
registry, hand_written = _registry_specs()
127129
backed, sources = _backed_functions()
130+
evidence_sources: Dict[str, List[str]] = {k: list(v) for k, v in sources.items()}
128131

129132
stable_handwritten: List[str] = []
130133
stable_auto: List[str] = []
@@ -144,12 +147,22 @@ def collect() -> dict:
144147
continue
145148
# spec.stability == "stable"
146149
is_hand = name in hand_written
150+
registry_backed = spec.validation_status in {"certified", "validated"}
151+
if registry_backed:
152+
notes = list(getattr(spec, "validation_notes", []) or [])
153+
if not notes:
154+
notes = [f"registry validation_status={spec.validation_status}"]
155+
evidence_sources.setdefault(name, [])
156+
for note in notes:
157+
if note not in evidence_sources[name]:
158+
evidence_sources[name].append(note)
159+
is_backed = name in backed or registry_backed
147160
if is_hand:
148161
stable_handwritten.append(name)
149-
(backed_handwritten if name in backed else unbacked_handwritten).append(name)
162+
(backed_handwritten if is_backed else unbacked_handwritten).append(name)
150163
else:
151164
stable_auto.append(name)
152-
(backed_auto if name in backed else unbacked_auto).append(name)
165+
(backed_auto if is_backed else unbacked_auto).append(name)
153166

154167
return {
155168
"totals": {
@@ -170,6 +183,11 @@ def collect() -> dict:
170183
for _ in p.rglob("test_*.py")
171184
),
172185
"symbols_referenced_in_parity_tests": len(backed),
186+
"registry_validated_symbols": sum(
187+
1 for spec in registry.values()
188+
if spec.stability == "stable"
189+
and spec.validation_status in {"certified", "validated"}
190+
),
173191
},
174192
"lists": {
175193
"unbacked_handwritten": sorted(unbacked_handwritten),
@@ -178,7 +196,7 @@ def collect() -> dict:
178196
"deprecated": sorted(deprecated),
179197
},
180198
"sources": {
181-
name: srcs for name, srcs in sources.items()
199+
name: srcs for name, srcs in evidence_sources.items()
182200
# Only carry backed-handwritten sources in the JSON payload —
183201
# auto-registered specs aren't the focus of this audit.
184202
if name in set(backed_handwritten)
@@ -206,7 +224,7 @@ def render_report(stats: dict, *, show_unbacked: bool = False) -> str:
206224
lines.append(f" experimental : {t['experimental']}")
207225
lines.append(f" deprecated : {t['deprecated']}")
208226
lines.append("")
209-
lines.append("Parity coverage (sp.<name> referenced in parity tests)")
227+
lines.append("Validation coverage")
210228
lines.append("-" * 50)
211229
lines.append(
212230
f" parity test files : "
@@ -216,6 +234,10 @@ def render_report(stats: dict, *, show_unbacked: bool = False) -> str:
216234
f" distinct sp.* symbols referenced : "
217235
f"{p['symbols_referenced_in_parity_tests']}"
218236
)
237+
lines.append(
238+
f" registry certified/validated : "
239+
f"{p['registry_validated_symbols']}"
240+
)
219241
lines.append(
220242
f" stable hand-written, BACKED : "
221243
f"{p['backed_handwritten']}"
@@ -238,9 +260,9 @@ def render_report(stats: dict, *, show_unbacked: bool = False) -> str:
238260
lines.append("-" * 50)
239261
lines.append(
240262
"* UNBACKED hand-written: a maintainer wrote a stable public "
241-
"API, but this audit found no parity-test reference. Add a "
242-
"test, attach validation evidence, or mark immature APIs "
243-
"experimental."
263+
"API, but this audit found no registry validation evidence and "
264+
"no parity-test reference. Add evidence, add a test, or mark "
265+
"immature APIs experimental."
244266
)
245267
lines.append(
246268
"* UNBACKED auto-registered: classified as stable by default. "
@@ -263,14 +285,15 @@ def check_drift(stats: dict) -> int:
263285
if n > floor:
264286
print(
265287
f"FAIL: {n} hand-written stable API entries lack parity tests "
266-
f"(floor: {floor}). Either add tests, attach validation "
267-
f"evidence, or downgrade immature APIs to experimental.",
288+
f"or registry validation evidence (floor: {floor}). Either "
289+
f"add evidence, add tests, or downgrade immature APIs to "
290+
f"experimental.",
268291
file=sys.stderr,
269292
)
270293
return 1
271294
print(
272295
f"OK: {n} hand-written stable API entries lack parity tests "
273-
f"(floor: {floor})."
296+
f"or registry validation evidence (floor: {floor})."
274297
)
275298
return 0
276299

src/statspai/agent/_resources.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,19 @@ def handle_resources_read(
156156
server_version: str,
157157
InvalidParamsError,
158158
ResourceNotFoundError,
159+
clean_for_json: Optional[Callable[[Any], Any]] = None,
159160
) -> Dict[str, Any]:
160161
"""Dispatch a ``resources/read`` URI to its renderer.
161162
162163
The encoder + error classes are passed in to avoid a circular
163164
import through ``mcp_server`` — this module is meant to be a leaf.
165+
166+
``clean_for_json`` is the recursive nan/inf scrubber from
167+
``mcp_server`` — passed in (rather than imported) for the same
168+
leaf-module reason as ``json_default``. Falls back to identity when
169+
the caller doesn't supply one (older callers / legacy tests).
164170
"""
171+
_clean = clean_for_json if clean_for_json is not None else (lambda x: x)
165172
uri = params.get("uri")
166173
if not isinstance(uri, str):
167174
raise InvalidParamsError(
@@ -183,8 +190,9 @@ def handle_resources_read(
183190
{
184191
"uri": uri,
185192
"mimeType": "application/json",
186-
"text": json.dumps(functions_index(),
187-
default=json_default),
193+
"text": json.dumps(_clean(functions_index()),
194+
default=json_default,
195+
allow_nan=False),
188196
},
189197
],
190198
}
@@ -212,7 +220,8 @@ def handle_resources_read(
212220
{
213221
"uri": uri,
214222
"mimeType": "application/json",
215-
"text": json.dumps(card, default=json_default),
223+
"text": json.dumps(_clean(card), default=json_default,
224+
allow_nan=False),
216225
},
217226
],
218227
}
@@ -247,7 +256,9 @@ def handle_resources_read(
247256
{
248257
"uri": uri,
249258
"mimeType": "application/json",
250-
"text": json.dumps(payload, default=json_default),
259+
"text": json.dumps(_clean(payload),
260+
default=json_default,
261+
allow_nan=False),
251262
},
252263
],
253264
}

0 commit comments

Comments
 (0)