static assert M>=N. floats to Type, arguments as const reference

This commit is contained in:
Bart Slinger
2018-09-13 10:53:13 +02:00
committed by Beat Küng
parent 7495794386
commit 480c5f1f8e
2 changed files with 47 additions and 19 deletions
+18 -19
View File
@@ -30,31 +30,30 @@ public:
*
* Initialize the class with a MxN matrix. The constructor starts the
* QR decomposition. This class does not check the rank of the matrix.
* The user needs to make sure that rank(A) = N and N >= M.
* The user needs to make sure that rank(A) = N and M >= N.
*/
LeastSquaresSolver(Matrix<Type, M, N> A)
LeastSquaresSolver(const Matrix<Type, M, N> &A)
{
if (M < N) {
return;
}
static_assert(M >= N, "Matrix dimension should be M >= N");
// Copy contentents of matrix A
_A = A;
for (size_t j = 0; j < N; j++) {
float normx = 0.0f;
Type normx = 0.;
for (size_t i = j; i < M; i++) {
normx += _A(i,j) * _A(i,j);
}
normx = sqrt(normx);
float s = _A(j,j) > 0 ? -1.0f : 1.0f;
float u1 = _A(j,j) - s*normx;
Type s = _A(j,j) > 0 ? -1. : 1.;
Type u1 = _A(j,j) - s*normx;
// prevent divide by zero
// also covers u1. normx is never negative
if (normx < 1e-8f) {
if (normx < 1e-8) {
break;
}
float w[M] = {};
w[0] = 1.0f;
Type w[M] = {};
w[0] = 1.;
for (size_t i = j+1; i < M; i++) {
w[i-j] = _A(i,j) / u1;
_A(i,j) = w[i-j];
@@ -63,7 +62,7 @@ public:
_tau(j) = -s*u1/normx;
for (size_t k = j+1; k < N; k++) {
float tmp = 0.0f;
Type tmp = 0.;
for (size_t i = j; i < M; i++) {
tmp += w[i-j] * _A(i,k);
}
@@ -83,17 +82,17 @@ public:
* This function calculates Q^T * b. This is useful for the solver
* because R*x = Q^T*b.
*/
Vector<Type, M> qtb(Vector<Type, M> b) {
Vector<Type, M> qtb(const Vector<Type, M> &b) {
Vector<Type, M> qtbv = b;
for (size_t j = 0; j < N; j++) {
float w[M];
w[0] = 1.0f;
Type w[M];
w[0] = 1.;
// fill vector w
for (size_t i = j+1; i < M; i++) {
w[i-j] = _A(i,j);
}
float tmp = 0.0f;
Type tmp = 0.;
for (size_t i = j; i < M; i++) {
tmp += w[i-j] * qtbv(i);
}
@@ -113,7 +112,7 @@ public:
* Find x in the equation Ax = b.
* A is provided in the initializer of the class.
*/
Vector<Type, N> solve(Vector<Type, M> b) {
Vector<Type, N> solve(const Vector<Type, M> &b) {
Vector<Type, M> qtbv = qtb(b);
Vector<Type, N> x;
@@ -124,9 +123,9 @@ public:
x(i) -= _A(i,r) * x(r);
}
// divide by zero, return vector of zeros
if (fabs(_A(i,i)) < 1e-8f) {
if (fabs(_A(i,i)) < 1e-8) {
for (size_t z = 0; z < N; z++) {
x(z) = 0.0f;
x(z) = 0.;
}
break;
}
+29
View File
@@ -5,6 +5,7 @@ using namespace matrix;
int test_4x3(void);
int test_4x4(void);
int test_4x4_type_double(void);
int test_div_zero(void);
int main()
@@ -17,6 +18,8 @@ int main()
ret = test_4x3();
if (ret != 0) return ret;
ret = test_4x4_type_double();
ret = test_div_zero();
if (ret != 0) return ret;
@@ -74,6 +77,32 @@ int test_4x4() {
return 0;
}
int test_4x4_type_double() {
// Start with an (m x n) A matrix
double data[16] = { 20., -10., -13., 21.,
17., 16., -18., -14.,
0.7, -0.8, 0.9, -0.5,
-1., -1.1, -1.2, -1.3
};
Matrix<double, 4, 4> A(data);
double b_data[4] = {2.0, 3.0, 4.0, 5.0};
Vector<double, 4> b(b_data);
double x_check_data[4] = { 0.97893433,
-2.80798701,
-0.03175765,
-2.19387649
};
Vector<double, 4> x_check(x_check_data);
LeastSquaresSolver<double, 4, 4> qrd = LeastSquaresSolver<double, 4, 4>(A);
Vector<double, 4> x = qrd.solve(b);
TEST(isEqual(x, x_check));
return 0;
}
int test_div_zero() {
float data[4] = {0.0f, 0.0f, 0.0f, 0.0f};
Matrix<float, 2, 2> A(data);