欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

使用pytorch实现fasttext

程序员文章站 2022-03-26 22:17:47
使用pytorch实现fasttextimport torchimport torch.nn as nnclass FastText(nn.Module): """创建fasttext模型的类""" def __init__(self, vocab_size, embed_dim, num_class): """ :param vocab_size: 语料中不重复的词汇总数 :param embed_dim: 词嵌入维度...

使用pytorch实现fasttext

import torch
import torch.nn as nn
class FastText(nn.Module):
    """创建fasttext模型的类"""
    def __init__(self, vocab_size, embed_dim, num_class):
        """
        :param vocab_size: 语料中不重复的词汇总数
        :param embed_dim: 词嵌入维度
        :param num_class: 目标的类别数
        """
        super().__init__()
        # 初始化EmbeddingBag层
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        # 初始化全连接层
        self.fc = nn.Linear(embed_dim, num_class)
        # 初始化权重
        self.init_weights()
    def init_weights(self):
        """初始化权重函数"""
        # 均匀分布参数
        initrange = 0.5
        # embedding层使用均匀分布,如果是采用迁移词向量,这里则初始化迁移词向量
        self.embedding.weight.data.uniform_(-initrange, initrange)
        # 全连接层也是均匀分布
        self.fc.weight.data.uniform_(-initrange, initrange)
        # 偏置初始化为0
        self.fc.bias.data.zero_()
    def forward(self, text, offsets):
        """
        正向传播逻辑
        :param text: 输入的文本的数值映射
        :param offsets: Bag的起始位置
        """
        # 先通过EmbeddingBag层
        embedded = self.embedding(text, offsets)
        # 再通过全连接层
        return self.fc(embedded)

if __name__ == '__main__':
    input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
    # 这里我们还需要定义不同Bag起始位置(输入可以分为不同的Bag)
    # 这里的[0, 4]代表将输入[1,2,4,5,4,3,2,9]分成了[1,2,4,5]和[4,3,2,9]
    # 这样第一个Bag[1,2,4,5]将通过Embedding层得到4个1x3的张量,
    # 因为他们属于一个Bag,要做'mean'运算,即张量相加再除4得到一个1x3的张量
    # 同理,第二个Bag也得到一个1x3的张量
    offsets = torch.LongTensor([0, 4])

    ft = FastText(10, 3, 4)
    # 假设语料中不重复的词汇总数为10, 词嵌入维度为3, 目标的类别数为4
    ret = ft(input, offsets)
    print(ret)

工作中会直接使用fasttext工具,因此这个实现仅仅作为学习使用。

本文地址:https://blog.csdn.net/huyidu/article/details/112671907