欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

[pysyft-002]联邦学习pysyft从入门到精通--三个节点训练一个线性分类器

程序员文章站 2022-07-14 13:31:37
...
import syft as sy
import torch
from torch import nn
from torch import optim

"""
https://github.com/OpenMined/PySyft/blob/master/examples/tutorials/Part%2002%20-%20Intro%20to%20Federated%20Learning.ipynb
Part 02 - Intro to Federated Learning.ipynb
"""

"""
本例演示:
在A节点上有一个模型。B、C节点上分别有两个样本集。A节点把模型分别送到B和C节点上进行多轮训练。
本脚本运行在A节点上。
"""

#syft需要对pytorch做hook
hook = sy.TorchHook(torch)

#两个worker,每个worker是一个训练节点
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")

#数据集,data是样本属性,target是样本类别标记
data = torch.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
target = torch.tensor([[0],[0],[1],[1.]], requires_grad=True)

#数据集拆分成两部分,一部分发给bob训练,一部分发给alice训练。训练出两个模型。
#bob和alice都不知道对方的模型,bot和alice是独立的。

#给bob worker的数据
data_bob = data[0:2]
target_bob = target[0:2]

#给alice worker的数据
data_alice = data[2:]
target_alice = target[2:]

#把数据发给bob和alice,返回的是指针,指向bot和alice上的数据 
p_data_bob = data_bob.send(bob)
p_data_alice = data_alice.send(alice)
p_target_bob = target_bob.send(bob)
p_target_alice = target_alice.send(alice)

#在正式环境上,可以通过其他方式上述数据的指针传过来。
#保存指针,至此,数据准备完成了,开始进行正式训练过程。
datasets = [(p_data_bob, p_target_bob), (p_data_alice, p_target_alice)]

#初始化一个线性分类器y = w_1*x_1+w_2*x_2+b
model = nn.Linear(2,1)

#sgd优化器
opt = optim.SGD(params=model.parameters(),lr=0.1)

#训练过程
def train():
    #10次迭代
    for iter in range(10):
         for data,target in datasets:
             #把上一轮训练好的模型,发给一个worker
             model.send(data.location)
             #梯度清零
             opt.zero_grad()
             #做预测
             pred = model(data)
             #计算loss
             loss = ((pred - target)**2).sum()
             #求导
             loss.backward()
             #更新模型参数
             opt.step()
             #更新模型
             model.get()
             #输出训练误差
             print(loss.get())

#运行
train()

 

相关标签: 联邦学习 pysyft