amici.jax

JAX

This module provides an interface to generate and use AMICI models with JAX. Please note that this module is experimental, the API may substantially change in the future. Use at your own risk and do not expect backward compatibility.

class amici.jax.JAXModel[source]

JAXModel provides an abstract base class for a JAX-based implementation of an AMICI model. The class implements routines for simulation and evaluation of derived quantities, model specific implementations need to be provided by classes inheriting from JAXModel.

Variables:
  • api_version – API version of the derived class. Needs to match the API version of the base class (MODEL_API_VERSION).

  • MODEL_API_VERSION – API version of the base class.

  • jax_py_file – Path to the JAX model file.

MODEL_API_VERSION = '0.0.3'
__init__()[source]
api_version: str
jax_py_file: pathlib.Path
abstract property observable_ids: list[str]

Get the observable ids of the model.

Returns:

Observable ids

abstract property parameter_ids: list[str]

Get the parameter ids of the model.

Returns:

Parameter ids

preequilibrate_condition(p, x_reinit, mask_reinit, solver, controller, steady_state_event, max_steps)[source]

Simulate a condition.

Parameters:
  • p (jaxtyping.Float[Array, 'np']) – parameters for simulation ordered according to ids in :ivar parameter_ids:

  • x_reinit (jaxtyping.Float[Array, '*nx']) – re-initialized state vector. If not provided, the state vector is not re-initialized.

  • mask_reinit (jaxtyping.Bool[Array, '*nx']) – mask for re-initialization. If True, the corresponding state variable is re-initialized.

  • solver (diffrax._solver.base.AbstractSolver) – ODE solver

  • controller (diffrax._step_size_controller.base.AbstractStepSizeController) – step size controller

  • max_steps (int | jax.numpy.int64) – maximum number of solver steps

Return type:

tuple[jaxtyping.Float[Array, 'nx'], dict]

Returns:

pre-equilibrated state variables and statistics

simulate_condition(p, ts_dyn, ts_posteq, my, iys, iy_trafos, ops, nps, solver, controller, adjoint, steady_state_event, max_steps, x_preeq=Array([], shape=(0,), dtype=float32), mask_reinit=Array([], shape=(0,), dtype=float32), x_reinit=Array([], shape=(0,), dtype=float32), ts_mask=Array([], shape=(0,), dtype=float32), ret=ReturnValue.llh)[source]

Simulate a condition.

Parameters:
  • p (jaxtyping.Float[Array, 'np']) – parameters for simulation ordered according to ids in :ivar parameter_ids:

  • ts_dyn (jaxtyping.Float[Array, 'nt_dyn']) – time points for dynamic simulation. Sorted in monotonically increasing order but duplicate time points are allowed to facilitate the evaluation of multiple observables at specific time points.

  • ts_posteq (jaxtyping.Float[Array, 'nt_posteq']) – time points for post-equilibration. Usually valued Infty, but needs to be shaped according to the number of observables that are evaluated after post-equilibration.

  • my (jaxtyping.Float[Array, 'nt']) – observed data

  • iys (jaxtyping.Int[Array, 'nt']) – indices of the observables according to ordering in :ivar observable_ids:

  • iy_trafos (jaxtyping.Int[Array, 'nt']) – indices of transformations for observables

  • ops (jaxtyping.Float[Array, 'nt *nop']) – observables parameters

  • nps (jaxtyping.Float[Array, 'nt *nnp']) – noise parameters

  • solver (diffrax._solver.base.AbstractSolver) – ODE solver

  • controller (diffrax._step_size_controller.base.AbstractStepSizeController) – step size controller

  • adjoint (diffrax._adjoint.AbstractAdjoint) – adjoint method. Recommended values are diffrax.DirectAdjoint() for jax.jacfwd (with vector-valued outputs) and diffrax.RecursiveCheckpointAdjoint() for jax.grad (for scalar-valued outputs).

  • steady_state_event (collections.abc.Callable[..., typing.Union[jaxtyping.Bool[Array, ''], jaxtyping.Bool[ndarray, ''], numpy.bool_, numpy.number, bool]]) – event function for steady state. See diffrax.steady_state_event() for details.

  • max_steps (int | jax.numpy.int64) – maximum number of solver steps

  • x_preeq (jaxtyping.Float[Array, '*nx']) – initial state vector for pre-equilibration. If not provided, the initial state vector is computed using _x0().

  • mask_reinit (jaxtyping.Bool[Array, '*nx']) – mask for re-initialization. If True, the corresponding state variable is re-initialized.

  • x_reinit (jaxtyping.Float[Array, '*nx']) – re-initialized state vector. If not provided, the state vector is not re-initialized.

  • ts_mask (jaxtyping.Bool[Array, 'nt']) – mask to remove (padded) time points. If True, the corresponding time point is used for the evaluation of the output. Only applied if ret is ReturnValue.llh, ReturnValue.nllhs, ReturnValue.res, or ReturnValue.chi2.

  • ret (amici.jax.model.ReturnValue) – which output to return. See ReturnValue for available options.

