11import importlib
22import os
3+ from typing import Any
34
45import lephare as lp
56import numpy as np
67import qp
78from astropy .table import Table
89from ceci .config import StageParameter as Param
10+ from ceci .config import StageConfig
911from rail .core .common_params import SHARED_PARAMS
1012from rail .estimation .estimator import CatEstimator , CatInformer
1113
3537 }
3638)
3739
40+ star_default_config = dict (
41+ LIB_ASCII = "YES"
42+ )
43+
44+ gal_default_config = dict (
45+ LIB_ASCII = "YES" ,
46+ MOD_EXTINC = "18,26,26,33,26,33,26,33" ,
47+ EXTINC_LAW = "SMC_prevot.dat,SB_calzetti.dat,SB_calzetti_bump1.dat,SB_calzetti_bump2.dat" ,
48+ EM_LINES = "EMP_UV" ,
49+ EM_DISPERSION = "0.5,0.75,1.,1.5,2." ,
50+ )
51+
52+ qso_default_config = dict (
53+ LIB_ASCII = "YES" ,
54+ MOD_EXTINC = "0,1000" ,
55+ EB_V = "0.,0.1,0.2,0.3" ,
56+ EXTINC_LAW = "SB_calzetti.dat" ,
57+ )
58+
59+
60+ def _add_sub_config (
61+ config : dict [str , Any ],
62+ sub_config : dict [str , Any ],
63+ prefix : str ,
64+ ) -> None :
65+ """Add all sub-config parameters to the stage config with
66+ the requested prefix. This will make the correct paramter types
67+ and defaults
68+ """
69+ for key , val in sub_config .items ():
70+ dtype = type (val )
71+ default = val
72+ param = Param (dtype = dtype , default = default )
73+ config [f"{ prefix } { key } " ] = param
74+
75+
76+ def _get_sub_config (config : StageConfig , prefix : str ) -> dict [str , Any ]:
77+ """Extract all config parameters that start with a
78+ particular prefix into a dict"""
79+ out_dict = {key [len (prefix ):]: val for key , val in config .items () if key .find (prefix ) == 0 }
80+ return out_dict
81+
3882
3983class LephareInformer (CatInformer ):
4084 """Inform stage for LephareEstimator
@@ -56,46 +100,19 @@ class LephareInformer(CatInformer):
56100 err_bands = SHARED_PARAMS ,
57101 ref_band = SHARED_PARAMS ,
58102 redshift_col = SHARED_PARAMS ,
59- lephare_config = Param (
60- dict ,
61- lsst_default_config ,
62- msg = "The lephare config keymap." ,
63- ),
64- star_config = Param (
65- dict ,
66- dict (LIB_ASCII = "YES" ),
67- msg = "Star config overrides." ,
68- ),
69- gal_config = Param (
70- dict ,
71- dict (
72- LIB_ASCII = "YES" ,
73- MOD_EXTINC = "18,26,26,33,26,33,26,33" ,
74- EXTINC_LAW = "SMC_prevot.dat,SB_calzetti.dat,SB_calzetti_bump1.dat,SB_calzetti_bump2.dat" ,
75- EM_LINES = "EMP_UV" ,
76- EM_DISPERSION = "0.5,0.75,1.,1.5,2." ,
77- ),
78- msg = "Galaxy config overrides." ,
79- ),
80- qso_config = Param (
81- dict ,
82- dict (
83- LIB_ASCII = "YES" ,
84- MOD_EXTINC = "0,1000" ,
85- EB_V = "0.,0.1,0.2,0.3" ,
86- EXTINC_LAW = "SB_calzetti.dat" ,
87- ),
88- msg = "QSO config overrides." ,
89- ),
90103 )
104+ _add_sub_config (config_options , lsst_default_config , "lephare." )
105+ _add_sub_config (config_options , star_default_config , "gal." )
106+ _add_sub_config (config_options , gal_default_config , "gal." )
107+ _add_sub_config (config_options , qso_default_config , "qso." )
91108
92109 def __init__ (self , args , ** kwargs ):
93110 """Init function, init config stuff (COPIED from rail_bpz)"""
94111
95112 super ().__init__ (args , ** kwargs )
96113
97114 def validate (self ):
98- self .lephare_config = self .config [ "lephare_config" ]
115+ self .lephare_config = _get_sub_config ( self .config , 'lephare.' )
99116
100117 # Put something in place to allow for not rerunning the prepare stage
101118 try :
@@ -114,7 +131,7 @@ def validate(self):
114131 print (
115132 f"rail_lephare is setting the Z_STEP config to { Z_STEP } based on the informer params."
116133 )
117- self .config ["lephare_config" ][ " Z_STEP" ] = Z_STEP
134+ self .config ["lephare. Z_STEP" ] = Z_STEP
118135 # We create a run directory with the informer name
119136 self .run_dir = _set_run_dir (self .config ["name" ])
120137
@@ -147,13 +164,17 @@ def run(self):
147164 # Get number of sources
148165 ngal = len (training_data [self .config .ref_band ])
149166
167+ star_config = _get_sub_config (self .config , 'star.' )
168+ gal_config = _get_sub_config (self .config , 'gal.' )
169+ qso_config = _get_sub_config (self .config , 'gso.' )
170+
150171 # The three main lephare specific inform tasks
151172 if self .do_prepare :
152173 lp .prepare (
153174 self .lephare_config ,
154- star_config = self . config [ " star_config" ] ,
155- gal_config = self . config [ " gal_config" ] ,
156- qso_config = self . config [ " qso_config" ] ,
175+ star_config = star_config ,
176+ gal_config = gal_config ,
177+ qso_config = qso_config ,
157178 )
158179 else :
159180 print (
@@ -168,20 +189,20 @@ def run(self):
168189 training_data , self .config .bands , self .config .err_bands
169190 )
170191 # This will return zeros if AUTO_ADAPT is NO
171- offsets = lp .calculate_offsets_from_input (self .config [ " lephare_config" ] , input )
192+ offsets = lp .calculate_offsets_from_input (self .lephare_config , input )
172193 # We must make a string dictionary to allow pickling and saving
173194 lephare_config = lp .keymap_to_string_dict (
174- lp .all_types_to_keymap (self .config [ " lephare_config" ] )
195+ lp .all_types_to_keymap (self .lephare_config )
175196 )
176197 # Give principle inform config 'model' to instance.
177198 self .model = dict (
178199 lephare_version = lp .__version__ ,
179- lephare_config = lephare_config ,
200+ lephare_config = self . lephare_config ,
180201 offsets = offsets ,
181202 run_dir = self .run_dir ,
182- star_config = self . config [ " star_config" ] ,
183- gal_config = self . config [ " gal_config" ] ,
184- qso_config = self . config [ " qso_config" ] ,
203+ star_config = star_config ,
204+ gal_config = gal_config ,
205+ qso_config = qso_config ,
185206 )
186207 self .add_data ("model" , self .model )
187208
@@ -202,10 +223,10 @@ class LephareEstimator(CatEstimator):
202223 ref_band = SHARED_PARAMS ,
203224 err_bands = SHARED_PARAMS ,
204225 redshift_col = SHARED_PARAMS ,
205- lephare_config = Param (
206- dict ,
207- {} ,
208- msg = "The lephare config keymap. If unset we load it from the model. " ,
226+ lephare_config_from_model = Param (
227+ bool ,
228+ True ,
229+ "Load lephare config keymap from model" ,
209230 ),
210231 use_inform_offsets = Param (
211232 bool ,
@@ -252,6 +273,7 @@ class LephareEstimator(CatEstimator):
252273 ),
253274 ),
254275 )
276+ _add_sub_config (config_options , lsst_default_config , "lephare." )
255277
256278 def __init__ (self , args , ** kwargs ):
257279 super ().__init__ (args , ** kwargs )
@@ -262,10 +284,10 @@ def __init__(self, args, **kwargs):
262284
263285 def open_model (self , ** kwargs ):
264286 CatEstimator .open_model (self , ** kwargs )
265- if self .config ["lephare_config" ]:
266- self .lephare_config = self .config ["lephare_config" ]
267- else :
287+ if self .config ["lephare_config_from_model" ]:
268288 self .lephare_config = self .model ["lephare_config" ]
289+ else :
290+ self .lephare_config = _get_sub_config (self .config , "lephare." )
269291 # Use string dictionary config in case keymap passed to estimate stage
270292 self .lephare_config = lp .keymap_to_string_dict (
271293 lp .all_types_to_keymap (self .lephare_config )
0 commit comments