最近毕设要接触到联邦学习,在这里记录一下对联邦学习的理解,防止以后忘了。
机器学习小白一个,可能有错,欢迎指出!
FedAvg以及FedSGD的实现
谈到联邦学习,最重要的一个点是梯度融合时的策略。如何巧妙地融合各个客户端之间的梯度,以保证最优的模型效果,是联邦学习中需要着重研究的问题。FedAvg是一个经典且简单的策略。
一句话描述:所有client的梯度取平均值得到最终梯度。具体证明如下。
机器学习问题建模
首先,对于一切损失函数非凸的机器学习问题(例如神经网络),都可以表示为以下式子:
是目标函数,是参数。是第i个数据的损失,或者说第i个数据的代价。上面式子的意思是:一切非凸机器学习问题,都是一个最小化目标函数的问题,这个目标函数是由每个数据的损失平均贡献的。
进一步细化, ,这表示 在参数贡献了的损失。如果我们有个用户,每个用户贡献的数据为 ,其集合内元素数量,那么我们可以改写上面的式子为:
如果数据呈IID,即独立同分布,那么根据期望的性质可以得到:
当然,这个独立同分布不一定能实现,所以需要后面进一步讨论。
SGD算法
SGD算法全称是:stochastic gradient descent,随机梯度下降法,是一种常见的梯度下降算法。
和SGD相对的算法是BGD:batch gradient descent,批量梯度下降法BGD。BGD在每次更新时,要用到所有的样本来得到一个标准梯度,然后沿着这个梯度更新。因此对于凸优化问题BGD肯定可以得到一个收敛的解。而SGD则是每次取一个样本,来代替整个样本集合进行梯度下降,这样虽然不是每次迭代得到的损失函数都向着全局最优方向, 但是大的整体的方向是向全局最优解发展的。还有一种mini-batch梯度下降,是这两个方法的折中。
根据一些证明,SGD和BGD都能收敛,所以都是可用的。因为联邦学习是分布式的,所以肯定只能用SGD。
基线算法FederatedSGD
定义一个值C:每次参与联邦学习聚合的client数量占总client数量的比例。当C=1时,代表全员参与聚合。FedSGD就是在C=1时的一个基线算法,也就是每次让所有client参与,把本地所有的数据进行训练,在本地只训练一次,然后进行聚合(说实话我很迷惑,这不应该叫FedBGD吗)。
聚合时的操作是这样的:
首先要有一个固定学习律,然后每个用户计算自己的损失变化:,其中代表客户端此时的模型参数。服务端收收集损失变化,利用的平均值对整个模型进行更新:
因为的变化可以这么表示:
所以,这个式子也可以写为:
然后把所有全加起来,再把那一项换掉,可以得到:
上面这个式子就很清晰了,新的模型=每个设备的权重*每个设备的模型。在聚合之前,每个设备自己还可以自娱自乐,自己迭代多轮:
至此,FedSGD的操作就介绍完毕了,其实就是个求均值。
Federated Averaging
上面介绍了FedSGD,然而FedSGD其实是FedAvg的特殊情况。
我们定义三个参数:
:每轮参与联邦学习聚合的client数量占总client数量的比例。
:每个client在本地的训练次数(即自娱自乐的次数)。
:每个client在本地训练时的BatchSize。
然后FedAvg可以表示为如下伪代码:
Server executes: initialize for each round do
for each client in parallel do
Client Update // Run on client
split into batches of size
for each local epoch from 1 to do for batch do
return to server
可以看出,当C=B=1,且B为无穷大时,FedAvg与FedSGD一样。
至此,FedAvg的原理也介绍完毕了。
HierFAVG策略
具体请见这篇论文:https://arxiv.org/abs/1905.06641
目前来说,FedAvg是基于云的,client连到云上,然后进行FL。这里有一个巨大缺陷:设备连接时产生了巨大的网络资源消耗,与云服务器的连接也不一定稳定,一旦断掉很麻烦。所以,可以引入边缘计算来解决这个问题。但是边缘节点毕竟接入量有限,可能导致训练性能的大量损失。
因此,上面那篇论文提出了一种边云协同策略,具体来说就是:
- 首先,client先训练,然后把参数上传到边缘
- 边缘进行聚合,当边缘聚合k轮后,把聚合好的数据上传到云。
- 云总共聚合n轮。
所以总共训练n*k论,根据那个文章,训练性能还不错。
来看看
不叫FedBGD是不是可能是分布式学习中客户端可能不是全部参与模型的训练呀?