欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

PyTorch常用知识点

程序员文章站 2024-01-29 19:37:34
...

torch.cat

  • torch.cat是将两个张量(tensor)拼接在一起。
  • 使用torch.cat((A,B),dim)时,除拼接维数dim数值可不同外其余维数数值需相同,方能对齐。
  • 
    >>> import torch
    >>> A=torch.ones(2,3) #2x3的张量(矩阵)                                     
    >>> A
    tensor([[ 1.,  1.,  1.],
            [ 1.,  1.,  1.]])
    >>> B=2*torch.ones(4,3)#4x3的张量(矩阵)                                    
    >>> B
    tensor([[ 2.,  2.,  2.],
            [ 2.,  2.,  2.],
            [ 2.,  2.,  2.],
            [ 2.,  2.,  2.]])
    >>> C=torch.cat((A,B),0)#按维数0(行)拼接
    >>> C
    tensor([[ 1.,  1.,  1.],
             [ 1.,  1.,  1.],
             [ 2.,  2.,  2.],
             [ 2.,  2.,  2.],
             [ 2.,  2.,  2.],
             [ 2.,  2.,  2.]])
    >>> C.size()
    torch.Size([6, 3])
    >>> D=2*torch.ones(2,4) #2x4的张量(矩阵)
    >>> C=torch.cat((A,D),1)#按维数1(列)拼接
    >>> C
    tensor([[ 1.,  1.,  1.,  2.,  2.,  2.,  2.],
            [ 1.,  1.,  1.,  2.,  2.,  2.,  2.]])
    >>> C.size()
    torch.Size([2, 7])

 

torch.view

  • 把原先tensor中的数据按照行优先的顺序排成一个一维的数据(这里应该是因为要求地址是连续存储的),然后按照参数组合成其他维度的tensor。比如说是不管你原先的数据是[[[1,2,3],[4,5,6]]]还是[1,2,3,4,5,6],因为它们排成一维向量都是6个元素,所以只要view后面的参数一致,得到的结果都是一样的。比如:
  • >>>a=torch.Tensor([[[1,2,3],[4,5,6]]])
    >>>b=torch.Tensor([1,2,3,4,5,6])
    >>>print(a.view(1,6))
    >>>print(b.view(1,6))
    tensor([[1., 2., 3., 4., 5., 6.]]) 
    tensor([[1., 2., 3., 4., 5., 6.]]) 
    
    >>>a=torch.Tensor([[[1,2,3],[4,5,6]]])
    >>>print(a.view(3,2))
    tensor([[1., 2.],
            [3., 4.],
            [5., 6.]])
  • 上面相当于就是从1,2,3,4,5,6顺序的拿数组来填充需要的形状。
  • 另外,参数不可为空。参数中的-1就代表这个位置由其他位置的数字来推断,只要在不致歧义的情况的下,view参数就可以推断出来,也就是人可以推断出形状的情况下,view函数也可以推断出来。比如一个tensor的数据个数是6个,如果view(1,-1),我们就可以根据tensor的元素个数推断出-1代表6。而如果是view(-1,-1,2),人不知道怎么推断,机器也不知道。还有一种情况是人可以推断出来,但是机器推断不出来的:view(-1,-1,6),人可以知道-1都代表1,但是机器不允许同时有两个负1。如果没有-1,那么所有参数的乘积就要和tensor中元素的总个数一致了,否则就会出现错误。

torch.stack