Tensorflow中使用Keras自定义数据集 分布式训练 极简教程
前言
连滚带爬的摸索了一周,网上的教程参差不齐,都是在讲一大堆概念,而没有动手的教程。就算是动手的教程,也大都版本太旧,而且是在理想状态下的试验,不符合每个人的需求。最后经过摸索,终于跑通了!其实说白了,就是希望能有一个教程,从自己的数据集,自己的模型开始教分布式训练。
本篇教程基于Tensorflow 2.1 !很多分布式的东西在2.3才支持!!但是目前为止,还没发布!
分布式训练
介绍分布式训练的文章很多,这里就不重复了。要注意的是,分布式训练有如下几种分类方法
- 单机多卡,多机多卡
- 模型并行,数据并行
- Parameter Server结构和Ring All Reduce结构
- 异步和同步
这里就不逐个介绍了!接下来的教程,主要是同步式的数据并行多机训练方式!也是最简单的方式。
自定义数据集
Tensorflow的数据并行训练需要自定义数据集,直接用tensor会报一大堆错(真的蠢)!因此在这里简单介绍下怎么自定义数据集
indata = np.ones((1000, 5)).astype(np.float32)
outdata = np.ones((1000,1)).astype(np.float32)
indata = np.reshape(indata, (1000, 5))
outdata = np.reshape(outdata, (1000, 1))
dataset = tf.data.Dataset.from_tensor_slices((indata, outdata))
dataset = dataset.batch(20)
这样子就自定义了一个数据集了。主要是from_tensor_slices这个函数!自行百度这个函数的用法~最后我们就准备了一个数据集了,indata是他的输入,outdata是他的输出 batch_size是20。训练的时候直接扔dataset就ok,如Model.fit(dataset, epochs=5,steps_per_epoch=10)
分布式模型配置
然后我们就可以自定义自己的模型了!与以往不同,我们需要新增一个strategy,然后使用它,具体看代码!
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': ["localhost:20000","localhost:20001"]
},
'task': {'type': 'worker', 'index': 0}
})
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
with strategy.scope():
inputs = Input(shape=(5,))
model = Dense(10)(inputs)
model = Dense(1)(model)
TargetModel = Model(inputs=inputs, outputs=model)
TargetModel.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss='mse'
)
这样我们就成功的定义了一个模型,用于分布式训练了。关于TF_CONFIG这个环境变量的内容,网上一大堆资料,你能搜到这篇博客,证明你肯定已经看过很多资料了! 这里就不过多的介绍了!关于MultiWorkerMirroredStrategy,这里给出一些参考资料。一般而言多机训练就选这个!
https://www.tensorflow.org/guide/distributed_training
注意支持的API!
分布式数据集配置
这里就是Tensorflow坑爹的地方了!以为有了上面的代码,就应该完美运行,可惜的是,Tensorflow依然会报一大堆错,说你的数据集并不是可分布式的。所以你需要设置一下,让你的数据集可以支持分布式!这里太坑爹了,网上都没有教程,最后在官方文档的一行小字中找到的。
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
dataset = dataset.with_options(options)
加上这一段话!就可以支持分布式了!文档在这 https://www.tensorflow.org/api_docs/python/tf/data/experimental/DistributeOptions
当然你也可以使用dataset.shard()函数来设置分布式的分片,前提是每个worker使用的数据集是完全一模一样的,不然分片没有意义。
https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shard
运行
至此!我们的代码就写完了!完整的代码如下!
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import tensorflow as tf
import os
import json
import time
import numpy as np
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "5" # 少打点log
os.environ['TF_CONFIG'] = json.dumps({
'cluster': {
'worker': ["localhost:20000", "localhost:20001"]
},
'task': {'type': 'worker', 'index': 0}
})
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
with strategy.scope():
inputs = Input(shape=(5,))
model = Dense(10)(inputs)
model = Dense(1)(model)
TargetModel = Model(inputs=inputs, outputs=model)
TargetModel.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss='mse'
)
indata = np.ones((1000, 5)).astype(np.float32)
outdata = np.ones((1000, 1)).astype(np.float32)
indata = np.reshape(indata, (1000, 5))
outdata = np.reshape(outdata, (1000, 1))
dataset = tf.data.Dataset.from_tensor_slices((indata, outdata))
dataset = dataset.batch(20)
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
dataset = dataset.with_options(options)
for _ in range(20):
TargetModel.fit(dataset, epochs=5, steps_per_epoch=10)
我们在机器上,先运行一遍,然后把 ‘task’: {‘type’: ‘worker’, ‘index’: 0} 改为 ‘task’: {‘type’: ‘worker’, ‘index’: 1} 再运行一遍(你可以写成传入参数的形式)。然后就会发现双方同时训练,并且loss都是一样的了!如果你是在不同机器上运行的话,只需要把localhost改成对应ip就ok,端口随意。
坑与须知
以下是我本人做实验总结的一些东西
- 模型是同步的,即所有worker上模型都是一样的,他们采用All Recude的方式。如果要异步的话,应该要采用PS模式,目前而言,PS模式还不支持Keras!
- 分布式模型的所有的数据操作,如predict和fit,都要用数据集的形式!这里可太坑爹了。要解决的话,自己复制一遍模型,然后这个模型不要用with strategy.scope()括起来。
- 所有的worker代码都会执行,只有fit的时候会同步执行(好像predict也会)!所以最好自己定好一个主worker,一般是第0个。
- 多输入模型不兼容,不知道是自己代码问题还是别的问题,太坑爹了,我已经转战PyTorch了!
- 在本例子中,数据集是不共享的,但是他们的梯度是共享的。相当于有两份数据集,在不同的机器上跑了。这也是分布式训练,数据并行的意义 (这在强化学习中非常有用!) 如果要用相同的数据集,但是不同机器跑不同部分的话,可以设置dataset.shard()。