Skip to content

Commit c67a722

Browse files
Translate input_size_multiplier in user-supplied Auto configs (#1553)
Co-authored-by: Marco <marco@nixtla.io>
1 parent 9198931 commit c67a722

2 files changed

Lines changed: 97 additions & 1 deletion

File tree

neuralforecast/common/_base_auto.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ def __init__(
198198
raise ValueError(
199199
f"Unknown backend {backend}. The supported backends are 'ray' and 'optuna'."
200200
)
201+
# Translate `*input_size_multiplier` entries (as used by the models'
202+
# `default_config`) into concrete `*input_size` ones, so that configs
203+
# built on top of `default_config` are valid `config` arguments.
204+
config_base = self._translate_input_size_multipliers(config_base, h)
201205

202206
# Shallow-copy user-supplied options so subsequent mutations
203207
# (default resolution) don't leak back to the caller.
@@ -268,7 +272,9 @@ def __init__(
268272
else:
269273

270274
def config_f(trial):
271-
return {**config(trial), **config_base}
275+
return self._translate_input_size_multipliers(
276+
{**config(trial), **config_base}, h
277+
)
272278

273279
self.config = config_f
274280

@@ -300,6 +306,43 @@ def config_f(trial):
300306
def __repr__(self):
301307
return type(self).__name__ if self.alias is None else self.alias
302308

309+
@staticmethod
310+
def _translate_input_size_multipliers(config, h):
311+
"""Translate `*input_size_multiplier` config entries into `*input_size` ones.
312+
313+
The models' `default_config` express the input sizes as multiples of the
314+
horizon (`input_size_multiplier` and `inference_input_size_multiplier`).
315+
These keys are not valid model arguments, so apply here the same
316+
translation that `get_default_config` performs, which allows users to
317+
provide configs built on top of `default_config`.
318+
319+
Args:
320+
config (dict): Configuration dict, possibly containing
321+
`*input_size_multiplier` entries.
322+
h (int): Forecast horizon.
323+
324+
Returns:
325+
dict: Configuration dict with the multipliers translated into sizes.
326+
"""
327+
translations = {
328+
"input_size_multiplier": "input_size",
329+
"inference_input_size_multiplier": "inference_input_size",
330+
}
331+
if not any(key in config for key in translations):
332+
return config
333+
config = dict(config)
334+
for multiplier_key, size_key in translations.items():
335+
if multiplier_key not in config:
336+
continue
337+
multipliers = config.pop(multiplier_key)
338+
if size_key in config:
339+
continue
340+
if isinstance(multipliers, (list, tuple)):
341+
config[size_key] = tune.choice([h * x for x in multipliers])
342+
else:
343+
config[size_key] = h * multipliers
344+
return config
345+
303346
def _train_tune(self, config_step, cls_model, dataset, val_size, test_size):
304347
"""BaseAuto._train_tune
305348

tests/test_common/test_base_auto.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,59 @@ def config_fn(trial):
166166
AutoMLP(h=4, config=config_fn, backend="optuna")
167167

168168

169+
def test_default_config_copy_is_valid_user_config(setup_module):
170+
# https://github.com/Nixtla/neuralforecast/issues/571
171+
# The default configs express the input size as `input_size_multiplier`,
172+
# which used to be translated into `input_size` only when `config=None`,
173+
# making configs built on top of `default_config` invalid.
174+
dataset, _, _ = setup_module
175+
config = AutoMLP.default_config.copy()
176+
# tweak the search space, keeping it small so the test runs fast
177+
config["input_size_multiplier"] = [1, 2]
178+
config["hidden_size"] = tune.choice([8])
179+
config["num_layers"] = 2
180+
config["max_steps"] = 1
181+
config["val_check_steps"] = 1
182+
auto = AutoMLP(
183+
h=12,
184+
config=config,
185+
num_samples=1,
186+
ray_options=RayOptions(cpus=1, gpus=0),
187+
)
188+
assert "input_size_multiplier" not in auto.config
189+
assert "input_size" in auto.config
190+
auto.fit(dataset=dataset)
191+
y_hat = auto.predict(dataset=dataset)
192+
assert y_hat.shape[0] > 0
193+
194+
195+
def test_optuna_config_translates_input_size_multiplier(setup_module):
196+
# https://github.com/Nixtla/neuralforecast/issues/571 (optuna variant)
197+
dataset, _, _ = setup_module
198+
199+
def config_multiplier(trial):
200+
return {
201+
"input_size_multiplier": trial.suggest_categorical(
202+
"input_size_multiplier", [1, 2]
203+
),
204+
"hidden_size": trial.suggest_categorical("hidden_size", [8]),
205+
"num_layers": 2,
206+
"max_steps": 1,
207+
"val_check_steps": 1,
208+
}
209+
210+
auto = AutoMLP(
211+
h=12,
212+
config=config_multiplier,
213+
backend="optuna",
214+
search_alg=optuna.samplers.RandomSampler(seed=0),
215+
num_samples=1,
216+
)
217+
auto.fit(dataset=dataset)
218+
y_hat = auto.predict(dataset=dataset)
219+
assert y_hat.shape[0] > 0
220+
221+
169222
def test_ray_time_budget(setup_module):
170223
dataset, _, _ = setup_module
171224
config = {

0 commit comments

Comments
 (0)