Program Listing for File forwardproblem.h
↰ Return to documentation for file (include/amici/forwardproblem.h
)
#ifndef AMICI_FORWARDPROBLEM_H
#define AMICI_FORWARDPROBLEM_H
#include "amici/defines.h"
#include "amici/misc.h"
#include "amici/model.h"
#include "amici/vector.h"
#include <amici/amici.h>
#include <sundials/sundials_direct.h>
#include <vector>
namespace amici {
class ExpData;
class Solver;
class SteadystateProblem;
class FinalStateStorer;
class ForwardProblem {
public:
ForwardProblem(
ExpData const* edata, Model* model, Solver* solver,
SteadystateProblem const* preeq
);
~ForwardProblem() = default;
friend ::amici::FinalStateStorer;
void workForwardProblem();
void getAdjointUpdates(Model& model, ExpData const& edata);
realtype getTime() const { return t_; }
AmiVector const& getState() const { return x_; }
AmiVector const& getStateDerivative() const { return dx_; }
AmiVectorArray const& getStateSensitivity() const { return sx_; }
std::vector<AmiVector> const& getStatesAtDiscontinuities() const {
return x_disc_;
}
std::vector<AmiVector> const& getRHSAtDiscontinuities() const {
return xdot_disc_;
}
std::vector<AmiVector> const& getRHSBeforeDiscontinuities() const {
return xdot_old_disc_;
}
std::vector<int> const& getNumberOfRoots() const { return nroots_; }
std::vector<realtype> const& getDiscontinuities() const { return discs_; }
std::vector<std::vector<int>> const& getRootIndexes() const {
return root_idx_;
}
std::vector<realtype> const& getDJydx() const { return dJydx_; }
std::vector<realtype> const& getDJzdx() const { return dJzdx_; }
AmiVector* getStatePointer() { return &x_; }
AmiVector* getStateDerivativePointer() { return &dx_; }
AmiVectorArray* getStateSensitivityPointer() { return &sx_; }
AmiVectorArray* getStateDerivativeSensitivityPointer() { return &sdx_; }
int getCurrentTimeIteration() const { return it_; }
realtype getFinalTime() const { return final_state_.t; }
int getEventCounter() const {
return gsl::narrow<int>(event_states_.size()) - 1;
}
int getRootCounter() const { return gsl::narrow<int>(discs_.size()) - 1; }
SimulationState const& getSimulationStateTimepoint(int it) const {
if (model->getTimepoint(it) == initial_state_.t)
return getInitialSimulationState();
auto map_iter = timepoint_states_.find(model->getTimepoint(it));
assert(map_iter != timepoint_states_.end());
return map_iter->second;
};
SimulationState const& getSimulationStateEvent(int iroot) const {
return event_states_.at(iroot);
};
SimulationState const& getInitialSimulationState() const {
return initial_state_;
};
SimulationState const& getFinalSimulationState() const {
return final_state_;
};
Model* model;
Solver* solver;
ExpData const* edata;
private:
void handlePresimulation();
void handleEvent(realtype* tlastroot, bool seflag, bool initial_event);
void store_pre_event_state(bool seflag, bool initial_event);
void handle_secondary_event(realtype* tlastroot);
void storeEvent();
void handleDataPoint(realtype t);
void applyEventBolus();
void applyEventSensiBolusFSA();
bool checkEventsToFill(int nmaxevent) const {
return std::any_of(
nroots_.cbegin(), nroots_.cend(),
[nmaxevent](int curNRoots) { return curNRoots < nmaxevent; }
);
};
void fillEvents(int nmaxevent) {
if (checkEventsToFill(nmaxevent)) {
discs_.push_back(t_);
storeEvent();
}
}
SimulationState getSimulationState();
std::vector<std::vector<int>> root_idx_;
std::vector<int> nroots_;
std::vector<realtype> rootvals_;
std::vector<realtype> rval_tmp_;
std::vector<realtype> discs_;
std::vector<realtype> irdiscs_;
std::vector<AmiVector> x_disc_;
std::vector<AmiVector> xdot_disc_;
std::vector<AmiVector> xdot_old_disc_;
std::vector<realtype> dJydx_;
std::vector<realtype> dJzdx_;
realtype t_;
std::vector<int> roots_found_;
std::map<realtype, SimulationState> timepoint_states_;
std::vector<SimulationState> event_states_;
SimulationState initial_state_;
SimulationState final_state_;
AmiVector x_;
AmiVector x_old_;
AmiVector dx_;
AmiVector dx_old_;
AmiVector xdot_;
AmiVector xdot_old_;
AmiVectorArray sx_;
AmiVectorArray sdx_;
std::vector<realtype> stau_;
realtype tlastroot_{0.0};
bool preequilibrated_{false};
int it_;
};
class FinalStateStorer : public ContextManager {
public:
explicit FinalStateStorer(ForwardProblem* fwd)
: fwd_(fwd) {}
FinalStateStorer& operator=(FinalStateStorer const& other) = delete;
~FinalStateStorer() {
if (fwd_)
fwd_->final_state_ = fwd_->getSimulationState();
}
private:
ForwardProblem* fwd_;
};
} // namespace amici
#endif // FORWARDPROBLEM_H