欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

SiamRPN++测试过程

程序员文章站 2024-03-05 15:39:13
...
import argparse
import collections
import datetime
import imp
import os
import pickle
import time
import lmdb
import ipdb
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from tensorboardX import SummaryWriter
from torch.autograd import Variable
from torch.nn.modules.module import Module
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from dataloader.dataset import ImagnetVIDDataset
from network.SiamRPN import *
from utils.AverageMeter import AverageMeter
from utils.Logger import Logger
from utils.loss import rpn_cross_entropy_balance, rpn_smoothL1

net = SiamRPN()
params = torch.load('your_own_trained_weight.pth')
net.load_state_dict(params['network'])

a = torch.Tensor(1,3,127,127)

torch.nn.init.normal_(a,mean=0,std=1)

b = torch.Tensor(1,3,255,255)

torch.nn.init.normal_(a,mean=0,std=1)

aa,bb = net(a,b)

print(aa.shape)

print(bb.shape)

aa输出为cls分类的tensor(1,10,25,25)

bb输出为reg分类的tensor(1,20,25,25)

相关标签: SiamRPN_plus_plus