欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch 扩充Tensor的维度

程序员文章站 2022-06-11 22:24:30
...
  • dataloader读进来的数据标签是torch.Size([n]),需要和数据做torch.cat操作,要保持两者的维度一致,变成torch.Size([n, 1])
  • 代码举例
    import torch
    import numpy as np
    
    label = torch.randperm(5)
    print(label)  # tensor([1, 3, 0, 4, 2])
    print(label.size()) # torch.Size([5])
    
    # 方法1
    label2 = label[:,np.newaxis]
    print(label2, label2.size())
    
    # tensor([[1],
            [3],
            [0],
            [4],
            [2]]) torch.Size([5, 1])
            
    # 方法2
    label3 = torch.unsqueeze(label,1)
    print(label3, label3.size())
    # tensor([[1],
            [3],
            [0],
            [4],
            [2]]) torch.Size([5, 1])