Program Listing for File model.h

Return to documentation for file (include/amici/model.h)

#ifndef AMICI_MODEL_H
#define AMICI_MODEL_H

#include "amici/abstract_model.h"
#include "amici/defines.h"
#include "amici/logging.h"
#include "amici/model_dimensions.h"
#include "amici/model_state.h"
#include "amici/simulation_parameters.h"
#include "amici/splinefunctions.h"
#include "amici/sundials_matrix_wrapper.h"
#include "amici/vector.h"

#include <map>
#include <vector>

namespace amici {

class ExpData;
class Model;
class Solver;

} // namespace amici

// for serialization friend in amici::Model
namespace boost {
namespace serialization {
template <class Archive>
void serialize(Archive& ar, amici::Model& m, unsigned int version);
}
} // namespace boost

namespace amici {

enum class ModelQuantity {
    J,
    JB,
    Jv,
    JvB,
    JDiag,
    sx,
    sy,
    sz,
    srz,
    ssigmay,
    ssigmaz,
    xdot,
    sxdot,
    xBdot,
    x0_rdata,
    x0,
    x_rdata,
    x,
    dwdw,
    dwdx,
    dwdp,
    y,
    dydp,
    dydx,
    w,
    root,
    qBdot,
    qBdot_ss,
    xBdot_ss,
    JSparseB_ss,
    deltax,
    deltasx,
    deltaxB,
    k,
    p,
    ts,
    dJydy,
    dJydy_matlab,
    deltaqB,
    dsigmaydp,
    dsigmaydy,
    dsigmazdp,
    dJydsigma,
    dJydx,
    dzdx,
    dzdp,
    dJrzdsigma,
    dJrzdz,
    dJrzdx,
    dJzdsigma,
    dJzdz,
    dJzdx,
    drzdp,
    drzdx,
};

extern std::map<ModelQuantity, std::string> const model_quantity_to_str;

class Model : public AbstractModel, public ModelDimensions {
  public:
    Model() = default;

    Model(
        ModelDimensions const& model_dimensions,
        SimulationParameters simulation_parameters,
        amici::SecondOrderMode o2mode, std::vector<amici::realtype> idlist,
        std::vector<int> z2event, bool pythonGenerated = false,
        int ndxdotdp_explicit = 0, int ndxdotdx_explicit = 0,
        int w_recursion_depth = 0,
        std::map<realtype, std::vector<int>> state_independent_events = {}
    );

    ~Model() override = default;

    Model& operator=(Model const& other) = delete;

    virtual Model* clone() const = 0;

    template <class Archive>
    friend void boost::serialization::serialize(
        Archive& ar, Model& m, unsigned int version
    );

    friend bool operator==(Model const& a, Model const& b);

