知识图谱源码详解【八】__init__.py
程序员文章站
2022-03-04 13:09:45
...
import torch
from src.model.DKN.KCNN import KCNN
from src.model.DKN.attention import Attention
from src.model.general.click_predictor.DNN import DNNClickPredictor
# 就是把整个模型框架梳理到一块了!
class DKN(torch.nn.Module):
"""
Deep knowledge-aware network.
Input 1 + K candidate news and a list of user clicked news, produce the click probability.
"""
#纯定义,如果看过前面的内容,这里不难理解
def __init__(self,
config,
pretrained_word_embedding=None,
pretrained_entity_embedding=None,
pretrained_context_embedding=None):
super(DKN, self).__init__()
self.config = config
self.kcnn = KCNN(config, pretrained_word_embedding,
pretrained_entity_embedding,
pretrained_context_embedding)
self.attention = Attention(config)
self.click_predictor = DNNClickPredictor(
len(self.config.window_sizes) * 2 * self.config.num_filters)
def forward(self, candidate_news, clicked_news):
"""
Args:
candidate_news:
[
{
"title": batch_size * num_words_title,
"title_entities": batch_size * num_words_title
} * (1 + K)
]
clicked_news:
[
{
"title": batch_size * num_words_title,
"title_entities": batch_size * num_words_title
} * num_clicked_news_a_user
]
Returns:
click_probability: batch_size
"""
# batch_size, 1 + K, len(window_sizes) * num_filters
candidate_news_vector = torch.stack(
[self.kcnn(x) for x in candidate_news], dim=1)
# batch_size, num_clicked_news_a_user, len(window_sizes) * num_filters
clicked_news_vector = torch.stack([self.kcnn(x) for x in clicked_news],
dim=1)
# batch_size, 1 + K, len(window_sizes) * num_filters
user_vector = torch.stack([
self.attention(x, clicked_news_vector)
for x in candidate_news_vector.transpose(0, 1)
],
dim=1)
size = candidate_news_vector.size()
# batch_size, 1 + K
click_probability = self.click_predictor(
candidate_news_vector.view(size[0] * size[1], size[2]),
user_vector.view(size[0] * size[1],
size[2])).view(size[0], size[1])
return click_probability
def get_news_vector(self, news):
"""
Args:
news:
{
"title": batch_size * num_words_title,
"title_entities": batch_size * num_words_title
}
Returns:
(shape) batch_size, len(window_sizes) * num_filters
"""
# batch_size, len(window_sizes) * num_filters
return self.kcnn(news)
def get_user_vector(self, clicked_news_vector):
"""
Args:
clicked_news_vector: batch_size, num_clicked_news_a_user, len(window_sizes) * num_filters
Returns:
(shape) batch_size, num_clicked_news_a_user, len(window_sizes) * num_filters
"""
# batch_size, num_clicked_news_a_user, len(window_sizes) * num_filters
return clicked_news_vector
def get_prediction(self, candidate_news_vector, clicked_news_vector):
"""
Args:
candidate_news_vector: candidate_size, len(window_sizes) * num_filters
clicked_news_vector: num_clicked_news_a_user, len(window_sizes) * num_filters
Returns:
click_probability: 0-dim tensor
"""
# candidate_size, len(window_sizes) * num_filters
user_vector = self.attention(candidate_news_vector,
clicked_news_vector.expand(candidate_news_vector.size(0), -1, -1))
# candidate_size
click_probability = self.click_predictor(candidate_news_vector,
user_vector)
return click_probability
上一篇: 大学选专业一般是怎么选的?大学选专业应该考虑什么因素?
下一篇: C#获取注册表指定键值
推荐阅读
-
jQuery 源码解析(八) 异步队列模块 Callbacks 回调函数详解
-
知识图谱视频|知识图谱完整项目实战(附源码) 知识图谱
-
知识图谱视频|知识图谱完整项目实战(附源码) 知识图谱
-
Vue 2.0 深入源码分析(八) 基础篇 依赖注入 provide/inject组合详解
-
知识图谱基本工具Neo4j使用笔记 二 :py2neo源码手册
-
jQuery 源码解析(八) 异步队列模块 Callbacks 回调函数详解
-
Vue 2.0 深入源码分析(八) 基础篇 依赖注入 provide/inject组合详解
-
知识图谱源码详解【八】__init__.py
-
知识图谱DKN源码详解(三)dataset.py【未完】
-
知识图谱DKN源码详解(二)attention.py