|
59 | 59 | PMM, |
60 | 60 | DistributionLoss, |
61 | 61 | MQLoss, |
| 62 | + sCRPS, |
62 | 63 | ) |
63 | 64 | from neuralforecast.tsdataset import TimeSeriesDataset |
64 | 65 | from neuralforecast.utils import ( |
@@ -2251,3 +2252,30 @@ def test_mase_validation_loss_scale(setup_airplane_data): |
2251 | 2252 | f"MASE validation loss is {valid_loss}, which indicates the scale mismatch " |
2252 | 2253 | f"bug may have regressed. Expected < 50 for a properly scaled MASE." |
2253 | 2254 | ) |
| 2255 | + |
| 2256 | + |
| 2257 | +@pytest.mark.parametrize("loss,valid_loss", [ |
| 2258 | + # DistributionLoss uses NLL — mismatched levels are allowed |
| 2259 | + (DistributionLoss(distribution="Normal", level=[80, 90]), DistributionLoss(distribution="Normal", level=[50])), |
| 2260 | + # GMM + sCRPS with matching quantiles |
| 2261 | + (GMM(n_components=5, level=[80, 90]), sCRPS(level=[80, 90])), |
| 2262 | + # MQLoss with matching quantiles |
| 2263 | + (MQLoss(level=[80, 90]), MQLoss(level=[80, 90])), |
| 2264 | +]) |
| 2265 | +def test_loss_valid_loss_quantiles_allowed(loss, valid_loss): |
| 2266 | + """Loss/valid_loss combinations that should not raise a quantile mismatch error.""" |
| 2267 | + model = NHITS(h=12, input_size=24, loss=loss, valid_loss=valid_loss, max_steps=1) |
| 2268 | + NeuralForecast(models=[model], freq="M") |
| 2269 | + |
| 2270 | + |
| 2271 | +@pytest.mark.parametrize("loss,valid_loss", [ |
| 2272 | + # MQLoss quantiles embedded in loss computation — mismatch must be rejected |
| 2273 | + (MQLoss(level=[80, 90]), MQLoss(level=[50])), |
| 2274 | + # GMM + sCRPS reproduces the issue |
| 2275 | + (GMM(n_components=5, level=[80, 90]), sCRPS(level=[50])), |
| 2276 | +]) |
| 2277 | +def test_loss_valid_loss_quantiles_mismatch_raises(loss, valid_loss): |
| 2278 | + """Loss/valid_loss combinations with mismatched quantiles must raise ValueError.""" |
| 2279 | + model = NHITS(h=12, input_size=24, loss=loss, valid_loss=valid_loss, max_steps=1) |
| 2280 | + with pytest.raises(ValueError, match="quantiles.*do not match"): |
| 2281 | + NeuralForecast(models=[model], freq="M") |
0 commit comments