AMICI & JAX

Overview

The purpose of this guide is to showcase how AMICI can be combined with differentiable programming in JAX. We will do so by reimplementing the parameter transformations available in AMICI in JAX and comparing it to the native implementation.

[1]:
import jax
import jax.numpy as jnp

Preparation

To get started, we will import a model using the petab. To this end, we will use the benchmark collection, which features a variety of different models. For more details about petab import, see the respective notebook petab notebook.

From the benchmark collection, we now import the Böhm model.

[2]:
import petab

model_name = "Boehm_JProteomeRes2014"
yaml_file = f"https://raw.githubusercontent.com/Benchmarking-Initiative/Benchmark-Models-PEtab/master/Benchmark-Models/{model_name}/{model_name}.yaml"
petab_problem = petab.Problem.from_yaml(yaml_file)

The petab problem includes information about parameter scaling in it’s the parameter table. For the boehm model, all estimated parameters (petab.ESTIMATE column equal to 1) have a petab.LOG10 as parameter scaling.

[3]:
petab_problem.parameter_df
[3]:
parameterName parameterScale lowerBound upperBound nominalValue estimate
parameterId
Epo_degradation_BaF3 EPO_{degradation,BaF3} log10 0.00001 100000 0.026983 1
k_exp_hetero k_{exp,hetero} log10 0.00001 100000 0.000010 1
k_exp_homo k_{exp,homo} log10 0.00001 100000 0.006170 1
k_imp_hetero k_{imp,hetero} log10 0.00001 100000 0.016368 1
k_imp_homo k_{imp,homo} log10 0.00001 100000 97749.379402 1
k_phos k_{phos} log10 0.00001 100000 15766.507020 1
ratio ratio lin -5.00000 5 0.693000 0
sd_pSTAT5A_rel \sigma_{pSTAT5A,rel} log10 0.00001 100000 3.852612 1
sd_pSTAT5B_rel \sigma_{pSTAT5B,rel} log10 0.00001 100000 6.591478 1
sd_rSTAT5A_rel \sigma_{rSTAT5A,rel} log10 0.00001 100000 3.152713 1
specC17 specC17 lin -5.00000 5 0.107000 0

We now import the petab problem using amici.petab_import.

[4]:
from amici.petab.petab_import import import_petab_problem

amici_model = import_petab_problem(
    petab_problem, compile_=True, verbose=False
)

JAX implementation

For full jax support, we would have to implement a new primitive, which would require quite a bit of engineering, and in the end wouldn’t add much benefit since AMICI can’t run on GPUs. Instead, will interface AMICI using the experimental jax module host_callback.

To do so, we define a base function that only takes a single argument (the parameters) and runs simulation using petab via simulate_petab. To enable gradient computation later on, we create a solver object and set the sensitivity order to first order and pass it to simulate_petab. Moreover, simulate_petab expects a dictionary of parameters, so we create a dictionary using the free parameter ids from the petab problem. As we want to implement parameter transformation in JAX, we disable parameter scaling in petab by passing scaled_parameters=True.

[5]:
from amici.petab.simulations import simulate_petab
import amici

amici_solver = amici_model.getSolver()
amici_solver.setSensitivityOrder(amici.SensitivityOrder.first)


def amici_hcb_base(parameters: jnp.array):
    return simulate_petab(
        petab_problem,
        amici_model,
        problem_parameters=dict(zip(petab_problem.x_free_ids, parameters)),
        solver=amici_solver,
    )

Now we can use this base function to create two separate functions that compute the log-likelihood (llh) and its gradient (sllh) in two individual routines. Note that, as we are using the same base function here, the log-likelihood computation will also run with sensitivities which is not necessary and will add some overhead. This is only out of convenience and should be fixed in an application where efficiency is important.

[6]:
def amici_hcb_llh(parameters: jnp.array):
    return amici_hcb_base(parameters)["llh"]


def amici_hcb_sllh(parameters: jnp.array):
    sllh = amici_hcb_base(parameters)["sllh"]
    return jnp.asarray(
        tuple(sllh[par_id] for par_id in petab_problem.x_free_ids)
    )

Now we can finally define the JAX function that runs amici simulation using the host callback. We add a custom_jvp decorator so that we can define a custom jacobian vector product function in the next step. More details about custom jacobian vector product functions can be found in the JAX documentation

[7]:
import jax.experimental.host_callback as hcb
from jax import custom_jvp

import numpy as np


@custom_jvp
def jax_objective(parameters: jnp.array):
    return hcb.call(
        amici_hcb_llh,
        parameters,
        result_shape=jax.ShapeDtypeStruct((), np.float64),
    )

Now we define the function that implement the jacobian vector product. This effectively just returns the objective function value (computed using the previously defined jax_objective) as well as the inner product of the gradient (computed using a host callback to the previously defined amici_hcb_sllh) and the tangents vector. Note that this implementation performs two simulation runs, one for the function value and one for the gradient, which is inefficient and could be avoided by caching solutions.

[8]:
@jax_objective.defjvp
def jax_objective_jvp(primals: jnp.array, tangents: jnp.array):
    (parameters,) = primals
    (x_dot,) = tangents
    llh = jax_objective(parameters)
    sllh = hcb.call(
        amici_hcb_sllh,
        parameters,
        result_shape=jax.ShapeDtypeStruct(
            (petab_problem.parameter_df.estimate.sum(),), np.float64
        ),
    )
    return llh, sllh.dot(x_dot)

