@@ -128,6 +128,11 @@ def _result_id_schema(description: str) -> Dict[str, Any]:
128128 '(Rambachan-Roth default); '
129129 'RM = relative magnitude.' ),
130130 },
131+ 'e' : {
132+ 'type' : 'integer' ,
133+ 'default' : 0 ,
134+ 'description' : "Relative event time to audit." ,
135+ },
131136 'm_bar' : {
132137 'type' : 'number' ,
133138 'description' : "Bound on deviation magnitude (optional)." ,
@@ -567,57 +572,113 @@ def _tool_honest_did_from_result(rid: Optional[str],
567572 if isinstance (obj , dict ) and 'error' in obj :
568573 return obj
569574
570- betas , sigma , n_pre , n_post = _extract_event_study (obj )
571- if betas is None or sigma is None :
572- return {
573- 'error' : ("could not extract event-study coefficients + "
574- "covariance from the cached result" ),
575- 'hint' : ("honest_did_from_result expects a result fitted by "
576- "sp.event_study / sp.callaway_santanna / "
577- "sp.did_imputation / sp.sun_abraham. Run one of "
578- "those with as_handle=true first." ),
579- }
580-
581- method = arguments .get ('method' , 'SD' )
582- m_bar = arguments .get ('m_bar' )
583-
584575 import statspai as sp
585576 fn = getattr (sp , 'honest_did' , None )
586577 if fn is None :
587578 return {'error' : "sp.honest_did is not available in this build" }
588- kwargs = dict (betas = list (betas ), sigma = _listify_sigma (sigma ),
589- num_pre_periods = int (n_pre ),
590- num_post_periods = int (n_post ),
591- method = method )
592- if m_bar is not None :
593- kwargs ['m_bar' ] = float (m_bar )
579+ method_arg = str (arguments .get ('method' , 'SD' ))
580+ method_key = method_arg .lower ()
581+ method = {
582+ 'sd' : 'smoothness' ,
583+ 'rm' : 'relative_magnitude' ,
584+ 'smoothness' : 'smoothness' ,
585+ 'relative_magnitude' : 'relative_magnitude' ,
586+ }.get (method_key , method_arg )
587+ legacy_method = {
588+ 'sd' : 'SD' ,
589+ 'smoothness' : 'SD' ,
590+ 'rm' : 'RM' ,
591+ 'relative_magnitude' : 'RM' ,
592+ }.get (method_key , method_arg )
593+ event_time = int (arguments .get ('e' , 0 ))
594+ m_bar = arguments .get ('m_bar' )
595+ m_grid = [float (m_bar )] if m_bar is not None else None
596+
597+ event_result = _coerce_event_study_result (obj )
598+ current_api_failed : Optional [Exception ] = None
594599 try :
595- result = fn (** kwargs )
596- except Exception as e :
597- from .remediation import remediate
598- return {
599- 'error' : f"{ type (e ).__name__ } : { e } " ,
600- 'remediation' : remediate (e , context = {'tool' : 'honest_did_from_result' }),
600+ kwargs = {'e' : event_time , 'method' : method }
601+ if m_grid is not None :
602+ kwargs ['m_grid' ] = m_grid
603+ result = fn (event_result , ** kwargs )
604+ except Exception as exc :
605+ current_api_failed = exc
606+
607+ if current_api_failed is not None :
608+ betas , sigma , n_pre , n_post = _extract_event_study (obj )
609+ if betas is None or sigma is None :
610+ return {
611+ 'error' : ("could not extract event-study coefficients + "
612+ "covariance from the cached result" ),
613+ 'hint' : ("honest_did_from_result expects a result fitted by "
614+ "sp.event_study / sp.callaway_santanna / "
615+ "sp.did_imputation / sp.sun_abraham. Run one of "
616+ "those with as_handle=true first." ),
617+ 'upstream_error' : (
618+ f"{ type (current_api_failed ).__name__ } : { current_api_failed } "
619+ ),
620+ }
621+ kwargs = dict (betas = list (betas ), sigma = _listify_sigma (sigma ),
622+ num_pre_periods = int (n_pre ),
623+ num_post_periods = int (n_post ),
624+ method = legacy_method )
625+ if m_bar is not None :
626+ kwargs ['m_bar' ] = float (m_bar )
627+ try :
628+ result = fn (** kwargs )
629+ except Exception as e :
630+ from .remediation import remediate
631+ return {
632+ 'error' : f"{ type (e ).__name__ } : { e } " ,
633+ 'remediation' : remediate (e , context = {'tool' : 'honest_did_from_result' }),
634+ }
635+
636+ if isinstance (result , pd .DataFrame ):
637+ out = {
638+ 'method' : 'Rambachan-Roth (2023) honest DiD' ,
639+ 'restriction' : method ,
640+ 'event_time' : event_time ,
641+ 'rows' : result .to_dict (orient = 'records' ),
642+ 'max_rejecting_M' : (
643+ float (result .loc [result ['rejects_zero' ], 'M' ].max ())
644+ if 'rejects_zero' in result and bool (result ['rejects_zero' ].any ())
645+ else 0.0
646+ ),
601647 }
602-
603- from .tools import _default_serializer
604- out = _default_serializer (result , detail = detail )
648+ else :
649+ from .tools import _default_serializer
650+ out = _default_serializer (result , detail = detail )
605651 if not isinstance (out , dict ):
606652 out = {'value' : out }
607653 out ['source_result_id' ] = rid
608- out ['extracted_n_pre' ] = int (n_pre )
609- out ['extracted_n_post' ] = int (n_post )
610654 new_rid : Optional [str ] = None
611655 if as_handle :
612656 new_rid = RESULT_CACHE .put (result , tool = 'honest_did_from_result' ,
613- arguments = {'source' : rid , 'method' : method })
657+ arguments = {'source' : rid , 'method' : method ,
658+ 'e' : event_time })
614659 out ['result_id' ] = new_rid
615660 out ['result_uri' ] = f"statspai://result/{ new_rid } "
616661 from ._enrichment import enrich_payload
617662 enrich_payload (out , tool_name = 'honest_did' , result_id = new_rid )
618663 return out
619664
620665
666+ def _coerce_event_study_result (obj : Any ) -> Any :
667+ """Return an object shaped for the current ``sp.honest_did`` API."""
668+ detail = getattr (obj , 'detail' , None )
669+ if isinstance (detail , pd .DataFrame ) and {'relative_time' , 'att' , 'se' } <= set (detail .columns ):
670+ return obj
671+
672+ method = str (getattr (obj , 'method' , '' )).lower ()
673+ if 'callaway' in method and detail is not None :
674+ import statspai as sp
675+ try :
676+ return sp .aggte (obj , type = 'dynamic' , bstrap = False )
677+ except TypeError :
678+ return sp .aggte (obj , type = 'dynamic' )
679+ return obj
680+
681+
621682def _extract_event_study (obj : Any ):
622683 """Best-effort extraction of (betas, sigma, n_pre, n_post)."""
623684 import numpy as np
0 commit comments