"""
C++ object views
----------------
This module provides views on C++ objects for efficient access.
"""
import numpy as np
import copy
import collections
import amici
from . import ExpDataPtr, ReturnDataPtr, ExpData, ReturnData, Model
from typing import Union, List, Dict, Iterator, Literal
[docs]class SwigPtrView(collections.abc.Mapping):
"""
Interface class to expose ``std::vector<double>`` and scalar members of
swig wrapped C++ objects as numpy array attributes and fields. This
class is memory efficient as copies of the underlying C++ objects is
only created when respective fields are accessed for the first time.
Cached copies are used for all subsequent calls.
:ivar _swigptr: pointer to the C++ object
:ivar _field_names: names of members that will be exposed as numpy arrays
:ivar _field_dimensions: dimensions of numpy arrays
:ivar _cache: dictionary with cached values
"""
_swigptr = None
_field_names: List[str] = []
_field_dimensions: Dict[str, List[int]] = dict()
def __getitem__(self, item: str) -> Union[np.ndarray, float]:
"""
Access to field names, copies data from C++ object into numpy
array, reshapes according to field dimensions and stores values in
cache.
:param item: field name
:return: value
"""
if self._swigptr is None:
raise NotImplementedError('Cannot get items from abstract class.')
if item == 'ptr':
return self._swigptr
if item in self._cache:
return self._cache[item]
if item == 'id':
return getattr(self._swigptr, item)
if item not in self._field_names:
self.__missing__(item)
value = _field_as_numpy(
self._field_dimensions, item, self._swigptr
)
self._cache[item] = value
return value
def __missing__(self, key: str) -> None:
"""
Default behaviour for missing keys
:param key: field name
"""
raise KeyError(f'Unknown field name {key}.')
def __getattr__(self, item) -> Union[np.ndarray, float]:
"""
Attribute accessor for field names
:param item: field name
:returns: value
"""
return self.__getitem__(item)
[docs] def __init__(self, swigptr):
"""
Constructor
:param swigptr: pointer to the C++ object
"""
self._swigptr = swigptr
self._cache = {}
super(SwigPtrView, self).__init__()
def __len__(self) -> int:
"""
Returns the number of available keys/fields
:returns: length of _field_names
"""
return len(self._field_names)
def __iter__(self) -> Iterator:
"""
Create an iterator of the keys/fields
:returns: iterator over _field_names
"""
return iter(self._field_names)
def __copy__(self):
"""
Create a shallow copy
:return: SwigPtrView shallow copy
"""
other = SwigPtrView(self._swigptr)
other._field_names = self._field_names
other._field_dimensions = self._field_dimensions
other._cache = self._cache
return other
def __contains__(self, item) -> bool:
"""
Faster implementation of ``__contains__`` that avoids copy of the field
:param item: item to check for
:returns: whether item is available as key
"""
return item in self._field_names
def __deepcopy__(self, memo):
"""
Create a deep copy
:param memo: dict with id-to-object mapping
:returns: SwigPtrView deep copy
"""
other = SwigPtrView(self._swigptr)
other._field_names = copy.deepcopy(self._field_names)
other._field_dimensions = copy.deepcopy(self._field_dimensions)
other._cache = copy.deepcopy(self._cache)
return other
def __repr__(self):
"""
String representation of the object
:returns: string representation
"""
return f'<{self.__class__.__name__}({self._swigptr})>'
[docs]class ReturnDataView(SwigPtrView):
"""
Interface class for C++ :class:`amici.ReturnData` objects that avoids
possibly costly copies of member data.
"""
_field_names = [
'ts', 'x', 'x0', 'x_ss', 'sx', 'sx0', 'sx_ss', 'y', 'sigmay',
'sy', 'ssigmay', 'z', 'rz', 'sigmaz', 'sz', 'srz',
'ssigmaz', 'sllh', 's2llh', 'J', 'xdot', 'status', 'llh',
'chi2', 'res', 'sres', 'FIM', 'w', 'preeq_wrms', 'preeq_t',
'preeq_numsteps', 'preeq_numstepsB', 'preeq_status', 'preeq_cpu_time',
'preeq_cpu_timeB', 'posteq_wrms', 'posteq_t', 'posteq_numsteps',
'posteq_numstepsB', 'posteq_status', 'posteq_cpu_time',
'posteq_cpu_timeB', 'numsteps', 'numrhsevals',
'numerrtestfails', 'numnonlinsolvconvfails', 'order', 'cpu_time',
'numstepsB', 'numrhsevalsB', 'numerrtestfailsB',
'numnonlinsolvconvfailsB', 'cpu_timeB', 'cpu_time_total'
]
[docs] def __init__(self, rdata: Union[ReturnDataPtr, ReturnData]):
"""
Constructor
:param rdata: pointer to the ``ReturnData`` instance
"""
if not isinstance(rdata, (ReturnDataPtr, ReturnData)):
raise TypeError(f'Unsupported pointer {type(rdata)}, must be'
f'amici.ExpDataPtr!')
self._field_dimensions = {
'ts': [rdata.nt],
'x': [rdata.nt, rdata.nx],
'x0': [rdata.nx],
'x_ss': [rdata.nx],
'sx': [rdata.nt, rdata.nplist, rdata.nx],
'sx0': [rdata.nplist, rdata.nx],
'sx_ss': [rdata.nplist, rdata.nx],
# observables
'y': [rdata.nt, rdata.ny],
'sigmay': [rdata.nt, rdata.ny],
'sy': [rdata.nt, rdata.nplist, rdata.ny],
'ssigmay': [rdata.nt, rdata.nplist, rdata.ny],
# event observables
'z': [rdata.nmaxevent, rdata.nz],
'rz': [rdata.nmaxevent, rdata.nz],
'sigmaz': [rdata.nmaxevent, rdata.nz],
'sz': [rdata.nmaxevent, rdata.nplist, rdata.nz],
'srz': [rdata.nmaxevent, rdata.nplist, rdata.nz],
'ssigmaz': [rdata.nmaxevent, rdata.nplist, rdata.nz],
# objective function
'sllh': [rdata.nplist],
's2llh': [rdata.np, rdata.nplist],
'res': [rdata.nt * rdata.nytrue *
(2 if rdata.sigma_res else 1)],
'sres': [rdata.nt * rdata.nytrue *
(2 if rdata.sigma_res else 1), rdata.nplist],
'FIM': [rdata.nplist, rdata.nplist],
# diagnosis
'J': [rdata.nx_solver, rdata.nx_solver],
'w': [rdata.nt, rdata.nw],
'xdot': [rdata.nx_solver],
'preeq_numlinsteps': [rdata.newton_maxsteps, 2],
'preeq_numsteps': [1, 3],
'preeq_status': [1, 3],
'posteq_numlinsteps': [rdata.newton_maxsteps, 2],
'posteq_numsteps': [1, 3],
'posteq_status': [1, 3],
'numsteps': [rdata.nt],
'numrhsevals': [rdata.nt],
'numerrtestfails': [rdata.nt],
'numnonlinsolvconvfails': [rdata.nt],
'order': [rdata.nt],
'numstepsB': [rdata.nt],
'numrhsevalsB': [rdata.nt],
'numerrtestfailsB': [rdata.nt],
'numnonlinsolvconvfailsB': [rdata.nt],
}
super(ReturnDataView, self).__init__(rdata)
def __getitem__(self, item: str) -> Union[np.ndarray, ReturnDataPtr,
ReturnData, float]:
"""
Access fields by name.s
Custom ``__getitem__`` implementation shim to map ``t`` to ``ts``.
:param item: field/attribute key
:returns: self[item]
"""
if item == 'status':
return int(super().__getitem__(item))
if item == 't':
item = 'ts'
return super().__getitem__(item)
[docs] def by_id(
self,
entity_id: str,
field: str = None,
model: Model = None
) -> np.array:
"""
Get the value of a given field for a named entity.
:param entity_id: The ID of the model entity that is to be extracted
from ``field`` (e.g. a state ID).
:param field: The requested field, e.g. 'x' for model states. This is
optional if field would be one of ``{'x', 'y', 'w'}``
:param model: The model from which this ReturnDataView was generated.
This is optional if this ReturnData was generated with
``solver.getReturnDataReportingMode() == amici.RDataReporting.full``.
"""
if field is None:
field = _entity_type_from_id(entity_id, self, model)
if field in {'x', 'x0', 'x_ss', 'sx', 'sx0', 'sx_ss'}:
ids = (model and model.getStateIds()) or self._swigptr.state_ids
elif field in {'w'}:
ids = (model and model.getExpressionIds()) \
or self._swigptr.expression_ids
elif field in {'y', 'sy', 'sigmay'}:
ids = (model and model.getObservableIds()) \
or self._swigptr.observable_ids
elif field in {'sllh'}:
ids = (model and model.getParameterIds()) \
or self._swigptr.parameter_ids
else:
raise NotImplementedError(
f"Subsetting {field} by ID is not implemented or not possible."
)
col_index = ids.index(entity_id)
return getattr(self, field)[:, ..., col_index]
[docs]class ExpDataView(SwigPtrView):
"""
Interface class for C++ Exp Data objects that avoids possibly costly
copies of member data.
"""
_field_names = [
'observedData', 'observedDataStdDev', 'observedEvents',
'observedEventsStdDev', 'fixedParameters',
'fixedParametersPreequilibration',
'fixedParametersPresimulation'
]
[docs] def __init__(self, edata: Union[ExpDataPtr, ExpData]):
"""
Constructor
:param edata: pointer to the ExpData instance
"""
if not isinstance(edata, (ExpDataPtr, ExpData)):
raise TypeError(f'Unsupported pointer {type(edata)}, must be'
f'amici.ExpDataPtr!')
self._field_dimensions = { # observables
'observedData': [edata.nt(), edata.nytrue()],
'observedDataStdDev': [edata.nt(), edata.nytrue()],
# event observables
'observedEvents': [edata.nmaxevent(), edata.nztrue()],
'observedEventsStdDev': [edata.nmaxevent(), edata.nztrue()],
# fixed parameters
'fixedParameters': [len(edata.fixedParameters)],
'fixedParametersPreequilibration': [
len(edata.fixedParametersPreequilibration)],
'fixedParametersPresimulation': [
len(edata.fixedParametersPreequilibration)],
}
edata.observedData = edata.getObservedData()
edata.observedDataStdDev = edata.getObservedDataStdDev()
edata.observedEvents = edata.getObservedEvents()
edata.observedEventsStdDev = edata.getObservedEventsStdDev()
super(ExpDataView, self).__init__(edata)
def _field_as_numpy(
field_dimensions: Dict[str, List[int]],
field: str, data: SwigPtrView
) -> Union[np.ndarray, float, None]:
"""
Convert data object field to numpy array with dimensions according to
specified field dimensions
:param field_dimensions: dimension specifications
``dict({field: list([dim1, dim2, ...])})``
:param data: object with fields
:param field: Name of field
:returns: Field Data as numpy array with dimensions according to
specified field dimensions
"""
attr = getattr(data, field)
if field_dim := field_dimensions.get(field, None):
return None if len(attr) == 0 else np.array(attr).reshape(field_dim)
return float(attr)
def _entity_type_from_id(
entity_id: str,
rdata: Union[amici.ReturnData, 'amici.ReturnDataView'] = None,
model: amici.Model = None,
) -> Literal['x', 'y', 'w', 'p', 'k']:
"""Guess the type of some entity by its ID."""
for entity_type, symbol in (
('State', 'x'),
('Observable', 'y'),
('Expression', 'w'),
('Parameter', 'p'),
('FixedParameter', 'k')
):
if model:
if entity_id in getattr(model, f'get{entity_type}Ids')():
return symbol
else:
if entity_id in getattr(
rdata if isinstance(rdata, amici.ReturnData)
else rdata._swigptr,
f'{entity_type.lower()}_ids'):
return symbol
raise KeyError(f"Unknown symbol {entity_id}.")