共计 7297 个字符,预计需要花费 19 分钟才能阅读完成。
本文将介绍如何应用 Flower 构建现有机器学习工作的联邦学习版本。咱们将应用 PyTorch 在 CIFAR-10 数据集上训练卷积神经网络,而后将展现如何批改训练代码以联邦的形式运行训练。
什么是联邦学习?
咱们将在这篇文章中辨别两种次要办法:集中式和联邦式 (本文的图例示意如下)
集中式
每个设施都会将其数据发送到全局服务器,而后服务器将应用它来训练全局模型。训练实现后服务器将经过训练的全局模型发送到设施。
这并不是咱们所说的联邦学习的解决方案,传输了数据,会带来很多问题
联邦式
每个设施都不会与服务器共享数据,而是将数据保留在本地并用它来训练模型。模型的权重会被发送到全局服务器,而后全局服务器会将收到的所有权重聚合到一个全局模型中,服务器最终将经过训练的全局模型发送到设施。这种形式是个别模式的联邦学习,它的次要长处是爱护用户的隐衷,防止数据泄露。
咱们先实现集中式训练代码,因为该训练模式基本上与传统的 PyTorch 训练雷同,而后再将其改为联邦学习的形式。
集中式 PyTorch 训练
让咱们创立一个名为 cifar.py 的新文件,其中蕴含在 CIFAR-10 上进行传统(集中式)训练所需的所有组件。首先,须要导入所有的包(例如 torch 和 torchvision)。咱们当初没有导入任何用于联邦学习的包。能够稍后再进行导入。
fromtypingimportTuple, Dict | |
importtorch | |
importtorch.nnasnn | |
importtorch.nn.functionalasF | |
importtorchvision | |
importtorchvision.transformsastransforms | |
fromtorchimportTensor | |
fromtorchvision.datasetsimportCIFAR10 |
模型架构(一个非常简单的卷积神经网络)在 Net() 类中定义。
classNet(nn.Module): | |
def__init__(self) ->None: | |
super(Net, self).__init__() | |
self.conv1=nn.Conv2d(3, 6, 5) | |
self.pool=nn.MaxPool2d(2, 2) | |
self.conv2=nn.Conv2d(6, 16, 5) | |
self.fc1=nn.Linear(16*5*5, 120) | |
self.fc2=nn.Linear(120, 84) | |
self.fc3=nn.Linear(84, 10) | |
defforward(self, x: Tensor) ->Tensor: | |
x=self.pool(F.relu(self.conv1(x))) | |
x=self.pool(F.relu(self.conv2(x))) | |
x=x.view(-1, 16*5*5) | |
x=F.relu(self.fc1(x)) | |
x=F.relu(self.fc2(x)) | |
x=self.fc3(x) | |
returnx |
load_data() 函数加载 CIFAR-10 训练和测试集。转换在加载后规范化了数据。
DATA_ROOT="~/data/cifar-10" | |
defload_data() ->Tuple[ | |
torch.utils.data.DataLoader, | |
torch.utils.data.DataLoader, | |
Dict | |
]: | |
"""Load CIFAR-10 (training and test set).""" | |
transform=transforms.Compose([transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), | |
(0.5, 0.5, 0.5) | |
) | |
] | |
) | |
trainset=CIFAR10(DATA_ROOT, | |
train=True, | |
download=True, | |
transform=transform) | |
trainloader=torch.utils.data.DataLoader(trainset, | |
batch_size=32, | |
shuffle=True) | |
testset=CIFAR10(DATA_ROOT, | |
train=False, | |
download=True, | |
transform=transform) | |
testloader=torch.utils.data.DataLoader(testset, | |
batch_size=32, | |
shuffle=False) | |
num_examples= {"trainset" : len(trainset), "testset" : len(testset)} | |
returntrainloader, testloader, num_examples |
咱们当初须要定义训练函数 train(),它循环遍历训练集、计算损失、反向流传,而后对每批训练执行一个优化步骤。
模型的评估在函数 test() 中定义。该函数遍历所有测试样本并依据测试数据集测量模型的损失。
deftrain( | |
net: Net, | |
trainloader: torch.utils.data.DataLoader, | |
epochs: int, | |
device: torch.device, | |
) ->None: | |
"""Train the network.""" | |
# Define loss and optimizer | |
criterion=nn.CrossEntropyLoss() | |
optimizer=torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) | |
print(f"Training {epochs} epoch(s) w/ {len(trainloader)} batches each") | |
# Train the network | |
forepochinrange(epochs): # loop over the dataset multiple times | |
running_loss=0.0 | |
fori, datainenumerate(trainloader, 0): | |
images, labels=data[0].to(device), data[1].to(device) | |
# zero the parameter gradients | |
optimizer.zero_grad() | |
# forward + backward + optimize | |
outputs=net(images) | |
loss=criterion(outputs, labels) | |
loss.backward() | |
optimizer.step() | |
# print statistics | |
running_loss+=loss.item() | |
ifi%100==99: # print every 100 mini-batches | |
print("[%d, %5d] loss: %.3f"% (epoch+1, | |
i+1, | |
running_loss/2000)) | |
running_loss=0.0 | |
deftest( | |
net: Net, | |
testloader: torch.utils.data.DataLoader, | |
device: torch.device, | |
) ->Tuple[float, float]: | |
"""Validate the network on the entire test set.""" | |
criterion=nn.CrossEntropyLoss() | |
correct=0 | |
total=0 | |
loss=0.0 | |
withtorch.no_grad(): | |
fordataintestloader: | |
images, labels=data[0].to(device), data[1].to(device) | |
outputs=net(images) | |
loss+=criterion(outputs, labels).item() | |
_, predicted=torch.max(outputs.data, 1) | |
total+=labels.size(0) | |
correct+= (predicted==labels).sum().item() | |
accuracy=correct/total | |
returnloss, accuracy |
定义了数据加载、模型架构、训练和评估后,咱们能够将所有内容放在一起并在 CIFAR-10 上训练咱们的 CNN。
defmain(): | |
DEVICE=torch.device("cuda:0"iftorch.cuda.is_available() else"cpu") | |
print("Centralized PyTorch training") | |
print("Load data") | |
trainloader, testloader, _=load_data() | |
print("Start training") | |
net=Net().to(DEVICE) | |
train(net=net, trainloader=trainloader, epochs=2, device=DEVICE) | |
print("Evaluate model") | |
loss, accuracy=test(net=net, testloader=testloader, device=DEVICE) | |
print("Loss:", loss) | |
print("Accuracy:", accuracy) | |
if__name__=="__main__": | |
main() |
当初就能够间接运行了:
python3 cifar.py
到目前为止,如果你以前应用过 PyTorch,这所有看起来应该相当相熟。上面开始进入正题,咱们开始构建一个简略的联邦学习零碎,该零碎由一个服务器和两个客户端组成。
PyTorch 的联邦学习
咱们曾经在单个数据集 (CIFAR-10) 上训练了模型,咱们称之为集中学习。这种集中学习的概念是咱们以前罕用的形式。通常,如果你想以联邦学习的形式运行,则必须更改大部分代码并从头开始设置所有内容。然而,这里有一个包 Flower,它能够将事后存在的代码以联邦学习运行(当然须要大量的批改)。
既然是联邦学习,咱们必须有服务器,而后 cifar.py 代码也须要连贯到服务器的客户端。服务器向客户端发送模型参数。客户端运行训练并更新参数。更新后的参数被发送回服务器,服务器对所有接管到的参数更新进行均匀,这就是联邦学习的一个简略的流程。
咱们这个例子是由一台服务器和两个客户端组成。咱们先设置 server.py。服务端须要导入 Flower 包 flwr,而后应用 start_server 函数启动服务器并通知它执行三轮联邦学习。
importflwrasfl | |
if__name__=="__main__": | |
fl.server.start_server( | |
server_address="0.0.0.0:8080", | |
config=fl.server.ServerConfig(num_rounds=3) | |
) |
而后就能够启动服务器了:
python3 server.py
咱们还要在 client.py 中定义客户端逻辑,次要就是将之前在 cifar.py 中定义的集中训练的代码进行整合:
fromcollectionsimportOrderedDict | |
fromtypingimportDict, List, Tuple | |
importnumpyasnp | |
importtorch | |
importcifar | |
importflwrasfl | |
DEVICE: str=torch.device("cuda:0"iftorch.cuda.is_available() else"cpu") |
Flower 客户端须要实现 flwr.client.Client 或 flwr.client.NumPyClient 类。这里的实现将基于 flwr.client.NumPyClient,咱们将其称为 CifarClient。因为咱们应用了 NumPy,而 PyTorch 或 TensorFlow/Keras)都是间接是吃 NumPy 的互操作,所以应用 NumPyClient 比 Client 更容易。
实现咱们的 CifarClient 须要实现四个办法,两个获取 / 设置模型参数的办法,一个训练模型的办法,一个测试模型的办法:
1、set_parameters
这个办法有 2 个作用:
- 在从服务器接管的本地模型上设置模型参数
- 遍历作为 NumPy ndarray 接管的模型参数列表
2、get_parameters
获取模型参数并将它们作为 NumPy ndarray 的列表返回(这是 flwr.client.NumPyClient 所须要的)
3、fit
一看就晓得,这是训练本地模型的办法,它有 3 个作用:
- 应用从服务器接管到的参数更新本地模型的参数
- 在本地训练集上训练模型
- 训练本地模型,并将权重上传服务器
4、evaluate
验证模型的办法:
- 从服务器接管到的参数更新本地模型的参数
- 在本地测试集上评估更新后的模型
- 将本地损失和准确率等指标返回给服务器
咱们先前在 cifar.py 中定义的函数 train() 和 test() 能够作为 fit 和 evaluate 应用。所以在这里真正要做的是通过咱们的 NumPyClient 类通知 Flower 曾经定义的哪些函数,剩下的两个办法实现起来也不简单:
classCifarClient(fl.client.NumPyClient): | |
"""Flower client implementing CIFAR-10 image classification using | |
PyTorch.""" | |
def__init__( | |
self, | |
model: cifar.Net, | |
trainloader: torch.utils.data.DataLoader, | |
testloader: torch.utils.data.DataLoader, | |
num_examples: Dict, | |
) ->None: | |
self.model=model | |
self.trainloader=trainloader | |
self.testloader=testloader | |
self.num_examples=num_examples | |
defget_parameters(self, config) ->List[np.ndarray]: | |
# Return model parameters as a list of NumPy ndarrays | |
return [val.cpu().numpy() for_, valinself.model.state_dict().items()] | |
defset_parameters(self, parameters: List[np.ndarray]) ->None: | |
# Set model parameters from a list of NumPy ndarrays | |
params_dict=zip(self.model.state_dict().keys(), parameters) | |
state_dict=OrderedDict({k: torch.tensor(v) fork, vinparams_dict}) | |
self.model.load_state_dict(state_dict, strict=True) | |
deffit(self, parameters: List[np.ndarray], config: Dict[str, str] | |
) ->Tuple[List[np.ndarray], int, Dict]: | |
# Set model parameters, train model, return updated model parameters | |
self.set_parameters(parameters) | |
cifar.train(self.model, self.trainloader, epochs=1, device=DEVICE) | |
returnself.get_parameters(config={}), self.num_examples["trainset"], {} | |
defevaluate(self, parameters: List[np.ndarray], config: Dict[str, str] | |
) ->Tuple[float, int, Dict]: | |
# Set model parameters, evaluate model on local test dataset, return result | |
self.set_parameters(parameters) | |
loss, accuracy=cifar.test(self.model, self.testloader, device=DEVICE) | |
returnfloat(loss), self.num_examples["testset"], {"accuracy": float(accuracy)} |
最初咱们要定义一个函数来加载模型和数据,创立并启动这个 CifarClient 客户端。
defmain() ->None: | |
"""Load data, start CifarClient.""" | |
# Load model and data | |
model=cifar.Net() | |
model.to(DEVICE) | |
trainloader, testloader, num_examples=cifar.load_data() | |
# Start client | |
client=CifarClient(model, trainloader, testloader, num_examples) | |
fl.client.start_numpy_client(server_address="0.0.0.0:8080", client) | |
if__name__=="__main__": | |
main() |
这样就实现了。当初能够关上两个额定的终端窗口并运行(因为咱们要演示 2 个客户端的联邦学习)
python3 client.py
在每个窗口中(请确保后面的服务器正在运行)能够看到你的 PyTorch 我的项目在两个客户端上进行训练了。
总结
本文介绍了如何应用 Flower 将咱们原有 pytorch 代码革新为联邦学习的形式进行训练,残缺的代码能够在这里找到:
https://avoid.overfit.cn/post/8d05a12c208c4f499573c9966d0fe415
作者:Charles Beauville