Program Listing for File vector.h

Return to documentation for file (include/amici/vector.h)

#ifndef AMICI_VECTOR_H
#define AMICI_VECTOR_H

#include <type_traits>
#include <vector>

#include <amici/exception.h>

#include <nvector/nvector_serial.h>

#include <gsl/gsl-lite.hpp>

namespace amici {
class AmiVector;
}

// for serialization friend
namespace boost {
namespace serialization {
template <class Archive>
void serialize(Archive& ar, amici::AmiVector& s, unsigned int version);
}
} // namespace boost

namespace amici {

using const_N_Vector
    = std::add_const_t<typename std::remove_pointer_t<N_Vector>>*;

inline realtype const* N_VGetArrayPointerConst(const_N_Vector x) {
    return N_VGetArrayPointer(const_cast<N_Vector>(x));
}

class AmiVector {
  public:
    AmiVector() = default;

    explicit AmiVector(long int const length)
        : vec_(static_cast<decltype(vec_)::size_type>(length), 0.0)
        , nvec_(N_VMake_Serial(length, vec_.data())) {}

    explicit AmiVector(std::vector<realtype> rvec)
        : vec_(std::move(rvec))
        , nvec_(N_VMake_Serial(gsl::narrow<long int>(vec_.size()), vec_.data())
          ) {}

    explicit AmiVector(gsl::span<realtype const> rvec)
        : AmiVector(std::vector<realtype>(rvec.begin(), rvec.end())) {}

    AmiVector(AmiVector const& vold)
        : vec_(vold.vec_) {
        nvec_ = N_VMake_Serial(
            gsl::narrow<long int>(vold.vec_.size()), vec_.data()
        );
    }

    AmiVector(AmiVector&& other) noexcept
        : nvec_(nullptr) {
        vec_ = std::move(other.vec_);
        synchroniseNVector();
    }

    ~AmiVector();

    AmiVector& operator=(AmiVector const& other);

    AmiVector& operator*=(AmiVector const& multiplier) {
        N_VProd(
            getNVector(), const_cast<N_Vector>(multiplier.getNVector()),
            getNVector()
        );
        return *this;
    }

    AmiVector& operator/=(AmiVector const& divisor) {
        N_VDiv(
            getNVector(), const_cast<N_Vector>(divisor.getNVector()),
            getNVector()
        );
        return *this;
    }

    auto begin() { return vec_.begin(); }

    auto end() { return vec_.end(); }

    realtype* data();

    realtype const* data() const;

    N_Vector getNVector();

    const_N_Vector getNVector() const;

    std::vector<realtype> const& getVector() const;

    int getLength() const;

    void zero();

    void minus();

    void set(realtype val);

    realtype& operator[](int pos);
    realtype& at(int pos);

    realtype const& at(int pos) const;

    void copy(AmiVector const& other);

    void abs() { N_VAbs(getNVector(), getNVector()); };

    template <class Archive>
    friend void boost::serialization::serialize(
        Archive& ar, AmiVector& s, unsigned int version
    );

  private:
    std::vector<realtype> vec_;

    N_Vector nvec_{nullptr};

    void synchroniseNVector();
};

class AmiVectorArray {
  public:
    AmiVectorArray() = default;

    AmiVectorArray(long int length_inner, long int length_outer);

    AmiVectorArray(AmiVectorArray const& vaold);

    ~AmiVectorArray() = default;

    AmiVectorArray& operator=(AmiVectorArray const& other);

    realtype* data(int pos);

    realtype const* data(int pos) const;

    realtype& at(int ipos, int jpos);

    realtype const& at(int ipos, int jpos) const;

    N_Vector* getNVectorArray();

    N_Vector getNVector(int pos);

    const_N_Vector getNVector(int pos) const;

    AmiVector& operator[](int pos);

    AmiVector const& operator[](int pos) const;

    int getLength() const;

    void zero();

    void flatten_to_vector(std::vector<realtype>& vec) const;

    void copy(AmiVectorArray const& other);

  private:
    std::vector<AmiVector> vec_array_;

    std::vector<N_Vector> nvec_array_;
};

inline void linearSum(
    realtype a, AmiVector const& x, realtype b, AmiVector const& y, AmiVector& z
) {
    N_VLinearSum(
        a, const_cast<N_Vector>(x.getNVector()), b,
        const_cast<N_Vector>(y.getNVector()), z.getNVector()
    );
}

inline realtype dotProd(AmiVector const& x, AmiVector const& y) {
    return N_VDotProd(
        const_cast<N_Vector>(x.getNVector()),
        const_cast<N_Vector>(y.getNVector())
    );
}

} // namespace amici

namespace gsl {
inline span<realtype> make_span(N_Vector nv) {
    return span<realtype>(N_VGetArrayPointer(nv), N_VGetLength_Serial(nv));
}

inline span<realtype const> make_span(amici::AmiVector const& av) {
    return make_span(av.getVector());
}

} // namespace gsl

#endif /* AMICI_VECTOR_H */