From 00a0b36836e49b61db8b52127e0a42bae34db9ba Mon Sep 17 00:00:00 2001 From: jgoppert Date: Thu, 5 Nov 2015 15:43:36 -0500 Subject: [PATCH] Moved inverse outside of matrix definition. --- matrix/SquareMatrix.hpp | 240 ++++++++++++++++++++-------------------- test/inverse.cpp | 4 +- 2 files changed, 123 insertions(+), 121 deletions(-) diff --git a/matrix/SquareMatrix.hpp b/matrix/SquareMatrix.hpp index ec3a450188..ad74b217da 100644 --- a/matrix/SquareMatrix.hpp +++ b/matrix/SquareMatrix.hpp @@ -41,128 +41,10 @@ public: { } - /** - * inverse based on LU factorization with partial pivotting - */ - SquareMatrix inverse() const - { - SquareMatrix L; - L.setIdentity(); - const SquareMatrix &A = (*this); - SquareMatrix U = A; - SquareMatrix P; - P.setIdentity(); - - //printf("A:\n"); A.print(); - - // for all diagonal elements - for (size_t n = 0; n < M; n++) { - - // if diagonal is zero, swap with row below - if (fabsf(U(n, n)) < 1e-8f) { - //printf("trying pivot for row %d\n",n); - for (size_t i = 0; i < M; i++) { - if (i == n) { - continue; - } - - //printf("\ttrying row %d\n",i); - if (fabsf(U(i, n)) > 1e-8f) { - //printf("swapped %d\n",i); - U.swapRows(i, n); - P.swapRows(i, n); - } - } - } - -#ifdef MATRIX_ASSERT - //printf("A:\n"); A.print(); - //printf("U:\n"); U.print(); - //printf("P:\n"); P.print(); - //fflush(stdout); - ASSERT(fabsf(U(n, n)) > 1e-8f); -#endif - - // failsafe, return zero matrix - if (fabsf(U(n, n)) < 1e-8f) { - SquareMatrix zero; - zero.setZero(); - return zero; - } - - // for all rows below diagonal - for (size_t i = (n + 1); i < M; i++) { - L(i, n) = U(i, n) / U(n, n); - - // add i-th row and n-th row - // multiplied by: -a(i,n)/a(n,n) - for (size_t k = n; k < M; k++) { - U(i, k) -= L(i, n) * U(n, k); - } - } - } - - //printf("L:\n"); L.print(); - //printf("U:\n"); U.print(); - - // solve LY=P*I for Y by forward subst - SquareMatrix Y = P; - - // for all columns of Y - for (size_t c = 0; c < M; c++) { - // for all rows of L - for (size_t i = 0; i < M; i++) { - // for all columns of L - for (size_t j = 0; j < i; j++) { - // for all existing y - // subtract the component they - // contribute to the solution - Y(i, c) -= L(i, j) * Y(j, c); - } - - // divide by the factor - // on current - // term to be solved - // Y(i,c) /= L(i,i); - // but L(i,i) = 1.0 - } - } - - //printf("Y:\n"); Y.print(); - - // solve Ux=y for x by back subst - SquareMatrix X = Y; - - // for all columns of X - for (size_t c = 0; c < M; c++) { - // for all rows of U - for (size_t k = 0; k < M; k++) { - // have to go in reverse order - size_t i = M - 1 - k; - - // for all columns of U - for (size_t j = i + 1; j < M; j++) { - // for all existing x - // subtract the component they - // contribute to the solution - X(i, c) -= U(i, j) * X(j, c); - } - - // divide by the factor - // on current - // term to be solved - X(i, c) /= U(i, i); - } - } - - //printf("X:\n"); X.print(); - return X; - } - // inverse alias inline SquareMatrix I() const { - return inverse(); + return inv(*this); } Vector diag() const @@ -212,6 +94,126 @@ SquareMatrix expm(const SquareMatrix & A, size_t order=5) return res; } +/** + * inverse based on LU factorization with partial pivotting + */ +template +SquareMatrix inv(const SquareMatrix & A) +{ + SquareMatrix L; + L.setIdentity(); + SquareMatrix U = A; + SquareMatrix P; + P.setIdentity(); + + //printf("A:\n"); A.print(); + + // for all diagonal elements + for (size_t n = 0; n < M; n++) { + + // if diagonal is zero, swap with row below + if (fabsf(U(n, n)) < 1e-8f) { + //printf("trying pivot for row %d\n",n); + for (size_t i = 0; i < M; i++) { + if (i == n) { + continue; + } + + //printf("\ttrying row %d\n",i); + if (fabsf(U(i, n)) > 1e-8f) { + //printf("swapped %d\n",i); + U.swapRows(i, n); + P.swapRows(i, n); + } + } + } + +#ifdef MATRIX_ASSERT + //printf("A:\n"); A.print(); + //printf("U:\n"); U.print(); + //printf("P:\n"); P.print(); + //fflush(stdout); + ASSERT(fabsf(U(n, n)) > 1e-8f); +#endif + + // failsafe, return zero matrix + if (fabsf(U(n, n)) < 1e-8f) { + SquareMatrix zero; + zero.setZero(); + return zero; + } + + // for all rows below diagonal + for (size_t i = (n + 1); i < M; i++) { + L(i, n) = U(i, n) / U(n, n); + + // add i-th row and n-th row + // multiplied by: -a(i,n)/a(n,n) + for (size_t k = n; k < M; k++) { + U(i, k) -= L(i, n) * U(n, k); + } + } + } + + //printf("L:\n"); L.print(); + //printf("U:\n"); U.print(); + + // solve LY=P*I for Y by forward subst + SquareMatrix Y = P; + + // for all columns of Y + for (size_t c = 0; c < M; c++) { + // for all rows of L + for (size_t i = 0; i < M; i++) { + // for all columns of L + for (size_t j = 0; j < i; j++) { + // for all existing y + // subtract the component they + // contribute to the solution + Y(i, c) -= L(i, j) * Y(j, c); + } + + // divide by the factor + // on current + // term to be solved + // Y(i,c) /= L(i,i); + // but L(i,i) = 1.0 + } + } + + //printf("Y:\n"); Y.print(); + + // solve Ux=y for x by back subst + SquareMatrix X = Y; + + // for all columns of X + for (size_t c = 0; c < M; c++) { + // for all rows of U + for (size_t k = 0; k < M; k++) { + // have to go in reverse order + size_t i = M - 1 - k; + + // for all columns of U + for (size_t j = i + 1; j < M; j++) { + // for all existing x + // subtract the component they + // contribute to the solution + X(i, c) -= U(i, j) * X(j, c); + } + + // divide by the factor + // on current + // term to be solved + X(i, c) /= U(i, i); + } + } + + //printf("X:\n"); X.print(); + return X; +} + + + }; // namespace matrix /* vim: set et fenc=utf-8 ff=unix sts=0 sw=4 ts=4 : */ diff --git a/test/inverse.cpp b/test/inverse.cpp index c2d60c9003..2dd17eb9a4 100644 --- a/test/inverse.cpp +++ b/test/inverse.cpp @@ -20,7 +20,7 @@ int main() 1. , -2. , 1. }; SquareMatrix A(data); - SquareMatrix A_I = A.inverse(); + SquareMatrix A_I = inv(A); SquareMatrix A_I_check(data_check); A_I.print(); A_I_check.print(); @@ -33,7 +33,7 @@ int main() A_large_I.setZero(); for (size_t i = 0; i < n_large; i++) { - A_large_I = A_large.inverse(); + A_large_I = inv(A_large); assert(A_large == A_large_I); }