amici.jax
JAX
This module provides an interface to generate and use AMICI models with JAX. Please note that this module is experimental, the API may substantially change in the future. Use at your own risk and do not expect backward compatibility.
- class amici.jax.JAXModel[source]
JAXModel provides an abstract base class for a JAX-based implementation of an AMICI model. The class implements routines for simulation and evaluation of derived quantities, model specific implementations need to be provided by classes inheriting from JAXModel.
- MODEL_API_VERSION = '0.0.2'
-
jax_py_file:
pathlib.Path
- abstract property observable_ids: list[str]
Get the observable ids of the model.
- Returns:
Observable ids
- abstract property parameter_ids: list[str]
Get the parameter ids of the model.
- Returns:
Parameter ids
- preequilibrate_condition(p, x_reinit, mask_reinit, solver, controller, max_steps)[source]
Simulate a condition.
- Parameters:
p (
jaxtyping.Float[Array, 'np']
) – parameters for simulation ordered according to ids in :ivar parameter_ids:solver (
diffrax._solver.base.AbstractSolver
) – ODE solvercontroller (
diffrax._step_size_controller.base.AbstractStepSizeController
) – step size controllermax_steps (
int
|jax.numpy.int64
) – maximum number of solver steps
- Return type:
- Returns:
pre-equilibrated state variables and statistics
- simulate_condition(p, ts_init, ts_dyn, ts_posteq, my, iys, iy_trafos, solver, controller, adjoint, max_steps, x_preeq=Array([], shape=(0,), dtype=float32), mask_reinit=Array([], shape=(0,), dtype=float32), x_reinit=Array([], shape=(0,), dtype=float32), ret=ReturnValue.llh)[source]
Simulate a condition.
- Parameters:
p (
jaxtyping.Float[Array, 'np']
) – parameters for simulation ordered according to ids in :ivar parameter_ids:ts_init (
jaxtyping.Float[Array, 'nt_preeq']
) – time points that do not require simulation. Usually valued 0.0, but needs to be shaped according to the number of observables that are evaluated before dynamic simulation.ts_dyn (
jaxtyping.Float[Array, 'nt_dyn']
) – time points for dynamic simulation. Usually valued > 0.0 and sorted in monotonically increasing order. Duplicate time points are allowed to facilitate the evaluation of multiple observables at specific time points.ts_posteq (
jaxtyping.Float[Array, 'nt_posteq']
) – time points for post-equilibration. Usually valued Infty, but needs to be shaped according to the number of observables that are evaluated after post-equilibration.my (
jaxtyping.Float[Array, 'nt']
) – observed dataiys (
jaxtyping.Int[Array, 'nt']
) – indices of the observables according to ordering in :ivar observable_ids:x_preeq (
jaxtyping.Float[Array, '*nx']
) – initial state vector for pre-equilibration. If not provided, the initial state vector is computed using_x0()
.mask_reinit (
jaxtyping.Bool[Array, '*nx']
) – mask for re-initialization. If True, the corresponding state variable is re-initialized.x_reinit (
jaxtyping.Float[Array, '*nx']
) – re-initialized state vector. If not provided, the state vector is not re-initialized.solver (
diffrax._solver.base.AbstractSolver
) – ODE solvercontroller (
diffrax._step_size_controller.base.AbstractStepSizeController
) – step size controlleradjoint (
diffrax._adjoint.AbstractAdjoint
) – adjoint method. Recommended values are diffrax.DirectAdjoint() for jax.jacfwd (with vector-valued outputs) and diffrax.RecursiveCheckpointAdjoint() for jax.grad (for scalar-valued outputs).max_steps (
int
|jax.numpy.int64
) – maximum number of solver stepsret (
amici.jax.model.ReturnValue
) – which output to return. SeeReturnValue
for available options.
- Return type:
tuple
[jaxtyping.Float[Array, 'nt *nx']
|jax.numpy.float64
,dict
]- Returns:
output according to ret and statistics
- class amici.jax.JAXProblem(model, petab_problem)[source]
PEtab problem wrapper for JAX models.
- Variables:
parameters – Values for the model parameters. Do not change dimensions, values may be changed during, e.g. model training.
model – JAXModel instance to use for simulation.
_parameter_mappings –
ParameterMappingForCondition
instances for each simulation condition._measurements – Subset measurement dataframes for each simulation condition.
_petab_problem – PEtab problem to simulate.
- __init__(model, petab_problem)[source]
Initialize a JAXProblem instance with a model and a PEtab problem.
- Parameters:
model (
amici.jax.model.JAXModel
) – JAXModel instance to use for simulation.petab_problem (
petab.v1.problem.Problem
) – PEtab problem to simulate.
- get_petab_parameter_by_id(name)[source]
Get the value of a PEtab parameter by name.
- Parameters:
name (
str
) – PEtab parameter id, as returned byparameter_ids
.- Return type:
- Returns:
Value of the parameter
- classmethod load(directory)[source]
Load a problem from a directory.
- Parameters:
directory (
pathlib.Path
) – Directory to load the problem from.- Returns:
Loaded problem instance.
- load_parameters(simulation_condition)[source]
Load parameters for a simulation condition.
- Parameters:
simulation_condition (
str
) – Simulation condition to load parameters for.- Return type:
jaxtyping.Float[Array, 'np']
- Returns:
Parameters for the simulation condition.
- load_reinitialisation(simulation_condition, p)[source]
Load reinitialisation values and mask for the state vector for a simulation condition.
- Parameters:
simulation_condition (
str
) – Simulation condition to load reinitialisation for.p (
jaxtyping.Float[Array, 'np']
) – Parameters for the simulation condition.
- Return type:
tuple
[jaxtyping.Bool[Array, 'nx']
,jaxtyping.Float[Array, 'nx']
]- Returns:
Tuple of reinitialisation masm and value for states.
-
model:
amici.jax.model.JAXModel
- property parameter_ids: list[str]
Parameter ids that are estimated in the PEtab problem. Same ordering as values in
parameters
.- Returns:
PEtab parameter ids
- run_preequilibration(simulation_condition, solver, controller, max_steps)[source]
Run a pre-equilibration simulation for a given simulation condition.
- Parameters:
simulation_condition (
str
) – Simulation condition to run simulation for.solver (
diffrax._solver.base.AbstractSolver
) – ODE solver to use for simulationcontroller (
diffrax._step_size_controller.base.AbstractStepSizeController
) – Step size controller to use for simulationmax_steps (
jax.numpy.int64
) – Maximum number of steps to take during simulation
- Return type:
- Returns:
Pre-equilibration state
- run_simulation(simulation_condition, solver, controller, max_steps, x_preeq=Array([], shape=(0,), dtype=float32), ret=ReturnValue.llh)[source]
Run a simulation for a given simulation condition.
- Parameters:
simulation_condition (
tuple
[str
,...
]) – Simulation condition to run simulation for.solver (
diffrax._solver.base.AbstractSolver
) – ODE solver to use for simulationcontroller (
diffrax._step_size_controller.base.AbstractStepSizeController
) – Step size controller to use for simulationmax_steps (
jax.numpy.int64
) – Maximum number of steps to take during simulationx_preeq (
jaxtyping.Float[Array, '*nx']
) – Pre-equilibration state if availableret (
amici.jax.model.ReturnValue
) – which output to return. SeeReturnValue
for available options.
- Return type:
- Returns:
Tuple of output value and simulation statistics
- save(directory)[source]
Save the problem to a directory.
- Parameters:
directory (
pathlib.Path
) – Directory to save the problem to.
- class amici.jax.ReturnValue(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]
- chi2 = 'sum(((observed - simulated) / sigma ) ** 2)'
- llh = 'log-likelihood'
- nllhs = 'pointwise negative log-likelihood'
- res = 'residuals'
- sigmay = 'standard deviations of the observables'
- tcl = 'total values for conservation laws'
- x = 'full state vector'
- x0 = 'full initial state vector'
- x0_solver = 'reduced initial state vector'
- x_solver = 'reduced state vector'
- y = 'observables'
- amici.jax.petab_simulate(problem, solver=Kvaerno5( scan_kind=None, root_finder=VeryChord( rtol=<tolerance taken from `diffeqsolve(..., stepsize_controller=...)` argument>, atol=<tolerance taken from `diffeqsolve(..., stepsize_controller=...)` argument>, norm=<tolerance taken from `diffeqsolve(..., stepsize_controller=...)` argument>, kappa=0.01, linear_solver=AutoLinearSolver(well_posed=None) ), root_find_max_steps=10 ), controller=PIDController( rtol=1e-08, atol=1e-08, pcoeff=0.4, icoeff=0.3, dcoeff=0.0, dtmin=None, dtmax=None, force_dtmin=True, step_ts=None, jump_ts=None, factormin=0.2, factormax=10.0, norm=<function rms_norm>, safety=0.9, error_order=None ), max_steps=1024)[source]
Run simulations for a problem and return the results as a petab simulation dataframe.
- Parameters:
problem (
amici.jax.petab.JAXProblem
) – Problem to run simulations for.solver (
diffrax._solver.base.AbstractSolver
) – ODE solver to use for simulation.controller (
diffrax._step_size_controller.base.AbstractStepSizeController
) – Step size controller to use for simulation.max_steps (
int
) – Maximum number of steps to take during simulation.
- Returns:
petab simulation dataframe.
- amici.jax.run_simulations(problem, simulation_conditions=None, solver=Kvaerno5( scan_kind=None, root_finder=VeryChord( rtol=<tolerance taken from `diffeqsolve(..., stepsize_controller=...)` argument>, atol=<tolerance taken from `diffeqsolve(..., stepsize_controller=...)` argument>, norm=<tolerance taken from `diffeqsolve(..., stepsize_controller=...)` argument>, kappa=0.01, linear_solver=AutoLinearSolver(well_posed=None) ), root_find_max_steps=10 ), controller=PIDController( rtol=1e-08, atol=1e-08, pcoeff=0.4, icoeff=0.3, dcoeff=0.0, dtmin=None, dtmax=None, force_dtmin=True, step_ts=None, jump_ts=None, factormin=0.2, factormax=10.0, norm=<function rms_norm>, safety=0.9, error_order=None ), max_steps=1024, ret=ReturnValue.llh)[source]
Run simulations for a problem.
- Parameters:
problem (
amici.jax.petab.JAXProblem
) – Problem to run simulations for.simulation_conditions (
collections.abc.Iterable
[tuple
[str
,...
]] |None
) – Simulation conditions to run simulations for.solver (
diffrax._solver.base.AbstractSolver
) – ODE solver to use for simulation.controller (
diffrax._step_size_controller.base.AbstractStepSizeController
) – Step size controller to use for simulation.max_steps (
int
) – Maximum number of steps to take during simulation.ret (
amici.jax.model.ReturnValue
|str
) – which output to return. SeeReturnValue
for available options.
- Returns:
Overall output value and condition specific results and statistics.