欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

gluonnlp.vocab简单解析

程序员文章站 2022-09-21 10:21:37
# Licensed to the Apache Software Foundation (ASF) under one# or more contributor license agreements. See the NOTICE file# distributed with this work for additional information# regarding copyright ownership. The ASF licenses this file# to you under....
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: disable=consider-iterating-dictionary
"""Vocabulary."""
#目的是大部分nlp构建Vocab类的思路:
1.统计文本的token词频Counter,设置2个参数max_size,min_freq
2.设计保留字及其他特殊字符功能,Vocab类先将特殊字符放入到关键字参数字典中,用于构建索引字符的映射关系。
3.取字符索引,从索引取字符,构建字符到对应的embedding的过程。
4....

__all__ = ['Vocab']

import collections
import json
import uuid
import warnings
import sys
from typing import Dict, Hashable, List, Optional

from mxnet import nd

from .. import _constants as C
from .. import embedding as emb
from ..data.utils import Counter, DefaultLookupDict, count_tokens

UNK_IDX = 0
_DEPR_PAD = object()
_DEPR_BOS = object()
_DEPR_EOS = object()


def _is_py35():
    return sys.version_info[0] == 3 and sys.version_info[1] == 5

class Vocab:
    """Indexing and embedding attachment for text tokens.
        为文本标识符(字、特殊字符)构建索引和嵌入层
    Parameters
    ----------
    counter
        统计文本标识符字符的频率,设置条件限制字符的最大长度及最小频率数
        Counts text token frequencies in the text data. Its keys will be indexed according to
        frequency thresholds such as `max_size` and `min_freq`. Keys of `counter`,
        `unknown_token`, and values of `reserved_tokens` must be of the same hashable type.
        Examples: str, int, and tuple.
    max_size
        计算字符的最大长度时不考虑保留字(特殊字符)。后续在考虑词典映射关系时添加保留字/特殊字符进去。
        The maximum possible number of the most frequent tokens in the keys of `counter` that can be
        indexed. Note that this argument does not count any token from `reserved_tokens`. Suppose
        that there are different keys of `counter` whose frequency are the same, if indexing all of
        them will exceed this argument value, such keys will be indexed one by one according to
        their __cmp__() order until the frequency threshold is met. If this argument is None or
        larger than its largest possible value restricted by `counter` and `reserved_tokens`, this
        argument has no effect.
    min_freq
        The minimum frequency required for a token in the keys of `counter` to be indexed.
    unknown_token
        The representation for any unknown token. If `unknown_token` is not
        `None`, looking up any token that is not part of the vocabulary and
        thus considered unknown will return the index of `unknown_token`. If
        None, looking up an unknown token will result in `KeyError`.
    reserved_tokens
        保留字,不含通过关键字参数的特殊字符。
        A list specifying additional tokens to be added to the vocabulary.
        `reserved_tokens` must not contain the value of `unknown_token` or
        duplicate tokens. It must neither contain special tokens specified via
        keyword arguments.
    token_to_idx
        可以提前指定词典某些字符和对应索引的token_to_index字典
        If not `None`, specifies the indices of tokens to be used by the
        vocabulary. Each token in `token_to_index` must be part of the Vocab
        and each index can only be associated with a single token.
        `token_to_idx` is not required to contain a mapping for all tokens. For
        example, it is valid to only set the `unknown_token` index to 10
        (instead of the default of 0) with `token_to_idx = {'<unk>': 10}`,
        assuming that there are at least 10 tokens in the vocabulary.
    `**kwargs`
        关键字参数,变量命名必须为xxx_token。
        Keyword arguments of the format `xxx_token` can be used to specify
        further special tokens that will be exposed as attribute of the
        vocabulary and associated with an index.
        For example, specifying `mask_token='<mask>` as additional keyword
        argument when constructing a vocabulary `v` leads to `v.mask_token`
        exposing the value of the special token: `<mask>`.
        If the specified token is not part of the Vocabulary, it will be added,
        just as if it was listed in the `reserved_tokens` argument. The
        specified tokens are listed together with reserved tokens in the
        `reserved_tokens` attribute of the vocabulary object.
    deprecated_padding_token
        过时的padding标识符,用关键字参数来指定
        The representation for the special token of padding token. Default:
        '<pad>'. Specifying padding_token as positional argument is deprecated
        and support will be removed. Specify it as keyword argument instead
        (see documentation of `**kwargs` above)
    deprecated_bos_token
        The representation for the special token of beginning-of-sequence
        token. Default: '<bos>'. Specifying bos_token as positional argument is
        deprecated and support will be removed. Specify it as keyword argument
        instead (see documentation of `**kwargs` above)
    deprecated_eos_token
        The representation for the special token of end-of-sequence token.
        Default: '<eos>'. Specifying eos_token as positional argument is
        deprecated and support will be removed. Specify it as keyword argument
        instead (see documentation of `**kwargs` above)

    Attributes
    ----------
    embedding : instance of :class:`gluonnlp.embedding.TokenEmbedding`
        The embedding of the indexed tokens.
    idx_to_token : list of strs
        A list of indexed tokens where the list indices and the token indices are aligned.
    reserved_tokens : list of strs or None
        A list of reserved tokens that will always be indexed.
    token_to_idx : dict mapping str to int
        A dict mapping each token to its index integer.
    unknown_token : hashable object or None
        The representation for any unknown token. In other words, any unknown token will be indexed
        as the same representation.
    padding_token : hashable object or None
        The representation for padding token.
    bos_token : hashable object or None
        The representation for beginning-of-sentence token.
    eos_token : hashable object or None
        The representation for end-of-sentence token.


    Examples
    --------

    >>> text_data = ['hello', 'world', 'hello', 'nice', 'world', 'hi', 'world']
    >>> counter = gluonnlp.data.count_tokens(text_data)
    >>> my_vocab = gluonnlp.Vocab(counter)
    >>> fasttext = gluonnlp.embedding.create('fasttext', source='wiki.simple')
    -etc-
    >>> my_vocab.set_embedding(fasttext)
    >>> my_vocab.embedding[['hello', 'world']][:, :5]
    <BLANKLINE>
    [[ 0.39567   0.21454  -0.035389 -0.24299  -0.095645]
     [ 0.10444  -0.10858   0.27212   0.13299  -0.33165 ]]
    <NDArray 2x5 @cpu(0)>
    >>> my_vocab[['hello', 'world']]
    [5, 4]

    >>> input_dim, output_dim = my_vocab.embedding.idx_to_vec.shape
    >>> layer = gluon.nn.Embedding(input_dim, output_dim)
    >>> layer.initialize()
    >>> layer.weight.set_data(my_vocab.embedding.idx_to_vec)
    >>> layer(mx.nd.array([5, 4]))[:, :5]
    <BLANKLINE>
    [[ 0.39567   0.21454  -0.035389 -0.24299  -0.095645]
     [ 0.10444  -0.10858   0.27212   0.13299  -0.33165 ]]
    <NDArray 2x5 @cpu(0)>
    >>> glove = gluonnlp.embedding.create('glove', source='glove.6B.50d')
    -etc-
    >>> my_vocab.set_embedding(glove)
    >>> my_vocab.embedding[['hello', 'world']][:, :5]
    <BLANKLINE>
    [[-0.38497   0.80092   0.064106 -0.28355  -0.026759]
     [-0.41486   0.71848  -0.3045    0.87445   0.22441 ]]
    <NDArray 2x5 @cpu(0)>

    Extra keyword arguments of the format `xxx_token` are used to expose
    specified tokens as attributes.

    >>> my_vocab2 = gluonnlp.Vocab(counter, special_token='hi')
    >>> my_vocab2.special_token
    'hi'

    With the `token_to_idx` argument the order of the `Vocab`'s index can be
    adapted. For example, `Vocab` assigns the index `0` to the `unknown_token`
    by default. With the `token_to_idx` argument, the default can be
    overwritten. Here we assign index `3` to the unknown token representation
    `<unk>`.

    >>> tok2idx = {'<unk>': 3}
    >>> my_vocab3 = gluonnlp.Vocab(counter, token_to_idx=tok2idx)
    >>> my_vocab3.unknown_token
    '<unk>'
    >>> my_vocab3[my_vocab3.unknown_token]
    3
    >>> my_vocab[my_vocab.unknown_token]
    0

    """

    def __init__(self, counter: Optional[Counter] = None, max_size: Optional[int] = None,
                 min_freq: int = 1, unknown_token: Optional[Hashable] = C.UNK_TOKEN,
                 deprecated_padding_token: Optional[Hashable] = _DEPR_PAD,
                 deprecated_bos_token: Optional[Hashable] = _DEPR_BOS,
                 deprecated_eos_token: Optional[Hashable] = _DEPR_EOS,
                 reserved_tokens: Optional[List[Hashable]] = None,
                 token_to_idx: Optional[Dict[Hashable, int]] = None, *,
                 padding_token: Optional[Hashable] = C.PAD_TOKEN,
                 bos_token: Optional[Hashable] = C.BOS_TOKEN,
                 eos_token: Optional[Hashable] = C.EOS_TOKEN, **kwargs):

        # Sanity checks.
        assert min_freq > 0, '`min_freq` must be set to a positive value.'

        # Deprecation checks and warnings
        #过时特殊字符检查,若设置到即把自定义参数名称放到关键字参数字典,
        # 若没设置(deprecated_padding_token值)到则使用默认的特殊字符例如'<pad>'。
        combs = ((deprecated_padding_token, 'padding_token', _DEPR_PAD, padding_token),
                 (deprecated_bos_token, 'bos_token', _DEPR_BOS, bos_token),
                 (deprecated_eos_token, 'eos_token', _DEPR_EOS, eos_token))
        for depr_pos_arg, name, indicator, value in combs:
            if depr_pos_arg != indicator:
                warnings.warn(
                    'Specifying `{n}` as positional argument is deprecated and '
                    'support will be removed. Please specify `{n}` as keyword argument instead, '
                    'for example `Vocab(counter, {n}={v})`'.format(n=name, v=depr_pos_arg),
                    DeprecationWarning)
                # Store positional argument value in kwargs
                kwargs[name] = depr_pos_arg
            elif name not in kwargs:  # Store keyword argument value in kwargs
                kwargs[name] = value

        # Set up idx_to_token and token_to_idx based on presence of unknown token
        self._unknown_token = unknown_token
        self._idx_to_token = [unknown_token] if unknown_token else []
        if unknown_token:
            self._token_to_idx = DefaultLookupDict(UNK_IDX) #设置取不到token情况下的默认索引值
        else:
            self._token_to_idx = {}

        # Handle special tokens
        # 处理特殊字符,将其放到关键字参数字典。
        special_tokens = []
        special_iter = kwargs.items()
        if _is_py35():
            special_iter = sorted(special_iter)
        for special_token_name, special_token in special_iter:
            # Test if kwarg specifies a special token
            if not special_token_name.endswith('_token'):
                raise ValueError('{} is invalid. Only keyword arguments '
                                 'that end in \'_token\' are supported '
                                 'to declare special tokens.'.format(special_token_name))

            if special_token is not None and special_token not in special_tokens:
                special_tokens.append(special_token)

        if reserved_tokens is not None:
            special_tokens.extend(reserved_tokens)
            special_token_set = set(special_tokens)
            if unknown_token:
                assert unknown_token not in special_token_set, \
                    '`reserved_token` cannot contain `unknown_token`.'
            assert len(special_token_set) == len(special_tokens), \
                '`reserved_tokens` cannot contain duplicate reserved tokens or ' \
                'other special tokens.'
        #将特殊字符放入self._reserved_tokens中,是最早放入到self._idx_to_token,self._token_to_idx的一批数据
        if not special_tokens:
            self._reserved_tokens = None
        else:
            self._reserved_tokens = special_tokens
            self._idx_to_token.extend(special_tokens)

        self._token_to_idx.update((token, idx) for idx, token in enumerate(self._idx_to_token))
        self._embedding = None

        if counter:
            #将普通字符放入到self._idx_to_token,self._token_to_idx映射表中
            self._index_counter_keys(counter, unknown_token, special_tokens, max_size, min_freq)

        self._identifiers_to_tokens = kwargs
        if kwargs:
            #将关键字参数作为类属性暴露出来,以供后续使用该参数
            self._expose_tokens_as_attributes(kwargs)

        if token_to_idx:
            self._sort_index_according_to_user_specification(token_to_idx)
            if unknown_token:
                self._token_to_idx._default = \
                    self._token_to_idx[unknown_token]  # pytype: disable=not-writable


    def _index_counter_keys(self, counter, unknown_token, special_tokens, max_size,
                            min_freq):
        """Indexes keys of `counter`.

        构建索引到字符字典_idx_to_token和字符到索引_token_to_idx的字典,
        同时满足min_freq,max_size的过程
        Indexes keys of `counter` according to frequency thresholds such as `max_size` and
        `min_freq`.
        """

        unknown_and_special_tokens = set(special_tokens) if special_tokens else set()

        if unknown_token:
            unknown_and_special_tokens.add(unknown_token)

        token_freqs = sorted(counter.items(), key=lambda x: x[0])
        token_freqs.sort(key=lambda x: x[1], reverse=True)

        token_cap = len(unknown_and_special_tokens) + (
            len(counter) if not max_size else max_size)

        for token, freq in token_freqs:
            if freq < min_freq or len(self._idx_to_token) == token_cap:
                break
            if token not in unknown_and_special_tokens:
                self._idx_to_token.append(token)
                self._token_to_idx[token] = len(self._idx_to_token) - 1

    def _expose_tokens_as_attributes(self, identifiers_to_tokens):
        # This method must not be called before internal attributes accessed by
        # @properties getters are set. Otherwise the @properties may raise
        # during the hasattr(self, identifier) check

        for identifier, token in identifiers_to_tokens.items():
            # Special tokens are automatically added to the vocab; assert, just to be sure
            assert token is None or token in self
            if identifier.startswith('_'):
                raise ValueError('It is not allowed to use identifiers starting with '
                                 'underscore. In Python identifier names beginning with '
                                 'underscore are internal.')
            if hasattr(self, identifier):
                raise ValueError('vocab.{} already exists. '
                                 'Please choose a different identifier for token {}'
                                 .format(identifier, token))
            setattr(self, identifier, token)

    def _sort_index_according_to_user_specification(self, token_to_idx):
        # Sanity checks 合理性检查
        if not set(token_to_idx.keys()).issubset(self.token_to_idx.keys()):
            raise ValueError('User-specified token_to_idx mapping can only contain '
                             'tokens that will be part of the vocabulary.')
        if len(set(token_to_idx.values())) != len(token_to_idx):
            raise ValueError('User-specified indices must not contain duplicates.')
        if min(token_to_idx.values()) < 0 or max(token_to_idx.values()) >= len(self.token_to_idx):
            raise ValueError('User-specified indices must not be < 0 or >= the number of tokens '
                             'that will be in the vocabulary. The current vocab contains {}'
                             'tokens.'.format(len(self.token_to_idx)))

        # Update index ordering 更新字符的索引值,满足自定义词典类指定的idxtotoken,tokentoidx映射关系.
        for token, new_idx in token_to_idx.items():
            old_idx = self.token_to_idx[token]
            ousted_token = self.idx_to_token[new_idx]

            self.token_to_idx[token] = new_idx
            self.token_to_idx[ousted_token] = old_idx
            self.idx_to_token[old_idx] = ousted_token
            self.idx_to_token[new_idx] = token

    @property
    def embedding(self):
        return self._embedding

    @property
    def idx_to_token(self):
        return self._idx_to_token

    @property
    def reserved_tokens(self):
        return self._reserved_tokens

    @property
    def token_to_idx(self):
        return self._token_to_idx

    @property
    def unknown_token(self):
        return self._unknown_token

    def __contains__(self, token):
        """Checks whether a text token exists in the vocabulary.


        Parameters
        ----------
        token : str
            A text token.


        Returns
        -------
        bool
            Whether the text token exists in the vocabulary (including `unknown_token`).
        """

        return token in self._token_to_idx

    def __getitem__(self, tokens):
        """Looks up indices of text tokens according to the vocabulary.

        If `unknown_token` of the vocabulary is None, looking up unknown tokens results in KeyError.

        Parameters
        ----------
        tokens : str or list of strs
            A source token or tokens to be converted.


        Returns
        -------
        int or list of ints
            A token index or a list of token indices according to the vocabulary.
        """

        if not isinstance(tokens, (list, tuple)):
            return self._token_to_idx[tokens]
        else:
            return [self._token_to_idx[token] for token in tokens]

    def __len__(self):
        return len(self._idx_to_token)

    def set_embedding(self, *embeddings):
        """Attaches one or more embeddings to the indexed text tokens.

        思路:给标识符token构建Embedding过程,如果是多个embeddings,那就先统计计算总的Embedding的维数并初始化总的Embedding
        然后按先后顺序在总Embedding填充子embeding(embeddings列表每一个embedding).
        Parameters
        ----------
        embeddings : None or tuple of :class:`gluonnlp.embedding.TokenEmbedding` instances
            The embedding to be attached to the indexed tokens. If a tuple of multiple embeddings
            are provided, their embedding vectors will be concatenated for the same token.
        """

        if len(embeddings) == 1 and embeddings[0] is None:
            self._embedding = None
            return

        for embs in embeddings:
            assert isinstance(embs, emb.TokenEmbedding), \
                'The argument `embeddings` must be an instance or a list of instances of ' \
                '`gluonnlp.embedding.TokenEmbedding`.'
            assert embs.idx_to_vec is not None, \
                'For all specified `embeddings`, `embeddings.idx_to_vec` must be initialized. ' \
                'Use eg. `emb[emb.unknown_token] = nd.zeros(emsize)` to initialize, ' \
                'where `emsize` is the desired embedding dimensionality.'

        assert all([embs.unknown_token for embs in embeddings]) or \
            all([not embs.unknown_token for embs in embeddings]), \
            'Either all or none of the TokenEmbeddings must have an ' \
            'unknown_token set.'

        new_vec_len = sum(embs.idx_to_vec.shape[1] for embs in embeddings)
        # TODO(leezu): Remove once np shape is used by default
        assert len(self), 'Empty vocab not yet supported'
        new_idx_to_vec = nd.zeros(shape=(len(self), new_vec_len))

        col_start = 0
        # Concatenate all the embedding vectors in embedding.
        for embs in embeddings:
            if embs and embs.idx_to_vec is not None:
                col_end = col_start + embs.idx_to_vec.shape[1]
                # Cancatenate vectors of the unknown token.
                new_idx_to_vec[0, col_start:col_end] = embs.idx_to_vec[0]
                new_idx_to_vec[1:, col_start:col_end] = embs[self._idx_to_token[1:]]
                col_start = col_end

        self._embedding = emb.TokenEmbedding(self.unknown_token,
                                             init_unknown_vec=None,
                                             allow_extend=False,
                                             idx_to_token=self.idx_to_token,
                                             idx_to_vec=new_idx_to_vec)

    def to_tokens(self, indices):
        """Converts token indices to tokens according to the vocabulary.


        Parameters
        ----------
        indices : int or list of ints
            A source token index or token indices to be converted.


        Returns
        -------
        str or list of strs
            A token or a list of tokens according to the vocabulary.
        """

        to_reduce = False
        if not isinstance(indices, (list, tuple)):
            indices = [indices]
            to_reduce = True

        max_idx = len(self._idx_to_token) - 1

        tokens = []
        for idx in indices:
            if not isinstance(idx, int) or idx > max_idx:
                raise ValueError('Token index {} in the provided `indices` is invalid.'.format(idx))
            tokens.append(self._idx_to_token[idx])

        return tokens[0] if to_reduce else tokens

    def to_indices(self, tokens):
        """Looks up indices of text tokens according to the vocabulary.


        Parameters
        ----------
        tokens : str or list of strs
            A source token or tokens to be converted.


        Returns
        -------
        int or list of ints
            A token index or a list of token indices according to the vocabulary.
        """

        return self[tokens]

    def __call__(self, tokens):
        """Looks up indices of text tokens according to the vocabulary.


        Parameters
        ----------
        tokens : str or list of strs
            A source token or tokens to be converted.


        Returns
        -------
        int or list of ints
            A token index or a list of token indices according to the vocabulary.
        """

        return self[tokens]

    def __repr__(self):
        unk = '"{}"'.format(self._unknown_token) if self._unknown_token else 'None'
        reserved = '"{}"'.format(self._reserved_tokens) if self._reserved_tokens else 'None'
        return 'Vocab(size={}, unk={}, reserved={})'.format(len(self), unk, reserved)

    def to_json(self):
        """Serialize Vocab object to json string.
        序列为json格式
        This method does not serialize the underlying embedding.
        """
        if self._embedding:
            warnings.warn('Serialization of attached embedding '
                          'to json is not supported. '
                          'You may serialize the embedding to a binary format '
                          'separately using vocab.embedding.serialize')
        vocab_dict = {}
        vocab_dict['idx_to_token'] = self._idx_to_token
        vocab_dict['token_to_idx'] = dict(self._token_to_idx)
        vocab_dict['reserved_tokens'] = self._reserved_tokens
        vocab_dict['unknown_token'] = self._unknown_token
        vocab_dict['identifiers_to_tokens'] = self._identifiers_to_tokens
        return json.dumps(vocab_dict)

    @classmethod
    def from_json(cls, json_str):
        """Deserialize Vocab object from json string.
        将json格式反序列成Vocab对象。
        Parameters
        ----------
        json_str : str
            Serialized json string of a Vocab object.


        Returns
        -------
        Vocab
        """
        vocab_dict = json.loads(json_str)
        token_to_idx = vocab_dict.get('token_to_idx')
        unknown_token = vocab_dict.get('unknown_token')
        reserved_tokens = vocab_dict.get('reserved_tokens')
        identifiers_to_tokens = vocab_dict.get('identifiers_to_tokens', dict())

        special_tokens = {unknown_token}

        # Backward compatibility for explicit serialization of padding_token,
        # bos_token, eos_token handling in the json string as done in older
        # versions of GluonNLP.
        #兼容以前版本的参数
        deprecated_arguments = ['padding_token', 'bos_token', 'eos_token']
        for token_name in deprecated_arguments:
            if token_name in vocab_dict:
                token = vocab_dict[token_name]
                assert token_name not in identifiers_to_tokens, 'Invalid json string. ' \
                    '{} was serialized twice.'.format(token_name)
                identifiers_to_tokens[token_name] = token

        # Separate reserved from special tokens
        #将保留字从特殊字符字典中拿出来
        special_tokens.update(identifiers_to_tokens.values())
        if reserved_tokens is not None:
            reserved_tokens = [
                t for t in reserved_tokens if t not in special_tokens
            ]

        # Backward compatiblity code to deserialize corrupted vocabularies
        # created without bugfix https://github.com/dmlc/gluon-nlp/pull/749
        corrected_token_to_idx = collections.defaultdict(list)
        idx_to_token = vocab_dict.get('idx_to_token')
        if len(idx_to_token) > len(token_to_idx):  # Index is corrupt
            warnings.warn(
                'Detected a corrupted index in the deserialize vocabulary. '
                'For versions before GluonNLP v0.7 the index is corrupted '
                'by specifying the same token for different special purposes, '
                'for example eos_token == padding_token. '
                'Deserializing the vocabulary nevertheless.'
            )
            for token, count in collections.Counter(idx_to_token).items():
                if count == 1:
                    continue
                # Introduce new tokens to avoid invalid duplicates
                idx = -1
                while count > 0:
                    count -= 1
                    idx = idx_to_token.index(token, idx + 1)
                    if idx == token_to_idx[token]:
                        # Valid idx
                        continue
                    # Introduce temporary token
                    token_to_idx.update({str(uuid.uuid4()): idx})
                    corrected_token_to_idx[token].append(idx)

        vocab = cls(
            counter=count_tokens(token_to_idx.keys()),
            unknown_token=unknown_token,
            reserved_tokens=reserved_tokens,
            token_to_idx=token_to_idx,
            **identifiers_to_tokens)

        # Backward compatiblity code to deserialize corrupted vocabularies
        # created without bugfix https://github.com/dmlc/gluon-nlp/pull/749
        for token, corrected_idxs in corrected_token_to_idx.items():
            for idx in corrected_idxs:
                # delete temporary tokens
                del vocab._token_to_idx[vocab._idx_to_token[idx]]
                vocab._idx_to_token[idx] = token

        return vocab 

本文地址:https://blog.csdn.net/sinat_24395003/article/details/108244708

相关标签: gluonnlp