Program Listing for File solver.h

Return to documentation for file (include/amici/solver.h)

#ifndef AMICI_SOLVER_H
#define AMICI_SOLVER_H

#include "amici/defines.h"
#include "amici/logging.h"
#include "amici/misc.h"
#include "amici/sundials_linsol_wrapper.h"
#include "amici/vector.h"

#include <chrono>
#include <cmath>
#include <functional>
#include <memory>

namespace amici {

class ReturnData;
class ForwardProblem;
class BackwardProblem;
class Model;
class Solver;

} // namespace amici

// for serialization friend in Solver
namespace boost {
namespace serialization {
template <class Archive>
void serialize(Archive& ar, amici::Solver& s, unsigned int version);
}
} // namespace boost

namespace amici {

class Solver {
  public:
    using user_data_type = std::pair<Model*, Solver const*>;
    using free_solver_ptr = std::function<void(void*)>;
    Solver() = default;

    Solver(Solver const& other);

    virtual ~Solver() = default;

    virtual Solver* clone() const = 0;

    int run(realtype tout) const;

    int step(realtype tout) const;

    void runB(realtype tout) const;

    void setup(
        realtype t0, Model* model, AmiVector const& x0, AmiVector const& dx0,
        AmiVectorArray const& sx0, AmiVectorArray const& sdx0
    ) const;

    void setupB(
        int* which, realtype tf, Model* model, AmiVector const& xB0,
        AmiVector const& dxB0, AmiVector const& xQB0
    ) const;

    void setupSteadystate(
        realtype const t0, Model* model, AmiVector const& x0,
        AmiVector const& dx0, AmiVector const& xB0, AmiVector const& dxB0,
        AmiVector const& xQ0
    ) const;

    void updateAndReinitStatesAndSensitivities(Model* model) const;

    virtual void getRootInfo(int* rootsfound) const = 0;

    virtual void calcIC(realtype tout1) const = 0;

    virtual void calcICB(int which, realtype tout1) const = 0;

    virtual void solveB(realtype tBout, int itaskB) const = 0;

    virtual void turnOffRootFinding() const = 0;

    SensitivityMethod getSensitivityMethod() const;

    void setSensitivityMethod(SensitivityMethod sensi_meth);

    SensitivityMethod getSensitivityMethodPreequilibration() const;

    void setSensitivityMethodPreequilibration(SensitivityMethod sensi_meth_preeq
    );

    void switchForwardSensisOff() const;

    int getNewtonMaxSteps() const;

    void setNewtonMaxSteps(int newton_maxsteps);

    NewtonDampingFactorMode getNewtonDampingFactorMode() const;

    void setNewtonDampingFactorMode(NewtonDampingFactorMode dampingFactorMode);

    double getNewtonDampingFactorLowerBound() const;

    void setNewtonDampingFactorLowerBound(double dampingFactorLowerBound);

    SensitivityOrder getSensitivityOrder() const;

    void setSensitivityOrder(SensitivityOrder sensi);

    double getRelativeTolerance() const;

    void setRelativeTolerance(double rtol);

    double getAbsoluteTolerance() const;

    void setAbsoluteTolerance(double atol);

    double getRelativeToleranceFSA() const;

    void setRelativeToleranceFSA(double rtol);

    double getAbsoluteToleranceFSA() const;

    void setAbsoluteToleranceFSA(double atol);

    double getRelativeToleranceB() const;

    void setRelativeToleranceB(double rtol);

    double getAbsoluteToleranceB() const;

    void setAbsoluteToleranceB(double atol);

    double getRelativeToleranceQuadratures() const;

    void setRelativeToleranceQuadratures(double rtol);

    double getAbsoluteToleranceQuadratures() const;

    void setAbsoluteToleranceQuadratures(double atol);

    double getSteadyStateToleranceFactor() const;

    void setSteadyStateToleranceFactor(double factor);

    double getRelativeToleranceSteadyState() const;

    void setRelativeToleranceSteadyState(double rtol);

    double getAbsoluteToleranceSteadyState() const;

