欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

tf预训练模型转换为torch预训练模型

程序员文章站 2022-06-13 16:03:22
...

在将albert的tensorflow预训练模型转换为 torch类型预训练模型,踩了很多坑。终于解决,希望对大家有用

  1. 前期准备
    创建一个环境带有torch和tf的环境,步骤如下:
    首先创建环境
    python conda create -n torchtf_env python=3.7
    然后,安装torch(根据自己电脑的cuda安装)
    python conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch -c conda-forge
    之后,继续安装tensorflow-gpu版本
    python conda install tensorflow-gpu==1.15
    最后安装transformers
    pip install transformers

2 .从github上下载tensorflow预训练的albert版本

#! usr/bin/env python3
# -*- coding:utf-8 -*-
"""
Created on 19/03/2021 20:22 
@Author: lixj
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import torch
from transformers import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
import logging
logging.basicConfig(level=logging.INFO)

def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
    # Initialise PyTorch model
    config = AlbertConfig.from_pretrained(bert_config_file)
    # print("Building PyTorch model from configuration: {}".format(str(config)))
    model = AlbertForPreTraining(config)
    # Load weights from tf checkpoint
    load_tf_weights_in_albert(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Required parameters
    parser.add_argument(
        "--tf_checkpoint_path", default='albert_base_en/model.ckpt-best', type=str,  help="Path to the TensorFlow checkpoint path."
    )
    parser.add_argument(
        "--bert_config_file",
        default='albert_base_en/albert_config.json',
        type=str,
        help="The config json file corresponding to the pre-trained BERT model. \n"
        "This specifies the model architecture.",
    )
    parser.add_argument(
        "--pytorch_dump_path", default='albert_base_en/pytorch_model.bin', type=str,help="Path to the output PyTorch model."
    )
    args = parser.parse_args()
    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)