static assert M>=N. floats to Type, arguments as const reference

This commit is contained in:
Bart Slinger
2018-09-13 10:53:13 +02:00
committed by Beat Küng
parent 7495794386
commit 480c5f1f8e
2 changed files with 47 additions and 19 deletions
+18 -19
View File
@@ -30,31 +30,30 @@ public:
* *
* Initialize the class with a MxN matrix. The constructor starts the * Initialize the class with a MxN matrix. The constructor starts the
* QR decomposition. This class does not check the rank of the matrix. * QR decomposition. This class does not check the rank of the matrix.
* The user needs to make sure that rank(A) = N and N >= M. * The user needs to make sure that rank(A) = N and M >= N.
*/ */
LeastSquaresSolver(Matrix<Type, M, N> A) LeastSquaresSolver(const Matrix<Type, M, N> &A)
{ {
if (M < N) { static_assert(M >= N, "Matrix dimension should be M >= N");
return;
}
// Copy contentents of matrix A // Copy contentents of matrix A
_A = A; _A = A;
for (size_t j = 0; j < N; j++) { for (size_t j = 0; j < N; j++) {
float normx = 0.0f; Type normx = 0.;
for (size_t i = j; i < M; i++) { for (size_t i = j; i < M; i++) {
normx += _A(i,j) * _A(i,j); normx += _A(i,j) * _A(i,j);
} }
normx = sqrt(normx); normx = sqrt(normx);
float s = _A(j,j) > 0 ? -1.0f : 1.0f; Type s = _A(j,j) > 0 ? -1. : 1.;
float u1 = _A(j,j) - s*normx; Type u1 = _A(j,j) - s*normx;
// prevent divide by zero // prevent divide by zero
// also covers u1. normx is never negative // also covers u1. normx is never negative
if (normx < 1e-8f) { if (normx < 1e-8) {
break; break;
} }
float w[M] = {}; Type w[M] = {};
w[0] = 1.0f; w[0] = 1.;
for (size_t i = j+1; i < M; i++) { for (size_t i = j+1; i < M; i++) {
w[i-j] = _A(i,j) / u1; w[i-j] = _A(i,j) / u1;
_A(i,j) = w[i-j]; _A(i,j) = w[i-j];
@@ -63,7 +62,7 @@ public:
_tau(j) = -s*u1/normx; _tau(j) = -s*u1/normx;
for (size_t k = j+1; k < N; k++) { for (size_t k = j+1; k < N; k++) {
float tmp = 0.0f; Type tmp = 0.;
for (size_t i = j; i < M; i++) { for (size_t i = j; i < M; i++) {
tmp += w[i-j] * _A(i,k); tmp += w[i-j] * _A(i,k);
} }
@@ -83,17 +82,17 @@ public:
* This function calculates Q^T * b. This is useful for the solver * This function calculates Q^T * b. This is useful for the solver
* because R*x = Q^T*b. * because R*x = Q^T*b.
*/ */
Vector<Type, M> qtb(Vector<Type, M> b) { Vector<Type, M> qtb(const Vector<Type, M> &b) {
Vector<Type, M> qtbv = b; Vector<Type, M> qtbv = b;
for (size_t j = 0; j < N; j++) { for (size_t j = 0; j < N; j++) {
float w[M]; Type w[M];
w[0] = 1.0f; w[0] = 1.;
// fill vector w // fill vector w
for (size_t i = j+1; i < M; i++) { for (size_t i = j+1; i < M; i++) {
w[i-j] = _A(i,j); w[i-j] = _A(i,j);
} }
float tmp = 0.0f; Type tmp = 0.;
for (size_t i = j; i < M; i++) { for (size_t i = j; i < M; i++) {
tmp += w[i-j] * qtbv(i); tmp += w[i-j] * qtbv(i);
} }
@@ -113,7 +112,7 @@ public:
* Find x in the equation Ax = b. * Find x in the equation Ax = b.
* A is provided in the initializer of the class. * A is provided in the initializer of the class.
*/ */
Vector<Type, N> solve(Vector<Type, M> b) { Vector<Type, N> solve(const Vector<Type, M> &b) {
Vector<Type, M> qtbv = qtb(b); Vector<Type, M> qtbv = qtb(b);
Vector<Type, N> x; Vector<Type, N> x;
@@ -124,9 +123,9 @@ public:
x(i) -= _A(i,r) * x(r); x(i) -= _A(i,r) * x(r);
} }
// divide by zero, return vector of zeros // divide by zero, return vector of zeros
if (fabs(_A(i,i)) < 1e-8f) { if (fabs(_A(i,i)) < 1e-8) {
for (size_t z = 0; z < N; z++) { for (size_t z = 0; z < N; z++) {
x(z) = 0.0f; x(z) = 0.;
} }
break; break;
} }
+29
View File
@@ -5,6 +5,7 @@ using namespace matrix;
int test_4x3(void); int test_4x3(void);
int test_4x4(void); int test_4x4(void);
int test_4x4_type_double(void);
int test_div_zero(void); int test_div_zero(void);
int main() int main()
@@ -17,6 +18,8 @@ int main()
ret = test_4x3(); ret = test_4x3();
if (ret != 0) return ret; if (ret != 0) return ret;
ret = test_4x4_type_double();
ret = test_div_zero(); ret = test_div_zero();
if (ret != 0) return ret; if (ret != 0) return ret;
@@ -74,6 +77,32 @@ int test_4x4() {
return 0; return 0;
} }
int test_4x4_type_double() {
// Start with an (m x n) A matrix
double data[16] = { 20., -10., -13., 21.,
17., 16., -18., -14.,
0.7, -0.8, 0.9, -0.5,
-1., -1.1, -1.2, -1.3
};
Matrix<double, 4, 4> A(data);
double b_data[4] = {2.0, 3.0, 4.0, 5.0};
Vector<double, 4> b(b_data);
double x_check_data[4] = { 0.97893433,
-2.80798701,
-0.03175765,
-2.19387649
};
Vector<double, 4> x_check(x_check_data);
LeastSquaresSolver<double, 4, 4> qrd = LeastSquaresSolver<double, 4, 4>(A);
Vector<double, 4> x = qrd.solve(b);
TEST(isEqual(x, x_check));
return 0;
}
int test_div_zero() { int test_div_zero() {
float data[4] = {0.0f, 0.0f, 0.0f, 0.0f}; float data[4] = {0.0f, 0.0f, 0.0f, 0.0f};
Matrix<float, 2, 2> A(data); Matrix<float, 2, 2> A(data);