mirror of
https://gitee.com/mirrors_PX4/PX4-Autopilot.git
synced 2026-06-25 00:20:34 +08:00
least squares solver for MxN matrices using QR householder algorithm
This commit is contained in:
@@ -0,0 +1,145 @@
|
||||
/**
|
||||
* @file LeastSquaresSolver.hpp
|
||||
*
|
||||
* Least Squares Solver using QR householder decomposition.
|
||||
* It calculates x for Ax = b.
|
||||
* A = Q*R
|
||||
* where R is an upper triangular matrix.
|
||||
*
|
||||
* R*x = Q^T*b
|
||||
* This is efficiently solved for x because of the upper triangular property of R.
|
||||
*
|
||||
* @author Bart Slinger <bartslinger@gmail.com>
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "math.hpp"
|
||||
|
||||
namespace matrix {
|
||||
|
||||
template<typename Type, size_t M, size_t N>
|
||||
class LeastSquaresSolver
|
||||
{
|
||||
public:
|
||||
|
||||
/**
|
||||
* @brief Class calculates QR decomposition which can be used for linear
|
||||
* least squares
|
||||
* @param A Matrix of size MxN
|
||||
*
|
||||
* Initialize the class with a MxN matrix. The constructor starts the
|
||||
* QR decomposition. This class does not check the rank of the matrix.
|
||||
* The user needs to make sure that rank(A) = N and N >= M.
|
||||
*/
|
||||
LeastSquaresSolver(Matrix<Type, M, N> A)
|
||||
{
|
||||
if (M < N) {
|
||||
return;
|
||||
}
|
||||
// Copy contentents of matrix A
|
||||
memcpy(_data, A._data, sizeof(_data));
|
||||
|
||||
for (size_t j = 0; j < N; j++) {
|
||||
float normx = 0.0f;
|
||||
for (size_t i = j; i < M; i++) {
|
||||
normx += _data[i][j] * _data[i][j];
|
||||
}
|
||||
normx = sqrt(normx);
|
||||
float s = _data[j][j] > 0 ? -1.0f : 1.0f;
|
||||
float u1 = _data[j][j] - s*normx;
|
||||
// prevent divide by zero
|
||||
// also covers u1. normx is never negative
|
||||
if (normx < 1e-8f) {
|
||||
return;
|
||||
}
|
||||
float w[M] = {};
|
||||
w[0] = 1.0f;
|
||||
for (size_t i = j+1; i < M; i++) {
|
||||
w[i-j] = _data[i][j] / u1;
|
||||
_data[i][j] = w[i-j];
|
||||
}
|
||||
_data[j][j] = s*normx;
|
||||
_tau[j] = -s*u1/normx;
|
||||
|
||||
for (size_t k = j+1; k < N; k++) {
|
||||
float tmp = 0.0f;
|
||||
for (size_t i = j; i < M; i++) {
|
||||
tmp += w[i-j] * _data[i][k];
|
||||
}
|
||||
for (size_t i = j; i < M; i++) {
|
||||
_data[i][k] -= _tau[j] * w[i-j] * tmp;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief qtb Calculate Q^T * b
|
||||
* @param b
|
||||
* @return Q^T*b
|
||||
*
|
||||
* This function calculates Q^T * b. This is useful for the solver
|
||||
* because R*x = Q^T*b.
|
||||
*/
|
||||
Vector<Type, M> qtb(Vector<Type, M> b) {
|
||||
Vector<Type, M> qtbv = b;
|
||||
|
||||
for (size_t j = 0; j < N; j++) {
|
||||
float w[M];
|
||||
w[0] = 1.0f;
|
||||
// fill vector w
|
||||
for (size_t i = j+1; i < M; i++) {
|
||||
w[i-j] = _data[i][j];
|
||||
}
|
||||
float tmp = 0.0f;
|
||||
for (size_t i = j; i < M; i++) {
|
||||
tmp += w[i-j] * qtbv(i);
|
||||
}
|
||||
|
||||
for (size_t i = j; i < M; i++) {
|
||||
qtbv(i) -= _tau[j] * w[i-j] * tmp;
|
||||
}
|
||||
}
|
||||
return qtbv;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Solve Ax=b for x
|
||||
* @param b
|
||||
* @return Vector x
|
||||
*
|
||||
* Find x in the equation Ax = b.
|
||||
* A is provided in the initializer of the class.
|
||||
*/
|
||||
Vector<Type, N> solve(Vector<Type, M> b) {
|
||||
Vector<Type, M> qtbv = qtb(b);
|
||||
Vector<Type, N> x;
|
||||
|
||||
for (size_t l = N; l > 0 ; l--) {
|
||||
size_t i = l - 1;
|
||||
x(i) = qtbv(i);
|
||||
for (size_t r = i+1; r < N; r++) {
|
||||
x(i) -= _data[i][r] * x(r);
|
||||
}
|
||||
// divide by zero, return vector of zeros
|
||||
if (fabs(_data[i][i]) < 1e-8f) {
|
||||
for (size_t z = 0; z < N; z++) {
|
||||
x(z) = 0.0f;
|
||||
}
|
||||
}
|
||||
x(i) = x(i) / _data[i][i];
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
private:
|
||||
Type _data[M][N] {};
|
||||
Type _tau[N] {};
|
||||
|
||||
};
|
||||
|
||||
} // namespace matrix
|
||||
|
||||
/* vim: set et fenc=utf-8 ff=unix sts=0 sw=4 ts=4 : */
|
||||
@@ -14,3 +14,4 @@
|
||||
#include "Scalar.hpp"
|
||||
#include "Quaternion.hpp"
|
||||
#include "AxisAngle.hpp"
|
||||
#include "LeastSquaresSolver.hpp"
|
||||
|
||||
@@ -17,6 +17,7 @@ set(tests
|
||||
helper
|
||||
hatvee
|
||||
copyto
|
||||
least_squares
|
||||
)
|
||||
|
||||
add_custom_target(test_build)
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
#include "test_macros.hpp"
|
||||
#include <matrix/math.hpp>
|
||||
|
||||
using namespace matrix;
|
||||
|
||||
int test_4x3(void);
|
||||
int test_4x4(void);
|
||||
|
||||
int main()
|
||||
{
|
||||
int ret;
|
||||
|
||||
ret = test_4x4();
|
||||
if (ret != 0) return ret;
|
||||
|
||||
ret = test_4x3();
|
||||
if (ret != 0) return ret;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int test_4x3() {
|
||||
// Start with an (m x n) A matrix
|
||||
float data[12] = {20.f , -10.f , -13.f ,
|
||||
17.f , 16.f , -18.f ,
|
||||
0.7f, -0.8f, 0.9f,
|
||||
-1.f , -1.1f, -1.2f};
|
||||
Matrix<float, 4, 3> A(data);
|
||||
|
||||
float b_data[4] = {2.0, 3.0, 4.0, 5.0};
|
||||
Vector<float, 4> b(b_data);
|
||||
|
||||
float x_check_data[3] = {-0.69168233f,
|
||||
-0.26227593f,
|
||||
-1.03767522f};
|
||||
Vector<float, 3> x_check(x_check_data);
|
||||
|
||||
LeastSquaresSolver<float, 4, 3> qrd = LeastSquaresSolver<float, 4, 3>(A);
|
||||
|
||||
Vector<float, 3> x = qrd.solve(b);
|
||||
TEST(isEqual(x, x_check));
|
||||
return 0;
|
||||
}
|
||||
|
||||
int test_4x4() {
|
||||
// Start with an (m x n) A matrix
|
||||
float data[16] = { 20.f , -10.f , -13.f , 21.f ,
|
||||
17.f , 16.f , -18.f , -14.f ,
|
||||
0.7f, -0.8f, 0.9f, -0.5f,
|
||||
-1.f , -1.1f, -1.2f, -1.3f};
|
||||
Matrix<float, 4, 4> A(data);
|
||||
|
||||
float b_data[4] = {2.0, 3.0, 4.0, 5.0};
|
||||
Vector<float, 4> b(b_data);
|
||||
|
||||
float x_check_data[4] = { 0.97893433f,
|
||||
-2.80798701f,
|
||||
-0.03175765f,
|
||||
-2.19387649f};
|
||||
Vector<float, 4> x_check(x_check_data);
|
||||
|
||||
LeastSquaresSolver<float, 4, 4> qrd = LeastSquaresSolver<float, 4, 4>(A);
|
||||
|
||||
Vector<float, 4> x = qrd.solve(b);
|
||||
TEST(isEqual(x, x_check));
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* vim: set et fenc=utf-8 ff=unix sts=0 sw=4 ts=4 : */
|
||||
@@ -124,4 +124,26 @@ for i in range(1,3):
|
||||
A_pow = A_pow.dot(A)
|
||||
print(eA_approx)
|
||||
|
||||
print('\nqr decomposition 4x4')
|
||||
A = array([[20.0, -10.0, -13.0, 21.0], [ 17.0, 16.0, -18.0, -14], [0.7, -0.8, 0.9, -0.5], [-1.0, -1.1, -1.2, -1.3]])
|
||||
b = array([[2.], [3.], [4.], [5.]])
|
||||
x = scipy.linalg.lstsq(A,b)[0]
|
||||
print('A:')
|
||||
pprint(A)
|
||||
print('b:')
|
||||
pprint(b)
|
||||
print('x:')
|
||||
pprint(scipy.linalg.lstsq(A,b)[0])
|
||||
|
||||
print('\nqr decomposition 4x3')
|
||||
A = array([[20.0, -10.0, -13.0], [ 17.0, 16.0, -18.0], [0.7, -0.8, 0.9], [-1.0, -1.1, -1.2]])
|
||||
b = array([[2.], [3.], [4.], [5.]])
|
||||
x = scipy.linalg.lstsq(A,b)[0]
|
||||
print('A:')
|
||||
pprint(A)
|
||||
print('b:')
|
||||
pprint(b)
|
||||
print('x:')
|
||||
pprint(x)
|
||||
|
||||
# vim: set et ft=python fenc=utf-8 ff=unix sts=4 sw=4 ts=8 :
|
||||
|
||||
Reference in New Issue
Block a user