Return type:

tuple[jaxtyping.Float[Array, 'nt *nx'] | jax.numpy.float64, dict]

Returns:

output according to ret and general results/statistics

abstract property state_ids: list[str]

Get the state ids of the model.

Returns:

State ids

class amici.jax.JAXProblem(model, petab_problem)[source]

PEtab problem wrapper for JAX models.

Variables:
  • parameters – Values for the model parameters. Do not change dimensions, values may be changed during, e.g. model training.

  • model – JAXModel instance to use for simulation.

  • _parameter_mappingsParameterMappingForCondition instances for each simulation condition.

  • _measurements – Preprocessed arrays for each simulation condition.

  • _petab_problem – PEtab problem to simulate.

__init__(model, petab_problem)[source]

Initialize a JAXProblem instance with a model and a PEtab problem.

Parameters:
get_all_simulation_conditions()[source]
Return type:

tuple[tuple[str, ...], ...]

get_petab_parameter_by_id(name)[source]

Get the value of a PEtab parameter by name.

Parameters:

name (str) – PEtab parameter id, as returned by parameter_ids.

Return type:

jax.numpy.float64

Returns:

Value of the parameter

classmethod load(directory)[source]

Load a problem from a directory.

Parameters:

directory (pathlib.Path) – Directory to load the problem from.

Returns:

Loaded problem instance.

load_parameters(simulation_condition)[source]

Load parameters for a simulation condition.

Parameters:

simulation_condition (str) – Simulation condition to load parameters for.

Return type:

jaxtyping.Float[Array, 'np']

Returns:

Parameters for the simulation condition.

load_reinitialisation(simulation_condition, p)[source]

Load reinitialisation values and mask for the state vector for a simulation condition.

Parameters:
  • simulation_condition (str) – Simulation condition to load reinitialisation for.

  • p (jaxtyping.Float[Array, 'np']) – Parameters for the simulation condition.

Return type:

tuple[jaxtyping.Bool[Array, 'nx'], jaxtyping.Float[Array, 'nx']]

Returns:

Tuple of reinitialisation masm and value for states.

model: amici.jax.model.JAXModel
property parameter_ids: list[str]

Parameter ids that are estimated in the PEtab problem. Same ordering as values in parameters.

Returns:

PEtab parameter ids

parameters: jax.Array
run_preequilibration(p, mask_reinit, x_reinit, solver, controller, steady_state_event, max_steps)[source]

Run a pre-equilibration simulation for a given simulation condition.

