tensorflow2.0------自定义损失函数和Layer
程序员文章站
2024-03-14 12:54:22
...
import matplotlib as mpl #画图用的库
import matplotlib.pyplot as plt
#下面这一句是为了可以在notebook中画图
%matplotlib inline
import numpy as np
import sklearn #机器学习算法库
import pandas as pd #处理数据的库
import os
import sys
import time
import tensorflow as tf
from tensorflow import keras #使用tensorflow中的keras
#import keras #单纯的使用keras
print(tf.__version__)
print(sys.version_info)
for module in mpl, np, sklearn, pd, tf, keras:
print(module.__name__, module.__version__)
layer = tf.keras.layers.Dense(10)#None表示不定长,input_shape所表示的意思就是 未知数量的样本,每个样本有5个输入单元
layer = tf.keras.layers.Dense(100, input_shape=[None,5])# input_shape只在第一层时才需要添加,不添加系统可自动推导出来
layer(tf.zeros([10,5]))#这里定义输入为10*5的矩阵,就是说有10个这样的样本
<tf.Tensor: id=29, shape=(10, 100), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0.]], dtype=float32)>
#Variables可打印出layer中的所有参数
# x*w + b
#layer.variables
#trainable_variables 打印可训练的参数 这一层有100个神经单元,上一层有5个输入,则总的参数为 5*100+100
layer.trainable_variables
#trainable_weights 打印可训练的权重
#layer.trainable_weights
[<tf.Variable 'dense_1/kernel:0' shape=(5, 100) dtype=float32, numpy=
array([[-1.19798876e-01, -2.13079855e-01, -1.02656543e-01,
1.85441867e-01, 4.94155735e-02, 6.16988540e-03,
1.53844759e-01, -1.81336433e-01, -1.03642046e-03,
5.85038960e-03, 6.24256879e-02, -2.19907671e-01,
1.83101669e-01, -1.65164471e-04, 1.83286175e-01,
-6.35699928e-03, 1.44905582e-01, -1.78429008e-01,
2.08230659e-01, 9.32715088e-02, 1.09333947e-01,
-2.11200222e-01, -8.16740841e-02, 1.59377173e-01,
1.87050387e-01, -2.32415795e-01, 2.03493580e-01,
-1.77162156e-01, -3.61523330e-02, 5.43355793e-02,
1.18782863e-01, 6.58839494e-02, 1.29797012e-02,
-7.31805861e-02, -1.64841652e-01, 2.62765586e-03,
2.74553746e-02, -2.06142843e-01, -7.72242248e-03,
-1.99526936e-01, 3.47671062e-02, -6.27684295e-02,
-6.43615574e-02, 6.72664791e-02, 3.23811024e-02,
1.10492334e-01, -1.85585946e-01, -7.10566938e-02,
-5.78716546e-02, 1.45465180e-01, -1.98468402e-01,
1.29277557e-02, 5.84578365e-02, 2.10349366e-01,
-1.16091035e-01, 2.19189212e-01, 9.42905992e-02,
4.18415517e-02, -5.85604757e-02, 2.27137461e-01,
-8.05416852e-02, 1.98801607e-02, -2.22806484e-02,
-5.09096831e-02, 1.39255926e-01, 1.72453120e-01,
5.94796687e-02, -1.94010302e-01, -1.29458040e-01,
-1.67704582e-01, -1.93801701e-01, 1.29134506e-02,
4.38458771e-02, 1.85908690e-01, -1.87387943e-01,
2.06867501e-01, -6.68620616e-02, -8.91843736e-02,
-2.37392247e-01, -1.34011492e-01, -2.35438019e-01,
2.25169495e-01, -5.90238869e-02, -3.62023115e-02,
-7.08049536e-02, -9.07327086e-02, 1.80093482e-01,
3.40920240e-02, -1.96182936e-01, 2.31930390e-01,
-1.24343611e-01, 5.49836010e-02, -2.27239236e-01,
-2.15709969e-01, 1.64134666e-01, -1.77111670e-01,
4.84552830e-02, -4.61576134e-02, -1.83183074e-01,
8.94475132e-02],
[ 1.93398938e-01, 1.36397049e-01, 7.62083381e-02,
-3.22984159e-03, -1.54330581e-02, 9.16190594e-02,
-1.49835646e-01, 1.23447612e-01, 8.96078497e-02,
1.57661691e-01, -1.11075133e-01, 2.09647104e-01,
1.70456603e-01, 7.91838914e-02, 2.08825901e-01,
1.82572678e-01, 1.95767805e-01, 7.79902935e-03,
1.98027804e-01, 1.86323211e-01, 2.93843001e-02,
1.00706235e-01, 5.45971245e-02, 6.51367009e-03,
1.20625392e-01, 2.35398397e-01, -1.44583061e-01,
9.71381515e-02, 2.00735196e-01, -1.82016537e-01,
-2.03217626e-01, 2.29307696e-01, -1.99740827e-01,
1.73530295e-01, 2.33968154e-01, -9.20783728e-02,
1.48191616e-01, -1.25100762e-01, 4.24706191e-02,
2.33314469e-01, -3.19331884e-03, -2.06792444e-01,
-1.71466410e-01, -4.59314734e-02, -4.27660197e-02,
-9.26681310e-02, -1.64626956e-01, 1.02817997e-01,
-1.55400887e-01, -1.20745704e-01, -6.18212074e-02,
1.27071634e-01, -2.18336537e-01, -6.66197389e-02,
7.29902834e-02, -1.60827935e-01, -2.38064080e-02,
1.88934609e-01, -1.09155729e-01, -2.90658325e-02,
-1.51838362e-02, -1.60760581e-02, 2.22714975e-01,
-2.19662994e-01, -1.01167545e-01, 6.25229627e-02,
-8.16874206e-02, 1.59866348e-01, 3.36323231e-02,
4.97549772e-05, -6.01041317e-03, 5.76255172e-02,
1.53653607e-01, -7.09600896e-02, 2.01412931e-01,
7.45818168e-02, -1.25227332e-01, -3.58315259e-02,
-7.37203658e-02, 7.37054497e-02, -4.70031798e-03,
1.21293589e-01, -1.40033484e-01, -2.44935155e-02,
-2.02377886e-02, -8.89493972e-02, -3.85637581e-02,
-1.17017962e-01, -1.54986203e-01, 2.01146439e-01,
1.70223281e-01, -7.02508092e-02, -6.59079999e-02,
2.73524970e-02, 1.71576589e-02, 3.22768092e-03,
-3.43136191e-02, -2.34245613e-01, -1.08609855e-01,
-1.99974671e-01],
[ 1.49224862e-01, -9.62817669e-02, -1.84434980e-01,
-9.43478197e-02, -7.78061897e-02, -1.41380519e-01,
-2.19036415e-01, 6.82868212e-02, 1.94785848e-01,
-9.73739773e-02, -2.09367737e-01, -1.71446055e-01,
2.15334728e-01, 1.59692004e-01, -4.41892445e-02,
1.65368274e-01, -1.25258297e-01, 3.53681594e-02,
1.67240217e-01, 1.25391930e-02, 1.24417022e-01,
7.86104649e-02, 2.17301652e-01, 5.11338264e-02,
1.49539217e-01, 3.26410979e-02, 3.23790461e-02,
-1.91050544e-01, 2.37367347e-01, -1.65161908e-01,
4.46816236e-02, 1.53735891e-01, 1.61214635e-01,
-1.78851366e-01, 6.62474334e-03, -1.60464942e-01,
-1.73395157e-01, 9.90249068e-02, 8.77296478e-02,
-1.61264986e-02, -1.75254315e-01, 1.20523423e-02,
2.61914581e-02, -5.92734069e-02, 1.72799513e-01,
-4.50387895e-02, 6.38738126e-02, -3.73772830e-02,
1.00026950e-01, -2.11596265e-01, -1.52270943e-02,
-5.68721741e-02, 9.41223055e-02, 5.17047495e-02,
1.99242249e-01, 1.42246112e-01, -2.29594246e-01,
-1.03673637e-01, -8.55330676e-02, -6.80788606e-02,
1.79324254e-01, 8.89710635e-02, 5.61997145e-02,
6.70184046e-02, -1.85485125e-01, -1.36590302e-01,
4.49251980e-02, -1.99818000e-01, 1.60398886e-01,
-2.13471681e-01, 2.15477839e-01, 1.14858165e-01,
-9.72904265e-03, -4.94042188e-02, 1.73236027e-01,
1.55743957e-03, 1.18652299e-01, 2.15957090e-01,
-1.57986939e-01, 1.29788026e-01, 1.06273189e-01,
1.85594425e-01, -7.64783174e-02, 1.57222435e-01,
-5.85600734e-04, -2.09712476e-01, 2.36654297e-01,
7.69105405e-02, -5.39526492e-02, 1.15425691e-01,
-2.03577191e-01, 1.61271915e-01, -3.52287591e-02,
-2.06974536e-01, 2.34036282e-01, -1.90731898e-01,
5.11476845e-02, -6.68352246e-02, 1.54985234e-01,
-1.00576073e-01],
[ 9.82330292e-02, -1.17788285e-01, -1.64985955e-02,
-2.20375121e-01, -2.27009207e-02, 4.55506295e-02,
1.50215611e-01, -1.06511310e-01, 1.80991217e-01,
9.07516927e-02, 8.77115577e-02, 2.16988727e-01,
-6.85292780e-02, -4.29446995e-03, 2.10644767e-01,
-7.10284859e-02, -8.33985656e-02, 2.07440242e-01,
2.24501938e-02, 6.17934614e-02, -9.74216759e-02,
2.12433785e-02, -3.45096290e-02, -2.13498011e-01,
-1.41982809e-01, 2.14598492e-01, -1.88461691e-01,
-7.90978819e-02, 1.52341321e-01, -4.15554941e-02,
2.29092702e-01, -4.19260561e-02, -1.91133752e-01,
-1.49677724e-01, 1.73151746e-01, -1.23825543e-01,
-3.35648656e-03, 9.36887711e-02, 8.27962607e-02,
-1.62343368e-01, -9.39139426e-02, -1.52234644e-01,
1.91828385e-01, 2.00211659e-01, -1.78918242e-03,
-1.33397788e-01, 1.32620350e-01, 2.15210244e-01,
-1.62174165e-01, -6.33318722e-02, -2.29889184e-01,
1.02371857e-01, 5.76548129e-02, 7.00682551e-02,
5.45155853e-02, 3.89488190e-02, -2.19435364e-01,
1.11161783e-01, 2.03933045e-01, -2.21788377e-01,
-5.48370630e-02, -1.85295686e-01, 1.66524306e-01,
-2.69961953e-02, 1.85335800e-01, -1.83955491e-01,
8.69494528e-02, 2.84251124e-02, -1.87801719e-01,
-1.06175631e-01, -1.65407091e-01, 1.84860483e-01,
-6.11513108e-02, 1.84147492e-01, 7.80433565e-02,
4.56521958e-02, 1.82224944e-01, -3.24423760e-02,
1.06075719e-01, 2.04735801e-01, 4.44191545e-02,
1.66268751e-01, -1.84311718e-02, 1.57610670e-01,
-9.12932307e-02, 1.04989901e-01, 1.47415563e-01,
2.24768922e-01, 1.00079611e-01, 1.55956462e-01,
-9.67906564e-02, 9.20642763e-02, 5.77013195e-03,
-6.64863139e-02, -9.70341861e-02, -2.28809565e-01,
-1.94292963e-02, 1.83736309e-01, -1.40318394e-01,
-1.26107663e-01],
[-2.16803998e-02, -1.80408135e-01, 1.03065744e-01,
2.20412865e-01, 8.55985433e-02, -2.06283450e-01,
-1.50228098e-01, 1.60772994e-01, -7.34403729e-04,
-2.38991186e-01, 2.57442147e-02, -5.39559573e-02,
8.43531340e-02, 1.49122730e-01, -1.76507264e-01,
7.43092746e-02, -1.61422133e-01, -4.64574546e-02,
-3.16567272e-02, -1.81297109e-01, 1.42134979e-01,
1.89695522e-01, 2.19301656e-01, 1.96553394e-01,
8.78056735e-02, 6.88405782e-02, 2.85918862e-02,
-6.20819628e-02, -2.01302141e-01, 9.91754085e-02,
1.38416246e-01, 1.93116322e-01, 2.01080546e-01,
-4.78256792e-02, 3.93381864e-02, -9.32268947e-02,
1.49945363e-01, 2.02513203e-01, -1.34237707e-02,
-7.41664022e-02, -3.37326378e-02, 5.03837019e-02,
-1.26262397e-01, -1.45604029e-01, 1.06270060e-01,
-1.16300881e-01, 6.08194619e-02, -6.81088418e-02,
-3.79134715e-03, -1.21684209e-01, -3.75699252e-02,
3.89467627e-02, -1.72224805e-01, 5.78877181e-02,
-1.39211655e-01, 1.22599110e-01, 5.07537574e-02,
-8.05236697e-02, -1.72095835e-01, 3.56161445e-02,
9.34672356e-03, 1.69605017e-03, -1.40235633e-01,
-9.40205157e-02, 1.44792780e-01, 1.81426957e-01,
-1.30601615e-01, -2.18807533e-01, -1.01545528e-01,
-1.24894843e-01, 2.31218085e-01, -1.61409378e-04,
2.04400972e-01, 2.19281301e-01, -1.26980454e-01,
-5.33272773e-02, 1.48247465e-01, -1.03203103e-01,
-2.27923319e-01, 2.34309331e-01, -8.20545107e-02,
5.46423346e-02, 9.31039602e-02, 3.61091942e-02,
1.77635834e-01, 1.10312253e-02, 4.05964702e-02,
-3.99166048e-02, -4.81580645e-02, -2.10754082e-01,
1.91807196e-01, 1.72180340e-01, 1.00455418e-01,
2.22950742e-01, 5.50290197e-02, -1.89168692e-01,
8.85924548e-02, 1.23825893e-01, -2.13536248e-01,
-1.49761781e-01]], dtype=float32)>,
<tf.Variable 'dense_1/bias:0' shape=(100,) dtype=float32, numpy=
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
dtype=float32)>]
#查询layer相关使用
help(layer)
#引用位于sklearn数据集中的房价预测数据集
from sklearn.datasets import fetch_california_housing
housing = fetch_california_housing()
print(housing.DESCR) #数据集的描述
print(housing.data.shape) #相当于 x
print(housing.target.shape) #相当于 y
#用sklearn中专门用于划分训练集和测试集的方法
from sklearn.model_selection import train_test_split
#train_test_split默认将数据划分为3:1,我们可以通过修改test_size值来改变数据划分比例(默认0.25,即3:1)
#将总数乘以test_size就表示test测试集、valid验证集数量
#将数据集整体拆分为train_all和test数据集
x_train_all,x_test, y_train_all,y_test = train_test_split(housing.data, housing.target, random_state=7)
#将train_all数据集拆分为train训练集和valid验证集
x_train,x_valid, y_train,y_valid = train_test_split(x_train_all, y_train_all, random_state=11)
print(x_train_all.shape,y_train_all.shape)
print(x_test.shape, y_test.shape)
print(x_train.shape, y_train.shape)
print(x_valid.shape, y_valid.shape)
(15480, 8) (15480,)
(5160, 8) (5160,)
(11610, 8) (11610,)
(3870, 8) (3870,)
#训练数据归一化处理
# x = (x - u)/std u为均值,std为方差
from sklearn.preprocessing import StandardScaler #使用sklearn中的StandardScaler实现训练数据归一化
scaler = StandardScaler()#初始化一个scaler对象
x_train_scaler = scaler.fit_transform(x_train)#x_train已经是二维数据了,无需astype转换
x_valid_scaler = scaler.transform(x_valid)
x_test_scaler = scaler.transform(x_test)
#tf.nn.softplus: log(1+e^x)
#keras.layers.Lambda 对流经该层的数据做个变换,而这个变换本身没有什么需要学习的参数
customized_softplus=keras.layers.Lambda(lambda x : tf.nn.softplus(x))
print(customized_softplus([-10.,-5.,0.,5.,10.]))
tf.Tensor([4.5417706e-05 6.7153489e-03 6.9314718e-01 5.0067153e+00 1.0000046e+01], shape=(5,), dtype=float32)
#自定义损失函数
#这里的接口参数为 真实值,预测值
def customized_mse(y_true, y_pred):
return tf.reduce_mean(tf.square(y_pred-y_true))
#自定义全连接层dense layer,定义一个子类CustomizedDenseLayer,继承于tf.keras.layers.Layer
#重载 __init__、build、call三个方法
class CustomizedDenseLayer(keras.layers.Layer):
def __init__(self, units, activation=None, **kwargs):
self.units = units
self.activation = keras.layers.Activation(activation)
super(CustomizedDenseLayer, self).__init__(**kwargs)
def build(self,input_shape):
"""构建所需要的参数"""
# x * w + b. input_shape=[None, a] w:[a,b] output_shape=[None,b]
self.kernel=self.add_weight(name="kernel",
shape=(input_shape[1],self.units),#input_shape中的第二个值,units表示神经单元数
initializer="uniform",#表示如何初始化这个参数矩阵的,uniform表示使用均匀分布来初始化
trainable=True) #参数可训练
self.bias=self.add_weight(name="bias",
shape=(self.units, ),
initializer="zeros",
trainable=True)
def call(self,x):
"""完成正向计算"""
return self.activation(x @ self.kernel + self.bias)
#tf.keras.models.Sequential()建立模型
model = keras.models.Sequential([
#keras.layers.Dense(30, activation="relu",input_shape=x_train.shape[1:]),
#keras.layers.Dense(1),
#使用自定义的layer来构建模型
CustomizedDenseLayer(30, activation="relu",input_shape=x_train.shape[1:]),
CustomizedDenseLayer(1),
customized_softplus,
#keras.layers.Dense(1,activation="softplus"),
#keras.layers.Dense(1),keras.layers.Activation("softplus"),
])
#编译model。 loss目标函数为均方差,这里表面上是字符串"mean_squared_error",实际上tensorflow中会映射到对应的算法函数,我们也可以自定义
model.compile(loss=customized_mse, optimizer="adam",metrics=["mean_squared_error"])
#查看model的架构
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
customized_dense_layer (Cust (None, 30) 270
_________________________________________________________________
customized_dense_layer_1 (Cu (None, 1) 31
_________________________________________________________________
lambda (Lambda) (None, 1) 0
=================================================================
Total params: 301
Trainable params: 301
Non-trainable params: 0
#使用监听模型训练过程中的callbacks
logdir='./callbacks_regression'
if not os.path.exists(logdir):
os.mkdir(logdir)
output_model_file = os.path.join(logdir,"regression_california_housing.h5")
#首先定义一个callback数组
callbacks = [
#keras.callbacks.TensorBoard(logdir),
#keras.callbacks.ModelCheckpoint(output_model_file,save_best_only=True),
keras.callbacks.EarlyStopping(patience=5,min_delta=1e-3)
]
history=model.fit(x_train_scaler,y_train,epochs=100,
validation_data=(x_valid_scaler,y_valid),
callbacks=callbacks)
Train on 11610 samples, validate on 3870 samples
Epoch 1/100
11610/11610 [==============================] - 1s 97us/sample - loss: 1.3880 - mean_squared_error: 1.3880 - val_loss: 0.6174 - val_mean_squared_error: 0.6174
Epoch 2/100
11610/11610 [==============================] - 1s 64us/sample - loss: 0.4870 - mean_squared_error: 0.4870 - val_loss: 0.4603 - val_mean_squared_error: 0.4603
。。。
Epoch 42/100
11610/11610 [==============================] - 1s 63us/sample - loss: 0.3083 - mean_squared_error: 0.3083 - val_loss: 0.3200 - val_mean_squared_error: 0.3200
Epoch 43/100
11610/11610 [==============================] - 1s 62us/sample - loss: 0.3068 - mean_squared_error: 0.3068 - val_loss: 0.3203 - val_mean_squared_error: 0.3203
#打印模型训练过程中的相关曲线
def plot_learning_curves(history):
pd.DataFrame(history.history).plot(figsize=(8,5))
plt.grid(True)
plt.gca().set_ylim(0,1)
plt.show()
plot_learning_curves(history)
model.evaluate(x_test_scaler,y_test)
5160/1 [================================。。。=============================================================================] - 0s 31us/sample - loss: 0.4490 - mean_squared_error: 0.3328
[0.3328484084255012, 0.33284846]