知识图谱DKN源码详解(四)train.py
程序员文章站
2022-03-04 13:13:27
...
内容
try: #不用多言, 获得该模块下的model_name函数
Model = getattr(importlib.import_module(f"model.{model_name}"), model_name)
config = getattr(importlib.import_module('config'), f"{model_name}Config")
except AttributeError:
print(f"{model_name} not included!")
exit()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class EarlyStopping
class EarlyStopping:
def __init__(self, patience=5):
self.patience = patience
self.counter = 0
self.best_loss = np.Inf
def __call__(self, val_loss):
"""
if you use other metrics where a higher value is better, e.g. accuracy,
call this with its corresponding negative value
"""
# 如果你使用的其他指标值越高越好,例如准确性,用它对应的负数来调用它
if val_loss < self.best_loss: #如果评测的损失小于最好的损失,那么就是最好的损失
early_stop = False
get_better = True
self.counter = 0
self.best_loss = val_loss # 最好的损失
else:
get_better = False #
self.counter += 1
if self.counter >= self.patience:
early_stop = True
else:
early_stop = False
return early_stop, get_better
def latest_checkpoint(directory):
看一看存储的模型路径名称:
def latest_checkpoint(directory): #最新的检查点!
if not os.path.exists(directory): #该路径在不在
return None
all_checkpoints = {
int(x.split('.')[-2].split('-')[-1]): x
for x in os.listdir(directory)
}
if not all_checkpoints:
return None
return os.path.join(directory,
all_checkpoints[max(all_checkpoints.keys())])
补充
os.listdir() 方法
概述
os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。(是该文件夹下所有的文件名)
它不包括 . 和 … 即使它在文件夹中。
只支持在 Unix, Windows 下使用。
语法
listdir()方法语法格式如下:
os.listdir(path)
参数
path – 需要列出的目录路径
返回值
返回指定路径下的文件和文件夹列表。
实例
#!/usr/bin/python
# -*- coding: UTF-8 -*-
import os, sys
# 打开文件
path = "/var/www/html/"
dirs = os.listdir( path )
# 输出所有文件和文件夹
for file in dirs:
print (file)
上一篇: TypeScript环境搭建的实现步骤