Tensorflow分布式训练
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2019/10/10 13:50
# @Site :
# @File : distributed_MNIST.py
# @Software: PyCharm
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
tf.app.flags.DEFINE_string("ps_hosts", "10.10.50.1:2223", "ps hosts")
tf.app.flags.DEFINE_string("worker_hosts", "10.10.50.1:2225,10.10.50.2:2225", "worker hosts")
tf.app.flags.DEFINE_string("job_name", "worker", "'ps' or 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
tf.app.flags.DEFINE_integer("num_workers", 2, "Number of workers")
tf.app.flags.DEFINE_boolean("is_sync", False, "using synchronous training or not")
FLAGS = tf.app.flags.FLAGS
def model(images):
"""Define a simple mnist classifier"""
net = tf.layers.dense(images, 500, activation=tf.nn.relu)
net = tf.layers.dense(net, 500, activation=tf.nn.relu)
net = tf.layers.dense(net, 10, activation=None)
return net
def main(_):
ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts.split(",")
# create the cluster configured by `ps_hosts' and 'worker_hosts'
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
# create a server for local task
server = tf.train.Server(cluster, job_name=FLAGS.job_name,
task_index=FLAGS.task_index)
if FLAGS.job_name == "ps":
server.join() # ps hosts only join
elif FLAGS.job_name == "worker":
# workers perform the operation
# ps_strategy = tf.contrib.training.GreedyLoadBalancingStrategy(FLAGS.num_ps)
# Note: tf.train.replica_device_setter automatically place the paramters (Variables)
# on the ps hosts (default placement strategy: round-robin over all ps hosts, and also
# place multi copies of operations to each worker host
with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % (FLAGS.task_index),
cluster=cluster)):
# load mnist dataset
mnist = read_data_sets("./dataset", one_hot=True)
# the model
images = tf.placeholder(tf.float32, [None, 784])
labels = tf.placeholder(tf.int32, [None, 10])
logits = model(images)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels))
# The StopAtStepHook handles stopping after running given steps.
hooks = [tf.train.StopAtStepHook(last_step=2000)]
global_step = tf.train.get_or_create_global_step()
optimizer = tf.train.AdamOptimizer(learning_rate=1e-04)
if FLAGS.is_sync:
# asynchronous training
# use tf.train.SyncReplicasOptimizer wrap optimizer
# ref: https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer
optimizer = tf.train.SyncReplicasOptimizer(optimizer, replicas_to_aggregate=FLAGS.num_workers,
total_num_replicas=FLAGS.num_workers)
# create the hook which handles initialization and queues
hooks.append(optimizer.make_session_run_hook((FLAGS.task_index==0)))
train_op = optimizer.minimize(loss, global_step=global_step,
aggregation_method=tf.AggregationMethod.ADD_N)
# The MonitoredTrainingSession takes care of session initialization,
# restoring from a checkpoint, saving to a checkpoint, and closing when done
# or an error occurs.
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=(FLAGS.task_index == 0),
checkpoint_dir="./checkpoint_dir",
hooks=hooks) as mon_sess:
while not mon_sess.should_stop():
# mon_sess.run handles AbortedError in case of preempted PS.
img_batch, label_batch = mnist.train.next_batch(32)
_, ls, step = mon_sess.run([train_op, loss, global_step],
feed_dict={images: img_batch, labels: label_batch})
if step % 100 == 0:
print("Train step %d, loss: %f" % (step, ls))
if __name__ == "__main__":
tf.app.run()
三台机器上都需要有代码,分别运行:
python distributed_MNIST.py --job_name=ps --task_index=0
python distributed_MNIST.py --job_name=worker --task_index=0
python distributed_MNIST.py --job_name=worker --task_index=1
运行效果:
worker0:
Train step 0, loss: 2.435655
Train step 200, loss: 0.563108
Train step 300, loss: 0.260397
Train step 400, loss: 0.419546
Train step 500, loss: 0.179173
Train step 600, loss: 0.246722
Train step 700, loss: 0.256308
Train step 800, loss: 0.208148
Train step 900, loss: 0.176686
Train step 1000, loss: 0.184715
Train step 1200, loss: 0.132179
Train step 1300, loss: 0.234628
Train step 1400, loss: 0.136147
Train step 1500, loss: 0.073287
Train step 1600, loss: 0.125461
Train step 1700, loss: 0.222780
Train step 1800, loss: 0.089276
Train step 1900, loss: 0.050637
worker1:
Train step 100, loss: 0.661454
Train step 1100, loss: 0.185059
上一篇: Oracle 导入导出工具
下一篇: 导出CSV文件工具类
推荐阅读
-
64位CentOS系统下安装配置伪分布式hadoop 2.5.2
-
detectron2训练自己的数据实现目标检测和关键点检测(一) ubuntu18.04安装测试detectron2
-
An introduction to Generative Adversarial Networks (with code in TensorFlow)
-
caffe使用命令行方式训练预测mnist、cifar10及自己的数据集
-
TensorFlow实现Logistic回归
-
分布式事务之深入理解什么是2PC、3PC及TCC协议?
-
tensorflow三种模型的加载和保存的方法(.ckpt,.pb,SavedModel)
-
Twitter的分布式自增ID算法snowflake
-
Keras版MobileNetv2模型是否带口罩二分类问题训练
-
基于TensorFlow2.x的实时多人二维姿势估计