对使用keras代码的模型中间层可视化
程序员文章站
2022-07-12 08:13:17
...
keras或者tf.keras在部署神经网络结构时十分方便。但是很多时候也需要多余训练的中间过程可视化。
模型保存
对于keras可使用save()方法保存模型的参数。
model.save('model_32.h5')
这里 model 是网络需要训练的模型。keras使用HDF5文件系统来保存模型。可使用可视化工具HDFViewer来查看里面的层和数据。
在里面不仅包含需要训练的权重,也包含BN,ReLU等操作。
可使用python代码读取各个神经层之间的连接权重
import h5py
# 模型地址
MODEL_PATH = '.../model_32.h5'
print("读取模型中...")
with h5py.File(MODEL_PATH, 'r') as f:
layer_1 = f['/model_weights/C0_1dwconv']
# layer_1_bias = dense_1['bias:0'][:]
layer_1_kernel = layer_1 ['depthwise_kernel:0'][:]
print("第一层的连接权重矩阵:\n%s\n"%layer_1_kernel )
值得注意的是,我们得到的这些矩阵的数据类型都是numpy.ndarray。
中间层可视化
我们期望把图片经过训练后的网络的中间层也可视化。
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Model
import pickle, os, time, cv2
import numpy as np
from tensorflow.keras import backend as K
import matplotlib.pyplot as plt
from imrr_Net import *
def load_data():
'''
Test data
'''
test_dir = '.../imageNet32/val_data'
fo = open(test_dir, 'rb')
dicts = pickle.load(fo,encoding='iso-8859-1')
fo.close()
data = dicts['data']
label = np.array(dicts['labels'])
#label = label - 1
test_data = data[101:102,:]
test_label = label[101:102]
print(test_label)
img = test_data.reshape(32, 32, 3)
#### visualization
# cv2.imwrite("./cat.jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
# cv2.waitKey (0)
# cv2.destroyAllWindows()
plt.imshow(img, cmap='gray')
#plt.savefig(".../IMRR_32/%d.jpg"%(i+1))
plt.savefig(".../IMRR_32/img.jpg")
#print(data.shape)
#print(test_data.shape)
test_data = test_data.reshape((1, 32*32, 3), order = 'F')
test_data = test_data.reshape(1, 32, 32, 3)
test_data = test_data.astype(np.float32)
return test_data
def train():
# load network
model = imrr_net()
######## load h5 file
model.load_weights(r'.../model_32.h5')
###### load data
test_pic = load_data()
######## You should known the layers name
layer_1 = K.function([model.layers[0].input], [model.get_layer('IMRR_32_shuffle_2').output])
f1 = layer_1([test_pic])[0]
for i in range(128):
show_img = f1[0,:,:,i]
print(show_img.shape)
# show_img.shape = [32, 32]
# plt.figure(figsize=(32, 32), dpi=80)
######## 两种可视化方法, cv2和plt,显示的效果不同。
# plt.imshow(show_img, cmap='gray')
# 图片存储地址
# plt.savefig(".../IMRR_32_shuffle_2/fea_"+ str(i)+".jpg")
# plt.close()
# 图片存储地址
cv2.imwrite(".../IMRR_32/IMRR_32_shuffle_2_cv/fea_"+ str(i)+".jpg", show_img, [int(cv2.IMWRITE_JPEG_QUALITY), 95])
cv2.waitKey (0)
cv2.destroyAllWindows()
if __name__ == "__main__":
train()
中间层的可视化结果会以图片的形式保存在固定的文件夹里。
参考
Keras入门(二)模型的保存、读取及加载
使用keras框架编写的深度模型 输出及每一层的特征可视化
keras模型可视化,层可视化及kernel可视化
keras 特征图可视化(中间层)
上一篇: 微信小程序开发第三篇 数据绑定,点击事件,数组循环
下一篇: Keras问题总结
推荐阅读