11from __future__ import annotations
22
3+ from abc import abstractmethod
34import json
45import os
5- from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional
6+ from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Type
67
78
89try :
1819import jax .random as jrandom
1920import jax .tree_util as jtn
2021import optax
22+ import diffrax as dfx
2123
2224from catalax import Model
2325from catalax .model .model import SimulationConfig
3234 from catalax .dataset import Dataset
3335 from catalax .neural .strategy import Strategy
3436
37+ NON_ADAPTIVE_SOLVERS = [
38+ dfx .Euler ,
39+ dfx .Heun ,
40+ ]
41+
3542
3643class NeuralBase (eqx .Module , Predictor , Surrogate ):
3744 func : MLP
3845 observable_indices : List [int ]
3946 hyperparams : Dict
40- solver : diffrax .AbstractSolver
47+ solver : Type [ diffrax .AbstractSolver ]
4148 vector_field : Optional [Stack ]
4249 species_order : List [str ]
4350
@@ -58,7 +65,7 @@ def __init__(
5865 ** kwargs ,
5966 ):
6067 # Save solver and observable indices
61- self .solver = solver # type: ignore
68+ self .solver = solver
6269 self .observable_indices = observable_indices
6370 self .vector_field = None
6471 self .species_order = species_order
@@ -88,6 +95,18 @@ def __init__(
8895 out_size = out_size ,
8996 )
9097
98+ @abstractmethod
99+ def __call__ (
100+ self ,
101+ ts ,
102+ y0 ,
103+ solver : Optional [Type [diffrax .AbstractSolver ]] = None ,
104+ rtol : Optional [float ] = None ,
105+ atol : Optional [float ] = None ,
106+ dt0 : Optional [float ] = None ,
107+ ) -> jax .Array :
108+ raise NotImplementedError ("This method is not implemented" )
109+
91110 def train (
92111 self ,
93112 dataset : Dataset ,
@@ -156,6 +175,10 @@ def predict(
156175 config : Optional [SimulationConfig ] = None ,
157176 n_steps : int = 100 ,
158177 use_times : bool = False ,
178+ solver : Optional [Type [diffrax .AbstractSolver ]] = None ,
179+ rtol : Optional [float ] = None ,
180+ atol : Optional [float ] = None ,
181+ dt0 : Optional [float ] = None ,
159182 ):
160183 """Predict model behavior using the given dataset.
161184
@@ -195,7 +218,10 @@ def predict(
195218 if config :
196219 times = jnp .linspace (config .t0 , config .t1 , config .nsteps ).T # type: ignore
197220
198- predictions = jax .vmap (self , in_axes = (0 , 0 ))(times , y0s ) # type: ignore
221+ predictions = jax .vmap (
222+ lambda ts , y0 : self (ts , y0 , solver = solver , rtol = rtol , atol = atol , dt0 = dt0 ),
223+ in_axes = (0 , 0 ),
224+ )(times , y0s )
199225
200226 return Dataset .from_jax_arrays (
201227 species_order = self .species_order ,
@@ -419,3 +445,30 @@ def n_parameters(self) -> int:
419445 for layer in layers :
420446 n_parameters += layer .size
421447 return n_parameters
448+
449+ def _create_controller (
450+ self ,
451+ solver : Optional [Type [diffrax .AbstractSolver ]] = None ,
452+ rtol : Optional [float ] = None ,
453+ atol : Optional [float ] = None ,
454+ ):
455+ """Create the appropriate stepsize controller"""
456+ if solver is None :
457+ return diffrax .PIDController (1e-3 , 1e-6 )
458+
459+ if solver in NON_ADAPTIVE_SOLVERS :
460+ return diffrax .ConstantStepSize ()
461+ else :
462+ return diffrax .PIDController (
463+ rtol = rtol if rtol is not None else 1e-3 ,
464+ atol = atol if atol is not None else 1e-6 ,
465+ )
466+
467+ def _instantiate_solver (
468+ self ,
469+ solver : Optional [Type [diffrax .AbstractSolver ]],
470+ ) -> diffrax .AbstractSolver :
471+ if solver is None :
472+ return self .solver ()
473+ else :
474+ return solver ()
0 commit comments