pytorch 扩充Tensor的维度
程序员文章站
2022-06-11 22:24:30
...
- dataloader读进来的数据标签是torch.Size([n]),需要和数据做torch.cat操作,要保持两者的维度一致,变成torch.Size([n, 1])
- 代码举例
import torch import numpy as np label = torch.randperm(5) print(label) # tensor([1, 3, 0, 4, 2]) print(label.size()) # torch.Size([5]) # 方法1 label2 = label[:,np.newaxis] print(label2, label2.size()) # tensor([[1], [3], [0], [4], [2]]) torch.Size([5, 1]) # 方法2 label3 = torch.unsqueeze(label,1) print(label3, label3.size()) # tensor([[1], [3], [0], [4], [2]]) torch.Size([5, 1])
上一篇: Pytorch官方指南(二) 翻译版本