欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

CNN中做归一化用到的相关API(自己的小总结:tf.nn.moments()函数理解) 以及CNN中NHWC转NCHW的方法

程序员文章站 2024-02-29 21:17:04
...

Note1:CNN中NHWC转NCHW的方法:

比如卷积层输出的net形状为:[2, 3, 3, 4]
即:NHWC为[2, 3, 3, 4]
N:一个batch内图片的数量。
H:垂直高度方向的像素个数。
W:水平宽度方向的像素个数。
C:通道数

现为了做BN,想先将NHWC转为NCWH=[2, 4, 3, 3]
方法呢?可以使用TensorFlow中的tf.transpose函数实现!
n = tf.transpose(net, (0, 3, 2, 1)),,,其中第二个参数是转换后的张量中,原始张量的维度编号(原来是(0, 1, 2, 3))
得到n的形状是[2, 4, 3, 3]
即:NCWH为[2, 4, 3, 3]
下面举个例子:三维的,供参考

import tensorflow as tf
import numpy as np

x = [[[1,   2],
      [3,   4]],
     [[11,  22],
      [33,  44]],
     [[111, 222],
      [333, 444]]]  # x的shape为(3, 2, 2),通道数为3
y = tf.transpose(x, (1, 2, 0))  # 其中第二个参数是转换后的张量中,原始张量的维度编号。编号0原本在首位,现在处于末位。
with tf.Session() as sess:
    # print(y.eval())
    '''
    [[[  1  11 111]
      [  2  22 222]]

     [[  3  33 333]
     [  4  44 444]]]
    '''
    # print(x[0, :, :])
    # 出现报错:TypeError: list indices must be integers or slices, not tuple
    # 这是因为此时矩阵存储在列表(list)中,而列表中的每一个元素大小可能不同,因此不能直接取其某一列进行操作
    # 解决方案
    # 可以利用numpy.array函数将其转变为标准矩阵,再对其进行取某一列的操作:若下所示:
    # print(np.array(x)[0, :, :])
    '''
    取第一个维度首元素如下:
    [[1 2]
     [3 4]]
    '''
    # print(np.array(x)[:, :, 0])
    '''
    [[  1   3]
     [ 11  33]
     [111 333]]
    '''
    # print(y[0, :, :])  # Tensor("strided_slice:0", shape=(2, 3), dtype=int32)
    # print(y[0, :, :].eval())  # 等价于print(sess.run(y[0, :, :]))
    '''
    取y第一个维度首元素如下:
    [[  1  11 111]
     [  2  22 222]]
    '''
    # print(np.shape(y))  # (2, 2, 3)
    print(y[:, :, 0].eval())
    '''
    取y最后一个维度首元素如下:
    [[1 2]
     [3 4]]
    '''

    # 可见x[0, :, :] = y[:, :, 0]。张量已经由NCHW转换为NHWC格式。

Note2:tf.nn.moments()函数理解

import numpy as np
import tensorflow as tf
net = tf.constant(np.reshape(np.asarray(range(0, 72)), (2, 3, 3, 4)))
n = tf.transpose(net, (0, 3, 2, 1))
# BN中
m0, n0 = tf.nn.moments(n, axes=(0, 2, 3))
m1, v1 = tf.nn.moments(net, axes=(0, 1, 2))
m2, v2 = tf.nn.moments(net, axes=(0, 2, 1))
with tf.Session():
    '''查看一下net和n的值'''
    print(net.eval())
    print(10*'-')
    print(n.eval())
    
    '''查看一下transpose后值的情况'''
    print(net[:, :, :, 0].eval())
    print(10*'-')
    print(n[:, 0, :, :].eval())
    
    '''验证效果:发现transpose后,方便了计算,而且计算结果正确'''
    print(m1.eval())  # [34 35 36 37]
    print(m0.eval())  # [34 35 36 37]
    print(m2.eval())  # [34 35 36 37]

手写解释代码计算过程
CNN中做归一化用到的相关API(自己的小总结:tf.nn.moments()函数理解) 以及CNN中NHWC转NCHW的方法CNN中做归一化用到的相关API(自己的小总结:tf.nn.moments()函数理解) 以及CNN中NHWC转NCHW的方法