<>torchvision

torchvision 是 PyTorch 的一个图形图像库(专门用来处理图像和视觉的),主要用于构建计算机视觉模型。

torchvision 包含四个大类:

* torchvision.datasets:包含一些加载数据的函数及常用的数据集接口
* torchvision.models:包含常用的模型结构(含预训练模型),如 AlexNet、VGG 等
* torchvision.transforms:包含常用的图片变换操作,如裁剪、旋转等
* torchvision.utils:包含其他一些有用的方法
<>torchvision.transforms

torchvision.transforms 是 PyTorch 中的图像预处理包,包含了很多对图像数据进行变换的函数,主要用于常见的一些图形变换。

torchvision.transforms.Compose() 类,这个类的主要作用是串联多个图形变换的操作,它会对列表里面的变换操作进行遍历。

torchvision.transforms.ToTensor() 类,把 shape=(H*W*C) 的像素值范围为 [0, 255] 的
PIL.Image 或者 numpy.ndarray 转换成 shape=(C*H*W) 的像素值范围为 [0.0, 1.0] 的
torch.FloatTensor。

torchvision.transforms.ToPILImage() 类,把 shape=(C*H*W) 的 Tensor 或者
shape=(H*W*C) 的 numpy.ndarray 转换成 shape=(H*W*C) 的 PIL.Image,值不变。

torchvision.transforms.Normalize(mean, std) 类,用给定的均值和标准差分别对每个通道的数据进行规范化。
具体来说,给定均值 ( M 1 , M 2 , . . . , M n ) (M_1, M_2, ..., M_n) (M1​,M2​,...,Mn​)
和标准差 ( S 1 , S 2 , . . . , S n ) (S_1, S_2, ..., S_n) (S1​,S2​,...,Sn​),其中
n(一般为 3 (R, G, B))为通道数。用公式 channel = (channel - mean) / std 来进行规范化。
对每个通道进行如下操作:output[channel] = (input[channel] - mean[channel]) / std[channel]。
比如原来的 tensor 是三个维度的,数值在 [0, 1] 之间,经过变换之后数值范围就扩展到 [-1, 1] 之间。计算如下:((0, 1) -
0.5) / 0.5 = (-1, 1)。

torchvision.transforms.CenterCrop(size) 类,将给定的 PIL.Image 进行中心裁剪,得到指定的 size。参数
size 可以是一个整数,裁剪出来的是一个正方形图像;size 也可以是一个 tuple(target_height, target_width)。

torchvision.transforms.RandomCrop(size, padding=0) 类,对 PIL.Image
进行随机裁剪,即裁剪中心点的位置随机选取。参数 size 可以是一个整数,也可以是一个 tuple。

torchvision.transforms.RandomResizedCrop(size) 类,先对 PIL.Image 进行随机裁剪,然后再将其
resize 成给定 size 大小。

torchvision.transforms.RandomHorizontalFlip(p=0.5) 类,将给定的 PIL.Image
随机水平翻转,翻转的概率默认为 0.5。

torchvision.transforms.RandomVerticalFlip(p=0.5) 类,将给定的 PIL.Image
随机垂直翻转,翻转的概率默认为 0.5。

torchvision.transforms.Pad(padding, fill=0) 类,用给定的值填充 PIL.Image
的所有边。padding–各边要填充多少个像素,fill–用什么值填充。
import torchvision.transforms as transforms # 图像预处理 transform = transforms.
Compose([ transforms.Resize(96), # 将图像缩放到 96*96 大小 transforms.ToTensor(), #
将图像数据转换成张量 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 对张量数据进行归一化 ]
)
<>torchvision.datasets

torchvision.datasets 是用于数据加载的包,PyTorch 团队在这个包中帮我们提前处理好了很多图像数据集。如
MNIST、CocoCaptions、CocoDetection、LSUN、ImageFolder、ImageNet、CIFAR10、STL10、SVHN、PhotoTour
等。

所有数据集都是 torch.utils.data.Dataset 的子类,它们实现了 _getitem_ 和 __len__ 方法,因此它们都可以传递给
torch.utils.data.DataLoader。
import torchvision.datasets as datasets # 数据集准备 trainset = datasets.MNIST( root
='./data', # 根目录 train=True, # 用于指定需要载入数据集的哪个部分,这里载入的是训练集;如果为 False,则载入测试集
transform=transform, # 用于指定导入数据集时需要对数据进行哪些变换操作,需提前定义这些变换操作 download=True #
用于指定是否需要网上下载;如果为 True,则从网上下载数据集并将其放在根目录中;如果已下载数据集,则不会再次下载 ) # 数据集加载 trainloader
= torch.utils.data.DataLoader( trainset, # 准备的数据集 batch_size=4, # 设定图像数据的批次大小
shuffle=True, # 如果为 True,则每个 epoch 都会将数据集打乱 num_workers=2, # 设定加载数据时的线程数目;默认为
0,主线程加载数据 collate_fn=<function default_collate>, # 指定取样本的方式,可以自己定义函数来实现想要的功能
pin_memory=False, # 指定是否为锁页内存 drop_last=False # 用于指定对 len(trainset)/batch_size
余下的数据的处理方式;如果为 True,则将最后不够一个 batch_size 的数据抛弃;如果为 False,则保留 ) '''
主机中的内存有两种存在方式,一是锁页,二是不锁页。锁页内存存放的内容在任何情况下都不会与主机的虚拟内存(虚拟内存就是硬盘)进行交换;而不锁页内存在主机内存不足时,数据会存放在虚拟内存中。显卡中的显存全部是锁页内存。
当计算机的内存充足时,可以设置 pin_memory=True;当系统卡住,或者交换内存使用过多的时候,设置 pin_memory=False。因为
pin_memory 与电脑硬件性能有关,PyTorch 开发者不能确保每一个炼丹玩家都有高端设备,因此 pin_memory 默认为 False。 '''
<>torchvision.models

torchvision.models 中包含如 AlexNet、VGG、ResNet、SqueezeNet、DenseNet
等模型结构,同时为我们提供已经预训练好的模型,我们加载之后可以直接使用。

可以通过以下代码快速创建一个随机初始化权重的模型:
import torchvision.models as models alexnet = models.alexnet() vgg16 = models.
vgg16() resnet18 = models.resnet18()
也可以通过 pretrained=True 来加载一个预训练好的模型:
import torchvision.models as models alexnet = models.alexnet(pretrained=True)
vgg16= models.vgg16(pretrained=True) resnet18 = models.resnet18(pretrained=True)
我们把 torchvision 中的多个类组合起来使用:
# 初始的 MNIST 数据集图像大小为 28*28,我们把它们处理成 96*96 的 torch.Tensor 的格式 import torchvision
.transforms as transforms import torchvision.datasets as datasets from torch.
utils.data import DataLoader # 图像预处理 transform = transforms.Compose([ transforms
.Resize(96), # 缩放到 96*96 大小 transforms.ToTensor(), # 将图像转换成 Tensor transforms.
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化 ]) # 数据集准备 train_dataset =
datasets.MNIST(root='./data/', train=True, transform=transform, download=True)
# 数据集加载 train_loader = DataLoader(dataset=train_dataset, batch_size=8, shuffle=
True) print(len(train_dataset)) print(len(train_loader)) --------- 60000 7500

技术
今日推荐
下载桌面版
GitHub
百度网盘(提取码:draw)
Gitee
云服务器优惠
阿里云优惠券
腾讯云优惠券
华为云优惠券
站点信息
问题反馈
邮箱:[email protected]
QQ群:766591547
关注微信