tensorflow分布式训练
程序员文章站
2022-05-26 19:15:28
...
深度学习的并行训练分为模型并行和数据并行训练两大类,所谓模型并行指的是模型很大以至于一台机器无法放下时将模型分到不同的机器上,而数据并行则是每台机器都有完整的模型但是处理不同的数据。目前绝大多数深度学习框架支持的都是数据并行。参考博文1中详细介绍了他们的区别,2则实现了一个简单的例程,但是由于其采用了过时的API,导致其并不能在最新的tensorflow中运行起来,这里给出能够运行的代码:
import tensorflow as tf
import numpy as np
FLAGS=tf.app.flags.FLAGS
tf.app.flags.DEFINE_float("learning_rate",0.00003,"Initial learning rate")
tf.app.flags.DEFINE_integer("steps_to_validate",100,"tps of validation")
tf.app.flags.DEFINE_string("ps_hosts","localhost:2222","ps")
tf.app.flags.DEFINE_string("worker_hosts","localhost:2223,localhost:2224","worker")
tf.app.flags.DEFINE_string("job_name","ps","job_name")
tf.app.flags.DEFINE_integer("task_index",0,"Index")
learning_rate=FLAGS.learning_rate
steps_to_validate=FLAGS.steps_to_validate
ps_hosts=FLAGS.ps_hosts.split(",")
worker_hosts=FLAGS.worker_hosts.split(",")
cluster=tf.train.ClusterSpec({"ps":ps_hosts,"worker":worker_hosts})
server=tf.train.Server(cluster,job_name=FLAGS.job_name,task_index=FLAGS.task_index)
if FLAGS.job_name=="ps":
server.join()
elif FLAGS.job_name=="worker":
with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d"%FLAGS.task_index,cluster=cluster)):
global_step=tf.train.get_or_create_global_step()
input=tf.placeholder("float")
label=tf.placeholder("float")
weight=tf.get_variable("weights",[1],tf.float32,initializer=tf.random_normal_initializer())
bias=tf.get_variable("bias",[1],tf.float32,initializer=tf.random_normal_initializer())
pred=tf.multiply(input,weight)+bias
loss_value=tf.reduce_mean(tf.square(label-pred))
train_op=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_value,global_step=global_step)
init_op=tf.global_variables_initializer()
saver=tf.train.Saver()
tf.summary.scalar("cost",loss_value)
summary_op=tf.summary.merge_all()
sv=tf.train.Supervisor(is_chief=(FLAGS.task_index==0),logdir="./check_point/",init_op=init_op,saver=saver,global_step=global_step,save_model_secs=60)
with sv.managed_session(server.target)as sess:
step=0
while step<1000000:
train_x=np.random.randn(1)
train_y=2*train_x+np.random.randn(1)*0.33+10
_,loss_v,step=sess.run([train_op,loss_value,global_step],feed_dict={input:train_x,label:train_y})
if step%steps_to_validate==0:
w,b=sess.run([weight,bias])
print("step:%d,weight %f,bais:%f,loss:%f"%(step,w,b,loss_v))
可通过如下脚本运行
export CUDA_VISIBLE_DEVICES=''
python disapby.py --job_name=ps --task_index=0 &
export CUDA_VISIBLE_DEVICES=0
python disapby.py --job_name=worker --task_index=0 &
export CUDA_VISIBLE_DEVICES=1
python disapby.py --job_name=worker --task_index=1 &
训练mnist的代码为:
import tensorflow as tf
#from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
import tensorflow.examples.tutorials.mnist.input_data as input_data
tf.app.flags.DEFINE_string("ps_hosts", "localhost:2222", "ps hosts")
tf.app.flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224", "worker hosts")
tf.app.flags.DEFINE_string("job_name", "worker", "'ps' or'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
tf.app.flags.DEFINE_integer("num_workers", 2, "Number of workers")
tf.app.flags.DEFINE_boolean("is_sync", False, "using synchronous training or not")
FLAGS = tf.app.flags.FLAGS
def model(images):
"""Define a simple mnist classifier"""
net = tf.layers.dense(images, 500, activation=tf.nn.relu)
net = tf.layers.dense(net, 500, activation=tf.nn.relu)
net = tf.layers.dense(net, 10, activation=None)
return net
def main(_):
ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts.split(",")
# create the cluster configured by `ps_hosts' and 'worker_hosts'
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
# create a server for local task
server = tf.train.Server(cluster, job_name=FLAGS.job_name,
task_index=FLAGS.task_index)
if FLAGS.job_name == "ps":
server.join() # ps hosts only join
elif FLAGS.job_name == "worker":
# workers perform the operation
# ps_strategy = tf.contrib.training.GreedyLoadBalancingStrategy(FLAGS.num_ps)
# Note: tf.train.replica_device_setter automatically place the paramters (Variables)
# on the ps hosts (default placement strategy: round-robin over all ps hosts, and also
# place multi copies of operations to each worker host
with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % (FLAGS.task_index),
cluster=cluster)):
# load mnist dataset
mnist =input_data.read_data_sets("/tmp/mnist/", one_hot=True)
# the model
images = tf.placeholder(tf.float32, [None, 784])
labels = tf.placeholder(tf.int32, [None, 10])
logits = model(images)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels))
# The StopAtStepHook handles stopping after running given steps.
hooks = [tf.train.StopAtStepHook(last_step=2000000)]
global_step = tf.train.get_or_create_global_step()
optimizer = tf.train.AdamOptimizer(learning_rate=1e-04)
if FLAGS.is_sync:
# asynchronous training
# use tf.train.SyncReplicasOptimizer wrap optimizer
# ref: https://www.tensorflow.org/api_docs/python/tf/train/SyncReplicasOptimizer
optimizer = tf.train.SyncReplicasOptimizer(optimizer, replicas_to_aggregate=FLAGS.num_workers,
total_num_replicas=FLAGS.num_workers)
# create the hook which handles initialization and queues
hooks.append(optimizer.make_session_run_hook((FLAGS.task_index==0)))
train_op = optimizer.minimize(loss, global_step=global_step,
aggregation_method=tf.AggregationMethod.ADD_N)
# The MonitoredTrainingSession takes care of session initialization,
# restoring from a checkpoint, saving to a checkpoint, and closing when done
# or an error occurs.
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=(FLAGS.task_index == 0),
checkpoint_dir="./checkpoint_dir",
hooks=hooks) as mon_sess:
while not mon_sess.should_stop():
# mon_sess.run handles AbortedError in case of preempted PS.
img_batch, label_batch = mnist.train.next_batch(32)
_, ls, step = mon_sess.run([train_op, loss, global_step],
feed_dict={images: img_batch, labels: label_batch})
if step % 100 == 0:
print("Train step %d, loss: %f" % (step, ls))
if __name__ == "__main__":
tf.app.run()
参考
上一篇: php导出csv文件方法封装
下一篇: Pytorch实现多输入图像分类
推荐阅读
-
分布式数据库HBase
-
给AI当老师国内外人工智能训练师相关从业人数将达500万
-
浅析Hadoop完全分布式集群搭建问题
-
python的分布式任务huey如何实现异步化任务讲解_PHP教程
-
redis/分布式文件存储系统/数据库 存储session,解决负载均衡集群中session不一致问题,redissession_PHP教程
-
Redis数据库中实现分布式锁的方法
-
教你在pycharm中使用tensorflow的方法
-
阿里巴巴开源项目:分布式数据库同步系统otter(解决中美异地机房)
-
阿里巴巴开源项目:分布式数据库同步系统otter(解决中美异地机房)
-
基于zookeeper的分布式lock实现