Source code for amici.numpy

"""
C++ object views
----------------
This module provides views on C++ objects for efficient access.
"""

import numpy as np
import copy
import collections

import amici
from . import ExpDataPtr, ReturnDataPtr, ExpData, ReturnData, Model
from typing import Union, List, Dict, Iterator, Literal


[docs]class SwigPtrView(collections.abc.Mapping): """ Interface class to expose ``std::vector<double>`` and scalar members of swig wrapped C++ objects as numpy array attributes and fields. This class is memory efficient as copies of the underlying C++ objects is only created when respective fields are accessed for the first time. Cached copies are used for all subsequent calls. :ivar _swigptr: pointer to the C++ object :ivar _field_names: names of members that will be exposed as numpy arrays :ivar _field_dimensions: dimensions of numpy arrays :ivar _cache: dictionary with cached values """ _swigptr = None _field_names: List[str] = [] _field_dimensions: Dict[str, List[int]] = dict() def __getitem__(self, item: str) -> Union[np.ndarray, float]: """ Access to field names, copies data from C++ object into numpy array, reshapes according to field dimensions and stores values in cache. :param item: field name :return: value """ if self._swigptr is None: raise NotImplementedError('Cannot get items from abstract class.') if item == 'ptr': return self._swigptr if item in self._cache: return self._cache[item] if item == 'id': return getattr(self._swigptr, item) if item not in self._field_names: self.__missing__(item) value = _field_as_numpy( self._field_dimensions, item, self._swigptr ) self._cache[item] = value return value def __missing__(self, key: str) -> None: """ Default behaviour for missing keys :param key: field name """ raise KeyError(f'Unknown field name {key}.') def __getattr__(self, item) -> Union[np.ndarray, float]: """ Attribute accessor for field names :param item: field name :returns: value """ return self.__getitem__(item) def __init__(self, swigptr): """ Constructor :param swigptr: pointer to the C++ object """ self._swigptr = swigptr self._cache = {} super(SwigPtrView, self).__init__() def __len__(self) -> int: """ Returns the number of available keys/fields :returns: length of _field_names """ return len(self._field_names) def __iter__(self) -> Iterator: """ Create an iterator of the keys/fields :returns: iterator over _field_names """ return iter(self._field_names) def __copy__(self): """ Create a shallow copy :return: SwigPtrView shallow copy """ other = SwigPtrView(self._swigptr) other._field_names = self._field_names other._field_dimensions = self._field_dimensions other._cache = self._cache return other def __contains__(self, item) -> bool: """ Faster implementation of ``__contains__`` that avoids copy of the field :param item: item to check for :returns: whether item is available as key """ return item in self._field_names def __deepcopy__(self, memo): """ Create a deep copy :param memo: dict with id-to-object mapping :returns: SwigPtrView deep copy """ other = SwigPtrView(self._swigptr) other._field_names = copy.deepcopy(self._field_names) other._field_dimensions = copy.deepcopy(self._field_dimensions) other._cache = copy.deepcopy(self._cache) return other def __repr__(self): """ String representation of the object :returns: string representation """ return f'<{self.__class__.__name__}({self._swigptr})>'
[docs]class ReturnDataView(SwigPtrView): """ Interface class for C++ :class:`amici.ReturnData` objects that avoids possibly costly copies of member data. """ _field_names = [ 'ts', 'x', 'x0', 'x_ss', 'sx', 'sx0', 'sx_ss', 'y', 'sigmay', 'sy', 'ssigmay', 'z', 'rz', 'sigmaz', 'sz', 'srz', 'ssigmaz', 'sllh', 's2llh', 'J', 'xdot', 'status', 'llh', 'chi2', 'res', 'sres', 'FIM', 'w', 'preeq_wrms', 'preeq_t', 'preeq_numsteps', 'preeq_numstepsB', 'preeq_status', 'preeq_cpu_time', 'preeq_cpu_timeB', 'posteq_wrms', 'posteq_t', 'posteq_numsteps', 'posteq_numstepsB', 'posteq_status', 'posteq_cpu_time', 'posteq_cpu_timeB', 'numsteps', 'numrhsevals', 'numerrtestfails', 'numnonlinsolvconvfails', 'order', 'cpu_time', 'numstepsB', 'numrhsevalsB', 'numerrtestfailsB', 'numnonlinsolvconvfailsB', 'cpu_timeB', 'cpu_time_total' ] def __init__(self, rdata: Union[ReturnDataPtr, ReturnData]): """ Constructor :param rdata: pointer to the ``ReturnData`` instance """ if not isinstance(rdata, (ReturnDataPtr, ReturnData)): raise TypeError(f'Unsupported pointer {type(rdata)}, must be' f'amici.ExpDataPtr!') self._field_dimensions = { 'ts': [rdata.nt], 'x': [rdata.nt, rdata.nx], 'x0': [rdata.nx], 'x_ss': [rdata.nx], 'sx': [rdata.nt, rdata.nplist, rdata.nx], 'sx0': [rdata.nplist, rdata.nx], 'sx_ss': [rdata.nplist, rdata.nx], # observables 'y': [rdata.nt, rdata.ny], 'sigmay': [rdata.nt, rdata.ny], 'sy': [rdata.nt, rdata.nplist, rdata.ny], 'ssigmay': [rdata.nt, rdata.nplist, rdata.ny], # event observables 'z': [rdata.nmaxevent, rdata.nz], 'rz': [rdata.nmaxevent, rdata.nz], 'sigmaz': [rdata.nmaxevent, rdata.nz], 'sz': [rdata.nmaxevent, rdata.nplist, rdata.nz], 'srz': [rdata.nmaxevent, rdata.nplist, rdata.nz], 'ssigmaz': [rdata.nmaxevent, rdata.nplist, rdata.nz], # objective function 'sllh': [rdata.nplist], 's2llh': [rdata.np, rdata.nplist], 'res': [rdata.nt * rdata.nytrue * (2 if rdata.sigma_res else 1)], 'sres': [rdata.nt * rdata.nytrue * (2 if rdata.sigma_res else 1), rdata.nplist], 'FIM': [rdata.nplist, rdata.nplist], # diagnosis 'J': [rdata.nx_solver, rdata.nx_solver], 'w': [rdata.nt, rdata.nw], 'xdot': [rdata.nx_solver], 'preeq_numlinsteps': [rdata.newton_maxsteps, 2], 'preeq_numsteps': [1, 3], 'preeq_status': [1, 3], 'posteq_numlinsteps': [rdata.newton_maxsteps, 2], 'posteq_numsteps': [1, 3], 'posteq_status': [1, 3], 'numsteps': [rdata.nt], 'numrhsevals': [rdata.nt], 'numerrtestfails': [rdata.nt], 'numnonlinsolvconvfails': [rdata.nt], 'order': [rdata.nt], 'numstepsB': [rdata.nt], 'numrhsevalsB': [rdata.nt], 'numerrtestfailsB': [rdata.nt], 'numnonlinsolvconvfailsB': [rdata.nt], } super(ReturnDataView, self).__init__(rdata) def __getitem__(self, item: str) -> Union[np.ndarray, ReturnDataPtr, ReturnData, float]: """ Access fields by name.s Custom ``__getitem__`` implementation shim to map ``t`` to ``ts``. :param item: field/attribute key :returns: self[item] """ if item == 'status': return int(super().__getitem__(item)) if item == 't': item = 'ts' return super().__getitem__(item)
[docs] def by_id( self, entity_id: str, field: str = None, model: Model = None ) -> np.array: """ Get the value of a given field for a named entity. :param entity_id: The ID of the model entity that is to be extracted from ``field`` (e.g. a state ID). :param field: The requested field, e.g. 'x' for model states. This is optional if field would be one of ``{'x', 'y', 'w'}`` :param model: The model from which this ReturnDataView was generated. This is optional if this ReturnData was generated with ``solver.getReturnDataReportingMode() == amici.RDataReporting.full``. """ if field is None: field = _entity_type_from_id(entity_id, self, model) if field in {'x', 'x0', 'x_ss', 'sx', 'sx0', 'sx_ss'}: ids = (model and model.getStateIds()) or self._swigptr.state_ids elif field in {'w'}: ids = (model and model.getExpressionIds()) \ or self._swigptr.expression_ids elif field in {'y', 'sy', 'sigmay'}: ids = (model and model.getObservableIds()) \ or self._swigptr.observable_ids elif field in {'sllh'}: ids = (model and model.getParameterIds()) \ or self._swigptr.parameter_ids else: raise NotImplementedError( f"Subsetting {field} by ID is not implemented or not possible." ) col_index = ids.index(entity_id) return getattr(self, field)[:, ..., col_index]
[docs]class ExpDataView(SwigPtrView): """ Interface class for C++ Exp Data objects that avoids possibly costly copies of member data. """ _field_names = [ 'observedData', 'observedDataStdDev', 'observedEvents', 'observedEventsStdDev', 'fixedParameters', 'fixedParametersPreequilibration', 'fixedParametersPresimulation' ] def __init__(self, edata: Union[ExpDataPtr, ExpData]): """ Constructor :param edata: pointer to the ExpData instance """ if not isinstance(edata, (ExpDataPtr, ExpData)): raise TypeError(f'Unsupported pointer {type(edata)}, must be' f'amici.ExpDataPtr!') self._field_dimensions = { # observables 'observedData': [edata.nt(), edata.nytrue()], 'observedDataStdDev': [edata.nt(), edata.nytrue()], # event observables 'observedEvents': [edata.nmaxevent(), edata.nztrue()], 'observedEventsStdDev': [edata.nmaxevent(), edata.nztrue()], # fixed parameters 'fixedParameters': [len(edata.fixedParameters)], 'fixedParametersPreequilibration': [ len(edata.fixedParametersPreequilibration)], 'fixedParametersPresimulation': [ len(edata.fixedParametersPreequilibration)], } edata.observedData = edata.getObservedData() edata.observedDataStdDev = edata.getObservedDataStdDev() edata.observedEvents = edata.getObservedEvents() edata.observedEventsStdDev = edata.getObservedEventsStdDev() super(ExpDataView, self).__init__(edata)
def _field_as_numpy( field_dimensions: Dict[str, List[int]], field: str, data: SwigPtrView ) -> Union[np.ndarray, float, None]: """ Convert data object field to numpy array with dimensions according to specified field dimensions :param field_dimensions: dimension specifications ``dict({field: list([dim1, dim2, ...])})`` :param data: object with fields :param field: Name of field :returns: Field Data as numpy array with dimensions according to specified field dimensions """ attr = getattr(data, field) if field_dim := field_dimensions.get(field, None): return None if len(attr) == 0 else np.array(attr).reshape(field_dim) return float(attr) def _entity_type_from_id( entity_id: str, rdata: Union[amici.ReturnData, 'amici.ReturnDataView'] = None, model: amici.Model = None, ) -> Literal['x', 'y', 'w', 'p', 'k']: """Guess the type of some entity by its ID.""" for entity_type, symbol in ( ('State', 'x'), ('Observable', 'y'), ('Expression', 'w'), ('Parameter', 'p'), ('FixedParameter', 'k') ): if model: if entity_id in getattr(model, f'get{entity_type}Ids')(): return symbol else: if entity_id in getattr( rdata if isinstance(rdata, amici.ReturnData) else rdata._swigptr, f'{entity_type.lower()}_ids'): return symbol raise KeyError(f"Unknown symbol {entity_id}.")