    void setAbsoluteToleranceSteadyState(double atol);

    double getSteadyStateSensiToleranceFactor() const;

    void setSteadyStateSensiToleranceFactor(double factor);

    double getRelativeToleranceSteadyStateSensi() const;

    void setRelativeToleranceSteadyStateSensi(double rtol);

    double getAbsoluteToleranceSteadyStateSensi() const;

    void setAbsoluteToleranceSteadyStateSensi(double atol);

    long int getMaxSteps() const;

    void setMaxSteps(long int maxsteps);

    double getMaxTime() const;

    void setMaxTime(double maxtime);

    void startTimer() const;

    bool timeExceeded(int interval = 1) const;

    long int getMaxStepsBackwardProblem() const;

    void setMaxStepsBackwardProblem(long int maxsteps);

    LinearMultistepMethod getLinearMultistepMethod() const;

    void setLinearMultistepMethod(LinearMultistepMethod lmm);

    NonlinearSolverIteration getNonlinearSolverIteration() const;

    void setNonlinearSolverIteration(NonlinearSolverIteration iter);

    InterpolationType getInterpolationType() const;

    void setInterpolationType(InterpolationType interpType);

    int getStateOrdering() const;

    void setStateOrdering(int ordering);

    bool getStabilityLimitFlag() const;

    void setStabilityLimitFlag(bool stldet);

    LinearSolver getLinearSolver() const;

    void setLinearSolver(LinearSolver linsol);

    InternalSensitivityMethod getInternalSensitivityMethod() const;

    void setInternalSensitivityMethod(InternalSensitivityMethod ism);

    RDataReporting getReturnDataReportingMode() const;

    void setReturnDataReportingMode(RDataReporting rdrm);

    void writeSolution(
        realtype* t, AmiVector& x, AmiVector& dx, AmiVectorArray& sx,
        AmiVector& xQ
    ) const;

    void writeSolutionB(
        realtype* t, AmiVector& xB, AmiVector& dxB, AmiVector& xQB, int which
    ) const;

    AmiVector const& getState(realtype t) const;

    AmiVector const& getDerivativeState(realtype t) const;

    AmiVectorArray const& getStateSensitivity(realtype t) const;

    AmiVector const& getAdjointState(int which, realtype t) const;

    AmiVector const& getAdjointDerivativeState(int which, realtype t) const;

    AmiVector const& getAdjointQuadrature(int which, realtype t) const;

    AmiVector const& getQuadrature(realtype t) const;

    virtual void
    reInit(realtype t0, AmiVector const& yy0, AmiVector const& yp0) const
        = 0;

    virtual void
    sensReInit(AmiVectorArray const& yyS0, AmiVectorArray const& ypS0) const
        = 0;

    virtual void sensToggleOff() const = 0;

    virtual void reInitB(
        int which, realtype tB0, AmiVector const& yyB0, AmiVector const& ypB0
    ) const
        = 0;

    virtual void quadReInitB(int which, AmiVector const& yQB0) const = 0;

    realtype gett() const;

    realtype getCpuTime() const;

    realtype getCpuTimeB() const;

    int nx() const;

    int nplist() const;

    int nquad() const;

    bool computingFSA() const {
        return getSensitivityOrder() >= SensitivityOrder::first
               && getSensitivityMethod() == SensitivityMethod::forward
               && nplist() > 0;
    }

    bool computingASA() const {
        return getSensitivityOrder() >= SensitivityOrder::first
               && getSensitivityMethod() == SensitivityMethod::adjoint
               && nplist() > 0;
    }

    void resetDiagnosis() const;

    void storeDiagnosis() const;

    void storeDiagnosisB(int which) const;

    std::vector<int> const& getNumSteps() const { return ns_; }

    std::vector<int> const& getNumStepsB() const { return nsB_; }

    std::vector<int> const& getNumRhsEvals() const { return nrhs_; }

    std::vector<int> const& getNumRhsEvalsB() const { return nrhsB_; }

    std::vector<int> const& getNumErrTestFails() const { return netf_; }

