[pysyft-005]联邦学习pysyft从入门到精通--使用plan
程序员文章站
2022-07-14 13:31:55
...
import torch
import torch.nn as nn
import torch.nn.functional as F
import syft as sy
'''
Part 8 - Introduction to Plans
http://localhost:8888/notebooks/git-home/github/PySyft/examples/tutorials/Part%2008%20-%20Introduction%20to%20Plans.ipynb
'''
'''
演示 plan
一个plan,表示若干个operation的组合,可以是一个函数,也可以是一个类。
plan可以发送到远程节点,可以异步执行。
'''
hook = sy.TorchHook(torch)
hook.local_worker.is_client_worker = False
server = hook.local_worker
x11 = torch.tensor([-1, 2.]).tag('input_data')
x12 = torch.tensor([1, -2.]).tag('input_data2')
x21 = torch.tensor([-1, 2.]).tag('input_data')
x22 = torch.tensor([1, -2.]).tag('input_data2')
#创建远程节点
device_1 = sy.VirtualWorker(hook, id="device_1", data=(x11, x12))
device_2 = sy.VirtualWorker(hook, id="device_2", data=(x21, x22))
devices = device_1, device_2
#plan是一个函数
@sy.func2plan()
def plan_double_abs(x):
x = x + x
x = torch.abs(x)
return x
def test_func_plan():
#plan在运行前要先build
print(plan_double_abs.is_built)
plan_double_abs.build(torch.tensor([1., -2.]))
print(plan_double_abs.is_built)
#把plan发送给远程节点
pointer_plan = plan_double_abs.send(device_1)
#远程执行build
pointer_to_data = device_1.search('input_data')[0]
pointer_to_result = pointer_plan(pointer_to_data)
print(pointer_to_result)
pointer_to_result.get()
#plan是个类
class Net(sy.Plan):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(2, 3)
self.fc2 = nn.Linear(3, 2)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=0)
def test_class_plan():
#实例化net
net = Net()
#发送先build
net.build(torch.tensor([1., 2.]))
#发送到远程节点
pointer_to_net = net.send(device_1)
#用plan做计算
pointer_to_data = device_1.search('input_data')[0]
pointer_to_result = pointer_to_net(pointer_to_data)
#输出结果
print(pointer_to_result)
print(pointer_to_result.get())
if __name__ == '__main__':
test_func_plan()
test_class_plan()