signed

QiShunwang

“诚信为本、客户至上”

CONTRASTIVE REPRESENTATION DISTILLATION

2021/6/3 17:06:20   来源:

CONTRASTIVE REPRESENTATION DISTILLATION

我们常常希望将表征性知识从一个神经网络转移到另一个神经网络。这方面的例子包括将一个大型网络提炼成一个较小的网络,将知识从一种感觉模式转移到另一种感觉模式,或者将一系列模型集合成一个单一的估计器。知识提炼是解决这些问题的标准方法,它使教师和学生网络的概率输出之间的KL背离最小。我们证明这个目标忽略了教师网络的重要结构知识。这促使我们提出了另一个目标,即训练学生在教师的数据表述中捕捉到更多的信息。我们把这个目标表述为对比性学习。实验证明,我们的新目标在各种知识转移任务上优于知识蒸馏和其他尖端蒸馏器,包括单一模型压缩、集合蒸馏和跨模式转移。我们的方法在许多转移任务中创造了新的最先进的技术,当与知识蒸馏相结合时,有时甚至超过了教师网络。

1 INTRODUCTION

知识提炼(KD)将知识从一个深度学习模型(教师)转移到另一个(学生)。最初由Hinton等人(2015)提出的目标是最小化教师和学生输出之间的KL散度。当输出是一个分布时,这种表述具有直观的意义,例如,在类上的probability mass function。然而,我们经常希望转移关于一个表示的知识。例如,在 "跨模式提炼 "的问题中,我们可能希望将图像处理网络的表示转移到声音(Aytar等人,2016)或深度(Gupta等人,2016)处理网络,这样,图像的深度特征和相关的声音或深度特征是高度相关的。在这种情况下,KL发散是不确定的。

表征性知识是结构化的–各维度呈现出复杂的相互依赖关系。在(Hinton等人,2015)中引入的原始KD目标将所有维度视为独立的,并以输入为条件。让 y T y^T yT为老师的输出, y S y^S yS是学生的输出。然后是原始KD目标函数,ψ, has the fully factored form: ψ ( y S , y T ) = Σ i ϕ i ( y i S , y i T ) ψ(y^S,y^T)=\Sigma_i\phi_ i(y^S_i,y^T_i) ψ(yS,yT)=Σiϕi(yiS,yiT). 这样一个考虑因素的目标不足以转移结构知识,即输出维度i和j之间的依赖关系。这类似于图像生成中,由于输出维度之间的独立假设,L2目标产生模糊结果的情况。

为了克服这个问题,我们希望有一个能够捕捉到相关性和高阶输出依赖性的目标。为了实现这一点,在本文中,我们利用了e the family of contrastive objectives(Gutmann & Hyvärinen,2010;Oord等人,2018;Arora等人,2019;Hjelm等人,2018)。近年来,这些目标函数被成功地用于density estimation和表征学习,特别是在自我监督的情况下。在这里,我们将它们调整为从一个深度网络到另一个深度网络的知识提炼任务。我们表明,在表示空间中工作是很重要的,与最近的工作如Zagoruyko & Komodakis(2016a);Romero等人(2014)类似。然而,请注意,这些作品中使用的损失函数并没有明确地试图捕捉表征空间中的相关性或高阶依赖关系。

我们的目标是最大限度地降低师生之间的互信息表示。我们发现,这导致在一些知识转移任务中表现更好,我们推测这是因为对比目标更好地传递了教师表征中的所有信息,而不是仅仅传递关于条件独立输出类概率的知识。有些令人惊讶的是,对比目标甚至改进了最初提出的提取类概率知识的任务的结果,例如,将一个大型CIFAR100网络压缩成一个性能几乎相同的小型网络。我们认为这是因为不同类别概率之间的相关性包含了有用的信息,可以规范学习问题。我们的论文在两个主要独立发展的文献之间建立了联系:知识提炼和表征学习。这种联系使我们能够利用表征学习的强大方法来显著改进知识提炼的SOTA。我们的贡献是:

  1. 一个基于对比的目标,在深度网络之间转移知识。
  2. 应用于模型压缩、跨模式转移和ensemble distillation。
  3. 对最近的12种蒸馏方法进行基准测试;CRD优于所有其他方法

2 RELATED WORK

注意力转移(Zagoruyko & Komodakis, 2016a)侧重于网络的特征图,而不是输出logits。这里的想法是在教师和学生的特征图中激发出类似的反应模式(称为 “attention”)。然而,在这种方法中,只有具有相同空间分辨率的特征图可以被结合,这是一个重要的限制,因为它需要学生和教师网络具有非常相似的架构。