    std::vector<int> const& getNumErrTestFailsB() const { return netfB_; }

    std::vector<int> const& getNumNonlinSolvConvFails() const {
        return nnlscf_;
    }

    std::vector<int> const& getNumNonlinSolvConvFailsB() const {
        return nnlscfB_;
    }

    std::vector<int> const& getLastOrder() const { return order_; }

    bool getNewtonStepSteadyStateCheck() const {
        return newton_step_steadystate_conv_;
    }

    bool getSensiSteadyStateCheck() const {
        return check_sensi_steadystate_conv_;
    }

    void setNewtonStepSteadyStateCheck(bool flag) {
        newton_step_steadystate_conv_ = flag;
    }

    void setSensiSteadyStateCheck(bool flag) {
        check_sensi_steadystate_conv_ = flag;
    }

    void setMaxNonlinIters(int max_nonlin_iters);

    int getMaxNonlinIters() const;

    void setMaxConvFails(int max_conv_fails);

    int getMaxConvFails() const;

    void setConstraints(std::vector<realtype> const& constraints);

    std::vector<realtype> getConstraints() const {
        return constraints_.getVector();
    }

    void setMaxStepSize(realtype max_step_size);

    realtype getMaxStepSize() const;

    template <class Archive>
    friend void boost::serialization::serialize(
        Archive& ar, Solver& s, unsigned int version
    );

    friend bool operator==(Solver const& a, Solver const& b);

    Logger* logger = nullptr;

  protected:
    virtual void setStopTime(realtype tstop) const = 0;

    virtual int solve(realtype tout, int itask) const = 0;

    virtual int solveF(realtype tout, int itask, int* ncheckPtr) const = 0;

    virtual void reInitPostProcessF(realtype tnext) const = 0;

    virtual void reInitPostProcessB(realtype tnext) const = 0;

    virtual void getSens() const = 0;

    virtual void getB(int which) const = 0;

    virtual void getQuadB(int which) const = 0;

    virtual void getQuad(realtype& t) const = 0;

    virtual void
    init(realtype t0, AmiVector const& x0, AmiVector const& dx0) const
        = 0;

    virtual void initSteadystate(
        realtype t0, AmiVector const& x0, AmiVector const& dx0
    ) const
        = 0;

    virtual void
    sensInit1(AmiVectorArray const& sx0, AmiVectorArray const& sdx0) const
        = 0;

    virtual void binit(
        int which, realtype tf, AmiVector const& xB0, AmiVector const& dxB0
    ) const
        = 0;

    virtual void qbinit(int which, AmiVector const& xQB0) const = 0;

    virtual void rootInit(int ne) const = 0;

    void initializeNonLinearSolverSens(Model const* model) const;

    virtual void setDenseJacFn() const = 0;

    virtual void setSparseJacFn() const = 0;

    virtual void setBandJacFn() const = 0;

    virtual void setJacTimesVecFn() const = 0;

    virtual void setDenseJacFnB(int which) const = 0;

    virtual void setSparseJacFnB(int which) const = 0;

    virtual void setBandJacFnB(int which) const = 0;

    virtual void setJacTimesVecFnB(int which) const = 0;

    virtual void setSparseJacFn_ss() const = 0;

    virtual void allocateSolver() const = 0;

    virtual void setSStolerances(double rtol, double atol) const = 0;

    virtual void setSensSStolerances(double rtol, double const* atol) const = 0;

    virtual void setSensErrCon(bool error_corr) const = 0;

    virtual void setQuadErrConB(int which, bool flag) const = 0;

    virtual void setQuadErrCon(bool flag) const = 0;

    virtual void setErrHandlerFn() const = 0;

    virtual void setUserData() const = 0;

    virtual void setUserDataB(int which) const = 0;

    virtual void setMaxNumSteps(long int mxsteps) const = 0;

    virtual void setMaxNumStepsB(int which, long int mxstepsB) const = 0;

    virtual void setStabLimDet(int stldet) const = 0;

    virtual void setStabLimDetB(int which, int stldet) const = 0;

