自定义模型中自定义损失函数的添加问题
程序员文章站
2024-03-14 12:24:28
...
结合focal loss 函数讲解
== 引入工具包 ==
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model
import numpy as np
print(tf.__version__)
print(np.__version__)
== step 0 参数设置 ==
EPOCHS = 5
batchsize = 32
== step 1 数据 ==
mnist = np.load("mnist.npz")
x_train, y_train, x_test, y_test = mnist['x_train'],mnist['y_train'],mnist['x_test'],mnist['y_test']
x_train, x_test = x_train / 255.0, x_test / 255.0
fig, ax = plt.subplots(
nrows=2,
ncols=5,
sharex=True,
sharey=True, )
ax = ax.flatten()
for i in range(10):
img = x_train[y_train == i][0].reshape(28, 28)
ax[i].imshow(img, cmap='Greys', interpolation='nearest')
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()
# Add a channels dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
y_train = tf.one_hot(y_train,depth=10)
y_test = tf.one_hot(y_test,depth=10)
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(batchsize)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batchsize)
== step 2 模型 ==
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32, 3, activation='relu')
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')
def call(self, x):
x = self.conv1(x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
== step 3 损失函数 ==
#多分类的focal loss损失函数
#类的实现
class FocalLoss(tf.keras.losses.Loss):
def __init__(self,gamma=2.0,alpha=0.25):
self.gamma = gamma
self.alpha = alpha
super(FocalLoss, self).__init__()
def call(self,y_true,y_pred):
y_pred = tf.nn.softmax(y_pred,axis=-1)
epsilon = tf.keras.backend.epsilon()#1e-7
y_pred = tf.clip_by_value(y_pred, epsilon, 1.0)
y_true = tf.cast(y_true,tf.float32)
loss = - y_true * tf.math.pow(1 - y_pred, self.gamma) * tf.math.log(y_pred)
loss = tf.math.reduce_sum(loss,axis=1)
return loss
#函数的实现
def FocalLoss(gamma=2.0,alpha=0.25):
def focal_loss_fixed(y_true, y_pred):
y_pred = tf.nn.softmax(y_pred,axis=-1)
epsilon = tf.keras.backend.epsilon()
y_pred = tf.clip_by_value(y_pred, epsilon, 1.0)
y_true = tf.cast(y_true,tf.float32)
loss = - y_true * tf.math.pow(1 - y_pred, gamma) * tf.math.log(y_pred)
loss = tf.math.reduce_sum(loss,axis=1)
return loss
return focal_loss_fixed
#loss_object = tf.keras.losses.CategoricalCrossentropy()
loss_object = FocalLoss(gamma=2.0,alpha=0.25)
== step 4 优化器 ==
optimizer = tf.keras.optimizers.Adam()
== step 5 评测函数==
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')
test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')
== step 6 训练 ==
model = MyModel()
@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images)
loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(labels, predictions)
@tf.function
def test_step(images, labels):
predictions = model(images)
t_loss = loss_object(labels, predictions)
test_loss(t_loss)
test_accuracy(labels, predictions)
for epoch in range(EPOCHS):
# 在下一个epoch开始时,重置评估指标
train_loss.reset_states()
train_accuracy.reset_states()
test_loss.reset_states()
test_accuracy.reset_states()
for images, labels in train_ds:
train_step(images, labels)
for test_images, test_labels in test_ds:
test_step(test_images, test_labels)
template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
print(template.format(epoch + 1,
train_loss.result(),
train_accuracy.result() * 100,
test_loss.result(),
test_accuracy.result() * 100))
== step 7 训练可视化 ==
== inference ==
推荐阅读
-
自定义模型中自定义损失函数的添加问题
-
tf2.0中自定义模型类、自定义层、自定义损失函数及评估指标
-
php自定义函数br2nl实现将html中br换行符转换为文本输入中换行符的方法【与函数nl2br功能相反】
-
php usort 使用用户自定义的比较函数对二维数组中的值进行排序
-
Laravel 5.4向IoC容器中添加自定义类的方法示例
-
Laravel 5.4向IoC容器中添加自定义类的方法示例
-
php usort 使用用户自定义的比较函数对二维数组中的值进行排序
-
php自定义函数br2nl实现将html中br换行符转换为文本输入中换行符的方法【与函数nl2br功能相反】
-
基于自定义Unity生存期模型PerCallContextLifeTimeManager的问题
-
解决在Web.config或App.config中添加自定义配置的方法详解