Program Listing for File sundials_linsol_wrapper.h

Return to documentation for file (include/amici/sundials_linsol_wrapper.h)

#ifndef AMICI_SUNDIALS_LINSOL_WRAPPER_H
#define AMICI_SUNDIALS_LINSOL_WRAPPER_H

#include "amici/exception.h"
#include "amici/sundials_matrix_wrapper.h"
#include "amici/vector.h"

#include <sundials/sundials_config.h>
#include <sunlinsol/sunlinsol_band.h>
#include <sunlinsol/sunlinsol_dense.h>
#include <sunlinsol/sunlinsol_klu.h>
#include <sunlinsol/sunlinsol_pcg.h>
#include <sunlinsol/sunlinsol_spbcgs.h>
#include <sunlinsol/sunlinsol_spfgmr.h>
#include <sunlinsol/sunlinsol_spgmr.h>
#include <sunlinsol/sunlinsol_sptfqmr.h>
#ifdef SUNDIALS_SUPERLUMT
#include <sunlinsol/sunlinsol_superlumt.h>
#endif
#include <sunnonlinsol/sunnonlinsol_fixedpoint.h>
#include <sunnonlinsol/sunnonlinsol_newton.h>

namespace amici {

class SUNLinSolWrapper {
  public:
    SUNLinSolWrapper() = default;

    explicit SUNLinSolWrapper(SUNLinearSolver linsol);

    virtual ~SUNLinSolWrapper();

    SUNLinSolWrapper(const SUNLinSolWrapper &other) = delete;

    SUNLinSolWrapper(SUNLinSolWrapper &&other) noexcept;

    SUNLinSolWrapper &operator=(const SUNLinSolWrapper &other) = delete;

    SUNLinSolWrapper &operator=(SUNLinSolWrapper &&other) noexcept;

    SUNLinearSolver get() const;

    SUNLinearSolver_Type getType() const;

    void setup(SUNMatrix A) const;

    void setup(const SUNMatrixWrapper& A) const;

    int Solve(SUNMatrix A, N_Vector x, N_Vector b, realtype tol) const;

    long int getLastFlag() const;

    int space(long int *lenrwLS, long int *leniwLS) const;

    virtual SUNMatrix getMatrix() const;

  protected:
    int initialize();

    SUNLinearSolver solver_ {nullptr};
};


class SUNLinSolBand : public SUNLinSolWrapper {
  public:
    SUNLinSolBand(N_Vector x, SUNMatrix A);

    SUNLinSolBand(AmiVector const &x, int ubw, int lbw);

    SUNMatrix getMatrix() const override;

  private:
    SUNMatrixWrapper A_;
};


class SUNLinSolDense : public SUNLinSolWrapper {
  public:
    explicit SUNLinSolDense(AmiVector const &x);

    SUNMatrix getMatrix() const override;

  private:
    SUNMatrixWrapper A_;
};


class SUNLinSolKLU : public SUNLinSolWrapper {
  public:
    enum class StateOrdering {
        AMD,
        COLAMD,
        natural
    };

    SUNLinSolKLU(N_Vector x, SUNMatrix A);

    SUNLinSolKLU(AmiVector const &x, int nnz, int sparsetype,
                 StateOrdering ordering);

    SUNMatrix getMatrix() const override;

    void reInit(int nnz, int reinit_type);

    void setOrdering(StateOrdering ordering);

  private:
    SUNMatrixWrapper A_;
};

#ifdef SUNDIALS_SUPERLUMT

class SUNLinSolSuperLUMT  : public SUNLinSolWrapper {
  public:
    enum class StateOrdering {
        natural,
        minDegATA,
        minDegATPlusA,
        COLAMD,
    };

    SUNLinSolSuperLUMT(N_Vector x, SUNMatrix A, int numThreads);

    SUNLinSolSuperLUMT(AmiVector const &x, int nnz, int sparsetype,
                       StateOrdering ordering);

    SUNLinSolSuperLUMT(AmiVector const &x, int nnz, int sparsetype,
                       StateOrdering ordering, int numThreads);

    SUNMatrix getMatrix() const override;

    void setOrdering(StateOrdering ordering);

  private:
    SUNMatrixWrapper A;
};

#endif

class SUNLinSolPCG : public SUNLinSolWrapper {
  public:
    SUNLinSolPCG(N_Vector y, int pretype, int maxl);

    int setATimes(void *A_data, ATimesFn ATimes);

    int setPreconditioner(void *P_data, PSetupFn Pset, PSolveFn Psol);

    int setScalingVectors(N_Vector s, N_Vector nul);

    int getNumIters() const;

    realtype getResNorm() const;

