背景
首先,本质上这也是一个关于 对比学习(Contrast Learning)的方法。关于对比学习的一些背景介绍可以看一下 MoCo 的笔记。
对于对比学习,我们希望最终的网络有两个能力:
- 将同样本的不同 view/arg 在特征空间尽量靠近。alignment
- 将不同样本在特征空间内尽量分开。uniformity/dispersion
BYOL 做了一件反直觉的事情,在训练中把 负样本 的对比部分给去掉了。按照直觉,我们会觉得,网络失去了将不同类的特征向量分开(dispersion)的能力,那么网络很容易得到退化解 / 奇异解(collapsed solution/trivial solution),但是 BYOL 并没有这样。这很大程度上因为 Moving Average/Mean Teacher(也类似于 MoCo 的 momentum 方法),下文详述。
主要贡献 / 动机
动机 1 – 结构改进
- 类似于 MoCo v2,BYOL 也添加了两个 MLP,为提升性能
- BYOL 添加了 2 个 MLP 之后,额外添加了一个 predictor,增强灵活性
动机 2 – 去除负样本
为什么想到去除负样本呢?
- 取负样本的开销很大,memory bank、MoCo 这些工作都在尝试取得尽可能多的负样本。
- 直觉 — 不取负样本的话无法完成对比学习的任务——不同样本之间的分离,容易获得退化解。
文中给了一个简单的证明,如果想避免坍缩,为什么不能用梯度下降来优化我们的目标。
1)假设,预测器 predictor 最优时,满足如下条件
2)更新 θ 根据
3)若要最小化
a. 对于 z\_θ
显然成立,所以 z\_θ 不会坍缩。online 网络我们用梯度下降更新参数。
b. 对于 z ’_ξ
显然一个常数解会令
最小,容易坍缩。/target 网络我们采用更新参数。
动机 3 – Moving Average / mean teacher
其实和 MoCo 的 momentum 实际上是一样的,target 是 online 的 缓慢收敛、smooth 的版本。
首先我们看一个 ema 曲线的示例,蓝点是原始数据,红线是拟合的数据,绿色的线是 EMA 曲线,可以看出特点很明显。
文中给出了一个实验,来说明 τ 对效果的实验。
这个实验说明什么呢:
- Student 不能学的太快了,学得太快会很容易退化。
- Student 也不能学的太慢,学得太慢会收敛很慢。
- 一个随机初始化的 CNN 很可能本来就具有将不同样本在特征空间分离的能力,只不过可能不强,我们需要一个速度适中、平缓的一个学习过程,来保持这种能力。
动机 4 – 左脚踩右脚
- 一个有趣的现象
文章第 3 部分的第 2 段(To prevent…)指出,随机初始化网络 A、B,A 作为 teacher,A 随机初始化之后的性能为 Top1-1.4%。固定 A 的参数,使用 x 的不同 aug 版本,x1、x2 作为两个网络的输入,使用 A 的网络输出作为 B 学习的 pseudo label,引导 B 的学习,最终 B 达到 Top1-18.8%。这个动机很神奇,如果可以无穷无尽的这样下去,C 再跟 B 学,D 再跟 C 学 … 性能是否会非常高。
- 优化 – 小碎步交替进行,螺旋升天
每次迭代 x 增强为 x1,x2;
输入 x1 到 online,x2 到 target,算一个 loss;
输入 x2 到 online,x1 到 target,再算一个 loss;
然后把两个 loss 加在一起,更新 online 的参数,之后 mean teacher 更新 target 的参数。
- Loss 防止退化
这是一个 cos 相似度,只关心方向,不关心大小,也防止 MSE 把 feature 的 scale 都拉倒接近 0。
- 如何理解这种优化?
相当于将上一段提到的每次提升一大截过程,变成了一个小碎步交替上升的过程。如何解释呢,上段中提到的训练过程,在整个训练的很多次迭代中 A 的参数都固定,然后 B 根据 A 的指导进行参数更新。而 Moving Average 策略,每一次迭代中,Online 网络(相当于 A)参数都进行了更新,Target 网络(相当于 B)都根据 online 网络的参数进行 momentum 更新。这样,相当于 A 网络在训练 B 的同时,也在提升自身的性能。
另外 EMA 策略也相当于让 target 慢步跟上,也算是小碎步的一部分。
实验结果
作者的意思就是——“你看你想不到把负样本去掉吧,你更想不到我把负样本去掉之后还很有效吧”。
小结
那么,我们回顾一下 BYOL 干了什么
1)去除负样本的同时防止退化
- mean teacher
- cos 相似度只关注方向
- 并不是一些文章里说的,MLP 里有 BN,BN 可以防止退化。「BYOL works even without batch statistics」