Program Listing for File vector.h
↰ Return to documentation for file (include/amici/vector.h
)
#ifndef AMICI_VECTOR_H
#define AMICI_VECTOR_H
#include <type_traits>
#include <vector>
#include <amici/exception.h>
#include <nvector/nvector_serial.h>
#include <gsl/gsl-lite.hpp>
namespace amici {
class AmiVector;
}
// for serialization friend
namespace boost {
namespace serialization {
template <class Archive>
void serialize(Archive& ar, amici::AmiVector& s, unsigned int version);
}
} // namespace boost
namespace amici {
using const_N_Vector
= std::add_const_t<typename std::remove_pointer_t<N_Vector>>*;
inline realtype const* N_VGetArrayPointerConst(const_N_Vector x) {
return N_VGetArrayPointer(const_cast<N_Vector>(x));
}
class AmiVector {
public:
AmiVector() = default;
explicit AmiVector(long int const length)
: vec_(static_cast<decltype(vec_)::size_type>(length), 0.0)
, nvec_(N_VMake_Serial(length, vec_.data())) {}
explicit AmiVector(std::vector<realtype> rvec)
: vec_(std::move(rvec))
, nvec_(N_VMake_Serial(gsl::narrow<long int>(vec_.size()), vec_.data())
) {}
explicit AmiVector(gsl::span<realtype const> rvec)
: AmiVector(std::vector<realtype>(rvec.begin(), rvec.end())) {}
AmiVector(AmiVector const& vold)
: vec_(vold.vec_) {
nvec_ = N_VMake_Serial(
gsl::narrow<long int>(vold.vec_.size()), vec_.data()
);
}
AmiVector(AmiVector&& other) noexcept
: nvec_(nullptr) {
vec_ = std::move(other.vec_);
synchroniseNVector();
}
~AmiVector();
AmiVector& operator=(AmiVector const& other);
AmiVector& operator*=(AmiVector const& multiplier) {
N_VProd(
getNVector(), const_cast<N_Vector>(multiplier.getNVector()),
getNVector()
);
return *this;
}
AmiVector& operator/=(AmiVector const& divisor) {
N_VDiv(
getNVector(), const_cast<N_Vector>(divisor.getNVector()),
getNVector()
);
return *this;
}
auto begin() { return vec_.begin(); }
auto end() { return vec_.end(); }
realtype* data();
realtype const* data() const;
N_Vector getNVector();
const_N_Vector getNVector() const;
std::vector<realtype> const& getVector() const;
int getLength() const;
void zero();
void minus();
void set(realtype val);
realtype& operator[](int pos);
realtype& at(int pos);
realtype const& at(int pos) const;
void copy(AmiVector const& other);
void abs() { N_VAbs(getNVector(), getNVector()); };
template <class Archive>
friend void boost::serialization::serialize(
Archive& ar, AmiVector& s, unsigned int version
);
private:
std::vector<realtype> vec_;
N_Vector nvec_{nullptr};
void synchroniseNVector();
};
class AmiVectorArray {
public:
AmiVectorArray() = default;
AmiVectorArray(long int length_inner, long int length_outer);
AmiVectorArray(AmiVectorArray const& vaold);
~AmiVectorArray() = default;
AmiVectorArray& operator=(AmiVectorArray const& other);
realtype* data(int pos);
realtype const* data(int pos) const;
realtype& at(int ipos, int jpos);
realtype const& at(int ipos, int jpos) const;
N_Vector* getNVectorArray();
N_Vector getNVector(int pos);
const_N_Vector getNVector(int pos) const;
AmiVector& operator[](int pos);
AmiVector const& operator[](int pos) const;
int getLength() const;
void zero();
void flatten_to_vector(std::vector<realtype>& vec) const;
void copy(AmiVectorArray const& other);
private:
std::vector<AmiVector> vec_array_;
std::vector<N_Vector> nvec_array_;
};
inline void linearSum(
realtype a, AmiVector const& x, realtype b, AmiVector const& y, AmiVector& z
) {
N_VLinearSum(
a, const_cast<N_Vector>(x.getNVector()), b,
const_cast<N_Vector>(y.getNVector()), z.getNVector()
);
}
inline realtype dotProd(AmiVector const& x, AmiVector const& y) {
return N_VDotProd(
const_cast<N_Vector>(x.getNVector()),
const_cast<N_Vector>(y.getNVector())
);
}
} // namespace amici
namespace gsl {
inline span<realtype> make_span(N_Vector nv) {
return span<realtype>(N_VGetArrayPointer(nv), N_VGetLength_Serial(nv));
}
inline span<realtype const> make_span(amici::AmiVector const& av) {
return make_span(av.getVector());
}
} // namespace gsl
#endif /* AMICI_VECTOR_H */