Program Listing for File sundials_linsol_wrapper.h

Return to documentation for file (include/amici/sundials_linsol_wrapper.h)

#ifndef AMICI_SUNDIALS_LINSOL_WRAPPER_H
#define AMICI_SUNDIALS_LINSOL_WRAPPER_H

#include "amici/sundials_matrix_wrapper.h"
#include "amici/vector.h"

#include <sundials/sundials_config.h>
#include <sunlinsol/sunlinsol_band.h>
#include <sunlinsol/sunlinsol_dense.h>
#include <sunlinsol/sunlinsol_klu.h>
#include <sunlinsol/sunlinsol_pcg.h>
#include <sunlinsol/sunlinsol_spbcgs.h>
#include <sunlinsol/sunlinsol_spfgmr.h>
#include <sunlinsol/sunlinsol_spgmr.h>
#include <sunlinsol/sunlinsol_sptfqmr.h>
#ifdef SUNDIALS_SUPERLUMT
#include <sunlinsol/sunlinsol_superlumt.h>
#endif
#include <sunnonlinsol/sunnonlinsol_fixedpoint.h>
#include <sunnonlinsol/sunnonlinsol_newton.h>

namespace amici {

class SUNLinSolWrapper {
  public:
    SUNLinSolWrapper() = default;

    explicit SUNLinSolWrapper(SUNLinearSolver linsol);

    virtual ~SUNLinSolWrapper();

    SUNLinSolWrapper(SUNLinSolWrapper const& other) = delete;

    SUNLinSolWrapper(SUNLinSolWrapper&& other) noexcept;

    SUNLinSolWrapper& operator=(SUNLinSolWrapper const& other) = delete;

    SUNLinSolWrapper& operator=(SUNLinSolWrapper&& other) noexcept;

    SUNLinearSolver get() const;

    SUNLinearSolver_Type getType() const;

    void setup(SUNMatrix A) const;

    void setup(SUNMatrixWrapper const& A) const;

    int Solve(SUNMatrix A, N_Vector x, N_Vector b, realtype tol) const;

    long int getLastFlag() const;

    int space(long int* lenrwLS, long int* leniwLS) const;

    virtual SUNMatrix getMatrix() const;

  protected:
    int initialize();

    SUNLinearSolver solver_{nullptr};
};

class SUNLinSolBand : public SUNLinSolWrapper {
  public:
    SUNLinSolBand(N_Vector x, SUNMatrix A);

    SUNLinSolBand(AmiVector const& x, int ubw, int lbw);

    SUNMatrix getMatrix() const override;

  private:
    SUNMatrixWrapper A_;
};

class SUNLinSolDense : public SUNLinSolWrapper {
  public:
    explicit SUNLinSolDense(AmiVector const& x);

    SUNMatrix getMatrix() const override;

  private:
    SUNMatrixWrapper A_;
};

class SUNLinSolKLU : public SUNLinSolWrapper {
  public:
    enum class StateOrdering { AMD, COLAMD, natural };

    SUNLinSolKLU(N_Vector x, SUNMatrix A);

    SUNLinSolKLU(
        AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering
    );

    SUNMatrix getMatrix() const override;

    void reInit(int nnz, int reinit_type);

    void setOrdering(StateOrdering ordering);

  private:
    SUNMatrixWrapper A_;
};

#ifdef SUNDIALS_SUPERLUMT
class SUNLinSolSuperLUMT : public SUNLinSolWrapper {
  public:
    enum class StateOrdering {
        natural,
        minDegATA,
        minDegATPlusA,
        COLAMD,
    };

    SUNLinSolSuperLUMT(N_Vector x, SUNMatrix A, int numThreads);

    SUNLinSolSuperLUMT(
        AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering
    );

    SUNLinSolSuperLUMT(
        AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering,
        int numThreads
    );

    SUNMatrix getMatrix() const override;

    void setOrdering(StateOrdering ordering);

  private:
    SUNMatrixWrapper A;
};

#endif

class SUNLinSolPCG : public SUNLinSolWrapper {
  public:
    SUNLinSolPCG(N_Vector y, int pretype, int maxl);

    int setATimes(void* A_data, ATimesFn ATimes);

    int setPreconditioner(void* P_data, PSetupFn Pset, PSolveFn Psol);

    int setScalingVectors(N_Vector s, N_Vector nul);

    int getNumIters() const;

    realtype getResNorm() const;

    N_Vector getResid() const;
};

class SUNLinSolSPBCGS : public SUNLinSolWrapper {
  public:
    explicit SUNLinSolSPBCGS(
        N_Vector x, int pretype = PREC_NONE, int maxl = SUNSPBCGS_MAXL_DEFAULT
    );

