pytorch-------测试(18)
程序员文章站
2022-06-11 22:22:18
...
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
logits = torch.rand(4,10)
pred = F.softmax(logits,dim=1)
print(pred.shape)
pred_label = pred.argmax(dim=1)
print(pred_label)
#argmax选取最大值的索引,与是否经过softmax的结果一样
print(logits.argmax(dim=1))
label = torch.tensor([5,5,5,5])
correct = torch.eq(pred_label,label)
print(correct)
#计算准确率
print(correct.sum().float().item()/4)
torch.Size([4, 10])
tensor([5, 7, 9, 9])
tensor([5, 7, 9, 9])
tensor([1, 0, 0, 0], dtype=torch.uint8)
0.25
for data, target in test_loader:
data = data.view(-1, 28 * 28)
data, target = data.to(device), target.cuda()
logits = net(data)
test_loss += criteon(logits, target).item()
pred = logits.data.max(1)[1]
correct += pred.eq(target.data).sum()
test_loss /= len(test_loader.dataset)
上一篇: YY语音中的游戏直播如何开启 YY游戏直播图文教程详细介绍
下一篇: 如何看待微信运营的数据?