欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

线性模型/使用matplotlib绘制3D图像

程序员文章站 2022-03-02 08:18:59
...

跟着刘老师:https://www.bilibili.com/video/av92862340 学习的pytorch课程,第一节课课后习题:

使用pytorch实现一个简单的线性模型,并调用matplotlib输出模型图像。

(本菜鸡注释真的多...)

import torch
import numpy as np
import matplotlib.pyplot as plt #绘图用的模块
from mpl_toolkits.mplot3d import Axes3D #绘制3D坐标的函数
x_data=[1.0, 2.0, 3.0]
y_data=[5.0, 8.0, 11.0]
#构建线性模型
def forward(x):
    return x*w+b
#构建损失函数
def loss(x, y):
    y_pred= forward(x)
    return (y_pred-y)**2
W= np.arange(0.0, 4.1, 0.1) #arrange对象
B= np.arange(0.0, 4.1, 0.1)
[w, b]=np.meshgrid(W,B)#用两个arrange对象中的可能取值,映射扩充所有可能的取样点
#绘图的Z坐标必须是二维的,所以必须将这个过程放在一个函数里
def function(w, b):
    for w in W:
        for b in B:
            l_sum= 0
            for x_val, y_val in zip(x_data, y_data):
                y_pred_val=forward(x_val)
                loss_val= loss(x_val, y_val)
                l_sum+= loss_val
    return l_sum/3

fig= plt.figure() #创建一个绘图对象
ax= Axes3D(fig) #用上述创建的绘图对象创建一个Axes对象,带有3D对象
# 这个函数表示用取样点构建曲面, cmap表示曲面的颜色
ax.plot_surface(w, b, function(w, b),cmap=plt.cm.coolwarm)
plt.show()

输出结果:

线性模型/使用matplotlib绘制3D图像