diff --git a/test/least_squares.cpp b/test/least_squares.cpp index 46042f2d96..14868218f7 100644 --- a/test/least_squares.cpp +++ b/test/least_squares.cpp @@ -5,6 +5,7 @@ using namespace matrix; int test_4x3(void); int test_4x4(void); +int test_div_zero(void); int main() { @@ -16,6 +17,9 @@ int main() ret = test_4x3(); if (ret != 0) return ret; + ret = test_div_zero(); + if (ret != 0) return ret; + return 0; } @@ -70,4 +74,22 @@ int test_4x4() { return 0; } +int test_div_zero() { + float data[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + Matrix A(data); + + float b_data[2] = {1.0, 1.0}; + Vector b(b_data); + + // Implement such that x returns zeros if it reaches div by zero + float x_check_data[2] = {0.0f, 0.0f}; + Vector x_check(x_check_data); + + LeastSquaresSolver qrd = LeastSquaresSolver(A); + + Vector x = qrd.solve(b); + TEST(isEqual(x, x_check)); + return 0; +} + /* vim: set et fenc=utf-8 ff=unix sts=0 sw=4 ts=4 : */