欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

[pysyft-005]联邦学习pysyft从入门到精通--使用plan

程序员文章站 2022-07-14 13:31:55
...
import torch
import torch.nn as nn
import torch.nn.functional as F
import syft as sy

'''
Part 8 - Introduction to Plans
http://localhost:8888/notebooks/git-home/github/PySyft/examples/tutorials/Part%2008%20-%20Introduction%20to%20Plans.ipynb
'''

'''
演示 plan
一个plan,表示若干个operation的组合,可以是一个函数,也可以是一个类。
plan可以发送到远程节点,可以异步执行。
'''


hook = sy.TorchHook(torch)
hook.local_worker.is_client_worker = False
server = hook.local_worker

x11 = torch.tensor([-1, 2.]).tag('input_data')
x12 = torch.tensor([1, -2.]).tag('input_data2')
x21 = torch.tensor([-1, 2.]).tag('input_data')
x22 = torch.tensor([1, -2.]).tag('input_data2')

#创建远程节点
device_1 = sy.VirtualWorker(hook, id="device_1", data=(x11, x12))
device_2 = sy.VirtualWorker(hook, id="device_2", data=(x21, x22))
devices = device_1, device_2

#plan是一个函数
@sy.func2plan()
def plan_double_abs(x):
    x = x + x
    x = torch.abs(x)
    return x


def test_func_plan():
    #plan在运行前要先build
    print(plan_double_abs.is_built)
    plan_double_abs.build(torch.tensor([1., -2.]))
    print(plan_double_abs.is_built)


    #把plan发送给远程节点
    pointer_plan = plan_double_abs.send(device_1)
                      
                      
    #远程执行build
    pointer_to_data = device_1.search('input_data')[0]
    pointer_to_result = pointer_plan(pointer_to_data)
    print(pointer_to_result)
    pointer_to_result.get()


#plan是个类
class Net(sy.Plan):
        def __init__(self):
            super(Net, self).__init__()
            self.fc1 = nn.Linear(2, 3)
            self.fc2 = nn.Linear(3, 2)

        def forward(self, x):
            x = F.relu(self.fc1(x))
            x = self.fc2(x)
            return F.log_softmax(x, dim=0)

def test_class_plan():
    #实例化net
    net = Net()
    #发送先build
    net.build(torch.tensor([1., 2.]))
    #发送到远程节点
    pointer_to_net = net.send(device_1)
    #用plan做计算
    pointer_to_data = device_1.search('input_data')[0]
    pointer_to_result = pointer_to_net(pointer_to_data)
    #输出结果
    print(pointer_to_result)
    print(pointer_to_result.get())

    
if __name__ == '__main__':
    test_func_plan()
    test_class_plan()

 

相关标签: 联邦学习 pysyft