pytorch 6 batch_train 批训练操作
程序员文章站
2022-03-10 16:58:26
看代码吧~import torchimport torch.utils.data as datatorch.manual_seed(1) # reproducible# batch_size =...
看代码吧~
import torch import torch.utils.data as data torch.manual_seed(1) # reproducible # batch_size = 5 batch_size = 8 # 每次使用8个数据同时传入网路 x = torch.linspace(1, 10, 10) # this is x data (torch tensor) y = torch.linspace(10, 1, 10) # this is y data (torch tensor) torch_dataset = data.tensordataset(x, y) loader = data.dataloader( dataset=torch_dataset, # torch tensordataset format batch_size=batch_size, # mini batch size shuffle=false, # 设置不随机打乱数据 random shuffle for training num_workers=2, # 使用两个进程提取数据,subprocesses for loading data ) def show_batch(): for epoch in range(3): # 全部的数据使用3遍,train entire dataset 3 times for step, (batch_x, batch_y) in enumerate(loader): # for each training step # train your data... print('epoch: ', epoch, '| step: ', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy()) if __name__ == '__main__': show_batch()
batch_size = 8 , 所有数据利用三次
epoch: 0 | step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.] epoch: 0 | step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.] epoch: 1 | step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.] epoch: 1 | step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.] epoch: 2 | step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.] epoch: 2 | step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]
补充:pytorch批训练bug
问题描述:
在进行pytorch神经网络批训练的时候,有时会出现报错
typeerror: batch must contain tensors, numbers, dicts or lists; found <class 'torch.autograd.variable.variable'>
解决办法:
第一步:
检查(重点!!!!!):
train_dataset = data.tensordataset(train_x, train_y)
train_x,和train_y格式,要求是tensor类,我第一次出错就是因为传入的是variable
可以这样将数据变为tensor类:
train_x = torch.floattensor(train_x)
第二步:
train_loader = data.dataloader( dataset=train_dataset, batch_size=batch_size, shuffle=true )
实例化一个dataloader对象
第三步:
for epoch in range(epochs): for step, (batch_x, batch_y) in enumerate(train_loader): batch_x, batch_y = variable(batch_x), variable(batch_y)
这样就可以批训练了
需要注意的是:train_loader输出的是tensor,在训练网络时,需要变成variable
以上为个人经验,希望能给大家一个参考,也希望大家多多支持。
上一篇: 谈谈分库分表的几个核心流程
下一篇: MySQL 数据类型选择原则
推荐阅读