Program Listing for File model.h

Return to documentation for file (include/amici/model.h)

#ifndef AMICI_MODEL_H
#define AMICI_MODEL_H

#include "amici/abstract_model.h"
#include "amici/defines.h"
#include "amici/sundials_matrix_wrapper.h"
#include "amici/vector.h"
#include "amici/simulation_parameters.h"
#include "amici/model_dimensions.h"
#include "amici/model_state.h"

#include <map>
#include <memory>
#include <vector>

namespace amici {

class ExpData;
class Model;
class Solver;
class AmiciApplication;

extern AmiciApplication defaultContext;

} // namespace amici

// for serialization friend in amici::Model
namespace boost {
namespace serialization {
template <class Archive>
void serialize(Archive &ar, amici::Model &m, unsigned int version);
}
} // namespace boost

namespace amici {

class Model : public AbstractModel, public ModelDimensions {
  public:
    Model() = default;

    Model(ModelDimensions const& model_dimensions,
          SimulationParameters simulation_parameters,
          amici::SecondOrderMode o2mode,
          std::vector<amici::realtype> idlist,
          std::vector<int> z2event, bool pythonGenerated = false,
          int ndxdotdp_explicit = 0, int ndxdotdx_explicit = 0,
          int w_recursion_depth = 0);

    ~Model() override = default;

    Model &operator=(Model const &other) = delete;

    virtual Model *clone() const = 0;

    template <class Archive>
    friend void boost::serialization::serialize(Archive &ar, Model &m,
                                                unsigned int version);

    friend bool operator==(const Model &a, const Model &b);

    // Overloaded base class methods
    using AbstractModel::fdeltaqB;
    using AbstractModel::fdeltasx;
    using AbstractModel::fdeltax;
    using AbstractModel::fdeltaxB;
    using AbstractModel::fdJrzdsigma;
    using AbstractModel::fdJrzdz;
    using AbstractModel::fdJydsigma;
    using AbstractModel::fdJydy;
    using AbstractModel::fdJydy_colptrs;
    using AbstractModel::fdJydy_rowvals;
    using AbstractModel::fdJzdsigma;
    using AbstractModel::fdJzdz;
    using AbstractModel::fdrzdp;
    using AbstractModel::fdrzdx;
    using AbstractModel::fdsigmaydp;
    using AbstractModel::fdsigmazdp;
    using AbstractModel::fdwdp;
    using AbstractModel::fdwdp_colptrs;
    using AbstractModel::fdwdp_rowvals;
    using AbstractModel::fdwdx;
    using AbstractModel::fdwdx_colptrs;
    using AbstractModel::fdwdx_rowvals;
    using AbstractModel::fdwdw;
    using AbstractModel::fdwdw_colptrs;
    using AbstractModel::fdwdw_rowvals;
    using AbstractModel::fdydp;
    using AbstractModel::fdydx;
    using AbstractModel::fdzdp;
    using AbstractModel::fdzdx;
    using AbstractModel::fJrz;
    using AbstractModel::fJy;
    using AbstractModel::fJz;
    using AbstractModel::frz;
    using AbstractModel::fsigmay;
    using AbstractModel::fsigmaz;
    using AbstractModel::fsrz;
    using AbstractModel::fstau;
    using AbstractModel::fsx0;
    using AbstractModel::fsx0_fixedParameters;
    using AbstractModel::fsz;
    using AbstractModel::fw;
    using AbstractModel::fx0;
    using AbstractModel::fx0_fixedParameters;
    using AbstractModel::fy;
    using AbstractModel::fz;

    void initialize(AmiVector &x, AmiVector &dx, AmiVectorArray &sx,
                    AmiVectorArray &sdx, bool computeSensitivities);

    void initializeB(AmiVector &xB, AmiVector &dxB, AmiVector &xQB,
                     bool posteq) const;

    void initializeStates(AmiVector &x);

    void initializeStateSensitivities(AmiVectorArray &sx, const AmiVector &x);

