【技术向】简单的三层神经网络Matlab版
程序员文章站
2022-03-04 20:02:28
...
1.简介
先给出资源地址:地址
该资源是我之前写的,matlab实现的简单的三层神经网络,自带了输入数据,可实现batchBP和singleBP,并可输出图像对比,所有函数都在这里没有任何附加函数,在这边分享一下,有不对的地方欢迎一起讨论~
2.代码
2.1主函数
输入数据为3类,共有30个样本
x = [1.58 2.32 -5.8;0.67 1.58 -4.78;1.04 1.01 -3.63;-1.49 2.18 -3.39;-0.41 1.21 -4.73;...
1.39 3.16 2.87;1.2 1.4 -1.89;-0.92 1.44 -3.22;0.45 1.33 -4.38;-0.76 0.84 -1.96;...%第一类
0.21 0.03 -2.21;0.37 0.28 -1.8;0.18 1.22 0.16;-0.24 0.93 -1.01;-1.18 0.39 -0.39;...
0.74 0.96 -1.16;-0.38 1.94 -0.48;0.02 0.72 -0.17;0.44 1.31 -0.14;0.46 1.49 0.68;...%第二类
-1.54 1.17 0.64;5.41 3.45 -1.33;1.55 0.99 2.69;1.86 3.19 1.51;1.68 1.79 -0.87;...
3.51 -0.22 -1.39;1.4 -0.44 -0.92;0.44 0.83 1.97;0.25 0.68 -0.99;0.66 -0.45 0.08];%第三类
t = zeros(30,3);
t(1:10,1) = 1;
t(11:20,2) = 1;
t(21:30,3) = 1;
param.Nx = 30;
param.Nd = 3;
param.Ny = 3;
param.Nh = 6;%隐层节点个数
param.eta = 0.5;%权重更新率
param.theta = 1;%停止阈值
%初始化
Wh = 0.2*(rand(param.Nd,param.Nh)-0.5);
Wy = 0.2*(rand(param.Nh,param.Ny)-0.5);
w{1} = Wh;
w{2} = Wy;
[ J, z, yh ] = forward( x, w, t, param );
%批量方式更新权重
[ w1, count1, data1] = batchBP( x, z, yh, w, t, param );
%单样本方式更新权重
[ w2, count2, data2] = singleBP( x, z, yh, w, t, param );
%绘图
n1 = numel(find(data1 > 0));
data1 = data1(1:n1);
xx1 = 1:n1;
n2 = numel(find(data2 > 0));
data2 = data2(1:n2);
xx2 = 1:n2;
plot(xx1,data1,xx2,data2);
grid on
xlabel('迭代次数n')
ylabel('目标函数loss')
title('两种更新方法的loss随迭代次数的关系变化图')
legend('批量方式更新权重','单样本方式更新权重')
2.2前向传播
function [ J, z, yh ] = forward( x, w, t, param )
% x为输入样本, w为网络权重, t为输出真值,param为网络参数
% x = Nx*Nd; Wh = Nd*Nh ; Wy = Nh*Ny; t = Nx*Ny
% 三层BP网络前向过程
% J为能量损失, z为最后一层输出, yh为隐层输出
Nh = param.Nh;
Nx = param.Nx;
Nd = param.Nd;
Ny = param.Ny;
Wh = w{1};
Wy = w{2};
neth = x*Wh;
yh = tan_h(neth);
netj = yh*Wy;
z = sigmoid(netj);
J = zeros(Nx,1);
for k = 1:Nx
J(k) = 0;
for i = 1:Ny
J(k) = (z(k,i) - t(k,i))^2 + J(k);
end
end
J = 0.5*J;
end
2.3批量反传
function [ w, count, data] = batchBP( x, z, yh, w, t, param )
%BATCHBP 此处显示有关此函数的摘要
% 批量BP算法
Nh = param.Nh;
Nx = param.Nx;
Nd = param.Nd;
Ny = param.Ny;
eta = param.eta;
theta = param.theta;
flag = 0;
count = 0;
res = [0 0 0];
resid = 1;
m = 0;
data = zeros(30000,1);
q = 1;
while(flag == 0)
Wh = w{1};
Wy = w{2};
sj = zeros(Nx,Ny);
sh = zeros(Nx,Nh);
deltaj = zeros(Nh,Ny);
deltah = zeros(Nd,Nh);
for k = 1:Nx
for j = 1:Ny
sj(k,j) = z(k,j)*(1-z(k,j))*(t(k,j)-z(k,j));
for h = 1:Nh
deltaj(h,j) = eta*sj(k,j)*yh(k,h) + deltaj(h,j);
end
end
for h = 1:Nh
sh(k,h) = (1-yh(k,h)^2)*(Wy(h,:)*sj(k,:)');
for i = 1:Nd
deltah(i,h) = eta*sh(k,h)*x(k,i) + deltah(i,h);
end
end
end
Wy = Wy + deltaj;
Wh = Wh + deltah;
w{1} = Wh;
w{2} = Wy;
[ J, z, yh ] = forward( x, w, t, param );
JJ = sum(abs(J));
data(q) = JJ;
q = q + 1;
[ out ] = test( w, x, t );
res(resid) = JJ;
resid = resid + 1;
if resid == 4
resid = 1;
end
if sum(abs(J)) < theta || isnan(JJ) || (round(res(1),3) == round(res(2),3) && round(res(1),3) == round(res(3),3))
flag = 1;
end
count = count + 1;
disp(['batch迭代第' num2str(count) '次,正确率为:' num2str(out*100) '%, loss为:' num2str(JJ)]);
end
end
2.4单样本反传
function [ w, count, data] = singleBP( x, z, yh, w, t, param )
%BATCHBP 此处显示有关此函数的摘要
% 单样本更新BP算法
Nh = param.Nh;
Nx = param.Nx;
Nd = param.Nd;
Ny = param.Ny;
eta = param.eta;
theta = param.theta;
flag = 0;
count = 0;
res = [0 0 0];
resid = 1;
data = zeros(30000,1);
q = 1;
while(flag == 0)
for k = 1:Nx
Wh = w{1};
Wy = w{2};
sj = zeros(Nx,Ny);
sh = zeros(Nx,Nh);
deltaj = zeros(Nh,Ny);
deltah = zeros(Nd,Nh);
for j = 1:Ny
sj(k,j) = z(k,j)*(1-z(k,j))*(t(k,j)-z(k,j));
for h = 1:Nh
deltaj(h,j) = eta*sj(k,j)*yh(k,h);
end
end
for h = 1:Nh
sh(k,h) = (1-yh(k,h)^2)*(Wy(h,:)*sj(k,:)');
for i = 1:Nd
deltah(i,h) = eta*sh(k,h)*x(k,i);
end
end
Wy = Wy + deltaj;
Wh = Wh + deltah;
w{1} = Wh;
w{2} = Wy;
[ J, z, yh ] = forward( x, w, t, param );
JJ = sum(abs(J));
data(q) = JJ;
q = q + 1;
[ out ] = test( w, x, t );
res(resid) = JJ;
resid = resid + 1;
if resid == 4
resid = 1;
end
count = count + 1;
disp(['single迭代第' num2str(count) '次,正确率为:' num2str(out*100) '%, loss为:' num2str(JJ)]);
if sum(abs(J)) < theta || isnan(JJ) || (round(res(1),5) == round(res(2),5) && round(res(1),5) == round(res(3),5))
flag = 1;
break;
end
end
end
end
2.5测试函数
function [ out ] = test( w, x, t )
%TEST 此处显示有关此函数的摘要
% 此处显示详细说明
Wh = w{1};
Wy = w{2};
neth = x*Wh;
yh = tan_h(neth);
netj = yh*Wy;
z = sigmoid(netj);
[~,I] = max(z,[],2);
[~,It] = max(t,[],2);
tr = numel(find(I == It));
out = tr/size(x,1);
end
2.6**函数
双曲正切函数
function [ y ] = tan_h( x )
%TAN_H 此处显示有关此函数的摘要
% 双曲正切函数
[a,b] = size(x);
y = zeros(a,b);
for i = 1:a
for j = 1:b
y(i,j) = (exp(x(i,j))-exp(-x(i,j)))/(exp(x(i,j))+exp(-x(i,j)));
end
end
end
Sigmiod函数
function [ y ] = sigmoid( x )
%SIGMOID 此处显示有关此函数的摘要
% 此处显示详细说明
[a,b] = size(x);
y = zeros(a,b);
for i = 1:a
for j = 1:b
y(i,j) = 1/(1+exp(-x(i,j)));
end
end
end
2.7输出结果
BP和SP的对比结果如下图
上一篇: Matlab实现给黑白图片上色
下一篇: MATLAB月球坑识别和降落避障