    // Overloaded base class methods
    using AbstractModel::fdeltaqB;
    using AbstractModel::fdeltasx;
    using AbstractModel::fdeltax;
    using AbstractModel::fdeltaxB;
    using AbstractModel::fdJrzdsigma;
    using AbstractModel::fdJrzdz;
    using AbstractModel::fdJydsigma;
    using AbstractModel::fdJydy;
    using AbstractModel::fdJydy_colptrs;
    using AbstractModel::fdJydy_rowvals;
    using AbstractModel::fdJzdsigma;
    using AbstractModel::fdJzdz;
    using AbstractModel::fdrzdp;
    using AbstractModel::fdrzdx;
    using AbstractModel::fdsigmaydp;
    using AbstractModel::fdsigmaydy;
    using AbstractModel::fdsigmazdp;
    using AbstractModel::fdtotal_cldp;
    using AbstractModel::fdtotal_cldx_rdata;
    using AbstractModel::fdtotal_cldx_rdata_colptrs;
    using AbstractModel::fdtotal_cldx_rdata_rowvals;
    using AbstractModel::fdwdp;
    using AbstractModel::fdwdp_colptrs;
    using AbstractModel::fdwdp_rowvals;
    using AbstractModel::fdwdw;
    using AbstractModel::fdwdw_colptrs;
    using AbstractModel::fdwdw_rowvals;
    using AbstractModel::fdwdx;
    using AbstractModel::fdwdx_colptrs;
    using AbstractModel::fdwdx_rowvals;
    using AbstractModel::fdx_rdatadp;
    using AbstractModel::fdx_rdatadtcl;
    using AbstractModel::fdx_rdatadtcl_colptrs;
    using AbstractModel::fdx_rdatadtcl_rowvals;
    using AbstractModel::fdx_rdatadx_solver;
    using AbstractModel::fdx_rdatadx_solver_colptrs;
    using AbstractModel::fdx_rdatadx_solver_rowvals;
    using AbstractModel::fdydp;
    using AbstractModel::fdydx;
    using AbstractModel::fdzdp;
    using AbstractModel::fdzdx;
    using AbstractModel::fJrz;
    using AbstractModel::fJy;
    using AbstractModel::fJz;
    using AbstractModel::frz;
    using AbstractModel::fsigmay;
    using AbstractModel::fsigmaz;
    using AbstractModel::fsrz;
    using AbstractModel::fstau;
    using AbstractModel::fsx0;
    using AbstractModel::fsx0_fixedParameters;
    using AbstractModel::fsz;
    using AbstractModel::fw;
    using AbstractModel::fx0;
    using AbstractModel::fx0_fixedParameters;
    using AbstractModel::fy;
    using AbstractModel::fz;

    void initialize(
        AmiVector& x, AmiVector& dx, AmiVectorArray& sx, AmiVectorArray& sdx,
        bool computeSensitivities, std::vector<int>& roots_found
    );

    void reinitialize(
        realtype t, AmiVector& x, AmiVectorArray& sx, bool computeSensitivities
    );

    void initializeB(AmiVector& xB, AmiVector& dxB, AmiVector& xQB, bool posteq)
        const;

    void initializeStates(AmiVector& x);

    void initializeStateSensitivities(AmiVectorArray& sx, AmiVector const& x);

    void initializeSplines();

    void initializeSplineSensitivities();

    void initEvents(
        AmiVector const& x, AmiVector const& dx, std::vector<int>& roots_found
    );

    int nplist() const;

    int np() const;

    int nk() const;

    int ncl() const;

    int nx_reinit() const;

    double const* k() const;

    int nMaxEvent() const;

    void setNMaxEvent(int nmaxevent);

    int nt() const;

    std::vector<ParameterScaling> const& getParameterScale() const;

    void setParameterScale(ParameterScaling pscale);

    void setParameterScale(std::vector<ParameterScaling> const& pscaleVec);

    std::vector<realtype> const& getUnscaledParameters() const;

    std::vector<realtype> const& getParameters() const;

    realtype getParameterById(std::string const& par_id) const;

    realtype getParameterByName(std::string const& par_name) const;

    void setParameters(std::vector<realtype> const& p);

    void setParameterById(
        std::map<std::string, realtype> const& p, bool ignoreErrors = false
    );

    void setParameterById(std::string const& par_id, realtype value);

    int setParametersByIdRegex(std::string const& par_id_regex, realtype value);

    void setParameterByName(std::string const& par_name, realtype value);

    void setParameterByName(
        std::map<std::string, realtype> const& p, bool ignoreErrors = false
    );

    int
    setParametersByNameRegex(std::string const& par_name_regex, realtype value);

    std::vector<realtype> const& getFixedParameters() const;

    realtype getFixedParameterById(std::string const& par_id) const;

    realtype getFixedParameterByName(std::string const& par_name) const;

    void setFixedParameters(std::vector<realtype> const& k);

    void setFixedParameterById(std::string const& par_id, realtype value);

    int setFixedParametersByIdRegex(
        std::string const& par_id_regex, realtype value
    );

    void setFixedParameterByName(std::string const& par_name, realtype value);

    int setFixedParametersByNameRegex(
        std::string const& par_name_regex, realtype value
    );