Parameters:
  • p (jaxtyping.Float[Array, 'np']) – Parameters for the simulation condition

  • mask_reinit (jaxtyping.Bool[Array, 'nx']) – Mask for states that need reinitialisation

  • x_reinit (jaxtyping.Float[Array, 'nx']) – Reinitialisation values for states

  • solver (diffrax._solver.base.AbstractSolver) – ODE solver to use for simulation

  • controller (diffrax._step_size_controller.base.AbstractStepSizeController) – Step size controller to use for simulation

  • steady_state_event (collections.abc.Callable[..., typing.Union[jaxtyping.Bool[Array, ''], jaxtyping.Bool[ndarray, ''], numpy.bool_, numpy.number, bool]]) – Steady state event function to use for pre-equilibration. Allows customisation of the steady state condition, see diffrax.steady_state_event() for details.

  • max_steps (jax.numpy.int64) – Maximum number of steps to take during simulation

Return type:

tuple[jaxtyping.Float[Array, 'nx'], dict]

Returns:

Pre-equilibration state

run_preequilibrations(simulation_conditions, solver, controller, steady_state_event, max_steps)[source]
run_simulation(p, ts_dyn, ts_posteq, my, iys, iy_trafos, ops, nps, mask_reinit, x_reinit, solver, controller, steady_state_event, max_steps, x_preeq=Array([], shape=(0,), dtype=float32), ts_mask=array([], dtype=float64), ret=ReturnValue.llh)[source]

Run a simulation for a given simulation condition.

Parameters:
  • p (jaxtyping.Float[Array, 'np']) – Parameters for the simulation condition

  • ts_dyn (numpy.ndarray) – (Padded) dynamic time points

  • ts_posteq (numpy.ndarray) – (Padded) post-equilibrium time points

  • my (numpy.ndarray) – (Padded) measurements

  • iys (numpy.ndarray) – (Padded) observable indices

  • iy_trafos (numpy.ndarray) – (Padded) observable transformations indices

  • ops (jaxtyping.Float[Array, 'nt *nop']) – (Padded) observable parameters

  • nps (jaxtyping.Float[Array, 'nt *nnp']) – (Padded) noise parameters

  • mask_reinit (jaxtyping.Bool[Array, 'nx']) – Mask for states that need reinitialisation

  • x_reinit (jaxtyping.Float[Array, 'nx']) – Reinitialisation values for states

  • solver (diffrax._solver.base.AbstractSolver) – ODE solver to use for simulation

  • controller (diffrax._step_size_controller.base.AbstractStepSizeController) – Step size controller to use for simulation

  • steady_state_event (collections.abc.Callable[..., typing.Union[jaxtyping.Bool[Array, ''], jaxtyping.Bool[ndarray, ''], numpy.bool_, numpy.number, bool]]) – Steady state event function to use for post-equilibration. Allows customisation of the steady state condition, see diffrax.steady_state_event() for details.

  • max_steps (jax.numpy.int64) – Maximum number of steps to take during simulation

  • x_preeq (jaxtyping.Float[Array, '*nx']) – Pre-equilibration state. Can be empty if no pre-equilibration is available, in which case the states will be initialised to the model default values.

  • ts_mask (numpy.ndarray) – padding mask, see JAXModel.simulate_condition() for details.

  • ret (amici.jax.model.ReturnValue) – which output to return. See ReturnValue for available options.

Return type:

tuple[jax.numpy.float64, dict]

Returns:

Tuple of output value and simulation statistics

run_simulations(simulation_conditions, preeq_array, solver, controller, steady_state_event, max_steps, ret=ReturnValue.llh)[source]

Run simulations for a list of simulation conditions.

Parameters:
  • simulation_conditions (list[str]) – List of simulation conditions to run simulations for.

  • preeq_array (jaxtyping.Float[Array, 'ncond *nx']) – Matrix of pre-equilibrated states for the simulation conditions. Ordering must match the simulation conditions. If no pre-equilibration is available for a condition, the corresponding row must be empty.

  • solver (diffrax._solver.base.AbstractSolver) – ODE solver to use for simulation.

  • controller (diffrax._step_size_controller.base.AbstractStepSizeController) – Step size controller to use for simulation.

  • steady_state_event (collections.abc.Callable[..., typing.Union[jaxtyping.Bool[Array, ''], jaxtyping.Bool[ndarray, ''], numpy.bool_, numpy.number, bool]]) – Steady state event function to use for post-equilibration. Allows customisation of the steady state condition, see diffrax.steady_state_event() for details.

  • max_steps (jax.numpy.int64) – Maximum number of steps to take during simulation.

  • ret (amici.jax.model.ReturnValue) – which output to return. See ReturnValue for available options.

