|
1 | 1 | __all__ = ['BaseAuto', 'RayOptions', 'OptunaOptions'] |
2 | 2 |
|
3 | 3 |
|
| 4 | +import inspect |
4 | 5 | import warnings |
5 | 6 | from copy import deepcopy |
6 | 7 | from dataclasses import dataclass, fields, replace |
@@ -117,6 +118,10 @@ class BaseAuto(pl.LightningModule): |
117 | 118 | loss (PyTorch module): Instantiated train loss class from [losses collection](./losses.pytorch.html). |
118 | 119 | valid_loss (PyTorch module): Instantiated valid loss class from [losses collection](./losses.pytorch.html). |
119 | 120 | config (dict or callable): Dictionary with ray.tune defined search space or function that takes an optuna trial and returns a configuration dict. |
| 121 | + The config must include every parameter of the underlying model that has no |
| 122 | + default value (e.g. `input_size`, and `n_series` for multivariate models), |
| 123 | + either as a fixed value or as a search variable. `h`, `loss`, and `valid_loss` |
| 124 | + are injected automatically and must not be set in `config`. |
120 | 125 | search_alg (ray.tune.search variant or optuna.sampler): For ray see https://docs.ray.io/en/latest/tune/api_docs/suggestion.html |
121 | 126 | For optuna see https://optuna.readthedocs.io/en/stable/reference/samplers/index.html. |
122 | 127 | num_samples (int): Number of hyperparameter optimization steps/samples. |
@@ -241,6 +246,29 @@ def __init__( |
241 | 246 | else: |
242 | 247 | self.early_stop_patience_steps = -1 |
243 | 248 |
|
| 249 | + # Surface required-but-missing parameters up front. Without this, |
| 250 | + auto_provided = {"h", "loss", "valid_loss"} |
| 251 | + missing_required = [ |
| 252 | + name |
| 253 | + for name, param in inspect.signature(cls_model.__init__).parameters.items() |
| 254 | + if name != "self" |
| 255 | + and name not in auto_provided |
| 256 | + and name not in config_base |
| 257 | + and param.default is inspect.Parameter.empty |
| 258 | + and param.kind |
| 259 | + in ( |
| 260 | + inspect.Parameter.POSITIONAL_OR_KEYWORD, |
| 261 | + inspect.Parameter.KEYWORD_ONLY, |
| 262 | + ) |
| 263 | + ] |
| 264 | + if missing_required: |
| 265 | + raise ValueError( |
| 266 | + f"`config` is missing required parameter(s) for " |
| 267 | + f"{cls_model.__name__}: {missing_required}. These parameters " |
| 268 | + f"have no default and must be provided in `config` (either as " |
| 269 | + f"a fixed value or as a tune/optuna search variable)." |
| 270 | + ) |
| 271 | + |
244 | 272 | if callable(config): |
245 | 273 | # reset config_base here to save params to override in the config fn |
246 | 274 | config_base = {} |
|
0 commit comments