Source code for amici.numpy

"""
C++ object views
----------------
This module provides views on C++ objects for efficient access.
"""

import collections
import copy
from typing import Dict, Iterator, List, Literal, Union

import amici
import numpy as np

from . import ExpData, ExpDataPtr, Model, ReturnData, ReturnDataPtr


[docs]class SwigPtrView(collections.abc.Mapping): """ Interface class to expose ``std::vector<double>`` and scalar members of swig wrapped C++ objects as numpy array attributes and fields. This class is memory efficient as copies of the underlying C++ objects is only created when respective fields are accessed for the first time. Cached copies are used for all subsequent calls. :ivar _swigptr: pointer to the C++ object :ivar _field_names: names of members that will be exposed as numpy arrays :ivar _field_dimensions: dimensions of numpy arrays :ivar _cache: dictionary with cached values """ _swigptr = None _field_names: List[str] = [] _field_dimensions: Dict[str, List[int]] = dict() def __getitem__(self, item: str) -> Union[np.ndarray, float]: """ Access to field names, copies data from C++ object into numpy array, reshapes according to field dimensions and stores values in cache. :param item: field name :return: value """ if self._swigptr is None: raise NotImplementedError("Cannot get items from abstract class.") if item == "ptr": return self._swigptr if item in self._cache: return self._cache[item] if item == "id": return getattr(self._swigptr, item) if item not in self._field_names: self.__missing__(item) value = _field_as_numpy(self._field_dimensions, item, self._swigptr) self._cache[item] = value return value def __missing__(self, key: str) -> None: """ Default behaviour for missing keys :param key: field name """ raise KeyError(f"Unknown field name {key}.") def __getattr__(self, item) -> Union[np.ndarray, float]: """ Attribute accessor for field names :param item: field name :returns: value """ return self.__getitem__(item)
[docs] def __init__(self, swigptr): """ Constructor :param swigptr: pointer to the C++ object """ self._swigptr = swigptr self._cache = {} super(SwigPtrView, self).__init__()
def __len__(self) -> int: """ Returns the number of available keys/fields :returns: length of _field_names """ return len(self._field_names) def __iter__(self) -> Iterator: """ Create an iterator of the keys/fields :returns: iterator over _field_names """ return iter(self._field_names) def __copy__(self): """ Create a shallow copy :return: SwigPtrView shallow copy """ other = SwigPtrView(self._swigptr) other._field_names = self._field_names other._field_dimensions = self._field_dimensions other._cache = self._cache return other def __contains__(self, item) -> bool: """ Faster implementation of ``__contains__`` that avoids copy of the field :param item: item to check for :returns: whether item is available as key """ return item in self._field_names def __deepcopy__(self, memo): """ Create a deep copy :param memo: dict with id-to-object mapping :returns: SwigPtrView deep copy """ other = SwigPtrView(self._swigptr) other._field_names = copy.deepcopy(self._field_names) other._field_dimensions = copy.deepcopy(self._field_dimensions) other._cache = copy.deepcopy(self._cache) return other def __repr__(self): """ String representation of the object :returns: string representation """ return f"<{self.__class__.__name__}({self._swigptr})>"
[docs]class ReturnDataView(SwigPtrView): """ Interface class for C++ :class:`amici.ReturnData` objects that avoids possibly costly copies of member data. """ _field_names = [ "ts", "x", "x0", "x_ss", "sx", "sx0", "sx_ss", "y", "sigmay", "sy", "ssigmay", "z", "rz", "sigmaz", "sz", "srz", "ssigmaz", "sllh", "s2llh", "J", "xdot", "status", "llh", "chi2", "res", "sres", "FIM", "w", "preeq_wrms", "preeq_t", "preeq_numsteps", "preeq_numstepsB", "preeq_status", "preeq_cpu_time", "preeq_cpu_timeB", "posteq_wrms", "posteq_t", "posteq_numsteps", "posteq_numstepsB", "posteq_status", "posteq_cpu_time", "posteq_cpu_timeB", "numsteps", "numrhsevals", "numerrtestfails", "numnonlinsolvconvfails", "order", "cpu_time", "numstepsB", "numrhsevalsB", "numerrtestfailsB", "numnonlinsolvconvfailsB", "cpu_timeB", "cpu_time_total", ]
[docs] def __init__(self, rdata: Union[ReturnDataPtr, ReturnData]): """ Constructor :param rdata: pointer to the ``ReturnData`` instance """ if not isinstance(rdata, (ReturnDataPtr, ReturnData)): raise TypeError( f"Unsupported pointer {type(rdata)}, must be" f"amici.ExpDataPtr!" ) self._field_dimensions = { "ts": [rdata.nt], "x": [rdata.nt, rdata.nx], "x0": [rdata.nx], "x_ss": [rdata.nx], "sx": [rdata.nt, rdata.nplist, rdata.nx], "sx0": [rdata.nplist, rdata.nx], "sx_ss": [rdata.nplist, rdata.nx], # observables "y": [rdata.nt, rdata.ny], "sigmay": [rdata.nt, rdata.ny], "sy": [rdata.nt, rdata.nplist, rdata.ny], "ssigmay": [rdata.nt, rdata.nplist, rdata.ny], # event observables "z": [rdata.nmaxevent, rdata.nz], "rz": [rdata.nmaxevent, rdata.nz], "sigmaz": [rdata.nmaxevent, rdata.nz], "sz": [rdata.nmaxevent, rdata.nplist, rdata.nz], "srz": [rdata.nmaxevent, rdata.nplist, rdata.nz], "ssigmaz": [rdata.nmaxevent, rdata.nplist, rdata.nz], # objective function "sllh": [rdata.nplist], "s2llh": [rdata.np, rdata.nplist], "res": [rdata.nt * rdata.nytrue * (2 if rdata.sigma_res else 1)], "sres": [ rdata.nt * rdata.nytrue * (2 if rdata.sigma_res else 1), rdata.nplist, ], "FIM": [rdata.nplist, rdata.nplist], # diagnosis "J": [rdata.nx_solver, rdata.nx_solver], "w": [rdata.nt, rdata.nw], "xdot": [rdata.nx_solver], "preeq_numlinsteps": [rdata.newton_maxsteps, 2], "preeq_numsteps": [1, 3], "preeq_status": [1, 3], "posteq_numlinsteps": [rdata.newton_maxsteps, 2], "posteq_numsteps": [1, 3], "posteq_status": [1, 3], "numsteps": [rdata.nt], "numrhsevals": [rdata.nt], "numerrtestfails": [rdata.nt], "numnonlinsolvconvfails": [rdata.nt], "order": [rdata.nt], "numstepsB": [rdata.nt], "numrhsevalsB": [rdata.nt], "numerrtestfailsB": [rdata.nt], "numnonlinsolvconvfailsB": [rdata.nt], } super(ReturnDataView, self).__init__(rdata)
def __getitem__( self, item: str ) -> Union[np.ndarray, ReturnDataPtr, ReturnData, float]: """ Access fields by name.s Custom ``__getitem__`` implementation shim to map ``t`` to ``ts``. :param item: field/attribute key :returns: self[item] """ if item == "status": return int(super().__getitem__(item)) if item == "t": item = "ts" return super().__getitem__(item)
[docs] def by_id(self, entity_id: str, field: str = None, model: Model = None) -> np.array: """ Get the value of a given field for a named entity. :param entity_id: The ID of the model entity that is to be extracted from ``field`` (e.g. a state ID). :param field: The requested field, e.g. 'x' for model states. This is optional if field would be one of ``{'x', 'y', 'w'}`` :param model: The model from which this ReturnDataView was generated. This is optional if this ReturnData was generated with ``solver.getReturnDataReportingMode() == amici.RDataReporting.full``. """ if field is None: field = _entity_type_from_id(entity_id, self, model) if field in {"x", "x0", "x_ss", "sx", "sx0", "sx_ss"}: ids = (model and model.getStateIds()) or self._swigptr.state_ids elif field in {"w"}: ids = (model and model.getExpressionIds()) or self._swigptr.expression_ids elif field in {"y", "sy", "sigmay"}: ids = (model and model.getObservableIds()) or self._swigptr.observable_ids elif field in {"sllh"}: ids = (model and model.getParameterIds()) or self._swigptr.parameter_ids else: raise NotImplementedError( f"Subsetting {field} by ID is not implemented or not possible." ) col_index = ids.index(entity_id) return getattr(self, field)[:, ..., col_index]
[docs]class ExpDataView(SwigPtrView): """ Interface class for C++ Exp Data objects that avoids possibly costly copies of member data. """ _field_names = [ "observedData", "observedDataStdDev", "observedEvents", "observedEventsStdDev", "fixedParameters", "fixedParametersPreequilibration", "fixedParametersPresimulation", ]
[docs] def __init__(self, edata: Union[ExpDataPtr, ExpData]): """ Constructor :param edata: pointer to the ExpData instance """ if not isinstance(edata, (ExpDataPtr, ExpData)): raise TypeError( f"Unsupported pointer {type(edata)}, must be" f"amici.ExpDataPtr!" ) self._field_dimensions = { # observables "observedData": [edata.nt(), edata.nytrue()], "observedDataStdDev": [edata.nt(), edata.nytrue()], # event observables "observedEvents": [edata.nmaxevent(), edata.nztrue()], "observedEventsStdDev": [edata.nmaxevent(), edata.nztrue()], # fixed parameters "fixedParameters": [len(edata.fixedParameters)], "fixedParametersPreequilibration": [ len(edata.fixedParametersPreequilibration) ], "fixedParametersPresimulation": [ len(edata.fixedParametersPreequilibration) ], } edata.observedData = edata.getObservedData() edata.observedDataStdDev = edata.getObservedDataStdDev() edata.observedEvents = edata.getObservedEvents() edata.observedEventsStdDev = edata.getObservedEventsStdDev() super(ExpDataView, self).__init__(edata)
def _field_as_numpy( field_dimensions: Dict[str, List[int]], field: str, data: SwigPtrView ) -> Union[np.ndarray, float, None]: """ Convert data object field to numpy array with dimensions according to specified field dimensions :param field_dimensions: dimension specifications ``dict({field: list([dim1, dim2, ...])})`` :param data: object with fields :param field: Name of field :returns: Field Data as numpy array with dimensions according to specified field dimensions """ attr = getattr(data, field) if field_dim := field_dimensions.get(field, None): return None if len(attr) == 0 else np.array(attr).reshape(field_dim) return float(attr) def _entity_type_from_id( entity_id: str, rdata: Union[amici.ReturnData, "amici.ReturnDataView"] = None, model: amici.Model = None, ) -> Literal["x", "y", "w", "p", "k"]: """Guess the type of some entity by its ID.""" for entity_type, symbol in ( ("State", "x"), ("Observable", "y"), ("Expression", "w"), ("Parameter", "p"), ("FixedParameter", "k"), ): if model: if entity_id in getattr(model, f"get{entity_type}Ids")(): return symbol else: if entity_id in getattr( rdata if isinstance(rdata, amici.ReturnData) else rdata._swigptr, f"{entity_type.lower()}_ids", ): return symbol raise KeyError(f"Unknown symbol {entity_id}.")