Keras 使用 Lambda层详解
程序员文章站
2022-06-15 17:41:58
我就废话不多说了,大家还是直接看代码吧!from tensorflow.python.keras.models import sequential, modelfrom tensorflow.pyth...
我就废话不多说了,大家还是直接看代码吧!
from tensorflow.python.keras.models import sequential, model from tensorflow.python.keras.layers import dense, flatten, conv2d, maxpool2d, dropout, conv2dtranspose, lambda, input, reshape, add, multiply from tensorflow.python.keras.optimizers import adam def deconv(x): height = x.get_shape()[1].value width = x.get_shape()[2].value new_height = height*2 new_width = width*2 x_resized = tf.image.resize_images(x, [new_height, new_width], tf.image.resizemethod.nearest_neighbor) return x_resized def generator(scope='generator'): imgs_noise = input(shape=inputs_shape) x = conv2d(filters=32, kernel_size=(9,9), strides=(1,1), padding='same', activation='relu')(imgs_noise) x = conv2d(filters=64, kernel_size=(3,3), strides=(2,2), padding='same', activation='relu')(x) x = conv2d(filters=128, kernel_size=(3,3), strides=(2,2), padding='same', activation='relu')(x) x1 = conv2d(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x) x1 = conv2d(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x1) x2 = add()([x1, x]) x3 = conv2d(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x2) x3 = conv2d(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x3) x4 = add()([x3, x2]) x5 = conv2d(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x4) x5 = conv2d(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x5) x6 = add()([x5, x4]) x = maxpool2d(pool_size=(2,2))(x6) x = lambda(deconv)(x) x = conv2d(filters=64, kernel_size=(3, 3), strides=(1,1), padding='same',activation='relu')(x) x = lambda(deconv)(x) x = conv2d(filters=32, kernel_size=(3, 3), strides=(1,1), padding='same',activation='relu')(x) x = lambda(deconv)(x) x = conv2d(filters=3, kernel_size=(3, 3), strides=(1, 1), padding='same',activation='tanh')(x) x = lambda(lambda x: x+1)(x) y = lambda(lambda x: x*127.5)(x) model = model(inputs=imgs_noise, outputs=y) model.summary() return model my_generator = generator() my_generator.compile(loss='binary_crossentropy', optimizer=adam(0.7, decay=1e-3), metrics=['accuracy'])
补充知识:含有lambda自定义层keras模型,保存遇到的问题及解决方案
一,许多应用,keras含有的层已经不能满足要求,需要透过lambda自定义层来实现一些layer,这个情况下,只能保存模型的权重,无法使用model.save来保存模型。保存时会报
typeerror: can't pickle _thread.rlock objects
二,解决方案,为了便于后续的部署,可以转成tensorflow的pb进行部署。
from keras.models import load_model import tensorflow as tf import os, sys from keras import backend as k from tensorflow.python.framework import graph_util, graph_io def h5_to_pb(h5_weight_path, output_dir, out_prefix="output_", log_tensorboard=true): if not os.path.exists(output_dir): os.mkdir(output_dir) h5_model = build_model() h5_model.load_weights(h5_weight_path) out_nodes = [] for i in range(len(h5_model.outputs)): out_nodes.append(out_prefix + str(i + 1)) tf.identity(h5_model.output[i], out_prefix + str(i + 1)) model_name = os.path.splitext(os.path.split(h5_weight_path)[-1])[0] + '.pb' sess = k.get_session() init_graph = sess.graph.as_graph_def() main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes) graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=false) if log_tensorboard: from tensorflow.python.tools import import_pb_to_tensorboard import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir) def build_model(): inputs = input(shape=(784,), name='input_img') x = dense(64, activation='relu')(inputs) x = dense(64, activation='relu')(x) y = dense(10, activation='softmax')(x) h5_model = model(inputs=inputs, outputs=y) return h5_model if __name__ == '__main__': if len(sys.argv) == 3: # usage: python3 h5_to_pb.py h5_weight_path output_dir h5_to_pb(h5_weight_path=sys.argv[1], output_dir=sys.argv[2])
以上这篇keras 使用 lambda层详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
推荐阅读
-
Java8中Lambda表达式使用和Stream API详解
-
Android 中Lambda表达式的使用实例详解
-
vue-cli项目中使用公用的提示弹层tips或加载loading组件实例详解
-
Android 中Lambda表达式的使用实例详解
-
Android逆向之旅---Native层的Hook神器Cydia Substrate使用详解
-
Android页面中引导蒙层的使用方法详解
-
python lambda的使用详解
-
Java8新特性之Lambda表达式 使用详解
-
C# 表达式树 创建、生成、使用、lambda转成表达式树~表达式树的知识详解
-
java 使用BeanFactory实现service与dao层解耦合详解