知识图谱DKN源码详解(三)dataset.py【未完】
程序员文章站
2022-03-04 13:09:33
...
内容
里面有的函数在这里https://blog.csdn.net/qq_35222729/article/details/119882362
try:
config = getattr(importlib.import_module('config'), f"{model_name}Config")
except AttributeError:
print(f"{model_name} not included!")
exit()
class BaseDataset(Dataset):
def __init__(self, behaviors_path, news_path, roberta_embedding_dir):
super(BaseDataset, self).__init__()
assert all(attribute in [
'category', 'subcategory', 'title', 'abstract', 'title_entities',
'abstract_entities', 'title_roberta', 'title_mask_roberta',
'abstract_roberta', 'abstract_mask_roberta'
] for attribute in config.dataset_attributes['news']) #数据集的属性应该在这些属性中
assert all(attribute in ['user', 'clicked_news_length'] #同上
for attribute in config.dataset_attributes['record'])
self.behaviors_parsed = pd.read_table(behaviors_path) #读入我们的行为并处理
self.news_parsed = pd.read_table( #
news_path,
index_col='id',
usecols=['id'] + config.dataset_attributes['news'],
converters={
attribute: literal_eval #对某些列执行literal-eval,将某些列转变为原类型,脱层
for attribute in set(config.dataset_attributes['news']) & set([
'title', 'abstract', 'title_entities', 'abstract_entities',
'title_roberta', 'title_mask_roberta', 'abstract_roberta',
'abstract_mask_roberta'
])
})
self.news_id2int = {x: i for i, x in enumerate(self.news_parsed.index)}
self.news2dict = self.news_parsed.to_dict('index')
for key1 in self.news2dict.keys():
for key2 in self.news2dict[key1].keys():
self.news2dict[key1][key2] = torch.tensor(
self.news2dict[key1][key2])
padding_all = {
'category': 0,
'subcategory': 0,
'title': [0] * config.num_words_title,
'abstract': [0] * config.num_words_abstract,
'title_entities': [0] * config.num_words_title,
'abstract_entities': [0] * config.num_words_abstract,
'title_roberta': [0] * config.num_words_title,
'title_mask_roberta': [0] * config.num_words_title,
'abstract_roberta': [0] * config.num_words_abstract,
'abstract_mask_roberta': [0] * config.num_words_abstract
}
for key in padding_all.keys():
padding_all[key] = torch.tensor(padding_all[key])
self.padding = {
k: v
for k, v in padding_all.items()
if k in config.dataset_attributes['news']
}
def _news2dict(self, id):
ret = self.news2dict[id]
if model_name == 'Exp2' and not config.fine_tune:
for k in set(config.dataset_attributes['news']) & set(
['title', 'abstract']):
ret[k] = self.roberta_embedding[k][self.news_id2int[id]]
return ret
def __len__(self):
return len(self.behaviors_parsed)
def __getitem__(self, idx): #返回单个item
item = {}
row = self.behaviors_parsed.iloc[idx]
if 'user' in config.dataset_attributes['record']:
item['user'] = row.user
item["clicked"] = list(map(int, row.clicked.split()))
item["candidate_news"] = [
self._news2dict(x) for x in row.candidate_news.split()
]
item["clicked_news"] = [
self._news2dict(x)
for x in row.clicked_news.split()[:config.num_clicked_news_a_user]
]
if 'clicked_news_length' in config.dataset_attributes['record']:
item['clicked_news_length'] = len(item["clicked_news"])
repeated_times = config.num_clicked_news_a_user - \
len(item["clicked_news"])
assert repeated_times >= 0
item["clicked_news"] = [self.padding
] * repeated_times + item["clicked_news"]
return item
补充
1. ast.literal_eval
Python中,如果要将字符串型的list,tuple,dict转变成原有的类型呢?这个时候你自然会想到eval. eval函数在Python中做数据类型的转换还是很有用的。它的作用就是把数据还原成它本身或者是能够转化成的数据类型string <=> list
In [1]: s = '[1, 2, 3, 4]'
In [2]: l = eval(s)
In [3]: s
Out[3]: '[1, 2, 3, 4]'
In [4]: l
Out[4]: [1, 2, 3, 4]
In [5]: type(s)
Out[5]: str
In [6]: type(l)
Out[6]: list