[pysyft-002]联邦学习pysyft从入门到精通--三个节点训练一个线性分类器
程序员文章站
2022-07-14 13:31:37
...
import syft as sy
import torch
from torch import nn
from torch import optim
"""
https://github.com/OpenMined/PySyft/blob/master/examples/tutorials/Part%2002%20-%20Intro%20to%20Federated%20Learning.ipynb
Part 02 - Intro to Federated Learning.ipynb
"""
"""
本例演示:
在A节点上有一个模型。B、C节点上分别有两个样本集。A节点把模型分别送到B和C节点上进行多轮训练。
本脚本运行在A节点上。
"""
#syft需要对pytorch做hook
hook = sy.TorchHook(torch)
#两个worker,每个worker是一个训练节点
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
#数据集,data是样本属性,target是样本类别标记
data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)
#数据集拆分成两部分,一部分发给bob训练,一部分发给alice训练。训练出两个模型。
#bob和alice都不知道对方的模型,bot和alice是独立的。
#给bob worker的数据
data_bob = data[0:2]
target_bob = target[0:2]
#给alice worker的数据
data_alice = data[2:]
target_alice = target[2:]
#把数据发给bob和alice,返回的是指针,指向bot和alice上的数据
p_data_bob = data_bob.send(bob)
p_data_alice = data_alice.send(alice)
p_target_bob = target_bob.send(bob)
p_target_alice = target_alice.send(alice)
#在正式环境上,可以通过其他方式上述数据的指针传过来。
#保存指针,至此,数据准备完成了,开始进行正式训练过程。
datasets = [(p_data_bob, p_target_bob), (p_data_alice, p_target_alice)]
#初始化一个线性分类器y = w_1*x_1+w_2*x_2+b
model = nn.Linear(2,1)
#sgd优化器
opt = optim.SGD(params=model.parameters(),lr=0.1)
#训练过程
def train():
#10次迭代
for iter in range(10):
for data,target in datasets:
#把上一轮训练好的模型,发给一个worker
model.send(data.location)
#梯度清零
opt.zero_grad()
#做预测
pred = model(data)
#计算loss
loss = ((pred - target)**2).sum()
#求导
loss.backward()
#更新模型参数
opt.step()
#更新模型
model.get()
#输出训练误差
print(loss.get())
#运行
train()