Program Listing for File model_dae.h

Return to documentation for file (include/amici/model_dae.h)

#ifndef AMICI_MODEL_DAE_H
#define AMICI_MODEL_DAE_H

#include "amici/model.h"

#include <nvector/nvector_serial.h>

#include <sunmatrix/sunmatrix_band.h>
#include <sunmatrix/sunmatrix_dense.h>
#include <sunmatrix/sunmatrix_sparse.h>

#include <numeric>
#include <vector>

namespace amici {

class ExpData;
class IDASolver;

class Model_DAE : public Model {
  public:
    Model_DAE() = default;

    Model_DAE(
        ModelDimensions const& model_dimensions,
        SimulationParameters simulation_parameters,
        SecondOrderMode const o2mode, std::vector<realtype> const& idlist,
        std::vector<int> const& z2event, bool const pythonGenerated = false,
        int const ndxdotdp_explicit = 0, int const ndxdotdx_explicit = 0,
        int const w_recursion_depth = 0,
        std::map<realtype, std::vector<int>> state_independent_events = {}
    )
        : Model(
            model_dimensions, simulation_parameters, o2mode, idlist, z2event,
            pythonGenerated, ndxdotdp_explicit, ndxdotdx_explicit,
            w_recursion_depth, state_independent_events
        ) {
        derived_state_.M_ = SUNMatrixWrapper(nx_solver, nx_solver);
        auto M_nnz = static_cast<sunindextype>(
            std::reduce(idlist.begin(), idlist.end())
        );
        derived_state_.MSparse_
            = SUNMatrixWrapper(nx_solver, nx_solver, M_nnz, CSC_MAT);
        derived_state_.dfdx_
            = SUNMatrixWrapper(nx_solver, nx_solver, 0, CSC_MAT);
    }

    void
    fJ(realtype t, realtype cj, AmiVector const& x, AmiVector const& dx,
       AmiVector const& xdot, SUNMatrix J) override;

    void
    fJ(realtype t, realtype cj, const_N_Vector x, const_N_Vector dx,
       const_N_Vector xdot, SUNMatrix J);

    void
    fJB(realtype const t, realtype cj, AmiVector const& x, AmiVector const& dx,
        AmiVector const& xB, AmiVector const& dxB, AmiVector const& xBdot,
        SUNMatrix JB) override;

    void
    fJB(realtype t, realtype cj, const_N_Vector x, const_N_Vector dx,
        const_N_Vector xB, const_N_Vector dxB, SUNMatrix JB);

    void fJSparse(
        realtype t, realtype cj, AmiVector const& x, AmiVector const& dx,
        AmiVector const& xdot, SUNMatrix J
    ) override;

    void fJSparse(
        realtype t, realtype cj, const_N_Vector x, const_N_Vector dx,
        SUNMatrix J
    );

    void fJSparseB(
        realtype const t, realtype cj, AmiVector const& x, AmiVector const& dx,
        AmiVector const& xB, AmiVector const& dxB, AmiVector const& xBdot,
        SUNMatrix JB
    ) override;

    void fJSparseB(
        realtype t, realtype cj, const_N_Vector x, const_N_Vector dx,
        const_N_Vector xB, const_N_Vector dxB, SUNMatrix JB
    );

    void fJDiag(
        realtype t, AmiVector& JDiag, realtype cj, AmiVector const& x,
        AmiVector const& dx
    ) override;

    void
    fJv(realtype t, AmiVector const& x, AmiVector const& dx,
        AmiVector const& xdot, AmiVector const& v, AmiVector& nJv,
        realtype cj) override;

    void
    fJv(realtype t, const_N_Vector x, const_N_Vector dx, const_N_Vector v,
        N_Vector Jv, realtype cj);

    void fJvB(
        realtype t, const_N_Vector x, const_N_Vector dx, const_N_Vector xB,
        const_N_Vector dxB, const_N_Vector vB, N_Vector JvB, realtype cj
    );

    void froot(
        realtype t, AmiVector const& x, AmiVector const& dx,
        gsl::span<realtype> root
    ) override;

    void froot(
        realtype t, const_N_Vector x, const_N_Vector dx,
        gsl::span<realtype> root
    );

    void fxdot(
        realtype t, AmiVector const& x, AmiVector const& dx, AmiVector& xdot
    ) override;

