欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

Keras之自定义损失(loss)函数用法说明

程序员文章站 2022-06-15 10:20:16
在keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在keras中是固定的,须如下形式:def my_loss(y_true, y_pred):# y...

在keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在keras中是固定的,须如下形式:

def my_loss(y_true, y_pred):
# y_true: true labels. tensorflow/theano tensor
# y_pred: predictions. tensorflow/theano tensor of the same shape as y_true
 .
 .
 .
 return scalar #返回一个标量值

然后在model.compile中指定即可,如:

model.compile(loss=my_loss, optimizer='sgd')

具体参考keras官方metrics的定义keras/metrics.py

"""built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import six
from . import backend as k
from .losses import mean_squared_error
from .losses import mean_absolute_error
from .losses import mean_absolute_percentage_error
from .losses import mean_squared_logarithmic_error
from .losses import hinge
from .losses import logcosh
from .losses import squared_hinge
from .losses import categorical_crossentropy
from .losses import sparse_categorical_crossentropy
from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
from .losses import cosine_proximity
from .utils.generic_utils import deserialize_keras_object
from .utils.generic_utils import serialize_keras_object
 
def binary_accuracy(y_true, y_pred):
 return k.mean(k.equal(y_true, k.round(y_pred)), axis=-1)
 
 
def categorical_accuracy(y_true, y_pred):
 return k.cast(k.equal(k.argmax(y_true, axis=-1),
       k.argmax(y_pred, axis=-1)),
     k.floatx())
 
def sparse_categorical_accuracy(y_true, y_pred):
 # reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
 if k.ndim(y_true) == k.ndim(y_pred):
  y_true = k.squeeze(y_true, -1)
 # convert dense predictions to labels
 y_pred_labels = k.argmax(y_pred, axis=-1)
 y_pred_labels = k.cast(y_pred_labels, k.floatx())
 return k.cast(k.equal(y_true, y_pred_labels), k.floatx())
 
def top_k_categorical_accuracy(y_true, y_pred, k=5):
 return k.mean(k.in_top_k(y_pred, k.argmax(y_true, axis=-1), k), axis=-1)
 
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
 # if the shape of y_true is (num_samples, 1), flatten to (num_samples,)
 return k.mean(k.in_top_k(y_pred, k.cast(k.flatten(y_true), 'int32'), k),
     axis=-1)
 
# aliases
 
mse = mse = mean_squared_error
mae = mae = mean_absolute_error
mape = mape = mean_absolute_percentage_error
msle = msle = mean_squared_logarithmic_error
cosine = cosine_proximity
 
def serialize(metric):
 return serialize_keras_object(metric)
 
def deserialize(config, custom_objects=none):
 return deserialize_keras_object(config,
         module_objects=globals(),
         custom_objects=custom_objects,
         printable_module_name='metric function')
 
def get(identifier):
 if isinstance(identifier, dict):
  config = {'class_name': str(identifier), 'config': {}}
  return deserialize(config)
 elif isinstance(identifier, six.string_types):
  return deserialize(str(identifier))
 elif callable(identifier):
  return identifier
 else:
  raise valueerror('could not interpret '
       'metric function identifier:', identifier)

以上这篇keras之自定义损失(loss)函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。