欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

GraphSage 代码阅读笔记

程序员文章站 2022-07-12 13:17:07
...

relation也就是边 没有embedding

supervised_train.py 是用节点分类的label来做loss训练,不能输出节点embedding,使用NodeMinibatchIterator

unsupervised_train.py 是用节点和节点的邻接信息做loss训练,训练好可以输出节点embedding,使用EdgeMinibatchIterator

NodeMinibatchIterator__init__方法最后加上

train_node_set = set(self.train_nodes)
valid_node_set = set(self.val_nodes)
print("train_node_set size", len(train_node_set))
print("valid_node_set size", len(valid_node_set))
print("train_node_set valid_node_set intersect size",len(train_node_set.intersection(valid_node_set)))

打印结果

train_node_set size 9716
valid_node_set size 1825
train_node_set valid_node_set intersect size 0

EdgeMinibatchIterator__init__方法最后加上

train_edge_set = set(self.train_edges)
valid_edge_set = set(self.val_edges)
print("train_edge_set size", len(train_edge_set))
print("valid_edge_set size", len(valid_edge_set))
print("train_edge_set valid_edge_set intersect size", len(train_edge_set.intersection(valid_edge_set)))

打印结果

train_edge_set size 1336764
valid_edge_set size 75407
train_edge_set valid_edge_set intersect size 0

EdgeMinibatchIterator__init__方法最后改成

train_nodes = [n for n in G.nodes() if not G.node[n]['test'] and not G.node[n]['val']]
print(len(train_nodes), 'train nodes')
test_nodes = [n for n in G.nodes() if G.node[n]['test'] or G.node[n]['val']]
print(len(test_nodes), 'test nodes')
print("train test node intersect number", len(set(test_nodes).intersection(set(train_nodes))))

打印结果

9716 train nodes
5039 test nodes
train test node intersect number 0

更多理解https://discuss.dgl.ai/t/graphsage-question-the-train-data-and-valid-data-have-no-intersection-then-how-does-the-valid-data-get-the-embedding-for-downstream-model/539/3