Program Listing for File forwardproblem.h

Return to documentation for file (include/amici/forwardproblem.h)

#ifndef AMICI_FORWARDPROBLEM_H
#define AMICI_FORWARDPROBLEM_H

#include "amici/defines.h"
#include "amici/vector.h"
#include "amici/model.h"
#include "amici/misc.h"
#include "amici/sundials_matrix_wrapper.h"

#include <sundials/sundials_direct.h>
#include <vector>
#include <memory>

namespace amici {

class ExpData;
class Solver;
class SteadystateProblem;
class FinalStateStorer;

struct SimulationState{
    realtype t;
    AmiVector x;
    AmiVector dx;
    AmiVectorArray sx;
    ModelState state;
};


class ForwardProblem {
  public:
    ForwardProblem(const ExpData *edata, Model *model, Solver *solver,
                   const SteadystateProblem *preeq);

    ~ForwardProblem() = default;

    friend ::amici::FinalStateStorer;

    void workForwardProblem();

    void getAdjointUpdates(Model &model, const ExpData &edata);

    realtype getTime() const {
        return t_;
    }

    AmiVector const& getState() const {
        return x_;
    }

    AmiVector const& getStateDerivative() const {
        return dx_;
    }

    AmiVectorArray const& getStateSensitivity() const {
        return sx_;
    }

    std::vector<AmiVector> const& getStatesAtDiscontinuities() const {
        return x_disc_;
    }

    std::vector<AmiVector> const& getRHSAtDiscontinuities() const {
        return xdot_disc_;
    }

    std::vector<AmiVector> const& getRHSBeforeDiscontinuities() const {
        return xdot_old_disc_;
    }

    std::vector<int> const& getNumberOfRoots() const {
        return nroots_;
    }

    std::vector<realtype> const& getDiscontinuities() const {
        return discs_;
    }

    std::vector<std::vector<int>> const& getRootIndexes() const {
        return root_idx_;
    }

   std::vector<realtype> const& getDJydx() const {
        return dJydx_;
    }

    std::vector<realtype> const& getDJzdx() const {
        return dJzdx_;
    }

    AmiVector *getStatePointer() {
        return &x_;
    }

    AmiVector *getStateDerivativePointer() {
        return &dx_;
    }

    AmiVectorArray *getStateSensitivityPointer() {
        return &sx_;
    }

    AmiVectorArray *getStateDerivativeSensitivityPointer() {
        return &sdx_;
    }

    int getCurrentTimeIteration() const {
        return it_;
    }

    realtype getFinalTime() const {
        return final_state_.t;
    }

    int getEventCounter() const {
        return static_cast<int>(event_states_.size() - 1);
    }

    int getRootCounter() const {
        return static_cast<int>(discs_.size() - 1);
    }

    const SimulationState &getSimulationStateTimepoint(int it) const {
        if (model->getTimepoint(it) == initial_state_.t)
            return getInitialSimulationState();
        return timepoint_states_.find(model->getTimepoint(it))->second;
    };

    const SimulationState &getSimulationStateEvent(int iroot) const {
        return event_states_.at(iroot);
    };

    const SimulationState &getInitialSimulationState() const {
        return initial_state_;
    };

    const SimulationState &getFinalSimulationState() const {
        return final_state_;
    };

    Model *model;

    Solver *solver;

    const ExpData *edata;

  private:

    void handlePresimulation();

    void handleEvent(realtype *tlastroot,bool seflag);

    void storeEvent();

    void handleDataPoint(int it);

    void applyEventBolus();

    void applyEventSensiBolusFSA();

    bool checkEventsToFill(int nmaxevent) const {
        return std::any_of(nroots_.cbegin(), nroots_.cend(),
                           [nmaxevent](int curNRoots) {
                return curNRoots < nmaxevent;
        });
    };

    void fillEvents(int nmaxevent) {
        if (checkEventsToFill(nmaxevent)) {
            discs_.push_back(t_);
            storeEvent();
        }
    }

    SimulationState getSimulationState() const;

    std::vector<std::vector<int>> root_idx_;

    std::vector<int> nroots_;

    std::vector<realtype> rootvals_;

    std::vector<realtype> rval_tmp_;

    std::vector<realtype> discs_;

    std::vector<realtype> irdiscs_;

    std::vector<AmiVector> x_disc_;

    std::vector<AmiVector> xdot_disc_;

    std::vector<AmiVector> xdot_old_disc_;

    std::vector<realtype> dJydx_;

    std::vector<realtype> dJzdx_;

    realtype t_;

    std::vector<int> roots_found_;

    std::map<realtype, SimulationState> timepoint_states_;

    std::vector<SimulationState> event_states_;

    SimulationState initial_state_;

    SimulationState final_state_;

    AmiVector x_;

    AmiVector x_old_;

    AmiVector dx_;

    AmiVector dx_old_;

    AmiVector xdot_;

    AmiVector xdot_old_;

    AmiVectorArray sx_;

    AmiVectorArray sdx_;

    std::vector<realtype> stau_;

    realtype tlastroot_ {0.0};

    bool preequilibrated_ {false};

    int it_;

};

class FinalStateStorer : public ContextManager {
  public:
    explicit FinalStateStorer(ForwardProblem *fwd) : fwd_(fwd) {
    }

    FinalStateStorer &operator=(const FinalStateStorer &other) = delete;

    ~FinalStateStorer() {
        if(fwd_)
            fwd_->final_state_ = fwd_->getSimulationState();
    }
  private:
    ForwardProblem *fwd_;
};

} // namespace amici

#endif // FORWARDPROBLEM_H