    explicit SUNLinSolSPBCGS(
        AmiVector const& x, int pretype = PREC_NONE,
        int maxl = SUNSPBCGS_MAXL_DEFAULT
    );

    int setATimes(void* A_data, ATimesFn ATimes);

    int setPreconditioner(void* P_data, PSetupFn Pset, PSolveFn Psol);

    int setScalingVectors(N_Vector s, N_Vector nul);

    int getNumIters() const;

    realtype getResNorm() const;

    N_Vector getResid() const;
};

class SUNLinSolSPFGMR : public SUNLinSolWrapper {
  public:
    SUNLinSolSPFGMR(AmiVector const& x, int pretype, int maxl);

    int setATimes(void* A_data, ATimesFn ATimes);

    int setPreconditioner(void* P_data, PSetupFn Pset, PSolveFn Psol);

    int setScalingVectors(N_Vector s, N_Vector nul);

    int getNumIters() const;

    realtype getResNorm() const;

    N_Vector getResid() const;
};

class SUNLinSolSPGMR : public SUNLinSolWrapper {
  public:
    explicit SUNLinSolSPGMR(
        AmiVector const& x, int pretype = PREC_NONE,
        int maxl = SUNSPGMR_MAXL_DEFAULT
    );

    int setATimes(void* A_data, ATimesFn ATimes);

    int setPreconditioner(void* P_data, PSetupFn Pset, PSolveFn Psol);

    int setScalingVectors(N_Vector s, N_Vector nul);

    int getNumIters() const;

    realtype getResNorm() const;

    N_Vector getResid() const;
};

class SUNLinSolSPTFQMR : public SUNLinSolWrapper {
  public:
    explicit SUNLinSolSPTFQMR(
        N_Vector x, int pretype = PREC_NONE, int maxl = SUNSPTFQMR_MAXL_DEFAULT
    );

    explicit SUNLinSolSPTFQMR(
        AmiVector const& x, int pretype = PREC_NONE,
        int maxl = SUNSPTFQMR_MAXL_DEFAULT
    );

    int setATimes(void* A_data, ATimesFn ATimes);

    int setPreconditioner(void* P_data, PSetupFn Pset, PSolveFn Psol);

    int setScalingVectors(N_Vector s, N_Vector nul);

    int getNumIters() const;

    realtype getResNorm() const;

    N_Vector getResid() const;
};

class SUNNonLinSolWrapper {
  public:
    explicit SUNNonLinSolWrapper(SUNNonlinearSolver sol);

    virtual ~SUNNonLinSolWrapper();

    SUNNonLinSolWrapper(SUNNonLinSolWrapper const& other) = delete;

    SUNNonLinSolWrapper(SUNNonLinSolWrapper&& other) noexcept;

    SUNNonLinSolWrapper& operator=(SUNNonLinSolWrapper const& other) = delete;

    SUNNonLinSolWrapper& operator=(SUNNonLinSolWrapper&& other) noexcept;

    SUNNonlinearSolver get() const;

    SUNNonlinearSolver_Type getType() const;

    int setup(N_Vector y, void* mem);

    int Solve(
        N_Vector y0, N_Vector y, N_Vector w, realtype tol, bool callLSetup,
        void* mem
    );

    int setSysFn(SUNNonlinSolSysFn SysFn);

    int setLSetupFn(SUNNonlinSolLSetupFn SetupFn);

    int setLSolveFn(SUNNonlinSolLSolveFn SolveFn);

    int setConvTestFn(SUNNonlinSolConvTestFn CTestFn, void* ctest_data);

    int setMaxIters(int maxiters);

    long int getNumIters() const;

    int getCurIter() const;

    long int getNumConvFails() const;

  protected:
    void initialize();

    SUNNonlinearSolver solver = nullptr;
};

class SUNNonLinSolNewton : public SUNNonLinSolWrapper {
  public:
    explicit SUNNonLinSolNewton(N_Vector x);

    SUNNonLinSolNewton(int count, N_Vector x);

    int getSysFn(SUNNonlinSolSysFn* SysFn) const;
};

class SUNNonLinSolFixedPoint : public SUNNonLinSolWrapper {
  public:
    explicit SUNNonLinSolFixedPoint(const_N_Vector x, int m = 0);

    SUNNonLinSolFixedPoint(int count, const_N_Vector x, int m = 0);

    int getSysFn(SUNNonlinSolSysFn* SysFn) const;
};

} // namespace amici
#endif // AMICI_SUNDIALS_LINSOL_WRAPPER_H