Skip to content

Commit b5bca21

Browse files
hljubicmarcopeix
andauthored
Add SOFTSSharp model (#1516)
Co-authored-by: Marco <marco@nixtla.io>
1 parent 77c8095 commit b5bca21

10 files changed

Lines changed: 509 additions & 6 deletions

File tree

docs/mintlify/docs.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@
112112
"models.rmok.html",
113113
"models.rnn.html",
114114
"models.softs.html",
115+
"models.softssharp.html",
115116
"models.stemgnn.html",
116117
"models.tcn.html",
117118
"models.tft.html",
@@ -162,4 +163,4 @@
162163
"href": "https://github.com/Nixtla/neuralforecast"
163164
}
164165
}
165-
}
166+
}
290 KB
Loading

docs/models.html.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ The optimization process uses temporal cross-validation where the validation set
3131

3232
## 2. Available AutoModels
3333

34-
NeuralForecast provides 34 `AutoModel` variants, each wrapping a specific forecasting model with automatic hyperparameter optimization. Each `AutoModel` has a `default_config` attribute that defines sensible search spaces for its corresponding model.
34+
NeuralForecast provides 35 `AutoModel` variants, each wrapping a specific forecasting model with automatic hyperparameter optimization. Each `AutoModel` has a `default_config` attribute that defines sensible search spaces for its corresponding model.
3535

