欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch:tensor的broadcast

程序员文章站 2022-07-13 10:08:08
...

broadcasting

# 按照尾部维度对齐
x=torch.ones(5,1,4,2)
y=torch.ones(3,1,1)
(x+y).size()
                   torch.Size([5, 3, 4, 2])
  • 首先按尾部维度对齐 即1对2,1对4,3对1,然后没有维度能对5,就补一个维度,size为1,则1对5
  • 每一对中,如果两个数字不一样,那么就把1变成另一个,比如说第一对是1对2,那么就把1变成2(代表着y在最后一个维度(列)上copy了一次)
    又比如3对1,那么就把1变成3(代表着x在第二个维度上copy了两次)
    1对5,那么就把1变成5(这个1是新增加的维度,所以1变成5就代表着把y再copy4次)
  • PS:如果一对数字中,两个数字不同,并且又没有1,那么就会报错
a = torch.arange(6).reshape(2,1,3)
a
tensor([[[0, 1, 2]],

        [[3, 4, 5]]])
b = torch.tensor([[100],[200],[300]])
b
tensor([[100],
        [200],
        [300]])
a+b
tensor([[[100, 101, 102],
         [200, 201, 202],
         [300, 301, 302]],

        [[103, 104, 105],
         [203, 204, 205],
         [303, 304, 305]]])
a = torch.arange(1,9).reshape(2,4)
a
tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])
b = torch.arange(1,3).reshape(1,2)
b
tensor([[1, 2]])
  • 现在a的size是24 b的size是12 如果进行broadcast,那么其中一对就是2对4,所以不能进行broadcast,那么怎么办呢?
    先把b的第二个维度(列)copy一次,让他变成1*4
b=b.repeat(1,2)
a+b
tensor([[ 2,  4,  4,  6],
        [ 6,  8,  8, 10]])


a = torch.full((2,3,2),90) #2个班级,每个班级3名学生,每个学生考了两门考试,假设目前都是90分
a
tensor([[[90., 90.],
         [90., 90.],
         [90., 90.]],

        [[90., 90.],
         [90., 90.],
         [90., 90.]]])
  • 现在我们想给第二个班级的所有人加5分,使用broadcast
b = torch.tensor([0.,5.]).reshape(2,1,1)
b
tensor([[[0.]],

        [[5.]]])
a + b 
tensor([[[90., 90.],
         [90., 90.],
         [90., 90.]],

        [[95., 95.],
         [95., 95.],
         [95., 95.]]])
  • 如果想给每个班的第1个同学加5分:
b = torch.tensor([5.,0.,0.]).reshape(1,3,1)
b
tensor([[[5.],
         [0.],
         [0.]]])
a + b
tensor([[[95., 95.],
         [90., 90.],
         [90., 90.]],

        [[95., 95.],
         [90., 90.],
         [90., 90.]]])
  • 如果想给每个同学的第1门功课加5分:
b = torch.tensor([5.,0.]).reshape(1,1,2)
b
tensor([[[5., 0.]]])
a + b
tensor([[[95., 90.],
         [95., 90.],
         [95., 90.]],

        [[95., 90.],
         [95., 90.],
         [95., 90.]]])

相关标签: pytorch