Program Listing for File sundials_linsol_wrapper.h¶
↰ Return to documentation for file (include/amici/sundials_linsol_wrapper.h
)
#ifndef AMICI_SUNDIALS_LINSOL_WRAPPER_H
#define AMICI_SUNDIALS_LINSOL_WRAPPER_H
#include "amici/exception.h"
#include "amici/sundials_matrix_wrapper.h"
#include "amici/vector.h"
#include <sundials/sundials_config.h>
#include <sunlinsol/sunlinsol_band.h>
#include <sunlinsol/sunlinsol_dense.h>
#include <sunlinsol/sunlinsol_klu.h>
#include <sunlinsol/sunlinsol_pcg.h>
#include <sunlinsol/sunlinsol_spbcgs.h>
#include <sunlinsol/sunlinsol_spfgmr.h>
#include <sunlinsol/sunlinsol_spgmr.h>
#include <sunlinsol/sunlinsol_sptfqmr.h>
#ifdef SUNDIALS_SUPERLUMT
#include <sunlinsol/sunlinsol_superlumt.h>
#endif
#include <sunnonlinsol/sunnonlinsol_fixedpoint.h>
#include <sunnonlinsol/sunnonlinsol_newton.h>
namespace amici {
class SUNLinSolWrapper {
public:
SUNLinSolWrapper() = default;
explicit SUNLinSolWrapper(SUNLinearSolver linsol);
virtual ~SUNLinSolWrapper();
SUNLinSolWrapper(const SUNLinSolWrapper &other) = delete;
SUNLinSolWrapper(SUNLinSolWrapper &&other) noexcept;
SUNLinSolWrapper &operator=(const SUNLinSolWrapper &other) = delete;
SUNLinSolWrapper &operator=(SUNLinSolWrapper &&other) noexcept;
SUNLinearSolver get() const;
SUNLinearSolver_Type getType() const;
void setup(SUNMatrix A) const;
void setup(const SUNMatrixWrapper& A) const;
int Solve(SUNMatrix A, N_Vector x, N_Vector b, realtype tol) const;
long int getLastFlag() const;
int space(long int *lenrwLS, long int *leniwLS) const;
virtual SUNMatrix getMatrix() const;
protected:
int initialize();
SUNLinearSolver solver_ {nullptr};
};
class SUNLinSolBand : public SUNLinSolWrapper {
public:
SUNLinSolBand(N_Vector x, SUNMatrix A);
SUNLinSolBand(AmiVector const &x, int ubw, int lbw);
SUNMatrix getMatrix() const override;
private:
SUNMatrixWrapper A_;
};
class SUNLinSolDense : public SUNLinSolWrapper {
public:
explicit SUNLinSolDense(AmiVector const &x);
SUNMatrix getMatrix() const override;
private:
SUNMatrixWrapper A_;
};
class SUNLinSolKLU : public SUNLinSolWrapper {
public:
enum class StateOrdering {
AMD,
COLAMD,
natural
};
SUNLinSolKLU(N_Vector x, SUNMatrix A);
SUNLinSolKLU(AmiVector const &x, int nnz, int sparsetype,
StateOrdering ordering);
SUNMatrix getMatrix() const override;
void reInit(int nnz, int reinit_type);
void setOrdering(StateOrdering ordering);
private:
SUNMatrixWrapper A_;
};
#ifdef SUNDIALS_SUPERLUMT
class SUNLinSolSuperLUMT : public SUNLinSolWrapper {
public:
enum class StateOrdering {
natural,
minDegATA,
minDegATPlusA,
COLAMD,
};
SUNLinSolSuperLUMT(N_Vector x, SUNMatrix A, int numThreads);
SUNLinSolSuperLUMT(AmiVector const &x, int nnz, int sparsetype,
StateOrdering ordering);
SUNLinSolSuperLUMT(AmiVector const &x, int nnz, int sparsetype,
StateOrdering ordering, int numThreads);
SUNMatrix getMatrix() const override;
void setOrdering(StateOrdering ordering);
private:
SUNMatrixWrapper A;
};
#endif
class SUNLinSolPCG : public SUNLinSolWrapper {
public:
SUNLinSolPCG(N_Vector y, int pretype, int maxl);
int setATimes(void *A_data, ATimesFn ATimes);
int setPreconditioner(void *P_data, PSetupFn Pset, PSolveFn Psol);
int setScalingVectors(N_Vector s, N_Vector nul);
int getNumIters() const;
realtype getResNorm() const;
N_Vector getResid() const;
};
class SUNLinSolSPBCGS : public SUNLinSolWrapper {
public:
explicit SUNLinSolSPBCGS(N_Vector x, int pretype = PREC_NONE,
int maxl = SUNSPBCGS_MAXL_DEFAULT);
explicit SUNLinSolSPBCGS(AmiVector const &x, int pretype = PREC_NONE,
int maxl = SUNSPBCGS_MAXL_DEFAULT);
int setATimes(void *A_data, ATimesFn ATimes);
int setPreconditioner(void *P_data, PSetupFn Pset, PSolveFn Psol);
int setScalingVectors(N_Vector s, N_Vector nul);
int getNumIters() const;
realtype getResNorm() const;
N_Vector getResid() const;
};
class SUNLinSolSPFGMR : public SUNLinSolWrapper {
public:
SUNLinSolSPFGMR(AmiVector const &x, int pretype, int maxl);
int setATimes(void *A_data, ATimesFn ATimes);
int setPreconditioner(void *P_data, PSetupFn Pset, PSolveFn Psol);
int setScalingVectors(N_Vector s, N_Vector nul);
int getNumIters() const;
realtype getResNorm() const;
N_Vector getResid() const;
};
class SUNLinSolSPGMR : public SUNLinSolWrapper {
public:
explicit SUNLinSolSPGMR(AmiVector const &x, int pretype = PREC_NONE,
int maxl = SUNSPGMR_MAXL_DEFAULT);
int setATimes(void *A_data, ATimesFn ATimes);
int setPreconditioner(void *P_data, PSetupFn Pset, PSolveFn Psol);
int setScalingVectors(N_Vector s, N_Vector nul);
int getNumIters() const;
realtype getResNorm() const;
N_Vector getResid() const;
};
class SUNLinSolSPTFQMR : public SUNLinSolWrapper {
public:
explicit SUNLinSolSPTFQMR(N_Vector x, int pretype = PREC_NONE,
int maxl = SUNSPTFQMR_MAXL_DEFAULT);
explicit SUNLinSolSPTFQMR(AmiVector const &x, int pretype = PREC_NONE,
int maxl = SUNSPTFQMR_MAXL_DEFAULT);
int setATimes(void *A_data, ATimesFn ATimes);
int setPreconditioner(void *P_data, PSetupFn Pset, PSolveFn Psol);
int setScalingVectors(N_Vector s, N_Vector nul);
int getNumIters() const;
realtype getResNorm() const;
N_Vector getResid() const;
};
class SUNNonLinSolWrapper {
public:
explicit SUNNonLinSolWrapper(SUNNonlinearSolver sol);
virtual ~SUNNonLinSolWrapper();
SUNNonLinSolWrapper(const SUNNonLinSolWrapper &other) = delete;
SUNNonLinSolWrapper(SUNNonLinSolWrapper &&other) noexcept;
SUNNonLinSolWrapper &operator=(const SUNNonLinSolWrapper &other) = delete;
SUNNonLinSolWrapper &operator=(SUNNonLinSolWrapper &&other) noexcept;
SUNNonlinearSolver get() const;
SUNNonlinearSolver_Type getType() const;
int setup(N_Vector y, void *mem);
int Solve(N_Vector y0, N_Vector y, N_Vector w, realtype tol,
bool callLSetup, void *mem);
int setSysFn(SUNNonlinSolSysFn SysFn);
int setLSetupFn(SUNNonlinSolLSetupFn SetupFn);
int setLSolveFn(SUNNonlinSolLSolveFn SolveFn);
int setConvTestFn(SUNNonlinSolConvTestFn CTestFn, void* ctest_data);
int setMaxIters(int maxiters);
long int getNumIters() const;
int getCurIter() const;
long int getNumConvFails() const;
protected:
void initialize();
SUNNonlinearSolver solver = nullptr;
};
class SUNNonLinSolNewton : public SUNNonLinSolWrapper {
public:
explicit SUNNonLinSolNewton(N_Vector x);
SUNNonLinSolNewton(int count, N_Vector x);
int getSysFn(SUNNonlinSolSysFn *SysFn) const;
};
class SUNNonLinSolFixedPoint : public SUNNonLinSolWrapper {
public:
explicit SUNNonLinSolFixedPoint(const_N_Vector x, int m = 0);
SUNNonLinSolFixedPoint(int count, const_N_Vector x, int m = 0);
int getSysFn(SUNNonlinSolSysFn *SysFn) const;
};
} // namespace amici
#endif // AMICI_SUNDIALS_LINSOL_WRAPPER_H