    virtual std::string getName() const;

    virtual bool hasParameterNames() const;

    virtual std::vector<std::string> getParameterNames() const;

    virtual bool hasStateNames() const;

    virtual std::vector<std::string> getStateNames() const;

    virtual std::vector<std::string> getStateNamesSolver() const;

    virtual bool hasFixedParameterNames() const;

    virtual std::vector<std::string> getFixedParameterNames() const;

    virtual bool hasObservableNames() const;

    virtual std::vector<std::string> getObservableNames() const;

    virtual bool hasExpressionNames() const;

    virtual std::vector<std::string> getExpressionNames() const;

    virtual bool hasParameterIds() const;

    virtual std::vector<std::string> getParameterIds() const;

    virtual bool hasStateIds() const;

    virtual std::vector<std::string> getStateIds() const;

    virtual std::vector<std::string> getStateIdsSolver() const;

    virtual bool hasFixedParameterIds() const;

    virtual std::vector<std::string> getFixedParameterIds() const;

    virtual bool hasObservableIds() const;

    virtual std::vector<std::string> getObservableIds() const;

    virtual bool hasExpressionIds() const;

    virtual std::vector<std::string> getExpressionIds() const;

    virtual bool hasQuadraticLLH() const;

    std::vector<realtype> const& getTimepoints() const;

    realtype getTimepoint(int it) const;

    void setTimepoints(std::vector<realtype> const& ts);

    double t0() const;

    void setT0(double t0);

    std::vector<bool> const& getStateIsNonNegative() const;

    void setStateIsNonNegative(std::vector<bool> const& stateIsNonNegative);

    void setAllStatesNonNegative();

    ModelState const& getModelState() const { return state_; };

    void setModelState(ModelState const& state) {
        if (gsl::narrow<int>(state.unscaledParameters.size()) != np())
            throw AmiException("Mismatch in parameter size");
        if (gsl::narrow<int>(state.fixedParameters.size()) != nk())
            throw AmiException("Mismatch in fixed parameter size");
        if (gsl::narrow<int>(state.h.size()) != ne)
            throw AmiException("Mismatch in Heaviside size");
        if (gsl::narrow<int>(state.total_cl.size()) != ncl())
            throw AmiException("Mismatch in conservation law size");
        if (gsl::narrow<int>(state.stotal_cl.size()) != ncl() * np())
            throw AmiException("Mismatch in conservation law sensitivity size");
        state_ = state;
    };

    void setMinimumSigmaResiduals(double min_sigma) { min_sigma_ = min_sigma; }

    realtype getMinimumSigmaResiduals() const { return min_sigma_; }

    void setAddSigmaResiduals(bool sigma_res) { sigma_res_ = sigma_res; }

    bool getAddSigmaResiduals() const { return sigma_res_; }

    std::vector<int> const& getParameterList() const;

    int plist(int pos) const;

    void setParameterList(std::vector<int> const& plist);

    std::vector<realtype> getInitialStates();

    void setInitialStates(std::vector<realtype> const& x0);

    bool hasCustomInitialStates() const;

    std::vector<realtype> getInitialStateSensitivities();

    void setInitialStateSensitivities(std::vector<realtype> const& sx0);

    bool hasCustomInitialStateSensitivities() const;

    void setUnscaledInitialStateSensitivities(std::vector<realtype> const& sx0);

    void setSteadyStateComputationMode(SteadyStateComputationMode mode);

    SteadyStateComputationMode getSteadyStateComputationMode() const;

    void setSteadyStateSensitivityMode(SteadyStateSensitivityMode mode);

    SteadyStateSensitivityMode getSteadyStateSensitivityMode() const;

    void setReinitializeFixedParameterInitialStates(bool flag);

    bool getReinitializeFixedParameterInitialStates() const;

    void requireSensitivitiesForAllParameters();

    void
    getExpression(gsl::span<realtype> w, realtype const t, AmiVector const& x);

