欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

TensorFlow2子类模型多输入多输出

程序员文章站 2022-05-26 19:30:33
...

在最近的一次项目中,因为需要模型具有多输入多输出,而且我的一个输出是一个包含张量的列表,所以无法使用函数式API或者容器去造模型,因为列表的添加操作不是一个层,而这两类的输出必须是层的结果,虽然可以用tf.keras.layers.Lambda将此操作变成层,但总归是牵强的,所以使用子类模型。

class Test(keras.Model):
    def __init__(self):
        super(Test, self).__init__()
        filters = 64
        initializer = tf.random_normal_initializer(0., 0.02)
        self.conv1 = Conv2D(filters, 4, 2, 'same', use_bias=False, 
                                              kernel_initializer=initializer)
        self.bn1 = BatchNormalization()
        self.conv2 = Conv2D(filters*2, 4, 2, 'same', use_bias=False, 
                                              kernel_initializer=initializer)
        self.bn2 = BatchNormalization()
        self.conv3 = Conv2D(filters*4, 4, 2, 'same', use_bias=False, 
                                              kernel_initializer=initializer)
        self.bn3 = BatchNormalization()

    def call(self, inputs):
        x1 = inputs[0]
        x2 = inputs[1]
        skips = []      # 存结果的列表
        x1_1 = tf.nn.relu(self.bn1(self.conv1(x1)))
        x2_1 = tf.nn.relu(self.bn1(self.conv1(x2)))
        skips.append(x1_1)

        x1_2 = tf.nn.relu(self.bn1(self.conv1(x1_1)))
        x2_2 = tf.nn.relu(self.bn1(self.conv1(x2_1)))
        skips.append(x1_2)

        x1_3 = tf.nn.relu(self.bn1(self.conv1(x1_2)))
        x2_3 = tf.nn.relu(self.bn1(self.conv1(x2_2)))
        skips.append(x1_3)

        return [skips, x2_3]


model = Test()
model.build(input_shape=[(batch_size, data_size), (batch_size, data_size)])
input1 = tf.random.normal([batch_size, data_size])
input2 = tf.random.normal([batch_size, data_size])
out_put1, out_put2 = model([input1, input2])

TF2用着真的是太难受了,网上的教程都比较泛,对一些细节的处理实例太难找了,找着了还大概率是tf.compat.v1。。。  做完这次我真的好好去看看Torch了。。。

另外,在TF2的图执行模式里,是无法使用for等循环的,但有专门的库函数tf.while_loop,反正我还不怎么会用,而且还要转sess,或者直接可以转eager模式就可以解决。

还有另外的思路@tf.function和tf.while

我的问题可能在一些大佬看来很低级,但确实给我造成了麻烦,我本以为教程上的东西就能解决一切问题的了,还是太弱。若要朋友想指正我的说法或者想要交流TF2里的坑,请私信我。

相关标签: tensorflow python