    void initHeaviside(const AmiVector &x, const AmiVector &dx);

    int nplist() const;

    int np() const;

    int nk() const;

    int ncl() const;

    int nx_reinit() const;

    const double *k() const;

    int nMaxEvent() const;

    void setNMaxEvent(int nmaxevent);

    int nt() const;

    std::vector<ParameterScaling> const &getParameterScale() const;

    void setParameterScale(ParameterScaling pscale);

    void setParameterScale(const std::vector<ParameterScaling> &pscaleVec);

    std::vector<realtype> const &getUnscaledParameters() const;

    std::vector<realtype> const &getParameters() const;

    realtype getParameterById(std::string const &par_id) const;

    realtype getParameterByName(std::string const &par_name) const;

    void setParameters(std::vector<realtype> const &p);

    void setParameterById(std::map<std::string, realtype> const &p,
                          bool ignoreErrors = false);

    void setParameterById(std::string const &par_id, realtype value);

    int setParametersByIdRegex(std::string const &par_id_regex, realtype value);

    void setParameterByName(std::string const &par_name, realtype value);

    void setParameterByName(std::map<std::string, realtype> const &p,
                            bool ignoreErrors = false);

    int setParametersByNameRegex(std::string const &par_name_regex,
                                 realtype value);

    std::vector<realtype> const &getFixedParameters() const;

    realtype getFixedParameterById(std::string const &par_id) const;

    realtype getFixedParameterByName(std::string const &par_name) const;

    void setFixedParameters(std::vector<realtype> const &k);

    void setFixedParameterById(std::string const &par_id, realtype value);

    int setFixedParametersByIdRegex(std::string const &par_id_regex,
                                    realtype value);

    void setFixedParameterByName(std::string const &par_name, realtype value);

    int setFixedParametersByNameRegex(std::string const &par_name_regex,
                                      realtype value);

    virtual std::string getName() const;

    virtual bool hasParameterNames() const;

    virtual std::vector<std::string> getParameterNames() const;

    virtual bool hasStateNames() const;

    virtual std::vector<std::string> getStateNames() const;

    virtual bool hasFixedParameterNames() const;

    virtual std::vector<std::string> getFixedParameterNames() const;

    virtual bool hasObservableNames() const;

    virtual std::vector<std::string> getObservableNames() const;

    virtual bool hasExpressionNames() const;

    virtual std::vector<std::string> getExpressionNames() const;

    virtual bool hasParameterIds() const;

    virtual std::vector<std::string> getParameterIds() const;

    virtual bool hasStateIds() const;

    virtual std::vector<std::string> getStateIds() const;

    virtual bool hasFixedParameterIds() const;

    virtual std::vector<std::string> getFixedParameterIds() const;

    virtual bool hasObservableIds() const;

    virtual std::vector<std::string> getObservableIds() const;

    virtual bool hasExpressionIds() const;

    virtual std::vector<std::string> getExpressionIds() const;

    virtual bool hasQuadraticLLH() const;

    std::vector<realtype> const &getTimepoints() const;

    realtype getTimepoint(int it) const;

    void setTimepoints(std::vector<realtype> const &ts);

    double t0() const;

    void setT0(double t0);

    std::vector<bool> const &getStateIsNonNegative() const;

    void setStateIsNonNegative(std::vector<bool> const &stateIsNonNegative);

    void setAllStatesNonNegative();

    ModelState const &getModelState() const {
        return state_;
    };

    void setModelState(ModelState const &state) {
        if (static_cast<int>(state.unscaledParameters.size()) != np())
            throw AmiException("Mismatch in parameter size");
        if (static_cast<int>(state.fixedParameters.size()) != nk())
            throw AmiException("Mismatch in fixed parameter size");
        if (static_cast<int>(state.h.size()) != ne)
            throw AmiException("Mismatch in Heaviside size");
        if (static_cast<int>(state.total_cl.size()) != ncl())
            throw AmiException("Mismatch in conservation law size");
        if (static_cast<int>(state.stotal_cl.size()) != ncl() * np() )
            throw AmiException("Mismatch in conservation law sensitivity size");
        state_ = state;
    };

