What is being transferred in transfer learning?

    科技2024-12-07  17

    文章目录

    概主要内容feature reusemistakes and feature similarityloss landscapemodule criticalitypre-trained checkpoint

    Neyshabur B., Sedghi H., Zhang C. What is being transferred in transfer learning? arXiv preprint arXiv 2008.11687, 2020.

    迁移学习到底迁移了什么?

    主要内容

    T: 普通训练的模型P: 预训练的模型RI: 随机初始化的模型RI-T: 随机初始化再经过普通训练的模型P-T: 在预训练的基础上再fine-tuning的模型

    本文的预训练都是在ImageNet上, 然后在CheXpert和DomainNet(分为real, clipart, quickdraw)上测试.

    feature reuse

    大家认为迁移学习有用的一个直觉就是迁移学习通过特征的复用来样本少的数据提供一个较好的特征先验.

    通过上面的图可以看到, P-T总是能够表现优于RI-T, 这能够支撑我们的观点. 但是, 为什么数据差别特别大的时候, 预训练还是有用呢(此时feature reuse的作用应该不是很明显)? 作者将图片按照不同的block size打乱(就像最开始的那些乱七八糟的图片). 这个时候, 模型应该只能抓住浅层的特征, 抽象的特征是没法被很好提取的, 结果如下图所示.

    当打乱的程度加剧(block size变小), 任务越发困难;相对正确率差距 ( A P − T − A R I − T ) / A P − T % (A_{P-T}-A_{RI-T})/A_{P-T} \% (APTARIT)/APT%随着block size减小而减小(clipart, real), 这说明feature reuse很有效果, quickdraw 相反是由于其数据集和预训练的数据集相差过大, 但是即便如此, 在quickdraw上预训练还是有效的, 说明存在除了feature reuse外的因素;P-T的训练速度(右图)一直很稳定, 而RI-T的训练速度则在block size下降的时候有一个急剧的下降, 这说明feature reuse并不是影响P-T训练速度的主要因素.

    mistakes and feature similarity

    这部分通过探究不同模型有哪些common和uncommon的mistakes来揭示预训练的作用.

    P-T在简单样本上的成功率很高, 而在比较模糊难以判断的样本上比较难(而此时RI-T往往比较好), 这说明P-T有着很强的先验.

    通过 centered kernel alignment (CKA) 来衡量特征之间的相似度: 可以发现, 基于预训练的模型之间的特征相似度很高, 而RI-T与别的模型相似度很低, 即便是两个相同初始化的RI-T. 说明预训练模型之间往往是在重复利用相同的特征.

    下表为不同模型的参数的 ℓ 2 \ell_2 2距离, 同样能够反映上面一点.

    loss landscape

    Θ , Θ ~ \Theta, \tilde{\Theta} Θ,Θ~表示两个checkpoint的参数, 通过线性插值 { Θ λ = ( 1 − λ ) Θ + λ Θ ~ : λ ∈ [ 0 , 1 ] } , \{\Theta_{\lambda} = (1- \lambda) \Theta + \lambda \tilde{\Theta}: \lambda \in [0, 1]\}, {Θλ=(1λ)Θ+λΘ~:λ[0,1]}, 考量模型在 Θ λ \Theta_{\lambda} Θλ下的表现.

    上图, 左为DomainNET real, 右为quickdraw, 可见预训练模型之间的loss landscape是很光滑的, 不同于RI-T.

    module criticality

    如果我们将训练好后的模型的某一层参数替换为其初始参数, 然后观察替换前后的正确率就能一定程度上判断这个层在整个网络中的重要性, module criticality就是一个这样的类似的指标.

    下图反映了不同模型的不同层的criticality.

    下图反映了RI-T的训练后的参数 θ \theta θ其实加了扰动反而性能更好? 而P-T的就相当稳定.

    pre-trained checkpoint

    我们选pre-trained模型的时候, 往往是通过正确率指标来判断的, 但是事实上, 这个判断并不十分准确, 事实上我们可以早一步地选取checkpoint (直观上理解, 大概是只要参数进入了那个光滑的盆地就行了).

    Processed: 0.014, SQL: 8