Program Listing for File sundials_matrix_wrapper.h

Return to documentation for file (include/amici/sundials_matrix_wrapper.h)

#ifndef AMICI_SUNDIALS_MATRIX_WRAPPER_H
#define AMICI_SUNDIALS_MATRIX_WRAPPER_H

#include <sunmatrix/sunmatrix_band.h>   // SUNMatrix_Band
#include <sunmatrix/sunmatrix_dense.h>  // SUNMatrix_Dense
#include <sunmatrix/sunmatrix_sparse.h> // SUNMatrix_Sparse

#include <gsl/gsl-lite.hpp>

#include <algorithm>
#include <vector>

#include <assert.h>

#include "amici/vector.h"

namespace amici {

class SUNMatrixWrapper {
  public:
    SUNMatrixWrapper() = default;

    SUNMatrixWrapper(
        sunindextype M, sunindextype N, sunindextype NNZ, int sparsetype
    );

    SUNMatrixWrapper(sunindextype M, sunindextype N);

    SUNMatrixWrapper(sunindextype M, sunindextype ubw, sunindextype lbw);

    SUNMatrixWrapper(
        SUNMatrixWrapper const& A, realtype droptol, int sparsetype
    );

    explicit SUNMatrixWrapper(SUNMatrix mat);

    ~SUNMatrixWrapper();

    operator SUNMatrix() { return get(); };

    SUNMatrixWrapper(SUNMatrixWrapper const& other);

    SUNMatrixWrapper(SUNMatrixWrapper&& other);

    SUNMatrixWrapper& operator=(SUNMatrixWrapper const& other);

    SUNMatrixWrapper& operator=(SUNMatrixWrapper&& other);

    void reallocate(sunindextype nnz);

    void realloc();

    SUNMatrix get() const;

    sunindextype rows() const {
        assert(
            !matrix_
            || (matrix_id() == SUNMATRIX_SPARSE
                    ? num_rows_ == SM_ROWS_S(matrix_)
                    : num_rows_ == SM_ROWS_D(matrix_))
        );
        return num_rows_;
    }

    sunindextype columns() const {
        assert(
            !matrix_
            || (matrix_id() == SUNMATRIX_SPARSE
                    ? num_columns_ == SM_COLUMNS_S(matrix_)
                    : num_columns_ == SM_COLUMNS_D(matrix_))
        );
        return num_columns_;
    }

    sunindextype num_nonzeros() const;

    sunindextype num_indexptrs() const;

    sunindextype capacity() const;

    realtype* data();

    realtype const* data() const;

    realtype get_data(sunindextype idx) const {
        assert(matrix_);
        assert(matrix_id() == SUNMATRIX_SPARSE);
        assert(idx < capacity());
        assert(SM_DATA_S(matrix_) == data_);
        return data_[idx];
    }

    realtype get_data(sunindextype irow, sunindextype icol) const {
        assert(matrix_);
        assert(matrix_id() == SUNMATRIX_DENSE);
        assert(irow < rows());
        assert(icol < columns());
        return SM_ELEMENT_D(matrix_, irow, icol);
    }

    void set_data(sunindextype idx, realtype data) {
        assert(matrix_);
        assert(matrix_id() == SUNMATRIX_SPARSE);
        assert(idx < capacity());
        assert(SM_DATA_S(matrix_) == data_);
        data_[idx] = data;
    }

    void set_data(sunindextype irow, sunindextype icol, realtype data) {
        assert(matrix_);
        assert(matrix_id() == SUNMATRIX_DENSE);
        assert(irow < rows());
        assert(icol < columns());
        SM_ELEMENT_D(matrix_, irow, icol) = data;
    }

    sunindextype get_indexval(sunindextype idx) const {
        assert(matrix_);
        assert(matrix_id() == SUNMATRIX_SPARSE);
        assert(idx < capacity());
        assert(indexvals_ == SM_INDEXVALS_S(matrix_));
        return indexvals_[idx];
    }

    void set_indexval(sunindextype idx, sunindextype val) {
        assert(matrix_);
        assert(matrix_id() == SUNMATRIX_SPARSE);
        assert(idx < capacity());
        assert(indexvals_ == SM_INDEXVALS_S(matrix_));
        indexvals_[idx] = val;
    }

    void set_indexvals(gsl::span<sunindextype const> const vals) {
        assert(matrix_);
        assert(matrix_id() == SUNMATRIX_SPARSE);
        assert(gsl::narrow<sunindextype>(vals.size()) == capacity());
        assert(indexvals_ == SM_INDEXVALS_S(matrix_));
        std::copy_n(vals.begin(), capacity(), indexvals_);
    }

