联邦学习大综述阅读笔记

Advances and Open Problems in Federated Learning

介绍

联邦学习是一种多个客户端(移动设备等)在中心服务器的组织下统一进行训练的机器学习方法,在联邦学习中,客户端上的数据私有的。本篇论文介绍了联邦学习当前在理论和实践方面被关注的问题。

联邦学习的各个客户端实在中心服务器的松散组织下进行训练的,它有三个主要的挑战。第一,其训练数据是不平衡且不独立同分布(identically and independently distributed, i.i.d)的。第二,其参与学习的客户端是不可靠的。第三,其通信带宽是有瓶颈的。

联邦学习本质上是一种跨学科的研究领域,涵盖机器学习、分布式优化、密码学、安全、查分隐私、公平、压缩感知、系统、信息论、统计等领域。

联邦学习最初是定义在手机等边缘设备和服务器的,但是一些领域将联邦学习应用于少量相对可靠的客户端,例如多个组织联合训练一个模型。基于此,将联邦学习定义为两类,分别是跨设备(cross-device)联邦学习和跨数据孤岛(cross-silo)联邦学习。

给出联邦学习更广泛的定义:联邦学习是一种多个实体(客户端)在中心服务器的协调下合作解决一个机器学习问题的机器学习方法。每个客户端的数据储存在本地且不进行交换或传输,而是通过旨在立即聚合的集中更新来实现学习目标。

Federated learning is a machine learning setting where multiple entities (clients) collaborate in solving a machine learning problem, under the coordination of a central server or service provider. Each client’s raw data is stored locally and not exchanged or transferred; instead, focused updates intended for immediate aggregation are used to achieve the learning objective.

集中更新是最小范围的更新,期望只包含当前特定学习任务所需的最低限度的信息。为了服务于数据最小化,更新应尽早进行聚合。

跨设备联邦学习被应用于消费级电子设备,谷歌将联邦学习应用于Gboard输入法,Pixel手机和安卓短信中,苹果将联邦学习应用于iOS13的QuickType输入法和语音唤醒”Hey Siri“中。doc.ai正在将联邦学习应用于医学研究,Snips探索了联邦学习的热词检测。

跨数据孤岛联邦学习被用于再保险的金融风险预测,药品发现,电子健康记录挖掘,医疗数据分割等。

已经有多个联邦学习框架出现,包括TensorFlow Fedrated, PySyft, Leaf, PaddleFL等。

下表是跨设备联邦学习,跨数据孤岛联邦学习与数据中心分布式计算特征的对比。

数据中心分布式学习 跨数据孤岛联邦学习 跨设备联邦学习
设置 在一个大而”平“的数据集上训练模型。客户端是一个簇或一个数据中心的计算节点。 在孤岛数据上训练模型,客户端是不同的组织(医疗或财经)或者地理分布式的数据中心。 客户端是大量的手机或者IoT设备。
数据分布 数据是集中存储并且可以被随机重排均匀分配给客户端的。每个客户端可以读取数据集的一部分。 数据在客户端本地生成并保持分散。每个客户端储存自己的数据并且不允许从其他客户端读取数据。数据并非独立同分布。
组织 中心组织。 中心服务器组织训练,但是无法见到原始数据。
广域通信 无,在一个集群中每个客户端是完全连接的。 通常是中枢辐射(hub-and-spoke)式拓扑,中心代表服务提供商,辐射代表连接的客户端。
数据可用性 所有客户端都是几乎随时可用的。 每次只有一小部分客户端可用。
分布规模 通常是1-1000个客户端。 通常是2-100个客户端。 大规模并行,可达到10的10次方个客户端。
主要瓶颈 通常是计算,可以假设网络速度非常快。 是计算或者通信。 通常是通信,但是也根据任务有所不同。通常跨设备联邦学习使用非常慢的wifi网络传输。
可寻址性 每个客户端都有身份或者名称,允许系统专门访问他。 客户端不能被直接索引。
客户端状态 有状态,每个客户端都可能参与了每一轮计算,在每一轮中都携带状态。 无状态,每个客户端在整个任务中可能只参与一次训练,因此可以假设在每一轮计算中都有客户端从未见过的新样本。
客户端可靠性 极少出错。 高度不可靠,5%或更多的参与本轮计算的客户端可能会出错或者脱节。
数据划分 数据可在客户机之间任意划分或重新划分。 分区是固定的,可能是按例划分(水平·)或按特征划分(垂直)。 固定按例划分(水平)。