    virtual void setId(Model const* model) const = 0;

    virtual void setSuppressAlg(bool flag) const = 0;

    virtual void setSensParams(
        realtype const* p, realtype const* pbar, int const* plist
    ) const
        = 0;

    virtual void getDky(realtype t, int k) const = 0;

    virtual void getDkyB(realtype t, int k, int which) const = 0;

    virtual void getSensDky(realtype t, int k) const = 0;

    virtual void getQuadDkyB(realtype t, int k, int which) const = 0;

    virtual void getQuadDky(realtype t, int k) const = 0;

    virtual void adjInit() const = 0;

    virtual void quadInit(AmiVector const& xQ0) const = 0;

    virtual void allocateSolverB(int* which) const = 0;

    virtual void
    setSStolerancesB(int which, realtype relTolB, realtype absTolB) const
        = 0;

    virtual void
    quadSStolerancesB(int which, realtype reltolQB, realtype abstolQB) const
        = 0;

    virtual void quadSStolerances(realtype reltolQB, realtype abstolQB) const
        = 0;

    virtual void getNumSteps(void const* ami_mem, long int* numsteps) const = 0;

    virtual void
    getNumRhsEvals(void const* ami_mem, long int* numrhsevals) const
        = 0;

    virtual void
    getNumErrTestFails(void const* ami_mem, long int* numerrtestfails) const
        = 0;

    virtual void getNumNonlinSolvConvFails(
        void const* ami_mem, long int* numnonlinsolvconvfails
    ) const
        = 0;

    virtual void getLastOrder(void const* ami_mem, int* order) const = 0;

    void initializeLinearSolver(Model const* model) const;

    void initializeNonLinearSolver() const;

    virtual void setLinearSolver() const = 0;

    virtual void setLinearSolverB(int which) const = 0;

    virtual void setNonLinearSolver() const = 0;

    virtual void setNonLinearSolverB(int which) const = 0;

    virtual void setNonLinearSolverSens() const = 0;

    void initializeLinearSolverB(Model const* model, int which) const;

    void initializeNonLinearSolverB(int which) const;

    virtual Model const* getModel() const = 0;

    bool getInitDone() const;

    bool getSensInitDone() const;

    bool getAdjInitDone() const;

    bool getInitDoneB(int which) const;

    bool getQuadInitDoneB(int which) const;

    bool getQuadInitDone() const;

    virtual void diag() const = 0;

    virtual void diagB(int which) const = 0;

    void resetMutableMemory(int nx, int nplist, int nquad) const;

    virtual void* getAdjBmem(void* ami_mem, int which) const = 0;

    void applyTolerances() const;

    void applyTolerancesFSA() const;

    void applyTolerancesASA(int which) const;

    void applyQuadTolerancesASA(int which) const;

    void applyQuadTolerances() const;

    void applySensitivityTolerances() const;

    virtual void apply_constraints() const;

    mutable std::unique_ptr<void, free_solver_ptr> solver_memory_;

    mutable std::vector<std::unique_ptr<void, free_solver_ptr>>
        solver_memory_B_;

    mutable user_data_type user_data;

    InternalSensitivityMethod ism_{InternalSensitivityMethod::simultaneous};

    LinearMultistepMethod lmm_{LinearMultistepMethod::BDF};

    NonlinearSolverIteration iter_{NonlinearSolverIteration::newton};

    InterpolationType interp_type_{InterpolationType::polynomial};

    long int maxsteps_{10000};

    std::chrono::duration<double, std::ratio<1>> maxtime_{0};

    mutable CpuTimer simulation_timer_;

    mutable std::unique_ptr<SUNLinSolWrapper> linear_solver_;

    mutable std::unique_ptr<SUNLinSolWrapper> linear_solver_B_;

    mutable std::unique_ptr<SUNNonLinSolWrapper> non_linear_solver_;

    mutable std::unique_ptr<SUNNonLinSolWrapper> non_linear_solver_B_;

    mutable std::unique_ptr<SUNNonLinSolWrapper> non_linear_solver_sens_;

