pytorch输入tensor看对应的输出数值
程序员文章站
2022-07-13 10:08:32
...
pytorch输入相应的数值,查看对应的输出数值与tensorflow类似,这里放入一段pytorch由输入计算对应的输出内容的过程。
比如如下的代码内容
# coding:utf-8
import os
import pickle
import torch
import random
import warnings
import numpy as np
import pandas as pd
from tqdm import tqdm
from typing import List, Tuple, Dict
from collections import defaultdict
from torch.utils.data import Dataset
from transformers import (
BertTokenizer,
DataCollatorForLanguageModeling,
DataCollatorForWholeWordMask,
PreTrainedTokenizer, BertConfig
)
from transformers.utils import logging
from modeling.modeling_nezha.modeling import NeZhaForMaskedLM,NeZhaModel
from modeling.modeling_nezha.configuration import NeZhaConfig
from simple_trainer import Trainer
from pretrain_args import TrainingArguments
warnings.filterwarnings('ignore')
logger = logging.get_logger(__name__)
def seed_everything(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
return seed
def main():
"""
download pretrain model from https://github.com/lonePatient/NeZha_Chinese_PyTorch,
we only use pretrain model name : nezha-cn-base, nezha-base-wwm
"""
config = {
'pretrain_type': 'dynamic_mask', # dynamic_mask, whole_word_mask
'data_cache_path': '',
'train_data_path': '/home/xiaoguzai/数据/data/train.txt',
'test_data_path': '/home/xiaoguzai/数据/data/test.txt',
}
mlm_probability = 0.15
num_train_epochs = 1
seq_length = 90
batch_size = 32
learning_rate = 6e-5
save_steps = 5000
seed = 2021
config['data_cache_path'] = '../user_data/pretrain/'+config['pretrain_type']+'/data.pkl'
model_path = '/home/xiaoguzai/数据/nezha-chinese-base/pytorch_model.bin'
config_path = '/home/xiaoguzai/数据/nezha-chinese-base/config.json'
vocab_file = '/home/xiaoguzai/数据/nezha-chinese-base/vocab.txt'
tokenizer = BertTokenizer.from_pretrained(vocab_file)
model_config = NeZhaConfig.from_pretrained(config_path)
nezha = NeZhaModel(config=model_config)
input_ids = torch.tensor([[1,2],[3,4]])
output = nezha(input_ids)
print('output = ')
print(output)
if __name__ == '__main__':
main()
关键代码:
nezha = NeZhaModel(config=model_config)
input_ids = torch.tensor([[1,2],[3,4]])
output = nezha(input_ids)
从这段可以看出pytorch的计算过程和tensorflow的计算过程有类似之处,都是定义模型之后直接传入数值,即可获得相应的传出数值