diff --git a/matrix/LeastSquaresSolver.hpp b/matrix/LeastSquaresSolver.hpp new file mode 100644 index 0000000000..08d7df11de --- /dev/null +++ b/matrix/LeastSquaresSolver.hpp @@ -0,0 +1,145 @@ +/** + * @file LeastSquaresSolver.hpp + * + * Least Squares Solver using QR householder decomposition. + * It calculates x for Ax = b. + * A = Q*R + * where R is an upper triangular matrix. + * + * R*x = Q^T*b + * This is efficiently solved for x because of the upper triangular property of R. + * + * @author Bart Slinger + */ + +#pragma once + +#include "math.hpp" + +namespace matrix { + +template +class LeastSquaresSolver +{ +public: + + /** + * @brief Class calculates QR decomposition which can be used for linear + * least squares + * @param A Matrix of size MxN + * + * Initialize the class with a MxN matrix. The constructor starts the + * QR decomposition. This class does not check the rank of the matrix. + * The user needs to make sure that rank(A) = N and N >= M. + */ + LeastSquaresSolver(Matrix A) + { + if (M < N) { + return; + } + // Copy contentents of matrix A + memcpy(_data, A._data, sizeof(_data)); + + for (size_t j = 0; j < N; j++) { + float normx = 0.0f; + for (size_t i = j; i < M; i++) { + normx += _data[i][j] * _data[i][j]; + } + normx = sqrt(normx); + float s = _data[j][j] > 0 ? -1.0f : 1.0f; + float u1 = _data[j][j] - s*normx; + // prevent divide by zero + // also covers u1. normx is never negative + if (normx < 1e-8f) { + return; + } + float w[M] = {}; + w[0] = 1.0f; + for (size_t i = j+1; i < M; i++) { + w[i-j] = _data[i][j] / u1; + _data[i][j] = w[i-j]; + } + _data[j][j] = s*normx; + _tau[j] = -s*u1/normx; + + for (size_t k = j+1; k < N; k++) { + float tmp = 0.0f; + for (size_t i = j; i < M; i++) { + tmp += w[i-j] * _data[i][k]; + } + for (size_t i = j; i < M; i++) { + _data[i][k] -= _tau[j] * w[i-j] * tmp; + } + } + + } + } + + /** + * @brief qtb Calculate Q^T * b + * @param b + * @return Q^T*b + * + * This function calculates Q^T * b. This is useful for the solver + * because R*x = Q^T*b. + */ + Vector qtb(Vector b) { + Vector qtbv = b; + + for (size_t j = 0; j < N; j++) { + float w[M]; + w[0] = 1.0f; + // fill vector w + for (size_t i = j+1; i < M; i++) { + w[i-j] = _data[i][j]; + } + float tmp = 0.0f; + for (size_t i = j; i < M; i++) { + tmp += w[i-j] * qtbv(i); + } + + for (size_t i = j; i < M; i++) { + qtbv(i) -= _tau[j] * w[i-j] * tmp; + } + } + return qtbv; + } + + /** + * @brief Solve Ax=b for x + * @param b + * @return Vector x + * + * Find x in the equation Ax = b. + * A is provided in the initializer of the class. + */ + Vector solve(Vector b) { + Vector qtbv = qtb(b); + Vector x; + + for (size_t l = N; l > 0 ; l--) { + size_t i = l - 1; + x(i) = qtbv(i); + for (size_t r = i+1; r < N; r++) { + x(i) -= _data[i][r] * x(r); + } + // divide by zero, return vector of zeros + if (fabs(_data[i][i]) < 1e-8f) { + for (size_t z = 0; z < N; z++) { + x(z) = 0.0f; + } + } + x(i) = x(i) / _data[i][i]; + } + return x; + } + +private: + Type _data[M][N] {}; + Type _tau[N] {}; + +}; + +} // namespace matrix + +/* vim: set et fenc=utf-8 ff=unix sts=0 sw=4 ts=4 : */ diff --git a/matrix/math.hpp b/matrix/math.hpp index 342ae5e612..30479f26e0 100644 --- a/matrix/math.hpp +++ b/matrix/math.hpp @@ -14,3 +14,4 @@ #include "Scalar.hpp" #include "Quaternion.hpp" #include "AxisAngle.hpp" +#include "LeastSquaresSolver.hpp" diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index f6c56af60e..43fa7e12dc 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -17,6 +17,7 @@ set(tests helper hatvee copyto + least_squares ) add_custom_target(test_build) diff --git a/test/least_squares.cpp b/test/least_squares.cpp new file mode 100644 index 0000000000..57c7064662 --- /dev/null +++ b/test/least_squares.cpp @@ -0,0 +1,69 @@ +#include "test_macros.hpp" +#include + +using namespace matrix; + +int test_4x3(void); +int test_4x4(void); + +int main() +{ + int ret; + + ret = test_4x4(); + if (ret != 0) return ret; + + ret = test_4x3(); + if (ret != 0) return ret; + + return 0; +} + +int test_4x3() { + // Start with an (m x n) A matrix + float data[12] = {20.f , -10.f , -13.f , + 17.f , 16.f , -18.f , + 0.7f, -0.8f, 0.9f, + -1.f , -1.1f, -1.2f}; + Matrix A(data); + + float b_data[4] = {2.0, 3.0, 4.0, 5.0}; + Vector b(b_data); + + float x_check_data[3] = {-0.69168233f, + -0.26227593f, + -1.03767522f}; + Vector x_check(x_check_data); + + LeastSquaresSolver qrd = LeastSquaresSolver(A); + + Vector x = qrd.solve(b); + TEST(isEqual(x, x_check)); + return 0; +} + +int test_4x4() { + // Start with an (m x n) A matrix + float data[16] = { 20.f , -10.f , -13.f , 21.f , + 17.f , 16.f , -18.f , -14.f , + 0.7f, -0.8f, 0.9f, -0.5f, + -1.f , -1.1f, -1.2f, -1.3f}; + Matrix A(data); + + float b_data[4] = {2.0, 3.0, 4.0, 5.0}; + Vector b(b_data); + + float x_check_data[4] = { 0.97893433f, + -2.80798701f, + -0.03175765f, + -2.19387649f}; + Vector x_check(x_check_data); + + LeastSquaresSolver qrd = LeastSquaresSolver(A); + + Vector x = qrd.solve(b); + TEST(isEqual(x, x_check)); + return 0; +} + +/* vim: set et fenc=utf-8 ff=unix sts=0 sw=4 ts=4 : */ diff --git a/test/test_data.py b/test/test_data.py index 0ae4e958ae..94eed5e278 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -124,4 +124,26 @@ for i in range(1,3): A_pow = A_pow.dot(A) print(eA_approx) +print('\nqr decomposition 4x4') +A = array([[20.0, -10.0, -13.0, 21.0], [ 17.0, 16.0, -18.0, -14], [0.7, -0.8, 0.9, -0.5], [-1.0, -1.1, -1.2, -1.3]]) +b = array([[2.], [3.], [4.], [5.]]) +x = scipy.linalg.lstsq(A,b)[0] +print('A:') +pprint(A) +print('b:') +pprint(b) +print('x:') +pprint(scipy.linalg.lstsq(A,b)[0]) + +print('\nqr decomposition 4x3') +A = array([[20.0, -10.0, -13.0], [ 17.0, 16.0, -18.0], [0.7, -0.8, 0.9], [-1.0, -1.1, -1.2]]) +b = array([[2.], [3.], [4.], [5.]]) +x = scipy.linalg.lstsq(A,b)[0] +print('A:') +pprint(A) +print('b:') +pprint(b) +print('x:') +pprint(x) + # vim: set et ft=python fenc=utf-8 ff=unix sts=4 sw=4 ts=8 :