欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

Word2Vec 学习整理

程序员文章站 2022-09-28 22:55:26
文章目录简介代码简介Word2Vec 学习(Skip-gram方法)参考资料:https://wmathor.com/index.php/archives/1443/代码import torchimport numpy as np import unicodedataimport stringimport reimport torch.nn as nnimport torch.optim as optimimport torch.utils.data as data字典库c...

文章目录


简介

Word2Vec 学习(Skip-gram方法)
参考资料:https://wmathor.com/index.php/archives/1443/


代码

import torch
import numpy as np 
import unicodedata
import string
import re
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

字典库

class Dictionary:
    def __init__(self, name):
        self.name=name
        self.words=[]
    
    def addWord(self, word):
        if word not in self.words:
            self.words.append(word)
    
    def getSize(self):
        return len(self.words)

字符串转换函数,可去除标点符号,统一为小写
(参照pytorch中文文档教程)


def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

def normalizeString(s):
    s = unicodeToAscii(s.lower())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s=s.strip()
    return s

从文件中读取字符串

lines=[]
with open('data\\news.txt',encoding='utf-8') as f:
    phrases=f.read().strip().split('\n')
    for phrase in phrases:
        phrase=phrase.strip().split('.')
        for p in phrase:
            if p!='':
                lines+=normalizeString(p).split(' ')

建立中心/上下文词对

engDictionary=Dictionary('English')
for word in lines:
    engDictionary.addWord(word)

def make_list(lines):
    words_list=[]
    for i in range(2, len(lines)-2):
        centre=engDictionary.words.index(lines[i])
        context=[]
        for t in range(i-2, i+3):
            if t!=i:
                context.append(engDictionary.words.index(lines[t]))
        for w in context:
            words_list.append([centre, w])
    return words_list
 
words_list=make_list(lines)

转换为训练数据
ps: 在造one-hot向量时要用numpy的narray,直接用torch.tensor会报错

input_data=[]
output_data=[]

def make_data(words_list):
    for w in words_list:
        k=np.zeros(engDictionary.getSize())
        k[w[0]]=1
        input_data.append(k)
        output_data.append(w[1])

make_data(words_list)

input_data=torch.Tensor(input_data)
output_data=torch.LongTensor(output_data)

dataset=data.TensorDataset(input_data, output_data)
dataloader=data.DataLoader(dataset, batch_size=8, shuffle=True)

模型
ps: 最后不需要softmax,crossentropy自带softmax

class Skip_gram(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Skip_gram, self).__init__()
        self.hidden_size=hidden_size
        self.W=nn.Linear(input_size, hidden_size, bias=False)
        self.V=nn.Linear(hidden_size, input_size, bias=False)
        
    def forward(self, x):
        x=self.W(x)
        x=self.V(x)
        return x
    
model=Skip_gram(engDictionary.getSize(), 10)

训练
ps: 1. optimizer用Adam, 不用SGD
2. crossentropy的输入一个为概率向量,另一个为分类标签的值,不是向量

criterion=nn.CrossEntropyLoss()
optimizer=optim.Adam(model.parameters(), lr=0.001)

def train(epoches):
    for epoch in range(1, epoches+1):
        sumloss=0
        for x, y in iter(dataloader):
            optimizer.zero_grad()
            o=model(x)
            loss=criterion(o, y)
            sumloss+=loss.item()
            loss.backward()
            optimizer.step()
        if epoch%100==0:
            print('epoch {}: loss:{:.2f}'.format(epoch, sumloss))

train(1000)

保存网络数据

PATH='./word2vec.pth'
torch.save(model.state_dict(), PATH)  

本文地址:https://blog.csdn.net/qq_32401661/article/details/109960321