TensorFlow2子类模型多输入多输出
程序员文章站
2022-05-26 19:30:33
...
在最近的一次项目中,因为需要模型具有多输入多输出,而且我的一个输出是一个包含张量的列表,所以无法使用函数式API或者容器去造模型,因为列表的添加操作不是一个层,而这两类的输出必须是层的结果,虽然可以用tf.keras.layers.Lambda将此操作变成层,但总归是牵强的,所以使用子类模型。
class Test(keras.Model):
def __init__(self):
super(Test, self).__init__()
filters = 64
initializer = tf.random_normal_initializer(0., 0.02)
self.conv1 = Conv2D(filters, 4, 2, 'same', use_bias=False,
kernel_initializer=initializer)
self.bn1 = BatchNormalization()
self.conv2 = Conv2D(filters*2, 4, 2, 'same', use_bias=False,
kernel_initializer=initializer)
self.bn2 = BatchNormalization()
self.conv3 = Conv2D(filters*4, 4, 2, 'same', use_bias=False,
kernel_initializer=initializer)
self.bn3 = BatchNormalization()
def call(self, inputs):
x1 = inputs[0]
x2 = inputs[1]
skips = [] # 存结果的列表
x1_1 = tf.nn.relu(self.bn1(self.conv1(x1)))
x2_1 = tf.nn.relu(self.bn1(self.conv1(x2)))
skips.append(x1_1)
x1_2 = tf.nn.relu(self.bn1(self.conv1(x1_1)))
x2_2 = tf.nn.relu(self.bn1(self.conv1(x2_1)))
skips.append(x1_2)
x1_3 = tf.nn.relu(self.bn1(self.conv1(x1_2)))
x2_3 = tf.nn.relu(self.bn1(self.conv1(x2_2)))
skips.append(x1_3)
return [skips, x2_3]
model = Test()
model.build(input_shape=[(batch_size, data_size), (batch_size, data_size)])
input1 = tf.random.normal([batch_size, data_size])
input2 = tf.random.normal([batch_size, data_size])
out_put1, out_put2 = model([input1, input2])
TF2用着真的是太难受了,网上的教程都比较泛,对一些细节的处理实例太难找了,找着了还大概率是tf.compat.v1。。。 做完这次我真的好好去看看Torch了。。。
另外,在TF2的图执行模式里,是无法使用for等循环的,但有专门的库函数tf.while_loop,反正我还不怎么会用,而且还要转sess,或者直接可以转eager模式就可以解决。
还有另外的思路@tf.function和tf.while
我的问题可能在一些大佬看来很低级,但确实给我造成了麻烦,我本以为教程上的东西就能解决一切问题的了,还是太弱。若要朋友想指正我的说法或者想要交流TF2里的坑,请私信我。
上一篇: keras 多输入多输出网络
下一篇: jQuery实现商城首页幻灯片的效果