欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

tensorflow学习笔记-会话机制(session)

程序员文章站 2024-03-20 17:45:04
...

在TensorFlow中,有两种用于运行计算图(graph)的会话(session)

  • tf.Session( )

  • tf.InteractivesSession( )

1. tf.Session( )

要使用tf,我们必须先构建(定义)graph,之后才能运行graph。

1.1 非交互式会话中的例子

import tensorflow as tf

# 构建graph
a = tf.add(3, 5) 

# 运行graph
sess = tf.Session()  # 创建tf.Session的一个对象sess
print(sess.run(a)) 

sess.close()         # 关闭sess对象

一个session可能会占用一些资源,比如变量、队列和读取器(reader)。我们使用sess.close()关闭会话或者使用上下文管理器释放这些不再使用的资源。

1.2 建议的tf.Session( )写法

import tensorflow as tf  

# 构建graph
matrix1 = tf.constant([[3., 3.]])  
matrix2 = tf.constant([[2.], [2.]])  

product = tf.matmul(matrix1, matrix2)  

# 运行graph
with tf.Session() as sess:          # 使用"with"语句,自动关闭会话
    print(sess.run(product))  

1.3 Fetch(取回)

在使用sess.run( )运行图时,我们可以传入fetches,用于取回某些操作或tensor的输出内容。fetches可以是list,tuple,namedtuple,dict中的任意一个。fetches可以是一个列表,在op的一次运行中一起获得(而不是逐个去获取 tensor)多个tensor值。

import tensoflow as tf
from collections import namedtuple

a = tf.constant([10, 20])
b = tf.constant([1.0, 2.0])
MyData = namedtuple('MyData', ['a', 'b'])

with tf.Session() as sess:
    c = sess.run(a)            # fetches可以为单个数a
    d = sess.run([a, b])       # fetches可以为一个列表[a, b]
    v = sess.run({'k1': MyData(a, b), 'k2': [b, a]}) 

    print(c)
    print(d)
    print(d[0])
    print(v) 
'''
v is a dict and v['k1'] is a MyData namedtuple with the numpy array [10, 20] and the numpy array [1.0, 2.0]. v['k2'] is a list with the numpy array [1.0, 2.0] and the numpy array [10, 20].
'''

1.4 Feed(注入)

TensorFlow提供了feed注入机制, 它可以临时替代graph中任意op操作的输入tensor,可以对graph中任何操作提交补丁(直接插入一个tensor)。

feed机制只在调用它的方法内有效,方法结束,feed就会消失。最常见的用例是把某些特殊操作为feed注入的对象。你可以提供数据feed_dict,作为sess.run( )调用的参数。使用tf.placeholder( ),为某些操作的输入创建占位符。

import tensorflow as tf
import numpy as np

x = np.ones((2, 3))
y = np.ones((3, 2)) 

input1 = tf.placeholder(tf.int32)
input2 = tf.placeholder(tf.int32)

output = tf.matmul(input1, input2)

with tf.Session() as sess:
    print(sess.run(output, feed_dict = {input1:x, input2:y}))

如果没有正确提供tf.placeholder( ),feed操作将产生错误。注意,feed注入的值不能是tf的tensor对象,应该是Python常量、字符串、列表、numpy ndarrays,或者TensorHandles。

1.5 分布式训练

从version 0.8之后,TensorFlow开始支持分布式计算的机器学习,而且TensorFlow会充分利用CPU、GPU等计算资源。如果检测到GPU,TensorFlow会优先使用GPU运行程序。用字符串标识设备,目前支持的设备包括:

“/cpu:0”:机器的第一个CPU。

“/gpu:0”:机器的第一个GPU, 如果有的话

“/gpu:1”:机器的第二个GPU, 以此类推

当计算机有多个GPU时,通过tf.device( ),我们可以指定用哪个GPU来执行。代码示例如下:

# 在with tf.device()下,构建graph
with tf.device("/gpu:0"):
    a = tf.constant([[3., 3.]])
    b = tf.constant([[2.], [2.]])
    product = tf.matmul(a, b)

# 运行graph
with tf.Session() as sess:    
    print(sess.run(product))

2. tf.InteractivesSession( )

当python编辑环境是shell、IPython等交互式环境时,我们使用类tf.InteractiveSession代替类tf.Session,用方法tensor.eval( ),operation.run( ) 代替sess.run( ),这样可避免用一个变量sess来持有会话。其中更多地使用 tensor.eval(),所有的表达式都可以看作是tensor。

// 进入python3交互式环境
# python3

>>> import tensorflow as tf  

// 进入一个交互式会话
>>> sess = tf.InteractiveSession()

>>> a = tf.constant(5.0)
>>> b = tf.constant(6.0)
>>> c = a * b

// We can just use 'c.eval()' without passing 'sess'
>>> print(c.eval()) 

>>> sess.close()   // 关闭交互式会话

>>> exit()        // 退出python3交互式环境