VisDA-2020亚军技术方案分享

©作者|葛艺潇

学校|香港中文大学博士生

研究方向|图像检索、图像生成等

本文介绍我们在 ECCV 2020 的 Visual Domain Adaptation Challenge 中取得第二名的技术方案,该方案主要结合了两个我们自己的工作:MMT [1] 和 SDA [2],并且在原始 MMT 的基础上进行了升级,提出 MMT+,用以应对更嘈杂的目标域数据。比赛结束**个月了,终于抽出空来总结一下,也方便以后的参赛者参考。代码和模型均已公开。

比赛链接:

http://ai.bu.edu/visda-2020/

技术报告链接:

https://arxiv.org/abs/2008.10313

代码链接:

https://github.com/yxgeee/MMT-plus

视频介绍:

https://www.zhihu.com/zvideo/1300859379609722880

比赛

今年视觉领域自适应的比赛关注于行人重识别问题,参见主办方推文:

https://zhuanlan.zhihu.com/p/137644578

背景技术

目前针对领域自适应行人重识别的算法主要分为两类:

1)域转换类,如:SPGAN、PTGAN、SDA等。算法主要分为两步,第一步先训练一个 GAN 模型将源域图像转换为目标域风格,并保持原本 ID;第二步利用转换后的源域图像及其真实标签进行训练,获得最终模型。优势在于可以充分利用源域图像;劣势在于目前来看单独使用时性能不佳。

2)伪标签类,如:SSG、PAST、MMT 等。训练过程中在两步之间交替进行,其中第一步是利用聚类算法(也有的算法不使用聚类,这里主要介绍基于聚类的算法)为目标域图像生成伪标签,第二步是利用伪标签与目标域图像进行训练。优势在于目前公共 benchmark 上一直保持 SoTA 的性能;劣势在于当伪标签噪声较大时,训练不稳定,甚至误差放大。

挑战

该比赛与公共 benchmark 相比,最大的两项挑战在于:

1)源域为合成行人图像(PersonX [3]),目标域为真实行人图像,域差异较大。若直接使用源域图像进行模型的预训练,预训练的模型在目标域上表现非常差,从而生成的伪标签也不理想。解决方案为,先利用域转换算法(SDA)将源域图像的风格迁移到目标域,再做预训练,这样可以较大程度上提升初始伪标签的质量,从而有助于后续的伪标签法。

2)目标域 ID 分布较为嘈杂,有的人可能有多张图片,而有的人可能只有很少的图片,为聚类型算法带来挑战。解决方案为,进一步改进伪标签算法(MMT+)以对抗较大的伪标签噪声,保证模型鲁棒性。

技术框架

我们的总体训练框架主要分为三步:利用 SDA 算法训练的 GAN 模型将源域图像转换到目标域、利用转换后的源域图像进行预训练、在预训练后的模型基础上利用 MMT+ 算法继续训练。下面具体介绍每个步骤。

训练步骤一:SDA域转换

Structured Domain Adaptation(SDA)是我们在 19 年下半年所做的一篇工作,很不幸的是一直还没有被收录,但是该方法在域转换上还是很有效的,论文如下:https://arxiv.org/abs/2003.06650

该方法的 idea 很简单,主要围绕一个 loss,文中称作关系一致性损失(relation-consistency loss)。Intuition 是,在训练域转换 GAN 模型时,要求源域图像风格迁移后,图像间的关系保持不变,这样可以更好地维持源域图像原有的信息和数据分布。

和经典工作 SPGAN 相比时,SDA 迁移的源域图像在目标域仍保持良好的类内关系,如下图所示,蓝色裙子的人在经过 SPGAN 的域转换后变成了蓝裙子和白裙子,而 SDA 维持了蓝色裙子。

具体训练细节在这里不展开了,感兴趣的同学可以参阅 SDA 的原论文以及比赛的 technique report。

训练步骤二:转换后的源域图像预训练

对源域图像进行域转换的目的是,提供更好的预训练模型。所以,我们使用训练好的 SDA 将所有源域图像转移到目标域,并用其进行网络的预训练。如下图所示,无论是 SPGAN 还是 SDA,用域转换后的源域图像进行预训练比原始源域图像预训练所得到的的模型精度要明显高出许多。

