From f3e478cbffedfbe674e71da357f7d0a7437019b4 Mon Sep 17 00:00:00 2001 From: James Goppert Date: Tue, 22 Nov 2016 10:04:45 -0500 Subject: [PATCH] Fix matrix inverse pivotting logic. --- matrix/SquareMatrix.hpp | 8 ++++---- test/inverse.cpp | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/matrix/SquareMatrix.hpp b/matrix/SquareMatrix.hpp index 3e42b64e55..fbd61641c8 100644 --- a/matrix/SquareMatrix.hpp +++ b/matrix/SquareMatrix.hpp @@ -128,16 +128,16 @@ SquareMatrix inv(const SquareMatrix & A) // if diagonal is zero, swap with row below if (fabsf(U(n, n)) < 1e-8f) { //printf("trying pivot for row %d\n",n); - for (size_t i = 0; i < M; i++) { - if (i == n) { - continue; - } + for (size_t i = n + 1; i < M; i++) { //printf("\ttrying row %d\n",i); if (fabsf(U(i, n)) > 1e-8f) { //printf("swapped %d\n",i); U.swapRows(i, n); P.swapRows(i, n); + L.swapRows(i, n); + L.swapCols(i, n); + break; } } } diff --git a/test/inverse.cpp b/test/inverse.cpp index f8af641aea..f04d7f771a 100644 --- a/test/inverse.cpp +++ b/test/inverse.cpp @@ -41,6 +41,45 @@ int main() SquareMatrix zero_test = zeros(); inv(zero_test); + // test pivotting + float data2[81] = { + -2, 1, 1, -1, -5, 1, 2, -1, 0, + -3, 2, -1, 0, 2, 2, -1, -5, 3, + 0, 0, 0, 1, 4, -3, 3, 0, -2, + 2, 2, -1, -2, -1, 0, 3, 0, 1, + -1, 2, -1, -1, -3, 3, 0, -2, 3, + 0, 1, 1, -3, 3, -2, 0, -4, 0, + 1, 0, 0, 0, 0, 0, -2, 4, -3, + 1, -1, 0, -1, -1, 1, -1, -3, 4, + 0, 3, -1, -2, 2, 1, -2, 0, -1}; + + float data2_check[81] = { + 6, -4, 3, -3 , -9, -8, -10, 8, 14, + -2, -7, -5, -3 , -2, -2, -16, -5, 8, + -2, 0, -23, 7, -24, -5, -28, -14, 9, + 3, -7, 2, -5, -4, -6, -13, 4, 13, + -1, 4, -8, 5, -8, 0, -3, -5, -2, + 6, 7, -7, 7, -21, -7, -5, 3, 6, + 1, 4, -4, 4, -7, -1, 0, -1, -1, + -7, 3, -11, 5, 1, 6, -1, -13, -10, + -8, 0, -11, 3, 3, 6, -5, -14, -8}; + SquareMatrix A2(data2); + SquareMatrix A2_I = inv(A2); + SquareMatrix A2_I_check(data2_check); + TEST((A2_I - A2_I_check).abs().max() < 1e-3); + float data3[9] = { + 0, 1, 2, + 3, 4, 5, + 6, 7, 9}; + float data3_check[9] = { + -0.3333333f, -1.6666666f, 1, + -1, 4, -2, + 1, -2, 1 + }; + SquareMatrix A3(data3); + SquareMatrix A3_I = inv(A3); + SquareMatrix A3_I_check(data3_check); + TEST((A3_I - A3_I_check).abs().max() < 1e-5); return 0; }