分布式TensorFlow模式
1. 简介
Tensorflow API提供了Cluster、 Server以及 Supervisor来支持模型的分布式训练。 关于Tensorflow的分布式训练介绍可以参考 Distributed Tensorflow。
简单的概括说明如下:
- Tensorflow分布式Cluster由多个Task组成,每个Task对应一个tf.train.Server实例, 作为Cluster的一个单独节点;
- 多个相同作用的Task可以被划分为一个job,例如ps job作为参数服务器只保存Tensorflow model的参数,而worker job则作为计算节点只执行计算密集型的Graph计算。
- Cluster中的Task会相对进行通信,以便进行状态同步、参数更新等操作。
Tensorflow分布式集群的所有节点执行的代码是相同的。分布式任务代码具有固定的模式:
1 | # 第1步:命令行参数解析,获取集群的信息ps_hosts和worker_hosts,以及当前节点的角色信息job_name和task_index |
复制
2. Tensorflow分布式训练代码框架
根据上面说到的Tensorflow分布式训练代码固定模式,如果要编写一个分布式的Tensorlfow代码,其框架如下所示:
1 | import tensorflow as tf |
复制
对于所有Tensorflow分布式代码,可变的只有两点:
- 构建tensorflow graph模型代码;
- 每一步执行训练的代码;
3. 分布式MNIST任务
我们通过修改tensorflow/tensorflow提供的mnist_softmax.py来构造分布式的MNIST样例来进行验证。 修改后的代码请参考如下:
1 | from tensorflow.examples.tutorials.mnist import input_data |
复制
我们同样通过tensorlfow的Docker image来启动一个容器来进行验证。
1 | $ docker run -d -v /path/to/your/code:/tensorflow/mnist --name tensorflow tensorflow/tensorflow |
复制
启动tensorflow之后,启动4个Terminal,然后通过下面命令进入tensorflow容器,切换到/tensorflow/mnist目录下
1 | $ docker exec -ti tensorflow /bin/bash |
复制
然后在四个Terminal中分别执行下面一个命令来启动Tensorflow cluster的一个task节点,
1 | # Start ps 0 |
复制