欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch-------测试(18)

程序员文章站 2022-06-11 22:22:18
...
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

logits = torch.rand(4,10)
pred = F.softmax(logits,dim=1)
print(pred.shape)

pred_label = pred.argmax(dim=1)
print(pred_label)
#argmax选取最大值的索引,与是否经过softmax的结果一样
print(logits.argmax(dim=1))

label = torch.tensor([5,5,5,5])
correct = torch.eq(pred_label,label)
print(correct)
#计算准确率
print(correct.sum().float().item()/4)

torch.Size([4, 10])
tensor([5, 7, 9, 9])
tensor([5, 7, 9, 9])
tensor([1, 0, 0, 0], dtype=torch.uint8)
0.25
    for data, target in test_loader:
        data = data.view(-1, 28 * 28)
        data, target = data.to(device), target.cuda()
        logits = net(data)
        test_loss += criteon(logits, target).item()

        pred = logits.data.max(1)[1]
        correct += pred.eq(target.data).sum()

    test_loss /= len(test_loader.dataset)