Skip to content

Commit 8afac6d

Browse files
authored
[FIX] Raise error when missing params in auto config (#1540)
1 parent 6898904 commit 8afac6d

2 files changed

Lines changed: 53 additions & 0 deletions

File tree

neuralforecast/common/_base_auto.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
__all__ = ['BaseAuto', 'RayOptions', 'OptunaOptions']
22

33

4+
import inspect
45
import warnings
56
from copy import deepcopy
67
from dataclasses import dataclass, fields, replace
@@ -117,6 +118,10 @@ class BaseAuto(pl.LightningModule):
117118
loss (PyTorch module): Instantiated train loss class from [losses collection](./losses.pytorch.html).
118119
valid_loss (PyTorch module): Instantiated valid loss class from [losses collection](./losses.pytorch.html).
119120
config (dict or callable): Dictionary with ray.tune defined search space or function that takes an optuna trial and returns a configuration dict.
121+
The config must include every parameter of the underlying model that has no
122+
default value (e.g. `input_size`, and `n_series` for multivariate models),
123+
either as a fixed value or as a search variable. `h`, `loss`, and `valid_loss`
124+
are injected automatically and must not be set in `config`.
120125
search_alg (ray.tune.search variant or optuna.sampler): For ray see https://docs.ray.io/en/latest/tune/api_docs/suggestion.html
121126
For optuna see https://optuna.readthedocs.io/en/stable/reference/samplers/index.html.
122127
num_samples (int): Number of hyperparameter optimization steps/samples.
@@ -241,6 +246,29 @@ def __init__(
241246
else:
242247
self.early_stop_patience_steps = -1
243248

249+
# Surface required-but-missing parameters up front. Without this,
250+
auto_provided = {"h", "loss", "valid_loss"}
251+
missing_required = [
252+
name
253+
for name, param in inspect.signature(cls_model.__init__).parameters.items()
254+
if name != "self"
255+
and name not in auto_provided
256+
and name not in config_base
257+
and param.default is inspect.Parameter.empty
258+
and param.kind
259+
in (
260+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
261+
inspect.Parameter.KEYWORD_ONLY,
262+
)
263+
]
264+
if missing_required:
265+
raise ValueError(
266+
f"`config` is missing required parameter(s) for "
267+
f"{cls_model.__name__}: {missing_required}. These parameters "
268+
f"have no default and must be provided in `config` (either as "
269+
f"a fixed value or as a tune/optuna search variable)."
270+
)
271+
244272
if callable(config):
245273
# reset config_base here to save params to override in the config fn
246274
config_base = {}

tests/test_common/test_base_auto.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88
from ray import tune
99

10+
from neuralforecast.auto import AutoMLP
1011
from neuralforecast.common._base_auto import BaseAuto, OptunaOptions, RayOptions
1112
from neuralforecast.losses.pytorch import MAE, MSE
1213
from neuralforecast.models.mlp import MLP
@@ -141,6 +142,30 @@ def test_validation_default(setup_config):
141142
assert str(type(auto.valid_loss)) == "<class 'neuralforecast.losses.pytorch.MSE'>"
142143

143144

145+
def test_config_missing_required_param_raises():
146+
# A config that omits a required no-default arg of the underlying model used to fail deep
147+
# inside ray with an opaque "No best trial found" error. Now it must fail fast
148+
# at construction time with a message naming the missing key.
149+
with pytest.raises(ValueError, match="input_size"):
150+
AutoMLP(
151+
h=4,
152+
config={
153+
"max_steps": tune.choice([5]),
154+
"random_seed": tune.choice([0]),
155+
},
156+
backend="ray",
157+
)
158+
159+
def config_fn(trial):
160+
return {
161+
"max_steps": trial.suggest_categorical("max_steps", [5]),
162+
"random_seed": trial.suggest_categorical("random_seed", [0]),
163+
}
164+
165+
with pytest.raises(ValueError, match="input_size"):
166+
AutoMLP(h=4, config=config_fn, backend="optuna")
167+
168+
144169
def test_ray_time_budget(setup_module):
145170
dataset, _, _ = setup_module
146171
config = {

0 commit comments

Comments
 (0)