欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

知识图谱DKN源码详解(三)dataset.py【未完】

程序员文章站 2022-03-04 13:09:33
...

内容

里面有的函数在这里https://blog.csdn.net/qq_35222729/article/details/119882362


try:
    config = getattr(importlib.import_module('config'), f"{model_name}Config")
except AttributeError:
    print(f"{model_name} not included!")
    exit()


class BaseDataset(Dataset):
    def __init__(self, behaviors_path, news_path, roberta_embedding_dir):
        super(BaseDataset, self).__init__()
        assert all(attribute in [
            'category', 'subcategory', 'title', 'abstract', 'title_entities',
            'abstract_entities', 'title_roberta', 'title_mask_roberta',
            'abstract_roberta', 'abstract_mask_roberta'
        ] for attribute in config.dataset_attributes['news'])  #数据集的属性应该在这些属性中
        assert all(attribute in ['user', 'clicked_news_length'] #同上
                   for attribute in config.dataset_attributes['record'])

        self.behaviors_parsed = pd.read_table(behaviors_path) #读入我们的行为并处理
        self.news_parsed = pd.read_table(  #
            news_path,
            index_col='id',
            usecols=['id'] + config.dataset_attributes['news'],
            converters={
                attribute: literal_eval   #对某些列执行literal-eval,将某些列转变为原类型,脱层
                for attribute in set(config.dataset_attributes['news']) & set([
                    'title', 'abstract', 'title_entities', 'abstract_entities',
                    'title_roberta', 'title_mask_roberta', 'abstract_roberta',
                    'abstract_mask_roberta'
                ])
            })
        self.news_id2int = {x: i for i, x in enumerate(self.news_parsed.index)}
        self.news2dict = self.news_parsed.to_dict('index')  
        for key1 in self.news2dict.keys():
            for key2 in self.news2dict[key1].keys():
                self.news2dict[key1][key2] = torch.tensor(
                    self.news2dict[key1][key2])
        padding_all = {
            'category': 0,
            'subcategory': 0,
            'title': [0] * config.num_words_title,
            'abstract': [0] * config.num_words_abstract,
            'title_entities': [0] * config.num_words_title,
            'abstract_entities': [0] * config.num_words_abstract,
            'title_roberta': [0] * config.num_words_title,
            'title_mask_roberta': [0] * config.num_words_title,
            'abstract_roberta': [0] * config.num_words_abstract,
            'abstract_mask_roberta': [0] * config.num_words_abstract
        }
        for key in padding_all.keys():
            padding_all[key] = torch.tensor(padding_all[key])

        self.padding = {
            k: v
            for k, v in padding_all.items()
            if k in config.dataset_attributes['news']
        }

    def _news2dict(self, id):
        ret = self.news2dict[id]
        if model_name == 'Exp2' and not config.fine_tune:
            for k in set(config.dataset_attributes['news']) & set(
                ['title', 'abstract']):
                ret[k] = self.roberta_embedding[k][self.news_id2int[id]]
        return ret

    def __len__(self):
        return len(self.behaviors_parsed)

    def __getitem__(self, idx):   #返回单个item
        item = {}
        row = self.behaviors_parsed.iloc[idx]
        if 'user' in config.dataset_attributes['record']:
            item['user'] = row.user
        item["clicked"] = list(map(int, row.clicked.split()))
        item["candidate_news"] = [
            self._news2dict(x) for x in row.candidate_news.split()
        ]
        item["clicked_news"] = [
            self._news2dict(x)
            for x in row.clicked_news.split()[:config.num_clicked_news_a_user]
        ]
        if 'clicked_news_length' in config.dataset_attributes['record']:
            item['clicked_news_length'] = len(item["clicked_news"])
        repeated_times = config.num_clicked_news_a_user - \
            len(item["clicked_news"])
        assert repeated_times >= 0
        item["clicked_news"] = [self.padding
                                ] * repeated_times + item["clicked_news"]

        return item

补充

1. ast.literal_eval

Python中,如果要将字符串型的list,tuple,dict转变成原有的类型呢?这个时候你自然会想到eval. eval函数在Python中做数据类型的转换还是很有用的。它的作用就是把数据还原成它本身或者是能够转化成的数据类型
string <=> list

In [1]: s = '[1, 2, 3, 4]'

In [2]: l = eval(s)

In [3]: s
Out[3]: '[1, 2, 3, 4]'

In [4]: l
Out[4]: [1, 2, 3, 4]

In [5]: type(s)
Out[5]: str

In [6]: type(l)
Out[6]: list