这种技术实现了最先进的提炼结果(以学生网络的泛化为标准)。FitNets(Romero等人,2014)也通过使用回归来指导学生网络的特征激活来处理中间表征。由于Zagoruyko和Komodakis(2016a)做了这种回归的加权形式,他们往往表现得更好。其他论文(Yim等人,2017;Huang & Wang,2017;Kim等人,2018;Yim等人,2017;Huang & Wang,2017;Ahn等人,2019;Koratana等人,2019)执行了基于表示的各种标准。我们在本文中使用的对比性目标与CMC中使用的目标相同(Tian et al., 2019)。但我们从不同的角度推导,并给出严格的证明,我们的目标是互信息的下限。我们的目标也与(Oord等人,2018;Gutmann & Hyvärinen,2010)中介绍的InfoNCE和NCE目标有关。Oord等人(2018)在自我监督的表征学习的背景下使用对比学习。他们表明,他们的目标是最大化相互信息的下限。Hjelm等人(2018)使用了一种非常相关的方法。InfoNCE和NCE密切相关,但与对抗性学习不同(Goodfellow等人,2014)。在(Goodfellow,2014)中,表明Gutmann & Hyvärinen(2010)的NCE目标可以导致最大似然学习,但不是对抗性目标。

3 METHOD

对比学习的关键思想是非常普遍的:对于 "positive " pair,学习一个在某些度量空间中接近的表征,并在 "negative " pair之间分离表征。图1直观地解释了我们如何为我们考虑的三个任务构建对比性学习:模型压缩、跨模式转移和ensemble distillation。

在这里插入图片描述

图1:我们考虑的三种提取设置:(a)压缩模型,(b)将知识从一种模式(如RGB)转移到另一种模式(如深度),(c)将网络集合提取到单个网络中。建构目标鼓励教师和学生将相同的输入映射到接近的表示(在某些度量空间中),并将不同的输入映射到遥远的表示,如阴影圆所示。

3.1 CONTRASTIVE LOSS

给定两个深度神经网络,一个是教师 f T f^T fT,一个是学生 f S f^S fS。设x为网络输入;我们将倒数第二层(在logits之前)的表征分别表示为 f T ( x ) 和 f S ( x ) f^T(x)和f^S(x) fT(x)fS(x)。设 x i x_i xi代表训练样本, x j x_j xj为随机选取样本。我们想推近表示 f T ( x i ) 和 f S ( x i ) f^T(x_i)和f^S(x_i) fT(xi)fS(xi),而分离 f T ( x j ) 和 f S ( x j ) f^T(x_j)和f^S(x_j) fT(xj)fS(xj)。为了便于记法,我们分别为学生和教师的数据表示定义了随机变量S和T:

在这里插入图片描述

直观地说,我们将考虑联合分布 p ( S , T ) p(S, T) p(S,T)和边际分布的乘积 p ( S ) p ( T ) p(S)p(T) p(S)p(T),因此,通过最大化这些分布之间的KL散度,我们可以最大化学生和教师表示之间的互信息。为了设置一个能够实现这一目标的适当的损失,让我们定义一个带有latent variable C的分布q,它决定一个tuple ( f T ( x i ) , f S ( x j ) ) (f^T(x_i), f^S(x_j )) (fT(xi),fS(xj))是来自联合(C = 1)还是边际分布的乘积(C = 0)

在这里插入图片描述

Now, suppose in our data, we are given 1 congruent pair (drawn from the joint distribution, i.e. the same input provided to T and S) for every N incongruent pairs (drawn from the product of marginals; independent randomly drawn inputs provided to T and S). Then the priors on the latent C are:

在这里插入图片描述

通过简单的应用贝叶斯法则,C=1类的后验为

在这里插入图片描述

接下来,我们观察到与互信息的联系,如下所示

在这里插入图片描述

然后取两边的期望 p ( t , S ) p(t,S) p(t,S)(相当于 q ( T , S ∣ C = 1 ) q(T,S | C=1) q(T,SC=1))并重新排列((equivalently w.r.t. q ( T , S ∣ C = 1 ) q(T,S | C=1) q(T,SC=1)) and rearranging),得到:

在这里插入图片描述

其中 I ( T ; S ) I(T; S) I(T;S)是教师和学生embeddings分布之间的互信息。因此,最大化 E q ( T , S ∣ C = 1 ) l o g   q ( C = 1 ∣ T , S ) \mathbb E_{q(T ,S|C=1)}log\ q(C = 1|T, S) Eq(T,SC=1)log q(C=1T,S),通过学生网络的参数S,增加了互信息的下限。然而,我们不知道真实的分布q(C = 1|T, S);所以,我们通过拟合一个模型 h : { T , S } → [ 0 , 1 ] h:\{\mathcal T ,\mathcal S\} → [0, 1] h:{T,S}[0,1]来估计它,通过来自数据分布 q ( C = 1 ∣ T , S ) q(C = 1|T, S) q(C=1T,S)的样本,其中 T 和 S \mathcal T和\mathcal S TS代表embeddings的域( the domains of the embeddings)。我们最大化这个模型下的数据的对数似然(一个二元分类问题):