    N_Vector getResid() const;
};


class SUNLinSolSPBCGS : public SUNLinSolWrapper {
  public:
    explicit SUNLinSolSPBCGS(N_Vector x, int pretype = PREC_NONE,
                             int maxl = SUNSPBCGS_MAXL_DEFAULT);

    explicit SUNLinSolSPBCGS(AmiVector const &x, int pretype = PREC_NONE,
                             int maxl = SUNSPBCGS_MAXL_DEFAULT);

    int setATimes(void *A_data, ATimesFn ATimes);

    int setPreconditioner(void *P_data, PSetupFn Pset, PSolveFn Psol);

    int setScalingVectors(N_Vector s, N_Vector nul);

    int getNumIters() const;

    realtype getResNorm() const;

    N_Vector getResid() const;
};


class SUNLinSolSPFGMR : public SUNLinSolWrapper {
  public:
    SUNLinSolSPFGMR(AmiVector const &x, int pretype, int maxl);

    int setATimes(void *A_data, ATimesFn ATimes);

    int setPreconditioner(void *P_data, PSetupFn Pset, PSolveFn Psol);

    int setScalingVectors(N_Vector s, N_Vector nul);

    int getNumIters() const;

    realtype getResNorm() const;

    N_Vector getResid() const;
};


class SUNLinSolSPGMR : public SUNLinSolWrapper {
  public:
    explicit SUNLinSolSPGMR(AmiVector const &x, int pretype = PREC_NONE,
                            int maxl = SUNSPGMR_MAXL_DEFAULT);

    int setATimes(void *A_data, ATimesFn ATimes);

    int setPreconditioner(void *P_data, PSetupFn Pset, PSolveFn Psol);

    int setScalingVectors(N_Vector s, N_Vector nul);

    int getNumIters() const;

    realtype getResNorm() const;

    N_Vector getResid() const;
};


class SUNLinSolSPTFQMR : public SUNLinSolWrapper {
  public:
    explicit SUNLinSolSPTFQMR(N_Vector x, int pretype = PREC_NONE,
                              int maxl = SUNSPTFQMR_MAXL_DEFAULT);

    explicit SUNLinSolSPTFQMR(AmiVector const &x, int pretype = PREC_NONE,
                              int maxl = SUNSPTFQMR_MAXL_DEFAULT);

    int setATimes(void *A_data, ATimesFn ATimes);

    int setPreconditioner(void *P_data, PSetupFn Pset, PSolveFn Psol);

    int setScalingVectors(N_Vector s, N_Vector nul);

    int getNumIters() const;

    realtype getResNorm() const;

    N_Vector getResid() const;
};


class SUNNonLinSolWrapper {
  public:
    explicit SUNNonLinSolWrapper(SUNNonlinearSolver sol);

    virtual ~SUNNonLinSolWrapper();

    SUNNonLinSolWrapper(const SUNNonLinSolWrapper &other) = delete;

    SUNNonLinSolWrapper(SUNNonLinSolWrapper &&other) noexcept;

    SUNNonLinSolWrapper &operator=(const SUNNonLinSolWrapper &other) = delete;

    SUNNonLinSolWrapper &operator=(SUNNonLinSolWrapper &&other) noexcept;

    SUNNonlinearSolver get() const;

    SUNNonlinearSolver_Type getType() const;

    int setup(N_Vector y, void *mem);

    int Solve(N_Vector y0, N_Vector y, N_Vector w, realtype tol,
              bool callLSetup, void *mem);

    int setSysFn(SUNNonlinSolSysFn SysFn);

    int setLSetupFn(SUNNonlinSolLSetupFn SetupFn);

    int setLSolveFn(SUNNonlinSolLSolveFn SolveFn);

    int setConvTestFn(SUNNonlinSolConvTestFn CTestFn, void* ctest_data);

    int setMaxIters(int maxiters);

    long int getNumIters() const;

    int getCurIter() const;

    long int getNumConvFails() const;

  protected:
    void initialize();

    SUNNonlinearSolver solver = nullptr;
};


class SUNNonLinSolNewton : public SUNNonLinSolWrapper {
  public:
    explicit SUNNonLinSolNewton(N_Vector x);

    SUNNonLinSolNewton(int count, N_Vector x);

    int getSysFn(SUNNonlinSolSysFn *SysFn) const;
};


class SUNNonLinSolFixedPoint : public SUNNonLinSolWrapper {
  public:
    explicit SUNNonLinSolFixedPoint(const_N_Vector x, int m = 0);

    SUNNonLinSolFixedPoint(int count, const_N_Vector x, int m = 0);

    int getSysFn(SUNNonlinSolSysFn *SysFn) const;
};

} // namespace amici
#endif // AMICI_SUNDIALS_LINSOL_WRAPPER_H