基于TensorFlow.Keras实现的混合密度模型(MDN)
程序员文章站
2024-01-16 23:13:10
...
Note
- 全文及代码参考自:混合密度模型Mixture Density Networks
- 原文对模型解释的很清楚,我就不再过多解释。 如需深入,建议学习高斯混合模型(GMM)。
- 原文的代码有些过时了,我精简了原始的代码,以MarkDown的形式展示在自己的博客里。
1 The Most Common Neural Network
1.1 Import Modules
* TensorFlow v2.1自带了Keras, 直接从中导入Keras
import matplotlib.pyplot as plt
import numpy as np
import math
import time
import tensorflow.keras as keras
from tensorflow.keras import layers, models
from tensorflow.keras import backend as K
# set figure size
from matplotlib.pylab import rcParams
rcParams['figure.figsize'] = 10, 4
# calculagraph
def helper_print_with_time(*arg,sep=','):
print(time.strftime("%H:%M:%S",time.localtime()),sep.join(map(str,arg)))
1.2 Produce Many-to-One Training Data
NSAMPLE = 1000
x_data = np.random.uniform(-10.5, 10.5, (NSAMPLE,1))
r_data = np.random.normal(size=(NSAMPLE,1))
y_data = np.sin(x_data*0.75)*7.0 + x_data*0.5 + r_data*1.0
plot_out = plt.plot(x_data,y_data,'r.',alpha=0.3)
1.3 Build Model
model = models.Sequential()
model.add(layers.Dense(20,activation='tanh',input_shape=(1,)))
model.add(layers.Dense(1,activation='linear'))
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_1 (Dense) (None, 20) 40
_________________________________________________________________
dense_2 (Dense) (None, 1) 21
=================================================================
Total params: 61
Trainable params: 61
Non-trainable params: 0
_________________________________________________________________
1.4 Train Model and Testing
# Train
helper_print_with_time('>>> begin')
model.compile(optimizer='rmsprop',loss='mse',metrics=['mse'])
model.fit(x_data,y_data,epochs=10000,batch_size=NSAMPLE,verbose=0)
helper_print_with_time('>>> end')
# Testing
x_test = np.arange(-10.5,10.5,0.1).reshape(-1,1)
y_test = model.predict(x_test)
plot_out = plt.plot(x_data,y_data,'ro', x_test,y_test,'b.',alpha=0.3)
11:18:21 >>> begin
11:18:33 >>> end
1.5 Exchange Input and Output
x_data, y_data = y_data, x_data
plot_out = plt.plot(x_data,y_data,'ro',alpha=0.3)
1.6 Retrain Model and Testing
# Train
helper_print_with_time('>>> begin')
model.compile(optimizer='rmsprop',loss='mse',metrics=['mse'])
model.fit(x_data,y_data,epochs=10000,batch_size=NSAMPLE,verbose=0)
helper_print_with_time('>>> end')
# Testing
x_test = np.arange(-10.5,10.5,0.1).reshape(-1,1)
y_test = model.predict(x_test)
plot_out = plt.plot(x_data,y_data,'ro', x_test,y_test,'bo',alpha=0.3)
11:19:19 >>> begin
11:19:31 >>> end
2 MDN (Mixture Density Networks)
2.1 Produce One-to-Many Training Data
NSAMPLE = 2500
y_data = np.random.uniform(-10.5, 10.5, (NSAMPLE,1))
r_data = np.random.normal(size=(NSAMPLE,1))
x_data = np.sin(y_data*0.75)*7.0 + y_data*0.5 + r_data*1.0
plot_out = plt.plot(x_data,y_data,'ro',alpha=0.3)
2.2 Build Model
- 设置20个混合元,因此输出为60维,包括20个比例(pct)、20个均值(mu)、20个方差(std)
NHIDDEN, KMIX = 128, 20 # KMIX is the number of mixtures
NOUT = KMIX * 3 # number of pct, mu, std
Input = keras.Input(shape=(1,))
hidden = layers.Dense(NHIDDEN,activation='tanh')(Input)
op =layers.Dense(KMIX,activation='linear',name='op')(hidden)
op = layers.Softmax()(op)
ou = layers.Dense(KMIX,activation='linear',name='ou')(hidden)
os = layers.Dense(KMIX,activation='linear',name='os')(hidden)
Output = layers.Concatenate()([op,ou,os])
model = keras.Model(Input,Output)
model.summary()
Model: "model_2"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_3 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
dense_4 (Dense) (None, 128) 256 input_3[0][0]
__________________________________________________________________________________________________
op (Dense) (None, 20) 2580 dense_4[0][0]
__________________________________________________________________________________________________
softmax_2 (Softmax) (None, 20) 0 op[0][0]
__________________________________________________________________________________________________
ou (Dense) (None, 20) 2580 dense_4[0][0]
__________________________________________________________________________________________________
os (Dense) (None, 20) 2580 dense_4[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 60) 0 softmax_2[0][0]
ou[0][0]
os[0][0]
==================================================================================================
Total params: 7,996
Trainable params: 7,996
Non-trainable params: 0
__________________________________________________________________________________________________
2.3 DIY the Loss Function
def get_mixture_coef(output, Training=True): # out_mu.shape=[len(x_data),KMIX]
out_pct = output[:, :KMIX]
out_mu = output[:, KMIX:2*KMIX]
# 训练时返回TensorFlow格式, 否则array格式, 取指数确保标准差为正(不是方差)
out_std = K.exp(output[:, 2*KMIX:]) if Training else np.exp(output[:, 2*KMIX:])
return out_pct, out_mu, out_std
# 注意损失函数计算高斯分布值的时候,用output去计算,而不是用input
def get_loss(pct, mu, std, y):
# 获取y在多元正态分布N(mu,std^2)下的概率密度
factors = 1 / math.sqrt(2*math.pi) / std
exponent = K.exp(-1/2*K.square((y-mu)/std))
# 混合高斯模型下的似然函数(取值为y)
GMM_likelihood = K.sum(pct*factors*exponent, axis=1);
# 加负号: 梯度下降 ==> 似然函数最大
log_likelihood = -K.log(GMM_likelihood)
# sum求和对应log函数内部的概率密度连乘
return K.mean(log_likelihood)
def loss_func(y_true, y_pred):
out_pct, out_mu, out_std = get_mixture_coef(y_pred)
result = get_loss(out_pct, out_mu, out_std, y_true)
return result
2.4 Train Model
helper_print_with_time('>>>begin')
model.compile(optimizer='adam',loss=loss_func,metrics=[loss_func])
history=model.fit(x_data,y_data,epochs=10000,batch_size=NSAMPLE,verbose=0)
helper_print_with_time('>>>end')
11:49:32 >>>begin
11:50:40 >>>end
2.5 Show the History of Loss
plt.ylim(min(history.history['loss'])*0.90, max(history.history['loss'])*1.01);
plot_out = plt.plot(history.history['loss'], 'r-')
2.6 View the shape of Input and Output
# 输入测试数据
x_test = np.arange(-15,15,0.1).reshape(-1,1) # needs to be a matrix, not a vector
out_pct_test, out_mu_test, out_std_test = get_mixture_coef(model.predict(x_test), Training=False)
# 各变量的size
x_test.shape, out_pct_test.shape, out_mu_test.shape, out_std_test.shape
((300, 1), (300, 20), (300, 20), (300, 20))
2.7 View Input vs. mu and Input vs. pct for Different Mixtures
- 可以看出,对于不同的Input,每个混合元的pct都是变化的,最大为1,最小为0。
plt.figure(figsize=(12, 10))
plt.subplot(211) # 绘制 mu
plt.plot(x_data,y_data,'b.',alpha=0.8) # 叠加训练数据
plt.plot(x_test,out_mu_test,'.',alpha=0.5) # 不同分支(颜色)的均值随输入x的变化
plt.subplot(212) # 绘制 pct
plt.plot(x_test,out_pct_test,'.',alpha=0.8) # 不同分支(颜色)的pct随输入x的变化
plt.show()
2.8 Predict the PDF for Given Input
- 给定输入,输出预测值的概率密度函数
def generate_PDF(model,x_new):
out_new = model.predict(np.array([x_new]))
pct_new, mu_new, std_new = get_mixture_coef(out_new, Training=False) # (*_new)都是行向量
# 获取y_label的取值范围
MIN = mu_new.min() - 3*std_new[np.where(mu_new==np.min(mu_new))] # -3*std
MAX = mu_new.max() + 3*std_new[np.where(mu_new==np.max(mu_new))] # +3*std
y_label = np.arange(MIN,MAX,0.1).reshape(-1,1) # y_label 必须是列向量
factors = 1 / math.sqrt(2*math.pi) / std_new
exponent = np.exp(-1/2*np.square((y_label-mu_new)/std_new))
GMM_PDF = np.sum(pct_new*factors*exponent, axis=1); # 对多个高斯分布求和
plt.plot(y_label, GMM_PDF)
return y_label, GMM_PDF
y_label, GMM_PDF = generate_PDF(model, x_new=10)
2.8 Draw Heatmap
def generate_heatmap(model):
x_label = np.arange(-15, 15, 0.08).reshape(-1,1) # x和y的个数不同,可以验证imshow()的轴向
y_label = np.arange(-15, 15, 0.1).reshape(1,1,-1)
N, M = x_label.size, y_label.size
out_new = model.predict(x_label)
pct_new, mu_new, std_new = get_mixture_coef(out_new, Training=False)
[pct_new, mu_new, std_new] = [x.reshape(x.shape[0],-1,1) for x in [pct_new, mu_new, std_new]]
factors = 1 / np.sqrt(2*math.pi) / std_new
exponent = np.exp(-1/2*np.square((y_label-mu_new)/std_new))
heatmap = np.sum(pct_new*factors*exponent, axis=1)
# plt.imshow(heatmap[:,::-1].T); # 验证轴向是否正确
plt.figure(figsize=(8, 8))
extent=[x_label.min(), x_label.max(), y_label.min(), y_label.max()]
plt.imshow(heatmap[:,::-1].T, extent=extent); plt.show()
generate_heatmap(model)
上一篇: php 连接redis 数据库单利类
下一篇: Python2和Python3的区别详解