    mutable bool solver_was_called_F_{false};

    mutable bool solver_was_called_B_{false};

    void setInitDone() const;

    void setSensInitDone() const;

    void setSensInitOff() const;

    void setAdjInitDone() const;

    void setInitDoneB(int which) const;

    void setQuadInitDoneB(int which) const;

    void setQuadInitDone() const;

    void checkSensitivityMethod(
        SensitivityMethod const sensi_meth, bool preequilibration
    ) const;

    virtual void apply_max_nonlin_iters() const = 0;

    virtual void apply_max_conv_fails() const = 0;

    virtual void apply_max_step_size() const = 0;

    mutable AmiVector x_{0};

    mutable AmiVector dky_{0};

    mutable AmiVector dx_{0};

    mutable AmiVectorArray sx_{0, 0};
    mutable AmiVectorArray sdx_{0, 0};

    mutable AmiVector xB_{0};

    mutable AmiVector dxB_{0};

    mutable AmiVector xQB_{0};

    mutable AmiVector xQ_{0};

    mutable realtype t_{std::nan("")};

    mutable bool force_reinit_postprocess_F_{false};

    mutable bool force_reinit_postprocess_B_{false};

    mutable bool sens_initialized_{false};

    mutable AmiVector constraints_;

  private:
    void apply_max_num_steps() const;

    void apply_max_num_steps_B() const;

    SensitivityMethod sensi_meth_{SensitivityMethod::forward};

    SensitivityMethod sensi_meth_preeq_{SensitivityMethod::forward};

    booleantype stldet_{true};

    int ordering_{static_cast<int>(SUNLinSolKLU::StateOrdering::AMD)};

    long int newton_maxsteps_{0L};

    long int newton_maxlinsteps_{0L};

    NewtonDampingFactorMode newton_damping_factor_mode_{
        NewtonDampingFactorMode::on
    };

    realtype newton_damping_factor_lower_bound_{1e-8};

    LinearSolver linsol_{LinearSolver::KLU};

    realtype atol_{1e-16};

    realtype rtol_{1e-8};

    realtype atol_fsa_{NAN};

    realtype rtol_fsa_{NAN};

    realtype atolB_{NAN};

    realtype rtolB_{NAN};

    realtype quad_atol_{1e-12};

    realtype quad_rtol_{1e-8};

    realtype ss_tol_factor_{1e2};

    realtype ss_atol_{NAN};

    realtype ss_rtol_{NAN};

    realtype ss_tol_sensi_factor_{1e2};

    realtype ss_atol_sensi_{NAN};

    realtype ss_rtol_sensi_{NAN};

    RDataReporting rdata_mode_{RDataReporting::full};

    bool newton_step_steadystate_conv_{false};

    bool check_sensi_steadystate_conv_{true};

    int max_nonlin_iters_{3};

    int max_conv_fails_{10};

    realtype max_step_size_{0.0};

    mutable realtype cpu_time_{0.0};

    mutable realtype cpu_timeB_{0.0};

    long int maxstepsB_{0L};

    SensitivityOrder sensi_{SensitivityOrder::none};

    mutable bool initialized_{false};

    mutable bool adj_initialized_{false};

    mutable bool quad_initialized_{false};

    mutable std::vector<bool> initializedB_{false};

    mutable std::vector<bool> initializedQB_{false};

    mutable int ncheckPtr_{0};

    mutable std::vector<int> ns_;

    mutable std::vector<int> nsB_;

    mutable std::vector<int> nrhs_;

    mutable std::vector<int> nrhsB_;

    mutable std::vector<int> netf_;

    mutable std::vector<int> netfB_;

    mutable std::vector<int> nnlscf_;

    mutable std::vector<int> nnlscfB_;

    mutable std::vector<int> order_;
};

bool operator==(Solver const& a, Solver const& b);

void wrapErrHandlerFn(
    int error_code, char const* module, char const* function, char* msg,
    void* eh_data
);

} // namespace amici

#endif // AMICISOLVER_H