Stanford CS20学习笔记(3)
tensorflow中的条件语句
假设需要实现Huber loss function, 公式如下:
除非在eager模式下, 否则不能写为 if y - y_pred<delta:
(因为流图的性质), 这里面就要用到tf.cond(pred, fn1, fn2, name=None)
def huber_loss(labels, predictions, delta=14.0):
residual = tf.abs(labels - predictions)
def f1(): return 0.5 * tf.square(residual)
def f2(): return delta * residual - 0.5 * tf.square(delta)
return tf.cond(residual < delta, f1, f2)
TF 中的Control Flow:
Control Flow Ops | tf.group, tf.count_up_to, tf.cond, tf.case, tf.while_loop, … |
Comparison Ops | tf.equal, tf.not_equal, tf.less, tf.greater, tf.where, … |
Logical Ops | tf.logical_and, tf.logical_not, tf.logical_or, tf.logical_xor |
Debugging Ops | tf.is_finite, tf.is_inf, tf.is_nan, tf.Assert, tf.Print, … |
tf具有这些control flow的原因:Since TF builds graph before computation, we have
to specify all possible subgraphs beforehand.
PyTorch’s dynamic graphs and TF’s eager execution
help overcome this
tf.data
placehoder的优缺点:
优点: put the data processing outside TensorFlow, making it easy to
do in Python:可以用python的方式,简单的处理.
缺点: users often end up processing their data in a single thread
and creating data bottleneck that slows execution down.:但是通常是单线程处理数据, 比较慢.
tf.data的一些操作
构建数据集:
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
遍历数据集::
●iterator = dataset.make_one_shot_iterator()
Iterates through the dataset exactly once. No need to initialization.
● iterator = dataset.make_initializable_iterator()
Iterates through the dataset as many times as we want. Need to initialize with each epoch.
迭代器遍历一遍数据集:(1 epoch)
iterator = dataset.make_one_shot_iterator()
X, Y = iterator.get_next() # X is the birth rate, Y is the life expectancy
with tf.Session() as sess:
print(sess.run([X, Y])) # >> [1.822, 74.82825]
print(sess.run([X, Y])) # >> [3.869, 70.81949]
print(sess.run([X, Y])) # >> [3.911, 72.15066]
迭代器多次遍历数据集(n epochs)
iterator = dataset.make_initializable_iterator()
...
for i in range(100):
sess.run(iterator.initializer)
total_loss = 0
try:
while True:
sess.run([optimizer])
except tf.errors.OutOfRangeError:
pass
实例dataset的一些方法:
dataset = dataset.shuffle(1000)
dataset = dataset.repeat(100)
dataset = dataset.batch(128)
dataset = dataset.map(lambda x: tf.one_hot(x, 10))
#convert each elem of dataset to one_hot vector
why tf.data and when tf,data
●For prototyping, feed dict can be faster and easier to write (pythonic)
● tf.data is tricky to use when you have complicated preprocessing or multiple
data sources
● NLP data is normally just a sequence of integers. In this case, transferring the
data over to GPU is pretty quick, so the speedup of tf.data isn’t that large
Summary
本节主要内容是 tf 中的一些控制流
, 以及 tf.data
这个模块. tf.data
模块功能很多, 有需要的时候再去学,是构建并处理数据集的很好方法.因为它会自动调用多线程去处理数据集.
上一篇: SCTP协议跟踪