mirror of
https://gitee.com/mirrors_PX4/PX4-Autopilot.git
synced 2026-06-25 19:40:36 +08:00
least squares solver for MxN matrices using QR householder algorithm
This commit is contained in:
@@ -0,0 +1,145 @@
|
|||||||
|
/**
|
||||||
|
* @file LeastSquaresSolver.hpp
|
||||||
|
*
|
||||||
|
* Least Squares Solver using QR householder decomposition.
|
||||||
|
* It calculates x for Ax = b.
|
||||||
|
* A = Q*R
|
||||||
|
* where R is an upper triangular matrix.
|
||||||
|
*
|
||||||
|
* R*x = Q^T*b
|
||||||
|
* This is efficiently solved for x because of the upper triangular property of R.
|
||||||
|
*
|
||||||
|
* @author Bart Slinger <bartslinger@gmail.com>
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "math.hpp"
|
||||||
|
|
||||||
|
namespace matrix {
|
||||||
|
|
||||||
|
template<typename Type, size_t M, size_t N>
|
||||||
|
class LeastSquaresSolver
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Class calculates QR decomposition which can be used for linear
|
||||||
|
* least squares
|
||||||
|
* @param A Matrix of size MxN
|
||||||
|
*
|
||||||
|
* Initialize the class with a MxN matrix. The constructor starts the
|
||||||
|
* QR decomposition. This class does not check the rank of the matrix.
|
||||||
|
* The user needs to make sure that rank(A) = N and N >= M.
|
||||||
|
*/
|
||||||
|
LeastSquaresSolver(Matrix<Type, M, N> A)
|
||||||
|
{
|
||||||
|
if (M < N) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Copy contentents of matrix A
|
||||||
|
memcpy(_data, A._data, sizeof(_data));
|
||||||
|
|
||||||
|
for (size_t j = 0; j < N; j++) {
|
||||||
|
float normx = 0.0f;
|
||||||
|
for (size_t i = j; i < M; i++) {
|
||||||
|
normx += _data[i][j] * _data[i][j];
|
||||||
|
}
|
||||||
|
normx = sqrt(normx);
|
||||||
|
float s = _data[j][j] > 0 ? -1.0f : 1.0f;
|
||||||
|
float u1 = _data[j][j] - s*normx;
|
||||||
|
// prevent divide by zero
|
||||||
|
// also covers u1. normx is never negative
|
||||||
|
if (normx < 1e-8f) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
float w[M] = {};
|
||||||
|
w[0] = 1.0f;
|
||||||
|
for (size_t i = j+1; i < M; i++) {
|
||||||
|
w[i-j] = _data[i][j] / u1;
|
||||||
|
_data[i][j] = w[i-j];
|
||||||
|
}
|
||||||
|
_data[j][j] = s*normx;
|
||||||
|
_tau[j] = -s*u1/normx;
|
||||||
|
|
||||||
|
for (size_t k = j+1; k < N; k++) {
|
||||||
|
float tmp = 0.0f;
|
||||||
|
for (size_t i = j; i < M; i++) {
|
||||||
|
tmp += w[i-j] * _data[i][k];
|
||||||
|
}
|
||||||
|
for (size_t i = j; i < M; i++) {
|
||||||
|
_data[i][k] -= _tau[j] * w[i-j] * tmp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief qtb Calculate Q^T * b
|
||||||
|
* @param b
|
||||||
|
* @return Q^T*b
|
||||||
|
*
|
||||||
|
* This function calculates Q^T * b. This is useful for the solver
|
||||||
|
* because R*x = Q^T*b.
|
||||||
|
*/
|
||||||
|
Vector<Type, M> qtb(Vector<Type, M> b) {
|
||||||
|
Vector<Type, M> qtbv = b;
|
||||||
|
|
||||||
|
for (size_t j = 0; j < N; j++) {
|
||||||
|
float w[M];
|
||||||
|
w[0] = 1.0f;
|
||||||
|
// fill vector w
|
||||||
|
for (size_t i = j+1; i < M; i++) {
|
||||||
|
w[i-j] = _data[i][j];
|
||||||
|
}
|
||||||
|
float tmp = 0.0f;
|
||||||
|
for (size_t i = j; i < M; i++) {
|
||||||
|
tmp += w[i-j] * qtbv(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = j; i < M; i++) {
|
||||||
|
qtbv(i) -= _tau[j] * w[i-j] * tmp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return qtbv;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Solve Ax=b for x
|
||||||
|
* @param b
|
||||||
|
* @return Vector x
|
||||||
|
*
|
||||||
|
* Find x in the equation Ax = b.
|
||||||
|
* A is provided in the initializer of the class.
|
||||||
|
*/
|
||||||
|
Vector<Type, N> solve(Vector<Type, M> b) {
|
||||||
|
Vector<Type, M> qtbv = qtb(b);
|
||||||
|
Vector<Type, N> x;
|
||||||
|
|
||||||
|
for (size_t l = N; l > 0 ; l--) {
|
||||||
|
size_t i = l - 1;
|
||||||
|
x(i) = qtbv(i);
|
||||||
|
for (size_t r = i+1; r < N; r++) {
|
||||||
|
x(i) -= _data[i][r] * x(r);
|
||||||
|
}
|
||||||
|
// divide by zero, return vector of zeros
|
||||||
|
if (fabs(_data[i][i]) < 1e-8f) {
|
||||||
|
for (size_t z = 0; z < N; z++) {
|
||||||
|
x(z) = 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
x(i) = x(i) / _data[i][i];
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Type _data[M][N] {};
|
||||||
|
Type _tau[N] {};
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace matrix
|
||||||
|
|
||||||
|
/* vim: set et fenc=utf-8 ff=unix sts=0 sw=4 ts=4 : */
|
||||||
@@ -14,3 +14,4 @@
|
|||||||
#include "Scalar.hpp"
|
#include "Scalar.hpp"
|
||||||
#include "Quaternion.hpp"
|
#include "Quaternion.hpp"
|
||||||
#include "AxisAngle.hpp"
|
#include "AxisAngle.hpp"
|
||||||
|
#include "LeastSquaresSolver.hpp"
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ set(tests
|
|||||||
helper
|
helper
|
||||||
hatvee
|
hatvee
|
||||||
copyto
|
copyto
|
||||||
|
least_squares
|
||||||
)
|
)
|
||||||
|
|
||||||
add_custom_target(test_build)
|
add_custom_target(test_build)
|
||||||
|
|||||||
@@ -0,0 +1,69 @@
|
|||||||
|
#include "test_macros.hpp"
|
||||||
|
#include <matrix/math.hpp>
|
||||||
|
|
||||||
|
using namespace matrix;
|
||||||
|
|
||||||
|
int test_4x3(void);
|
||||||
|
int test_4x4(void);
|
||||||
|
|
||||||
|
int main()
|
||||||
|
{
|
||||||
|
int ret;
|
||||||
|
|
||||||
|
ret = test_4x4();
|
||||||
|
if (ret != 0) return ret;
|
||||||
|
|
||||||
|
ret = test_4x3();
|
||||||
|
if (ret != 0) return ret;
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int test_4x3() {
|
||||||
|
// Start with an (m x n) A matrix
|
||||||
|
float data[12] = {20.f , -10.f , -13.f ,
|
||||||
|
17.f , 16.f , -18.f ,
|
||||||
|
0.7f, -0.8f, 0.9f,
|
||||||
|
-1.f , -1.1f, -1.2f};
|
||||||
|
Matrix<float, 4, 3> A(data);
|
||||||
|
|
||||||
|
float b_data[4] = {2.0, 3.0, 4.0, 5.0};
|
||||||
|
Vector<float, 4> b(b_data);
|
||||||
|
|
||||||
|
float x_check_data[3] = {-0.69168233f,
|
||||||
|
-0.26227593f,
|
||||||
|
-1.03767522f};
|
||||||
|
Vector<float, 3> x_check(x_check_data);
|
||||||
|
|
||||||
|
LeastSquaresSolver<float, 4, 3> qrd = LeastSquaresSolver<float, 4, 3>(A);
|
||||||
|
|
||||||
|
Vector<float, 3> x = qrd.solve(b);
|
||||||
|
TEST(isEqual(x, x_check));
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int test_4x4() {
|
||||||
|
// Start with an (m x n) A matrix
|
||||||
|
float data[16] = { 20.f , -10.f , -13.f , 21.f ,
|
||||||
|
17.f , 16.f , -18.f , -14.f ,
|
||||||
|
0.7f, -0.8f, 0.9f, -0.5f,
|
||||||
|
-1.f , -1.1f, -1.2f, -1.3f};
|
||||||
|
Matrix<float, 4, 4> A(data);
|
||||||
|
|
||||||
|
float b_data[4] = {2.0, 3.0, 4.0, 5.0};
|
||||||
|
Vector<float, 4> b(b_data);
|
||||||
|
|
||||||
|
float x_check_data[4] = { 0.97893433f,
|
||||||
|
-2.80798701f,
|
||||||
|
-0.03175765f,
|
||||||
|
-2.19387649f};
|
||||||
|
Vector<float, 4> x_check(x_check_data);
|
||||||
|
|
||||||
|
LeastSquaresSolver<float, 4, 4> qrd = LeastSquaresSolver<float, 4, 4>(A);
|
||||||
|
|
||||||
|
Vector<float, 4> x = qrd.solve(b);
|
||||||
|
TEST(isEqual(x, x_check));
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* vim: set et fenc=utf-8 ff=unix sts=0 sw=4 ts=4 : */
|
||||||
@@ -124,4 +124,26 @@ for i in range(1,3):
|
|||||||
A_pow = A_pow.dot(A)
|
A_pow = A_pow.dot(A)
|
||||||
print(eA_approx)
|
print(eA_approx)
|
||||||
|
|
||||||
|
print('\nqr decomposition 4x4')
|
||||||
|
A = array([[20.0, -10.0, -13.0, 21.0], [ 17.0, 16.0, -18.0, -14], [0.7, -0.8, 0.9, -0.5], [-1.0, -1.1, -1.2, -1.3]])
|
||||||
|
b = array([[2.], [3.], [4.], [5.]])
|
||||||
|
x = scipy.linalg.lstsq(A,b)[0]
|
||||||
|
print('A:')
|
||||||
|
pprint(A)
|
||||||
|
print('b:')
|
||||||
|
pprint(b)
|
||||||
|
print('x:')
|
||||||
|
pprint(scipy.linalg.lstsq(A,b)[0])
|
||||||
|
|
||||||
|
print('\nqr decomposition 4x3')
|
||||||
|
A = array([[20.0, -10.0, -13.0], [ 17.0, 16.0, -18.0], [0.7, -0.8, 0.9], [-1.0, -1.1, -1.2]])
|
||||||
|
b = array([[2.], [3.], [4.], [5.]])
|
||||||
|
x = scipy.linalg.lstsq(A,b)[0]
|
||||||
|
print('A:')
|
||||||
|
pprint(A)
|
||||||
|
print('b:')
|
||||||
|
pprint(b)
|
||||||
|
print('x:')
|
||||||
|
pprint(x)
|
||||||
|
|
||||||
# vim: set et ft=python fenc=utf-8 ff=unix sts=4 sw=4 ts=8 :
|
# vim: set et ft=python fenc=utf-8 ff=unix sts=4 sw=4 ts=8 :
|
||||||
|
|||||||
Reference in New Issue
Block a user