欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

【技术向】简单的三层神经网络Matlab版

程序员文章站 2022-03-04 20:02:28
...

1.简介

先给出资源地址:地址

该资源是我之前写的,matlab实现的简单的三层神经网络,自带了输入数据,可实现batchBP和singleBP,并可输出图像对比,所有函数都在这里没有任何附加函数,在这边分享一下,有不对的地方欢迎一起讨论~


2.代码


2.1主函数

输入数据为3类,共有30个样本

x = [1.58 2.32 -5.8;0.67 1.58 -4.78;1.04 1.01 -3.63;-1.49 2.18 -3.39;-0.41 1.21 -4.73;...
    1.39 3.16 2.87;1.2 1.4 -1.89;-0.92 1.44 -3.22;0.45 1.33 -4.38;-0.76 0.84 -1.96;...%第一类
    0.21 0.03 -2.21;0.37 0.28 -1.8;0.18 1.22 0.16;-0.24 0.93 -1.01;-1.18 0.39 -0.39;...
    0.74 0.96 -1.16;-0.38 1.94 -0.48;0.02 0.72 -0.17;0.44 1.31 -0.14;0.46 1.49 0.68;...%第二类
    -1.54 1.17 0.64;5.41 3.45 -1.33;1.55 0.99 2.69;1.86 3.19 1.51;1.68 1.79 -0.87;...
    3.51 -0.22 -1.39;1.4 -0.44 -0.92;0.44 0.83 1.97;0.25 0.68 -0.99;0.66 -0.45 0.08];%第三类
t = zeros(30,3);
t(1:10,1) = 1;
t(11:20,2) = 1;
t(21:30,3) = 1;

param.Nx = 30;
param.Nd = 3;
param.Ny = 3;

param.Nh = 6;%隐层节点个数
param.eta = 0.5;%权重更新率
param.theta = 1;%停止阈值

%初始化
Wh = 0.2*(rand(param.Nd,param.Nh)-0.5);
Wy = 0.2*(rand(param.Nh,param.Ny)-0.5);
w{1} = Wh;
w{2} = Wy;
[ J, z, yh ] = forward( x, w, t, param );

%批量方式更新权重
[ w1, count1, data1] = batchBP( x, z, yh, w, t, param );
%单样本方式更新权重
[ w2, count2, data2] = singleBP( x, z, yh, w, t, param );
%绘图
n1 = numel(find(data1 > 0));
data1 = data1(1:n1);
xx1 = 1:n1;
n2 = numel(find(data2 > 0));
data2 = data2(1:n2);
xx2 = 1:n2;
plot(xx1,data1,xx2,data2);
grid on
xlabel('迭代次数n')
ylabel('目标函数loss')
title('两种更新方法的loss随迭代次数的关系变化图') 
legend('批量方式更新权重','单样本方式更新权重')

2.2前向传播

function [ J, z, yh ] = forward( x, w, t, param )
%   x为输入样本, w为网络权重, t为输出真值,param为网络参数
%   x = Nx*Nd; Wh = Nd*Nh ; Wy = Nh*Ny; t = Nx*Ny
%   三层BP网络前向过程
%   J为能量损失, z为最后一层输出, yh为隐层输出
Nh = param.Nh;
Nx = param.Nx;
Nd = param.Nd;
Ny = param.Ny;

Wh = w{1};
Wy = w{2};

neth = x*Wh;
yh = tan_h(neth);

netj = yh*Wy;
z = sigmoid(netj);

J = zeros(Nx,1);
for k = 1:Nx
    J(k) = 0;
    for i = 1:Ny
        J(k) = (z(k,i) - t(k,i))^2 + J(k);
    end
end
J = 0.5*J;

end

2.3批量反传

function [ w, count, data] = batchBP( x, z, yh, w, t, param )
%BATCHBP 此处显示有关此函数的摘要
%   批量BP算法
Nh = param.Nh;
Nx = param.Nx;
Nd = param.Nd;
Ny = param.Ny;
eta = param.eta;
theta = param.theta;

