欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

知识图谱DKN源码详解(四)train.py

程序员文章站 2022-03-04 13:13:27
...

内容

try:  #不用多言, 获得该模块下的model_name函数
    Model = getattr(importlib.import_module(f"model.{model_name}"), model_name)
    config = getattr(importlib.import_module('config'), f"{model_name}Config")
except AttributeError:
    print(f"{model_name} not included!")
    exit()
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class EarlyStopping

class EarlyStopping:
    def __init__(self, patience=5):
        self.patience = patience   
        self.counter = 0
        self.best_loss = np.Inf

    def __call__(self, val_loss):
        """
        if you use other metrics where a higher value is better, e.g. accuracy,
        call this with its corresponding negative value
        """
        # 如果你使用的其他指标值越高越好,例如准确性,用它对应的负数来调用它
        if val_loss < self.best_loss:   #如果评测的损失小于最好的损失,那么就是最好的损失
            early_stop = False
            get_better = True
            self.counter = 0
            self.best_loss = val_loss  # 最好的损失 
        else:
            get_better = False         #  
            self.counter += 1
            if self.counter >= self.patience:
                early_stop = True
            else:
                early_stop = False

        return early_stop, get_better  

def latest_checkpoint(directory):

看一看存储的模型路径名称:
知识图谱DKN源码详解(四)train.py

def latest_checkpoint(directory):   #最新的检查点! 
	if not os.path.exists(directory):  #该路径在不在
	        return None
   	all_checkpoints = {
        int(x.split('.')[-2].split('-')[-1]): x
        for x in os.listdir(directory)
    }
    if not all_checkpoints:
        return None
    return os.path.join(directory,
                        all_checkpoints[max(all_checkpoints.keys())])

补充

os.listdir() 方法

概述

os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。(是该文件夹下所有的文件名)

它不包括 . 和 … 即使它在文件夹中。

只支持在 Unix, Windows 下使用。

语法

listdir()方法语法格式如下:

os.listdir(path)

参数

path – 需要列出的目录路径

返回值

返回指定路径下的文件和文件夹列表。

实例

#!/usr/bin/python
# -*- coding: UTF-8 -*-

import os, sys

# 打开文件
path = "/var/www/html/"
dirs = os.listdir( path )

# 输出所有文件和文件夹
for file in dirs:
   print (file)

知识图谱DKN源码详解(四)train.py