From 480c5f1f8ef0811fc88e782d68561e7566b7ad87 Mon Sep 17 00:00:00 2001 From: Bart Slinger Date: Thu, 13 Sep 2018 10:53:13 +0200 Subject: [PATCH] static assert M>=N. floats to Type, arguments as const reference --- matrix/LeastSquaresSolver.hpp | 37 +++++++++++++++++------------------ test/least_squares.cpp | 29 +++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 19 deletions(-) diff --git a/matrix/LeastSquaresSolver.hpp b/matrix/LeastSquaresSolver.hpp index 266e2ebe40..f594e16e9d 100644 --- a/matrix/LeastSquaresSolver.hpp +++ b/matrix/LeastSquaresSolver.hpp @@ -30,31 +30,30 @@ public: * * Initialize the class with a MxN matrix. The constructor starts the * QR decomposition. This class does not check the rank of the matrix. - * The user needs to make sure that rank(A) = N and N >= M. + * The user needs to make sure that rank(A) = N and M >= N. */ - LeastSquaresSolver(Matrix A) + LeastSquaresSolver(const Matrix &A) { - if (M < N) { - return; - } + static_assert(M >= N, "Matrix dimension should be M >= N"); + // Copy contentents of matrix A _A = A; for (size_t j = 0; j < N; j++) { - float normx = 0.0f; + Type normx = 0.; for (size_t i = j; i < M; i++) { normx += _A(i,j) * _A(i,j); } normx = sqrt(normx); - float s = _A(j,j) > 0 ? -1.0f : 1.0f; - float u1 = _A(j,j) - s*normx; + Type s = _A(j,j) > 0 ? -1. : 1.; + Type u1 = _A(j,j) - s*normx; // prevent divide by zero // also covers u1. normx is never negative - if (normx < 1e-8f) { + if (normx < 1e-8) { break; } - float w[M] = {}; - w[0] = 1.0f; + Type w[M] = {}; + w[0] = 1.; for (size_t i = j+1; i < M; i++) { w[i-j] = _A(i,j) / u1; _A(i,j) = w[i-j]; @@ -63,7 +62,7 @@ public: _tau(j) = -s*u1/normx; for (size_t k = j+1; k < N; k++) { - float tmp = 0.0f; + Type tmp = 0.; for (size_t i = j; i < M; i++) { tmp += w[i-j] * _A(i,k); } @@ -83,17 +82,17 @@ public: * This function calculates Q^T * b. This is useful for the solver * because R*x = Q^T*b. */ - Vector qtb(Vector b) { + Vector qtb(const Vector &b) { Vector qtbv = b; for (size_t j = 0; j < N; j++) { - float w[M]; - w[0] = 1.0f; + Type w[M]; + w[0] = 1.; // fill vector w for (size_t i = j+1; i < M; i++) { w[i-j] = _A(i,j); } - float tmp = 0.0f; + Type tmp = 0.; for (size_t i = j; i < M; i++) { tmp += w[i-j] * qtbv(i); } @@ -113,7 +112,7 @@ public: * Find x in the equation Ax = b. * A is provided in the initializer of the class. */ - Vector solve(Vector b) { + Vector solve(const Vector &b) { Vector qtbv = qtb(b); Vector x; @@ -124,9 +123,9 @@ public: x(i) -= _A(i,r) * x(r); } // divide by zero, return vector of zeros - if (fabs(_A(i,i)) < 1e-8f) { + if (fabs(_A(i,i)) < 1e-8) { for (size_t z = 0; z < N; z++) { - x(z) = 0.0f; + x(z) = 0.; } break; } diff --git a/test/least_squares.cpp b/test/least_squares.cpp index 14868218f7..f9cae89cd6 100644 --- a/test/least_squares.cpp +++ b/test/least_squares.cpp @@ -5,6 +5,7 @@ using namespace matrix; int test_4x3(void); int test_4x4(void); +int test_4x4_type_double(void); int test_div_zero(void); int main() @@ -17,6 +18,8 @@ int main() ret = test_4x3(); if (ret != 0) return ret; + ret = test_4x4_type_double(); + ret = test_div_zero(); if (ret != 0) return ret; @@ -74,6 +77,32 @@ int test_4x4() { return 0; } +int test_4x4_type_double() { + // Start with an (m x n) A matrix + double data[16] = { 20., -10., -13., 21., + 17., 16., -18., -14., + 0.7, -0.8, 0.9, -0.5, + -1., -1.1, -1.2, -1.3 + }; + Matrix A(data); + + double b_data[4] = {2.0, 3.0, 4.0, 5.0}; + Vector b(b_data); + + double x_check_data[4] = { 0.97893433, + -2.80798701, + -0.03175765, + -2.19387649 + }; + Vector x_check(x_check_data); + + LeastSquaresSolver qrd = LeastSquaresSolver(A); + + Vector x = qrd.solve(b); + TEST(isEqual(x, x_check)); + return 0; +} + int test_div_zero() { float data[4] = {0.0f, 0.0f, 0.0f, 0.0f}; Matrix A(data);