pytorch:tensor的broadcast
程序员文章站
2022-07-13 10:08:08
...
broadcasting
# 按照尾部维度对齐
x=torch.ones(5,1,4,2)
y=torch.ones(3,1,1)
(x+y).size()
torch.Size([5, 3, 4, 2])
- 首先按尾部维度对齐 即1对2,1对4,3对1,然后没有维度能对5,就补一个维度,size为1,则1对5
- 每一对中,如果两个数字不一样,那么就把1变成另一个,比如说第一对是1对2,那么就把1变成2(代表着y在最后一个维度(列)上copy了一次)
又比如3对1,那么就把1变成3(代表着x在第二个维度上copy了两次)
1对5,那么就把1变成5(这个1是新增加的维度,所以1变成5就代表着把y再copy4次) - PS:如果一对数字中,两个数字不同,并且又没有1,那么就会报错
a = torch.arange(6).reshape(2,1,3)
a
tensor([[[0, 1, 2]],
[[3, 4, 5]]])
b = torch.tensor([[100],[200],[300]])
b
tensor([[100],
[200],
[300]])
a+b
tensor([[[100, 101, 102],
[200, 201, 202],
[300, 301, 302]],
[[103, 104, 105],
[203, 204, 205],
[303, 304, 305]]])
a = torch.arange(1,9).reshape(2,4)
a
tensor([[1, 2, 3, 4],
[5, 6, 7, 8]])
b = torch.arange(1,3).reshape(1,2)
b
tensor([[1, 2]])
- 现在a的size是24 b的size是12 如果进行broadcast,那么其中一对就是2对4,所以不能进行broadcast,那么怎么办呢?
先把b的第二个维度(列)copy一次,让他变成1*4
b=b.repeat(1,2)
a+b
tensor([[ 2, 4, 4, 6],
[ 6, 8, 8, 10]])
a = torch.full((2,3,2),90) #2个班级,每个班级3名学生,每个学生考了两门考试,假设目前都是90分
a
tensor([[[90., 90.],
[90., 90.],
[90., 90.]],
[[90., 90.],
[90., 90.],
[90., 90.]]])
- 现在我们想给第二个班级的所有人加5分,使用broadcast
b = torch.tensor([0.,5.]).reshape(2,1,1)
b
tensor([[[0.]],
[[5.]]])
a + b
tensor([[[90., 90.],
[90., 90.],
[90., 90.]],
[[95., 95.],
[95., 95.],
[95., 95.]]])
- 如果想给每个班的第1个同学加5分:
b = torch.tensor([5.,0.,0.]).reshape(1,3,1)
b
tensor([[[5.],
[0.],
[0.]]])
a + b
tensor([[[95., 95.],
[90., 90.],
[90., 90.]],
[[95., 95.],
[90., 90.],
[90., 90.]]])
- 如果想给每个同学的第1门功课加5分:
b = torch.tensor([5.,0.]).reshape(1,1,2)
b
tensor([[[5., 0.]]])
a + b
tensor([[[95., 90.],
[95., 90.],
[95., 90.]],
[[95., 90.],
[95., 90.],
[95., 90.]]])
上一篇: pytorch的tensor除法
下一篇: PyTorch中的Tensor