Program Listing for File solver.h
↰ Return to documentation for file (include/amici/solver.h
)
#ifndef AMICI_SOLVER_H
#define AMICI_SOLVER_H
#include "amici/defines.h"
#include "amici/logging.h"
#include "amici/misc.h"
#include "amici/sundials_linsol_wrapper.h"
#include "amici/vector.h"
#include <chrono>
#include <cmath>
#include <functional>
#include <memory>
namespace amici {
class ReturnData;
class ForwardProblem;
class BackwardProblem;
class Model;
class Solver;
} // namespace amici
// for serialization friend in Solver
namespace boost {
namespace serialization {
template <class Archive>
void serialize(Archive& ar, amici::Solver& s, unsigned int version);
}
} // namespace boost
namespace amici {
class Solver {
public:
using user_data_type = std::pair<Model*, Solver const*>;
using free_solver_ptr = std::function<void(void*)>;
Solver() = default;
Solver(Solver const& other);
virtual ~Solver() = default;
virtual Solver* clone() const = 0;
int run(realtype tout) const;
int step(realtype tout) const;
void runB(realtype tout) const;
void setup(
realtype t0, Model* model, AmiVector const& x0, AmiVector const& dx0,
AmiVectorArray const& sx0, AmiVectorArray const& sdx0
) const;
void setupB(
int* which, realtype tf, Model* model, AmiVector const& xB0,
AmiVector const& dxB0, AmiVector const& xQB0
) const;
void setupSteadystate(
realtype const t0, Model* model, AmiVector const& x0,
AmiVector const& dx0, AmiVector const& xB0, AmiVector const& dxB0,
AmiVector const& xQ0
) const;
void updateAndReinitStatesAndSensitivities(Model* model) const;
virtual void getRootInfo(int* rootsfound) const = 0;
virtual void calcIC(realtype tout1) const = 0;
virtual void calcICB(int which, realtype tout1) const = 0;
virtual void solveB(realtype tBout, int itaskB) const = 0;
virtual void turnOffRootFinding() const = 0;
SensitivityMethod getSensitivityMethod() const;
void setSensitivityMethod(SensitivityMethod sensi_meth);
SensitivityMethod getSensitivityMethodPreequilibration() const;
void setSensitivityMethodPreequilibration(SensitivityMethod sensi_meth_preeq
);
void switchForwardSensisOff() const;
int getNewtonMaxSteps() const;
void setNewtonMaxSteps(int newton_maxsteps);
NewtonDampingFactorMode getNewtonDampingFactorMode() const;
void setNewtonDampingFactorMode(NewtonDampingFactorMode dampingFactorMode);
double getNewtonDampingFactorLowerBound() const;
void setNewtonDampingFactorLowerBound(double dampingFactorLowerBound);
SensitivityOrder getSensitivityOrder() const;
void setSensitivityOrder(SensitivityOrder sensi);
double getRelativeTolerance() const;
void setRelativeTolerance(double rtol);
double getAbsoluteTolerance() const;
void setAbsoluteTolerance(double atol);
double getRelativeToleranceFSA() const;
void setRelativeToleranceFSA(double rtol);
double getAbsoluteToleranceFSA() const;
void setAbsoluteToleranceFSA(double atol);
double getRelativeToleranceB() const;
void setRelativeToleranceB(double rtol);
double getAbsoluteToleranceB() const;
void setAbsoluteToleranceB(double atol);
double getRelativeToleranceQuadratures() const;
void setRelativeToleranceQuadratures(double rtol);
double getAbsoluteToleranceQuadratures() const;
void setAbsoluteToleranceQuadratures(double atol);
double getSteadyStateToleranceFactor() const;
void setSteadyStateToleranceFactor(double factor);
double getRelativeToleranceSteadyState() const;
void setRelativeToleranceSteadyState(double rtol);
double getAbsoluteToleranceSteadyState() const;
void setAbsoluteToleranceSteadyState(double atol);
double getSteadyStateSensiToleranceFactor() const;
void setSteadyStateSensiToleranceFactor(double factor);
double getRelativeToleranceSteadyStateSensi() const;
void setRelativeToleranceSteadyStateSensi(double rtol);
double getAbsoluteToleranceSteadyStateSensi() const;
void setAbsoluteToleranceSteadyStateSensi(double atol);
long int getMaxSteps() const;
void setMaxSteps(long int maxsteps);
double getMaxTime() const;
void setMaxTime(double maxtime);
void startTimer() const;
bool timeExceeded(int interval = 1) const;
long int getMaxStepsBackwardProblem() const;
void setMaxStepsBackwardProblem(long int maxsteps);
LinearMultistepMethod getLinearMultistepMethod() const;
void setLinearMultistepMethod(LinearMultistepMethod lmm);
NonlinearSolverIteration getNonlinearSolverIteration() const;
void setNonlinearSolverIteration(NonlinearSolverIteration iter);
InterpolationType getInterpolationType() const;
void setInterpolationType(InterpolationType interpType);
int getStateOrdering() const;
void setStateOrdering(int ordering);
bool getStabilityLimitFlag() const;
void setStabilityLimitFlag(bool stldet);
LinearSolver getLinearSolver() const;
void setLinearSolver(LinearSolver linsol);
InternalSensitivityMethod getInternalSensitivityMethod() const;
void setInternalSensitivityMethod(InternalSensitivityMethod ism);
RDataReporting getReturnDataReportingMode() const;
void setReturnDataReportingMode(RDataReporting rdrm);
void writeSolution(
realtype* t, AmiVector& x, AmiVector& dx, AmiVectorArray& sx,
AmiVector& xQ
) const;
void writeSolutionB(
realtype* t, AmiVector& xB, AmiVector& dxB, AmiVector& xQB, int which
) const;
AmiVector const& getState(realtype t) const;
AmiVector const& getDerivativeState(realtype t) const;
AmiVectorArray const& getStateSensitivity(realtype t) const;
AmiVector const& getAdjointState(int which, realtype t) const;
AmiVector const& getAdjointDerivativeState(int which, realtype t) const;
AmiVector const& getAdjointQuadrature(int which, realtype t) const;
AmiVector const& getQuadrature(realtype t) const;
virtual void
reInit(realtype t0, AmiVector const& yy0, AmiVector const& yp0) const
= 0;
virtual void
sensReInit(AmiVectorArray const& yyS0, AmiVectorArray const& ypS0) const
= 0;
virtual void sensToggleOff() const = 0;
virtual void reInitB(
int which, realtype tB0, AmiVector const& yyB0, AmiVector const& ypB0
) const
= 0;
virtual void quadReInitB(int which, AmiVector const& yQB0) const = 0;
realtype gett() const;
realtype getCpuTime() const;
realtype getCpuTimeB() const;
int nx() const;
int nplist() const;
int nquad() const;
bool computingFSA() const {
return getSensitivityOrder() >= SensitivityOrder::first
&& getSensitivityMethod() == SensitivityMethod::forward
&& nplist() > 0;
}
bool computingASA() const {
return getSensitivityOrder() >= SensitivityOrder::first
&& getSensitivityMethod() == SensitivityMethod::adjoint
&& nplist() > 0;
}
void resetDiagnosis() const;
void storeDiagnosis() const;
void storeDiagnosisB(int which) const;
std::vector<int> const& getNumSteps() const { return ns_; }
std::vector<int> const& getNumStepsB() const { return nsB_; }
std::vector<int> const& getNumRhsEvals() const { return nrhs_; }
std::vector<int> const& getNumRhsEvalsB() const { return nrhsB_; }
std::vector<int> const& getNumErrTestFails() const { return netf_; }
std::vector<int> const& getNumErrTestFailsB() const { return netfB_; }
std::vector<int> const& getNumNonlinSolvConvFails() const {
return nnlscf_;
}
std::vector<int> const& getNumNonlinSolvConvFailsB() const {
return nnlscfB_;
}
std::vector<int> const& getLastOrder() const { return order_; }
bool getNewtonStepSteadyStateCheck() const {
return newton_step_steadystate_conv_;
}
bool getSensiSteadyStateCheck() const {
return check_sensi_steadystate_conv_;
}
void setNewtonStepSteadyStateCheck(bool flag) {
newton_step_steadystate_conv_ = flag;
}
void setSensiSteadyStateCheck(bool flag) {
check_sensi_steadystate_conv_ = flag;
}
void setMaxNonlinIters(int max_nonlin_iters);
int getMaxNonlinIters() const;
void setMaxConvFails(int max_conv_fails);
int getMaxConvFails() const;
void setConstraints(std::vector<realtype> const& constraints);
std::vector<realtype> getConstraints() const {
return constraints_.getVector();
}
void setMaxStepSize(realtype max_step_size);
realtype getMaxStepSize() const;
template <class Archive>
friend void boost::serialization::serialize(
Archive& ar, Solver& s, unsigned int version
);
friend bool operator==(Solver const& a, Solver const& b);
Logger* logger = nullptr;
protected:
virtual void setStopTime(realtype tstop) const = 0;
virtual int solve(realtype tout, int itask) const = 0;
virtual int solveF(realtype tout, int itask, int* ncheckPtr) const = 0;
virtual void reInitPostProcessF(realtype tnext) const = 0;
virtual void reInitPostProcessB(realtype tnext) const = 0;
virtual void getSens() const = 0;
virtual void getB(int which) const = 0;
virtual void getQuadB(int which) const = 0;
virtual void getQuad(realtype& t) const = 0;
virtual void
init(realtype t0, AmiVector const& x0, AmiVector const& dx0) const
= 0;
virtual void initSteadystate(
realtype t0, AmiVector const& x0, AmiVector const& dx0
) const
= 0;
virtual void
sensInit1(AmiVectorArray const& sx0, AmiVectorArray const& sdx0) const
= 0;
virtual void binit(
int which, realtype tf, AmiVector const& xB0, AmiVector const& dxB0
) const
= 0;
virtual void qbinit(int which, AmiVector const& xQB0) const = 0;
virtual void rootInit(int ne) const = 0;
void initializeNonLinearSolverSens(Model const* model) const;
virtual void setDenseJacFn() const = 0;
virtual void setSparseJacFn() const = 0;
virtual void setBandJacFn() const = 0;
virtual void setJacTimesVecFn() const = 0;
virtual void setDenseJacFnB(int which) const = 0;
virtual void setSparseJacFnB(int which) const = 0;
virtual void setBandJacFnB(int which) const = 0;
virtual void setJacTimesVecFnB(int which) const = 0;
virtual void setSparseJacFn_ss() const = 0;
virtual void allocateSolver() const = 0;
virtual void setSStolerances(double rtol, double atol) const = 0;
virtual void setSensSStolerances(double rtol, double const* atol) const = 0;
virtual void setSensErrCon(bool error_corr) const = 0;
virtual void setQuadErrConB(int which, bool flag) const = 0;
virtual void setQuadErrCon(bool flag) const = 0;
virtual void setErrHandlerFn() const = 0;
virtual void setUserData() const = 0;
virtual void setUserDataB(int which) const = 0;
virtual void setMaxNumSteps(long int mxsteps) const = 0;
virtual void setMaxNumStepsB(int which, long int mxstepsB) const = 0;
virtual void setStabLimDet(int stldet) const = 0;
virtual void setStabLimDetB(int which, int stldet) const = 0;
virtual void setId(Model const* model) const = 0;
virtual void setSuppressAlg(bool flag) const = 0;
virtual void setSensParams(
realtype const* p, realtype const* pbar, int const* plist
) const
= 0;
virtual void getDky(realtype t, int k) const = 0;
virtual void getDkyB(realtype t, int k, int which) const = 0;
virtual void getSensDky(realtype t, int k) const = 0;
virtual void getQuadDkyB(realtype t, int k, int which) const = 0;
virtual void getQuadDky(realtype t, int k) const = 0;
virtual void adjInit() const = 0;
virtual void quadInit(AmiVector const& xQ0) const = 0;
virtual void allocateSolverB(int* which) const = 0;
virtual void
setSStolerancesB(int which, realtype relTolB, realtype absTolB) const
= 0;
virtual void
quadSStolerancesB(int which, realtype reltolQB, realtype abstolQB) const
= 0;
virtual void quadSStolerances(realtype reltolQB, realtype abstolQB) const
= 0;
virtual void getNumSteps(void const* ami_mem, long int* numsteps) const = 0;
virtual void
getNumRhsEvals(void const* ami_mem, long int* numrhsevals) const
= 0;
virtual void
getNumErrTestFails(void const* ami_mem, long int* numerrtestfails) const
= 0;
virtual void getNumNonlinSolvConvFails(
void const* ami_mem, long int* numnonlinsolvconvfails
) const
= 0;
virtual void getLastOrder(void const* ami_mem, int* order) const = 0;
void initializeLinearSolver(Model const* model) const;
void initializeNonLinearSolver() const;
virtual void setLinearSolver() const = 0;
virtual void setLinearSolverB(int which) const = 0;
virtual void setNonLinearSolver() const = 0;
virtual void setNonLinearSolverB(int which) const = 0;
virtual void setNonLinearSolverSens() const = 0;
void initializeLinearSolverB(Model const* model, int which) const;
void initializeNonLinearSolverB(int which) const;
virtual Model const* getModel() const = 0;
bool getInitDone() const;
bool getSensInitDone() const;
bool getAdjInitDone() const;
bool getInitDoneB(int which) const;
bool getQuadInitDoneB(int which) const;
bool getQuadInitDone() const;
virtual void diag() const = 0;
virtual void diagB(int which) const = 0;
void resetMutableMemory(int nx, int nplist, int nquad) const;
virtual void* getAdjBmem(void* ami_mem, int which) const = 0;
void applyTolerances() const;
void applyTolerancesFSA() const;
void applyTolerancesASA(int which) const;
void applyQuadTolerancesASA(int which) const;
void applyQuadTolerances() const;
void applySensitivityTolerances() const;
virtual void apply_constraints() const;
mutable std::unique_ptr<void, free_solver_ptr> solver_memory_;
mutable std::vector<std::unique_ptr<void, free_solver_ptr>>
solver_memory_B_;
mutable user_data_type user_data;
InternalSensitivityMethod ism_{InternalSensitivityMethod::simultaneous};
LinearMultistepMethod lmm_{LinearMultistepMethod::BDF};
NonlinearSolverIteration iter_{NonlinearSolverIteration::newton};
InterpolationType interp_type_{InterpolationType::polynomial};
long int maxsteps_{10000};
std::chrono::duration<double, std::ratio<1>> maxtime_{0};
mutable CpuTimer simulation_timer_;
mutable std::unique_ptr<SUNLinSolWrapper> linear_solver_;
mutable std::unique_ptr<SUNLinSolWrapper> linear_solver_B_;
mutable std::unique_ptr<SUNNonLinSolWrapper> non_linear_solver_;
mutable std::unique_ptr<SUNNonLinSolWrapper> non_linear_solver_B_;
mutable std::unique_ptr<SUNNonLinSolWrapper> non_linear_solver_sens_;
mutable bool solver_was_called_F_{false};
mutable bool solver_was_called_B_{false};
void setInitDone() const;
void setSensInitDone() const;
void setSensInitOff() const;
void setAdjInitDone() const;
void setInitDoneB(int which) const;
void setQuadInitDoneB(int which) const;
void setQuadInitDone() const;
void checkSensitivityMethod(
SensitivityMethod const sensi_meth, bool preequilibration
) const;
virtual void apply_max_nonlin_iters() const = 0;
virtual void apply_max_conv_fails() const = 0;
virtual void apply_max_step_size() const = 0;
mutable AmiVector x_{0};
mutable AmiVector dky_{0};
mutable AmiVector dx_{0};
mutable AmiVectorArray sx_{0, 0};
mutable AmiVectorArray sdx_{0, 0};
mutable AmiVector xB_{0};
mutable AmiVector dxB_{0};
mutable AmiVector xQB_{0};
mutable AmiVector xQ_{0};
mutable realtype t_{std::nan("")};
mutable bool force_reinit_postprocess_F_{false};
mutable bool force_reinit_postprocess_B_{false};
mutable bool sens_initialized_{false};
mutable AmiVector constraints_;
private:
void apply_max_num_steps() const;
void apply_max_num_steps_B() const;
SensitivityMethod sensi_meth_{SensitivityMethod::forward};
SensitivityMethod sensi_meth_preeq_{SensitivityMethod::forward};
booleantype stldet_{true};
int ordering_{static_cast<int>(SUNLinSolKLU::StateOrdering::AMD)};
long int newton_maxsteps_{0L};
long int newton_maxlinsteps_{0L};
NewtonDampingFactorMode newton_damping_factor_mode_{
NewtonDampingFactorMode::on
};
realtype newton_damping_factor_lower_bound_{1e-8};
LinearSolver linsol_{LinearSolver::KLU};
realtype atol_{1e-16};
realtype rtol_{1e-8};
realtype atol_fsa_{NAN};
realtype rtol_fsa_{NAN};
realtype atolB_{NAN};
realtype rtolB_{NAN};
realtype quad_atol_{1e-12};
realtype quad_rtol_{1e-8};
realtype ss_tol_factor_{1e2};
realtype ss_atol_{NAN};
realtype ss_rtol_{NAN};
realtype ss_tol_sensi_factor_{1e2};
realtype ss_atol_sensi_{NAN};
realtype ss_rtol_sensi_{NAN};
RDataReporting rdata_mode_{RDataReporting::full};
bool newton_step_steadystate_conv_{false};
bool check_sensi_steadystate_conv_{true};
int max_nonlin_iters_{3};
int max_conv_fails_{10};
realtype max_step_size_{0.0};
mutable realtype cpu_time_{0.0};
mutable realtype cpu_timeB_{0.0};
long int maxstepsB_{0L};
SensitivityOrder sensi_{SensitivityOrder::none};
mutable bool initialized_{false};
mutable bool adj_initialized_{false};
mutable bool quad_initialized_{false};
mutable std::vector<bool> initializedB_{false};
mutable std::vector<bool> initializedQB_{false};
mutable int ncheckPtr_{0};
mutable std::vector<int> ns_;
mutable std::vector<int> nsB_;
mutable std::vector<int> nrhs_;
mutable std::vector<int> nrhsB_;
mutable std::vector<int> netf_;
mutable std::vector<int> netfB_;
mutable std::vector<int> nnlscf_;
mutable std::vector<int> nnlscfB_;
mutable std::vector<int> order_;
};
bool operator==(Solver const& a, Solver const& b);
void wrapErrHandlerFn(
int error_code, char const* module, char const* function, char* msg,
void* eh_data
);
} // namespace amici
#endif // AMICISOLVER_H