Program Listing for File forwardproblem.h
↰ Return to documentation for file (include/amici/forwardproblem.h
)
#ifndef AMICI_FORWARDPROBLEM_H
#define AMICI_FORWARDPROBLEM_H
#include "amici/defines.h"
#include "amici/vector.h"
#include "amici/model.h"
#include "amici/misc.h"
#include "amici/sundials_matrix_wrapper.h"
#include <sundials/sundials_direct.h>
#include <vector>
#include <memory>
namespace amici {
class ExpData;
class Solver;
class SteadystateProblem;
class FinalStateStorer;
struct SimulationState{
realtype t;
AmiVector x;
AmiVector dx;
AmiVectorArray sx;
ModelState state;
};
class ForwardProblem {
public:
ForwardProblem(const ExpData *edata, Model *model, Solver *solver,
const SteadystateProblem *preeq);
~ForwardProblem() = default;
friend ::amici::FinalStateStorer;
void workForwardProblem();
void getAdjointUpdates(Model &model, const ExpData &edata);
realtype getTime() const {
return t_;
}
AmiVector const& getState() const {
return x_;
}
AmiVector const& getStateDerivative() const {
return dx_;
}
AmiVectorArray const& getStateSensitivity() const {
return sx_;
}
std::vector<AmiVector> const& getStatesAtDiscontinuities() const {
return x_disc_;
}
std::vector<AmiVector> const& getRHSAtDiscontinuities() const {
return xdot_disc_;
}
std::vector<AmiVector> const& getRHSBeforeDiscontinuities() const {
return xdot_old_disc_;
}
std::vector<int> const& getNumberOfRoots() const {
return nroots_;
}
std::vector<realtype> const& getDiscontinuities() const {
return discs_;
}
std::vector<std::vector<int>> const& getRootIndexes() const {
return root_idx_;
}
std::vector<realtype> const& getDJydx() const {
return dJydx_;
}
std::vector<realtype> const& getDJzdx() const {
return dJzdx_;
}
AmiVector *getStatePointer() {
return &x_;
}
AmiVector *getStateDerivativePointer() {
return &dx_;
}
AmiVectorArray *getStateSensitivityPointer() {
return &sx_;
}
AmiVectorArray *getStateDerivativeSensitivityPointer() {
return &sdx_;
}
int getCurrentTimeIteration() const {
return it_;
}
realtype getFinalTime() const {
return final_state_.t;
}
int getEventCounter() const {
return static_cast<int>(event_states_.size() - 1);
}
int getRootCounter() const {
return static_cast<int>(discs_.size() - 1);
}
const SimulationState &getSimulationStateTimepoint(int it) const {
if (model->getTimepoint(it) == initial_state_.t)
return getInitialSimulationState();
return timepoint_states_.find(model->getTimepoint(it))->second;
};
const SimulationState &getSimulationStateEvent(int iroot) const {
return event_states_.at(iroot);
};
const SimulationState &getInitialSimulationState() const {
return initial_state_;
};
const SimulationState &getFinalSimulationState() const {
return final_state_;
};
Model *model;
Solver *solver;
const ExpData *edata;
private:
void handlePresimulation();
void handleEvent(realtype *tlastroot, bool seflag,
bool initial_event);
void storeEvent();
void handleDataPoint(int it);
void applyEventBolus();
void applyEventSensiBolusFSA();
bool checkEventsToFill(int nmaxevent) const {
return std::any_of(nroots_.cbegin(), nroots_.cend(),
[nmaxevent](int curNRoots) {
return curNRoots < nmaxevent;
});
};
void fillEvents(int nmaxevent) {
if (checkEventsToFill(nmaxevent)) {
discs_.push_back(t_);
storeEvent();
}
}
SimulationState getSimulationState() const;
std::vector<std::vector<int>> root_idx_;
std::vector<int> nroots_;
std::vector<realtype> rootvals_;
std::vector<realtype> rval_tmp_;
std::vector<realtype> discs_;
std::vector<realtype> irdiscs_;
std::vector<AmiVector> x_disc_;
std::vector<AmiVector> xdot_disc_;
std::vector<AmiVector> xdot_old_disc_;
std::vector<realtype> dJydx_;
std::vector<realtype> dJzdx_;
realtype t_;
std::vector<int> roots_found_;
std::map<realtype, SimulationState> timepoint_states_;
std::vector<SimulationState> event_states_;
SimulationState initial_state_;
SimulationState final_state_;
AmiVector x_;
AmiVector x_old_;
AmiVector dx_;
AmiVector dx_old_;
AmiVector xdot_;
AmiVector xdot_old_;
AmiVectorArray sx_;
AmiVectorArray sdx_;
std::vector<realtype> stau_;
realtype tlastroot_ {0.0};
bool preequilibrated_ {false};
int it_;
};
class FinalStateStorer : public ContextManager {
public:
explicit FinalStateStorer(ForwardProblem *fwd) : fwd_(fwd) {
}
FinalStateStorer &operator=(const FinalStateStorer &other) = delete;
~FinalStateStorer() {
if(fwd_)
fwd_->final_state_ = fwd_->getSimulationState();
}
private:
ForwardProblem *fwd_;
};
} // namespace amici
#endif // FORWARDPROBLEM_H