在这里插入图片描述

We term h the critic since we will be learning representations that optimize the critic’s score.假设h有足够的表现力, h ∗ ( T , S ) = q ( C = 1 ∣ T , S ) h^∗(T, S)=q(C = 1|T, S) h(T,S)=q(C=1T,S)(通过吉布斯不等式;证明见第6.2.1节),因此我们可以用h∗重写公式9

在这里插入图片描述

因此,我们看到,最佳critic是一个estimator ,其期望值降低了互信息的范围。我们希望学习一个学生,使其表征与教师的表征之间的互信息最大化,这就提出了以下优化问题

在这里插入图片描述

这里的一个明显的困难是,最佳critich∗取决于当前的学生。我们可以通过将(12)中的约束弱化来规避这个困难:

在这里插入图片描述

在这里插入图片描述

这说明我们可以在学习h的同时共同优化 f S f^S fS。我们注意到,由于(16), f S ∗ = a r g   m a x f S L c r i t i c ( h ) f^{S∗} = arg\ max_{f^S} \mathcal L_{critic}(h) fS=arg maxfSLcritic(h)对于任何h,也是一个基于互信息优化下限(较弱的下限)的表示,因此我们的公式不依赖于h的完美优化。

我们可以选择用满足 h : { T , S } → [ 0 , 1 ] h:\{\mathcal T,\mathcal S\}\rightarrow[0,1] h:{T,S}[0,1]的任何函数族. 在实践中,我们使用以下方法:

在这里插入图片描述

其中M是数据集的cardinality, τ τ τ是调整concentration level的温度。在实践中,由于S和T的维度可能不同, g S 和 g T g^S和g^T gSgT将它们线性转化为相同的维度,并在内积之前通过L-2 norm进一步归一化它们。公式(18)的形式受到NCE的启发(Gutmann & Hyvärinen, 2010; Wu et al., 2018)。我们的表述与InfoNCE损失(Oord等人,2018)相似,即我们最大化了互信息的下限。然而,我们使用了一个不同的目标和约束,在我们的实验中,我们发现它比InfoNCE更有效。

Implementation. 理论上,公式16中较大的N会导致MI的更严格的下限。在实践中。为了避免使用非常大的batch规模,我们遵循Wu等人(2018)的做法,实现了一个内存缓冲器存储每个数据样本的潜在特征,这些特征是由以前的批次计算出来的。因此,在 训练中,我们可以有效地从存储缓冲区中检索大量的负面样本。

3.2 KNOWLEDGE DISTILLATION OBJECTIVE

Hinton等人(2015)提出了知识提炼损失。除了学生输出 y S y^S yS和 one-hot label y之间的常规交叉熵损失外,它还要求学生网络输出尽可能地与教师输出相似,即最小化他们输出概率之间的交叉熵。完整的目标是

在这里插入图片描述

3.3 CROSS-MODAL TRANSFER LOSS

在图1(b)所示的跨模态转移任务中,教师网络是在一个具有大规模标记的数据集的源模态X上训练的。然后,我们希望将知识转移给学生网络,但要将其适应于另一个数据集或模态Y。但教师网络的特征仍然有价值,可以帮助学生在另一个领域学习。在这个转移任务中,我们使用对比性损失公式10来匹配学生和教师的特征。此外,我们还考虑了其他的提炼目标,比如上一节中讨论的KD,注意力转移Zagoruyko & Komodakis(2016a)和FitNet Romero等人(2014)。

这种transfer是在一个成对但无标签的数据集 D = { ( x i , y i ) ∣ i = 1 , . . . , L , x i ∈ X , y i ∈ Y } D = \{(x_i, y_i)|i = 1, ..., L, x_i ∈ \mathcal X , y_i ∈\mathcal Y\} D={(xi,yi)i=1,...,L,xiX,yiY}上进行。在这种情况下,对于源模态上的原始训练任务,没有这种数据的真实标签y,因此我们在所有测试的目标中忽略了 H ( y , y S ) H(y, y^S) H(y,yS)项。之前的跨模态工作Aytar等人(2016);Hoffman等人(2016b;a)使用L2回归或KL-散度。

4 ENSEMBLE DISTILLATION LOSS

在1©所示的集合蒸馏的情况下,我们有M>1个教师网络, f T i f^{T_i} fTi和一个学生网络 f S f^S fS。我们采用对比框架,在每个教师网络 f T i f^{T_i} fTi和学生网络 f S f^S fS的特征之间定义了多个成对的对比性损失。这些损失相加,得出最终的损失(要最小化)

网络, f T i f^{T_i} fTi和一个学生网络 f S f^S fS。我们采用对比框架,在每个教师网络 f T i f^{T_i} fTi和学生网络 f S f^S fS的特征之间定义了多个成对的对比性损失。这些损失相加,得出最终的损失(要最小化)

在这里插入图片描述