pytorch计算一个矩阵每一行分别与另一个矩阵的差值
程序员文章站
2022-07-12 19:52:42
...
import torch
x = torch.randint(10, 20, (3,4)) # [3,1,4]
y = torch.randint(0, 9, (2,4)) # [2,1,4]
print(x)
print(y)
z = x.unsqueeze(1) - y.unsqueeze(0) # p[3, 2, 4]
'''
tensor([[10, 19, 14, 14],
[17, 15, 16, 12],
[15, 16, 12, 17]])
tensor([[7, 3, 3, 8],
[7, 8, 5, 3]])
tensor([[[ 3, 16, 11, 6],
[ 3, 11, 9, 11]],
[[10, 12, 13, 4],
[10, 7, 11, 9]],
[[ 8, 13, 9, 9],
[ 8, 8, 7, 14]]])
'''
上面的代码就是计算x矩阵的每一行分别与y矩阵的元素的差值。