|
5 | 5 | 'AutoNBEATS', 'AutoNBEATSx', 'AutoNHITS', 'AutoDLinear', 'AutoNLinear', 'AutoTiDE', 'AutoDeepNPTS', |
6 | 6 | 'AutoKAN', 'AutoTFT', 'AutoVanillaTransformer', 'AutoInformer', 'AutoAutoformer', 'AutoFEDformer', |
7 | 7 | 'AutoPatchTST', 'AutoiTransformer', 'AutoTimeXer', 'AutoTimesNet', 'AutoStemGNN', 'AutoHINT', 'AutoTSMixer', |
8 | | - 'AutoTSMixerx', 'AutoMLPMultivariate', 'AutoSOFTS', 'AutoTimeMixer', 'AutoRMoK', 'AutoXLinear', |
9 | | - 'RayOptions', 'OptunaOptions'] |
| 8 | + 'AutoTSMixerx', 'AutoMLPMultivariate', 'AutoSOFTS', 'AutoSOFTSSharp', 'AutoTimeMixer', 'AutoRMoK', |
| 9 | + 'AutoXLinear', 'RayOptions', 'OptunaOptions'] |
10 | 10 |
|
11 | 11 |
|
12 | 12 | from ray import tune |
|
37 | 37 | from .models.rmok import RMoK |
38 | 38 | from .models.rnn import RNN |
39 | 39 | from .models.softs import SOFTS |
| 40 | +from .models.softssharp import SOFTSSharp |
40 | 41 | from .models.stemgnn import StemGNN |
41 | 42 | from .models.tcn import TCN |
42 | 43 | from .models.tft import TFT |
@@ -2575,6 +2576,92 @@ def get_default_config(cls, h, backend, n_series): |
2575 | 2576 | return config |
2576 | 2577 |
|
2577 | 2578 |
|
| 2579 | +class AutoSOFTSSharp(BaseAuto): |
| 2580 | + |
| 2581 | + default_config = { |
| 2582 | + "input_size_multiplier": [1, 2, 3, 4, 5], |
| 2583 | + "h": None, |
| 2584 | + "n_series": None, |
| 2585 | + "hidden_size": tune.choice([64, 128, 256, 512]), |
| 2586 | + "d_core": tune.choice([64, 128, 256, 512]), |
| 2587 | + "pe_keep_prob": tune.choice([0.25, 0.5, 0.75, 1.0]), |
| 2588 | + "learning_rate": tune.loguniform(1e-4, 1e-1), |
| 2589 | + "scaler_type": tune.choice([None, "robust", "standard", "identity"]), |
| 2590 | + "max_steps": tune.choice([500, 1000, 2000]), |
| 2591 | + "batch_size": tune.choice([32, 64, 128, 256]), |
| 2592 | + "loss": None, |
| 2593 | + "random_seed": tune.randint(1, 20), |
| 2594 | + } |
| 2595 | + |
| 2596 | + def __init__( |
| 2597 | + self, |
| 2598 | + h, |
| 2599 | + n_series, |
| 2600 | + loss=MAE(), |
| 2601 | + valid_loss=None, |
| 2602 | + config=None, |
| 2603 | + search_alg=BasicVariantGenerator(random_state=1), |
| 2604 | + num_samples=10, |
| 2605 | + time_budget=None, |
| 2606 | + refit_with_val=False, |
| 2607 | + cpus=None, |
| 2608 | + gpus=None, |
| 2609 | + verbose=False, |
| 2610 | + alias=None, |
| 2611 | + backend="ray", |
| 2612 | + callbacks=None, |
| 2613 | + ray_options=None, |
| 2614 | + optuna_options=None, |
| 2615 | + ): |
| 2616 | + |
| 2617 | + if config is None: |
| 2618 | + config = self.get_default_config(h=h, backend=backend, n_series=n_series) |
| 2619 | + |
| 2620 | + if backend == "ray": |
| 2621 | + config["n_series"] = n_series |
| 2622 | + elif backend == "optuna": |
| 2623 | + mock_trial = MockTrial() |
| 2624 | + if ( |
| 2625 | + "n_series" in config(mock_trial) |
| 2626 | + and config(mock_trial)["n_series"] != n_series |
| 2627 | + ) or ("n_series" not in config(mock_trial)): |
| 2628 | + raise Exception(f"config needs 'n_series': {n_series}") |
| 2629 | + |
| 2630 | + super(AutoSOFTSSharp, self).__init__( |
| 2631 | + cls_model=SOFTSSharp, |
| 2632 | + h=h, |
| 2633 | + loss=loss, |
| 2634 | + valid_loss=valid_loss, |
| 2635 | + config=config, |
| 2636 | + search_alg=search_alg, |
| 2637 | + num_samples=num_samples, |
| 2638 | + time_budget=time_budget, |
| 2639 | + refit_with_val=refit_with_val, |
| 2640 | + cpus=cpus, |
| 2641 | + gpus=gpus, |
| 2642 | + verbose=verbose, |
| 2643 | + alias=alias, |
| 2644 | + backend=backend, |
| 2645 | + callbacks=callbacks, |
| 2646 | + ray_options=ray_options, |
| 2647 | + optuna_options=optuna_options, |
| 2648 | + ) |
| 2649 | + |
| 2650 | + @classmethod |
| 2651 | + def get_default_config(cls, h, backend, n_series): |
| 2652 | + config = cls.default_config.copy() |
| 2653 | + config["input_size"] = tune.choice( |
| 2654 | + [h * x for x in config["input_size_multiplier"]] |
| 2655 | + ) |
| 2656 | + config["step_size"] = tune.choice([1, h]) |
| 2657 | + del config["input_size_multiplier"] |
| 2658 | + if backend == "optuna": |
| 2659 | + config["n_series"] = n_series |
| 2660 | + config = cls._ray_config_to_optuna(config) |
| 2661 | + |
| 2662 | + return config |
| 2663 | + |
| 2664 | + |
2578 | 2665 | class AutoTimeMixer(BaseAuto): |
2579 | 2666 |
|
2580 | 2667 | default_config = { |
|
0 commit comments