Program Listing for File model_dae.h

Return to documentation for file (include/amici/model_dae.h)

#ifndef AMICI_MODEL_DAE_H
#define AMICI_MODEL_DAE_H

#include "amici/model.h"

#include <nvector/nvector_serial.h>

#include <sunmatrix/sunmatrix_band.h>
#include <sunmatrix/sunmatrix_dense.h>
#include <sunmatrix/sunmatrix_sparse.h>

#include <utility>
#include <vector>

namespace amici {

class ExpData;
class IDASolver;

class Model_DAE : public Model {
  public:
    Model_DAE() = default;

    Model_DAE(const int nx_rdata, const int nxtrue_rdata, const int nx_solver,
              const int nxtrue_solver, const int nx_solver_reinit, const int ny, const int nytrue,
              const int nz, const int nztrue, const int ne, const int nJ,
              const int nw, const int ndwdx, const int ndwdp, const int ndwdw,
              const int ndxdotdw, std::vector<int> ndJydy, const int nnz,
              const int ubw, const int lbw, const SecondOrderMode o2mode,
              std::vector<realtype> const &p, std::vector<realtype> const &k,
              std::vector<int> const &plist,
              std::vector<realtype> const &idlist,
              std::vector<int> const &z2event, const bool pythonGenerated=false,
              const int ndxdotdp_explicit=0)
        : Model(nx_rdata, nxtrue_rdata, nx_solver, nxtrue_solver,
                nx_solver_reinit, ny, nytrue, nz, nztrue, ne, nJ, nw, ndwdx,
                ndwdp, ndwdw, ndxdotdw, std::move(ndJydy), nnz, ubw, lbw,
                o2mode, p, k, plist, idlist, z2event, pythonGenerated,
                ndxdotdp_explicit) {
            M_ = SUNMatrixWrapper(nx_solver, nx_solver);
        }

    void fJ(realtype t, realtype cj, const AmiVector &x, const AmiVector &dx,
            const AmiVector &xdot, SUNMatrix J) override;

    void fJ(realtype t, realtype cj, const_N_Vector x, const_N_Vector dx,
            const_N_Vector xdot, SUNMatrix J);

    void fJB(const realtype t, realtype cj, const AmiVector &x,
             const AmiVector &dx, const AmiVector &xB, const AmiVector &dxB,
             const AmiVector &xBdot, SUNMatrix JB) override;

    void fJB(realtype t, realtype cj, const_N_Vector x, const_N_Vector dx,
             const_N_Vector xB, const_N_Vector dxB, SUNMatrix JB);

    void fJSparse(realtype t, realtype cj, const AmiVector &x,
                  const AmiVector &dx, const AmiVector &xdot,
                  SUNMatrix J) override;

    void fJSparse(realtype t, realtype cj, const_N_Vector x, const_N_Vector dx,
                  SUNMatrix J);

    void fJSparseB(const realtype t, realtype cj, const AmiVector &x,
                   const AmiVector &dx, const AmiVector &xB,
                   const AmiVector &dxB, const AmiVector &xBdot,
                   SUNMatrix JB) override;

    void fJSparseB(realtype t, realtype cj, const_N_Vector x, const_N_Vector dx,
                   const_N_Vector xB, const_N_Vector dxB, SUNMatrix JB);

    void fJDiag(realtype t, AmiVector &JDiag, realtype cj, const AmiVector &x,
                const AmiVector &dx) override;

    void fJv(realtype t, const AmiVector &x, const AmiVector &dx,
             const AmiVector &xdot, const AmiVector &v, AmiVector &nJv,
             realtype cj) override;

    void fJv(realtype t, const_N_Vector x, const_N_Vector dx, const_N_Vector v,
             N_Vector Jv, realtype cj);

    void fJvB(realtype t, const_N_Vector x, const_N_Vector dx,
              const_N_Vector xB, const_N_Vector dxB,
              const_N_Vector vB, N_Vector JvB, realtype cj);

    void froot(realtype t, const AmiVector &x, const AmiVector &dx,
               gsl::span<realtype> root) override;

    void froot(realtype t, const_N_Vector x, const_N_Vector dx, gsl::span<realtype> root);

    void fxdot(realtype t, const AmiVector &x, const AmiVector &dx,
               AmiVector &xdot) override;

    void fxdot(realtype t, const_N_Vector x, const_N_Vector dx, N_Vector xdot);

    void fxBdot(realtype t, const_N_Vector x, const_N_Vector dx,
                const_N_Vector xB, const_N_Vector dxB, N_Vector xBdot);

    void fqBdot(realtype t, const_N_Vector x, const_N_Vector dx,
                const_N_Vector xB, const_N_Vector dxB,
                N_Vector qBdot);

    void fxBdot_ss(const realtype t, const AmiVector &xB,
                   const AmiVector &dxB, AmiVector &xBdot) override;

    void fxBdot_ss(realtype t, const_N_Vector xB, const_N_Vector dxB,
                   N_Vector xBdot) const;

    void fqBdot_ss(realtype t, const_N_Vector xB, const_N_Vector dxB,
                   N_Vector qBdot) const;

    void fJSparseB_ss(SUNMatrix JB) override;

    void writeSteadystateJB(const realtype t, realtype cj,
                            const AmiVector &x, const AmiVector &dx,
                            const AmiVector &xB, const AmiVector &dxB,
                            const AmiVector &xBdot) override;

    void fdxdotdp(realtype t, const const_N_Vector x, const const_N_Vector dx);
    void fdxdotdp(const realtype t, const AmiVector &x,
                  const AmiVector &dx) override {
        fdxdotdp(t, x.getNVector(), dx.getNVector());
    };

    void fsxdot(realtype t, const AmiVector &x, const AmiVector &dx, int ip,
                const AmiVector &sx, const AmiVector &sdx,
                AmiVector &sxdot) override;
    void fsxdot(realtype t, const_N_Vector x, const_N_Vector dx, int ip,
                const_N_Vector sx, const_N_Vector sdx, N_Vector sxdot);

    void fM(realtype t, const_N_Vector x);

    std::unique_ptr<Solver> getSolver() override;

  protected:
    virtual void fJSparse(SUNMatrixContent_Sparse JSparse, realtype t,
                          const realtype *x, const double *p, const double *k,
                          const realtype *h, realtype cj, const realtype *dx,
                          const realtype *w, const realtype *dwdx) = 0;

    virtual void froot(realtype *root, realtype t, const realtype *x,
                       const double *p, const double *k, const realtype *h,
                       const realtype *dx);

    virtual void fxdot(realtype *xdot, realtype t, const realtype *x,
                       const double *p, const double *k, const realtype *h,
                       const realtype *dx, const realtype *w) = 0;

    virtual void fdxdotdp(realtype *dxdotdp, realtype t, const realtype *x,
                          const realtype *p, const realtype *k,
                          const realtype *h, int ip, const realtype *dx,
                          const realtype *w, const realtype *dwdp);

    virtual void fM(realtype *M, const realtype t, const realtype *x,
                    const realtype *p, const realtype *k);
};
} // namespace amici

#endif // MODEL_H