linear-library
Loading...
Searching...
No Matches
FlatMatrix.hpp
Go to the documentation of this file.
1#pragma once
2
3#include<iostream>
4#include <algorithm>
5#include <vector>
6#include <initializer_list>
7#include <stdexcept>
8
9namespace LinearAlgebra {
32 template <typename T>
33 class FlatMatrix {
34 private:
35 int cols,rows;
36 std::vector<T> flatMatrix;
37 void validation(std::vector<std::vector<T>> v) {
38 for (const auto& row : v)
39 if (static_cast<int>(row.size()) != cols)
40 throw std::invalid_argument("All rows must have the same size");
41 }
42 public:
43 FlatMatrix(const std::vector<std::vector<T>>& v) {
44 if (v.empty() || v[0].empty()) {
45 throw std::invalid_argument("Matrix is empty");
46 }
47
48 rows = v.size();
49 cols = v[0].size();
50 validation(v);
51
52 flatMatrix.reserve(rows * cols);
53 for (const auto& row : v) {
54 flatMatrix.insert(flatMatrix.end(),
55 row.begin(),
56 row.end()
57 );
58 }
59
60 }
61
62 FlatMatrix(int r,int c,std::vector<T>& m):rows(r),cols(c) {
63 if(r == 0 || c == 0) {
64 throw std::invalid_argument("rows or cols params is 0");
65 }
66
67 int expected = r * c;
68 if(m.size() != expected) {
69 throw std::invalid_argument("wrong size");
70 }
71
72 flatMatrix = std::move(m);
73 }
74
75 FlatMatrix(std::initializer_list<std::initializer_list<T>> v){
76 rows = v.size();
77 cols = v.begin()->size();
78
79 flatMatrix.clear();
80 flatMatrix.reserve(rows * cols);
81
82 for (const auto& row : v) {
83 if(row.size()!=cols) {
84 throw std::invalid_argument("All rows must have the same size");
85 }
86 flatMatrix.insert(flatMatrix.end(),row.begin(),row.end());
87 }
88 }
89
90
91 FlatMatrix(int r,int c):rows(r),cols(c),flatMatrix(r*c,T{}) {}
92 FlatMatrix() = default;
93
94 int getRows() const{
95 return rows;
96 }
97
98 int getCols() const{
99 return cols;
100 }
101
102 T& operator()(int i, int j) {
103 assert(i >= 0 && i < rows && j >= 0 && j < cols && "FlatMatrix: index out of range");
104 return flatMatrix[i * cols + j];
105 }
106
107 const T& operator()(int i, int j) const{
108 assert(i >= 0 && i < rows && j >= 0 && j < cols && "FlatMatrix: index out of range");
109 return flatMatrix[i * cols + j];
110 }
111
114 FlatMatrix<T> operator*(const T scalar) const;
115 FlatMatrix<T> operator~() const;
116
117
118 template<typename U>
119 friend std::ostream& operator<<(std::ostream& os, FlatMatrix<U>& m);
120
121 };
122
123 template<typename T>
125 FlatMatrix result(cols,rows);
126
127 for(int i = 0; i < cols; i++) {
128 for (int j = 0; j < rows; j++) {
129 result.flatMatrix[j*cols + i] = flatMatrix[i*cols +j];
130 }
131 }
132
133 return result;
134 }
135
136 template<typename T>
138 FlatMatrix result(rows,cols);
139 const int size = rows * cols;
140 for (int i = 0; i < size; i++) {
141 result.flatMatrix[i] = flatMatrix[i] * scalar;
142 }
143 return result;
144 }
145
146 template<typename T>
148 if(cols != B.getCols()) {
149 throw std::invalid_argument("num columns A not equal num rows B");
150 }
151
152 int colsB = B.getCols();
153 FlatMatrix result(rows,colsB);
154 const int bs = 32;
155
156 for (int i = 0; i < rows; i += bs) {
157 for (int k = 0; k < cols; k += bs) {
158 for (int j = 0; j < colsB; j += bs) {
159
160 int i_end = std::min(i + bs, rows);
161 int k_end = std::min(k + bs, cols);
162 int j_end = std::min(j + bs, colsB);
163
164 for (int ii = i; ii < i_end; ++ii) {
165 for (int kk = k; kk < k_end; ++kk) {
166 for (int jj = j; jj < j_end; ++jj) {
167 result.flatMatrix[ii*cols+jj] +=
168 flatMatrix[ii*cols+kk] * B.flatMatrix[kk*cols+jj];
169 }
170 }
171 }
172 }
173 }
174 }
175 return result;
176 }
177
178 template<typename U>
179 std::ostream& operator<<(std::ostream& os, FlatMatrix<U>& m) {
180 int rows = m.rows;
181 int cols = m.cols;
182 for(int i =0; i < rows; ++i) {
183 for(int j = 0; j < rows; ++j) {
184 os<<m(i,j);
185 if (j + 1 < cols)
186 os << ' ';
187 }
188 if (i + 1 < rows)
189 os << '\n';
190 }
191 return os;
192 }
193
194
195
196}
Examples:
Definition FlatMatrix.hpp:33
int getCols() const
Definition FlatMatrix.hpp:98
int getRows() const
Definition FlatMatrix.hpp:94
FlatMatrix(int r, int c, std::vector< T > &m)
Definition FlatMatrix.hpp:62
FlatMatrix(int r, int c)
Definition FlatMatrix.hpp:91
FlatMatrix< T > operator+(const FlatMatrix< T > &B) const
FlatMatrix(std::initializer_list< std::initializer_list< T > > v)
Definition FlatMatrix.hpp:75
FlatMatrix< T > operator*(const FlatMatrix< T > &B) const
Definition FlatMatrix.hpp:147
FlatMatrix(const std::vector< std::vector< T > > &v)
Definition FlatMatrix.hpp:43
T & operator()(int i, int j)
Definition FlatMatrix.hpp:102
const T & operator()(int i, int j) const
Definition FlatMatrix.hpp:107
friend std::ostream & operator<<(std::ostream &os, FlatMatrix< U > &m)
Definition FlatMatrix.hpp:179
FlatMatrix< T > operator~() const
Definition FlatMatrix.hpp:124
Examples:
Definition VectorMatrix.hpp:32
VectorMatrix()
Definition VectorMatrix.hpp:61
Definition DecomposeLU.hpp:9
std::ostream & operator<<(std::ostream &os, FlatMatrix< U > &m)
Definition FlatMatrix.hpp:179