amici.jax
JAX
This module provides an interface to generate and use AMICI models with JAX. Please note that this module is experimental, the API may substantially change in the future. Use at your own risk and do not expect backward compatibility.
- class amici.jax.JAXModel[source]
JAXModel provides an abstract base class for a JAX-based implementation of an AMICI model. The class implements routines for simulation and evaluation of derived quantities, model specific implementations need to be provided by classes inheriting from JAXModel.
- Variables:
api_version – API version of the derived class. Needs to match the API version of the base class (MODEL_API_VERSION).
MODEL_API_VERSION – API version of the base class.
jax_py_file – Path to the JAX model file.
- MODEL_API_VERSION = '0.0.3'
-
jax_py_file:
pathlib.Path
- abstract property observable_ids: list[str]
Get the observable ids of the model.
- Returns:
Observable ids
- abstract property parameter_ids: list[str]
Get the parameter ids of the model.
- Returns:
Parameter ids
- preequilibrate_condition(p, x_reinit, mask_reinit, solver, controller, steady_state_event, max_steps)[source]
Simulate a condition.
- Parameters:
p (
jaxtyping.Float[Array, 'np']
) – parameters for simulation ordered according to ids in :ivar parameter_ids:x_reinit (
jaxtyping.Float[Array, '*nx']
) – re-initialized state vector. If not provided, the state vector is not re-initialized.mask_reinit (
jaxtyping.Bool[Array, '*nx']
) – mask for re-initialization. If True, the corresponding state variable is re-initialized.solver (
diffrax._solver.base.AbstractSolver
) – ODE solvercontroller (
diffrax._step_size_controller.base.AbstractStepSizeController
) – step size controllermax_steps (
int
|jax.numpy.int64
) – maximum number of solver steps
- Return type:
- Returns:
pre-equilibrated state variables and statistics
- simulate_condition(p, ts_dyn, ts_posteq, my, iys, iy_trafos, ops, nps, solver, controller, adjoint, steady_state_event, max_steps, x_preeq=Array([], shape=(0,), dtype=float32), mask_reinit=Array([], shape=(0,), dtype=float32), x_reinit=Array([], shape=(0,), dtype=float32), ts_mask=Array([], shape=(0,), dtype=float32), ret=ReturnValue.llh)[source]
Simulate a condition.
- Parameters:
p (
jaxtyping.Float[Array, 'np']
) – parameters for simulation ordered according to ids in :ivar parameter_ids:ts_dyn (
jaxtyping.Float[Array, 'nt_dyn']
) – time points for dynamic simulation. Sorted in monotonically increasing order but duplicate time points are allowed to facilitate the evaluation of multiple observables at specific time points.ts_posteq (
jaxtyping.Float[Array, 'nt_posteq']
) – time points for post-equilibration. Usually valued Infty, but needs to be shaped according to the number of observables that are evaluated after post-equilibration.my (
jaxtyping.Float[Array, 'nt']
) – observed dataiys (
jaxtyping.Int[Array, 'nt']
) – indices of the observables according to ordering in :ivar observable_ids:iy_trafos (
jaxtyping.Int[Array, 'nt']
) – indices of transformations for observablesops (
jaxtyping.Float[Array, 'nt *nop']
) – observables parametersnps (
jaxtyping.Float[Array, 'nt *nnp']
) – noise parameterssolver (
diffrax._solver.base.AbstractSolver
) – ODE solvercontroller (
diffrax._step_size_controller.base.AbstractStepSizeController
) – step size controlleradjoint (
diffrax._adjoint.AbstractAdjoint
) – adjoint method. Recommended values are diffrax.DirectAdjoint() for jax.jacfwd (with vector-valued outputs) and diffrax.RecursiveCheckpointAdjoint() for jax.grad (for scalar-valued outputs).steady_state_event (
collections.abc.Callable
[...
,typing.Union
[jaxtyping.Bool[Array, '']
,jaxtyping.Bool[ndarray, '']
,numpy.bool_
,numpy.number
,bool
]]) – event function for steady state. Seediffrax.steady_state_event()
for details.max_steps (
int
|jax.numpy.int64
) – maximum number of solver stepsx_preeq (
jaxtyping.Float[Array, '*nx']
) – initial state vector for pre-equilibration. If not provided, the initial state vector is computed using_x0()
.mask_reinit (
jaxtyping.Bool[Array, '*nx']
) – mask for re-initialization. If True, the corresponding state variable is re-initialized.x_reinit (
jaxtyping.Float[Array, '*nx']
) – re-initialized state vector. If not provided, the state vector is not re-initialized.ts_mask (
jaxtyping.Bool[Array, 'nt']
) – mask to remove (padded) time points. If True, the corresponding time point is used for the evaluation of the output. Only applied if ret is ReturnValue.llh, ReturnValue.nllhs, ReturnValue.res, or ReturnValue.chi2.ret (
amici.jax.model.ReturnValue
) – which output to return. SeeReturnValue
for available options.
- Return type:
tuple
[jaxtyping.Float[Array, 'nt *nx']
|jax.numpy.float64
,dict
]- Returns:
output according to ret and general results/statistics
- class amici.jax.JAXProblem(model, petab_problem)[source]
PEtab problem wrapper for JAX models.
- Variables:
parameters – Values for the model parameters. Do not change dimensions, values may be changed during, e.g. model training.
model – JAXModel instance to use for simulation.
_parameter_mappings –
ParameterMappingForCondition
instances for each simulation condition._measurements – Preprocessed arrays for each simulation condition.
_petab_problem – PEtab problem to simulate.
- __init__(model, petab_problem)[source]
Initialize a JAXProblem instance with a model and a PEtab problem.
- Parameters:
model (
amici.jax.model.JAXModel
) – JAXModel instance to use for simulation.petab_problem (
petab.v1.problem.Problem
) – PEtab problem to simulate.
- get_petab_parameter_by_id(name)[source]
Get the value of a PEtab parameter by name.
- Parameters:
name (
str
) – PEtab parameter id, as returned byparameter_ids
.- Return type:
- Returns:
Value of the parameter
- classmethod load(directory)[source]
Load a problem from a directory.
- Parameters:
directory (
pathlib.Path
) – Directory to load the problem from.- Returns:
Loaded problem instance.
- load_parameters(simulation_condition)[source]
Load parameters for a simulation condition.
- Parameters:
simulation_condition (
str
) – Simulation condition to load parameters for.- Return type:
jaxtyping.Float[Array, 'np']
- Returns:
Parameters for the simulation condition.
- load_reinitialisation(simulation_condition, p)[source]
Load reinitialisation values and mask for the state vector for a simulation condition.
- Parameters:
simulation_condition (
str
) – Simulation condition to load reinitialisation for.p (
jaxtyping.Float[Array, 'np']
) – Parameters for the simulation condition.
- Return type:
tuple
[jaxtyping.Bool[Array, 'nx']
,jaxtyping.Float[Array, 'nx']
]- Returns:
Tuple of reinitialisation masm and value for states.
-
model:
amici.jax.model.JAXModel
- property parameter_ids: list[str]
Parameter ids that are estimated in the PEtab problem. Same ordering as values in
parameters
.- Returns:
PEtab parameter ids
- run_preequilibration(p, mask_reinit, x_reinit, solver, controller, steady_state_event, max_steps)[source]
Run a pre-equilibration simulation for a given simulation condition.
- Parameters:
p (
jaxtyping.Float[Array, 'np']
) – Parameters for the simulation conditionmask_reinit (
jaxtyping.Bool[Array, 'nx']
) – Mask for states that need reinitialisationx_reinit (
jaxtyping.Float[Array, 'nx']
) – Reinitialisation values for statessolver (
diffrax._solver.base.AbstractSolver
) – ODE solver to use for simulationcontroller (
diffrax._step_size_controller.base.AbstractStepSizeController
) – Step size controller to use for simulationsteady_state_event (
collections.abc.Callable
[...
,typing.Union
[jaxtyping.Bool[Array, '']
,jaxtyping.Bool[ndarray, '']
,numpy.bool_
,numpy.number
,bool
]]) – Steady state event function to use for pre-equilibration. Allows customisation of the steady state condition, seediffrax.steady_state_event()
for details.max_steps (
jax.numpy.int64
) – Maximum number of steps to take during simulation
- Return type:
- Returns:
Pre-equilibration state
- run_preequilibrations(simulation_conditions, solver, controller, steady_state_event, max_steps)[source]
- run_simulation(p, ts_dyn, ts_posteq, my, iys, iy_trafos, ops, nps, mask_reinit, x_reinit, solver, controller, steady_state_event, max_steps, x_preeq=Array([], shape=(0,), dtype=float32), ts_mask=array([], dtype=float64), ret=ReturnValue.llh)[source]
Run a simulation for a given simulation condition.
- Parameters:
p (
jaxtyping.Float[Array, 'np']
) – Parameters for the simulation conditionts_dyn (
numpy.ndarray
) – (Padded) dynamic time pointsts_posteq (
numpy.ndarray
) – (Padded) post-equilibrium time pointsmy (
numpy.ndarray
) – (Padded) measurementsiys (
numpy.ndarray
) – (Padded) observable indicesiy_trafos (
numpy.ndarray
) – (Padded) observable transformations indicesops (
jaxtyping.Float[Array, 'nt *nop']
) – (Padded) observable parametersnps (
jaxtyping.Float[Array, 'nt *nnp']
) – (Padded) noise parametersmask_reinit (
jaxtyping.Bool[Array, 'nx']
) – Mask for states that need reinitialisationx_reinit (
jaxtyping.Float[Array, 'nx']
) – Reinitialisation values for statessolver (
diffrax._solver.base.AbstractSolver
) – ODE solver to use for simulationcontroller (
diffrax._step_size_controller.base.AbstractStepSizeController
) – Step size controller to use for simulationsteady_state_event (
collections.abc.Callable
[...
,typing.Union
[jaxtyping.Bool[Array, '']
,jaxtyping.Bool[ndarray, '']
,numpy.bool_
,numpy.number
,bool
]]) – Steady state event function to use for post-equilibration. Allows customisation of the steady state condition, seediffrax.steady_state_event()
for details.max_steps (
jax.numpy.int64
) – Maximum number of steps to take during simulationx_preeq (
jaxtyping.Float[Array, '*nx']
) – Pre-equilibration state. Can be empty if no pre-equilibration is available, in which case the states will be initialised to the model default values.ts_mask (
numpy.ndarray
) – padding mask, seeJAXModel.simulate_condition()
for details.ret (
amici.jax.model.ReturnValue
) – which output to return. SeeReturnValue
for available options.
- Return type:
- Returns:
Tuple of output value and simulation statistics
- run_simulations(simulation_conditions, preeq_array, solver, controller, steady_state_event, max_steps, ret=ReturnValue.llh)[source]
Run simulations for a list of simulation conditions.
- Parameters:
simulation_conditions (
list
[str
]) – List of simulation conditions to run simulations for.preeq_array (
jaxtyping.Float[Array, 'ncond *nx']
) – Matrix of pre-equilibrated states for the simulation conditions. Ordering must match the simulation conditions. If no pre-equilibration is available for a condition, the corresponding row must be empty.solver (
diffrax._solver.base.AbstractSolver
) – ODE solver to use for simulation.controller (
diffrax._step_size_controller.base.AbstractStepSizeController
) – Step size controller to use for simulation.steady_state_event (
collections.abc.Callable
[...
,typing.Union
[jaxtyping.Bool[Array, '']
,jaxtyping.Bool[ndarray, '']
,numpy.bool_
,numpy.number
,bool
]]) – Steady state event function to use for post-equilibration. Allows customisation of the steady state condition, seediffrax.steady_state_event()
for details.max_steps (
jax.numpy.int64
) – Maximum number of steps to take during simulation.ret (
amici.jax.model.ReturnValue
) – which output to return. SeeReturnValue
for available options.
- Returns:
Output value and condition specific results and statistics. Results and statistics are returned as a dict with arrays with the leading dimension corresponding to the simulation conditions.
- save(directory)[source]
Save the problem to a directory.
- Parameters:
directory (
pathlib.Path
) – Directory to save the problem to.
- class amici.jax.ReturnValue(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]
- chi2 = 'sum(((observed - simulated) / sigma ) ** 2)'
- llh = 'log-likelihood'
- nllhs = 'pointwise negative log-likelihood'
- res = 'residuals'
- sigmay = 'standard deviations of the observables'
- tcl = 'total values for conservation laws'
- x = 'full state vector'
- x0 = 'full initial state vector'
- x0_solver = 'reduced initial state vector'
- x_solver = 'reduced state vector'
- y = 'observables'
- amici.jax.petab_simulate(problem, solver=Kvaerno5(), controller=PIDController(rtol=1e-08, atol=1e-08, pcoeff=0.4, icoeff=0.3, dcoeff=0.0), steady_state_event=<function steady_state_event.<locals>._cond_fn>, max_steps=1024)[source]
Run simulations for a problem and return the results as a petab simulation dataframe.
- Parameters:
problem (
amici.jax.petab.JAXProblem
) – Problem to run simulations for.solver (
diffrax._solver.base.AbstractSolver
) – ODE solver to use for simulation.controller (
diffrax._step_size_controller.base.AbstractStepSizeController
) – Step size controller to use for simulation.max_steps (
int
) – Maximum number of steps to take during simulation.steady_state_event (
collections.abc.Callable
[...
,typing.Union
[jaxtyping.Bool[Array, '']
,jaxtyping.Bool[ndarray, '']
,numpy.bool_
,numpy.number
,bool
]]) – Steady state event function to use for pre-/post-equilibration. Allows customisation of the steady state condition, seediffrax.steady_state_event()
for details.
- Returns:
petab simulation dataframe.
- amici.jax.run_simulations(problem, simulation_conditions=None, solver=Kvaerno5(), controller=PIDController(rtol=1e-08, atol=1e-08, pcoeff=0.4, icoeff=0.3, dcoeff=0.0), steady_state_event=<function steady_state_event.<locals>._cond_fn>, max_steps=1024, ret=ReturnValue.llh)[source]
Run simulations for a problem.
- Parameters:
problem (
amici.jax.petab.JAXProblem
) – Problem to run simulations for.simulation_conditions (
collections.abc.Iterable
[tuple
[str
,...
]] |None
) – Simulation conditions to run simulations for. This is a series of tuples, where each tuple contains the simulation condition or the pre-equilibration condition followed by the simulation condition. Default is to run simulations for all conditions.solver (
diffrax._solver.base.AbstractSolver
) – ODE solver to use for simulation.controller (
diffrax._step_size_controller.base.AbstractStepSizeController
) – Step size controller to use for simulation.steady_state_event (
collections.abc.Callable
[...
,typing.Union
[jaxtyping.Bool[Array, '']
,jaxtyping.Bool[ndarray, '']
,numpy.bool_
,numpy.number
,bool
]]) – Steady state event function to use for pre-/post-equilibration. Allows customisation of the steady state condition, seediffrax.steady_state_event()
for details.max_steps (
int
) – Maximum number of steps to take during simulation.ret (
amici.jax.model.ReturnValue
|str
) – which output to return. SeeReturnValue
for available options.
- Returns:
Overall output value and condition specific results and statistics.