Skip to content

Commit 958c235

Browse files
fix(agent): chain Honest-DiD from result handles
1 parent 0974aca commit 958c235

2 files changed

Lines changed: 123 additions & 32 deletions

File tree

src/statspai/agent/workflow_tools.py

Lines changed: 93 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ def _result_id_schema(description: str) -> Dict[str, Any]:
128128
'(Rambachan-Roth default); '
129129
'RM = relative magnitude.'),
130130
},
131+
'e': {
132+
'type': 'integer',
133+
'default': 0,
134+
'description': "Relative event time to audit.",
135+
},
131136
'm_bar': {
132137
'type': 'number',
133138
'description': "Bound on deviation magnitude (optional).",
@@ -567,57 +572,113 @@ def _tool_honest_did_from_result(rid: Optional[str],
567572
if isinstance(obj, dict) and 'error' in obj:
568573
return obj
569574

570-
betas, sigma, n_pre, n_post = _extract_event_study(obj)
571-
if betas is None or sigma is None:
572-
return {
573-
'error': ("could not extract event-study coefficients + "
574-
"covariance from the cached result"),
575-
'hint': ("honest_did_from_result expects a result fitted by "
576-
"sp.event_study / sp.callaway_santanna / "
577-
"sp.did_imputation / sp.sun_abraham. Run one of "
578-
"those with as_handle=true first."),
579-
}
580-
581-
method = arguments.get('method', 'SD')
582-
m_bar = arguments.get('m_bar')
583-
584575
import statspai as sp
585576
fn = getattr(sp, 'honest_did', None)
586577
if fn is None:
587578
return {'error': "sp.honest_did is not available in this build"}
588-
kwargs = dict(betas=list(betas), sigma=_listify_sigma(sigma),
589-
num_pre_periods=int(n_pre),
590-
num_post_periods=int(n_post),
591-
method=method)
592-
if m_bar is not None:
593-
kwargs['m_bar'] = float(m_bar)
579+
method_arg = str(arguments.get('method', 'SD'))
580+
method_key = method_arg.lower()
581+
method = {
582+
'sd': 'smoothness',
583+
'rm': 'relative_magnitude',
584+
'smoothness': 'smoothness',
585+
'relative_magnitude': 'relative_magnitude',
586+
}.get(method_key, method_arg)
587+
legacy_method = {
588+
'sd': 'SD',
589+
'smoothness': 'SD',
590+
'rm': 'RM',
591+
'relative_magnitude': 'RM',
592+
}.get(method_key, method_arg)
593+
event_time = int(arguments.get('e', 0))
594+
m_bar = arguments.get('m_bar')
595+
m_grid = [float(m_bar)] if m_bar is not None else None
596+
597+
event_result = _coerce_event_study_result(obj)
598+
current_api_failed: Optional[Exception] = None
594599
try:
595-
result = fn(**kwargs)
596-
except Exception as e:
597-
from .remediation import remediate
598-
return {
599-
'error': f"{type(e).__name__}: {e}",
600-
'remediation': remediate(e, context={'tool': 'honest_did_from_result'}),
600+
kwargs = {'e': event_time, 'method': method}
601+
if m_grid is not None:
602+
kwargs['m_grid'] = m_grid
603+
result = fn(event_result, **kwargs)
604+
except Exception as exc:
605+
current_api_failed = exc
606+
607+
if current_api_failed is not None:
608+
betas, sigma, n_pre, n_post = _extract_event_study(obj)
609+
if betas is None or sigma is None:
610+
return {
611+
'error': ("could not extract event-study coefficients + "
612+
"covariance from the cached result"),
613+
'hint': ("honest_did_from_result expects a result fitted by "
614+
"sp.event_study / sp.callaway_santanna / "
615+
"sp.did_imputation / sp.sun_abraham. Run one of "
616+
"those with as_handle=true first."),
617+
'upstream_error': (
618+
f"{type(current_api_failed).__name__}: {current_api_failed}"
619+
),
620+
}
621+
kwargs = dict(betas=list(betas), sigma=_listify_sigma(sigma),
622+
num_pre_periods=int(n_pre),
623+
num_post_periods=int(n_post),
624+
method=legacy_method)
625+
if m_bar is not None:
626+
kwargs['m_bar'] = float(m_bar)
627+
try:
628+
result = fn(**kwargs)
629+
except Exception as e:
630+
from .remediation import remediate
631+
return {
632+
'error': f"{type(e).__name__}: {e}",
633+
'remediation': remediate(e, context={'tool': 'honest_did_from_result'}),
634+
}
635+
636+
if isinstance(result, pd.DataFrame):
637+
out = {
638+
'method': 'Rambachan-Roth (2023) honest DiD',
639+
'restriction': method,
640+
'event_time': event_time,
641+
'rows': result.to_dict(orient='records'),
642+
'max_rejecting_M': (
643+
float(result.loc[result['rejects_zero'], 'M'].max())
644+
if 'rejects_zero' in result and bool(result['rejects_zero'].any())
645+
else 0.0
646+
),
601647
}
602-
603-
from .tools import _default_serializer
604-
out = _default_serializer(result, detail=detail)
648+
else:
649+
from .tools import _default_serializer
650+
out = _default_serializer(result, detail=detail)
605651
if not isinstance(out, dict):
606652
out = {'value': out}
607653
out['source_result_id'] = rid
608-
out['extracted_n_pre'] = int(n_pre)
609-
out['extracted_n_post'] = int(n_post)
610654
new_rid: Optional[str] = None
611655
if as_handle:
612656
new_rid = RESULT_CACHE.put(result, tool='honest_did_from_result',
613-
arguments={'source': rid, 'method': method})
657+
arguments={'source': rid, 'method': method,
658+
'e': event_time})
614659
out['result_id'] = new_rid
615660
out['result_uri'] = f"statspai://result/{new_rid}"
616661
from ._enrichment import enrich_payload
617662
enrich_payload(out, tool_name='honest_did', result_id=new_rid)
618663
return out
619664

620665

666+
def _coerce_event_study_result(obj: Any) -> Any:
667+
"""Return an object shaped for the current ``sp.honest_did`` API."""
668+
detail = getattr(obj, 'detail', None)
669+
if isinstance(detail, pd.DataFrame) and {'relative_time', 'att', 'se'} <= set(detail.columns):
670+
return obj
671+
672+
method = str(getattr(obj, 'method', '')).lower()
673+
if 'callaway' in method and detail is not None:
674+
import statspai as sp
675+
try:
676+
return sp.aggte(obj, type='dynamic', bstrap=False)
677+
except TypeError:
678+
return sp.aggte(obj, type='dynamic')
679+
return obj
680+
681+
621682
def _extract_event_study(obj: Any):
622683
"""Best-effort extraction of (betas, sigma, n_pre, n_post)."""
623684
import numpy as np

tests/test_mcp_result_handle.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,36 @@ def test_audit_with_missing_handle_returns_friendly_error(self):
174174
assert "error" in out
175175
assert "result_id" in out["error"] or "not found" in out["error"]
176176

177+
def test_honest_did_from_callaway_handle(self):
178+
import statspai as sp
179+
180+
df = sp.datasets.mpdta()
181+
fit = execute_tool(
182+
"callaway_santanna",
183+
{
184+
"y": "lemp",
185+
"g": "first_treat",
186+
"t": "year",
187+
"i": "countyreal",
188+
"estimator": "reg",
189+
"control_group": "nevertreated",
190+
},
191+
data=df,
192+
detail="minimal",
193+
as_handle=True,
194+
)
195+
rid = fit["result_id"]
196+
out = execute_tool(
197+
"honest_did_from_result",
198+
{"result_id": rid, "method": "SD", "e": 0},
199+
detail="minimal",
200+
)
201+
assert out["source_result_id"] == rid
202+
assert out["restriction"] == "smoothness"
203+
assert out["event_time"] == 0
204+
assert out["max_rejecting_M"] > 0
205+
assert any(row["rejects_zero"] for row in out["rows"])
206+
177207

178208
# ----------------------------------------------------------------------
179209
# Schema injection: result_id, as_handle, data_columns, data_sample_n

0 commit comments

Comments
 (0)