Returns:

Output value and condition specific results and statistics. Results and statistics are returned as a dict with arrays with the leading dimension corresponding to the simulation conditions.

save(directory)[source]

Save the problem to a directory.

Parameters:

directory (pathlib.Path) – Directory to save the problem to.

simulation_conditions: tuple[tuple[str, ...], ...]
update_parameters(p)[source]

Update parameters for the model.

Parameters:

p (jaxtyping.Float[Array, 'np']) – New problem instance with updated parameters.

Return type:

amici.jax.petab.JAXProblem

class amici.jax.ReturnValue(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]
chi2 = 'sum(((observed - simulated) / sigma ) ** 2)'
llh = 'log-likelihood'
nllhs = 'pointwise negative log-likelihood'
res = 'residuals'
sigmay = 'standard deviations of the observables'
tcl = 'total values for conservation laws'
x = 'full state vector'
x0 = 'full initial state vector'
x0_solver = 'reduced initial state vector'
x_solver = 'reduced state vector'
y = 'observables'
amici.jax.petab_simulate(problem, solver=Kvaerno5(), controller=PIDController(rtol=1e-08, atol=1e-08, pcoeff=0.4, icoeff=0.3, dcoeff=0.0), steady_state_event=<function steady_state_event.<locals>._cond_fn>, max_steps=1024)[source]

Run simulations for a problem and return the results as a petab simulation dataframe.

Parameters:
  • problem (amici.jax.petab.JAXProblem) – Problem to run simulations for.

  • solver (diffrax._solver.base.AbstractSolver) – ODE solver to use for simulation.

  • controller (diffrax._step_size_controller.base.AbstractStepSizeController) – Step size controller to use for simulation.

  • max_steps (int) – Maximum number of steps to take during simulation.

  • steady_state_event (collections.abc.Callable[..., typing.Union[jaxtyping.Bool[Array, ''], jaxtyping.Bool[ndarray, ''], numpy.bool_, numpy.number, bool]]) – Steady state event function to use for pre-/post-equilibration. Allows customisation of the steady state condition, see diffrax.steady_state_event() for details.

Returns:

petab simulation dataframe.

amici.jax.run_simulations(problem, simulation_conditions=None, solver=Kvaerno5(), controller=PIDController(rtol=1e-08, atol=1e-08, pcoeff=0.4, icoeff=0.3, dcoeff=0.0), steady_state_event=<function steady_state_event.<locals>._cond_fn>, max_steps=1024, ret=ReturnValue.llh)[source]

Run simulations for a problem.

Parameters:
  • problem (amici.jax.petab.JAXProblem) – Problem to run simulations for.

  • simulation_conditions (collections.abc.Iterable[tuple[str, ...]] | None) – Simulation conditions to run simulations for. This is a series of tuples, where each tuple contains the simulation condition or the pre-equilibration condition followed by the simulation condition. Default is to run simulations for all conditions.

  • solver (diffrax._solver.base.AbstractSolver) – ODE solver to use for simulation.

  • controller (diffrax._step_size_controller.base.AbstractStepSizeController) – Step size controller to use for simulation.

  • steady_state_event (collections.abc.Callable[..., typing.Union[jaxtyping.Bool[Array, ''], jaxtyping.Bool[ndarray, ''], numpy.bool_, numpy.number, bool]]) – Steady state event function to use for pre-/post-equilibration. Allows customisation of the steady state condition, see diffrax.steady_state_event() for details.

  • max_steps (int) – Maximum number of steps to take during simulation.

  • ret (amici.jax.model.ReturnValue | str) – which output to return. See ReturnValue for available options.

Returns:

Overall output value and condition specific results and statistics.