"""
Plotting
--------
Plotting related functions
"""
from typing import Iterable, Optional
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib.axes import Axes
from . import Model, ReturnDataView
[docs]def plot_state_trajectories(
rdata: ReturnDataView,
state_indices: Optional[Iterable[int]] = None,
ax: Optional[Axes] = None,
model: Model = None
) -> None:
"""
Plot state trajectories
:param rdata:
AMICI simulation results as returned by
:func:`amici.amici.runAmiciSimulation`
:param state_indices:
Indices of states for which trajectories are to be plotted
:param ax:
matplotlib Axes instance to plot into
:param model:
amici model instance
"""
if not ax:
fig, ax = plt.subplots()
if not state_indices:
state_indices = range(rdata['x'].shape[1])
for ix in state_indices:
if model is None:
label = f'$x_{{{ix}}}$'
elif model.getStateNames()[ix]:
label = model.getStateNames()[ix]
else:
label = model.getStateIds()[ix]
ax.plot(rdata['t'], rdata['x'][:, ix], label=label)
ax.set_xlabel('$t$')
ax.set_ylabel('$x(t)$')
ax.legend()
ax.set_title('State trajectories')
[docs]def plot_observable_trajectories(
rdata: ReturnDataView,
observable_indices: Optional[Iterable[int]] = None,
ax: Optional[Axes] = None,
model: Model = None
) -> None:
"""
Plot observable trajectories
:param rdata:
AMICI simulation results as returned by
:func:`amici.amici.runAmiciSimulation`
:param observable_indices:
Indices of observables for which trajectories are to be plotted
:param ax:
matplotlib Axes instance to plot into
:param model:
amici model instance
"""
if not ax:
fig, ax = plt.subplots()
if not observable_indices:
observable_indices = range(rdata['y'].shape[1])
for iy in observable_indices:
if model is None:
label = f'$y_{{{iy}}}$'
elif model.getObservableNames()[iy]:
label = model.getObservableNames()[iy]
else:
label = model.getObservableIds()[iy]
ax.plot(rdata['t'], rdata['y'][:, iy], label=label)
ax.set_xlabel('$t$')
ax.set_ylabel('$y(t)$')
ax.legend()
ax.set_title('Observable trajectories')
[docs]def plot_jacobian(rdata: ReturnDataView):
"""Plot Jacobian as heatmap."""
df = pd.DataFrame(
data=rdata.J,
index=rdata._swigptr.state_ids_solver,
columns=rdata._swigptr.state_ids_solver,
)
sns.heatmap(df, center=0.0)
plt.title("Jacobian")
# backwards compatibility
plotStateTrajectories = plot_state_trajectories
plotObservableTrajectories = plot_observable_trajectories