    void setMinimumSigmaResiduals(double min_sigma) {
        min_sigma_ = min_sigma;
    }

    realtype getMinimumSigmaResiduals() const {
        return min_sigma_;
    }

    void setAddSigmaResiduals(bool sigma_res) {
        sigma_res_ = sigma_res;
    }

    bool getAddSigmaResiduals() const {
        return sigma_res_;
    }

    std::vector<int> const &getParameterList() const;

    int plist(int pos) const;

    void setParameterList(std::vector<int> const &plist);

    std::vector<realtype> getInitialStates();

    void setInitialStates(std::vector<realtype> const &x0);

    bool hasCustomInitialStates() const;

    std::vector<realtype> getInitialStateSensitivities();

    void setInitialStateSensitivities(std::vector<realtype> const &sx0);

    bool hasCustomInitialStateSensitivities() const;

    void setUnscaledInitialStateSensitivities(std::vector<realtype> const &sx0);

    void setSteadyStateSensitivityMode(SteadyStateSensitivityMode mode);

    SteadyStateSensitivityMode getSteadyStateSensitivityMode() const;

    void setReinitializeFixedParameterInitialStates(bool flag);

    bool getReinitializeFixedParameterInitialStates() const;

    void requireSensitivitiesForAllParameters();

    void getExpression(gsl::span<realtype> w, const realtype t, const AmiVector &x);

    void getObservable(gsl::span<realtype> y, const realtype t,
                       const AmiVector &x);

    virtual ObservableScaling getObservableScaling(int iy) const;

    void getObservableSensitivity(gsl::span<realtype> sy, const realtype t,
                                  const AmiVector &x, const AmiVectorArray &sx);

    void getObservableSigma(gsl::span<realtype> sigmay, const int it,
                            const ExpData *edata);

    void getObservableSigmaSensitivity(gsl::span<realtype> ssigmay,
                                       const int it, const ExpData *edata);

    void addObservableObjective(realtype &Jy, const int it, const AmiVector &x,
                                const ExpData &edata);

    void addObservableObjectiveSensitivity(std::vector<realtype> &sllh,
                                           std::vector<realtype> &s2llh,
                                           const int it, const AmiVector &x,
                                           const AmiVectorArray &sx,
                                           const ExpData &edata);

    void addPartialObservableObjectiveSensitivity(std::vector<realtype> &sllh,
                                                  std::vector<realtype> &s2llh,
                                                  const int it,
                                                  const AmiVector &x,
                                                  const ExpData &edata);

    void getAdjointStateObservableUpdate(gsl::span<realtype> dJydx,
                                         const int it, const AmiVector &x,
                                         const ExpData &edata);

    void getEvent(gsl::span<realtype> z, const int ie, const realtype t,
                  const AmiVector &x);
    void getEventSensitivity(gsl::span<realtype> sz, const int ie,
                             const realtype t, const AmiVector &x,
                             const AmiVectorArray &sx);

    void getUnobservedEventSensitivity(gsl::span<realtype> sz, const int ie);

    void getEventRegularization(gsl::span<realtype> rz, const int ie,
                                const realtype t, const AmiVector &x);

    void getEventRegularizationSensitivity(gsl::span<realtype> srz,
                                           const int ie, const realtype t,
                                           const AmiVector &x,
                                           const AmiVectorArray &sx);
    void getEventSigma(gsl::span<realtype> sigmaz, const int ie,
                       const int nroots, const realtype t,
                       const ExpData *edata);

    void getEventSigmaSensitivity(gsl::span<realtype> ssigmaz, const int ie,
                                  const int nroots, const realtype t,
                                  const ExpData *edata);

    void addEventObjective(realtype &Jz, const int ie, const int nroots,
                           const realtype t, const AmiVector &x,
                           const ExpData &edata);

