欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

Python绘制神经网络中常见激活函数的图形

程序员文章站 2022-06-09 22:44:30
前言需要绘制的激活函数有sigmoid,tanh,ReLU,softplus,swish共5个函数。各个函数的公式sigmoid:tanh:ReLU:softplus:swish:其中 ????(⋅) 为 Logistic 函数, β为可学习的参数或一个固定超参数上面5个激活函数对应的代码公式如下:def sigmoid(x): return 1 / (1 + np.exp(-x))def tanh(x): return (np.exp(x) - np.ex...

前言

需要绘制的激活函数有sigmoidtanhReLUsoftplusswish共5个函数。

各个函数的公式

sigmoid:
Python绘制神经网络中常见激活函数的图形
tanh:
Python绘制神经网络中常见激活函数的图形
ReLU:
Python绘制神经网络中常见激活函数的图形

softplus:
Python绘制神经网络中常见激活函数的图形
swish:
Python绘制神经网络中常见激活函数的图形
其中 ????(⋅) 为 Logistic 函数, β为可学习的参数或一个固定超参数

上面5个激活函数对应的代码公式如下:

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def tanh(x):
    return (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x)) 

def relu(x):
    return np.maximum(0, x)

def softplus(x):
    return np.log(1 + np.exp(x))

def swish(x, beta):
    return x * sigmoid(beta * x)

开始绘图

此处我们使用matplotlib来绘图,由于matplotlib默认的坐标系不是直角坐标系,需要写一个函数来获取直角坐标系,并在直角坐标系上绘图。

下面的代码可以获取直角坐标系的ax,此后便可以通过ax.plot等操作在ax上面绘图。

def get_central_ax():
    ax = plt.gca() # get current axis 获得坐标轴对象
	
	# 将右边 上边的两条边颜色设置为空 其实就相当于抹掉这两条边
    ax.spines['right'].set_color('none') 
    ax.spines['top'].set_color('none') 

	# 指定下边的边作为 x 轴 指定左边的边为 y 轴
    ax.xaxis.set_ticks_position('bottom') 
    ax.yaxis.set_ticks_position('left') 
    
    ax.spines['bottom'].set_position(('data', 0)) #指定 data 设置的bottom(也就是指定的x轴)绑定到y轴的0这个点上
    ax.spines['left'].set_position(('data', 0))
    
    return ax

下面是绘制sigmoidtanhReLUsoftplus的代码,swish的图像需要单独绘制:

x = np.arange(-6.0, 6.0, 0.1)
y1 = sigmoid(x)
y2 = tanh(x)
y3 = relu(x)
y4 = softplus(x)

ax = get_central_ax()

# ax = plt.subplot(111)
ax.plot(x, y1)
ax.plot(x, y2, linestyle='--')
ax.plot(x, y3, linestyle='--')
ax.plot(x, y4, linestyle='--')
ax.legend(['sigmoid', 'tanh', 'ReLU', 'softplus'])
plt.show()

绘制的图像如下:
Python绘制神经网络中常见激活函数的图形

下面是绘制swish函数图像的代码:

x = np.arange(-6.0, 6.0, 0.1)
ax = get_central_ax() 

legends = []
for beta in [0, 0.5, 1, 100]:
    y_s = swish(x, beta)
    ax.plot(x, y_s, linestyle='--')
    legends.append('β = '+str(beta))

ax.legend(legends)
plt.show()

图形如下:
Python绘制神经网络中常见激活函数的图形

完整代码

import numpy as np
import matplotlib.pyplot as plt
def sigmoid(x):
    return 1 / (1 + np.exp(-x))
def tanh(x):
    return (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x)) 

def relu(x):
    return np.maximum(0, x)
def softplus(x):
    return np.log(1 + np.exp(x))
def swish(x, beta):
    return x * sigmoid(beta * x)

def get_central_ax():
    ax = plt.gca() # get current axis 获得坐标轴对象

    ax.spines['right'].set_color('none') 
    ax.spines['top'].set_color('none')

    ax.xaxis.set_ticks_position('bottom') 
    ax.yaxis.set_ticks_position('left') 
    ax.spines['bottom'].set_position(('data', 0))
    ax.spines['left'].set_position(('data', 0))
    
    return ax

# 绘制`sigmoid`,`tanh`,`ReLU`,`softplus`
x = np.arange(-6.0, 6.0, 0.1)
y1 = sigmoid(x)
y2 = tanh(x)
y3 = relu(x)
y4 = softplus(x)

ax = get_central_ax()

# ax = plt.subplot(111)
ax.plot(x, y1)
ax.plot(x, y2, linestyle='--')
ax.plot(x, y3, linestyle='--')
ax.plot(x, y4, linestyle='--')
ax.legend(['sigmoid', 'tanh', 'ReLU', 'softplus'])
plt.show()

# 绘制`swish`函数
x = np.arange(-6.0, 6.0, 0.1)
ax = get_central_ax() 

legends = []
for beta in [0, 0.5, 1, 100]:
    y_s = swish(x, beta)
    ax.plot(x, y_s, linestyle='--')
    legends.append('β = '+str(beta))

ax.legend(legends)
plt.show()

本文地址:https://blog.csdn.net/weixin_44843824/article/details/110849612

相关标签: Python学习