联邦学习模型的生命周期

  1. 问题定义:确定一个使用联邦学习解决的问题。
  2. 客户端检测:根据不同的需要,客户端需要储存本地的训练数据,在一些情况下还要额外储存一些数据或元数据,例如为监督学习任务储存用户的交互数据。
  3. 模拟原型(可选):使用代理数据集(proxy dataset)建立联邦学习模型的原型,并测试超参数。
  4. 模型训练:使用多个联邦学习任务训练联邦学习模型的不同变体,或不同超参数。
  5. 联合模型验证:经过模型的训练和挑选后,在数据中心的标准数据集上或从未参与过联邦学习训练的客户端的本地数据上验证模型。
  6. 模型部署:包含人为质量保障,实时A/B测试(在一些设备上使用新模型,另一些设备上使用上一代模型)和分段检出(保障新模型性能不佳时能够及时回滚并且不会影响到太多客户端)。

典型联邦学习训练流程

以FedAvg算法为例讨论联邦学习模板。

  1. 客户端选择:服务器从客户端中选择一部分有资格参与学习的客户端,例如手机在充电、连接网络以及空闲的时候才能够参与学习。
  2. 广播:被选中的客户端下载当前模型的最新权重文件和训练程序。
  3. 客户端训练:每个参与当前轮训练的客户端使用本地数据和训练程序(例如FedAvg的SGD优化)训练模型。
  4. 模型合并:服务器收集各个客户端的权重更新并进行合并,为了效率,掉队的客户端会被丢弃。

下图给出了典型跨设备联邦学习的数据尺度。

客户端训练,模型合并和模型更新的步骤并不是严格分开的,例如异步SGD算法可以将客户端的权重更新立即合并到全局模型中。但是将这三步分开有利于多学科对于联邦学习的研究,包括数据加密,隐私查分等方面的研究可以在标准算法上进行并且可以被方便的融入到其他算法流程中。

联邦学习不应当影响用户体验,具体体现在两个方面。首先,联邦学习模型不应当给用户提供实时模型的预测,用户可见的模型应当是在生命周期的第六步中的分段检出中提供的。其次,应当在客户端空闲以及连接电源时进行联邦学习,联邦学习不应当导致设备变慢或消耗电量。

联邦学习研究

大部分对于联邦学习的研究人员不会部署联邦学习系统,也无法访问大量现实世界中的设备,这导致了联邦学习在模拟实验和现实应用中的区别。

作者对于联邦学习的研究和问题定位给出如下建议:

  • 联邦学习涵盖了多个领域的问题,所以准确的描述特定联邦学习设置的细节是十分重要的,特别是所提方法做出了可能不适合所有环境的假设时(例如假设所有客户端在每一轮训练中都能够参与)。
  • 在模拟中的任何细节都应当被描述,以便进行复现。另外一个重点在于解释当前的联邦学习模拟设置针对的是现实世界中的那个问题,以便高效的将当前的研究应用于现实中的系统。
  • 隐私和通信效率是联邦学习中需要首先考虑的问题,即使实验是在单独的机器上使用公开数据集进行的也是如此。需要明确给出计算在哪里发生以及传输的数据是什么。

为不同的联邦学习指定开发标准的评估指标和建立基准数据集仍是正在进行工作的一个重要方向。

联邦学习的新兴场景及应用

全分散/点对点式学习

在联邦学习中,中心服务器组织训练的过程并接收客户端的更新。因此中心服务器是联邦学习的重点,但是其也可能成为一个故障点。尽管大公司或组织可以承担这个任务,在许多联邦学习场景下也没有一个可靠且强大的中心服务器随时可用。进一步讲,一旦客户端数量巨大,中心服务器也可能会成为瓶颈(虽然可以通过系统设计避免这个问题)。

