Skip to content

Commit 7786aa5

Browse files
committed
fix(inference): per-day 过滤零填充行(修历史回测不稳定)
1 parent 8d76282 commit 7786aa5

3 files changed

Lines changed: 6813 additions & 6798 deletions

File tree

build/cache/preds.parquet

9.91 KB
Binary file not shown.

build/inference.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,26 +57,41 @@ def run_inference(panel: pd.DataFrame, ckpt_path: Path, tau: int = 8,
5757

5858
# 5. 对每一日(从第 tau 天起)跑推理
5959
print(" running inference ...")
60-
results = []
61-
feat_arr = panel.pivot(index="ts_code", columns="trade_date", values=fc[0]).fillna(0.0).values
62-
# build a (S, T, F) tensor for the full date range, then slide
60+
# 关键:每天只把"当天实际有 panel 数据"的股票喂给模型;零填充行不进 batch。
61+
# 否则带 inter-stock attention 的模型会被 padding 污染(peer-set 变了 → 所有 pred 跟着变),
62+
# 导致历史回测不稳定(每次 fetch 拉到的 universe 大小不同就会改一遍)。
63+
# 用 (ts_code, trade_date) 元组的集合来判断 raw 行是否存在。
64+
raw_keys = set(zip(panel["ts_code"].to_numpy(), panel["trade_date"].to_numpy()))
65+
has_data = np.array(
66+
[[(ts, d) in raw_keys for d in all_dates] for ts in all_codes]
67+
) # (S, T) bool
68+
6369
pivoted = {}
6470
for col in fc:
6571
pivoted[col] = panel.pivot(index="ts_code", columns="trade_date", values=col).reindex(
66-
index=all_codes, columns=all_dates).fillna(0.0).values # (S, T)
72+
index=all_codes, columns=all_dates).fillna(0.0).values
6773
X_full = np.stack([pivoted[c] for c in fc], axis=-1) # (S, T, F)
6874

75+
results = []
6976
with torch.no_grad():
7077
for t in range(tau - 1, len(all_dates)):
71-
X_win = torch.tensor(X_full[:, t - tau + 1:t + 1, :], dtype=torch.float32, device=device)
72-
out = model(X_win) # (S, τ)
73-
scores = out[:, -1].cpu().numpy() # last-step signal
7478
d = all_dates[t]
75-
for code, s in zip(all_codes, scores):
76-
results.append({"trade_date": d, "ts_code": code, "pred": float(s)})
79+
mask = has_data[:, t] # 当天有数据的股票
80+
if not mask.any():
81+
continue
82+
sel_idx = np.flatnonzero(mask)
83+
X_win = torch.tensor(
84+
X_full[sel_idx, t - tau + 1:t + 1, :],
85+
dtype=torch.float32, device=device,
86+
) # (S_d, τ, F)
87+
out = model(X_win) # (S_d, τ)
88+
scores = out[:, -1].cpu().numpy() # last-step signal
89+
for i, s in zip(sel_idx, scores):
90+
results.append({"trade_date": d, "ts_code": all_codes[i], "pred": float(s)})
7791

7892
df = pd.DataFrame(results)
79-
print(f" → {len(df)} predictions across {len(all_dates) - tau + 1} dates")
93+
print(f" → {len(df)} predictions across {len(all_dates) - tau + 1} dates"
94+
f" (per-day avg {len(df) / max(len(all_dates) - tau + 1, 1):.0f} stocks)")
8095
return df
8196

8297

0 commit comments

Comments
 (0)