[NNDL] 最小二乘法解的矩阵形式

    科技2024-12-24  4

    最小二乘法解的矩阵形式

    文章目录

    最小二乘法解的矩阵形式简介平方损失函数对参数求导求解最优参数

    简介

    最近在看 NNDL,其中有一个经验风险最小化的例子,即最小二乘法,定义如下:

    给定一组包含 N N N 个训练样本的训练机 D = { ( x ( n ) , y ( n ) ) } n = 1 N D=\{(\mathbf{x}^{(n), }y^{(n)})\}_{n = 1}^N D={(x(n),y(n))}n=1N 。使用线性回归。样本和参数均为列向量。 f ( x ; w ) = w T x f(\mathbf{x};\mathbf{w}) = \mathbf{w}^T\mathbf{x} f(x;w)=wTx

    平方损失函数

    经验风险最小化,训练集的风险被定义为, X = [ x 1 , x 2 , ⋯   , x N ] T X=[\mathbf{x}_1, \mathbf{x}_2,\cdots,\mathbf{x}_N]^T X=[x1,x2,,xN]T R ( w ) = ∑ n = 1 N 1 2 ( y n − w T x ( n ) ) = 1 2 ∥ y − X w ∥ 2 = 1 2 ( y − X w ) T ( y − X w ) = 1 2 ( y T − w T X T ) ( y − X w ) = 1 2 ( y T y − y T X w − w T X T y + w T X T X w ) \begin{aligned} R(\mathbf{w}) &= \sum_{n = 1}^N\frac{1}{2}\left(y^{n} - \mathbf{w}^T\mathbf{x}^{(n)}\right)\\&= \frac{1}{2}\|\mathbf{y} - \mathbf{X}\mathbf{w}\|^2\\&= \frac{1}{2}(\mathbf{y} -\mathbf{Xw})^T(\mathbf{y} - \mathbf{Xw})\\&= \frac{1}{2}(\mathbf{y}^T-\mathbf{w}^T\mathbf{X}^T)(\mathbf{y} - \mathbf{Xw}) \\&= \frac{1}{2}\left(\mathbf{y}^T\mathbf{y} - \mathbf{y}^T\mathbf{X}\mathbf{w} - \mathbf{w}^T\mathbf{X}^T\mathbf{y} + \mathbf{w}^T\mathbf{X}^T\mathbf{X}\mathbf{w}\right) \end{aligned} R(w)=n=1N21(ynwTx(n))=21yXw2=21(yXw)T(yXw)=21(yTwTXT)(yXw)=21(yTyyTXwwTXTy+wTXTXw) 损失函数最终是一个标量,可以发现 y T X w = ( w T X T y ) T = s c a l a r \mathbf{y}^T\mathbf{X}\mathbf{w} = (\mathbf{w}^T\mathbf{X}^T\mathbf{y})^T = scalar yTXw=(wTXTy)T=scalar 两个是一个数字,因此 R ( w ) = 1 2 ( y T y − 2 ( y T X w ) + ∥ X w ∥ 2 ) R(\mathbf{w}) = \frac{1}{2}(\mathbf{y}^T\mathbf{y} - 2(\mathbf{y}^T\mathbf{X}\mathbf{w}) +\|\mathbf{Xw}\|^2) R(w)=21(yTy2(yTXw)+Xw2)

    对参数求导

    首先损失函数是一个凸函数,梯度为 0 的点是全局的最小值。需要对 w \mathbf{w} w 求导。 ∂ R ( w ) ∂ w = 1 2 ∂ ( y T y − 2 ( y T X w ) + ∥ X w ∥ 2 ) ∂ w = 1 2 ( 0 − ∂ ( 2 y T X w ) ∂ w + ∂ ∥ X w ∥ 2 ∂ w ) = − ∂ y T X w ∂ w + 1 2 ∥ X w ∥ 2 ∂ w \begin{aligned} \frac{\partial R(\mathbf{w})}{\partial \mathbf{w}}&= \frac{1}{2} \frac{\partial (\mathbf{y}^T\mathbf{y} - 2(\mathbf{y}^T\mathbf{X}\mathbf{w}) +\|\mathbf{Xw}\|^2)}{\partial\mathbf{w}}\\&= \frac{1}{2}(0-\frac{\partial (2\mathbf{y}^T\mathbf{Xw})}{\partial \mathbf{w}} + \frac{\partial\|\mathbf{Xw}\|^2}{\partial \mathbf{w}})\\&= -\frac{\partial\mathbf{y}^T\mathbf{Xw}}{\partial\mathbf{w}} + \frac{1}{2}\frac{\|\mathbf{Xw}\|^2}{\partial\mathbf{w}} \end{aligned} wR(w)=21w(yTy2(yTXw)+Xw2)=21(0w(2yTXw)+wXw2)=wyTXw+21wXw2 分析前半部分,矩阵展开计算依次求导。 ∇ w f ( w ) = [ ∂ f ( w ) / ∂ w 1 ∂ f ( w ) / ∂ w 2 ⋮ ∂ f ( w ) / ∂ w N ] = X T y \nabla_\mathbf{w} f(\mathbf{w})= \begin{bmatrix} \partial f(\mathbf{w})/\partial w_1\\ \partial f(\mathbf{w})/\partial w_2\\ \vdots\\ \partial f(\mathbf{w})/\partial w_N \end{bmatrix} = \mathbf{X}^T\mathbf{y} wf(w)=f(w)/w1f(w)/w2f(w)/wN=XTy 对于后半部分 ∇ w w T X T X w = ∇ w w T A w = [ ∂ w T A w / ∂ w 1 ∂ w T A w / ∂ w 2 ⋮ ∂ w T A w / ∂ w N ] = [ 2 w 1 ( A 11 + A 12 + A 13 + ⋯ A 1 N ) 2 w 2 ( A 21 + A 22 + A 23 + ⋯ A 2 N ) ⋮ 2 w N ( A N 1 + A N 2 + A N 3 + ⋯ A N N ) ] = 2 A w = 2 X T X w \begin{aligned} \nabla_\mathbf{w}\mathbf{w}^TX^TX\mathbf{w}&= \nabla_\mathbf{w}\mathbf{w}^TA\mathbf{w}\\&= \begin{bmatrix} \partial\mathbf{w}^T\mathbf{Aw}/\partial_{w_1}\\ \partial\mathbf{w}^T\mathbf{Aw}/\partial_{w_2}\\ \vdots\\ \partial\mathbf{w}^T\mathbf{Aw}/\partial_{w_N} \end{bmatrix} \\&= \begin{bmatrix} 2w_1(A_{11} + A_{12} + A_{13} + \cdots A_{1N})\\ 2w_2(A_{21} + A_{22} + A_{23} + \cdots A_{2N})\\ \vdots\\ 2w_N(A_{N1} + A_{N2} + A_{N3} + \cdots A_{NN})\\ \end{bmatrix}\\ &= 2A\mathbf{w}\\&= 2X^TX\mathbf{w} \end{aligned} wwTXTXw=wwTAw=wTAw/w1wTAw/w2wTAw/wN=2w1(A11+A12+A13+A1N)2w2(A21+A22+A23+A2N)2wN(AN1+AN2+AN3+ANN)=2Aw=2XTXw 所以有 ∂ R ( w ) ∂ w = − X T y + X T X w \frac{\partial R(\mathbf{w})}{\partial\mathbf{w}}= -X^T\mathbf{y} +X^TX\mathbf{w} wR(w)=XTy+XTXw

    求解最优参数

    让导数为 0, 可得 X T X w = X T y w = ( X T X ) − 1 X T y \begin{aligned} X^TX\mathbf{w} &= X^T\mathbf{y}\\ \mathbf{w}&=(X^TX)^{-1}X^T\mathbf{y} \end{aligned} XTXww=XTy=(XTX)1XTy

    Processed: 0.051, SQL: 8