    void
    getObservable(gsl::span<realtype> y, realtype const t, AmiVector const& x);

    virtual ObservableScaling getObservableScaling(int iy) const;

    void getObservableSensitivity(
        gsl::span<realtype> sy, realtype const t, AmiVector const& x,
        AmiVectorArray const& sx
    );

    void getObservableSigma(
        gsl::span<realtype> sigmay, int const it, ExpData const* edata
    );

    void getObservableSigmaSensitivity(
        gsl::span<realtype> ssigmay, gsl::span<realtype const> sy, int const it,
        ExpData const* edata
    );

    void addObservableObjective(
        realtype& Jy, int const it, AmiVector const& x, ExpData const& edata
    );

    void addObservableObjectiveSensitivity(
        std::vector<realtype>& sllh, std::vector<realtype>& s2llh, int const it,
        AmiVector const& x, AmiVectorArray const& sx, ExpData const& edata
    );

    void addPartialObservableObjectiveSensitivity(
        std::vector<realtype>& sllh, std::vector<realtype>& s2llh, int const it,
        AmiVector const& x, ExpData const& edata
    );

    void getAdjointStateObservableUpdate(
        gsl::span<realtype> dJydx, int const it, AmiVector const& x,
        ExpData const& edata
    );

    void getEvent(
        gsl::span<realtype> z, int const ie, realtype const t,
        AmiVector const& x
    );
    void getEventSensitivity(
        gsl::span<realtype> sz, int const ie, realtype const t,
        AmiVector const& x, AmiVectorArray const& sx
    );

    void getUnobservedEventSensitivity(gsl::span<realtype> sz, int const ie);

    void getEventRegularization(
        gsl::span<realtype> rz, int const ie, realtype const t,
        AmiVector const& x
    );

    void getEventRegularizationSensitivity(
        gsl::span<realtype> srz, int const ie, realtype const t,
        AmiVector const& x, AmiVectorArray const& sx
    );
    void getEventSigma(
        gsl::span<realtype> sigmaz, int const ie, int const nroots,
        realtype const t, ExpData const* edata
    );

    void getEventSigmaSensitivity(
        gsl::span<realtype> ssigmaz, int const ie, int const nroots,
        realtype const t, ExpData const* edata
    );

    void addEventObjective(
        realtype& Jz, int const ie, int const nroots, realtype const t,
        AmiVector const& x, ExpData const& edata
    );

    void addEventObjectiveRegularization(
        realtype& Jrz, int const ie, int const nroots, realtype const t,
        AmiVector const& x, ExpData const& edata
    );

    void addEventObjectiveSensitivity(
        std::vector<realtype>& sllh, std::vector<realtype>& s2llh, int const ie,
        int const nroots, realtype const t, AmiVector const& x,
        AmiVectorArray const& sx, ExpData const& edata
    );

    void addPartialEventObjectiveSensitivity(
        std::vector<realtype>& sllh, std::vector<realtype>& s2llh, int const ie,
        int const nroots, realtype const t, AmiVector const& x,
        ExpData const& edata
    );

    void getAdjointStateEventUpdate(
        gsl::span<realtype> dJzdx, int const ie, int const nroots,
        realtype const t, AmiVector const& x, ExpData const& edata
    );

    void getEventTimeSensitivity(
        std::vector<realtype>& stau, realtype const t, int const ie,
        AmiVector const& x, AmiVectorArray const& sx
    );

    void addStateEventUpdate(
        AmiVector& x, int const ie, realtype const t, AmiVector const& xdot,
        AmiVector const& xdot_old
    );

    void addStateSensitivityEventUpdate(
        AmiVectorArray& sx, int const ie, realtype const t,
        AmiVector const& x_old, AmiVector const& xdot,
        AmiVector const& xdot_old, std::vector<realtype> const& stau
    );

    void addAdjointStateEventUpdate(
        AmiVector& xB, int const ie, realtype const t, AmiVector const& x,
        AmiVector const& xdot, AmiVector const& xdot_old
    );

