|
24 | 24 | HERE = Path(__file__).resolve().parent |
25 | 25 | CACHE = HERE / "cache" |
26 | 26 | CACHE.mkdir(exist_ok=True) |
| 27 | +INDUSTRY_MAP_CSV = HERE / "industry_map.csv" # tushare 申万行业映射 (committed in repo) |
27 | 28 |
|
28 | 29 |
|
29 | 30 | def _bs_code_to_ts(bs_code: str) -> str: |
@@ -87,15 +88,29 @@ def get_csi300_components(bs) -> pd.DataFrame: |
87 | 88 | rs = bs.query_hs300_stocks() |
88 | 89 | comps = _bs_query_to_df(rs) |
89 | 90 | comps["ts_code"] = comps["code"].apply(_bs_code_to_ts) |
| 91 | + comps["name"] = comps["code_name"] |
90 | 92 | print(f" → {len(comps)} 只成分股") |
91 | 93 |
|
92 | | - print(" 补充行业分类 ...") |
| 94 | + # 优先用 repo 里的 tushare 申万 mapping(短名 + 细分二级,如"通信设备/元器件/半导体") |
| 95 | + primary_map = {} |
| 96 | + if INDUSTRY_MAP_CSV.exists(): |
| 97 | + df_map = pd.read_csv(INDUSTRY_MAP_CSV, dtype=str).dropna(subset=["industry"]) |
| 98 | + primary_map = dict(zip(df_map["ts_code"], df_map["industry"])) |
| 99 | + print(f" loaded primary industry mapping ({len(primary_map)} tickers)") |
| 100 | + |
| 101 | + # Fallback: BaoStock 证监会分类(粗一些,需要 shorten) |
93 | 102 | rs = bs.query_stock_industry() |
94 | 103 | ind = _bs_query_to_df(rs) |
95 | 104 | ind["ts_code"] = ind["code"].apply(_bs_code_to_ts) |
96 | | - ind_map = dict(zip(ind["ts_code"], ind["industry"])) |
97 | | - comps["industry"] = comps["ts_code"].map(ind_map).fillna("").apply(shorten_industry) |
98 | | - comps["name"] = comps["code_name"] |
| 105 | + fallback_map = dict(zip(ind["ts_code"], ind["industry"].apply(shorten_industry))) |
| 106 | + |
| 107 | + def _resolve_industry(ts): |
| 108 | + return primary_map.get(ts) or fallback_map.get(ts) or "—" |
| 109 | + |
| 110 | + comps["industry"] = comps["ts_code"].apply(_resolve_industry) |
| 111 | + missing = (comps["industry"] == "—").sum() |
| 112 | + if missing: |
| 113 | + print(f" ⚠️ {missing} 只成分股 industry 缺失(mapping 都没匹配上)") |
99 | 114 | return comps[["ts_code", "name", "industry"]] |
100 | 115 |
|
101 | 116 |
|
|
0 commit comments