欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

3D梯度下降

程序员文章站 2024-01-19 13:47:10
...

我的上一篇博客,两个变量没画出3D图
本次就将上一次代码进行改进,使结果的图像为3D

代码实现

#引入所需模块
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

#方程组为f(x1,x2) = (x1+x2)^2

#定义目标函数
def mubiaofunc(x1,x2):
    return (x1+x2)**2
    pass
    
#定义梯度函数,返回导数,我分别求了两个偏导,可以放在一个函数里
def tidufuncx1(x1,x2):
    return 2*x1+2*x2
    pass

def tidufuncx2(x1,x2):
    return 2*x2+2*x1
    pass

#定义空列表,存放x1,x2的元组
listx = []

#定义猜测的过程
def func(pointx1,pointx2,mubiaofunc,tidufuncx1,tidufuncx2,rating=0.1,tolent=0.000001,times=500):
	'''
    :param pointx1,pointx2: 猜测的点
    :param mubiaofunc: 目标函数
    :param tidufuncx1,tidufuncx2: 梯度函数
    :param rating: 步进系数(移动的距离)
    :param tolent: 收敛条件(极小)
    :param times:移动次数
    :return: 返回极值点的x值
    '''
    #先求一次
    mubiao = mubiaofunc(pointx1,pointx2)
    listx.append((pointx1,pointx2))
    newpointx1 = pointx1 - rating*tidufuncx1(pointx1,pointx2)
    newpointx2 = pointx2 - rating*tidufuncx2(pointx1,pointx2)
    newmubiao = mubiaofunc(newpointx1,newpointx2)
    remubiao = np.abs(mubiao-newmubiao)
    t = 0
    #然后判断前后目标函数之差是否小于等于步进系数并且移动次数是否大于给定值
    while remubiao > tolent and t < times:
        t += 1
        pointx1 = newpointx1
        pointx2 = newpointx2
        mubiao = newmubiao
        listx.append((pointx1,pointx2))     #列表中存的是元组
        newpointx1 = newpointx1 - rating*tidufuncx1(newpointx1,newpointx2)
        newpointx2 = newpointx2 - rating*tidufuncx2(newpointx2,newpointx1)
        newmubiao = mubiaofunc(newpointx1,newpointx2)
        remubiao = np.abs(mubiao - newmubiao)
        pass
    return pointx1,pointx2     #返回两个点
    pass

if __name__ == '__main__':
    print(func(5,5,mubiaofunc,tidufuncx1,tidufuncx2))     #返回最后一次的值
    print(listx)     #返回列表内容
    
    pl = plt.figure()    #建画布
    ax = Axes3D(pl)    #变成3D

    x1,x2 = np.linspace(-1,6,25),np.linspace(-1,6,25)     #linespace中的参数是从-1到6,分成25份,间距是24个
    x1,x2 = np.meshgrid(x1,x2)      #初始化数据
    fx1x2=mubiaofunc(x1,x2)      #z轴的值
    ax.plot_surface(x1,x2,fx1x2)     #坐标系
    ax.scatter(*(np.array(listx).T), np.array(listx).sum(1)**2,s=20,c='r')      #点图
    plt.show()
    pass

主要是新增了Axes3D模块