Skip to content

Commit 0ab5baf

Browse files
authored
[FIX] Mismtached quantiles in loss and valid_loss (#1470)
1 parent 2757284 commit 0ab5baf

2 files changed

Lines changed: 50 additions & 2 deletions

File tree

neuralforecast/core.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@
7373
)
7474

7575
from .common._base_auto import BaseAuto, MockTrial
76-
from .common._base_model import DistributedConfig
76+
from .common._base_model import DistributedConfig, MULTIQUANTILE_LOSSES
7777
from .compat import SparkDataFrame
78-
from .losses.pytorch import HuberIQLoss, IQLoss
78+
from .losses.pytorch import HuberIQLoss, IQLoss, sCRPS
7979

8080
# this disables warnings about the number of workers in the dataloaders
8181
# which the user can't control
@@ -248,6 +248,26 @@ def __init__(
248248
model.h == models[0].h for model in models
249249
), "All models should have the same horizon"
250250

251+
for model in models:
252+
valid_loss = getattr(model, "valid_loss", None)
253+
if isinstance(valid_loss, sCRPS):
254+
valid_qs = valid_loss.mql.quantiles
255+
elif isinstance(valid_loss, MULTIQUANTILE_LOSSES):
256+
valid_qs = valid_loss.quantiles
257+
else:
258+
continue
259+
loss = getattr(model, "loss", None)
260+
loss_qs = getattr(loss, "quantiles", None)
261+
if loss_qs is None:
262+
continue
263+
if sorted(loss_qs.tolist()) != sorted(valid_qs.tolist()):
264+
raise ValueError(
265+
f"{model.__class__.__name__}: `loss` ({loss.__class__.__name__}) "
266+
f"quantiles {loss_qs.tolist()} do not match `valid_loss` "
267+
f"({valid_loss.__class__.__name__}) quantiles {valid_qs.tolist()}. "
268+
f"Ensure both use the same `level` or `quantiles` argument."
269+
)
270+
251271
self.h = models[0].h
252272
self.models_init = models
253273
self.freq = freq

tests/test_core.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
PMM,
6060
DistributionLoss,
6161
MQLoss,
62+
sCRPS,
6263
)
6364
from neuralforecast.tsdataset import TimeSeriesDataset
6465
from neuralforecast.utils import (
@@ -2251,3 +2252,30 @@ def test_mase_validation_loss_scale(setup_airplane_data):
22512252
f"MASE validation loss is {valid_loss}, which indicates the scale mismatch "
22522253
f"bug may have regressed. Expected < 50 for a properly scaled MASE."
22532254
)
2255+
2256+
2257+
@pytest.mark.parametrize("loss,valid_loss", [
2258+
# DistributionLoss uses NLL — mismatched levels are allowed
2259+
(DistributionLoss(distribution="Normal", level=[80, 90]), DistributionLoss(distribution="Normal", level=[50])),
2260+
# GMM + sCRPS with matching quantiles
2261+
(GMM(n_components=5, level=[80, 90]), sCRPS(level=[80, 90])),
2262+
# MQLoss with matching quantiles
2263+
(MQLoss(level=[80, 90]), MQLoss(level=[80, 90])),
2264+
])
2265+
def test_loss_valid_loss_quantiles_allowed(loss, valid_loss):
2266+
"""Loss/valid_loss combinations that should not raise a quantile mismatch error."""
2267+
model = NHITS(h=12, input_size=24, loss=loss, valid_loss=valid_loss, max_steps=1)
2268+
NeuralForecast(models=[model], freq="M")
2269+
2270+
2271+
@pytest.mark.parametrize("loss,valid_loss", [
2272+
# MQLoss quantiles embedded in loss computation — mismatch must be rejected
2273+
(MQLoss(level=[80, 90]), MQLoss(level=[50])),
2274+
# GMM + sCRPS reproduces the issue
2275+
(GMM(n_components=5, level=[80, 90]), sCRPS(level=[50])),
2276+
])
2277+
def test_loss_valid_loss_quantiles_mismatch_raises(loss, valid_loss):
2278+
"""Loss/valid_loss combinations with mismatched quantiles must raise ValueError."""
2279+
model = NHITS(h=12, input_size=24, loss=loss, valid_loss=valid_loss, max_steps=1)
2280+
with pytest.raises(ValueError, match="quantiles.*do not match"):
2281+
NeuralForecast(models=[model], freq="M")

0 commit comments

Comments
 (0)