HuggingFace的DistilBERT学习笔记-MyToolkit
HuggingFace的DistilBERT学习笔记-顺序版学习
DistillBERT My Toolkit
学习Facebook的DistillBERT中所使用的工具包
python 代码规范
函数名、文件名、变量名:big_apple
类名:BigApple
导入argument后要进行sanity_checks(args),包括 todo
os模块
# os.path.dirname/abspath/isdir/isfile/exits/join/curdir/
# os.getcwd() get current work directory
# os.mkdir("dir")
import os
shell_path = os.getcwd()
file_path_name = os.path.abspath(__file__)
file_path = os.path.dirname(os.path.abspath(__file__))
print(shell_path) #/home/zhangmengyu/distill
print(file_path_name) #/home/zhangmengyu/distill/test.py
print(file_path) #/home/zhangmengyu/distill
print(os.path.isfile(file_path_name)) #True
print(os.path.isdir(file_path)) #True
# create logs file folder
logs_dir = os.path.join(os.getcwd(), "logs") # '/home/zhangmengyu/distill/logs'
logs_dir = os.path.join(os.path.curdir, "logs") # './logs'
if os.path.exists(logs_dir) and os.path.isdir(logs_dir):
pass
else:
os.mkdir(logs_dir)
argparse 模块
主要用到了 type default required choices=[1,2] action=“store_true” help等参数
但是要注意,如果用了action=“store_true”,那么这个flag出现就是true,不出现就是false。
import argparse
parser = argparse.ArgumentParser(description="study argparse module")
parser.add_argument("--bool_arg", action="store_true", help="bool arg: use --bool_arg")
parser.add_argument("--str_arg", type=str, default="default arg", required=True, choices=["test1", "test2"], help="str arg: use --str_arg str")
parser.add_argument("--int_arg", type=int, default="default arg", required=True, choices=[1,2], help="str arg: use --int_arg int")
args = parser.parse_args()
print(args) # Namespace(bool_arg=True, int_arg=1, str_arg='test1')
# shell$ python test.py - -bool_arg - -str_arg test1 - -int_arg 1
logging 模块
用的比较多的是Logger Formatter Handler
import logging
import os
import logging.handlers
LEVELS = {'NOSET': logging.NOTSET,
'DEBUG': logging.DEBUG,
'INFO': logging.INFO,
'WARNING': logging.WARNING,
'ERROR': logging.ERROR,
'CRITICAL': logging.CRITICAL}
## choice 1 可以同时支持输出到console 输出到文件 输出到回滚文件##
# create logs file folder
logs_dir = os.path.join(os.getcwd(), "logs") # '/home/zhangmengyu/distill/logs'
logs_dir = os.path.join(os.path.curdir, "logs") # './logs'
if os.path.exists(logs_dir) and os.path.isdir(logs_dir):
pass
else:
os.mkdir(logs_dir)
# init logger
logger = logging.getLogger(__name__)
formatter = logging.Formatter("%(asctime)s - %(levelname)8s - %(name)10s - PID: %(process)d - %(message)s")
logger.setLevel(logging.INFO)
# define a rotating file handler
rotatingFileHandler = logging.handlers.RotatingFileHandler(
filename="test_rotating.txt",
maxBytes=1024 * 1024 * 50,
backupCount=5,
) # 新的run的log 会 append 到旧的里面
rotatingFileHandler.setFormatter(formatter)
logger.addHandler(rotatingFileHandler)
# define a handler whitch writes messages to sys
console = logging.StreamHandler()
console.setFormatter(formatter)
logger.addHandler(console)
# define a file handler
FileHandler = logging.FileHandler(
filename="test_file.txt",
mode="w",
)
FileHandler.setFormatter(formatter)
logger.addHandler(FileHandler)
## choice 2 ##
# 或者可以使用 basic config
# 要么输出到console 要么输出到文件 根据filename参数是否为空来决定 默认filemode="a"
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
## use logger example ##
logger.info(type(a))
logger.info(f"Saaaapecial tokens {a}")
logger.error("error XXX")
shutil 模块
shutil.rmtree(path) 递归的删除文件夹
深浅拷贝 地址引用还是值引用 函数 循环
todo这个内容理解的还不深刻,每次用到都会忘记。
核心是?引用的都是地址,不管是for aa in a还是函数,只不过当对地址里面的值进行操作的时候。如果遇到【数字 string tuple】不可变类型,则不会更改,如果遇到【list dict】可变类型则会修改。
有个例外,不管什么时候,不管是在for循环里还是在函数体里还是在main中,对变量进行重新赋值都会重新开辟空间,不会对原来的值进行修改。
a=[1,2,[3,4]]
b=a
b[0]=0
b[2][0]=0
a[1]=0
print(id(b))
print(id(a))
print(a) #[0, 0, [0, 4]]
print(b) #[0, 0, [0, 4]]
aa=0
a = [1, 2, 3]
for aa in a:
aa = 0
print(a) # [1,2,3]
for i in range(len(a)):
a[i] = 0
print(a) # [0, 0, 0]
读写 json pickle numpy torch等
json和pickle读写方式类似
numpy和torch读写方式类似
正常存储时把data存到f中,这就是正常的顺序
但是numpy 由于有一个np.savez(f, key_a=data_a, key_b=data_b)所以会导致numpy的顺序和其他三个不同
json
dumps是将dict转化成json字符串格式,loads是将json字符串转化成dict格式。
dump和load也是类似的功能,只是与文件操作结合起来了。
只能序列化python的object
with open(os.path.join(args.dump_path, "parameters.json"), "w") as f:
json.dump(vars(args), f, indent=4)
with open('data.json','r') as f:
data = json.load(f)
pickle
pickle序列化后的数据,可读性差,人一般无法识别。
可以序列化函数和类,但是要让pickle找到类或函数的定义
with open("test.txt","wb") as f:
pickle.dump(a, f) #重点在于rb和wb 二进制形式dump和load
with open("test.txt","rb") as f:
d = pickle.load(f)
numpy 要注意的是后缀npy npz要写对 不然会自动加上去
import numpy as np
a = np.random.random_sample([10,5])
np.save("./test.npy", a)
a_load = np.load("test.npy")
b = np.random.randint([2,1])
np.savez("./test.npz", aa=a, bb=b)
data=np.load("./test.npz")
print(data["aa"])
print(data["bb"])
torch
torch.save(data, path)
data = torch.load(path)
todo torch的DistributedParall多机多卡训练
from torch.utils.data.distributed import DistributedSampler
init_gpu_params(args)
set_seed(args)
todo from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Dataset
看了一部分的Dataset和DataLoader,学到的知识暂时解了迷惑,但是还有很多地方没有学会,总要一点一点学习的
DataLoader是一个封装的类,可以看作一个黑盒子,这个黑盒子和DataLoadery一起配合使用,用来返回一个batch的tensor格式的数据。要注意什么时候是一个batch的tensor数据(即inputs.tensor,labels.tensor),什么时候是一个数据对(input, label),什么时候是一个list列表类型的数据([input_1, label_1], [input_2, label_2]…)
对DataSet不管__init__的数据格式是怎么存储的,只需要实现__len__(self)和__getitem__(self, index)(输出index对应的一个数据对)和collate_fn(也就是例子代码的batch_sequence,实际上是用于DataLoader的collate_fn)
DataLoader的数据处理流大概是这样的,首先根据Dataset的__len__(self)和__getitem__(self, index)(输出index对应的一个数据对)和batch_size和sampler,输出一个list列表类型的数据([input_1, label_1], [input_2, label_2]…),然后将这个list传递给collate_fn,collate_fn返回一个batch的tensor数据(即inputs.tensor,labels.tensor)
from torch.utils.data import BatchSampler, DataLoader, Dataset, RandomSampler
train_lm_seq_dataset = LmSeqsDataset(params=args, data=data)
if params.n_gpu <= 1:
sampler = RandomSampler(dataset)
else:
sampler = DistributedSampler(dataset)
my_dataloader = DataLoader(
dataset=train_lm_seq_dataset,
batch_sampler=sampler,
collate_fn=dataset.batch_sequences
)
class LmSeqsDataset(Dataset):
"""Custom Dataset wrapping language modeling sequences.
Each sample will be retrieved by indexing the list of token_ids and their corresponding lengths.
Input:
------
params: `NameSpace` parameters
data: `List[np.array[int]]
"""
def __init__(self, params, data):
self.params = params
self.token_ids = np.array(data)
self.lengths = np.array([len(t) for t in data])
def __getitem__(self, index):
return (self.token_ids[index], self.lengths[index])
def __len__(self):
return len(self.lengths)
def batch_sequences(self, batch):
"""
Do the padding and transform into torch.tensor.
"""
token_ids = [t[0] for t in batch]
lengths = [t[1] for t in batch]
assert len(token_ids) == len(lengths)
# Max for paddings
max_seq_len_ = max(lengths)
# Pad token ids
if self.params.mlm:
pad_idx = self.params.special_tok_ids["pad_token"]
else:
pad_idx = self.params.special_tok_ids["unk_token"]
tk_ = [list(t.astype(int)) + [pad_idx] * (max_seq_len_ - len(t)) for t in token_ids]
assert len(tk_) == len(token_ids)
assert all(len(t) == max_seq_len_ for t in tk_)
tk_t = torch.tensor(tk_) # (bs, max_seq_len_)
lg_t = torch.tensor(lengths) # (bs)
return tk_t, lg_t
todo tqdm
这是一个看起来最高大上,用起来也最高大上的,学起来最简单的一个工具包呀
和for aa in a:是相同的,不同的是可以定义desc 可以定义disable=True则不显示进度条
from tqdm import tqdm
char_test = [11,22,33,44,55,66,77,88,99]
import time
tchar = tqdm(char_test,desc="desc", disable=False)
for t in tchar: # 或者for t in trange(100):
time.sleep(1.0)
print(f"processing {t}")
# desc: 22%|████████▋ | 2/9 [00:02<00:07, 1.00s/it]processing 33
todo optimizer
本文地址:https://blog.csdn.net/weixin_43526074/article/details/107636816
上一篇: 列表、元组、字符串