    sunindextype get_indexptr(sunindextype ptr_idx) const {
        assert(matrix_);
        assert(matrix_id() == SUNMATRIX_SPARSE);
        assert(ptr_idx <= num_indexptrs());
        assert(indexptrs_ == SM_INDEXPTRS_S(matrix_));
        return indexptrs_[ptr_idx];
    }

    void set_indexptr(sunindextype ptr_idx, sunindextype ptr) {
        assert(matrix_);
        assert(matrix_id() == SUNMATRIX_SPARSE);
        assert(ptr_idx <= num_indexptrs());
        assert(ptr <= capacity());
        assert(indexptrs_ == SM_INDEXPTRS_S(matrix_));
        indexptrs_[ptr_idx] = ptr;
        if (ptr_idx == num_indexptrs())
            num_nonzeros_ = ptr;
    }

    void set_indexptrs(gsl::span<sunindextype const> const ptrs) {
        assert(matrix_);
        assert(matrix_id() == SUNMATRIX_SPARSE);
        assert(gsl::narrow<sunindextype>(ptrs.size()) == num_indexptrs() + 1);
        assert(indexptrs_ == SM_INDEXPTRS_S(matrix_));
        std::copy_n(ptrs.begin(), num_indexptrs() + 1, indexptrs_);
        num_nonzeros_ = indexptrs_[num_indexptrs()];
    }

    int sparsetype() const;

    void scale(realtype a);

    void multiply(N_Vector c, const_N_Vector b, realtype alpha = 1.0) const;

    void
    multiply(AmiVector& c, AmiVector const& b, realtype alpha = 1.0) const {
        multiply(c.getNVector(), b.getNVector(), alpha);
    }

    void multiply(
        gsl::span<realtype> c, gsl::span<realtype const> b,
        realtype const alpha = 1.0
    ) const;

    void multiply(
        N_Vector c, const_N_Vector b, gsl::span<int const> cols, bool transpose
    ) const;

    void multiply(
        gsl::span<realtype> c, gsl::span<realtype const> b,
        gsl::span<int const> cols, bool transpose
    ) const;

    void sparse_multiply(SUNMatrixWrapper& C, SUNMatrixWrapper const& B) const;

    void sparse_add(
        SUNMatrixWrapper const& A, realtype alpha, SUNMatrixWrapper const& B,
        realtype beta
    );

    void sparse_sum(std::vector<SUNMatrixWrapper> const& mats);

    sunindextype scatter(
        sunindextype const k, realtype const beta, sunindextype* w,
        gsl::span<realtype> x, sunindextype const mark, SUNMatrixWrapper* C,
        sunindextype nnz
    ) const;

    void transpose(
        SUNMatrixWrapper& C, realtype const alpha, sunindextype blocksize
    ) const;

    void to_dense(SUNMatrixWrapper& D) const;

    void to_diag(N_Vector v) const;

    void zero();

    SUNMatrix_ID matrix_id() const { return id_; };

    void refresh();

  private:
    SUNMatrix matrix_{nullptr};

    SUNMatrix_ID id_{SUNMATRIX_CUSTOM};

    int sparsetype_{CSC_MAT};

    sunindextype num_nonzeros_{0};
    sunindextype capacity_{0};

    realtype* data_{nullptr};
    sunindextype* indexptrs_{nullptr};
    sunindextype* indexvals_{nullptr};

    sunindextype num_rows_{0};
    sunindextype num_columns_{0};
    sunindextype num_indexptrs_{0};

    void finish_init();
    void update_ptrs();
    void update_size();
    bool ownmat = true;
};

auto unravel_index(sunindextype i, SUNMatrix m)
    -> std::pair<sunindextype, sunindextype>;

} // namespace amici

namespace gsl {
inline span<realtype> make_span(SUNMatrix m) {
    switch (SUNMatGetID(m)) {
    case SUNMATRIX_DENSE:
        return span<realtype>(SM_DATA_D(m), SM_LDATA_D(m));
    case SUNMATRIX_SPARSE:
        return span<realtype>(SM_DATA_S(m), SM_NNZ_S(m));
    default:
        throw amici::AmiException("Unimplemented SUNMatrix type for make_span");
    }
}
} // namespace gsl

#endif // AMICI_SUNDIALS_MATRIX_WRAPPER_H