对于Neural ODE的小研究

    科技2022-07-15  119

     

    上面就是用欧拉方法解常微分方程的代码。

    ●Midpoint method (or RK2) - 2nd order method方法只需

     

    这里odeint是一种通用的ODE求解器,必须提供fun(t,ht),初始条件,评估函数的时间步和求解器

    像Runge–Kutta(RK4)或Adams–Bashforth这样的高阶方法可以保证更好的数值精度

    所有这些都可以在形式通用的接口中实现(例如scipy

    将神经网络与ODE求解器集成

     

    结果如图所示 

    We can use existing (and efficient) implementation of solvers to integrate NNs dynamics The memory cost is O(1) , due to reversibility i.e. we don’t need to store all activations in the graph, we can easily recover them by backward integration (i.e. time reversed integration) Complex dynamics can be modeled with fewer parameters We can control accuracy/speed trade-off with adaptive solvers by setting lower/higher error tolerances Hidden states can be accessed at any value of t - no discrete time steps as in RestNet skip connection

     NeuralODE - adjoint method

    Adjoint method can be understand as a continuous version of chain rule Chain rule: Consider following sequence of operations ( L is a scalar loss):

    We can compute gradient of L w.r.t input state using chain rule

     此公式是任何深度学习autograd的核心

     

    Processed: 0.012, SQL: 8