Skip to content

Commit 50ea67e

Browse files
committed
made syntax to deal with parameter sub-dicts in lephare
1 parent c88c610 commit 50ea67e

2 files changed

Lines changed: 72 additions & 50 deletions

File tree

src/rail/estimation/algos/lephare.py

Lines changed: 71 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import importlib
22
import os
3+
from typing import Any
34

45
import lephare as lp
56
import numpy as np
67
import qp
78
from astropy.table import Table
89
from ceci.config import StageParameter as Param
10+
from ceci.config import StageConfig
911
from rail.core.common_params import SHARED_PARAMS
1012
from rail.estimation.estimator import CatEstimator, CatInformer
1113

@@ -35,6 +37,48 @@
3537
}
3638
)
3739

40+
star_default_config=dict(
41+
LIB_ASCII="YES"
42+
)
43+
44+
gal_default_config=dict(
45+
LIB_ASCII="YES",
46+
MOD_EXTINC="18,26,26,33,26,33,26,33",
47+
EXTINC_LAW="SMC_prevot.dat,SB_calzetti.dat,SB_calzetti_bump1.dat,SB_calzetti_bump2.dat",
48+
EM_LINES="EMP_UV",
49+
EM_DISPERSION="0.5,0.75,1.,1.5,2.",
50+
)
51+
52+
qso_default_config=dict(
53+
LIB_ASCII="YES",
54+
MOD_EXTINC="0,1000",
55+
EB_V="0.,0.1,0.2,0.3",
56+
EXTINC_LAW="SB_calzetti.dat",
57+
)
58+
59+
60+
def _add_sub_config(
61+
config: dict[str, Any],
62+
sub_config: dict[str, Any],
63+
prefix: str,
64+
) -> None:
65+
"""Add all sub-config parameters to the stage config with
66+
the requested prefix. This will make the correct paramter types
67+
and defaults
68+
"""
69+
for key, val in sub_config.items():
70+
dtype = type(val)
71+
default = val
72+
param = Param(dtype=dtype, default=default)
73+
config[f"{prefix}{key}"] = param
74+
75+
76+
def _get_sub_config(config: StageConfig, prefix: str) -> dict[str, Any]:
77+
"""Extract all config parameters that start with a
78+
particular prefix into a dict"""
79+
out_dict = {key[len(prefix):]: val for key, val in config.items() if key.find(prefix) == 0}
80+
return out_dict
81+
3882

