Skip to content

Commit f6161ee

Browse files
authored
Merge pull request #22 from JR-1991/solver-and-penalities-update
Solver and penalties update
2 parents 36124f4 + 0676311 commit f6161ee

7 files changed

Lines changed: 166 additions & 25 deletions

File tree

catalax/neural/neuralbase.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

3+
from abc import abstractmethod
34
import json
45
import os
5-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
6+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type
67

78

89
try:
@@ -18,6 +19,7 @@
1819
import jax.random as jrandom
1920
import jax.tree_util as jtn
2021
import optax
22+
import diffrax as dfx
2123

2224
from catalax import Model
2325
from catalax.model.model import SimulationConfig
@@ -32,12 +34,17 @@
3234
from catalax.dataset import Dataset
3335
from catalax.neural.strategy import Strategy
3436

37+
NON_ADAPTIVE_SOLVERS = [
38+
dfx.Euler,
39+
dfx.Heun,
40+
]
41+
3542

3643
class NeuralBase(eqx.Module, Predictor, Surrogate):
3744
func: MLP
3845
observable_indices: List[int]
3946
hyperparams: Dict
40-
solver: diffrax.AbstractSolver
47+
solver: Type[diffrax.AbstractSolver]
4148
vector_field: Optional[Stack]
4249
species_order: List[str]
4350

@@ -58,7 +65,7 @@ def __init__(
5865
**kwargs,
5966
):
6067
# Save solver and observable indices
61-
self.solver = solver # type: ignore
68+
self.solver = solver
6269
self.observable_indices = observable_indices
6370
self.vector_field = None
6471
self.species_order = species_order
@@ -88,6 +95,18 @@ def __init__(
8895
out_size=out_size,
8996
)
9097