    void addAdjointQuadratureEventUpdate(
        AmiVector xQB, int const ie, realtype const t, AmiVector const& x,
        AmiVector const& xB, AmiVector const& xdot, AmiVector const& xdot_old
    );

    void updateHeaviside(std::vector<int> const& rootsfound);

    void updateHeavisideB(int const* rootsfound);

    int checkFinite(
        gsl::span<realtype const> array, ModelQuantity model_quantity,
        realtype t
    ) const;
    int checkFinite(
        gsl::span<realtype const> array, ModelQuantity model_quantity,
        size_t num_cols, realtype t
    ) const;

    int
    checkFinite(SUNMatrix m, ModelQuantity model_quantity, realtype t) const;

    void setAlwaysCheckFinite(bool alwaysCheck);

    bool getAlwaysCheckFinite() const;

    void fx0(AmiVector& x);

    void fx0_fixedParameters(AmiVector& x);

    void fsx0(AmiVectorArray& sx, AmiVector const& x);

    void fsx0_fixedParameters(AmiVectorArray& sx, AmiVector const& x);

    virtual void fsdx0();

    void fx_rdata(AmiVector& x_rdata, AmiVector const& x_solver);

    void fsx_rdata(
        AmiVectorArray& sx_rdata, AmiVectorArray const& sx_solver,
        AmiVector const& x_solver
    );

    void setReinitializationStateIdxs(std::vector<int> const& idxs);

    std::vector<int> const& getReinitializationStateIdxs() const;

    bool pythonGenerated = false;

    AmiVectorArray const& get_dxdotdp() const;

    SUNMatrixWrapper const& get_dxdotdp_full() const;

    virtual std::vector<double> get_trigger_timepoints() const;

    std::vector<double> get_steadystate_mask() const {
        return steadystate_mask_.getVector();
    };

    AmiVector const& get_steadystate_mask_av() const {
        return steadystate_mask_;
    };

    void set_steadystate_mask(std::vector<double> const& mask);

    SecondOrderMode o2mode{SecondOrderMode::none};

    std::vector<realtype> idlist;

    Logger* logger = nullptr;

    std::map<realtype, std::vector<int>> state_independent_events_ = {};

  protected:
    void writeSliceEvent(
        gsl::span<realtype const> slice, gsl::span<realtype> buffer,
        int const ie
    );

    void writeSensitivitySliceEvent(
        gsl::span<realtype const> slice, gsl::span<realtype> buffer,
        int const ie
    );

    void writeLLHSensitivitySlice(
        std::vector<realtype> const& dLLhdp, std::vector<realtype>& sllh,
        std::vector<realtype>& s2llh
    );

    void checkLLHBufferSize(
        std::vector<realtype> const& sllh, std::vector<realtype> const& s2llh
    ) const;

    void initializeVectors();

    void fy(realtype t, AmiVector const& x);

    void fdydp(realtype t, AmiVector const& x);

    void fdydx(realtype t, AmiVector const& x);

    void fsigmay(int it, ExpData const* edata);

    void fdsigmaydp(int it, ExpData const* edata);

    void fdsigmaydy(int it, ExpData const* edata);

    void fJy(realtype& Jy, int it, AmiVector const& y, ExpData const& edata);

    void fdJydy(int it, AmiVector const& x, ExpData const& edata);

    void fdJydsigma(int it, AmiVector const& x, ExpData const& edata);

    void fdJydp(int const it, AmiVector const& x, ExpData const& edata);

    void fdJydx(int const it, AmiVector const& x, ExpData const& edata);

    void fz(int ie, realtype t, AmiVector const& x);

    void fdzdp(int ie, realtype t, AmiVector const& x);

    void fdzdx(int ie, realtype t, AmiVector const& x);

    void frz(int ie, realtype t, AmiVector const& x);

    void fdrzdp(int ie, realtype t, AmiVector const& x);

    void fdrzdx(int ie, realtype t, AmiVector const& x);