3983
class LephareInformer(CatInformer):
4084
"""Inform stage for LephareEstimator
@@ -56,46 +100,19 @@ class LephareInformer(CatInformer):
56100
err_bands=SHARED_PARAMS,
57101
ref_band=SHARED_PARAMS,
58102
redshift_col=SHARED_PARAMS,
59-
lephare_config=Param(
60-
dict,
61-
lsst_default_config,
62-
msg="The lephare config keymap.",
63-
),
64-
star_config=Param(
65-
dict,
66-
dict(LIB_ASCII="YES"),
67-
msg="Star config overrides.",
68-
),
69-
gal_config=Param(
70-
dict,
71-
dict(
72-
LIB_ASCII="YES",
73-
MOD_EXTINC="18,26,26,33,26,33,26,33",
74-
EXTINC_LAW="SMC_prevot.dat,SB_calzetti.dat,SB_calzetti_bump1.dat,SB_calzetti_bump2.dat",
75-
EM_LINES="EMP_UV",
76-
EM_DISPERSION="0.5,0.75,1.,1.5,2.",
77-
),
78-
msg="Galaxy config overrides.",
79-
),
80-
qso_config=Param(
81-
dict,
82-
dict(
83-
LIB_ASCII="YES",
84-
MOD_EXTINC="0,1000",
85-
EB_V="0.,0.1,0.2,0.3",
86-
EXTINC_LAW="SB_calzetti.dat",
87-
),
88-
msg="QSO config overrides.",
89-
),
90103
)
104+
_add_sub_config(config_options, lsst_default_config, "lephare.")
105+
_add_sub_config(config_options, star_default_config, "gal.")
106+
_add_sub_config(config_options, gal_default_config, "gal.")
107+
_add_sub_config(config_options, qso_default_config, "qso.")
91108

92109
def __init__(self, args, **kwargs):
93110
"""Init function, init config stuff (COPIED from rail_bpz)"""
94111

95112
super().__init__(args, **kwargs)
96113

97114
def validate(self):
98-
self.lephare_config = self.config["lephare_config"]
115+
self.lephare_config = _get_sub_config(self.config, 'lephare.')
99116

100117
# Put something in place to allow for not rerunning the prepare stage
101118
try:
@@ -114,7 +131,7 @@ def validate(self):
114131
print(
115132
f"rail_lephare is setting the Z_STEP config to {Z_STEP} based on the informer params."
116133
)
117-
self.config["lephare_config"]["Z_STEP"] = Z_STEP
134+
self.config["lephare.Z_STEP"] = Z_STEP
118135
# We create a run directory with the informer name
119136
self.run_dir = _set_run_dir(self.config["name"])
120137

@@ -147,13 +164,17 @@ def run(self):
147164
# Get number of sources
148165
ngal = len(training_data[self.config.ref_band])
149166

167+
star_config = _get_sub_config(self.config, 'star.')
168+
gal_config = _get_sub_config(self.config, 'gal.')
169+
qso_config = _get_sub_config(self.config, 'gso.')
170+
150171
# The three main lephare specific inform tasks
151172
if self.do_prepare:
152173
lp.prepare(
153174
self.lephare_config,
154-
star_config=self.config["star_config"],
155-
gal_config=self.config["gal_config"],
156-
qso_config=self.config["qso_config"],
175+
star_config=star_config,
176+
gal_config=gal_config,
177+
qso_config=qso_config,
157178
)
158179
else:
159180
print(
@@ -168,20 +189,20 @@ def run(self):
168189
training_data, self.config.bands, self.config.err_bands
169190
)
170191
# This will return zeros if AUTO_ADAPT is NO
171-
offsets = lp.calculate_offsets_from_input(self.config["lephare_config"], input)
192+
offsets = lp.calculate_offsets_from_input(self.lephare_config, input)
172193
# We must make a string dictionary to allow pickling and saving
173194
lephare_config = lp.keymap_to_string_dict(
174-
lp.all_types_to_keymap(self.config["lephare_config"])
195+
lp.all_types_to_keymap(self.lephare_config)
175196
)
176197
# Give principle inform config 'model' to instance.
177198
self.model = dict(
178199
lephare_version=lp.__version__,
179-
lephare_config=lephare_config,
200+
lephare_config=self.lephare_config,
180201
offsets=offsets,
181202
run_dir=self.run_dir,
182-
star_config=self.config["star_config"],
183-
gal_config=self.config["gal_config"],
184-
qso_config=self.config["qso_config"],
203+
star_config=star_config,
204+
gal_config=gal_config,
205+
qso_config=qso_config,
185206
)
186207
self.add_data("model", self.model)
187208

@@ -202,10 +223,10 @@ class LephareEstimator(CatEstimator):
202223
ref_band=SHARED_PARAMS,
203224
err_bands=SHARED_PARAMS,
204225
redshift_col=SHARED_PARAMS,
205-
lephare_config=Param(
206-
dict,
207-
{},
208-
msg="The lephare config keymap. If unset we load it from the model.",
226+
lephare_config_from_model=Param(
227+
bool,
228+
True,
229+
"Load lephare config keymap from model",
209230
),
210231
use_inform_offsets=Param(
211232
bool,
@@ -252,6 +273,7 @@ class LephareEstimator(CatEstimator):
252273
),
253274
),
254275
)
276+
_add_sub_config(config_options, lsst_default_config, "lephare.")
255277

256278
def __init__(self, args, **kwargs):
257279
super().__init__(args, **kwargs)
@@ -262,10 +284,10 @@ def __init__(self, args, **kwargs):
262284

263285
def open_model(self, **kwargs):
264286
CatEstimator.open_model(self, **kwargs)
265-
if self.config["lephare_config"]:
266-
self.lephare_config = self.config["lephare_config"]
267-
else:
287+
if self.config["lephare_config_from_model"]:
268288
self.lephare_config = self.model["lephare_config"]
289+
else:
290+
self.lephare_config = _get_sub_config(self.config, "lephare.")
269291
# Use string dictionary config in case keymap passed to estimate stage
270292
self.lephare_config = lp.keymap_to_string_dict(
271293
lp.all_types_to_keymap(self.lephare_config)

tests/lephare/test_algos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def test_informer_basic():
2323
assert inform_lephare.name == "LephareInformer"
2424
assert inform_lephare.config["name"] == "inform_Lephare"
2525
# Check config zgrid updated to stage param defaults:
26-
assert inform_lephare.config["lephare_config"]["Z_STEP"] == "0.01,0.0,3.0"
26+
assert inform_lephare.config["lephare.Z_STEP"] == "0.01,0.0,3.0"
2727

2828

2929
def test_informer_and_estimator(test_data_dir: str):

0 commit comments

Comments
 (0)