全分散式学习的中心思想是通过点对点的通信代替客户端与服务器的通信。通信拓扑可以用一个连接图表示,其中节点代表客户端,边代表客户端之间的通信信道。网络图通常被设计为稀疏且最大度较小,这样每个节点只需要向少量节点发送或者接收消息,这与客户端服务器的星型架构形成了鲜明的对比。全分散算法的每一轮训练执行客户端的本地更新并和图中的邻居节点交换信息。在机器学习的角度看,本地更新是本地(随机)梯度下降,通信包括与邻节点平均自己的本地模型参数。需要注意的是,在这种学习方式下没有传统联邦学习的全局模型,而是设计为所有本地模型都汇集到所需的全局解,即每个模型最终都达成一致。虽然多智能体优化(multi-agent optimization)在控制领域中有悠久的历史,全分散的SGD以及其他优化算法变种是在最近才被发掘的,以增强数据中心的可扩展性和网络设备的分散性。一些工作考虑了无向图和有向图的应用。

尽管在全分散式的学习中,仍然会有一个中心控制学习任务的设置。一些无可避免的问题例如:谁来决定在无中心学习中训练哪个模型?要用哪个算法?超参数如何设置?当任务出错时谁应该负责debug?任然需要一个被信任的客户端去拥有回答这些问题的权利。或者这些决定可以由提出当前学习任务的客户端做出,再或者可以通过协商一致的方法决定(例如投票)。

下图给出了点对点学习和联邦学习的对比。尽管分散学习和联邦学习的架构假设不同,但它们仍旧可以用于许多相同的问题领域,它们面临许多相同的挑战,并且它们的研究领域也有许多重合。因此这篇文章中包含了点对点学习的内容,在这一节中明确考虑了分散式学习面临的独特挑战,但是许多在其他节中提出的联邦学习的挑战也适用于分散式学习。

算法挑战

在机器学习分散式方案的可用性问题上,许多重要算法的问题仍然没有解决。一些问题类似于使用中央服务器进行联邦学习的特殊情况,而其他问题则完全是全分散化和没有可信的服务器带来的副作用。

网络拓扑和异步去分散式SGD的影响:全分散算法需要对客户端的不可靠性(包括客户端暂时连接,训练中途加入或退出)和网络的不可靠性(丢包)鲁棒。对于广义线性模型的特殊情况,使用二元结构可以使部分算法鲁棒,但对于深度学习和SGD算法这仍是一个问题。有工作展示了在网络图完整并且以一个固定概率丢包时可以达到与可靠网络相当的收敛速度。

连接质量良好或密集连接到网络能够更快得达到一致,并提供更好的理论收敛效率,这取决于网络图的谱隙(最小的非零特征值)。但当数据独立同分布时,稀疏拓扑不一定会影响到收敛速度。密集的网络通常会引起延迟,这种延迟随着节点度的增加而增加。大部分优化理论工作没有明确的考虑到网络拓扑如何影响运行时,即完成每次SGD迭代所需的系统时间(wall-clock time)。MATCHA是一种基于匹配分解采样的分散式SGD方法,可以减少任何点拓扑下每轮迭代所需的通信延迟并且保持相同的收敛速度。关键思想是将图拓扑分解成可并行执行的不相交的通信链路对,并在每次迭代中仔细选择这些匹配的子集。这种子图序列导致在关键链路上更频繁的通信(确保快速收敛)和在其他链路上更不频繁的通信(降低通信延迟)。

分散式SGD也天然适用于异步算法,其中每个客户端在随机时间独立激活,消除了全局同步需求并提高了可扩展性。

本地更新的分散式SGD:对于在通信前只执行一次SGD更新,在通信之前进行多次本地更新的方法的理论分析(例如mini-batch SGD)更具挑战性。在非IID本地数据集的情况下,依赖单次本地更新步骤的方法通常被证明是收敛的。对于具有多个本地更新步骤的情况,一些最近的工作提供了收敛性分析。此外,有工作对于非IID数据给出了收敛性分析但仅限于上文中提到的匹配分解采样方案。总的来说理解非IID数据分布下的收敛性以及如何设计模型的平均策略以实现最快收敛仍是一个问题。

个性化和信任机制:与跨设备联邦学习相似,在每个客户端使用非IID数据分布的全分散场景下的一个重要任务是设计算法学习个性化模型的集合。一些工作通过平滑任务相似的客户端(例如数据分布一致)的模型参数使用全分散算法对每个客户端训练个性化模型。Zantedeschi 等人进一步学习了个性化模型的相似性图。全分散模型的一个关键问题是算法对恶意用户或者不可靠数据的鲁棒性。将激励机制或者机制设计与全分散学习结合目前是一个重要的目标,这在没有可信赖的中心服务器下难以实现。

现实挑战

跨数据孤岛联邦学习

