欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

对使用keras代码的模型中间层可视化

程序员文章站 2022-07-12 08:13:17
...

keras或者tf.keras在部署神经网络结构时十分方便。但是很多时候也需要多余训练的中间过程可视化。

模型保存

对于keras可使用save()方法保存模型的参数。

model.save('model_32.h5')

这里 model 是网络需要训练的模型。keras使用HDF5文件系统来保存模型。可使用可视化工具HDFViewer来查看里面的层和数据。
在里面不仅包含需要训练的权重,也包含BN,ReLU等操作。
对使用keras代码的模型中间层可视化
可使用python代码读取各个神经层之间的连接权重

import h5py
# 模型地址
MODEL_PATH = '.../model_32.h5'
print("读取模型中...")
with h5py.File(MODEL_PATH, 'r') as f:
    layer_1 = f['/model_weights/C0_1dwconv']
    # layer_1_bias =  dense_1['bias:0'][:]
    layer_1_kernel = layer_1 ['depthwise_kernel:0'][:]

print("第一层的连接权重矩阵:\n%s\n"%layer_1_kernel )

值得注意的是,我们得到的这些矩阵的数据类型都是numpy.ndarray。

中间层可视化

我们期望把图片经过训练后的网络的中间层也可视化。

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Model
import pickle, os, time, cv2
import numpy as np
from tensorflow.keras import backend as K
import matplotlib.pyplot as plt
from imrr_Net import *

def load_data():
    '''
    Test data
    '''
    test_dir = '.../imageNet32/val_data'
    fo = open(test_dir, 'rb')
    dicts = pickle.load(fo,encoding='iso-8859-1')
    fo.close()
    data = dicts['data']
    label = np.array(dicts['labels'])
    #label = label - 1
    test_data = data[101:102,:]
    test_label = label[101:102]
    print(test_label)
    img = test_data.reshape(32, 32, 3)
    #### visualization
    # cv2.imwrite("./cat.jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
    # cv2.waitKey (0)
    # cv2.destroyAllWindows()
    plt.imshow(img, cmap='gray')
    #plt.savefig(".../IMRR_32/%d.jpg"%(i+1))
    plt.savefig(".../IMRR_32/img.jpg")
    #print(data.shape)
    #print(test_data.shape)
    test_data = test_data.reshape((1, 32*32, 3), order = 'F')
    test_data = test_data.reshape(1, 32, 32, 3)
    test_data = test_data.astype(np.float32)
    
    return test_data

def train():
    # load network
    model = imrr_net()
    ######## load h5 file
    model.load_weights(r'.../model_32.h5')
    ###### load data
    test_pic = load_data()
    ######## You should known the layers name
    layer_1 = K.function([model.layers[0].input], [model.get_layer('IMRR_32_shuffle_2').output])
    f1 = layer_1([test_pic])[0]

    for i in range(128):
        show_img = f1[0,:,:,i]
        print(show_img.shape)
        # show_img.shape = [32, 32]
        # plt.figure(figsize=(32, 32), dpi=80)
        ######## 两种可视化方法, cv2和plt,显示的效果不同。
        # plt.imshow(show_img, cmap='gray')
        # 图片存储地址
        # plt.savefig(".../IMRR_32_shuffle_2/fea_"+ str(i)+".jpg")
        # plt.close()
        # 图片存储地址
        cv2.imwrite(".../IMRR_32/IMRR_32_shuffle_2_cv/fea_"+ str(i)+".jpg", show_img, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
        cv2.waitKey (0) 
        cv2.destroyAllWindows()

if __name__ == "__main__":
    train()

中间层的可视化结果会以图片的形式保存在固定的文件夹里。

参考
Keras入门(二)模型的保存、读取及加载
使用keras框架编写的深度模型 输出及每一层的特征可视化
keras模型可视化,层可视化及kernel可视化
keras 特征图可视化(中间层)

相关标签: keras