    void fsigmaz(
        int const ie, int const nroots, realtype const t, ExpData const* edata
    );

    void fdsigmazdp(int ie, int nroots, realtype t, ExpData const* edata);

    void
    fJz(realtype& Jz, int nroots, AmiVector const& z, ExpData const& edata);

    void fdJzdz(
        int const ie, int const nroots, realtype const t, AmiVector const& x,
        ExpData const& edata
    );

    void fdJzdsigma(
        int const ie, int const nroots, realtype const t, AmiVector const& x,
        ExpData const& edata
    );

    void fdJzdp(
        int const ie, int const nroots, realtype t, AmiVector const& x,
        ExpData const& edata
    );

    void fdJzdx(
        int const ie, int const nroots, realtype t, AmiVector const& x,
        ExpData const& edata
    );

    void
    fJrz(realtype& Jrz, int nroots, AmiVector const& rz, ExpData const& edata);

    void fdJrzdz(
        int const ie, int const nroots, realtype const t, AmiVector const& x,
        ExpData const& edata
    );

    void fdJrzdsigma(
        int const ie, int const nroots, realtype const t, AmiVector const& x,
        ExpData const& edata
    );

    void fspl(realtype t);

    void fsspl(realtype t);

    void fw(realtype t, realtype const* x, bool include_static = true);

    void fdwdp(realtype t, realtype const* x, bool include_static = true);

    void fdwdx(realtype t, realtype const* x, bool include_static = true);

    void fdwdw(realtype t, realtype const* x, bool include_static = true);

    virtual void fx_rdata(
        realtype* x_rdata, realtype const* x_solver, realtype const* tcl,
        realtype const* p, realtype const* k
    );

    virtual void fsx_rdata(
        realtype* sx_rdata, realtype const* sx_solver, realtype const* stcl,
        realtype const* p, realtype const* k, realtype const* x_solver,
        realtype const* tcl, int const ip
    );

    virtual void fx_solver(realtype* x_solver, realtype const* x_rdata);

    virtual void fsx_solver(realtype* sx_solver, realtype const* sx_rdata);

    virtual void ftotal_cl(
        realtype* total_cl, realtype const* x_rdata, realtype const* p,
        realtype const* k
    );

    virtual void fstotal_cl(
        realtype* stotal_cl, realtype const* sx_rdata, int const ip,
        realtype const* x_rdata, realtype const* p, realtype const* k,
        realtype const* tcl
    );

    const_N_Vector computeX_pos(const_N_Vector x);

    realtype const* computeX_pos(AmiVector const& x);

    ModelState state_;

    ModelStateDerived derived_state_;

    std::vector<HermiteSpline> splines_;

    std::vector<int> z2event_;

    std::vector<realtype> x0data_;

    std::vector<realtype> sx0data_;

    std::vector<bool> state_is_non_negative_;

    std::vector<bool> root_initial_values_;

    bool any_state_non_negative_{false};

    int nmaxevent_{10};

    SteadyStateComputationMode steadystate_computation_mode_{
        SteadyStateComputationMode::integrateIfNewtonFails
    };

    SteadyStateSensitivityMode steadystate_sensitivity_mode_{
        SteadyStateSensitivityMode::integrateIfNewtonFails
    };

#ifdef NDEBUG
    bool always_check_finite_{false};
#else
    bool always_check_finite_{true};
#endif

    bool sigma_res_{false};

    realtype min_sigma_{50.0};

  private:
    mutable std::vector<SUNMatrixWrapper> dwdp_hierarchical_;

    mutable SUNMatrixWrapper dwdw_;

    mutable std::vector<SUNMatrixWrapper> dwdx_hierarchical_;

    int w_recursion_depth_{0};

    SimulationParameters simulation_parameters_;

    AmiVector steadystate_mask_;
};

bool operator==(Model const& a, Model const& b);
bool operator==(ModelDimensions const& a, ModelDimensions const& b);

} // namespace amici

#endif // AMICI_MODEL_H