Program Listing for File forwardproblem.h

Return to documentation for file (include/amici/forwardproblem.h)

#ifndef AMICI_FORWARDPROBLEM_H
#define AMICI_FORWARDPROBLEM_H

#include "amici/defines.h"
#include "amici/misc.h"
#include "amici/model.h"
#include "amici/vector.h"
#include <amici/amici.h>

#include <sundials/sundials_direct.h>
#include <vector>

namespace amici {

class ExpData;
class Solver;
class SteadystateProblem;
class FinalStateStorer;

class ForwardProblem {
  public:
    ForwardProblem(
        ExpData const* edata, Model* model, Solver* solver,
        SteadystateProblem const* preeq
    );

    ~ForwardProblem() = default;

    friend ::amici::FinalStateStorer;

    void workForwardProblem();

    void getAdjointUpdates(Model& model, ExpData const& edata);

    realtype getTime() const { return t_; }

    AmiVector const& getState() const { return x_; }

    AmiVector const& getStateDerivative() const { return dx_; }

    AmiVectorArray const& getStateSensitivity() const { return sx_; }

    std::vector<AmiVector> const& getStatesAtDiscontinuities() const {
        return x_disc_;
    }

    std::vector<AmiVector> const& getRHSAtDiscontinuities() const {
        return xdot_disc_;
    }

    std::vector<AmiVector> const& getRHSBeforeDiscontinuities() const {
        return xdot_old_disc_;
    }

    std::vector<int> const& getNumberOfRoots() const { return nroots_; }

    std::vector<realtype> const& getDiscontinuities() const { return discs_; }

    std::vector<std::vector<int>> const& getRootIndexes() const {
        return root_idx_;
    }

    std::vector<realtype> const& getDJydx() const { return dJydx_; }

    std::vector<realtype> const& getDJzdx() const { return dJzdx_; }

    AmiVector* getStatePointer() { return &x_; }

    AmiVector* getStateDerivativePointer() { return &dx_; }

    AmiVectorArray* getStateSensitivityPointer() { return &sx_; }

    AmiVectorArray* getStateDerivativeSensitivityPointer() { return &sdx_; }

    int getCurrentTimeIteration() const { return it_; }

    realtype getFinalTime() const { return final_state_.t; }

    int getEventCounter() const {
        return gsl::narrow<int>(event_states_.size()) - 1;
    }

    int getRootCounter() const { return gsl::narrow<int>(discs_.size()) - 1; }

    SimulationState const& getSimulationStateTimepoint(int it) const {
        if (model->getTimepoint(it) == initial_state_.t)
            return getInitialSimulationState();
        auto map_iter = timepoint_states_.find(model->getTimepoint(it));
        assert(map_iter != timepoint_states_.end());
        return map_iter->second;
    };

    SimulationState const& getSimulationStateEvent(int iroot) const {
        return event_states_.at(iroot);
    };

    SimulationState const& getInitialSimulationState() const {
        return initial_state_;
    };

    SimulationState const& getFinalSimulationState() const {
        return final_state_;
    };

    Model* model;

    Solver* solver;

    ExpData const* edata;

  private:
    void handlePresimulation();

    void handleEvent(realtype* tlastroot, bool seflag, bool initial_event);

    void store_pre_event_state(bool seflag, bool initial_event);

    void handle_secondary_event(realtype* tlastroot);

    void storeEvent();

    void handleDataPoint(realtype t);

    void applyEventBolus();

    void applyEventSensiBolusFSA();

    bool checkEventsToFill(int nmaxevent) const {
        return std::any_of(
            nroots_.cbegin(), nroots_.cend(),
            [nmaxevent](int curNRoots) { return curNRoots < nmaxevent; }
        );
    };

    void fillEvents(int nmaxevent) {
        if (checkEventsToFill(nmaxevent)) {
            discs_.push_back(t_);
            storeEvent();
        }
    }

    SimulationState getSimulationState();

    std::vector<std::vector<int>> root_idx_;

    std::vector<int> nroots_;

    std::vector<realtype> rootvals_;

    std::vector<realtype> rval_tmp_;

    std::vector<realtype> discs_;

    std::vector<realtype> irdiscs_;

    std::vector<AmiVector> x_disc_;

    std::vector<AmiVector> xdot_disc_;

    std::vector<AmiVector> xdot_old_disc_;

    std::vector<realtype> dJydx_;

    std::vector<realtype> dJzdx_;

    realtype t_;

    std::vector<int> roots_found_;

    std::map<realtype, SimulationState> timepoint_states_;

    std::vector<SimulationState> event_states_;

    SimulationState initial_state_;

    SimulationState final_state_;

    AmiVector x_;

    AmiVector x_old_;

    AmiVector dx_;

    AmiVector dx_old_;

    AmiVector xdot_;

    AmiVector xdot_old_;

    AmiVectorArray sx_;

    AmiVectorArray sdx_;

    std::vector<realtype> stau_;

    realtype tlastroot_{0.0};

    bool preequilibrated_{false};

    int it_;
};

class FinalStateStorer : public ContextManager {
  public:
    explicit FinalStateStorer(ForwardProblem* fwd)
        : fwd_(fwd) {}

    FinalStateStorer& operator=(FinalStateStorer const& other) = delete;

    ~FinalStateStorer() {
        if (fwd_)
            fwd_->final_state_ = fwd_->getSimulationState();
    }

  private:
    ForwardProblem* fwd_;
};

} // namespace amici

#endif // FORWARDPROBLEM_H