@@ -198,6 +198,10 @@ def __init__(
198198 raise ValueError (
199199 f"Unknown backend { backend } . The supported backends are 'ray' and 'optuna'."
200200 )
201+ # Translate `*input_size_multiplier` entries (as used by the models'
202+ # `default_config`) into concrete `*input_size` ones, so that configs
203+ # built on top of `default_config` are valid `config` arguments.
204+ config_base = self ._translate_input_size_multipliers (config_base , h )
201205
202206 # Shallow-copy user-supplied options so subsequent mutations
203207 # (default resolution) don't leak back to the caller.
@@ -268,7 +272,9 @@ def __init__(
268272 else :
269273
270274 def config_f (trial ):
271- return {** config (trial ), ** config_base }
275+ return self ._translate_input_size_multipliers (
276+ {** config (trial ), ** config_base }, h
277+ )
272278
273279 self .config = config_f
274280
@@ -300,6 +306,43 @@ def config_f(trial):
300306 def __repr__ (self ):
301307 return type (self ).__name__ if self .alias is None else self .alias
302308
309+ @staticmethod
310+ def _translate_input_size_multipliers (config , h ):
311+ """Translate `*input_size_multiplier` config entries into `*input_size` ones.
312+
313+ The models' `default_config` express the input sizes as multiples of the
314+ horizon (`input_size_multiplier` and `inference_input_size_multiplier`).
315+ These keys are not valid model arguments, so apply here the same
316+ translation that `get_default_config` performs, which allows users to
317+ provide configs built on top of `default_config`.
318+
319+ Args:
320+ config (dict): Configuration dict, possibly containing
321+ `*input_size_multiplier` entries.
322+ h (int): Forecast horizon.
323+
324+ Returns:
325+ dict: Configuration dict with the multipliers translated into sizes.
326+ """
327+ translations = {
328+ "input_size_multiplier" : "input_size" ,
329+ "inference_input_size_multiplier" : "inference_input_size" ,
330+ }
331+ if not any (key in config for key in translations ):
332+ return config
333+ config = dict (config )
334+ for multiplier_key , size_key in translations .items ():
335+ if multiplier_key not in config :
336+ continue
337+ multipliers = config .pop (multiplier_key )
338+ if size_key in config :
339+ continue
340+ if isinstance (multipliers , (list , tuple )):
341+ config [size_key ] = tune .choice ([h * x for x in multipliers ])
342+ else :
343+ config [size_key ] = h * multipliers
344+ return config
345+
303346 def _train_tune (self , config_step , cls_model , dataset , val_size , test_size ):
304347 """BaseAuto._train_tune
305348
0 commit comments