Program Listing for File sundials_matrix_wrapper.h
↰ Return to documentation for file (include/amici/sundials_matrix_wrapper.h
)
#ifndef AMICI_SUNDIALS_MATRIX_WRAPPER_H
#define AMICI_SUNDIALS_MATRIX_WRAPPER_H
#include <sunmatrix/sunmatrix_band.h> // SUNMatrix_Band
#include <sunmatrix/sunmatrix_dense.h> // SUNMatrix_Dense
#include <sunmatrix/sunmatrix_sparse.h> // SUNMatrix_Sparse
#include <gsl/gsl-lite.hpp>
#include <algorithm>
#include <vector>
#include <assert.h>
#include "amici/vector.h"
namespace amici {
class SUNMatrixWrapper {
public:
SUNMatrixWrapper() = default;
SUNMatrixWrapper(
sunindextype M, sunindextype N, sunindextype NNZ, int sparsetype
);
SUNMatrixWrapper(sunindextype M, sunindextype N);
SUNMatrixWrapper(sunindextype M, sunindextype ubw, sunindextype lbw);
SUNMatrixWrapper(
SUNMatrixWrapper const& A, realtype droptol, int sparsetype
);
explicit SUNMatrixWrapper(SUNMatrix mat);
~SUNMatrixWrapper();
operator SUNMatrix() { return get(); };
SUNMatrixWrapper(SUNMatrixWrapper const& other);
SUNMatrixWrapper(SUNMatrixWrapper&& other);
SUNMatrixWrapper& operator=(SUNMatrixWrapper const& other);
SUNMatrixWrapper& operator=(SUNMatrixWrapper&& other);
void reallocate(sunindextype nnz);
void realloc();
SUNMatrix get() const;
sunindextype rows() const {
assert(
!matrix_
|| (matrix_id() == SUNMATRIX_SPARSE
? num_rows_ == SM_ROWS_S(matrix_)
: num_rows_ == SM_ROWS_D(matrix_))
);
return num_rows_;
}
sunindextype columns() const {
assert(
!matrix_
|| (matrix_id() == SUNMATRIX_SPARSE
? num_columns_ == SM_COLUMNS_S(matrix_)
: num_columns_ == SM_COLUMNS_D(matrix_))
);
return num_columns_;
}
sunindextype num_nonzeros() const;
sunindextype num_indexptrs() const;
sunindextype capacity() const;
realtype* data();
realtype const* data() const;
realtype get_data(sunindextype idx) const {
assert(matrix_);
assert(matrix_id() == SUNMATRIX_SPARSE);
assert(idx < capacity());
assert(SM_DATA_S(matrix_) == data_);
return data_[idx];
}
realtype get_data(sunindextype irow, sunindextype icol) const {
assert(matrix_);
assert(matrix_id() == SUNMATRIX_DENSE);
assert(irow < rows());
assert(icol < columns());
return SM_ELEMENT_D(matrix_, irow, icol);
}
void set_data(sunindextype idx, realtype data) {
assert(matrix_);
assert(matrix_id() == SUNMATRIX_SPARSE);
assert(idx < capacity());
assert(SM_DATA_S(matrix_) == data_);
data_[idx] = data;
}
void set_data(sunindextype irow, sunindextype icol, realtype data) {
assert(matrix_);
assert(matrix_id() == SUNMATRIX_DENSE);
assert(irow < rows());
assert(icol < columns());
SM_ELEMENT_D(matrix_, irow, icol) = data;
}
sunindextype get_indexval(sunindextype idx) const {
assert(matrix_);
assert(matrix_id() == SUNMATRIX_SPARSE);
assert(idx < capacity());
assert(indexvals_ == SM_INDEXVALS_S(matrix_));
return indexvals_[idx];
}
void set_indexval(sunindextype idx, sunindextype val) {
assert(matrix_);
assert(matrix_id() == SUNMATRIX_SPARSE);
assert(idx < capacity());
assert(indexvals_ == SM_INDEXVALS_S(matrix_));
indexvals_[idx] = val;
}
void set_indexvals(gsl::span<sunindextype const> const vals) {
assert(matrix_);
assert(matrix_id() == SUNMATRIX_SPARSE);
assert(gsl::narrow<sunindextype>(vals.size()) == capacity());
assert(indexvals_ == SM_INDEXVALS_S(matrix_));
std::copy_n(vals.begin(), capacity(), indexvals_);
}
sunindextype get_indexptr(sunindextype ptr_idx) const {
assert(matrix_);
assert(matrix_id() == SUNMATRIX_SPARSE);
assert(ptr_idx <= num_indexptrs());
assert(indexptrs_ == SM_INDEXPTRS_S(matrix_));
return indexptrs_[ptr_idx];
}
void set_indexptr(sunindextype ptr_idx, sunindextype ptr) {
assert(matrix_);
assert(matrix_id() == SUNMATRIX_SPARSE);
assert(ptr_idx <= num_indexptrs());
assert(ptr <= capacity());
assert(indexptrs_ == SM_INDEXPTRS_S(matrix_));
indexptrs_[ptr_idx] = ptr;
if (ptr_idx == num_indexptrs())
num_nonzeros_ = ptr;
}
void set_indexptrs(gsl::span<sunindextype const> const ptrs) {
assert(matrix_);
assert(matrix_id() == SUNMATRIX_SPARSE);
assert(gsl::narrow<sunindextype>(ptrs.size()) == num_indexptrs() + 1);
assert(indexptrs_ == SM_INDEXPTRS_S(matrix_));
std::copy_n(ptrs.begin(), num_indexptrs() + 1, indexptrs_);
num_nonzeros_ = indexptrs_[num_indexptrs()];
}
int sparsetype() const;
void scale(realtype a);
void multiply(N_Vector c, const_N_Vector b, realtype alpha = 1.0) const;
void
multiply(AmiVector& c, AmiVector const& b, realtype alpha = 1.0) const {
multiply(c.getNVector(), b.getNVector(), alpha);
}
void multiply(
gsl::span<realtype> c, gsl::span<realtype const> b,
realtype const alpha = 1.0
) const;
void multiply(
N_Vector c, const_N_Vector b, gsl::span<int const> cols, bool transpose
) const;
void multiply(
gsl::span<realtype> c, gsl::span<realtype const> b,
gsl::span<int const> cols, bool transpose
) const;
void sparse_multiply(SUNMatrixWrapper& C, SUNMatrixWrapper const& B) const;
void sparse_add(
SUNMatrixWrapper const& A, realtype alpha, SUNMatrixWrapper const& B,
realtype beta
);
void sparse_sum(std::vector<SUNMatrixWrapper> const& mats);
sunindextype scatter(
sunindextype const k, realtype const beta, sunindextype* w,
gsl::span<realtype> x, sunindextype const mark, SUNMatrixWrapper* C,
sunindextype nnz
) const;
void transpose(
SUNMatrixWrapper& C, realtype const alpha, sunindextype blocksize
) const;
void to_dense(SUNMatrixWrapper& D) const;
void to_diag(N_Vector v) const;
void zero();
SUNMatrix_ID matrix_id() const { return id_; };
void refresh();
private:
SUNMatrix matrix_{nullptr};
SUNMatrix_ID id_{SUNMATRIX_CUSTOM};
int sparsetype_{CSC_MAT};
sunindextype num_nonzeros_{0};
sunindextype capacity_{0};
realtype* data_{nullptr};
sunindextype* indexptrs_{nullptr};
sunindextype* indexvals_{nullptr};
sunindextype num_rows_{0};
sunindextype num_columns_{0};
sunindextype num_indexptrs_{0};
void finish_init();
void update_ptrs();
void update_size();
bool ownmat = true;
};
auto unravel_index(sunindextype i, SUNMatrix m)
-> std::pair<sunindextype, sunindextype>;
} // namespace amici
namespace gsl {
inline span<realtype> make_span(SUNMatrix m) {
switch (SUNMatGetID(m)) {
case SUNMATRIX_DENSE:
return span<realtype>(SM_DATA_D(m), SM_LDATA_D(m));
case SUNMATRIX_SPARSE:
return span<realtype>(SM_DATA_S(m), SM_NNZ_S(m));
default:
throw amici::AmiException("Unimplemented SUNMatrix type for make_span");
}
}
} // namespace gsl
#endif // AMICI_SUNDIALS_MATRIX_WRAPPER_H