欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch AttributeError: ‘tuple‘ object has no attribute ‘dim‘

程序员文章站 2022-03-03 14:36:48
...

构建模型之后训练报错:

Traceback (most recent call last):
  File "/home/user1/alexnet_test.py", line 117, in <module>
    main()
  File "/home/user1/alexnet_test.py", line 89, in main
    train_loss, train_acc, train_bacc = train(model, optimizer, train_loader, criterion, weights=attrWeights)
  File "/home/user1/pjs/0fea/multiTaskCNN/mt_test/train.py", line 61, in train
    outputs = model(inputs)  # list, len = 40, every one is torch.Size([32, 2])
  File "/home/user1/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/user1/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 150, in forward
    return self.module(*inputs[0], **kwargs[0])
  File "/home/user1/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/user1/pjs/0fea/multiTaskCNN/mt_test/models/alexnet.py", line 73, in forward
    out = [tower(h_shared) for tower in self.towers]
  File "/home/user1/pjs/0fea/multiTaskCNN/mt_test/models/alexnet.py", line 73, in <listcomp>
    out = [tower(h_shared) for tower in self.towers]
  File "/home/user1/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/user1/miniconda3/lib/python3.7/site-packages/torch/nn/modules/container.py", line 92, in forward
    input = module(input)
  File "/home/user1/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/user1/miniconda3/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 92, in forward
    return F.linear(input, self.weight, self.bias)
  File "/home/user1/miniconda3/lib/python3.7/site-packages/torch/nn/functional.py", line 1404, in linear
    if input.dim() == 2 and bias is not None:
AttributeError: 'tuple' object has no attribute 'dim'

如果网上的答案都不能解决你的问题,比如:
1,self.avgpool = nn.AdaptiveAvgPool2d((6, 6))应该是这么写。而不是这么写:self.avgpool = nn.AdaptiveAvgPool2d(6, 6)
2,网络层的返回值包含了indices,因为你指定了pool层的 return_indices=True,可以改为: nn.MaxPool2d(kernel_size=3, stride=2, return_indices=False),

3, inception_v3
4, lstm
、、、等等
如果上述都不对,哥们,你可能在forward函数中多写了逗号。。导致了tuple的出现,比如下面这种**的错误:

 def forward(self, x):
        x = self.features(x), # the devil comma is here
        h_shared  = torch.flatten(x, 1) # torch.Size([16, 9216])
        out = [tower(h_shared) for tower in self.towers]
        return out

我真是无语了,可能是copy paste造成的一个逗号费半天劲。。。