欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

HuggingFace的DistilBERT学习笔记-MyToolkit

程序员文章站 2022-03-05 13:21:59
HuggingFace的DistilBERT学习笔记-顺序版学习DistillBERT My Toolkitpython 代码规范os模块argparse 模块logging 模块shutil 模块深浅拷贝 地址引用还是值引用 函数 循环读写 json pickle numpy torch等todo torch的DistributedParall多机多卡训练todo from torch.utils.data import BatchSampler, DataLoader, RandomSampler, D...

DistillBERT My Toolkit

学习Facebook的DistillBERT中所使用的工具包

python 代码规范

函数名、文件名、变量名:big_apple
类名:BigApple
导入argument后要进行sanity_checks(args),包括 todo

os模块

# os.path.dirname/abspath/isdir/isfile/exits/join/curdir/
# os.getcwd() get current work directory
# os.mkdir("dir")
import os
shell_path = os.getcwd()
file_path_name = os.path.abspath(__file__)
file_path = os.path.dirname(os.path.abspath(__file__))
print(shell_path) #/home/zhangmengyu/distill
print(file_path_name) #/home/zhangmengyu/distill/test.py
print(file_path) #/home/zhangmengyu/distill
print(os.path.isfile(file_path_name)) #True
print(os.path.isdir(file_path)) #True

# create logs file folder
logs_dir = os.path.join(os.getcwd(), "logs") # '/home/zhangmengyu/distill/logs'
logs_dir = os.path.join(os.path.curdir, "logs") # './logs'
if os.path.exists(logs_dir) and os.path.isdir(logs_dir):
    pass
else:
    os.mkdir(logs_dir)

argparse 模块

主要用到了 type default required choices=[1,2] action=“store_true” help等参数
但是要注意,如果用了action=“store_true”,那么这个flag出现就是true,不出现就是false。

import argparse
parser = argparse.ArgumentParser(description="study argparse module")
parser.add_argument("--bool_arg", action="store_true", help="bool arg: use --bool_arg")
parser.add_argument("--str_arg", type=str, default="default arg", required=True, choices=["test1", "test2"], help="str arg: use --str_arg str")
parser.add_argument("--int_arg", type=int, default="default arg", required=True, choices=[1,2], help="str arg: use --int_arg int")
args = parser.parse_args()
print(args) # Namespace(bool_arg=True, int_arg=1, str_arg='test1')
# shell$ python test.py - -bool_arg - -str_arg test1 - -int_arg 1

logging 模块

用的比较多的是Logger Formatter Handler

import logging
import os
import logging.handlers

LEVELS = {'NOSET': logging.NOTSET,
          'DEBUG': logging.DEBUG,
          'INFO': logging.INFO,
          'WARNING': logging.WARNING,
          'ERROR': logging.ERROR,
          'CRITICAL': logging.CRITICAL}

## choice 1 可以同时支持输出到console 输出到文件 输出到回滚文件##
# create logs file folder
logs_dir = os.path.join(os.getcwd(), "logs") # '/home/zhangmengyu/distill/logs'
logs_dir = os.path.join(os.path.curdir, "logs") # './logs'
if os.path.exists(logs_dir) and os.path.isdir(logs_dir):
    pass
else:
    os.mkdir(logs_dir)

# init logger
logger = logging.getLogger(__name__)
formatter = logging.Formatter("%(asctime)s - %(levelname)8s - %(name)10s - PID: %(process)d -  %(message)s")
logger.setLevel(logging.INFO)

# define a rotating file handler
rotatingFileHandler = logging.handlers.RotatingFileHandler(
    filename="test_rotating.txt",
    maxBytes=1024 * 1024 * 50,
    backupCount=5,
) # 新的run的log 会 append 到旧的里面
rotatingFileHandler.setFormatter(formatter)
logger.addHandler(rotatingFileHandler)

# define a handler whitch writes messages to sys
console = logging.StreamHandler()
console.setFormatter(formatter)
logger.addHandler(console)

# define a file handler
FileHandler = logging.FileHandler(
    filename="test_file.txt",
    mode="w",
)
FileHandler.setFormatter(formatter)
logger.addHandler(FileHandler)

## choice 2 ##
# 或者可以使用 basic config
# 要么输出到console 要么输出到文件 根据filename参数是否为空来决定 默认filemode="a"
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d -  %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

## use logger example ##
logger.info(type(a))
logger.info(f"Saaaapecial tokens {a}")
logger.error("error XXX")

shutil 模块

shutil.rmtree(path) 递归的删除文件夹

深浅拷贝 地址引用还是值引用 函数 循环

todo这个内容理解的还不深刻,每次用到都会忘记。
核心是?引用的都是地址,不管是for aa in a还是函数,只不过当对地址里面的值进行操作的时候。如果遇到【数字 string tuple】不可变类型,则不会更改,如果遇到【list dict】可变类型则会修改。
有个例外,不管什么时候,不管是在for循环里还是在函数体里还是在main中,对变量进行重新赋值都会重新开辟空间,不会对原来的值进行修改。

a=[1,2,[3,4]]
b=a
b[0]=0
b[2][0]=0
a[1]=0
print(id(b))
print(id(a))
print(a) #[0, 0, [0, 4]]
print(b) #[0, 0, [0, 4]]

aa=0
a = [1, 2, 3]
for aa in a:
    aa = 0
