人脸识别任务里,通常利用全连接层,来做人脸的分类。会面临三个实际问题:
- 真实的人脸识别数据噪声严重
- 真实的人脸识别数据存在严重的长尾分布问题,一些类别样本多,多数类别样本少
- 人脸类别越来越多,全连接层训练成本越来越高,难度越来越大
于是,有研究人员提出Partial FC,拒绝全量更新负类别中心,而是仅更新少部分负类别中心。该做法优势在于
- 降低噪声数据被采样的概率
- 降低高频负类别中心被选中的概率
- 降低负类别中心的更新频率,降低训练难度
如下图所示:
问题建模
人脸识别领域,常用的分类损失公式化定义如下:
L
=
−
1
B
∑
i
=
1
B
l
o
g
e
W
y
i
T
⋅
x
i
e
W
y
i
T
⋅
x
i
+
∑
j
=
1
,
j
≠
y
i
C
e
W
j
T
⋅
x
i
(1)
L=-\frac{1}{B}\sum_{i=1}^{B}log\frac{e^{W_{y_i}^T\cdot x_i}}{e^{W_{y_i}^T}\cdot x_i+\sum_{j=1,j\neq y_i}^Ce^{W_j^T\cdot x_i}} \tag{1}
L=−B1i=1∑BlogeWyiT⋅xi+∑j=1,j=yiCeWjT⋅xieWyiT⋅xi(1)
,其中,
B
B
B表示batch size,
C
C
C表示类别个数,
W
j
T
W_{j}^T
WjT表示第
j
j
j个类别中心的特征,
(
x
i
,
y
i
)
(x_i,y_i)
(xi,yi)表示第
i
i
i个样本的特征为
x
i
x_i
xi,类别为
y
i
y_i
yi。
真实大规模人脸数据实际使用时,有以下问题:
- 噪声问题:见上图(a),图片对都是一个人的图片,但是被分到不同的类别,这对模型训练有非常大的干扰。
- 长尾分布:见上图(b),大部分类别(identity)包含的图像数量很少,在WebFace42M中,44.57%的类别包含的图像数量少于10张。这会导致低频类别的类别中心更新缓慢,而高频类别的类别中心更新频繁。
- 训练资源:全连接层一般表示为
W
∈
R
D
×
C
W\in \mathbb{R}^{D\times C}
W∈RD×C,其中
D
D
D表示维度,
C
C
C表示类别数。假设
D
=
512
D=512
D=512,如果类别数是1000,000(一百万)
- fp16下,全连接层的显存消耗为: 512 × 100 , 000 × 2 1024 × 1024 × 1024 = 0.95 G B \frac{512\times 100,000 \times 2}{1024\times 1024\times 1024}=0.95GB 1024×1024×1024512×100,000×2=0.95GB
- 公式(1)中,需要计算 B B B个 x i x_i xi属于类别中心 W j T W_{j}^T WjT的logit,维度是 R B × D × C \mathbb{R}^{B\times D\times C} RB×D×C,显存消耗为 512 × 100 , 000 × 2 1024 × 1024 × 1024 ⋅ B = 0.95 B G B \frac{512\times 100,000 \times 2}{1024\times 1024\times 1024}\cdot B=0.95B \,GB 1024×1024×1024512×100,000×2⋅B=0.95BGB,batchsize越大,需要的显存越大。
- 在下图,进行了模型并行和partial fc在显存消耗和训练速度上的比较,可以发现:
- partial fc显著降低了对logit的显存消耗
- partial fc略微降低了对存储类别中心的显存消耗
- partial fc未降低对特征抽取网络的显存消耗(将原图像转换为特征的模型的消耗)
- 由于partial fc减少了负类别中心的数量,降低了logit计算的复杂度,随着训练类别越多,加速比越高。
partial fc
为了缓解上述问题,提出了partial fc,通过稀疏更新全连接层的参数,来支持大规模人脸识别模型的训练。
整体架构如下图所示:
模型通过数据并行训练的,不同GPU包含了不同数据的特征,整体步骤如下:
- 汇总不同GPU里的图像特征和图像标签
- 将汇总的图像特征和图像标签送到每张GPU上
- 将全连接层(即 C C C个类别中心)均分到每张GPU上
- 在单张卡上,保留需要的正类别中心,以及采样固定比例的负类别中心
- 利用样本、正类别中心、负类别中心计算损失函数
代码实现
def forward(
self,
local_embeddings: torch.Tensor,
local_labels: torch.Tensor,
):
"""
Parameters:
----------
local_embeddings: torch.Tensor
feature embeddings on each GPU(Rank).
local_labels: torch.Tensor
labels on each GPU(Rank).
Returns:
-------
loss: torch.Tensor
pass
"""
local_labels.squeeze_()
local_labels = local_labels.long()
batch_size = local_embeddings.size(0)
if self.last_batch_size == 0:
self.last_batch_size = batch_size
assert self.last_batch_size == batch_size, (
f"last batch size do not equal current batch size: {self.last_batch_size} vs {batch_size}")
_gather_embeddings = [
torch.zeros((batch_size, self.embedding_size)).cuda()
for _ in range(self.world_size)
]
_gather_labels = [
torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)
]
_list_embeddings = AllGather(local_embeddings, *_gather_embeddings)
distributed.all_gather(_gather_labels, local_labels)
# 汇总不同GPU里的图像特征和图像标签
embeddings = torch.cat(_list_embeddings)
labels = torch.cat(_gather_labels)
labels = labels.view(-1, 1)
# self.class_start表示该GPU中,分配的类别中心起始id
# self.num_local表示该GPU中,分配的类别中心数量
# 于是,该GPU的类别中心id范围是[类别中心起始id, 类别中心起始id + 类别中心数量]
# 在单张卡上,仅保留需要的正类别中心
index_positive = (self.class_start <= labels) & (
labels < self.class_start + self.num_local
)
labels[~index_positive] = -1
labels[index_positive] -= self.class_start
# 在单张卡上,采样固定比例的负类别中心
if self.sample_rate < 1:
weight = self.sample(labels, index_positive)
else:
weight = self.weight
with torch.cuda.amp.autocast(self.fp16):
norm_embeddings = normalize(embeddings)
norm_weight_activated = normalize(weight)
logits = linear(norm_embeddings, norm_weight_activated)
if self.fp16:
logits = logits.float()
logits = logits.clamp(-1, 1)
# 基于样本特征、样本标签、正类别中心,采样的负类别中心,计算损失
logits = self.margin_softmax(logits, labels)
loss = self.dist_cross_entropy(logits, labels)
return loss
优势
partial fc的核心思想是”降低训练中负类别中心数量,显式得减少需要更新的参数量“。负类别中心采样比例越低,节约的显存越多。
为了更好理解partial fc对长尾分布、噪声问题的影响,计算分类损失对于样本
x
i
x_i
xi的梯度,如下:
∂
L
∂
x
i
=
−
(
(
1
−
p
+
)
W
+
−
∑
j
∈
S
,
j
≠
y
i
p
j
−
W
j
−
)
(2)
\frac{\partial L}{\partial x_i}=-((1-p^+)W^+-\sum_{j\in \mathbb{S}, j\neq y_i}p_j^-W_j^-) \tag{2}
∂xi∂L=−((1−p+)W+−j∈S,j=yi∑pj−Wj−)(2)
其中,
p
+
p^+
p+、
p
−
p^-
p−分别表示通过样本特征
x
i
x_i
xi计算的logit分数,
S
\mathbb{S}
S表示负类别中心,
∣
S
∣
=
C
×
r
|\mathbb{S}|=C\times r
∣S∣=C×r,通过采样比例
r
r
r,调整训练时的负样本数量。
样本特征 x i x_i xi的更新方向和正类别中心和负类别中心都有关系,partial fc随机减少负类别中心数量,减低噪声数据被采样的概率,降低高频负类别中心被选中的概率,有效缓解长尾问题和噪声问题。
注意:采样率为1,等同于选取所有负类别中心,进行模型训练。相当于原始fc分类器
为了进一步验证partial fc的作用原理,做了下述验证下实验
探究采样率与类内、类间相似度关系
(a)图中,采样率越低,APCS收敛至更高数值。APCS表示类内距离 A P C S = 1 B ∑ i = 1 B W y i T x i ∣ ∣ W y i ∣ ∣ ⋅ ∣ ∣ x i ∣ ∣ APCS=\frac{1}{B}\sum_{i=1}^B\frac{W_{y_i}^Tx_i}{||W_{y_i}||\cdot ||x_i||} APCS=B1i=1∑B∣∣Wyi∣∣⋅∣∣xi∣∣WyiTxi,说明采样率越低,类内相似度越大,类内越紧密。
(b)图中,采样率越低,MICS分布越往右,整体数值越大。MICS表示最大的类间余弦相似度 MICS i = max j ≠ i W i T W j ∥ W i ∥ ∥ W j ∥ \text{MICS}_i = \max_{j \neq i} \frac{W_i^T W_j}{\|W_i\| \, \|W_j\|} MICSi=j=imax∥Wi∥∥Wj∥WiTWj,说明采样率越低,类间相似度越大,类间拉不开。
探究采样率与评测集合效果关系
随着采样率越来越低,IJB-C、MFR-All评测集上的效果越来越差,如下:
探究采样率对噪声数据的鲁棒性
为验证在噪声数据上的效果,做了如下实验:
构造了WebFace12M-Conflict数据集,随机将20万个类的样本放到另外60万个类别数据中。
图(a)纵坐标为AMNCS,指最小的负类中心距离(越大,说明样本和负类离得远)。可以发现,降低采样率(1.0->0.1),在干净数据上,效果接近;在噪声数据上,缓解过拟合问题。
图(b)的MICS,指最大的类间余弦相似度(越大,说明类别越相似,越区分不开类别)。可以发现,降低采样率(1.0->0.1),MICS分布往右,整体数值偏大。由于WebFace12M-Conflict数据集中,20万类的样本随机分布在其他类别中,类间余弦相似度本就很大,图(b)更好的刻画实际噪声分布。
上图定义了两个概念,分别是conflict-hard和conflict-noise。conflict-hard表示利用真实负样本计算AMNCS;conflict-noise表示利用噪声负样本计算AMNCS。结果表明:
- r=1.0时,针对AMNCS指标,conflict-hard>conflict-noise,表明负样本不采样,会使得模型过分拟合数据集,导致对噪声数据不鲁棒(按理说应该是conflict-noise>conflict-hard)。
- r=0.1时,针对AMNCS指标,conflict-hard<conflict-noise,刻画出真实数据特性。
消融实验
不同数据集、不同采样率下partial fc
不同网络结构下partial fc
对噪声数据鲁棒
采用WebFace12M-Conflict作为训练集合。