@@ -57,26 +57,41 @@ def run_inference(panel: pd.DataFrame, ckpt_path: Path, tau: int = 8,
5757
5858 # 5. 对每一日(从第 tau 天起)跑推理
5959 print (" running inference ..." )
60- results = []
61- feat_arr = panel .pivot (index = "ts_code" , columns = "trade_date" , values = fc [0 ]).fillna (0.0 ).values
62- # build a (S, T, F) tensor for the full date range, then slide
60+ # 关键:每天只把"当天实际有 panel 数据"的股票喂给模型;零填充行不进 batch。
61+ # 否则带 inter-stock attention 的模型会被 padding 污染(peer-set 变了 → 所有 pred 跟着变),
62+ # 导致历史回测不稳定(每次 fetch 拉到的 universe 大小不同就会改一遍)。
63+ # 用 (ts_code, trade_date) 元组的集合来判断 raw 行是否存在。
64+ raw_keys = set (zip (panel ["ts_code" ].to_numpy (), panel ["trade_date" ].to_numpy ()))
65+ has_data = np .array (
66+ [[(ts , d ) in raw_keys for d in all_dates ] for ts in all_codes ]
67+ ) # (S, T) bool
68+
6369 pivoted = {}
6470 for col in fc :
6571 pivoted [col ] = panel .pivot (index = "ts_code" , columns = "trade_date" , values = col ).reindex (
66- index = all_codes , columns = all_dates ).fillna (0.0 ).values # (S, T)
72+ index = all_codes , columns = all_dates ).fillna (0.0 ).values
6773 X_full = np .stack ([pivoted [c ] for c in fc ], axis = - 1 ) # (S, T, F)
6874
75+ results = []
6976 with torch .no_grad ():
7077 for t in range (tau - 1 , len (all_dates )):
71- X_win = torch .tensor (X_full [:, t - tau + 1 :t + 1 , :], dtype = torch .float32 , device = device )
72- out = model (X_win ) # (S, τ)
73- scores = out [:, - 1 ].cpu ().numpy () # last-step signal
7478 d = all_dates [t ]
75- for code , s in zip (all_codes , scores ):
76- results .append ({"trade_date" : d , "ts_code" : code , "pred" : float (s )})
79+ mask = has_data [:, t ] # 当天有数据的股票
80+ if not mask .any ():
81+ continue
82+ sel_idx = np .flatnonzero (mask )
83+ X_win = torch .tensor (
84+ X_full [sel_idx , t - tau + 1 :t + 1 , :],
85+ dtype = torch .float32 , device = device ,
86+ ) # (S_d, τ, F)
87+ out = model (X_win ) # (S_d, τ)
88+ scores = out [:, - 1 ].cpu ().numpy () # last-step signal
89+ for i , s in zip (sel_idx , scores ):
90+ results .append ({"trade_date" : d , "ts_code" : all_codes [i ], "pred" : float (s )})
7791
7892 df = pd .DataFrame (results )
79- print (f" → { len (df )} predictions across { len (all_dates ) - tau + 1 } dates" )
93+ print (f" → { len (df )} predictions across { len (all_dates ) - tau + 1 } dates"
94+ f" (per-day avg { len (df ) / max (len (all_dates ) - tau + 1 , 1 ):.0f} stocks)" )
8095 return df
8196
8297
0 commit comments