欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

【BERT for Tensorflow】本地ckpt文件的BERT使用

程序员文章站 2022-03-26 16:45:07
本文你将学到:- 如何将官方ckpt文件转为pytorch.bin以供pytorch/tensorflow使用- 如何在BERT的基础上拼接模型解决下游任务...

本地ckpt文件的BERT使用


摘要

本文你将学到:

  • 如何将官方ckpt文件转为pytorch.bin以供pytorch/tensorflow使用
  • 如何在BERT的基础上拼接模型解决下游任务

BERT官方ckpt文件

首先,下载好BERT官方文件,如uncased_L-12_H-768_A-12

使用如下文件代码convert_bert_original_tf_checkpoint_to_pytorch.py

# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BERT checkpoint."""


import argparse

import torch

from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
from transformers.utils import logging


logging.set_verbosity_info()


def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
    # Initialise PyTorch model
    config = BertConfig.from_json_file(bert_config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))
    model = BertForPreTraining(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_bert(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Required parameters
    parser.add_argument(
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
    )
    parser.add_argument(
        "--bert_config_file",
        default=None,
        type=str,
        required=True,
        help="The config json file corresponding to the pre-trained BERT model. \n"
        "This specifies the model architecture.",
    )
    parser.add_argument(
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
    )
    args = parser.parse_args()
    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)

ckpt to bin

之后,在命令行中输入

python convert_bert_original_tf_checkpoint_to_pytorch.py --tf_checkpoint_path bert_model.ckpt.index --bert_config_file bert_config.json  --pytorch_dump_path  pytorch_model.bin

注意,在Windows命令行中小心\换行符的不匹配而出问题,故使用上面一行更安全,用空格代替\换行符

python convert_bert_original_tf_checkpoint_to_pytorch.py \
--tf_checkpoint_path Models/chinese_L-12_H-768_A-12/bert_model.ckpt.index \
--bert_config_file Models/chinese_L-12_H-768_A-12/bert_config.json  \
--pytorch_dump_path  Models/chinese_L-12_H-768_A-12/pytorch_model.bin

之后,你就可以得到pytorch_model.bin,将这个文件复制到ckpt文件夹
【BERT for Tensorflow】本地ckpt文件的BERT使用

Tensorflow Fine-Tune(or whatever:)

最后,就可以通过Tensorflow加载使用

加载部分:

from transformers import BertConfig,TFBertModel
import os

pretrained_path = "../input/uncased_L-12_H-768_A-12/"
config_path = os.path.join(pretrained_path,"bert_config.json")
checkpoint_path = os.path.join(pretrained_path,"bert_model.ckpt")
vocab_path = os.path.join(pretrained_path,'vocab.txt')
 

# 加载config
config = BertConfig.from_json_file(config_path)
# 加载原始模型
tfbert_model1 = TFBertModel.from_pretrained(pretrained_path,from_pt=True, config=config)

# # 加载分类模型
# tfbert_model2 = TFBertForSequenceClassification.from_pretrained(pretrained_path, from_pt=True, config=config)

自定义模型部分
本文使用BERT + Bi-LSTM实现文本三分类任务

    # Encoded token ids from BERT tokenizer.
    input_ids = tf.keras.layers.Input(
        shape=(max_length,), dtype=tf.int32, name="input_ids"
    )
    # Attention masks indicates to the model which tokens should be attended to.
    attention_masks = tf.keras.layers.Input(
        shape=(max_length,), dtype=tf.int32, name="attention_masks"
    )
    # Token type ids are binary masks identifying different sequences in the model.
    token_type_ids = tf.keras.layers.Input(
        shape=(max_length,), dtype=tf.int32, name="token_type_ids"
    )
    # Loading pretrained BERT model.
    bert_model = TFBertModel.from_pretrained(pretrained_path,from_pt=True, config=config)
    
    # Freeze the BERT model to reuse the pretrained features without modifying them.
    bert_model.trainable = False
	'''从这开始,自定义!!!'''
    sequence_output, pooled_output = bert_model(
        input_ids, attention_mask=attention_masks, token_type_ids=token_type_ids
    )
    # Add trainable layers on top of frozen layers to adapt the pretrained features on the new data.
    bi_lstm = tf.keras.layers.Bidirectional(
        tf.keras.layers.LSTM(64, return_sequences=True)
    )(sequence_output)
    # Applying hybrid pooling approach to bi_lstm sequence output.
    avg_pool = tf.keras.layers.GlobalAveragePooling1D()(bi_lstm)
    max_pool = tf.keras.layers.GlobalMaxPooling1D()(bi_lstm)
    concat = tf.keras.layers.concatenate([avg_pool, max_pool])
    dropout = tf.keras.layers.Dropout(0.3)(concat)
    output = tf.keras.layers.Dense(3, activation="softmax")(dropout)
    model = tf.keras.models.Model(
        inputs=[input_ids, attention_masks, token_type_ids], outputs=output
    )

    model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss="categorical_crossentropy",
        metrics=["acc"],
    )

【BERT for Tensorflow】本地ckpt文件的BERT使用
参照上述代码'''从这开始,自定义!!!'''之后,*拼接各种下游任务模型,最后通过model.summary()查看模型组成

本文地址:https://blog.csdn.net/qq_44574333/article/details/109631559