欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

接上文(调整加载逻辑):Tensorflow2.3加载文本数据转为train_data,val_data,test_data,

程序员文章站 2022-04-02 10:55:35
上一篇文章:https://blog.csdn.net/qq_40974572/article/details/110875780,是参考官网实例,整体写下来还是觉得加载过程比较不清晰,所以有整理了一遍,功能全写为函数里面,看着更加清晰。第一部分是原文中的三个函数,基本不变,参数变量写在前头。import collectionsimport tensorflow as tfimport tensorflow_text as tf_textimport osBUFFER_SIZE = 5...

上一篇文章:https://blog.csdn.net/qq_40974572/article/details/110875780,是参考官网实例,整体写下来还是觉得加载过程比较不清晰,所以有整理了一遍,功能全写为函数里面,看着更加清晰。

第一部分是原文中的三个函数,基本不变,参数变量写在前头。

import collections
import tensorflow as tf
import tensorflow_text as tf_text
import os


BUFFER_SIZE = 50000
BATCH_SIZE = 64
VALIDATION_SIZE = 5000
tokenizer = tf_text.UnicodeScriptTokenizer()

# 加载数据,标签化数据
def labeler(example, index):
    return example, tf.cast(index, tf.int16)

# 分词器,在句子边界添加标记
def tokenize(text, unsued_label):
    lower_case = tf_text.case_fold_utf8(text)
    return tokenizer.tokenize(lower_case)

# 自动调整buffer_size
AUTOTUNE = tf.data.experimental.AUTOTUNE
def configure_dataset(dataset):
  return dataset.cache().prefetch(buffer_size=AUTOTUNE) # 预先读取数据进内存

第二部分,生成训练数据,验证数据,测试数据,函数接收数据路径,假设只有txt文件,文件中包含英文文本,本文章用到三个txt文件,标签为0,1,2,对句子分类。

先给每篇文章句子加标签,然后三篇文章连接成一个连续的句子,tokenize给句子边界加标签,接下来生成词汇表,用到了tf.lookup.KeyValueTensorInitializer和tf.lookup.StaticVocabularyTable,前者生成一个初始化器,用于后者生成一个词汇表;preprocess_text根据词汇表和输入的text,把单词转为数字,就是encoded之后的数据,最后就用encoded之后的数据分割成训练集、验证集和测试集。


def generate_train_val_test_data(datapath, **kwargs):
    file_list = os.listdir(datapath) # C:\Users\admin\.keras\datasets,假设该目录下有多个txt文件
    text_dir = []
    for name in file_list:
        tdir = os.path.join(os.path.dirname(datapath), name)
        text_dir.append(tdir)

    # 给每篇文章加标签
    labeled_datas_set = []
    for i, file_path in enumerate(text_dir):
        line_datas = tf.data.TextLineDataset(file_path)
        labeled_datas = line_datas.map(lambda dx: labeler(dx, i))
        labeled_datas_set.append(labeled_datas)

    # 多篇文章连接,然后打乱数据顺序
    # all_labeled_set = []
    all_labeled_set = labeled_datas_set[0]
    for labeled_data in labeled_datas_set[1:]:
        all_labeled_set.concatenate(labeled_data)
    all_labeled_set = all_labeled_set.shuffle(buffer_size=BUFFER_SIZE, reshuffle_each_iteration=False)


    tokendized = all_labeled_set.map(tokenize)
    # 词汇表
    vocab_dict = collections.defaultdict(lambda: 0)  # vocab_dict存储单词和对应的出现次数
    for toks in tokendized.as_numpy_iterator():
        for tok in toks:
            vocab_dict[tok] += 1  # 统计每个单词出现的次数
    VOCA_SIZE = 10000
    vocab = sorted(vocab_dict.items(), key=lambda x: x[1],
                   reverse=True)  # vocab为单词-该单词出现次数, vocab_dict.item()以元组形式返回(单词,出现次数)
    vocab = [token for token, count in vocab]
    vocab = vocab[:VOCA_SIZE]
    keys = vocab
    values = range(2, len(vocab) + 2)
    # key value初始化器
    init = tf.lookup.KeyValueTensorInitializer(
        keys, values, key_dtype=tf.string, value_dtype=tf.int64)
    num_oov_buckets = 1
    vocab_table = tf.lookup.StaticVocabularyTable(init, num_oov_buckets)

    # 文本数据转索引
    def preprocess_text(text, label):
        standardized = tf_text.case_fold_utf8(text)
        tokenized = tokenizer.tokenize(standardized)
        vectorized = vocab_table.lookup(tokenized)
        return vectorized, label


    all_encoded_sets = all_labeled_set.map(preprocess_text)
    # 生成训练集和验证集
    train_datas = all_encoded_sets.skip(VALIDATION_SIZE).shuffle(BUFFER_SIZE)
    val_datas = all_encoded_sets.take(VALIDATION_SIZE)

    # 测试集
    test_datas = all_encoded_sets.take(VALIDATION_SIZE).batch(batch_size=BATCH_SIZE)
    test_datas = configure_dataset(test_datas)

    return train_datas, val_datas, test_datas

测试通过,

if __name__ == '__main__':
    path = 'C:/Users/admin/.keras/datasets/test'
    train_data, val_data, test_data = generate_train_val_test_data(datapath=path)
    for t, l in train_data.take(1):
        print(t)
        print(l)

_

 

本文地址:https://blog.csdn.net/qq_40974572/article/details/110929281