    void addEventObjectiveRegularization(realtype &Jrz, const int ie,
                                         const int nroots, const realtype t,
                                         const AmiVector &x,
                                         const ExpData &edata);

    void addEventObjectiveSensitivity(std::vector<realtype> &sllh,
                                      std::vector<realtype> &s2llh,
                                      const int ie, const int nroots,
                                      const realtype t, const AmiVector &x,
                                      const AmiVectorArray &sx,
                                      const ExpData &edata);

    void addPartialEventObjectiveSensitivity(std::vector<realtype> &sllh,
                                             std::vector<realtype> &s2llh,
                                             const int ie, const int nroots,
                                             const realtype t,
                                             const AmiVector &x,
                                             const ExpData &edata);

    void getAdjointStateEventUpdate(gsl::span<realtype> dJzdx, const int ie,
                                    const int nroots, const realtype t,
                                    const AmiVector &x, const ExpData &edata);

    void getEventTimeSensitivity(std::vector<realtype> &stau, const realtype t,
                                 const int ie, const AmiVector &x,
                                 const AmiVectorArray &sx);

    void addStateEventUpdate(AmiVector &x, const int ie, const realtype t,
                             const AmiVector &xdot, const AmiVector &xdot_old);

    void addStateSensitivityEventUpdate(AmiVectorArray &sx, const int ie,
                                        const realtype t,
                                        const AmiVector &x_old,
                                        const AmiVector &xdot,
                                        const AmiVector &xdot_old,
                                        const std::vector<realtype> &stau);

    void addAdjointStateEventUpdate(AmiVector &xB, const int ie,
                                    const realtype t, const AmiVector &x,
                                    const AmiVector &xdot,
                                    const AmiVector &xdot_old);

    void addAdjointQuadratureEventUpdate(AmiVector xQB, const int ie,
                                         const realtype t, const AmiVector &x,
                                         const AmiVector &xB,
                                         const AmiVector &xdot,
                                         const AmiVector &xdot_old);

    void updateHeaviside(const std::vector<int> &rootsfound);

    void updateHeavisideB(const int *rootsfound);

    int checkFinite(gsl::span<const realtype> array, const char *fun) const;

    void setAlwaysCheckFinite(bool alwaysCheck);

    bool getAlwaysCheckFinite() const;

    void fx0(AmiVector &x);

    void fx0_fixedParameters(AmiVector &x);

    void fsx0(AmiVectorArray &sx, const AmiVector &x);

    void fsx0_fixedParameters(AmiVectorArray &sx, const AmiVector &x);

    virtual void fsdx0();

    void fx_rdata(AmiVector &x_rdata, const AmiVector &x_solver);

    void fsx_rdata(AmiVectorArray &sx_rdata, const AmiVectorArray &sx_solver);

    void setReinitializationStateIdxs(const std::vector<int> &idxs);

    std::vector<int> const& getReinitializationStateIdxs() const;

    bool pythonGenerated;

    const AmiVectorArray &get_dxdotdp() const;

    const SUNMatrixWrapper &get_dxdotdp_full() const;

    SecondOrderMode o2mode{SecondOrderMode::none};

    std::vector<realtype> idlist;

    AmiciApplication *app = &defaultContext;

  protected:
    void writeSliceEvent(gsl::span<const realtype> slice,
                         gsl::span<realtype> buffer, const int ie);

    void writeSensitivitySliceEvent(gsl::span<const realtype> slice,
                                    gsl::span<realtype> buffer, const int ie);

    void writeLLHSensitivitySlice(const std::vector<realtype> &dLLhdp,
                                  std::vector<realtype> &sllh,
                                  std::vector<realtype> &s2llh);

    void checkLLHBufferSize(const std::vector<realtype> &sllh,
                            const std::vector<realtype> &s2llh) const;

    void initializeVectors();

    void fy(realtype t, const AmiVector &x);

    void fdydp(realtype t, const AmiVector &x);

    void fdydx(realtype t, const AmiVector &x);

