least squares solver for MxN matrices using QR householder algorithm

This commit is contained in:
Bart Slinger
2018-09-09 12:48:33 +02:00
committed by Beat Küng
parent dc3af80977
commit 0009328257
5 changed files with 238 additions and 0 deletions
+145
View File
@@ -0,0 +1,145 @@
/**
* @file LeastSquaresSolver.hpp
*
* Least Squares Solver using QR householder decomposition.
* It calculates x for Ax = b.
* A = Q*R
* where R is an upper triangular matrix.
*
* R*x = Q^T*b
* This is efficiently solved for x because of the upper triangular property of R.
*
* @author Bart Slinger <bartslinger@gmail.com>
*/
#pragma once
#include "math.hpp"
namespace matrix {
template<typename Type, size_t M, size_t N>
class LeastSquaresSolver
{
public:
/**
* @brief Class calculates QR decomposition which can be used for linear
* least squares
* @param A Matrix of size MxN
*
* Initialize the class with a MxN matrix. The constructor starts the
* QR decomposition. This class does not check the rank of the matrix.
* The user needs to make sure that rank(A) = N and N >= M.
*/
LeastSquaresSolver(Matrix<Type, M, N> A)
{
if (M < N) {
return;
}
// Copy contentents of matrix A
memcpy(_data, A._data, sizeof(_data));
for (size_t j = 0; j < N; j++) {
float normx = 0.0f;
for (size_t i = j; i < M; i++) {
normx += _data[i][j] * _data[i][j];
}
normx = sqrt(normx);
float s = _data[j][j] > 0 ? -1.0f : 1.0f;
float u1 = _data[j][j] - s*normx;
// prevent divide by zero
// also covers u1. normx is never negative
if (normx < 1e-8f) {
return;
}
float w[M] = {};
w[0] = 1.0f;
for (size_t i = j+1; i < M; i++) {
w[i-j] = _data[i][j] / u1;
_data[i][j] = w[i-j];
}
_data[j][j] = s*normx;
_tau[j] = -s*u1/normx;
for (size_t k = j+1; k < N; k++) {
float tmp = 0.0f;
for (size_t i = j; i < M; i++) {
tmp += w[i-j] * _data[i][k];
}
for (size_t i = j; i < M; i++) {
_data[i][k] -= _tau[j] * w[i-j] * tmp;
}
}
}
}
/**
* @brief qtb Calculate Q^T * b
* @param b
* @return Q^T*b
*
* This function calculates Q^T * b. This is useful for the solver
* because R*x = Q^T*b.
*/
Vector<Type, M> qtb(Vector<Type, M> b) {
Vector<Type, M> qtbv = b;
for (size_t j = 0; j < N; j++) {
float w[M];
w[0] = 1.0f;
// fill vector w
for (size_t i = j+1; i < M; i++) {
w[i-j] = _data[i][j];
}
float tmp = 0.0f;
for (size_t i = j; i < M; i++) {
tmp += w[i-j] * qtbv(i);
}
for (size_t i = j; i < M; i++) {
qtbv(i) -= _tau[j] * w[i-j] * tmp;
}
}
return qtbv;
}
/**
* @brief Solve Ax=b for x
* @param b
* @return Vector x
*
* Find x in the equation Ax = b.
* A is provided in the initializer of the class.
*/
Vector<Type, N> solve(Vector<Type, M> b) {
Vector<Type, M> qtbv = qtb(b);
Vector<Type, N> x;
for (size_t l = N; l > 0 ; l--) {
size_t i = l - 1;
x(i) = qtbv(i);
for (size_t r = i+1; r < N; r++) {
x(i) -= _data[i][r] * x(r);
}
// divide by zero, return vector of zeros
if (fabs(_data[i][i]) < 1e-8f) {
for (size_t z = 0; z < N; z++) {
x(z) = 0.0f;
}
}
x(i) = x(i) / _data[i][i];
}
return x;
}
private:
Type _data[M][N] {};
Type _tau[N] {};
};
} // namespace matrix
/* vim: set et fenc=utf-8 ff=unix sts=0 sw=4 ts=4 : */
+1
View File
@@ -14,3 +14,4 @@
#include "Scalar.hpp"
#include "Quaternion.hpp"
#include "AxisAngle.hpp"
#include "LeastSquaresSolver.hpp"
+1
View File
@@ -17,6 +17,7 @@ set(tests
helper
hatvee
copyto
least_squares
)
add_custom_target(test_build)
+69
View File
@@ -0,0 +1,69 @@
#include "test_macros.hpp"
#include <matrix/math.hpp>
using namespace matrix;
int test_4x3(void);
int test_4x4(void);
int main()
{
int ret;
ret = test_4x4();
if (ret != 0) return ret;
ret = test_4x3();
if (ret != 0) return ret;
return 0;
}
int test_4x3() {
// Start with an (m x n) A matrix
float data[12] = {20.f , -10.f , -13.f ,
17.f , 16.f , -18.f ,
0.7f, -0.8f, 0.9f,
-1.f , -1.1f, -1.2f};
Matrix<float, 4, 3> A(data);
float b_data[4] = {2.0, 3.0, 4.0, 5.0};
Vector<float, 4> b(b_data);
float x_check_data[3] = {-0.69168233f,
-0.26227593f,
-1.03767522f};
Vector<float, 3> x_check(x_check_data);
LeastSquaresSolver<float, 4, 3> qrd = LeastSquaresSolver<float, 4, 3>(A);
Vector<float, 3> x = qrd.solve(b);
TEST(isEqual(x, x_check));
return 0;
}
int test_4x4() {
// Start with an (m x n) A matrix
float data[16] = { 20.f , -10.f , -13.f , 21.f ,
17.f , 16.f , -18.f , -14.f ,
0.7f, -0.8f, 0.9f, -0.5f,
-1.f , -1.1f, -1.2f, -1.3f};
Matrix<float, 4, 4> A(data);
float b_data[4] = {2.0, 3.0, 4.0, 5.0};
Vector<float, 4> b(b_data);
float x_check_data[4] = { 0.97893433f,
-2.80798701f,
-0.03175765f,
-2.19387649f};
Vector<float, 4> x_check(x_check_data);
LeastSquaresSolver<float, 4, 4> qrd = LeastSquaresSolver<float, 4, 4>(A);
Vector<float, 4> x = qrd.solve(b);
TEST(isEqual(x, x_check));
return 0;
}
/* vim: set et fenc=utf-8 ff=unix sts=0 sw=4 ts=4 : */
+22
View File
@@ -124,4 +124,26 @@ for i in range(1,3):
A_pow = A_pow.dot(A)
print(eA_approx)
print('\nqr decomposition 4x4')
A = array([[20.0, -10.0, -13.0, 21.0], [ 17.0, 16.0, -18.0, -14], [0.7, -0.8, 0.9, -0.5], [-1.0, -1.1, -1.2, -1.3]])
b = array([[2.], [3.], [4.], [5.]])
x = scipy.linalg.lstsq(A,b)[0]
print('A:')
pprint(A)
print('b:')
pprint(b)
print('x:')
pprint(scipy.linalg.lstsq(A,b)[0])
print('\nqr decomposition 4x3')
A = array([[20.0, -10.0, -13.0], [ 17.0, 16.0, -18.0], [0.7, -0.8, 0.9], [-1.0, -1.1, -1.2]])
b = array([[2.], [3.], [4.], [5.]])
x = scipy.linalg.lstsq(A,b)[0]
print('A:')
pprint(A)
print('b:')
pprint(b)
print('x:')
pprint(x)
# vim: set et ft=python fenc=utf-8 ff=unix sts=4 sw=4 ts=8 :