关于源域的预训练,我们总结了几点有用的 training tricks:

  • 使用 auto-augmentation 可以有效避免 overfit;

  • 使用 GeM pooling 代替 average pooling;

  • 一个在 MMT 开源代码中涉及的 trick:虽然目标域图像无标签,但是也可以用于在训练过程中进行 forward computation,可以一定程度上有效地将 BN adapt 到目标域,以下是伪代码。

训练步骤三:MMT+目标域训练

Mutual Mean-Teaching(MMT)是我们发表于 ICLR 2020 上的工作,框架非常有效,在公共 benchmark 上一直保持 SoTA 水平。

以下是以前写的论文讲解,

https://zhuanlan.zhihu.com/p/11607494zhuanlan.zhihu.com

我们将原始的 MMT 画成以下示意图:

之前提到,该比赛的目标域数据集 ID 分布较为嘈杂,导致伪标签噪声很大,为了进一步减轻伪标签噪声对 MMT 训练带来的影响,我们采取了以下两个措施:

  • 加入源域的图像进行协同训练,等于是加入了有真实标签的干净数据,这里使用 Domain-specific BN [4] 来消除 domain gap 对训练的影响;

  • 加入 MoCo [5] loss 进行实例区分任务,一定程度上抵消错误的伪标签带来的影响。值得注意的是,由于 MMT 中的 Mean-Net 和 MoCo 中的 momentum encoder 基本上一致,所以在 MMT 中加入 MoCo loss 很方便。

结合以上两点后,我们将新的框架称之为 MMT+:

这里还是送大家一点干货,training tricks:

  • 使用 ArcFace 或 CosFace 代替普通的 linear classification loss;

  • 由于 ArcFace 或 CosFace 已经很强了,所以 triplet loss 作用不大了;

  • 使用 GeM pooling 代替 average pooling;

  • 一个实验性的结论,不要在这一步使用 auto-augmentation;

  • 一个在 MMT 开源代码中涉及的 trick:每次重新聚类后重置优化器。

对比一下 MMT+ 与原始 MMT 的性能差异:

测试后处理:

  • 模型融合

融合了四个 backbone,具体做法很简单,就是提取特征 ->concate->L2-norm。

  • 消除相机偏差 [6]:单独训练一个相机分类模型,在原始 person similarity 的基础上减去 camera similarity。

  • K-reciprocal re-ranking [7]

总结

我们主要在模型训练上进行了改进,充分利用了域转换和伪标签两类方法,每个模型单独的性能都是不错的。但是比赛小白,测试后处理上差了点意思,还是有很大提升空间的,再接再厉。感谢组委会,祝以后的参赛者好运。

参考文献

[1] Mutual Mean-Teaching: Pseudo Label Refinery for Unsupervised Domain Adaptation on Person Re-identification. ICLR 2020. 

[2] Structured Domain Adaptation with Online Relation Regularization for Unsupervised Person Re-ID. 

[3] Dissecting Person Re-identification from the Viewpoint of Viewpoint. CVPR 2019.

[4] Domain-Specific Batch Normalization for Unsupervised Domain Adaptation. CVPR 2019.

[5] Momentum Contrast for Unsupervised Visual Representation Learning. CVPR 2020.

[6] Voc-reid: Vehicle re-identification based on vehicle orientation camera. CVPRW 2020.

[7] Re-ranking person re-identification with k-reciprocal encoding. CVPR 2017.

更多阅读

#投 稿 通 道#

 让你的论文被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学习心得技术干货。我们的目的只有一个,让知识真正流动起来。

???? 来稿标准:

• 稿件确系个人原创作品,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向) 

• 如果文章并非首发,请在投稿时提醒并附上所有已发布链接 

• PaperWeekly 默认每篇文章都是首发,均会添加“原创”标志

???? 投稿邮箱:

• 投稿邮箱:[email protected] 

• 所有文章配图,请单独在附件中发送 

• 请留下即时联系方式(微信或手机),以便我们在编辑发布时和作者沟通

????

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

关于PaperWeekly

PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。