From 4873dc1c1e74f492ec0fcd5b493c9b500f2784fa Mon Sep 17 00:00:00 2001 From: kritz Date: Wed, 4 Mar 2020 09:14:04 +0100 Subject: [PATCH] Analytic inverse implementation (#122) * Add analytic 2x2 matrix inverse * Add analytical 3x3 matrix inverse --- matrix/SquareMatrix.hpp | 41 +++++++++++++++++++++++++++++++++++++++++ test/inverse.cpp | 39 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/matrix/SquareMatrix.hpp b/matrix/SquareMatrix.hpp index ba7510c068..216ff172d8 100644 --- a/matrix/SquareMatrix.hpp +++ b/matrix/SquareMatrix.hpp @@ -435,6 +435,47 @@ bool inv(const SquareMatrix & A, SquareMatrix & inv) return true; } +template +bool inv(const SquareMatrix & A, SquareMatrix & inv) +{ + Type det = A(0, 0) * A(1, 1) - A(1, 0) * A(0, 1); + + if(fabs(static_cast(det)) < FLT_EPSILON || !is_finite(det)) { + return false; + } + + inv(0, 0) = A(1, 1); + inv(1, 0) = -A(1, 0); + inv(0, 1) = -A(0, 1); + inv(1, 1) = A(0, 0); + inv /= det; + return true; +} + +template +bool inv(const SquareMatrix & A, SquareMatrix & inv) +{ + Type det = A(0, 0) * (A(1, 1) * A(2, 2) - A(2, 1) * A(1, 2)) - + A(0, 1) * (A(1, 0) * A(2, 2) - A(1, 2) * A(2, 0)) + + A(0, 2) * (A(1, 0) * A(2, 1) - A(1, 1) * A(2, 0)); + + if(fabs(static_cast(det)) < FLT_EPSILON || !is_finite(det)) { + return false; + } + + inv(0, 0) = A(1, 1) * A(2, 2) - A(2, 1) * A(1, 2); + inv(0, 1) = A(0, 2) * A(2, 1) - A(0, 1) * A(2, 2); + inv(0, 2) = A(0, 1) * A(1, 2) - A(0, 2) * A(1, 1); + inv(1, 0) = A(1, 2) * A(2, 0) - A(1, 0) * A(2, 2); + inv(1, 1) = A(0, 0) * A(2, 2) - A(0, 2) * A(2, 0); + inv(1, 2) = A(1, 0) * A(0, 2) - A(0, 0) * A(1, 2); + inv(2, 0) = A(1, 0) * A(2, 1) - A(2, 0) * A(1, 1); + inv(2, 1) = A(2, 0) * A(0, 1) - A(0, 0) * A(2, 1); + inv(2, 2) = A(0, 0) * A(1, 1) - A(1, 0) * A(0, 1); + inv /= det; + return true; +} + /** * inverse based on LU factorization with partial pivotting */ diff --git a/test/inverse.cpp b/test/inverse.cpp index d5f75296c4..865c79db32 100644 --- a/test/inverse.cpp +++ b/test/inverse.cpp @@ -22,6 +22,27 @@ int main() SquareMatrix A_I_check(data_check); TEST((A_I - A_I_check).abs().max() < 1e-6f); + float data_2x2[4] = {12, 2, + -7, 5 + }; + float data_2x2_check[4] = { + 0.0675675675f, -0.02702702f, + 0.0945945945f, 0.162162162f + }; + + SquareMatrix A2x2(data_2x2); + SquareMatrix A2x2_I = inv(A2x2); + SquareMatrix A2x2_I_check(data_2x2_check); + TEST(isEqual(A2x2_I, A2x2_I_check)); + + SquareMatrix A2x2_sing = ones(); + SquareMatrix A2x2_sing_I; + TEST(inv(A2x2_sing, A2x2_sing_I) == false); + + SquareMatrix A3x3_sing = ones(); + SquareMatrix A3x3_sing_I; + TEST(inv(A3x3_sing, A3x3_sing_I) == false) + // stess test SquareMatrix A_large; A_large.setIdentity(); @@ -34,7 +55,7 @@ int main() } SquareMatrix zero_test = zeros(); - inv(zero_test); + TEST(isEqual(inv(zero_test), zeros())); // test pivotting float data2[81] = { @@ -64,6 +85,7 @@ int main() SquareMatrix A2_I = inv(A2); SquareMatrix A2_I_check(data2_check); TEST((A2_I - A2_I_check).abs().max() < 1e-3f); + float data3[9] = { 0, 1, 2, 3, 4, 5, @@ -93,6 +115,16 @@ int main() TEST(isEqual(A3_I, Z3)); TEST(isEqual(A3.I(), Z3)); + for(size_t i = 0; i < 9; i++) { + A2(0, i) = 0; + } + A2_I = inv(A2); + SquareMatrix Z9 = zeros(); + TEST(!A2.I(A2_I)); + TEST(!Z9.I(A2_I)); + TEST(isEqual(A2_I, Z9)); + TEST(isEqual(A2.I(), Z9)); + // cover NaN A3(0, 0) = NAN; A3(0, 1) = 0; @@ -101,6 +133,11 @@ int main() TEST(isEqual(A3_I, Z3)); TEST(isEqual(A3.I(), Z3)); + A2(0, 0) = NAN; + A2_I = inv(A2); + TEST(isEqual(A2_I, Z9)); + TEST(isEqual(A2.I(), Z9)); + float data4[9] = { 1.33471626f, 0.74946721f, -0.0531679f, 0.74946721f, 1.07519593f, 0.08036323f,