欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch计算一个矩阵每一行分别与另一个矩阵的差值

程序员文章站 2022-07-12 19:52:42
...
import torch

x = torch.randint(10, 20, (3,4))  # [3,1,4]
y = torch.randint(0, 9, (2,4))   # [2,1,4]

print(x)
print(y)

z = x.unsqueeze(1) - y.unsqueeze(0) # p[3, 2, 4]
'''
tensor([[10, 19, 14, 14],
        [17, 15, 16, 12],
        [15, 16, 12, 17]])
tensor([[7, 3, 3, 8],
        [7, 8, 5, 3]])
tensor([[[ 3, 16, 11,  6],
         [ 3, 11,  9, 11]],

        [[10, 12, 13,  4],
         [10,  7, 11,  9]],

        [[ 8, 13,  9,  9],
         [ 8,  8,  7, 14]]])
'''

上面的代码就是计算x矩阵的每一行分别与y矩阵的元素的差值。