SiamRPN++测试过程
程序员文章站
2024-03-05 15:39:13
...
import argparse
import collections
import datetime
import imp
import os
import pickle
import time
import lmdb
import ipdb
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from tensorboardX import SummaryWriter
from torch.autograd import Variable
from torch.nn.modules.module import Module
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from dataloader.dataset import ImagnetVIDDataset
from network.SiamRPN import *
from utils.AverageMeter import AverageMeter
from utils.Logger import Logger
from utils.loss import rpn_cross_entropy_balance, rpn_smoothL1
net = SiamRPN()
params = torch.load('your_own_trained_weight.pth')
net.load_state_dict(params['network'])
a = torch.Tensor(1,3,127,127)
torch.nn.init.normal_(a,mean=0,std=1)
b = torch.Tensor(1,3,255,255)
torch.nn.init.normal_(a,mean=0,std=1)
aa,bb = net(a,b)
print(aa.shape)
print(bb.shape)
aa输出为cls分类的tensor(1,10,25,25)
bb输出为reg分类的tensor(1,20,25,25)