上面就是用欧拉方法解常微分方程的代码。
用●Midpoint method (or RK2) - 2nd order method方法只需
这里odeint是一种通用的ODE求解器,必须提供fun(t,ht),初始条件,评估函数的时间步和求解器
像Runge–Kutta(RK4)或Adams–Bashforth这样的高阶方法可以保证更好的数值精度
所有这些都可以在形式通用的接口中实现(例如scipy
结果如图所示
● We can use existing (and efficient) implementation of solvers to integrate NNs dynamics ● The memory cost is O(1) , due to reversibility i.e. we don’t need to store all activations in the graph, we can easily recover them by backward integration (i.e. time reversed integration) ● Complex dynamics can be modeled with fewer parameters ● We can control accuracy/speed trade-off with adaptive solvers by setting lower/higher error tolerances ● Hidden states can be accessed at any value of t - no discrete time steps as in RestNet skip connectionNeuralODE - adjoint method
● Adjoint method can be understand as a continuous version of chain rule ● Chain rule: Consider following sequence of operations ( L is a scalar loss): ● We can compute gradient of L w.r.t input state using chain rule此公式是任何深度学习autograd的核心