3636
### RNN-Based Models
3737
Recurrent neural networks for sequential forecasting:
@@ -83,6 +83,7 @@ Models designed for specific forecasting scenarios:
8383
- `AutoKAN`: [Kolmogorov-Arnold Network for time series](./models.kan.html)
8484
- `AutoStemGNN`: [Graph neural network for multivariate forecasting](./models.stemgnn.html)
8585
- `AutoSOFTS`: [Spectral Optimal Fourier Transform model](./models.softs.html)
86+
- `AutoSOFTSSharp`: [SOFTS extension with stochastic variable-position encoding](./models.softssharp.html)
8687
- `AutoTimeMixer`: [Temporal mixing architecture](./models.timemixer.html)
8788
- `AutoRMoK`: [Random Mixture of Kernels](./models.rmok.html)
8889
- `AutoHINT`: [Hierarchical forecasting with automatic reconciliation](./models.hint.html)
@@ -191,4 +192,4 @@ model = AutoHINT(
191192

192193
model.fit(dataset=dataset, val_size=4)
193194
y_hat = model.predict(dataset=dataset)
194-
```
195+
```

docs/models.softssharp.html.md

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
---
2+
description: >-
3+
SOFTSSharp: SOFTS extension with stochastic variable-position encoding for multivariate time series forecasting.
4+
output-file: models.softssharp.html
5+
title: SOFTSSharp
6+
---
7+
8+
SOFTSSharp extends SOFTS by stochastically adding variable-position embeddings and multiple dropout layers inside the STAD aggregation-redistribution component, aiming to improve forecasting accuracy while preserving linear complexity.
9+
10+
11+
![Figure 1. Architecture of SOFTSSharp](imgs_models/softssharp.png)
12+
*Figure 1. Architecture of SOFTSSharp*
13+
14+
## 1. SOFTSSharp
15+
16+
::: neuralforecast.models.softssharp.SOFTSSharp
17+
options:
18+
members:
19+
- fit
20+
- predict
21+
heading_level: 3
22+
23+
### Usage example
24+
25+
```python
26+
import pandas as pd
27+
import matplotlib.pyplot as plt
28+
29+
from neuralforecast import NeuralForecast
30+
from neuralforecast.models import SOFTSSharp
31+
from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic
32+
from neuralforecast.losses.pytorch import MASE
33+
34+
Y_train_df = AirPassengersPanel[AirPassengersPanel.ds<AirPassengersPanel['ds'].values[-12]].reset_index(drop=True)
35+
Y_test_df = AirPassengersPanel[AirPassengersPanel.ds>=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True)
36+
37+
model = SOFTSSharp(h=12,
38+
input_size=24,
39+
n_series=2,
40+
hidden_size=256,
41+
d_core=256,
42+
e_layers=2,
43+
d_ff=64,
44+
dropout=0.1,
45+
pe_keep_prob=0.5,
46+
use_norm=True,
47+
loss=MASE(seasonality=4),
48+
early_stop_patience_steps=3,
49+
batch_size=32)
50+
51+
fcst = NeuralForecast(models=[model], freq='ME')
52+
fcst.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)
53+
forecasts = fcst.predict(futr_df=Y_test_df)
54+
55+
fig, ax = plt.subplots(1, 1, figsize = (20, 7))
56+
Y_hat_df = forecasts.reset_index(drop=False).drop(columns=['unique_id','ds'])
57+
plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)
58+
plot_df = pd.concat([Y_train_df, plot_df])
59+
60+
plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)
61+
plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')
62+
plt.plot(plot_df['ds'], plot_df['SOFTSSharp'], c='blue', label='Forecast')
63+
ax.set_title('AirPassengers Forecast', fontsize=22)
64+
ax.set_ylabel('Monthly Passengers', fontsize=20)
65+
ax.set_xlabel('Year', fontsize=20)
66+
ax.legend(prop={'size': 15})
67+
ax.grid()
68+
```
69+
70+
## 2. Auxiliary functions
71+
72+
::: neuralforecast.models.softssharp.PositionalEmbedding
73+
options:
74+
members: []
75+
76+
::: neuralforecast.models.softssharp.STADSharp
77+
options:
78+
members: []

nbs/docs/capabilities/overview.ipynb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
"|`RMoK` | `AutoRMoK` | KAN | Multivariate | Direct | - |\n",
3838
"|`RNN` | `AutoRNN` | RNN | Univariate | Both<sup>8</sup> | F/H/S | \n",
3939
"|`SOFTS` | `AutoSOFTS` | MLP | Multivariate | Direct | - | \n",
40+
"|`SOFTSSharp` | `AutoSOFTSSharp` | MLP | Multivariate | Direct | - | \n",
4041
"|`StemGNN` | `AutoStemGNN` | GNN | Multivariate | Direct | - | \n",
4142
"|`TCN` | `AutoTCN` | CNN | Univariate | Direct | F/H/S | \n",
4243
"|`TFT` | `AutoTFT` | Transformer | Univariate | Direct | F/H/S | \n",

neuralforecast/auto.py

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
'AutoNBEATS', 'AutoNBEATSx', 'AutoNHITS', 'AutoDLinear', 'AutoNLinear', 'AutoTiDE', 'AutoDeepNPTS',
66
'AutoKAN', 'AutoTFT', 'AutoVanillaTransformer', 'AutoInformer', 'AutoAutoformer', 'AutoFEDformer',
77
'AutoPatchTST', 'AutoiTransformer', 'AutoTimeXer', 'AutoTimesNet', 'AutoStemGNN', 'AutoHINT', 'AutoTSMixer',
8-
'AutoTSMixerx', 'AutoMLPMultivariate', 'AutoSOFTS', 'AutoTimeMixer', 'AutoRMoK', 'AutoXLinear',
9-
'RayOptions', 'OptunaOptions']
8+
'AutoTSMixerx', 'AutoMLPMultivariate', 'AutoSOFTS', 'AutoSOFTSSharp', 'AutoTimeMixer', 'AutoRMoK',
9+
'AutoXLinear', 'RayOptions', 'OptunaOptions']
1010

1111

1212
from ray import tune
@@ -37,6 +37,7 @@
3737
from .models.rmok import RMoK
3838
from .models.rnn import RNN
3939
from .models.softs import SOFTS
40+
from .models.softssharp import SOFTSSharp
4041
from .models.stemgnn import StemGNN
4142
from .models.tcn import TCN
4243
from .models.tft import TFT
@@ -2575,6 +2576,92 @@ def get_default_config(cls, h, backend, n_series):
25752576
return config
25762577

25772578