    void fsigmay(int it, const ExpData *edata);

    void fdsigmaydp(int it, const ExpData *edata);

    void fJy(realtype &Jy, int it, const AmiVector &y, const ExpData &edata);

    void fdJydy(int it, const AmiVector &x, const ExpData &edata);

    void fdJydsigma(int it, const AmiVector &x, const ExpData &edata);

    void fdJydp(const int it, const AmiVector &x, const ExpData &edata);

    void fdJydx(const int it, const AmiVector &x, const ExpData &edata);

    void fz(int ie, realtype t, const AmiVector &x);

    void fdzdp(int ie, realtype t, const AmiVector &x);

    void fdzdx(int ie, realtype t, const AmiVector &x);

    void frz(int ie, realtype t, const AmiVector &x);

    void fdrzdp(int ie, realtype t, const AmiVector &x);

    void fdrzdx(int ie, realtype t, const AmiVector &x);

    void fsigmaz(const int ie, const int nroots, const realtype t,
                 const ExpData *edata);

    void fdsigmazdp(int ie, int nroots, realtype t, const ExpData *edata);

    void fJz(realtype &Jz, int nroots, const AmiVector &z,
             const ExpData &edata);

    void fdJzdz(const int ie, const int nroots, const realtype t,
                const AmiVector &x, const ExpData &edata);

    void fdJzdsigma(const int ie, const int nroots, const realtype t,
                    const AmiVector &x, const ExpData &edata);

    void fdJzdp(const int ie, const int nroots, realtype t, const AmiVector &x,
                const ExpData &edata);

    void fdJzdx(const int ie, const int nroots, realtype t, const AmiVector &x,
                const ExpData &edata);

    void fJrz(realtype &Jrz, int nroots, const AmiVector &rz,
              const ExpData &edata);

    void fdJrzdz(const int ie, const int nroots, const realtype t,
                 const AmiVector &x, const ExpData &edata);

    void fdJrzdsigma(const int ie, const int nroots, const realtype t,
                     const AmiVector &x, const ExpData &edata);

    void fw(realtype t, const realtype *x);

    void fdwdp(realtype t, const realtype *x);

    void fdwdx(realtype t, const realtype *x);

    void fdwdw(realtype t, const realtype *x);

    virtual void fx_rdata(realtype *x_rdata, const realtype *x_solver,
                          const realtype *tcl);

    virtual void fsx_rdata(realtype *sx_rdata, const realtype *sx_solver,
                           const realtype *stcl, int ip);

    virtual void fx_solver(realtype *x_solver, const realtype *x_rdata);

    virtual void fsx_solver(realtype *sx_solver, const realtype *sx_rdata);

    virtual void ftotal_cl(realtype *total_cl, const realtype *x_rdata);

    virtual void fstotal_cl(realtype *stotal_cl, const realtype *sx_rdata,
                            int ip);

    const_N_Vector computeX_pos(const_N_Vector x);

    ModelState state_;

    ModelStateDerived derived_state_;

    std::vector<int> z2event_;

    std::vector<realtype> x0data_;

    std::vector<realtype> sx0data_;

    std::vector<bool> state_is_non_negative_;

    bool any_state_non_negative_ {false};

    int nmaxevent_ {10};

    SteadyStateSensitivityMode steadystate_sensitivity_mode_ {SteadyStateSensitivityMode::newtonOnly};

    bool always_check_finite_ {false};

    bool sigma_res_ {false};

    realtype min_sigma_ {50.0};

  private:
    mutable std::vector<SUNMatrixWrapper> dwdp_hierarchical_;

    mutable SUNMatrixWrapper dwdw_;

    mutable std::vector<SUNMatrixWrapper> dwdx_hierarchical_;

    int w_recursion_depth_ {0};

    SimulationParameters simulation_parameters_;
};

bool operator==(const Model &a, const Model &b);
bool operator==(const ModelDimensions &a, const ModelDimensions &b);

} // namespace amici

#endif // AMICI_MODEL_H