AMICI & JAX
Overview
The purpose of this guide is to showcase how AMICI can be combined with differentiable programming in JAX. We will do so by reimplementing the parameter transformations available in AMICI in JAX and comparing it to the native implementation.
[1]:
import jax
import jax.numpy as jnp
Preparation
To get started, we will import a model using the petab. To this end, we will use the benchmark collection, which features a variety of different models. For more details about petab import, see the respective notebook petab notebook.
From the benchmark collection, we now import the Böhm model.
[2]:
import petab.v1 as petab
model_name = "Boehm_JProteomeRes2014"
yaml_file = f"https://raw.githubusercontent.com/Benchmarking-Initiative/Benchmark-Models-PEtab/master/Benchmark-Models/{model_name}/{model_name}.yaml"
petab_problem = petab.Problem.from_yaml(yaml_file)
The petab problem includes information about parameter scaling in it’s the parameter table. For the boehm model, all estimated parameters (petab.ESTIMATE
column equal to 1
) have a petab.LOG10
as parameter scaling.
[3]:
petab_problem.parameter_df
[3]:
parameterName | parameterScale | lowerBound | upperBound | nominalValue | estimate | |
---|---|---|---|---|---|---|
parameterId | ||||||
Epo_degradation_BaF3 | EPO_{degradation,BaF3} | log10 | 0.00001 | 100000 | 0.026983 | 1 |
k_exp_hetero | k_{exp,hetero} | log10 | 0.00001 | 100000 | 0.000010 | 1 |
k_exp_homo | k_{exp,homo} | log10 | 0.00001 | 100000 | 0.006170 | 1 |
k_imp_hetero | k_{imp,hetero} | log10 | 0.00001 | 100000 | 0.016368 | 1 |
k_imp_homo | k_{imp,homo} | log10 | 0.00001 | 100000 | 97749.379402 | 1 |
k_phos | k_{phos} | log10 | 0.00001 | 100000 | 15766.507020 | 1 |
ratio | ratio | lin | 0.00000 | 5 | 0.693000 | 0 |
sd_pSTAT5A_rel | \sigma_{pSTAT5A,rel} | log10 | 0.00001 | 100000 | 3.852612 | 1 |
sd_pSTAT5B_rel | \sigma_{pSTAT5B,rel} | log10 | 0.00001 | 100000 | 6.591478 | 1 |
sd_rSTAT5A_rel | \sigma_{rSTAT5A,rel} | log10 | 0.00001 | 100000 | 3.152713 | 1 |
specC17 | specC17 | lin | -5.00000 | 5 | 0.107000 | 0 |
We now import the petab problem using amici.petab_import.
[4]:
from amici.petab.petab_import import import_petab_problem
amici_model = import_petab_problem(petab_problem, compile_=True, verbose=False)
JAX implementation
For full jax support, we would have to implement a new primitive, which would require quite a bit of engineering, and in the end wouldn’t add much benefit since AMICI can’t run on GPUs. Instead, we will interface AMICI using the jax method pure_callback.
To do so, we define a base function that only takes a single argument (the parameters) and runs simulation using petab via simulate_petab. To enable gradient computation later on, we create a solver object and set the sensitivity order to first order and pass it to simulate_petab
. Moreover, simulate_petab
expects a dictionary of parameters, so we create a dictionary
using the free parameter ids from the petab problem.
[5]:
from amici.petab.simulations import simulate_petab
import amici
amici_solver = amici_model.getSolver()
amici_solver.setSensitivityOrder(amici.SensitivityOrder.first)
def amici_callback_base(parameters: jnp.array):
ret = simulate_petab(
petab_problem,
amici_model,
problem_parameters=dict(zip(petab_problem.x_free_ids, parameters)),
solver=amici_solver,
)
llh = np.array(ret["llh"])
sllh = np.array(
tuple(ret["sllh"][par_id] for par_id in petab_problem.x_free_ids)
)
return llh, sllh
Now we can use this base function to create two separate functions that return the log-likelihood (llh
) and a tuple with log-likelihood and its gradient (sllh
). Both functions use pure_callback such that they can be called by other jax functions. Note that, as we are using the same base function here, the log-likelihood computation will also run with sensitivities which is not necessary and will add some
overhead. This is only out of convenience and should be fixed in an application where efficiency is important.
[6]:
def device_fun_llh(x: jnp.array):
return jax.pure_callback(
lambda x: amici_callback_base(x)[0],
jax.ShapeDtypeStruct((), x.dtype),
x,
)
def device_fun_llh_sllh(x: jnp.array):
return jax.pure_callback(
amici_callback_base,
(
jax.ShapeDtypeStruct((), x.dtype),
jax.ShapeDtypeStruct(
x.shape,
x.dtype,
),
),
x,
)
Even though the two functions that we just defined are valid jax functions, they can’t compute derivatives yet. To support derivative computation, we have to define a new function with a jax.custom_jvp
decorator, which specifies that we will define a custom jacobian vector product (jvp) function, as well as the corresponding jvp using the @jax_objective.defjvp
decorator. More details about custom jacobian vector product functions can be found in the JAX
documentation
[7]:
@jax.custom_jvp
def jax_objective(parameters: jnp.array):
return device_fun_llh(parameters)
@jax_objective.defjvp
def jax_objective_jvp(primals: jnp.array, tangents: jnp.array):
(parameters,) = primals
(x_dot,) = tangents
llh, sllh = device_fun_llh_sllh(parameters)
return llh, sllh @ x_dot
As last step, we implement the parameter transformation in jax. This effectively just extracts parameter scales from the petab problem, implements rescaling in jax and then passes the scaled parameters to the objective function we previously defined. We add the jax.value_and_grad
decorator such that the generated jax function returns both function value and function gradient in a tuple. Moreover, we add the jax.jit
decorator such that the function is just-in-time
compiled upon the first function call.
[8]:
parameter_scales = petab_problem.parameter_df.loc[
petab_problem.x_free_ids, petab.PARAMETER_SCALE
].values
@jax.jit
@jax.value_and_grad
def jax_objective_with_parameter_transform(parameters: jnp.array):
par_scaled = jnp.asarray(
tuple(
value
if scale == petab.LIN
else jnp.exp(value)
if scale == petab.LOG
else jnp.power(10, value)
for value, scale in zip(parameters, parameter_scales)
)
)
return jax_objective(par_scaled)
Testing
We can now run the function to compute the log-likelihood and the gradient.
[9]:
import numpy as np
parameters = dict(zip(petab_problem.x_free_ids, petab_problem.x_nominal_free))
scaled_parameters = petab_problem.scale_parameters(parameters)
scaled_parameters_np = np.asarray(list(scaled_parameters.values()))
[10]:
llh_jax, sllh_jax = jax_objective_with_parameter_transform(
jnp.array(scaled_parameters_np)
)
As a sanity check, we compare the computed value to native parameter transformation in amici.
[11]:
r = simulate_petab(
petab_problem,
amici_model,
solver=amici_solver,
scaled_parameters=True,
scaled_gradients=True,
problem_parameters=scaled_parameters,
)
[12]:
import pandas as pd
pd.DataFrame(
dict(
amici=r["llh"],
jax=float(llh_jax),
rel_diff=(r["llh"] - float(llh_jax)) / r["llh"],
),
index=("llh",),
)
[12]:
amici | jax | rel_diff | |
---|---|---|---|
llh | -138.221997 | -138.222 | -2.135248e-08 |
[13]:
grad_amici = np.asarray(list(r["sllh"].values()))
grad_jax = np.asarray(sllh_jax)
rel_diff = (grad_amici - grad_jax) / grad_jax
pd.DataFrame(
index=r["sllh"].keys(),
data=dict(amici=grad_amici, jax=grad_jax, rel_diff=rel_diff),
)
[13]:
amici | jax | rel_diff | |
---|---|---|---|
Epo_degradation_BaF3 | -0.022045 | -0.022034 | 4.645833e-04 |
k_exp_hetero | -0.055323 | -0.055323 | 8.646725e-08 |
k_exp_homo | -0.005789 | -0.005801 | -2.013520e-03 |
k_imp_hetero | -0.005414 | -0.005403 | 1.973604e-03 |
k_imp_homo | 0.000045 | 0.000045 | 1.119566e-06 |
k_phos | -0.007907 | -0.007794 | 1.447768e-02 |
sd_pSTAT5A_rel | -0.010784 | -0.010800 | -1.469518e-03 |
sd_pSTAT5B_rel | -0.024037 | -0.024037 | -8.729860e-06 |
sd_rSTAT5A_rel | -0.019191 | -0.019186 | 2.829431e-04 |
We see quite some differences in the gradient calculation, with over to 1% error for k_phos
. The primary reason is that running JAX in default configuration will use float32 precision for the parameters that are passed to AMICI, which uses float64, and the derivative of the parameter transformation. As AMICI simulations that run on the CPU are the most expensive operation, there is barely any tradeoff for using float32 vs. float64 in JAX. Therefore, we configure JAX to use float64 instead
and rerun simulations.
[14]:
jax.config.update("jax_enable_x64", True)
llh_jax, sllh_jax = jax_objective_with_parameter_transform(
scaled_parameters_np
)
We can now evaluate the results again and see that differences between pure AMICI and AMICI/JAX implementations have now disappeared.
[15]:
pd.DataFrame(
dict(
amici=r["llh"],
jax=float(llh_jax),
rel_diff=(r["llh"] - float(llh_jax)) / r["llh"],
),
index=("llh",),
)
[15]:
amici | jax | rel_diff | |
---|---|---|---|
llh | -138.221997 | -138.221997 | -0.0 |
[16]:
grad_amici = np.asarray(list(r["sllh"].values()))
grad_jax = np.asarray(sllh_jax)
rel_diff = (grad_amici - grad_jax) / grad_jax
pd.DataFrame(
index=r["sllh"].keys(),
data=dict(amici=grad_amici, jax=grad_jax, rel_diff=rel_diff),
)
[16]:
amici | jax | rel_diff | |
---|---|---|---|
Epo_degradation_BaF3 | -0.022045 | -0.022045 | -0.0 |
k_exp_hetero | -0.055323 | -0.055323 | -0.0 |
k_exp_homo | -0.005789 | -0.005789 | -0.0 |
k_imp_hetero | -0.005414 | -0.005414 | -0.0 |
k_imp_homo | 0.000045 | 0.000045 | 0.0 |
k_phos | -0.007907 | -0.007907 | -0.0 |
sd_pSTAT5A_rel | -0.010784 | -0.010784 | -0.0 |
sd_pSTAT5B_rel | -0.024037 | -0.024037 | -0.0 |
sd_rSTAT5A_rel | -0.019191 | -0.019191 | -0.0 |