查看原文
其他

AAAI 2024 | 上交等提出自适应间距强化对比学习,增强多个模型的分类能力

张剑清 PaperWeekly
2024-08-22


©PaperWeekly 原创 · 作者 | 张剑清

单位 | 上海交通大学、清华大学(AIR)

研究方向 | 联邦学习


本文介绍的是我们的一篇收录于 AAAI 2024 的论文,主要考虑的是数据异质和模型异构场景下的联邦学习框架。在异构联邦学习中,由于模型架构不同,传统联邦学习中的参数聚合方法不再适用,取而代之的是基于知识蒸馏的知识共享方法。
在这些方法中,我们关注不引入额外数据集的(data-free)这一类方法。这类方法普遍通过共享类别表征向量(prototype)实现,但在模型架构差异较大的场景,每个客户机生成的表征向量差异悬殊,直接在服务器端聚合表征向量会造成表征能力的下降。于是,我们提出一种在服务器端基于自适应间距强化的对比学习来提高表征向量的表征能力的方法 FedTGP,进一步提升客户端模型的分类能力。

论文标题:

FedTGP: Trainable Global Prototypes with Adaptive-Margin-Enhanced Contrastive Learning for Data and Model Heterogeneity in Federated Learning

论文链接:

https://arxiv.org/abs/2401.03230

代码链接:

https://github.com/TsingZ0/FedTGP(含有PPT和Poster)

运行实验所需仓库-个性化联邦学习算法库:

https://github.com/TsingZ0/PFLlib

运行实验所需仓库-异构联邦学习算法库:

https://github.com/TsingZ0/HtFLlib




异构联邦学习背景

传统联邦学习通过在每一次迭代中传递模型参数的方式实现知识共享,但该方式存在局限,无法适应更广泛的场景,尤其是不易寻找到参与联邦学习的客户机。客户机在参与联邦学习之前,有自己本地的模型训练任务,也有自研的模型架构和训练得到的模型参数。每个客户机参加联邦学习的动机是为了通过联邦学习增强自己模型的表现能力。若强制要求参与的客户机都使用相同的模型结构且进行模型参数共享,则需要每个客户机重新训练模型。

另一方面,每个客户机训练得到的模型参数也是一种数字资产,尤其是在大模型时代保护模型参数的知识产权尤为重要。此外,共享模型参数也有通讯量大的问题。通过允许异构模型参与联邦学习,并共享轻量化的知识载体,异构联邦学习拓展了传统联邦学习的边界,变得更加实用。
▲ 图1:异构联邦学习技术
目前异构联邦学习技术还未形成统一的知识共享机制,我们考虑一种轻量化且不需要额外数据的知识共享机制:共享 prototype。本文考虑的是面向图像的多分类任务,其 prototype 的定义就是每个类别的代表性特征向量,可通过平均该类所有的特征向量获得。现有工作中,FedProto [1] 是这方面最具代表性的方法之一,如下图所示。


▲ 图2:异构联邦学习中使用prototype作为知识载体


FedProto的局限性

虽然 FedProto 得到了广泛使用,但之前的工作要么将其用在传统联邦学习场景(异构联邦学习技术在传统场景也都适用),要么采用异构性不强的异构模型(比如增减全连接层数和改变 CNN 网络的卷积核等)。在这些场景下,通过加权平均聚合 prototype 的方式确实具有不错的表现。 

但当我们考虑更一般的场景:参与联邦学习的客户机训练的模型的架构差异巨大,比如两层 CNN 模型和 ResNet-152 模型。此时 FedProto 的 prototype 聚合方法就出现了一些问题。我们观察到,由于模型架构相差巨大,不同模型的特征提取能力也天差地别,它们生成的 prototype 也天差地别。

当我们通过加权平均去计算全局 prototype(global prototype)时,具有较好表征能力(不同 prototype 之间的间距(margin)较大)的 prototype 会被较差表征能力的 prototype 影响,导致最终得到的 global prototype 表征能力弱于最好的客户机模型。我们称这种现象为间距收缩(margin shrink),如下图所示。进一步地,当这个特征提取能力最好的客户机模型使用了 global prototype 之后,其表征能力则会下降。

▲ 图3:FedProto在模型异构性较大场景下的间距收缩现象(Cifar10)


自适应间距强化的对比学习(ACL)

为了解决上述间距收缩的问题,我们提出了一种自适应间距强化的对比学习方法(ACL),如下图所示。

▲ 图4:FedProto与FedTGP的对比。其中圆形代表客户机上传的prototype,三角形代表global prototype。
该方法的核心思想是训练一个 global prototype,使其能够最大限度地保留最强客户机模型生成的 prototype 的表征能力,同时也汲取来自其他客户机的 prototype 信息。为了实现这一点,我们首先给传统对比学习方法加上一个间距限制,即尽可能保证 prototype 之间的间距不低于所设置的阈值 。考虑类别 对应的 trainable global prototype(TGP),我们定义其训练时候的损失函数为:


其中, 是在第 轮参与联邦学习的客户机集合, 是客户机 上生成的对应类别 的 prototype, 是间距计算函数。
但在联邦学习的过程中,各个客户机模型的特征提取能力不断变化,若设置一个固定的阈值,则会导致间距过大或过小。于是我们考虑将 设置为一个自适应的值,其计算细节如下,其描述的就是每一轮不同类别之间的最大间距,且具有最大值

从而我们得到最终的对比学习目标:

使用 ACL 之后,我们便可以消除间距收缩的问题:

 ▲ 图5:我们的FedTGP在使用ACL之后,消除了间距收缩的问题(Cifar10)


部分实验
由于篇幅原因,我们只展示部分实验结果,更多实验结果和分析详见论文。

▲ 表1:在4个数据集和8种异构模型场景下的测试准确率

▲ 表2:在Cifar100数据集和不同模型异构级别情况下的测试准确率

参考文献

[1] Tan Y, Long G, Liu L, et al. Fedproto: Federated prototype learning across heterogeneous clients[C]//Proceedings of the AAAI Conference on Artificial Intelligence. 2022.


更多阅读



#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编



🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧


·
·
·

继续滑动看下一个
PaperWeekly
向上滑动看下一个

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存