From 7e3eff7b2de0eb192a6df66ea669b3383c168be3 Mon Sep 17 00:00:00 2001 From: Siddharth Bharat Purohit Date: Fri, 3 Feb 2017 23:50:09 +0530 Subject: [PATCH] remove unnecessary duplicate matrices from inverse --- matrix/SquareMatrix.hpp | 63 +++++++++++++++++++++++++++++++++-------- 1 file changed, 51 insertions(+), 12 deletions(-) diff --git a/matrix/SquareMatrix.hpp b/matrix/SquareMatrix.hpp index fbd61641c8..72b493e2a3 100644 --- a/matrix/SquareMatrix.hpp +++ b/matrix/SquareMatrix.hpp @@ -47,9 +47,23 @@ public: // inverse alias inline SquareMatrix I() const { - return inv(*this); + SquareMatrix i; + if(inv(*this, i)) { + return i; + } else { + i.setZero(); + return i; + } } + + // inverse alias + inline bool I(SquareMatrix &i) const + { + return inv(*this, i); + } + + Vector diag() const { Vector res; @@ -108,11 +122,12 @@ SquareMatrix expm(const Matrix & A, size_t order=5) return res; } + /** * inverse based on LU factorization with partial pivotting */ template -SquareMatrix inv(const SquareMatrix & A) +bool inv(const SquareMatrix & A, SquareMatrix & inv) { SquareMatrix L; L.setIdentity(); @@ -147,14 +162,12 @@ SquareMatrix inv(const SquareMatrix & A) //printf("U:\n"); U.print(); //printf("P:\n"); P.print(); //fflush(stdout); - ASSERT(fabsf(U(n, n)) > 1e-8f); + //ASSERT(fabsf(U(n, n)) > 1e-8f); #endif // failsafe, return zero matrix if (fabsf(U(n, n)) < 1e-8f) { - SquareMatrix zero; - zero.setZero(); - return zero; + return false; } // for all rows below diagonal @@ -173,7 +186,7 @@ SquareMatrix inv(const SquareMatrix & A) //printf("U:\n"); U.print(); // solve LY=P*I for Y by forward subst - SquareMatrix Y = P; + //SquareMatrix Y = P; // for all columns of Y for (size_t c = 0; c < M; c++) { @@ -184,7 +197,7 @@ SquareMatrix inv(const SquareMatrix & A) // for all existing y // subtract the component they // contribute to the solution - Y(i, c) -= L(i, j) * Y(j, c); + P(i, c) -= L(i, j) * P(j, c); } // divide by the factor @@ -198,7 +211,7 @@ SquareMatrix inv(const SquareMatrix & A) //printf("Y:\n"); Y.print(); // solve Ux=y for x by back subst - SquareMatrix X = Y; + //SquareMatrix X = Y; // for all columns of X for (size_t c = 0; c < M; c++) { @@ -212,20 +225,46 @@ SquareMatrix inv(const SquareMatrix & A) // for all existing x // subtract the component they // contribute to the solution - X(i, c) -= U(i, j) * X(j, c); + P(i, c) -= U(i, j) * P(j, c); } // divide by the factor // on current // term to be solved - X(i, c) /= U(i, i); + if(fabsf(U(i,i)) < 1e-8f) { + return false; + } + P(i, c) /= U(i, i); } } + //check sanity of results + for (uint8_t i = 0; i < M; i++) { + for (uint8_t j = 0; j < M; j++) { + if (!PX4_ISFINITE(P(i,j))) { + return false; + } + } + } //printf("X:\n"); X.print(); - return X; + inv = P; + return true; } +/** + * inverse based on LU factorization with partial pivotting + */ +template +SquareMatrix inv(const SquareMatrix & A) +{ + SquareMatrix i; + if(inv(A, i)) { + return i; + } else { + i.setZero(); + return i; + } +} typedef SquareMatrix Matrix3f; } // namespace matrix