mirror of
https://gitee.com/mirrors_PX4/PX4-Autopilot.git
synced 2026-05-02 05:04:08 +08:00
Moved inverse outside of matrix definition.
This commit is contained in:
parent
5566b3dc77
commit
00a0b36836
@ -41,128 +41,10 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
/**
|
||||
* inverse based on LU factorization with partial pivotting
|
||||
*/
|
||||
SquareMatrix <Type, M> inverse() const
|
||||
{
|
||||
SquareMatrix<Type, M> L;
|
||||
L.setIdentity();
|
||||
const SquareMatrix<Type, M> &A = (*this);
|
||||
SquareMatrix<Type, M> U = A;
|
||||
SquareMatrix<Type, M> P;
|
||||
P.setIdentity();
|
||||
|
||||
//printf("A:\n"); A.print();
|
||||
|
||||
// for all diagonal elements
|
||||
for (size_t n = 0; n < M; n++) {
|
||||
|
||||
// if diagonal is zero, swap with row below
|
||||
if (fabsf(U(n, n)) < 1e-8f) {
|
||||
//printf("trying pivot for row %d\n",n);
|
||||
for (size_t i = 0; i < M; i++) {
|
||||
if (i == n) {
|
||||
continue;
|
||||
}
|
||||
|
||||
//printf("\ttrying row %d\n",i);
|
||||
if (fabsf(U(i, n)) > 1e-8f) {
|
||||
//printf("swapped %d\n",i);
|
||||
U.swapRows(i, n);
|
||||
P.swapRows(i, n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef MATRIX_ASSERT
|
||||
//printf("A:\n"); A.print();
|
||||
//printf("U:\n"); U.print();
|
||||
//printf("P:\n"); P.print();
|
||||
//fflush(stdout);
|
||||
ASSERT(fabsf(U(n, n)) > 1e-8f);
|
||||
#endif
|
||||
|
||||
// failsafe, return zero matrix
|
||||
if (fabsf(U(n, n)) < 1e-8f) {
|
||||
SquareMatrix<Type, M> zero;
|
||||
zero.setZero();
|
||||
return zero;
|
||||
}
|
||||
|
||||
// for all rows below diagonal
|
||||
for (size_t i = (n + 1); i < M; i++) {
|
||||
L(i, n) = U(i, n) / U(n, n);
|
||||
|
||||
// add i-th row and n-th row
|
||||
// multiplied by: -a(i,n)/a(n,n)
|
||||
for (size_t k = n; k < M; k++) {
|
||||
U(i, k) -= L(i, n) * U(n, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//printf("L:\n"); L.print();
|
||||
//printf("U:\n"); U.print();
|
||||
|
||||
// solve LY=P*I for Y by forward subst
|
||||
SquareMatrix<Type, M> Y = P;
|
||||
|
||||
// for all columns of Y
|
||||
for (size_t c = 0; c < M; c++) {
|
||||
// for all rows of L
|
||||
for (size_t i = 0; i < M; i++) {
|
||||
// for all columns of L
|
||||
for (size_t j = 0; j < i; j++) {
|
||||
// for all existing y
|
||||
// subtract the component they
|
||||
// contribute to the solution
|
||||
Y(i, c) -= L(i, j) * Y(j, c);
|
||||
}
|
||||
|
||||
// divide by the factor
|
||||
// on current
|
||||
// term to be solved
|
||||
// Y(i,c) /= L(i,i);
|
||||
// but L(i,i) = 1.0
|
||||
}
|
||||
}
|
||||
|
||||
//printf("Y:\n"); Y.print();
|
||||
|
||||
// solve Ux=y for x by back subst
|
||||
SquareMatrix<Type, M> X = Y;
|
||||
|
||||
// for all columns of X
|
||||
for (size_t c = 0; c < M; c++) {
|
||||
// for all rows of U
|
||||
for (size_t k = 0; k < M; k++) {
|
||||
// have to go in reverse order
|
||||
size_t i = M - 1 - k;
|
||||
|
||||
// for all columns of U
|
||||
for (size_t j = i + 1; j < M; j++) {
|
||||
// for all existing x
|
||||
// subtract the component they
|
||||
// contribute to the solution
|
||||
X(i, c) -= U(i, j) * X(j, c);
|
||||
}
|
||||
|
||||
// divide by the factor
|
||||
// on current
|
||||
// term to be solved
|
||||
X(i, c) /= U(i, i);
|
||||
}
|
||||
}
|
||||
|
||||
//printf("X:\n"); X.print();
|
||||
return X;
|
||||
}
|
||||
|
||||
// inverse alias
|
||||
inline SquareMatrix<Type, M> I() const
|
||||
{
|
||||
return inverse();
|
||||
return inv(*this);
|
||||
}
|
||||
|
||||
Vector<Type, M> diag() const
|
||||
@ -212,6 +94,126 @@ SquareMatrix<Type, M> expm(const SquareMatrix<Type, M> & A, size_t order=5)
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* inverse based on LU factorization with partial pivotting
|
||||
*/
|
||||
template<typename Type, size_t M>
|
||||
SquareMatrix <Type, M> inv(const SquareMatrix<Type, M> & A)
|
||||
{
|
||||
SquareMatrix<Type, M> L;
|
||||
L.setIdentity();
|
||||
SquareMatrix<Type, M> U = A;
|
||||
SquareMatrix<Type, M> P;
|
||||
P.setIdentity();
|
||||
|
||||
//printf("A:\n"); A.print();
|
||||
|
||||
// for all diagonal elements
|
||||
for (size_t n = 0; n < M; n++) {
|
||||
|
||||
// if diagonal is zero, swap with row below
|
||||
if (fabsf(U(n, n)) < 1e-8f) {
|
||||
//printf("trying pivot for row %d\n",n);
|
||||
for (size_t i = 0; i < M; i++) {
|
||||
if (i == n) {
|
||||
continue;
|
||||
}
|
||||
|
||||
//printf("\ttrying row %d\n",i);
|
||||
if (fabsf(U(i, n)) > 1e-8f) {
|
||||
//printf("swapped %d\n",i);
|
||||
U.swapRows(i, n);
|
||||
P.swapRows(i, n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef MATRIX_ASSERT
|
||||
//printf("A:\n"); A.print();
|
||||
//printf("U:\n"); U.print();
|
||||
//printf("P:\n"); P.print();
|
||||
//fflush(stdout);
|
||||
ASSERT(fabsf(U(n, n)) > 1e-8f);
|
||||
#endif
|
||||
|
||||
// failsafe, return zero matrix
|
||||
if (fabsf(U(n, n)) < 1e-8f) {
|
||||
SquareMatrix<Type, M> zero;
|
||||
zero.setZero();
|
||||
return zero;
|
||||
}
|
||||
|
||||
// for all rows below diagonal
|
||||
for (size_t i = (n + 1); i < M; i++) {
|
||||
L(i, n) = U(i, n) / U(n, n);
|
||||
|
||||
// add i-th row and n-th row
|
||||
// multiplied by: -a(i,n)/a(n,n)
|
||||
for (size_t k = n; k < M; k++) {
|
||||
U(i, k) -= L(i, n) * U(n, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//printf("L:\n"); L.print();
|
||||
//printf("U:\n"); U.print();
|
||||
|
||||
// solve LY=P*I for Y by forward subst
|
||||
SquareMatrix<Type, M> Y = P;
|
||||
|
||||
// for all columns of Y
|
||||
for (size_t c = 0; c < M; c++) {
|
||||
// for all rows of L
|
||||
for (size_t i = 0; i < M; i++) {
|
||||
// for all columns of L
|
||||
for (size_t j = 0; j < i; j++) {
|
||||
// for all existing y
|
||||
// subtract the component they
|
||||
// contribute to the solution
|
||||
Y(i, c) -= L(i, j) * Y(j, c);
|
||||
}
|
||||
|
||||
// divide by the factor
|
||||
// on current
|
||||
// term to be solved
|
||||
// Y(i,c) /= L(i,i);
|
||||
// but L(i,i) = 1.0
|
||||
}
|
||||
}
|
||||
|
||||
//printf("Y:\n"); Y.print();
|
||||
|
||||
// solve Ux=y for x by back subst
|
||||
SquareMatrix<Type, M> X = Y;
|
||||
|
||||
// for all columns of X
|
||||
for (size_t c = 0; c < M; c++) {
|
||||
// for all rows of U
|
||||
for (size_t k = 0; k < M; k++) {
|
||||
// have to go in reverse order
|
||||
size_t i = M - 1 - k;
|
||||
|
||||
// for all columns of U
|
||||
for (size_t j = i + 1; j < M; j++) {
|
||||
// for all existing x
|
||||
// subtract the component they
|
||||
// contribute to the solution
|
||||
X(i, c) -= U(i, j) * X(j, c);
|
||||
}
|
||||
|
||||
// divide by the factor
|
||||
// on current
|
||||
// term to be solved
|
||||
X(i, c) /= U(i, i);
|
||||
}
|
||||
}
|
||||
|
||||
//printf("X:\n"); X.print();
|
||||
return X;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}; // namespace matrix
|
||||
|
||||
/* vim: set et fenc=utf-8 ff=unix sts=0 sw=4 ts=4 : */
|
||||
|
||||
@ -20,7 +20,7 @@ int main()
|
||||
1. , -2. , 1. };
|
||||
|
||||
SquareMatrix<float, 3> A(data);
|
||||
SquareMatrix<float, 3> A_I = A.inverse();
|
||||
SquareMatrix<float, 3> A_I = inv(A);
|
||||
SquareMatrix<float, 3> A_I_check(data_check);
|
||||
A_I.print();
|
||||
A_I_check.print();
|
||||
@ -33,7 +33,7 @@ int main()
|
||||
A_large_I.setZero();
|
||||
|
||||
for (size_t i = 0; i < n_large; i++) {
|
||||
A_large_I = A_large.inverse();
|
||||
A_large_I = inv(A_large);
|
||||
assert(A_large == A_large_I);
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user