linear-library
Loading...
Searching...
No Matches
LU.hpp
Go to the documentation of this file.
1#pragma once
2#include "../vector_matrix/FlatMatrix.hpp"
3#include <vector>
4#include <utility>
5#include <limits>
6#include <type_traits>
7#include <cmath>
8#include <numeric>
9
10namespace LinearAlgebra{
22 template<typename T>
23 class LU {
24 static_assert(std::is_floating_point_v<T>,
25 "T must be a floating-point type (float or double)");
26 private:
27
28 FlatMatrix<T> matrix;
29 int signP = 1;
30 void elimination(int col);
31 void initP() {
32 P.resize(matrix.getRows());
33 std::iota(P.begin(), P.end(), 0);
34 };
35 std::vector<int> initInvP() const {
36 int n = matrix.getRows();
37 std::vector<int>invP(n,-1);
38 for (int i = 0; i < n; i++) {
39 int pi = P[i];
40 if (pi < 0 || pi >= n) throw std::invalid_argument("invalid permutation");
41 invP[pi] = i;
42 }
43 return invP;
44 };
45 int pivoting(int col);
46 void forwardSubstitution(std::vector<T>& y, const std::vector<T>& b, int n) const;
47 void backwardSubstitution(std::vector<T>& x, const std::vector<T>& y, int n) const;
48 std::vector<int> P; //vector of swap
49 static constexpr T eps = std::numeric_limits<T>::epsilon() * static_cast<T>(100);
50 void decomposition();
51 public:
52
53 LU(const FlatMatrix<T>& m) : matrix(m) {
54 decomposition();
55 }
56
57 LU(FlatMatrix<T>&& m) : matrix(std::move(m)) {
58 decomposition();
59 }
60 T det() const;
61 FlatMatrix<T> inv() const;
62 const std::vector<int>& getP() const{ return P; }
63 const FlatMatrix<T>& getMatrix() const{ return matrix;}
64 };
65
66
67 template<typename T>
68 int LU<T>::pivoting(int col) {
69 T pivotVal = std::abs(matrix(col,col));
70 int pivot = col;
71 int n = matrix.getRows();
72 for (int i = col+1; i < n; i++) {
73 T val = std::abs(matrix(i,col));
74 if(val > pivotVal) {
75 pivotVal = val;
76 pivot = i;
77 }
78 }
79 return pivot;
80 }
81
82 template<typename T>
83 void LU<T>::elimination(int col) {
84 T pivot = matrix(col,col);
85 int n = matrix.getRows();
86 for (int i = col+1; i < n; i++)
87 {
88 T factor = matrix(i,col)/pivot;
89 matrix(i,col) = factor;
90 for (int j = col+1; j < n; j++)
91 {
92 matrix(i,j) -= factor * matrix(col,j);
93 }
94 }
95
96 }
97
98 template<typename T>
99 void LU<T>::decomposition() {
100 if (matrix.getRows() != matrix.getCols())
101 throw std::runtime_error("Matrix must be square for LU decomposition");
102 int n = matrix.getRows();
103 initP();
104 int swapCount = 0;
105 for (int k = 0; k < n; ++k) {
106 int pivot = pivoting(k);
107 if(pivot != k) {
108 for (int j = 0; j < n; j++) {
109 std::swap(matrix(k,j), matrix(pivot,j));
110 }
111
112 std::swap(P[k], P[pivot]);
113 swapCount++;
114 }
115 if (std::abs(matrix(k,k)) <= eps) {
116 throw std::runtime_error("Matrix is singular or nearly singular at pivot " + std::to_string(k));
117 }
118 elimination(k);
119 }
120 signP = swapCount % 2 == 0 ? 1 : -1;
121 }
122
123 template<typename T>
124 T LU<T>::det() const {
125 int n = matrix.getRows();
126 T res = T(1);
127 for(int i = 0; i < n; i++) {
128 res *= matrix(i,i);
129 }
130
131 return res*T(signP);
132 }
133
134 template<typename T>
135 void LU<T>::forwardSubstitution(std::vector<T>& y, const std::vector<T>& b, int n) const {
136 y[0] = b[0];
137 for (int i = 1; i < n; ++i) {
138 T sum = 0;
139 for (int j = 0; j < i; ++j) {
140 sum += matrix(i,j) * y[j];
141 }
142 y[i] = (b[i] - sum);
143 }
144 }
145
146 template<typename T>
147 void LU<T>::backwardSubstitution(std::vector<T>& x, const std::vector<T>& y, int n) const {
148 for (int i = n-1; i >= 0; --i){
149 T sum = 0;
150 for (int j = i+1; j < n; ++j) {
151 sum += matrix(i,j) * x[j];
152 }
153 x[i] = (y[i] - sum)/matrix(i,i);
154 }
155 }
156
157 template<typename T>
159 int n = matrix.getRows();
161 std::vector<int> invP = initInvP();
162 std::vector<T> b(n), y(n), x(n);
163 int prev = invP[0];
164 b[prev] = T(1);
165
166 for(int i = 0; i < n; ++i) {
167 int curr = invP[i];
168
169 if (i > 0) {
170 b[prev] = T(0);
171 b[curr] = T(1);
172 }
173
174 forwardSubstitution(y, b, n);
175 backwardSubstitution(x, y, n);
176
177 for (int k = 0; k < n; ++k) {
178 X(k,i) = x[k];
179 }
180
181 prev = curr;
182 }
183 return X;
184 }
185}
LU decomposition with partial pivoting.
Definition LU.hpp:23
const FlatMatrix< T > & getMatrix() const
Definition LU.hpp:63
T det() const
Definition LU.hpp:124
LU(const FlatMatrix< T > &m)
Definition LU.hpp:53
FlatMatrix< T > inv() const
Definition LU.hpp:158
LU(FlatMatrix< T > &&m)
Definition LU.hpp:57
const std::vector< int > & getP() const
Definition LU.hpp:62
Examples:
Definition VectorMatrix.hpp:32
VectorMatrix()
Definition VectorMatrix.hpp:61
int getRows() const
return count rows
Definition VectorMatrix.hpp:107
Definition DecomposeLU.hpp:9