tf预训练模型转换为torch预训练模型
程序员文章站
2022-06-13 16:03:22
...
在将albert的tensorflow预训练模型转换为 torch类型预训练模型,踩了很多坑。终于解决,希望对大家有用
-
前期准备
创建一个环境带有torch和tf的环境,步骤如下:
首先创建环境python conda create -n torchtf_env python=3.7
然后,安装torch(根据自己电脑的cuda安装)python conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch -c conda-forge
之后,继续安装tensorflow-gpu版本python conda install tensorflow-gpu==1.15
最后安装transformers
pip install transformers
2 .从github上下载tensorflow预训练的albert版本
#! usr/bin/env python3
# -*- coding:utf-8 -*-
"""
Created on 19/03/2021 20:22
@Author: lixj
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import torch
from transformers import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
import logging
logging.basicConfig(level=logging.INFO)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
# Initialise PyTorch model
config = AlbertConfig.from_pretrained(bert_config_file)
# print("Building PyTorch model from configuration: {}".format(str(config)))
model = AlbertForPreTraining(config)
# Load weights from tf checkpoint
load_tf_weights_in_albert(model, config, tf_checkpoint_path)
# Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--tf_checkpoint_path", default='albert_base_en/model.ckpt-best', type=str, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--bert_config_file",
default='albert_base_en/albert_config.json',
type=str,
help="The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.",
)
parser.add_argument(
"--pytorch_dump_path", default='albert_base_en/pytorch_model.bin', type=str,help="Path to the output PyTorch model."
)
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)
上一篇: 健康的保护神——漫淡视保屏