欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Tensorflow 2 单GPU同时训练多个模型

程序员文章站 2022-05-26 17:54:45
...

Tensorflow 2 单GPU同时训练多个模型

问题

有时我们需要对多个模型进行性能对比。若一次只训练一个模型,我们需要时刻关注训练进度,非常耗费精力。同时进行多个模型的训练能够降低人力成本。

代码

这里对三个网络进行图像的二值分割训练,它们分别是Unet, Linknet, FPN。利用for循环对整训练集进行遍历。train函数是这段代码的核心,每调用一次train就进行一次迭代。with里面是前向传播形成计算图、xxx_tape.gradient是通过反向自动微分算法求网络的梯度、xxx_optimizer.apply_gradients是优化器更新网络参数。

import tensorflow as tf
from segmentation_models import Unet,Linknet,FPN
from segmentation_models import set_framework
set_framework('tf.keras')

Unet = Unet('resnet34',input_shape=(128, 128, 3), classes=1, activation='sigmoid', encoder_weights=None, encoder_freeze=False)
Linknet = Linknet('resnet34',input_shape=(128, 128, 3), classes=1, activation='sigmoid', encoder_weights=None, encoder_freeze=False)
FPN = FPN('resnet34',input_shape=(128, 128, 3), classes=1, activation='sigmoid', encoder_weights=None, encoder_freeze=False)

#Adam优化器,学习率0.001
Unet_optimizer = tf.keras.optimizers.Adam(1e-3)
Linknet_optimizer = tf.keras.optimizers.Adam(1e-3)
FPN_optimizer = tf.keras.optimizers.Adam(1e-3)

def loss_Bce(y_true, y_pred):
    costs=tf.reduce_mean(tf.keras.losses.binary_crossentropy(y_true, y_pred))
    return costs
    
@tf.function
def train(input_images,label):   
    with tf.GradientTape() as Unet_tape:
        predict_Unet = Unet(input_images, training=True)
        Unet_loss = loss_Bce(label,predict_Unet)
    Unet_gradients = Unet_tape.gradient(Unet_loss,Unet.trainable_variables)
    Unet_optimizer.apply_gradients(zip(Unet_gradients,
                                      Unet.trainable_variables))
    with tf.GradientTape() as Linknet_tape:
        predict_Linknet = Linknet(input_images, training=True)
        Linknet_loss = loss_Bce(label,predict_Linknet)
    Linknet_gradients = Linknet_tape.gradient(Linknet_loss,Linknet.trainable_variables)
    Linknet_optimizer.apply_gradients(zip(Linknet_gradients,
                                  Linknet.trainable_variables))  
    with tf.GradientTape() as FPN_tape: 
        predict_FPN = FPN(input_images, training=True)
        FPN_loss = loss_Bce(label,predict_FPN)
    FPN_gradients = FPN_tape.gradient(FPN_loss,FPN.trainable_variables)
    FPN_optimizer.apply_gradients(zip(FPN_gradients,
                                       FPN.trainable_variables))  
    return Unet_loss,Linknet_loss,FPN_loss


Epochs=100
Batch_size=32
Total_size=img_train.shape[0]#img_train是训练集的输入图像 shape=[size,height,width,channel]
for i in range(Epochs):
    for k1 in range(int(Total_size/Batch_size)):
        input_images=img_train[k1*Batch_size:k1*Batch_size+Batch_size]   
        label=label_train[k1*Batch_size:k1*Batch_size+Batch_size]   
        Unet_loss,Linknet_loss,FPN_loss=train(input_images,label)
        print("\rEpoch: {:d} batch: {:d} Unet_loss: {:.4f} Linknet_loss: {:.4f} Dice_loss: {:.4f} "
              .format(i+1,k1+1,Unet_loss, Linknet_loss, FPN_loss), end='',  flush=True)

本代码是借鉴了生成对抗网络网络(GAN)的训练方法,但略有不同。以下是GAN的train函数。

def train_step(input_image, target):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = generator(input_image, training=True)
        
        disc_real_output = discriminator(tf.concat([input_image, target],axis=3), training=True)
        
        disc_generated_output = discriminator(tf.concat([tf.cast(input_image,dtype=tf.float32), gen_output],axis=3), training=True)
        

        gen_total_loss, gen_gan_loss, gen_l1_loss  = generator_loss(disc_generated_output, gen_output, target)
        
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

    generator_gradients = gen_tape.gradient(gen_total_loss,
                                          generator.trainable_variables)
    discriminator_gradients = disc_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))
    return gen_total_loss, gen_gan_loss, gen_l1_loss,disc_loss

可见GAN构建成生成器和判别器的计算图时,它们是放在一个with函数里面同时进行计算的。GAN只有两个网络,且判别器的结构通常比较简单,所以一般不会爆内存。
对于训练多个大型网络,则需要依次更新,所以就要多个with函数。经测试,本方法同时训练4个网络是没问题的。

谢谢观看!