diff --git a/matrix/Matrix.hpp b/matrix/Matrix.hpp index ce8fa2833d..ba16ef801f 100644 --- a/matrix/Matrix.hpp +++ b/matrix/Matrix.hpp @@ -14,7 +14,7 @@ #include #include -#include "Vector.hpp" +#include "matrix.hpp" namespace matrix { @@ -175,6 +175,11 @@ public: return res; } + inline Matrix operator/(Type scalar) const + { + return (*this)*(1/scalar); + } + Matrix operator+(Type scalar) const { Matrix res; diff --git a/matrix/SquareMatrix.hpp b/matrix/SquareMatrix.hpp index 06061b420d..ec3a450188 100644 --- a/matrix/SquareMatrix.hpp +++ b/matrix/SquareMatrix.hpp @@ -165,7 +165,7 @@ public: return inverse(); } - Vector diagonal() const + Vector diag() const { Vector res; const SquareMatrix &self = *this; @@ -176,28 +176,6 @@ public: return res; } - SquareMatrix expm(Type dt, size_t n) const - { - SquareMatrix res; - res.setIdentity(); - SquareMatrix A_pow = *this; - size_t k_fact = 1; - size_t k = 1; - - while (k < n) { - res += A_pow * (Type(pow(dt, Type(k))) / Type(k_fact)); - - if (k == n) { - break; - } - - A_pow *= A_pow; - k_fact *= k; - k++; - } - - return res; - } }; typedef SquareMatrix SquareMatrix3f; @@ -218,6 +196,22 @@ SquareMatrix diag(Vector d) { return m; } +template +SquareMatrix expm(const SquareMatrix & A, size_t order=5) +{ + SquareMatrix res; + SquareMatrix A_pow = A; + res.setIdentity(); + size_t i_factorial = 1; + for (size_t i=1; i<=order; i++) { + i_factorial *= i; + res += A_pow / Type(i_factorial); + A_pow *= A_pow; + } + + return res; +} + }; // namespace matrix /* vim: set et fenc=utf-8 ff=unix sts=0 sw=4 ts=4 : */ diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 97bf5b5963..2f0570c7a6 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -10,6 +10,7 @@ set(tests vector3 attitude filter + squareMatrix ) foreach(test ${tests}) diff --git a/test/inverse.cpp b/test/inverse.cpp index 1bd92ae47c..c2d60c9003 100644 --- a/test/inverse.cpp +++ b/test/inverse.cpp @@ -12,12 +12,18 @@ template class SquareMatrix; int main() { - float data[9] = {1, 0, 0, 0, 1, 0, 1, 0, 1}; + float data[9] = {1, 2, 3, + 4, 5, 6, + 7, 8, 10}; + float data_check[9] = {-0.66666667f, -1.33333333f, 1. , + -0.66666667f, 3.66666667f, -2. , + 1. , -2. , 1. }; + SquareMatrix A(data); SquareMatrix A_I = A.inverse(); - float data_check[9] = {1, 0, 0, 0, 1, 0, -1, 0, 1}; SquareMatrix A_I_check(data_check); - (void)A_I; + A_I.print(); + A_I_check.print(); assert(A_I == A_I_check); // stess test diff --git a/test/squareMatrix.cpp b/test/squareMatrix.cpp new file mode 100644 index 0000000000..9b64d122c9 --- /dev/null +++ b/test/squareMatrix.cpp @@ -0,0 +1,34 @@ +#include +#include + +#include "matrix.hpp" + +using namespace matrix; + +template class SquareMatrix; + +int main() +{ + float data[9] = {1, 2, 3, + 4, 5, 6, + 7, 8, 10}; + SquareMatrix A(data); + Vector3 diag_check(1, 5, 10); + A.diag().T().print(); + + assert(A.diag() == diag_check); + + float data_check[9] = { + 1.01158503f, 0.02190432f, 0.03238144f, + 0.04349195f, 1.05428524f, 0.06539627f, + 0.07576783f, 0.08708946f, 1.10894048f}; + + printf("expm(A*t)\n"); + float dt = 0.01f; + SquareMatrix eA = expm(SquareMatrix(A*dt), 5); + SquareMatrix eA_check(data_check); + assert((eA - eA_check).abs().max() < 1e-3); + return 0; +} + +/* vim: set et fenc=utf-8 ff=unix sts=0 sw=4 ts=4 : */ diff --git a/test/test_data.py b/test/test_data.py index 5f64994b06..f6da657d98 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -1,6 +1,7 @@ from __future__ import print_function from pylab import * from pprint import pprint +import scipy.linalg # test cases, derived from doc/nasa_rotation_def.pdf @@ -101,4 +102,25 @@ print('\nq3_norm:') q3_norm =q3 / norm(q3) pprint(q3_norm) +print('\ninverse') +A = array([[1,2,3], [4,5,6], [7,8,10]]) +pprint(A) +pprint(inv(A)) + +print('\nmatrix exponential') +A = 0.01*array([[1.0,2.0,3.0], [4.0,5.0,6.0], [7.0,8.0,10.0]]) +eA_check = scipy.linalg.expm(A) + +pprint(eA_check) + +eA_approx = eye(3) +k = 1.0 +A_pow = A +for i in range(1,3): + k *= i + # print(i, k, '\n', A_pow/k, '\n') + eA_approx += A_pow/k + A_pow = A_pow.dot(A) +print(eA_approx) + # vim: set et ft=python fenc=utf-8 ff=unix sts=4 sw=4 ts=8 : diff --git a/test/vector.cpp b/test/vector.cpp index 5c63140311..dc27463da4 100644 --- a/test/vector.cpp +++ b/test/vector.cpp @@ -14,6 +14,10 @@ int main() (void)n; float r = v.dot(v); (void)r; + + Vector v2(v); + assert(v == v2); + return 0; }