欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch输入tensor看对应的输出数值

程序员文章站 2022-07-13 10:08:32
...

pytorch输入相应的数值,查看对应的输出数值与tensorflow类似,这里放入一段pytorch由输入计算对应的输出内容的过程。
比如如下的代码内容

# coding:utf-8
import os
import pickle

import torch
import random
import warnings
import numpy as np
import pandas as pd
from tqdm import tqdm
from typing import List, Tuple, Dict
from collections import defaultdict
from torch.utils.data import Dataset

from transformers import (
    BertTokenizer,
    DataCollatorForLanguageModeling,
    DataCollatorForWholeWordMask,
    PreTrainedTokenizer, BertConfig
)
from transformers.utils import logging

from modeling.modeling_nezha.modeling import NeZhaForMaskedLM,NeZhaModel
from modeling.modeling_nezha.configuration import NeZhaConfig
from simple_trainer import Trainer
from pretrain_args import TrainingArguments

warnings.filterwarnings('ignore')
logger = logging.get_logger(__name__)


def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    return seed

def main():
    """
    download pretrain model from https://github.com/lonePatient/NeZha_Chinese_PyTorch,
    we only use pretrain model name : nezha-cn-base, nezha-base-wwm
    """
    config = {
        'pretrain_type': 'dynamic_mask',  # dynamic_mask, whole_word_mask
        'data_cache_path': '',
        'train_data_path': '/home/xiaoguzai/数据/data/train.txt',
        'test_data_path': '/home/xiaoguzai/数据/data/test.txt',
    }

    mlm_probability = 0.15
    num_train_epochs = 1
    seq_length = 90
    batch_size = 32
    learning_rate = 6e-5
    save_steps = 5000
    seed = 2021

    config['data_cache_path'] = '../user_data/pretrain/'+config['pretrain_type']+'/data.pkl'

    model_path = '/home/xiaoguzai/数据/nezha-chinese-base/pytorch_model.bin'
    config_path = '/home/xiaoguzai/数据/nezha-chinese-base/config.json'

    vocab_file = '/home/xiaoguzai/数据/nezha-chinese-base/vocab.txt'
    tokenizer = BertTokenizer.from_pretrained(vocab_file)
    model_config = NeZhaConfig.from_pretrained(config_path)
    nezha = NeZhaModel(config=model_config)
    input_ids = torch.tensor([[1,2],[3,4]])
    output = nezha(input_ids)
    print('output = ')
    print(output)
if __name__ == '__main__':
    main()

关键代码:

nezha = NeZhaModel(config=model_config)
input_ids = torch.tensor([[1,2],[3,4]])
output = nezha(input_ids)

从这段可以看出pytorch的计算过程和tensorflow的计算过程有类似之处,都是定义模型之后直接传入数值,即可获得相应的传出数值