Tensorflow基础(必备TF框架知识)_5:高级API接口
程序员文章站
2022-07-08 09:38:50
...
本文取材自imooc课程《Python3+TensorFlow打造人脸识别智能小程序》
TensorFlow在深度学习中高级封装——TF-Slim
import tensorflow.contrib.slim as slim
slim.layers
net = slim.conv2d(inputs, num_outputs=64, kernel_size=[11, 11], stride=4, padding='SAME',
weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
weights_regularizer=slim.l2_regularizer(0.0005), scope='conv1')
slim.arg_scope(list_ops_or_scope, **kwargs)
- list_ops_or_scope:操作列表或作用域列表
- kwargs:以keyword=value方式显示
with slim.arg_scope(list_ops_or_scope=[slim.conv2d], padding='SAME',
weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
weights_regularizer=slim.l2_regularizer(0.0005)):
net = slim.conv2d(inputs, num_outputs=64, kernel_size=[11, 11], scope='conv1')
net = slim.conv2d(net, num_outputs=128, kernel_size=[11, 11], padding='VALID', scope='conv2')
net = slim.conv2d(net, num_outputs=256, kernel_size=[11, 11], scope='conv3')
slim.batch_norm
- normalizer_fn=slim.batch_norm, normalizer_params=batch_norm_params
- slim.batch_norm(is_training, zero_debias_moving_mean, decay, epsilon, scale, updates_collections)
is_training:训练时置为True,测试时置为False
updates_collections
with slim.arg_scope(list_ops_or_scope=[slim.conv2d],
weights_initializer=slim.variance_scaling_initializer(),
weights_regularizer=slim.l2_regularizer(weight_decay),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params):
with slim.arg_scope(list_ops_or_scope=[slim.batch_norm], **batch_norm_params):
with slim.arg_scope(list_ops_or_scope=[slim.max_pool2d], padding='SAME') as arg_sc:
return arg_sc
# 方法一
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(total_loss)
# 方法二
batchnorm_updates = tf.get_collection(UPDATE_OPS_COLLECTION)
batchnorm_updates_op = tf.group(*batchnorm_updates)
slim net
from tensorflow.contrib.slim.python.slim.nets import alexnet
from tensorflow.contrib.slim.python.slim.nets import inception
from tensorflow.contrib.slim.python.slim.nets import overfeat
from tensorflow.contrib.slim.python.slim.nets import resnet_utils
from tensorflow.contrib.slim.python.slim.nets import resnet_v1
from tensorflow.contrib.slim.python.slim.nets import resnet_v2
from tensorflow.contrib.slim.python.slim.nets import vgg
slim loss
- 经验风险最小
分类损失
平方损失
等等 - 正则化LOSS
变量的L2正则化约束
net = slim.conv2d(inputs, num_outputs=64, kernel_size=[11, 11], stride=4, padding='SAME',
weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
weights_regularizer=slim.l2_regularizer(0.0005), scope='conv1')
# 获取loss
regularization_loss = tf.add_n(slim.losses.get_regularization_losses())
# 获取全部loss的两种方法
# 方法一
total_loss_1 = classification_loss + sum_of_squares_loss + pose_loss + regularization_loss
# 方法二
slim.losses.add_loss(pose_loss)
total_loss_2 = slim.losses.get_total_loss()
slim learn
# 定义优化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
# 定义训练
train_tensor = slim.learning.create_train_op(total_loss, optimizer)
# 开始训练
slim.learning.train(train_op=train_tensor, logdir=train_log_dir)
- 学习率
tf.train.piecewise_constant 分段常数衰减
tf.train.inverse_time_decay 反时限衰减
tf.train.polynomial_decay 多项式衰减
tf.train.exponential_decay 指数衰减 - 优化器
tf.train.Optimizer
tf.train.GradientDescentOptimizer
tf.train.MomentumOptimizer
tf.train.AdamOptimizer
tf.train.RMSPropOptimizer - 度量
slim.metrics
# value_op:幂等操作,返回度量的当前值; update_op:执行聚合步骤以及返回度量值的操作 mae_value_op, mae_update_op = slim.metrics.streaming_mean_absolute_error(predictions, labels)
slim.evaluation
- slim.evaluation.evaluation_loop()
slim.evaluation.evaluation_loop( master, checkpoint_dir, logdir, num_evals, eval_op, summary_op, eval_interval_secs)
- slim.evaluation.evaluate_once()
slim data
from tensorflow.contrib.slim.python.slim.data import data_decoder
from tensorflow.contrib.slim.python.slim.data import data_provider
from tensorflow.contrib.slim.python.slim.data import dataset
from tensorflow.contrib.slim.python.slim.data import dataset_data_provider
from tensorflow.contrib.slim.python.slim.data import parallel_reader
from tensorflow.contrib.slim.python.slim.data import prefetch_queue
from tensorflow.contrib.slim.python.slim.data import tfexample_decoder
除了TF-Slim外,比较主流的Tensorflow高层封装还有Keras、TFLearn、TensorLayer。