Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 57 additions & 4 deletions catalax/neural/neuralbase.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

from abc import abstractmethod
import json
import os
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type


try:
Expand All @@ -18,6 +19,7 @@
import jax.random as jrandom
import jax.tree_util as jtn
import optax
import diffrax as dfx

from catalax import Model
from catalax.model.model import SimulationConfig
Expand All @@ -32,12 +34,17 @@
from catalax.dataset import Dataset
from catalax.neural.strategy import Strategy

NON_ADAPTIVE_SOLVERS = [
dfx.Euler,
dfx.Heun,
]


class NeuralBase(eqx.Module, Predictor, Surrogate):
func: MLP
observable_indices: List[int]
hyperparams: Dict
solver: diffrax.AbstractSolver
solver: Type[diffrax.AbstractSolver]
vector_field: Optional[Stack]
species_order: List[str]

Expand All @@ -58,7 +65,7 @@ def __init__(
**kwargs,
):
# Save solver and observable indices
self.solver = solver # type: ignore
self.solver = solver
self.observable_indices = observable_indices
self.vector_field = None
self.species_order = species_order
Expand Down Expand Up @@ -88,6 +95,18 @@ def __init__(
out_size=out_size,
)

@abstractmethod
def __call__(
self,
ts,
y0,
solver: Optional[Type[diffrax.AbstractSolver]] = None,
rtol: Optional[float] = None,
atol: Optional[float] = None,
dt0: Optional[float] = None,
) -> jax.Array:
raise NotImplementedError("This method is not implemented")

