线性模型/使用matplotlib绘制3D图像
程序员文章站
2022-03-02 08:18:59
...
跟着刘老师:https://www.bilibili.com/video/av92862340 学习的pytorch课程,第一节课课后习题:
使用pytorch实现一个简单的线性模型,并调用matplotlib输出模型图像。
(本菜鸡注释真的多...)
import torch
import numpy as np
import matplotlib.pyplot as plt #绘图用的模块
from mpl_toolkits.mplot3d import Axes3D #绘制3D坐标的函数
x_data=[1.0, 2.0, 3.0]
y_data=[5.0, 8.0, 11.0]
#构建线性模型
def forward(x):
return x*w+b
#构建损失函数
def loss(x, y):
y_pred= forward(x)
return (y_pred-y)**2
W= np.arange(0.0, 4.1, 0.1) #arrange对象
B= np.arange(0.0, 4.1, 0.1)
[w, b]=np.meshgrid(W,B)#用两个arrange对象中的可能取值,映射扩充所有可能的取样点
#绘图的Z坐标必须是二维的,所以必须将这个过程放在一个函数里
def function(w, b):
for w in W:
for b in B:
l_sum= 0
for x_val, y_val in zip(x_data, y_data):
y_pred_val=forward(x_val)
loss_val= loss(x_val, y_val)
l_sum+= loss_val
return l_sum/3
fig= plt.figure() #创建一个绘图对象
ax= Axes3D(fig) #用上述创建的绘图对象创建一个Axes对象,带有3D对象
# 这个函数表示用取样点构建曲面, cmap表示曲面的颜色
ax.plot_surface(w, b, function(w, b),cmap=plt.cm.coolwarm)
plt.show()
输出结果:
上一篇: vueRouter 路由的配置
下一篇: 前端路由理解
推荐阅读
-
Python使用matplotlib模块绘制图像并设置标题与坐标轴等信息示例
-
Python使用matplotlib或pandas绘制图像中文乱码问题解决方案仅供参考
-
matplotlib函数库使用Axes3D绘制3D图形
-
Python中使用matplotlib绘制mqtt数据实时图像功能
-
matplotlib 3D模型绘制一朵小红花
-
matplotlib 3D模型绘制一朵小红花
-
Python使用matplotlib绘制3D图形(代码示例)
-
Python数据分析三大框架之matplotlib(三)3D图像绘制
-
使用Python的matplotlib包绘制三维图像
-
Python使用matplotlib绘制3D图形(代码示例)