As last step, we implement the parameter transformation in jax. This effectively just extracts parameter scales from the petab problem, implements rescaling in jax and then passes the scaled parameters to the previously objective function we previously defined. We add the value_and_grad decorator such that the generated jax function returns both function value and function gradient in a tuple. Moreover, we add the jax.jit decorator such that the function is just in time compiled upon the first function call.

[9]:
from jax import value_and_grad

parameter_scales = petab_problem.parameter_df.loc[
    petab_problem.x_free_ids, petab.PARAMETER_SCALE
].values


@jax.jit
@value_and_grad
def jax_objective_with_parameter_transform(parameters: jnp.array):
    par_scaled = jnp.asarray(
        tuple(
            value
            if scale == petab.LIN
            else jnp.exp(value)
            if scale == petab.LOG
            else jnp.power(10, value)
            for value, scale in zip(parameters, parameter_scales)
        )
    )
    return jax_objective(par_scaled)

Testing

We can now run the function to compute the log-likelihood and the gradient.

[10]:
parameters = dict(zip(petab_problem.x_free_ids, petab_problem.x_nominal_free))
scaled_parameters = petab_problem.scale_parameters(parameters)
scaled_parameters_np = np.asarray(list(scaled_parameters.values()))
[11]:
llh_jax, sllh_jax = jax_objective_with_parameter_transform(
    scaled_parameters_np
)

As a sanity check, we compare the computed value to native parameter transformation in amici.

[12]:
r = simulate_petab(
    petab_problem,
    amici_model,
    solver=amici_solver,
    scaled_parameters=True,
    scaled_gradients=True,
    problem_parameters=scaled_parameters,
)
[13]:
import pandas as pd

pd.DataFrame(
    dict(
        amici=r["llh"],
        jax=float(llh_jax),
        rel_diff=(r["llh"] - float(llh_jax)) / r["llh"],
    ),
    index=("llh",),
)
[13]:
amici jax rel_diff
llh -138.221997 -138.222 -2.135248e-08
[14]:
grad_amici = np.asarray(list(r["sllh"].values()))
grad_jax = np.asarray(sllh_jax)
rel_diff = (grad_amici - grad_jax) / grad_jax
pd.DataFrame(
    index=r["sllh"].keys(),
    data=dict(amici=grad_amici, jax=grad_jax, rel_diff=rel_diff),
)
[14]:
amici jax rel_diff
Epo_degradation_BaF3 -0.022045 -0.022034 4.645833e-04
k_exp_hetero -0.055323 -0.055323 8.646725e-08
k_exp_homo -0.005789 -0.005801 -2.013520e-03
k_imp_hetero -0.005414 -0.005403 1.973517e-03
k_imp_homo 0.000045 0.000045 1.119566e-06
k_phos -0.007907 -0.007794 1.447768e-02
sd_pSTAT5A_rel -0.010784 -0.010800 -1.469604e-03
sd_pSTAT5B_rel -0.024037 -0.024037 -8.729860e-06
sd_rSTAT5A_rel -0.019191 -0.019186 2.829431e-04

We see quite some differences in the gradient calculation, with over to 1% error for k_phos. The primary reason is that running JAX in default configuration will use float32 precision for the parameters that are passed to AMICI, which uses float64, and the derivative of the parameter transformation. As AMICI simulations that run on the CPU are the most expensive operation, there is barely any tradeoff for using float32 vs. float64 in JAX. Therefore, we configure JAX to use float64 instead and rerun simulations.

[15]:
jax.config.update("jax_enable_x64", True)
llh_jax, sllh_jax = jax_objective_with_parameter_transform(
    scaled_parameters_np
)

We can now evaluate the results again and see that differences between pure AMICI and AMICI/JAX implementations have now disappeared.

[16]:
pd.DataFrame(
    dict(
        amici=r["llh"],
        jax=float(llh_jax),
        rel_diff=(r["llh"] - float(llh_jax)) / r["llh"],
    ),
    index=("llh",),
)
[16]:
amici jax rel_diff
llh -138.221997 -138.221997 -0.0
[17]:
grad_amici = np.asarray(list(r["sllh"].values()))
grad_jax = np.asarray(sllh_jax)
rel_diff = (grad_amici - grad_jax) / grad_jax
pd.DataFrame(
    index=r["sllh"].keys(),
    data=dict(amici=grad_amici, jax=grad_jax, rel_diff=rel_diff),
)
[17]:
amici jax rel_diff
Epo_degradation_BaF3 -0.022045 -0.022045 -0.0
k_exp_hetero -0.055323 -0.055323 -0.0
k_exp_homo -0.005789 -0.005789 -0.0
k_imp_hetero -0.005414 -0.005414 -0.0
k_imp_homo 0.000045 0.000045 0.0
k_phos -0.007907 -0.007907 -0.0
sd_pSTAT5A_rel -0.010784 -0.010784 -0.0
sd_pSTAT5B_rel -0.024037 -0.024037 -0.0
sd_rSTAT5A_rel -0.019191 -0.019191 -0.0