98+
@abstractmethod
99+
def __call__(
100+
self,
101+
ts,
102+
y0,
103+
solver: Optional[Type[diffrax.AbstractSolver]] = None,
104+
rtol: Optional[float] = None,
105+
atol: Optional[float] = None,
106+
dt0: Optional[float] = None,
107+
) -> jax.Array:
108+
raise NotImplementedError("This method is not implemented")
109+
91110
def train(
92111
self,
93112
dataset: Dataset,
@@ -156,6 +175,10 @@ def predict(
156175
config: Optional[SimulationConfig] = None,
157176
n_steps: int = 100,
158177
use_times: bool = False,
178+
solver: Optional[Type[diffrax.AbstractSolver]] = None,
179+
rtol: Optional[float] = None,
180+
atol: Optional[float] = None,
181+
dt0: Optional[float] = None,
159182
):
160183
"""Predict model behavior using the given dataset.
161184
@@ -195,7 +218,10 @@ def predict(
195218
if config:
196219
times = jnp.linspace(config.t0, config.t1, config.nsteps).T # type: ignore
197220

198-
predictions = jax.vmap(self, in_axes=(0, 0))(times, y0s) # type: ignore
221+
predictions = jax.vmap(
222+
lambda ts, y0: self(ts, y0, solver=solver, rtol=rtol, atol=atol, dt0=dt0),
223+
in_axes=(0, 0),
224+
)(times, y0s)
199225

200226
return Dataset.from_jax_arrays(
201227
species_order=self.species_order,
@@ -419,3 +445,30 @@ def n_parameters(self) -> int:
419445
for layer in layers:
420446
n_parameters += layer.size
421447
return n_parameters
448+
449+
def _create_controller(
450+
self,
451+
solver: Optional[Type[diffrax.AbstractSolver]] = None,
452+
rtol: Optional[float] = None,
453+
atol: Optional[float] = None,
454+
):
455+
"""Create the appropriate stepsize controller"""
456+
if solver is None:
457+
return diffrax.PIDController(1e-3, 1e-6)
458+
459+
if solver in NON_ADAPTIVE_SOLVERS:
460+
return diffrax.ConstantStepSize()
461+
else:
462+
return diffrax.PIDController(
463+
rtol=rtol if rtol is not None else 1e-3,
464+
atol=atol if atol is not None else 1e-6,
465+
)
466+
467+
def _instantiate_solver(
468+
self,
469+
solver: Optional[Type[diffrax.AbstractSolver]],
470+
) -> diffrax.AbstractSolver:
471+
if solver is None:
472+
return self.solver()
473+
else:
474+
return solver()

catalax/neural/neuralode.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import List, Optional, Type
22

33
import jax
44
import diffrax
@@ -35,15 +35,25 @@ def __init__(
3535
final_activation=final_activation,
3636
)
3737

38-
def __call__(self, ts, y0, key: jax.Array = jax.random.PRNGKey(0)):
38+
def __call__(
39+
self,
40+
ts,
41+
y0,
42+
solver: Optional[Type[diffrax.AbstractSolver]] = None,
43+
rtol: Optional[float] = None,
44+
atol: Optional[float] = None,
45+
dt0: Optional[float] = None,
46+
):
47+
stepsize_controller = self._create_controller(solver, rtol, atol)
48+
solver_instance = self._instantiate_solver(solver)
3949
solution = diffrax.diffeqsolve(
4050
diffrax.ODETerm(self.func), # type: ignore
41-
self.solver(), # type: ignore
51+
solver_instance, # type: ignore
4252
t0=ts[0], # type: ignore
4353
t1=ts[-1],
44-
dt0=ts[1] - ts[0],
54+
dt0=dt0 or ts[1] - ts[0],
4555
y0=y0,
46-
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6), # type: ignore
56+
stepsize_controller=stepsize_controller,
4757
saveat=diffrax.SaveAt(ts=ts), # type: ignore
4858
)
4959
return solution.ys

catalax/neural/penalties/penalties.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
penalize_duplicate_reactions,
2525
penalize_non_conservative,
2626
penalize_non_integer,
27+
penalize_null_space,
2728
)
2829

2930

@@ -207,6 +208,7 @@ def for_rateflow(
207208
integer_alpha: Optional[float] = None,
208209
sparsity_alpha: Optional[float] = None,
209210
l2_alpha: Optional[float] = None,
211+
null_space_alpha: Optional[float] = None,
210212
) -> Penalties:
211213
"""Create a collection of penalties specifically designed for NeuralRDE models.
212214
@@ -275,6 +277,12 @@ def for_rateflow(
275277
alpha=l2_alpha or alpha,
276278
)
277279

280+
if null_space_alpha is not None:
281+
penalties.add_penalty(
282+
name="null_space",
283+
fun=penalize_null_space,
284+
alpha=null_space_alpha,
285+
)
278286
return penalties
279287

280288

catalax/neural/penalties/stoich_mat.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,28 @@
44
from catalax.neural.rateflow import RateFlowODE
55

66

7+
def penalize_null_space(model: RateFlowODE, alpha: float = 0.1, **kwargs) -> jax.Array:
8+
"""Penalize the null space of the stoichiometric matrix.
9+
10+
This penalty function encourages the stoichiometric matrix to have a null space of zero.
11+
"""
12+
assert isinstance(model, RateFlowODE), "Model must be a RateFlowODE"
13+
14+
stoich_matrix = _normalize_matrix(model.stoich_matrix)
15+
if model.mass_constraint is not None:
16+
# Check shape compatibility for matrix multiplication
17+
if model.mass_constraint.shape[1] != stoich_matrix.shape[0]:
18+
raise ValueError(
19+
f"Incompatible shapes for matrix multiplication: "
20+
f"mass_constraint.shape={model.mass_constraint.shape}, "
21+
f"stoich_matrix.shape={stoich_matrix.shape}. "
22+
"Expected mass_constraint.shape[1] == stoich_matrix.shape[0]."
23+
)
24+
return alpha * jnp.mean(model.mass_constraint @ stoich_matrix)
25+
else:
26+
return jnp.array(0.0)
27+
28+
729
def penalize_density(model: RateFlowODE, alpha: float = 0.1, **kwargs) -> jax.Array:
830
"""Penalize dense stoichiometric matrices by encouraging sparsity.
931

catalax/neural/plots/rateflow.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,17 +79,19 @@ def plot_learned_rates(
7979
label=f"Reaction {reaction + 1}",
8080
)
8181

82+
sub_ax[1].grid(True, which="both", linestyle="--")
83+
sub_ax[1].grid(True, which="minor", alpha=0.3)
84+
sub_ax[1].minorticks_on()
85+
8286
sub_ax[1].legend(fontsize="small", frameon=True, fancybox=True, shadow=True)
83-
sub_ax[1].grid(alpha=0.3, linestyle="--", linewidth=0.8)
8487
sub_ax[1].set_xlabel("Time", fontsize=12, labelpad=10)
8588
sub_ax[1].set_ylabel("Rate Magnitude", fontsize=12, labelpad=10)
8689
sub_ax[1].spines["top"].set_visible(False)
8790
sub_ax[1].spines["right"].set_visible(False)
8891

8992
# Model fit (right panel)
90-
dataset.measurements[i].plot(
91-
ax=sub_ax[2], linestyle="--", linewidth=2, alpha=0.7
92-
)
93+
dataset.measurements[i].plot(ax=sub_ax[2], model_data=pred.measurements[i])
94+
9395
sub_ax[2].set_title(f"Dataset {i + 1}", fontsize=12, pad=15)
9496
sub_ax[2].legend(fontsize="small", frameon=True, fancybox=True, shadow=True)
9597
sub_ax[2].grid(alpha=0.2, linestyle="--", linewidth=0.8)

catalax/neural/rateflow.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from typing import List, Optional, Tuple
1+
from typing import List, Optional, Tuple, Type
22

33
import diffrax
44
import jax
55
import jax.tree_util as jtn
66
import jax.random as jrandom
77
from matplotlib.figure import Figure
8+
import equinox as eqx
89

910
from catalax.dataset.dataset import Dataset
1011
from catalax.model.model import Model
@@ -15,9 +16,15 @@
1516

1617

1718
class RateFlowODE(NeuralBase):
18-
reaction_size: int
1919
stoich_matrix: jax.Array
20-
learn_stoich: bool
20+
reaction_size: int = eqx.field()
21+
learn_stoich: bool = eqx.field(
22+
default=True,
23+
static=True,
24+
)
25+
mass_constraint: Optional[jax.Array] = eqx.field(
26+
default=None,
27+
)
2128

2229
def __init__(
2330
self,
@@ -32,6 +39,7 @@ def __init__(
3239
use_final_bias: bool = False,
3340
learn_stoich: bool = True,
3441
stoich_matrix: jax.Array | None = None,
42+
mass_constraint: jax.Array | None = None,
3543
*,
3644
key,
3745
**kwargs,
@@ -52,6 +60,22 @@ def __init__(
5260
learn_stoich=learn_stoich,
5361
)
5462

63+
if mass_constraint is not None:
64+
assert mass_constraint.ndim == 2, (
65+
"Mass constraint must be a matrix of shape (n_constraints, n_species). Given shape: "
66+
f"{mass_constraint.shape}"
67+
)
68+
69+
_, n_species = mass_constraint.shape
70+
assert n_species == len(species_order), (
71+
"Mass constraint must be a matrix of shape (n_constraints, n_species). Given shape: "
72+
f"{mass_constraint.shape}"
73+
)
74+
75+
self.mass_constraint = mass_constraint
76+
else:
77+
self.mass_constraint = None
78+
5579
self.reaction_size = reaction_size
5680
self.learn_stoich = learn_stoich
5781

@@ -62,17 +86,27 @@ def __init__(
6286
self.stoich_matrix = stoich_matrix
6387

6488
def stoich_func(self, t, y, args):
65-
return self.stoich_matrix @ self.func(t, y, args)
89+
return self.stoich_matrix @ jax.nn.relu(self.func(t, y, args))
6690

67-
def __call__(self, ts, y0):
91+
def __call__(
92+
self,
93+
ts,
94+
y0,
95+
solver: Optional[Type[diffrax.AbstractSolver]] = None,
96+
rtol: float = 1e-3,
97+
atol: float = 1e-6,
98+
dt0: Optional[float] = None,
99+
):
100+
stepsize_controller = self._create_controller(solver, rtol, atol)
101+
solver_instance = self._instantiate_solver(solver)
68102
solution = diffrax.diffeqsolve(
69103
diffrax.ODETerm(self.stoich_func), # type: ignore
70-
self.solver(), # type: ignore
104+
solver_instance, # type: ignore
71105
t0=ts[0], # type: ignore
72106
t1=ts[-1],
73-
dt0=ts[1] - ts[0],
107+
dt0=dt0 or ts[1] - ts[0],
74108
y0=y0,
75-
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6), # type: ignore
109+
stepsize_controller=stepsize_controller,
76110
saveat=diffrax.SaveAt(ts=ts), # type: ignore
77111
)
78112
return solution.ys

catalax/neural/universalode.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import List, Optional
2+
from typing import List, Optional, Type
33

44
import jax
55
import diffrax
@@ -97,14 +97,26 @@ def _combined_term(self, t, y, args):
9797

9898
return mech_rates + self._corrective_term(t, y, args)
9999

100-
def __call__(self, ts, y0):
100+
def __call__(
101+
self,
102+
ts,
103+
y0,
104+
solver: Optional[Type[diffrax.AbstractSolver]] = None,
105+
rtol: Optional[float] = None,
106+
atol: Optional[float] = None,
107+
dt0: Optional[float] = None,
108+
):
109+
stepsize_controller = self._create_controller(solver, rtol, atol)
110+
solver_instance = self._instantiate_solver(solver)
111+
101112
solution = diffrax.diffeqsolve(
102113
diffrax.ODETerm(self._combined_term), # type: ignore
103-
self.solver(), # type: ignore
114+
solver_instance, # type: ignore
104115
t0=0.0, # type: ignore
105116
t1=ts[-1],
106-
dt0=ts[1] - ts[0],
117+
dt0=dt0 or ts[1] - ts[0],
107118
y0=y0,
119+
stepsize_controller=stepsize_controller,
108120
saveat=diffrax.SaveAt(ts=ts), # type: ignore
109121
max_steps=64**4,
110122
)

0 commit comments

Comments
 (0)