Skip to content

Commit 455bc3f

Browse files
authored
[FIX] Return self so we don't lose HINT wrapper and preds are reconciled (#1524)
1 parent 40d71c9 commit 455bc3f

2 files changed

Lines changed: 36 additions & 2 deletions

File tree

neuralforecast/models/hint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def fit(
214214
Returns:
215215
self: A fitted base `NeuralForecast` model.
216216
"""
217-
model = self.model.fit(
217+
self.model = self.model.fit(
218218
dataset=dataset,
219219
val_size=val_size,
220220
test_size=test_size,
@@ -226,7 +226,7 @@ def fit(
226226
self.futr_exog_list = self.model.futr_exog_list
227227
self.hist_exog_list = self.model.hist_exog_list
228228
self.stat_exog_list = self.model.stat_exog_list
229-
return model
229+
return self
230230

231231
def predict(self, dataset, step_size=1, random_seed=None, **data_module_kwargs):
232232
"""HINT.predict

tests/test_models/test_hint.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,37 @@ def test_hint_model():
8787
parent_value = hint_mean[parent_idx]
8888
children_sum = hint_mean[children_list].sum()
8989
np.testing.assert_allclose(children_sum, parent_value, rtol=1e-6)
90+
91+
92+
# HINT wrapper must survive nf.fit() so that
93+
# nf.predict() emits the HINT column and applies reconciliation.
94+
def test_hint_fit_predict():
95+
Y_df, S, quantiles = setup_synthetic_data()
96+
nhits = NHITS(
97+
h=4,
98+
input_size=4,
99+
loss=GMM(n_components=2, quantiles=quantiles, num_samples=len(quantiles)),
100+
max_steps=5,
101+
early_stop_patience_steps=2,
102+
val_check_steps=1,
103+
scaler_type="robust",
104+
learning_rate=1e-3,
105+
)
106+
model = HINT(h=4, model=nhits, S=S, reconciliation="BottomUp")
107+
108+
nf = NeuralForecast(models=[model], freq="Q")
109+
nf.fit(df=Y_df, val_size=4)
110+
forecasts = nf.predict()
111+
112+
assert isinstance(nf.models[0], HINT), (
113+
"HINT wrapper was replaced by the underlying model after fit()."
114+
)
115+
assert "HINT" in forecasts.columns, "fit()/predict() did not emit a HINT column."
116+
117+
parent_children_dict = {0: [1, 2], 1: [3, 4], 2: [5, 6]}
118+
for _, df in forecasts.groupby("ds"):
119+
hint_mean = df["HINT"].values
120+
for parent_idx, children_list in parent_children_dict.items():
121+
parent_value = hint_mean[parent_idx]
122+
children_sum = hint_mean[children_list].sum()
123+
np.testing.assert_allclose(children_sum, parent_value, rtol=1e-6)

0 commit comments

Comments
 (0)