一、获取数据集

 
 获取训练集和测试集,可直接在网上下载,如Kaggle上有许多数据集。数据集内每个数据包括标签和内容,此次用的是CIFAR10的数据集(包括飞机、火车等10个分类的图片)。

我们直接使用自带的torchvision.datasets.CIFAR10函数进行下载。
import torchvision #
root表示数据集存储路径,train表示是否是训练集,transform表示将图片转换成指定格式,download表 #示是否从网络中下载
#准备训练集、测试集 train = torchvision.datasets.CIFAR10(root = '../dataset2',train =
True,transform=torchvision.transforms.ToTensor(),download=True) test =
torchvision.datasets.CIFAR10(root = '../dataset2',train =
False,transform=torchvision.transforms.ToTensor(),download=True)
我们可以打印看下数据有多少
print("训练集长度为{}".format(len(train))) print("测试集长度为{}".format(len(test)))
 

 然后加载数据集,使用DataLoader

加载数据集,就是类似于将数据打包,例如将1000个数据,每100个打包成一个集合
from torch.utils.data import DataLoader train_loader =
DataLoader(train,batch_size=64) test_loader = DataLoader(test,batch_size=64)
二、神经网络的搭建 

在加载好数据之后,我们就可以来搭建神经网络了。我们先来看一下CIFAR10网络的结构模型。

搭建模型,在torch中就是继承Model类,重写forward方法,写法很简单

 
import torch.nn as nn class MyModule(nn.Module): def __init__(self):
super(MyModule, self).__init__() self.model = nn.Sequential( nn.Conv2d(3, 32,
5, 1, 2), nn.MaxPool2d(2), nn.Conv2d(32, 32, 5, 1, 2), nn.MaxPool2d(2),
nn.Conv2d(32, 64, 5, 1, 2), nn.MaxPool2d(2), nn.Flatten(), nn.Linear(64 * 4 *
4, 64),#注意这一步在图中没有画出来 nn.Linear(64, 10) ) def forward(self,x): x =
self.model(x) return x
测试网络模型搭建是否正确
import torch input = torch.ones((64,3,32,32)) mymodel = MyModule() output =
mymodel(input) print(output.shape)
训练模型

损失函数(此处使用交叉熵验证)以及其他参数设置,然后开始训练
mymodel = MyModule() #损失函数 loss = nn.CrossEntropyLoss() #优化器 learn_rate =
1e-2; opt = torch.optim.SGD(mymodel.parameters(),lr=learn_rate) #训练次数 test_step
= 0 #测试次数 train_step = 0 #训练轮数 epoch = 10 for i in range(epoch):
print("========第{}论训练=========".format(i+1)) for data in train_loader:
imgs,targets = data output = mymodel(imgs) lossf = loss(output,targets) #优化器调优
opt.zero_grad()#每一次梯度清零 loss.backward()#反向传播 opt.step() train_step += 1
print("第{}次训练 Loss:{}".format(train_step,lossf.item())) # 测试 total_accuacy = 0
with torch.no_grad():#总的准确数 for data in test_loader: img,targets = data output
= mymodel(img) accuracy = (output.argmax(1) == targets).sum() test_step += 1
total_accuacy += accuracy print("测试集上的准确率为{}".format(total_accuacy/len(test)))
#保存模型 torch.save(mymodel,"mine.pth")
代码
import torchvision #
root表示数据集存储路径,train表示是否是训练集,transform表示将图片转换成指定格式,download表示是否从网络中下载 #准备训练集、测试集
from torch.utils.tensorboard import SummaryWriter from torch.utils.data import
DataLoader import torch import torch.nn as nn class MyModule(nn.Module): def
__init__(self): super(MyModule, self).__init__() self.model = nn.Sequential(
nn.Conv2d(3, 32, 5, 1, 2), nn.MaxPool2d(2), nn.Conv2d(32, 32, 5, 1, 2),
nn.MaxPool2d(2), nn.Conv2d(32, 64, 5, 1, 2), nn.MaxPool2d(2), nn.Flatten(),
nn.Linear(64 * 4 * 4, 64),#注意这一步在图中没有画出来 nn.Linear(64, 10) ) def
forward(self,x): x = self.model(x) return x if __name__ == '__main__': train =
torchvision.datasets.CIFAR10(root='./dataset2', train=True,
transform=torchvision.transforms.ToTensor(), download=False) test =
torchvision.datasets.CIFAR10(root='./dataset2', train=True,
transform=torchvision.transforms.ToTensor(), download=False) train_loader =
DataLoader(train, batch_size=64) test_loader = DataLoader(test, batch_size=64)
mymodel = MyModule() #损失函数 loss = nn.CrossEntropyLoss() #优化器 learn_rate = 1e-2;
opt = torch.optim.SGD(mymodel.parameters(),lr=learn_rate) #训练次数 test_step = 0
#测试次数 train_step = 0 #训练轮数 epoch = 10 total_loss = 0 for i in range(epoch):
print("========第{}论训练=========".format(i+1)) for data in train_loader:
imgs,targets = data output = mymodel(imgs) lossf = loss(output,targets) #优化器调优
opt.zero_grad()#每一次梯度清零 lossf.backward()#反向传播 opt.step() train_step += 1
if(train_step% 100 == 0): print("第{}次训练
Loss:{}".format(train_step,lossf.item())) # 测试 total_accuacy = 0 with
torch.no_grad():#总的准确数 for data in test_loader: img,targets = data output =
mymodel(img) accuracy = (output.argmax(1) == targets).sum() test_step += 1
total_accuacy += accuracy print("测试集上的准确率为{}".format(total_accuacy/len(test)))
#保存模型 torch.save(mymodel,"mine.pth")

技术
下载桌面版
GitHub
Microsoft Store
SourceForge
Gitee
百度网盘(提取码:draw)
云服务器优惠
华为云优惠券
京东云优惠券
腾讯云优惠券
阿里云优惠券
Vultr优惠券
站点信息
问题反馈
邮箱:[email protected]
吐槽一下
QQ群:766591547
关注微信