本文提供一个基于 PySyft 和 Torch 的联邦学习案例,应用自编码器(AE)来进行图像重建工作。咱们将应用 Federated Average 算法来合并每个客户端的 AE 权重,并爱护每个客户端的隐衷。上面是实现该案例的代码:
首先,咱们导入必要的库。
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
import syft as sy
而后,咱们定义自编码器的模型类。
class AE(nn.Module):
def __init__(self):
super(AE, self).__init__()
# 编码器
self.encoder = nn.Sequential(nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=7)
)
# 解码器
self.decoder = nn.Sequential(nn.ConvTranspose2d(64, 32, kernel_size=7),
nn.ReLU(),
nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.ReLU(),
nn.ConvTranspose2d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.Sigmoid())
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
接下来,咱们定义训练和测试函数。在训练函数中,咱们应用 PySyft 在每个客户端上训练 AE,并应用 Federated Average 算法在每个轮次完结时加权均匀客户端权重。在测试函数中,咱们应用联邦学习的模型进行图像重建,并计算测试损失。
# 训练函数
def train(model_ptr, optimizer, criterion, data_loader, device):
model_ptr.train()
for batch_idx, (data, _) in enumerate(data_loader):
# 发送数据到客户端
data = data.send(model_ptr.location)
target = data.clone().detach()
# 在客户端上进行训练
optimizer.zero_grad()
output = model_ptr(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# 获取客户端权重并加权均匀
model_ptr.weight.data = model_ptr.weight.data.get() + model_ptr.weight.grad.data
model_ptr.weight.grad.data.zero_()
# 将客户端权重加权均匀
model_ptr.weight.data /= len(data_loader)
接着下面的代码,咱们能够在测试函数中应用联邦学习的模型进行图像重建,并计算测试损失。
# 测试函数
def test(model_ptr, data_loader, device):
model_ptr.eval()
test_loss = 0
with torch.no_grad():
for data, _ in data_loader:
# 发送数据到客户端
data = data.send(model_ptr.location)
target = data.clone().detach()
# 应用联邦学习的模型进行图像重建
output = model_ptr(data)
test_loss += F.mse_loss(output.get(), target, reduction='sum').item()
# 计算均匀测试损失
test_loss /= len(data_loader.dataset)
return test_loss
当初,咱们能够开始构建联邦学习环境并进行训练了。首先,咱们创立虚构工人,并将其调配给不同的客户端。
# 创立虚构工人
hook = sy.TorchHook(torch)
workers = [sy.VirtualWorker(hook, id="worker{}".format(i)) for i in range(3)]
# 将数据调配给不同的客户端
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = MNIST(root='./data', train=True, download=True, transform=transform)
federated_train_loader = sy.FederatedDataLoader(train_data.federate(workers), batch_size=64, shuffle=True, num_workers=0, drop_last=True)
而后,咱们在每个客户端上训练 AE,并应用 Federated Average 算法进行加权均匀客户端权重。咱们训练 10 轮,并在每轮完结时计算并输入均匀测试损失。
# 初始化模型指针
model = AE().to(device)
model_ptr = model.send(workers[0])
# 设置超参数
criterion = nn.MSELoss()
learning_rate = 0.01
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
train(model_ptr, optimizer, criterion, federated_train_loader, device)
test_loss = test(model_ptr, federated_train_loader, device)
print('Epoch [{}/{}], Test Loss: {:.4f}'.format(epoch+1, num_epochs, test_loss))
# 获取加权均匀模型并在本地进行测试
avg_model_ptr = model_ptr.copy().move(workers[0])
avg_model_ptr.weight.data = torch.zeros_like(avg_model_ptr.weight.data)
avg_model_ptr.weight.requires_grad = False
for ptr in model_ptr.pointers():
avg_model_ptr.weight.data += ptr.weight.data / len(workers)
test_loss = test(avg_model_ptr, federated_train_loader, device)
print('Final Test Loss: {:.4f}'.format(test_loss))
这样,咱们就胜利地实现了一个根本的联邦学习案例,应用 PySyft 模仿了一个简略的图像重建工作。
本文由 mdnice 多平台公布