    void fxdot(realtype t, const_N_Vector x, const_N_Vector dx, N_Vector xdot);

    void fxBdot(
        realtype t, const_N_Vector x, const_N_Vector dx, const_N_Vector xB,
        const_N_Vector dxB, N_Vector xBdot
    );

    void fqBdot(
        realtype t, const_N_Vector x, const_N_Vector dx, const_N_Vector xB,
        const_N_Vector dxB, N_Vector qBdot
    );

    void fxBdot_ss(
        realtype const t, AmiVector const& xB, AmiVector const& dxB,
        AmiVector& xBdot
    ) override;

    void fxBdot_ss(
        realtype t, const_N_Vector xB, const_N_Vector dxB, N_Vector xBdot
    ) const;

    void fqBdot_ss(
        realtype t, const_N_Vector xB, const_N_Vector dxB, N_Vector qBdot
    ) const;

    void fJSparseB_ss(SUNMatrix JB) override;

    void writeSteadystateJB(
        realtype const t, realtype cj, AmiVector const& x, AmiVector const& dx,
        AmiVector const& xB, AmiVector const& dxB, AmiVector const& xBdot
    ) override;

    void fdxdotdp(realtype t, const_N_Vector const x, const_N_Vector const dx);
    void fdxdotdp(realtype const t, AmiVector const& x, AmiVector const& dx)
        override {
        fdxdotdp(t, x.getNVector(), dx.getNVector());
    };

    void fsxdot(
        realtype t, AmiVector const& x, AmiVector const& dx, int ip,
        AmiVector const& sx, AmiVector const& sdx, AmiVector& sxdot
    ) override;
    void fsxdot(
        realtype t, const_N_Vector x, const_N_Vector dx, int ip,
        const_N_Vector sx, const_N_Vector sdx, N_Vector sxdot
    );

    void fM(realtype t, const_N_Vector x);

    std::unique_ptr<Solver> getSolver() override;

  protected:
    virtual void fJSparse(
        SUNMatrixContent_Sparse JSparse, realtype t, realtype const* x,
        double const* p, double const* k, realtype const* h, realtype cj,
        realtype const* dx, realtype const* w, realtype const* dwdx
    );

    virtual void froot(
        realtype* root, realtype t, realtype const* x, double const* p,
        double const* k, realtype const* h, realtype const* dx
    );

    virtual void fxdot(
        realtype* xdot, realtype t, realtype const* x, double const* p,
        double const* k, realtype const* h, realtype const* dx,
        realtype const* w
    ) = 0;

    virtual void fdxdotdp(
        realtype* dxdotdp, realtype t, realtype const* x, realtype const* p,
        realtype const* k, realtype const* h, int ip, realtype const* dx,
        realtype const* w, realtype const* dwdp
    );

    virtual void fdxdotdp_explicit(
        realtype* dxdotdp_explicit, realtype t, realtype const* x,
        realtype const* p, realtype const* k, realtype const* h,
        realtype const* dx, realtype const* w
    );

    virtual void fdxdotdp_explicit_colptrs(SUNMatrixWrapper& dxdotdp);

    virtual void fdxdotdp_explicit_rowvals(SUNMatrixWrapper& dxdotdp);

    virtual void fdxdotdx_explicit(
        realtype* dxdotdx_explicit, realtype t, realtype const* x,
        realtype const* p, realtype const* k, realtype const* h,
        realtype const* dx, realtype const* w
    );

    virtual void fdxdotdx_explicit_colptrs(SUNMatrixWrapper& dxdotdx);

    virtual void fdxdotdx_explicit_rowvals(SUNMatrixWrapper& dxdotdx);

    virtual void fdxdotdw(
        realtype* dxdotdw, realtype t, realtype const* x, realtype const* p,
        realtype const* k, realtype const* h, realtype const* dx,
        realtype const* w
    );

    virtual void fdxdotdw_colptrs(SUNMatrixWrapper& dxdotdw);

    virtual void fdxdotdw_rowvals(SUNMatrixWrapper& dxdotdw);

    void fdxdotdw(realtype t, const_N_Vector x, const_N_Vector dx);

    virtual void
    fM(realtype* M, realtype const t, realtype const* x, realtype const* p,
       realtype const* k);
};
} // namespace amici

#endif // MODEL_H