def train(
self,
dataset: Dataset,
Expand Down Expand Up @@ -156,6 +175,10 @@ def predict(
config: Optional[SimulationConfig] = None,
n_steps: int = 100,
use_times: bool = False,
solver: Optional[Type[diffrax.AbstractSolver]] = None,
rtol: Optional[float] = None,
atol: Optional[float] = None,
dt0: Optional[float] = None,
):
"""Predict model behavior using the given dataset.

Expand Down Expand Up @@ -195,7 +218,10 @@ def predict(
if config:
times = jnp.linspace(config.t0, config.t1, config.nsteps).T # type: ignore

predictions = jax.vmap(self, in_axes=(0, 0))(times, y0s) # type: ignore
predictions = jax.vmap(
lambda ts, y0: self(ts, y0, solver=solver, rtol=rtol, atol=atol, dt0=dt0),
in_axes=(0, 0),
)(times, y0s)

return Dataset.from_jax_arrays(
species_order=self.species_order,
Expand Down Expand Up @@ -419,3 +445,30 @@ def n_parameters(self) -> int:
for layer in layers:
n_parameters += layer.size
return n_parameters

def _create_controller(
self,
solver: Optional[Type[diffrax.AbstractSolver]] = None,
rtol: Optional[float] = None,
atol: Optional[float] = None,
):
"""Create the appropriate stepsize controller"""
if solver is None:
return diffrax.PIDController(1e-3, 1e-6)

if solver is not None and solver in NON_ADAPTIVE_SOLVERS:
Comment thread
JR-1991 marked this conversation as resolved.
Outdated
return diffrax.ConstantStepSize()
else:
return diffrax.PIDController(
rtol=rtol if rtol is not None else 1e-3,
atol=atol if atol is not None else 1e-6,
)

def _instantiate_solver(
self,
solver: Optional[Type[diffrax.AbstractSolver]],
) -> diffrax.AbstractSolver:
if solver is None:
return self.solver()
else:
return solver()
20 changes: 15 additions & 5 deletions catalax/neural/neuralode.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional, Type

import jax
import diffrax
Expand Down Expand Up @@ -35,15 +35,25 @@ def __init__(
final_activation=final_activation,
)

def __call__(self, ts, y0, key: jax.Array = jax.random.PRNGKey(0)):
def __call__(
self,
ts,
y0,
solver: Optional[Type[diffrax.AbstractSolver]] = None,
rtol: Optional[float] = None,
atol: Optional[float] = None,
dt0: Optional[float] = None,
):
stepsize_controller = self._create_controller(solver, rtol, atol)
solver_instance = self._instantiate_solver(solver)
solution = diffrax.diffeqsolve(
diffrax.ODETerm(self.func), # type: ignore
self.solver(), # type: ignore
solver_instance, # type: ignore
t0=ts[0], # type: ignore
t1=ts[-1],
dt0=ts[1] - ts[0],
dt0=dt0 or ts[1] - ts[0],
y0=y0,
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6), # type: ignore
stepsize_controller=stepsize_controller,
saveat=diffrax.SaveAt(ts=ts), # type: ignore
)
return solution.ys
Expand Down
8 changes: 8 additions & 0 deletions catalax/neural/penalties/penalties.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
penalize_duplicate_reactions,
penalize_non_conservative,
penalize_non_integer,
penalize_null_space,
)


Expand Down Expand Up @@ -207,6 +208,7 @@ def for_rateflow(
integer_alpha: Optional[float] = None,
sparsity_alpha: Optional[float] = None,
l2_alpha: Optional[float] = None,
null_space_alpha: Optional[float] = None,
) -> Penalties:
"""Create a collection of penalties specifically designed for NeuralRDE models.

Expand Down Expand Up @@ -275,6 +277,12 @@ def for_rateflow(
alpha=l2_alpha or alpha,
)

if null_space_alpha is not None:
penalties.add_penalty(
name="null_space",
fun=penalize_null_space,
alpha=null_space_alpha,
)
return penalties


Expand Down
14 changes: 14 additions & 0 deletions catalax/neural/penalties/stoich_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@
from catalax.neural.rateflow import RateFlowODE


def penalize_null_space(model: RateFlowODE, alpha: float = 0.1, **kwargs) -> jax.Array:
"""Penalize the null space of the stoichiometric matrix.

This penalty function encourages the stoichiometric matrix to have a null space of zero.
"""
assert isinstance(model, RateFlowODE), "Model must be a RateFlowODE"

stoich_matrix = _normalize_matrix(model.stoich_matrix)
if model.mass_constraint is not None:
Comment thread
JR-1991 marked this conversation as resolved.
return alpha * jnp.mean(model.mass_constraint @ stoich_matrix)
else:
return jnp.array(0.0)


def penalize_density(model: RateFlowODE, alpha: float = 0.1, **kwargs) -> jax.Array:
"""Penalize dense stoichiometric matrices by encouraging sparsity.

Expand Down
10 changes: 6 additions & 4 deletions catalax/neural/plots/rateflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,19 @@ def plot_learned_rates(
label=f"Reaction {reaction + 1}",
)

sub_ax[1].grid(True, which="both", linestyle="--")
sub_ax[1].grid(True, which="minor", alpha=0.3)
sub_ax[1].minorticks_on()

sub_ax[1].legend(fontsize="small", frameon=True, fancybox=True, shadow=True)
sub_ax[1].grid(alpha=0.3, linestyle="--", linewidth=0.8)
sub_ax[1].set_xlabel("Time", fontsize=12, labelpad=10)
sub_ax[1].set_ylabel("Rate Magnitude", fontsize=12, labelpad=10)
sub_ax[1].spines["top"].set_visible(False)
sub_ax[1].spines["right"].set_visible(False)

# Model fit (right panel)
dataset.measurements[i].plot(
ax=sub_ax[2], linestyle="--", linewidth=2, alpha=0.7
)
dataset.measurements[i].plot(ax=sub_ax[2], model_data=pred.measurements[i])

sub_ax[2].set_title(f"Dataset {i + 1}", fontsize=12, pad=15)
sub_ax[2].legend(fontsize="small", frameon=True, fancybox=True, shadow=True)
sub_ax[2].grid(alpha=0.2, linestyle="--", linewidth=0.8)
Expand Down
52 changes: 44 additions & 8 deletions catalax/neural/rateflow.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Type

import diffrax
import jax
import jax.tree_util as jtn
import jax.random as jrandom
from matplotlib.figure import Figure
import equinox as eqx

from catalax.dataset.dataset import Dataset
from catalax.model.model import Model
Expand All @@ -15,9 +16,17 @@


class RateFlowODE(NeuralBase):
reaction_size: int
stoich_matrix: jax.Array
learn_stoich: bool
reaction_size: int = eqx.field(
default=None,
)
Comment thread
JR-1991 marked this conversation as resolved.
Outdated
learn_stoich: bool = eqx.field(
default=True,
static=True,
)
mass_constraint: Optional[jax.Array] = eqx.field(
default=None,
)

def __init__(
self,
Expand All @@ -32,6 +41,7 @@ def __init__(
use_final_bias: bool = False,
learn_stoich: bool = True,
stoich_matrix: jax.Array | None = None,
mass_constraint: jax.Array | None = None,
*,
key,
**kwargs,
Expand All @@ -52,6 +62,22 @@ def __init__(
learn_stoich=learn_stoich,
)

if mass_constraint is not None:
assert mass_constraint.ndim == 2, (
"Mass constraint must be a matrix of shape (n_constraints, n_species). Given shape: "
f"{mass_constraint.shape}"
)

_, n_species = mass_constraint.shape
assert n_species == len(species_order), (
"Mass constraint must be a matrix of shape (n_constraints, n_species). Given shape: "
f"{mass_constraint.shape}"
)

self.mass_constraint = mass_constraint
else:
self.mass_constraint = None

self.reaction_size = reaction_size
self.learn_stoich = learn_stoich

Expand All @@ -62,17 +88,27 @@ def __init__(
self.stoich_matrix = stoich_matrix

def stoich_func(self, t, y, args):
return self.stoich_matrix @ self.func(t, y, args)
return self.stoich_matrix @ jax.nn.relu(self.func(t, y, args))
Comment thread
JR-1991 marked this conversation as resolved.

def __call__(self, ts, y0):
def __call__(
self,
ts,
y0,
solver: Optional[Type[diffrax.AbstractSolver]] = None,
rtol: float = 1e-3,
atol: float = 1e-6,
dt0: Optional[float] = None,
):
stepsize_controller = self._create_controller(solver, rtol, atol)
solver_instance = self._instantiate_solver(solver)
solution = diffrax.diffeqsolve(
diffrax.ODETerm(self.stoich_func), # type: ignore
self.solver(), # type: ignore
solver_instance, # type: ignore
t0=ts[0], # type: ignore
t1=ts[-1],
dt0=ts[1] - ts[0],
dt0=dt0 or ts[1] - ts[0],
y0=y0,
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6), # type: ignore
stepsize_controller=stepsize_controller,
saveat=diffrax.SaveAt(ts=ts), # type: ignore
)
return solution.ys
Expand Down
20 changes: 16 additions & 4 deletions catalax/neural/universalode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import List, Optional
from typing import List, Optional, Type

import jax
import diffrax
Expand Down Expand Up @@ -97,14 +97,26 @@ def _combined_term(self, t, y, args):

return mech_rates + self._corrective_term(t, y, args)

def __call__(self, ts, y0):
def __call__(
self,
ts,
y0,
solver: Optional[Type[diffrax.AbstractSolver]] = None,
rtol: Optional[float] = None,
atol: Optional[float] = None,
dt0: Optional[float] = None,
):
stepsize_controller = self._create_controller(solver, rtol, atol)
solver_instance = self._instantiate_solver(solver)

solution = diffrax.diffeqsolve(
diffrax.ODETerm(self._combined_term), # type: ignore
self.solver(), # type: ignore
solver_instance, # type: ignore
t0=0.0, # type: ignore
t1=ts[-1],
dt0=ts[1] - ts[0],
dt0=dt0 or ts[1] - ts[0],
y0=y0,
stepsize_controller=stepsize_controller,
saveat=diffrax.SaveAt(ts=ts), # type: ignore
max_steps=64**4,
)
Expand Down
16 changes: 16 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import catalax as ctx
import catalax.neural as ctn
import equinox as eqx
from jax.flatten_util import ravel_pytree

model = ctx.Model(name="Michaelis-Menten")
model.add_species(s1="Substrate")
model.add_constant(e="Enzyme")
model.add_ode("s1", "kcat * e * s1 / (k_m + s1)")

neural_ode = ctn.NeuralODE.from_model(model, width_size=6, depth=2)

params, static = eqx.partition(neural_ode, eqx.is_array)
theta0, unravel = ravel_pytree(params)

print(unravel)