diff --git a/matrix/integration.hpp b/matrix/integration.hpp index a0ad55c68a..8909a17c60 100644 --- a/matrix/integration.hpp +++ b/matrix/integration.hpp @@ -10,18 +10,33 @@ int integrate_rk4( const Matrix & y0, const Matrix & u, Type t0, - Type h, + Type tf, + Type h0, Matrix & y1 ) { // https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods + Type t1 = t0; + y1 = y0; + Type h = h0; Vector k1, k2, k3, k4; - k1 = f(t0, y0, u); - k2 = f(t0 + h/2, y0 + k1*h/2, u); - k3 = f(t0 + h/2, y0 + k2*h/2, u); - k4 = f(t0 + h, y0 + k3*h, u); - y1 = y0 + (k1 + k2*2 + k3*2 + k4)*(h/6); + if (tf < t0) return -1; // make sure t1 > t0 + while (t1 < tf) { + if (t1 + h0 < tf) { + h = h0; + } else { + h = tf - t1; + } + k1 = f(t1, y1, u); + k2 = f(t1 + h/2, y1 + k1*h/2, u); + k3 = f(t1 + h/2, y1 + k2*h/2, u); + k4 = f(t1 + h, y1 + k3*h, u); + y1 += (k1 + k2*2 + k3*2 + k4)*(h/6); + t1 += h; + } return 0; } } // namespace matrix + +// vim: set et fenc=utf-8 ff=unix sts=0 sw=4 ts=4 : diff --git a/test/integration.cpp b/test/integration.cpp index ebc456adb0..abbfe885b3 100644 --- a/test/integration.cpp +++ b/test/integration.cpp @@ -8,17 +8,20 @@ using namespace matrix; Vector f(float t, const Matrix & y, const Matrix & u); Vector f(float t, const Matrix & y, const Matrix & u) { - return ones(); + float v = -sinf(t); + return v*ones(); } int main() { Vector y = ones(); Vector u = ones(); - float t = 1; - float h = 0.1f; - integrate_rk4(f, y, u, t, h, y); - TEST(isEqual(y, (ones()*1.1f))); + float t0 = 0; + float tf = 2; + float h = 0.001f; + integrate_rk4(f, y, u, t0, tf, h, y); + float v = 1 + cosf(tf) - cosf(t0); + TEST(isEqual(y, (ones()*v))); return 0; }