AMICI & JAX
Overview
The purpose of this guide is to showcase how AMICI can be combined with differentiable programming in JAX. We will do so by reimplementing the parameter transformations available in AMICI in JAX and comparing it to the native implementation.
[1]:
import jax
import jax.numpy as jnp
Preparation
To get started, we will import a model using the petab. To this end, we will use the benchmark collection, which features a variety of different models. For more details about petab import, see the respective notebook petab notebook.
From the benchmark collection, we now import the Böhm model.
[2]:
import petab
model_name = "Boehm_JProteomeRes2014"
yaml_file = f"https://raw.githubusercontent.com/Benchmarking-Initiative/Benchmark-Models-PEtab/master/Benchmark-Models/{model_name}/{model_name}.yaml"
petab_problem = petab.Problem.from_yaml(yaml_file)
The petab problem includes information about parameter scaling in it’s the parameter table. For the boehm model, all estimated parameters (petab.ESTIMATE
column equal to 1
) have a petab.LOG10
as parameter scaling.
[3]:
petab_problem.parameter_df
[3]:
parameterName | parameterScale | lowerBound | upperBound | nominalValue | estimate | |
---|---|---|---|---|---|---|
parameterId | ||||||
Epo_degradation_BaF3 | EPO_{degradation,BaF3} | log10 | 0.00001 | 100000 | 0.026983 | 1 |
k_exp_hetero | k_{exp,hetero} | log10 | 0.00001 | 100000 | 0.000010 | 1 |
k_exp_homo | k_{exp,homo} | log10 | 0.00001 | 100000 | 0.006170 | 1 |
k_imp_hetero | k_{imp,hetero} | log10 | 0.00001 | 100000 | 0.016368 | 1 |
k_imp_homo | k_{imp,homo} | log10 | 0.00001 | 100000 | 97749.379402 | 1 |
k_phos | k_{phos} | log10 | 0.00001 | 100000 | 15766.507020 | 1 |
ratio | ratio | lin | -5.00000 | 5 | 0.693000 | 0 |
sd_pSTAT5A_rel | \sigma_{pSTAT5A,rel} | log10 | 0.00001 | 100000 | 3.852612 | 1 |
sd_pSTAT5B_rel | \sigma_{pSTAT5B,rel} | log10 | 0.00001 | 100000 | 6.591478 | 1 |
sd_rSTAT5A_rel | \sigma_{rSTAT5A,rel} | log10 | 0.00001 | 100000 | 3.152713 | 1 |
specC17 | specC17 | lin | -5.00000 | 5 | 0.107000 | 0 |
We now import the petab problem using amici.petab_import.
[4]:
from amici.petab.petab_import import import_petab_problem
amici_model = import_petab_problem(
petab_problem, compile_=True, verbose=False
)
JAX implementation
For full jax support, we would have to implement a new primitive, which would require quite a bit of engineering, and in the end wouldn’t add much benefit since AMICI can’t run on GPUs. Instead, will interface AMICI using the experimental jax module host_callback.
To do so, we define a base function that only takes a single argument (the parameters) and runs simulation using petab via simulate_petab. To enable gradient computation later on, we create a solver object and set the sensitivity order to first order and pass it to simulate_petab
. Moreover, simulate_petab
expects a dictionary of parameters, so we create a dictionary
using the free parameter ids from the petab problem. As we want to implement parameter transformation in JAX, we disable parameter scaling in petab by passing scaled_parameters=True
.
[5]:
from amici.petab.simulations import simulate_petab
import amici
amici_solver = amici_model.getSolver()
amici_solver.setSensitivityOrder(amici.SensitivityOrder.first)
def amici_hcb_base(parameters: jnp.array):
return simulate_petab(
petab_problem,
amici_model,
problem_parameters=dict(zip(petab_problem.x_free_ids, parameters)),
solver=amici_solver,
)
Now we can use this base function to create two separate functions that compute the log-likelihood (llh
) and its gradient (sllh
) in two individual routines. Note that, as we are using the same base function here, the log-likelihood computation will also run with sensitivities which is not necessary and will add some overhead. This is only out of convenience and should be fixed in an application where efficiency is important.
[6]:
def amici_hcb_llh(parameters: jnp.array):
return amici_hcb_base(parameters)["llh"]
def amici_hcb_sllh(parameters: jnp.array):
sllh = amici_hcb_base(parameters)["sllh"]
return jnp.asarray(
tuple(sllh[par_id] for par_id in petab_problem.x_free_ids)
)
Now we can finally define the JAX function that runs amici simulation using the host callback. We add a custom_jvp
decorator so that we can define a custom jacobian vector product function in the next step. More details about custom jacobian vector product functions can be found in the JAX documentation
[7]:
import jax.experimental.host_callback as hcb
from jax import custom_jvp
import numpy as np
@custom_jvp
def jax_objective(parameters: jnp.array):
return hcb.call(
amici_hcb_llh,
parameters,
result_shape=jax.ShapeDtypeStruct((), np.float64),
)
Now we define the function that implement the jacobian vector product. This effectively just returns the objective function value (computed using the previously defined jax_objective
) as well as the inner product of the gradient (computed using a host callback to the previously defined amici_hcb_sllh
) and the tangents vector. Note that this implementation performs two simulation runs, one for the function value and one for the gradient, which is inefficient and could be avoided by
caching solutions.
[8]:
@jax_objective.defjvp
def jax_objective_jvp(primals: jnp.array, tangents: jnp.array):
(parameters,) = primals
(x_dot,) = tangents
llh = jax_objective(parameters)
sllh = hcb.call(
amici_hcb_sllh,
parameters,
result_shape=jax.ShapeDtypeStruct(
(petab_problem.parameter_df.estimate.sum(),), np.float64
),
)
return llh, sllh.dot(x_dot)
As last step, we implement the parameter transformation in jax. This effectively just extracts parameter scales from the petab problem, implements rescaling in jax and then passes the scaled parameters to the previously objective function we previously defined. We add the value_and_grad
decorator such that the generated jax function returns both function value and function gradient in a tuple. Moreover, we add the jax.jit
decorator such that the function is just in time
compiled upon the first function call.
[9]:
from jax import value_and_grad
parameter_scales = petab_problem.parameter_df.loc[
petab_problem.x_free_ids, petab.PARAMETER_SCALE
].values
@jax.jit
@value_and_grad
def jax_objective_with_parameter_transform(parameters: jnp.array):
par_scaled = jnp.asarray(
tuple(
value
if scale == petab.LIN
else jnp.exp(value)
if scale == petab.LOG
else jnp.power(10, value)
for value, scale in zip(parameters, parameter_scales)
)
)
return jax_objective(par_scaled)
Testing
We can now run the function to compute the log-likelihood and the gradient.
[10]:
parameters = dict(zip(petab_problem.x_free_ids, petab_problem.x_nominal_free))
scaled_parameters = petab_problem.scale_parameters(parameters)
scaled_parameters_np = np.asarray(list(scaled_parameters.values()))
[11]:
llh_jax, sllh_jax = jax_objective_with_parameter_transform(
scaled_parameters_np
)
As a sanity check, we compare the computed value to native parameter transformation in amici.
[12]:
r = simulate_petab(
petab_problem,
amici_model,
solver=amici_solver,
scaled_parameters=True,
scaled_gradients=True,
problem_parameters=scaled_parameters,
)
[13]:
import pandas as pd
pd.DataFrame(
dict(
amici=r["llh"],
jax=float(llh_jax),
rel_diff=(r["llh"] - float(llh_jax)) / r["llh"],
),
index=("llh",),
)
[13]:
amici | jax | rel_diff | |
---|---|---|---|
llh | -138.221997 | -138.222 | -2.135248e-08 |
[14]:
grad_amici = np.asarray(list(r["sllh"].values()))
grad_jax = np.asarray(sllh_jax)
rel_diff = (grad_amici - grad_jax) / grad_jax
pd.DataFrame(
index=r["sllh"].keys(),
data=dict(amici=grad_amici, jax=grad_jax, rel_diff=rel_diff),
)
[14]:
amici | jax | rel_diff | |
---|---|---|---|
Epo_degradation_BaF3 | -0.022045 | -0.022034 | 4.645833e-04 |
k_exp_hetero | -0.055323 | -0.055323 | 8.646725e-08 |
k_exp_homo | -0.005789 | -0.005801 | -2.013520e-03 |
k_imp_hetero | -0.005414 | -0.005403 | 1.973517e-03 |
k_imp_homo | 0.000045 | 0.000045 | 1.119566e-06 |
k_phos | -0.007907 | -0.007794 | 1.447768e-02 |
sd_pSTAT5A_rel | -0.010784 | -0.010800 | -1.469604e-03 |
sd_pSTAT5B_rel | -0.024037 | -0.024037 | -8.729860e-06 |
sd_rSTAT5A_rel | -0.019191 | -0.019186 | 2.829431e-04 |
We see quite some differences in the gradient calculation, with over to 1% error for k_phos
. The primary reason is that running JAX in default configuration will use float32 precision for the parameters that are passed to AMICI, which uses float64, and the derivative of the parameter transformation. As AMICI simulations that run on the CPU are the most expensive operation, there is barely any tradeoff for using float32 vs. float64 in JAX. Therefore, we configure JAX to use float64 instead
and rerun simulations.
[15]:
jax.config.update("jax_enable_x64", True)
llh_jax, sllh_jax = jax_objective_with_parameter_transform(
scaled_parameters_np
)
We can now evaluate the results again and see that differences between pure AMICI and AMICI/JAX implementations have now disappeared.
[16]:
pd.DataFrame(
dict(
amici=r["llh"],
jax=float(llh_jax),
rel_diff=(r["llh"] - float(llh_jax)) / r["llh"],
),
index=("llh",),
)
[16]:
amici | jax | rel_diff | |
---|---|---|---|
llh | -138.221997 | -138.221997 | -0.0 |
[17]:
grad_amici = np.asarray(list(r["sllh"].values()))
grad_jax = np.asarray(sllh_jax)
rel_diff = (grad_amici - grad_jax) / grad_jax
pd.DataFrame(
index=r["sllh"].keys(),
data=dict(amici=grad_amici, jax=grad_jax, rel_diff=rel_diff),
)
[17]:
amici | jax | rel_diff | |
---|---|---|---|
Epo_degradation_BaF3 | -0.022045 | -0.022045 | -0.0 |
k_exp_hetero | -0.055323 | -0.055323 | -0.0 |
k_exp_homo | -0.005789 | -0.005789 | -0.0 |
k_imp_hetero | -0.005414 | -0.005414 | -0.0 |
k_imp_homo | 0.000045 | 0.000045 | 0.0 |
k_phos | -0.007907 | -0.007907 | -0.0 |
sd_pSTAT5A_rel | -0.010784 | -0.010784 | -0.0 |
sd_pSTAT5B_rel | -0.024037 | -0.024037 | -0.0 |
sd_rSTAT5A_rel | -0.019191 | -0.019191 | -0.0 |