Program Listing for File sundials_matrix_wrapper.h

Return to documentation for file (include/amici/sundials_matrix_wrapper.h)

#ifndef AMICI_SUNDIALS_MATRIX_WRAPPER_H
#define AMICI_SUNDIALS_MATRIX_WRAPPER_H

#include <sunmatrix/sunmatrix_band.h>   // SUNMatrix_Band
#include <sunmatrix/sunmatrix_dense.h>  // SUNMatrix_Dense
#include <sunmatrix/sunmatrix_sparse.h> // SUNMatrix_Sparse

#include <gsl/gsl-lite.hpp>

#include <vector>

#include "amici/vector.h"

namespace amici {

class SUNMatrixWrapper {
  public:
    SUNMatrixWrapper() = default;

    SUNMatrixWrapper(sunindextype M, sunindextype N, sunindextype NNZ,
                     int sparsetype);

    SUNMatrixWrapper(sunindextype M, sunindextype N);

    SUNMatrixWrapper(sunindextype M, sunindextype ubw, sunindextype lbw);

    SUNMatrixWrapper(const SUNMatrixWrapper &A, realtype droptol,
                     int sparsetype);

    explicit SUNMatrixWrapper(SUNMatrix mat);

    ~SUNMatrixWrapper();

    SUNMatrixWrapper(const SUNMatrixWrapper &other);

    SUNMatrixWrapper(SUNMatrixWrapper &&other);

    SUNMatrixWrapper &operator=(const SUNMatrixWrapper &other);

    SUNMatrixWrapper &operator=(SUNMatrixWrapper &&other);

    void reallocate(sunindextype nnz);

    void realloc();

    SUNMatrix get() const;

    sunindextype rows() const;

    sunindextype columns() const;

    sunindextype num_nonzeros() const;

    sunindextype num_indexptrs() const;

    sunindextype capacity() const;

    realtype *data();

    const realtype *data() const;

    realtype get_data(sunindextype idx) const;

    realtype get_data(sunindextype irow, sunindextype icol) const;

    void set_data(sunindextype idx, realtype data);

    void set_data(sunindextype irow, sunindextype icol, realtype data);

    sunindextype get_indexval(sunindextype idx) const;

    void set_indexval(sunindextype idx, sunindextype val);

    void set_indexvals(const gsl::span<const sunindextype> vals);

    sunindextype get_indexptr(sunindextype ptr_idx) const;

    void set_indexptr(sunindextype ptr_idx, sunindextype ptr);

    void set_indexptrs(const gsl::span<const sunindextype> ptrs);

    int sparsetype() const;

    void scale(realtype a);

    void multiply(N_Vector c, const_N_Vector b, realtype alpha = 1.0) const;

    void multiply(gsl::span<realtype> c, gsl::span<const realtype> b,
                  const realtype alpha = 1.0) const;

    void multiply(N_Vector c,
                  const_N_Vector b,
                  gsl::span <const int> cols,
                  bool transpose) const;

    void multiply(gsl::span<realtype> c,
                  gsl::span<const realtype> b,
                  gsl::span <const int> cols,
                  bool transpose) const;

    void sparse_multiply(SUNMatrixWrapper &C,
                         const SUNMatrixWrapper &B) const;

    void sparse_add(const SUNMatrixWrapper &A, realtype alpha,
                    const SUNMatrixWrapper &B, realtype beta);

    void sparse_sum(const std::vector<SUNMatrixWrapper> &mats);

    sunindextype scatter(const sunindextype k, const realtype beta,
                         sunindextype *w, gsl::span<realtype> x,
                         const sunindextype mark,
                         SUNMatrixWrapper *C, sunindextype nnz) const;

    void transpose(SUNMatrixWrapper &C, const realtype alpha,
                   sunindextype blocksize) const;

    void to_dense(SUNMatrixWrapper &D) const;

    void to_diag(N_Vector v) const;

    void zero();

    SUNMatrix_ID matrix_id() const {return id_;};

    void refresh();

  private:

    SUNMatrix matrix_ {nullptr};

    SUNMatrix_ID id_ {SUNMATRIX_CUSTOM};

    int sparsetype_ {CSC_MAT};

    sunindextype num_nonzeros_ {0};
    sunindextype capacity_ {0};

    realtype *data_ {nullptr};
    sunindextype *indexptrs_ {nullptr};
    sunindextype *indexvals_ {nullptr};

    sunindextype num_rows_ {0};
    sunindextype num_columns_ {0};
    sunindextype num_indexptrs_ {0};

    void finish_init();
    void update_ptrs();
    void update_size();
    bool ownmat = true;
};

} // namespace amici

namespace gsl {
inline span<realtype> make_span(SUNMatrix m)
{
    switch (SUNMatGetID(m)) {
    case SUNMATRIX_DENSE:
        return span<realtype>(SM_DATA_D(m), SM_LDATA_D(m));
    case SUNMATRIX_SPARSE:
        return span<realtype>(SM_DATA_S(m), SM_NNZ_S(m));
    default:
        throw amici::AmiException("Unimplemented SUNMatrix type for make_span");
    }
}
} // namespace gsl

#endif // AMICI_SUNDIALS_MATRIX_WRAPPER_H