获得tensorflow模型的参数量
程序员文章站
2024-02-25 22:00:51
...
tensorflow模型训练好之后,通常会保存为.ckpt文件,有时我们想了解一下模型保存了多少参数,这固然可以手动计算,但是速度太慢,这里我写了一个程序可以直接获得模型的参数量,下面是源代码:
转自: https://blog.csdn.net/chnguoshiwushuang/article/details/81588794
from tensorflow.python import pywrap_tensorflow
import os
import numpy as np
model_dir = "models_pretrained/"
checkpoint_path = os.path.join(model_dir, "model.ckpt-82798")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
total_parameters = 0
for key in var_to_shape_map:#list the keys of the model
# print(key)
# print(reader.get_tensor(key))
shape = np.shape(reader.get_tensor(key)) #get the shape of the tensor in the model
shape = list(shape)
# print(shape)
# print(len(shape))
variable_parameters = 1
for dim in shape:
# print(dim)
variable_parameters *= dim
# print(variable_parameters)
total_parameters += variable_parameters
print(total_parameters)
上一篇: 获取url中的query string