print(a) # [1,2,3] 

for i in range(len(a)):
    a[i] = 0
print(a) # [0, 0, 0]

读写 json pickle numpy torch等

json和pickle读写方式类似
numpy和torch读写方式类似
正常存储时把data存到f中,这就是正常的顺序
但是numpy 由于有一个np.savez(f, key_a=data_a, key_b=data_b)所以会导致numpy的顺序和其他三个不同

json
dumps是将dict转化成json字符串格式,loads是将json字符串转化成dict格式。
dump和load也是类似的功能,只是与文件操作结合起来了。
只能序列化python的object

with open(os.path.join(args.dump_path, "parameters.json"), "w") as f:
    json.dump(vars(args), f, indent=4)
with open('data.json','r') as f:
    data = json.load(f)

pickle
pickle序列化后的数据,可读性差,人一般无法识别。
可以序列化函数和类,但是要让pickle找到类或函数的定义

with open("test.txt","wb") as f:
    pickle.dump(a, f) #重点在于rb和wb 二进制形式dump和load
with open("test.txt","rb") as f:
    d = pickle.load(f)

numpy 要注意的是后缀npy npz要写对 不然会自动加上去

    import numpy as np
    a = np.random.random_sample([10,5])
    np.save("./test.npy", a)
    a_load = np.load("test.npy")
    b = np.random.randint([2,1])
    np.savez("./test.npz", aa=a, bb=b)
    data=np.load("./test.npz")
    print(data["aa"])
    print(data["bb"])

torch

torch.save(data, path)
data = torch.load(path)

todo torch的DistributedParall多机多卡训练

from torch.utils.data.distributed import DistributedSampler
init_gpu_params(args)
set_seed(args)

todo from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Dataset

看了一部分的Dataset和DataLoader,学到的知识暂时解了迷惑,但是还有很多地方没有学会,总要一点一点学习的
DataLoader是一个封装的类,可以看作一个黑盒子,这个黑盒子和DataLoadery一起配合使用,用来返回一个batch的tensor格式的数据。要注意什么时候是一个batch的tensor数据(即inputs.tensor,labels.tensor),什么时候是一个数据对(input, label),什么时候是一个list列表类型的数据([input_1, label_1], [input_2, label_2]…)

对DataSet不管__init__的数据格式是怎么存储的,只需要实现__len__(self)和__getitem__(self, index)(输出index对应的一个数据对)和collate_fn(也就是例子代码的batch_sequence,实际上是用于DataLoader的collate_fn)

DataLoader的数据处理流大概是这样的,首先根据Dataset的__len__(self)和__getitem__(self, index)(输出index对应的一个数据对)和batch_size和sampler,输出一个list列表类型的数据([input_1, label_1], [input_2, label_2]…),然后将这个list传递给collate_fn,collate_fn返回一个batch的tensor数据(即inputs.tensor,labels.tensor)

from torch.utils.data import BatchSampler, DataLoader, Dataset, RandomSampler

train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)

if params.n_gpu <= 1:
    sampler = RandomSampler(dataset)
else:
    sampler = DistributedSampler(dataset)

my_dataloader = DataLoader(
	dataset=train_lm_seq_dataset, 
	batch_sampler=sampler,  
	collate_fn=dataset.batch_sequences
	)

class LmSeqsDataset(Dataset):
    """Custom Dataset wrapping language modeling sequences.

    Each sample will be retrieved by indexing the list of token_ids and their corresponding lengths.

    Input:
    ------
        params: `NameSpace` parameters
        data: `List[np.array[int]]
    """

    def __init__(self, params, data):
        self.params = params
        self.token_ids = np.array(data)
        self.lengths = np.array([len(t) for t in data])

    def __getitem__(self, index):
        return (self.token_ids[index], self.lengths[index])

    def __len__(self):
        return len(self.lengths)

    def batch_sequences(self, batch):
        """
        Do the padding and transform into torch.tensor.
        """
        token_ids = [t[0] for t in batch]
        lengths = [t[1] for t in batch]
        assert len(token_ids) == len(lengths)

        # Max for paddings
        max_seq_len_ = max(lengths)

        # Pad token ids
        if self.params.mlm:
            pad_idx = self.params.special_tok_ids["pad_token"]
        else:
            pad_idx = self.params.special_tok_ids["unk_token"]
        tk_ = [list(t.astype(int)) + [pad_idx] * (max_seq_len_ - len(t)) for t in token_ids]
        assert len(tk_) == len(token_ids)
        assert all(len(t) == max_seq_len_ for t in tk_)

        tk_t = torch.tensor(tk_)  # (bs, max_seq_len_)
        lg_t = torch.tensor(lengths)  # (bs)
        return tk_t, lg_t

todo tqdm

这是一个看起来最高大上,用起来也最高大上的,学起来最简单的一个工具包呀
和for aa in a:是相同的,不同的是可以定义desc 可以定义disable=True则不显示进度条

from tqdm import tqdm
char_test = [11,22,33,44,55,66,77,88,99]
import time
tchar = tqdm(char_test,desc="desc", disable=False)
for t in tchar: # 或者for t in trange(100):
    time.sleep(1.0)
    print(f"processing {t}")
    # desc:  22%|████████▋         | 2/9 [00:02<00:07,  1.00s/it]processing 33

todo optimizer

本文地址:https://blog.csdn.net/weixin_43526074/article/details/107636816