【BERT for Tensorflow】本地ckpt文件的BERT使用
程序员文章站
2022-03-26 16:45:07
本文你将学到:- 如何将官方ckpt文件转为pytorch.bin以供pytorch/tensorflow使用- 如何在BERT的基础上拼接模型解决下游任务...
本地ckpt文件的BERT使用
摘要
本文你将学到:
- 如何将官方ckpt文件转为pytorch.bin以供pytorch/tensorflow使用
- 如何在BERT的基础上拼接模型解决下游任务
BERT官方ckpt文件
首先,下载好BERT官方文件,如uncased_L-12_H-768_A-12
使用如下文件代码convert_bert_original_tf_checkpoint_to_pytorch.py
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BERT checkpoint."""
import argparse
import torch
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
from transformers.utils import logging
logging.set_verbosity_info()
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
# Initialise PyTorch model
config = BertConfig.from_json_file(bert_config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = BertForPreTraining(config)
# Load weights from tf checkpoint
load_tf_weights_in_bert(model, config, tf_checkpoint_path)
# Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--bert_config_file",
default=None,
type=str,
required=True,
help="The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
ckpt to bin
之后,在命令行中输入
python convert_bert_original_tf_checkpoint_to_pytorch.py --tf_checkpoint_path bert_model.ckpt.index --bert_config_file bert_config.json --pytorch_dump_path pytorch_model.bin
注意,在Windows命令行中小心\换行符的不匹配而出问题,故使用上面一行更安全,用空格代替\换行符
python convert_bert_original_tf_checkpoint_to_pytorch.py \
--tf_checkpoint_path Models/chinese_L-12_H-768_A-12/bert_model.ckpt.index \
--bert_config_file Models/chinese_L-12_H-768_A-12/bert_config.json \
--pytorch_dump_path Models/chinese_L-12_H-768_A-12/pytorch_model.bin
之后,你就可以得到pytorch_model.bin,将这个文件复制到ckpt文件夹
Tensorflow Fine-Tune(or whatever:)
最后,就可以通过Tensorflow加载使用
加载部分:
from transformers import BertConfig,TFBertModel
import os
pretrained_path = "../input/uncased_L-12_H-768_A-12/"
config_path = os.path.join(pretrained_path,"bert_config.json")
checkpoint_path = os.path.join(pretrained_path,"bert_model.ckpt")
vocab_path = os.path.join(pretrained_path,'vocab.txt')
# 加载config
config = BertConfig.from_json_file(config_path)
# 加载原始模型
tfbert_model1 = TFBertModel.from_pretrained(pretrained_path,from_pt=True, config=config)
# # 加载分类模型
# tfbert_model2 = TFBertForSequenceClassification.from_pretrained(pretrained_path, from_pt=True, config=config)
自定义模型部分
本文使用BERT + Bi-LSTM实现文本三分类任务
# Encoded token ids from BERT tokenizer.
input_ids = tf.keras.layers.Input(
shape=(max_length,), dtype=tf.int32, name="input_ids"
)
# Attention masks indicates to the model which tokens should be attended to.
attention_masks = tf.keras.layers.Input(
shape=(max_length,), dtype=tf.int32, name="attention_masks"
)
# Token type ids are binary masks identifying different sequences in the model.
token_type_ids = tf.keras.layers.Input(
shape=(max_length,), dtype=tf.int32, name="token_type_ids"
)
# Loading pretrained BERT model.
bert_model = TFBertModel.from_pretrained(pretrained_path,from_pt=True, config=config)
# Freeze the BERT model to reuse the pretrained features without modifying them.
bert_model.trainable = False
'''从这开始,自定义!!!'''
sequence_output, pooled_output = bert_model(
input_ids, attention_mask=attention_masks, token_type_ids=token_type_ids
)
# Add trainable layers on top of frozen layers to adapt the pretrained features on the new data.
bi_lstm = tf.keras.layers.Bidirectional(
tf.keras.layers.LSTM(64, return_sequences=True)
)(sequence_output)
# Applying hybrid pooling approach to bi_lstm sequence output.
avg_pool = tf.keras.layers.GlobalAveragePooling1D()(bi_lstm)
max_pool = tf.keras.layers.GlobalMaxPooling1D()(bi_lstm)
concat = tf.keras.layers.concatenate([avg_pool, max_pool])
dropout = tf.keras.layers.Dropout(0.3)(concat)
output = tf.keras.layers.Dense(3, activation="softmax")(dropout)
model = tf.keras.models.Model(
inputs=[input_ids, attention_masks, token_type_ids], outputs=output
)
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss="categorical_crossentropy",
metrics=["acc"],
)
参照上述代码'''从这开始,自定义!!!'''
之后,*拼接各种下游任务模型,最后通过model.summary()查看模型组成
本文地址:https://blog.csdn.net/qq_44574333/article/details/109631559