flag = 0;
count = 0;
res = [0 0 0];
resid = 1;
m = 0;
data = zeros(30000,1);
q = 1;
while(flag == 0)
    Wh = w{1};
    Wy = w{2};
    sj = zeros(Nx,Ny);
    sh = zeros(Nx,Nh);
    deltaj = zeros(Nh,Ny);
    deltah = zeros(Nd,Nh);
    for k = 1:Nx
        for j = 1:Ny
            sj(k,j) = z(k,j)*(1-z(k,j))*(t(k,j)-z(k,j));
            for h = 1:Nh
                deltaj(h,j) = eta*sj(k,j)*yh(k,h) + deltaj(h,j);
            end
        end
        for h = 1:Nh
            sh(k,h) = (1-yh(k,h)^2)*(Wy(h,:)*sj(k,:)');
            for i = 1:Nd
                deltah(i,h) = eta*sh(k,h)*x(k,i) + deltah(i,h);
            end
        end
    end
    Wy = Wy + deltaj;
    Wh = Wh + deltah;
    w{1} = Wh;
    w{2} = Wy;
    [ J, z, yh ] = forward( x, w, t, param );
    JJ = sum(abs(J));
    data(q) = JJ;
    q = q + 1;
    [ out ] = test( w, x, t );
    res(resid) = JJ;
    resid = resid + 1;
    if resid == 4
        resid = 1;
    end
    if sum(abs(J)) < theta || isnan(JJ) || (round(res(1),3) == round(res(2),3) && round(res(1),3) == round(res(3),3))
        flag = 1;       
    end
    count = count + 1;
    disp(['batch迭代第' num2str(count) '次,正确率为:' num2str(out*100) '%, loss为:' num2str(JJ)]);

end
end

2.4单样本反传

function [ w, count, data] = singleBP( x, z, yh, w, t, param )
%BATCHBP 此处显示有关此函数的摘要
%   单样本更新BP算法
Nh = param.Nh;
Nx = param.Nx;
Nd = param.Nd;
Ny = param.Ny;
eta = param.eta;
theta = param.theta;

flag = 0;
count = 0;
res = [0 0 0];
resid = 1;
data = zeros(30000,1);
q = 1;
while(flag == 0)
    for k = 1:Nx
        Wh = w{1};
        Wy = w{2};
        sj = zeros(Nx,Ny);
        sh = zeros(Nx,Nh);
        deltaj = zeros(Nh,Ny);
        deltah = zeros(Nd,Nh);
        for j = 1:Ny
            sj(k,j) = z(k,j)*(1-z(k,j))*(t(k,j)-z(k,j));
            for h = 1:Nh
                deltaj(h,j) = eta*sj(k,j)*yh(k,h);
            end
        end
        for h = 1:Nh
            sh(k,h) = (1-yh(k,h)^2)*(Wy(h,:)*sj(k,:)');
            for i = 1:Nd
                deltah(i,h) = eta*sh(k,h)*x(k,i);
            end
        end
        Wy = Wy + deltaj;
        Wh = Wh + deltah;
        w{1} = Wh;
        w{2} = Wy;
        [ J, z, yh ] = forward( x, w, t, param );
        JJ = sum(abs(J));
        data(q) = JJ;
        q = q + 1;
        [ out ] = test( w, x, t );
        res(resid) = JJ;
        resid = resid + 1;
        if resid == 4
            resid = 1;
        end
        count = count + 1;
        disp(['single迭代第' num2str(count) '次,正确率为:' num2str(out*100) '%, loss为:' num2str(JJ)]);
        if sum(abs(J)) < theta || isnan(JJ) || (round(res(1),5) == round(res(2),5) && round(res(1),5) == round(res(3),5))
            flag = 1;
            break;
        end
    end
end
end

2.5测试函数

function [ out ] = test( w, x, t )
%TEST 此处显示有关此函数的摘要
%   此处显示详细说明
Wh = w{1};
Wy = w{2};

neth = x*Wh;
yh = tan_h(neth);

netj = yh*Wy;
z = sigmoid(netj);

[~,I] = max(z,[],2);
[~,It] = max(t,[],2);
tr = numel(find(I == It));
out = tr/size(x,1);

end

2.6**函数

双曲正切函数

function [ y ] = tan_h( x )
%TAN_H 此处显示有关此函数的摘要
%   双曲正切函数
[a,b] = size(x);
y = zeros(a,b);
for i = 1:a
    for j = 1:b
        y(i,j) = (exp(x(i,j))-exp(-x(i,j)))/(exp(x(i,j))+exp(-x(i,j)));
    end
end
end

Sigmiod函数

function [ y ] = sigmoid( x )
%SIGMOID 此处显示有关此函数的摘要
%   此处显示详细说明
[a,b] = size(x);
y = zeros(a,b);
for i = 1:a
    for j = 1:b
        y(i,j) = 1/(1+exp(-x(i,j)));
    end
end
end

2.7输出结果

BP和SP的对比结果如下图【技术向】简单的三层神经网络Matlab版

相关标签: Matlab