2579+
class AutoSOFTSSharp(BaseAuto):
2580+
2581+
default_config = {
2582+
"input_size_multiplier": [1, 2, 3, 4, 5],
2583+
"h": None,
2584+
"n_series": None,
2585+
"hidden_size": tune.choice([64, 128, 256, 512]),
2586+
"d_core": tune.choice([64, 128, 256, 512]),
2587+
"pe_keep_prob": tune.choice([0.25, 0.5, 0.75, 1.0]),
2588+
"learning_rate": tune.loguniform(1e-4, 1e-1),
2589+
"scaler_type": tune.choice([None, "robust", "standard", "identity"]),
2590+
"max_steps": tune.choice([500, 1000, 2000]),
2591+
"batch_size": tune.choice([32, 64, 128, 256]),
2592+
"loss": None,
2593+
"random_seed": tune.randint(1, 20),
2594+
}
2595+
2596+
def __init__(
2597+
self,
2598+
h,
2599+
n_series,
2600+
loss=MAE(),
2601+
valid_loss=None,
2602+
config=None,
2603+
search_alg=BasicVariantGenerator(random_state=1),
2604+
num_samples=10,
2605+
time_budget=None,
2606+
refit_with_val=False,
2607+
cpus=None,
2608+
gpus=None,
2609+
verbose=False,
2610+
alias=None,
2611+
backend="ray",
2612+
callbacks=None,
2613+
ray_options=None,
2614+
optuna_options=None,
2615+
):
2616+
2617+
if config is None:
2618+
config = self.get_default_config(h=h, backend=backend, n_series=n_series)
2619+
2620+
if backend == "ray":
2621+
config["n_series"] = n_series
2622+
elif backend == "optuna":
2623+
mock_trial = MockTrial()
2624+
if (
2625+
"n_series" in config(mock_trial)
2626+
and config(mock_trial)["n_series"] != n_series
2627+
) or ("n_series" not in config(mock_trial)):
2628+
raise Exception(f"config needs 'n_series': {n_series}")
2629+
2630+
super(AutoSOFTSSharp, self).__init__(
2631+
cls_model=SOFTSSharp,
2632+
h=h,
2633+
loss=loss,
2634+
valid_loss=valid_loss,
2635+
config=config,
2636+
search_alg=search_alg,
2637+
num_samples=num_samples,
2638+
time_budget=time_budget,
2639+
refit_with_val=refit_with_val,
2640+
cpus=cpus,
2641+
gpus=gpus,
2642+
verbose=verbose,
2643+
alias=alias,
2644+
backend=backend,
2645+
callbacks=callbacks,
2646+
ray_options=ray_options,
2647+
optuna_options=optuna_options,
2648+
)
2649+
2650+
@classmethod
2651+
def get_default_config(cls, h, backend, n_series):
2652+
config = cls.default_config.copy()
2653+
config["input_size"] = tune.choice(
2654+
[h * x for x in config["input_size_multiplier"]]
2655+
)
2656+
config["step_size"] = tune.choice([1, h])
2657+
del config["input_size_multiplier"]
2658+
if backend == "optuna":
2659+
config["n_series"] = n_series
2660+
config = cls._ray_config_to_optuna(config)
2661+
2662+
return config
2663+
2664+
25782665
class AutoTimeMixer(BaseAuto):
25792666

25802667
default_config = {

neuralforecast/core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
NHITS,
3333
RNN,
3434
SOFTS,
35+
SOFTSSharp,
3536
TCN,
3637
TFT,
3738
Autoformer,
@@ -196,6 +197,8 @@ def _insample_times(
196197
"autodeepnpts": DeepNPTS,
197198
"softs": SOFTS,
198199
"autosofts": SOFTS,
200+
"softssharp": SOFTSSharp,
201+
"autosoftssharp": SOFTSSharp,
199202
"timemixer": TimeMixer,
200203
"autotimemixer": TimeMixer,
201204
"kan": KAN,

neuralforecast/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
'MLP', 'NHITS', 'NBEATS', 'NBEATSx', 'DLinear', 'NLinear',
33
'TFT', 'VanillaTransformer', 'Informer', 'Autoformer', 'PatchTST', 'FEDformer',
44
'StemGNN', 'HINT', 'TimesNet', 'TimeLLM', 'TSMixer', 'TSMixerx', 'MLPMultivariate',
5-
'iTransformer', 'BiTCN', 'TiDE', 'DeepNPTS', 'SOFTS', 'TimeMixer', 'KAN', 'RMoK',
5+
'iTransformer', 'BiTCN', 'TiDE', 'DeepNPTS', 'SOFTS', 'SOFTSSharp', 'TimeMixer', 'KAN', 'RMoK',
66
'TimeXer', 'xLSTM', 'XLinear'
77
]
88

@@ -36,6 +36,7 @@
3636
from .tide import TiDE
3737
from .deepnpts import DeepNPTS
3838
from .softs import SOFTS
39+
from .softssharp import SOFTSSharp
3940
from .timemixer import TimeMixer
4041
from .kan import KAN
4142
from .rmok import RMoK

0 commit comments

Comments
 (0)