与跨设备联邦学习的特征不同,跨数据孤岛联邦学习在总体设计的特定方面灵活性更强,但同时实现其他属性也更加困难。

跨数据孤岛联邦学习适用于一定数量的公司或者组织能够共享在他们的数据上训练的模型的激励但是不能直接共享数据时。这可能是由于保密或者法律的限制,甚至是在一家公司内部由于地理区域的不同而无法集中数据所造成的。

数据分割

在跨设备联邦学习设置中,数据是根据例子分割的。在跨数据孤岛联邦学习中,除了根据例子划分数据,根据特征划分也是具有实际意义的。例如两家具有交叉或相同客户的公司,比如同一个城市中的本地银行和本地零售商。这两种不同的分割方法在一些文章中也被称为水平划分和垂直划分。

按特征分割数据的跨数据孤岛联邦学习与按例子分割数据的方法相比,使用了非常不同的训练架构。它可能涉及也可能不需要一个中央服务器作为中立放,并且根据训练的具体算法,客户端交换特定的中间结果而不是模型参数以帮助其他部分的梯度计算。

假设有两个公司A和B要联合训练一个机器学习模型,并且他们的系统各自拥有自己的数据。除此以外公司B还有数据的标签。A和B由于数据隐私和安全原因无法直接交换数据。为了保障训练过程中数据的隐私性,需要引入一个第三方合作者C。假设C是诚实的并且对于A和B没有任何关系,且A和B是诚实的并对对方好奇。在现实场景下C是有政府或者安全计算节点(英特尔 SGXs)扮演的。

联邦学习系统包括两个部分。

第一部分:加密实体对齐。由于两个公司的用户群不是完全相同的,系统通过加密后的用户ID对齐A和B的交叉用户,并且将不交叉的用户数据丢弃。

第二部分:加密模型训练。在确定了哪些数据是公共的之后,可以使用这些实体数据训练机器学习模型,训练过程可以分为以下四部分。

  1. 参与者C创建加密对,并向A和B发送公钥。
  2. A和B为梯度和loss计算,加密并交换中间结果。
  3. A和B计算加密后的梯度并分别增加一个额外的掩膜。B还要计算一个加密后的loss。A和B向C发送加密数据。
  4. C解密并将解密后的梯度和loss发回给A和B。A和B给梯度去掩膜并更新模型参数。

下面用线性回归和同态加密举例。

[to be fill]

Yang, Q., et al. (2019). "Federated Machine Learning." ACM Transactions on Intelligent Systems and Technology 10(2): 1-19.

安全计算方法

安全计算方法可以分为基于噪声的方法和无噪的方法。

无噪的方法主要包括混淆电路,同态加密和密钥分享

https://zhuanlan.zhihu.com/p/31641175

拆分学习

拆分学习的主要思想是将模型进行层(Layer)级的拆分,并设置在客户机和服务器中。这种方法在训练和推理时都可以使用。

在最简单的拆分学习中,每个客户端计算神经网络一个指定层(称为分割层,cut layer)的前向过程,分割层的输出被称为破碎数据(smashed data),破碎数据被送往另一个实体(服务器或另一个客户机),完成余的计算。这样做可以在不传播原始数据的情况下完成一轮前向传播,梯度可以从最后一层反向转播到分割层,类似于正向转播过程。在分割层上的梯度被送回给客户端,并完成反向传播。拆分学习可以在每个客户端不接触其他客户端的数据的情况下完成训练。

拆分学习提供了另一个维度上的模型并行化思路,并行化实在模型层间进行的。一些工作中打破了神经网络层间的依赖关系,通过并行计算各层来减少总的训练时间。还有一些工作研究了将客户端模型组件与最佳的服务端模型组件相结合以实现模型的自动选择。

通信中的数据能够泄露原始数据的信息,至于泄露多少,是否可以接受要具体问题具体分析。一个叫做NoPeek SplitNN的拆分学习变体通过降低原始数据与破碎数据的距离相关性减少了原始数据泄露的可能,同时通过绝对交叉熵保障了模型性能。另一种工程驱动的减少泄露的方法是修剪客户端中激活的通道。

提高效率和有效性

保护私有数据安全

攻击防御

确保公平消除偏差

应对系统挑战


文章作者: keevinzha
版权声明: 咳咳想白嫖文章?本文章著作权归作者所有,任何形式的转载都请注明出处。 https://www.keevinzha.com !
  目录