Transfrom
from torchvision import transforms
🖼️ TorchVision 介绍¶
TorchVision 是 PyTorch 官方支持的、专门用于计算机视觉 (Computer Vision) 任务的库。它是 PyTorch 生态系统中不可或缺的一部分,极大地简化了计算机视觉研究和应用的开发过程。
TorchVision 主要由三个核心模块构成,共同为视觉任务提供高效的支持:
1. torchvision.datasets (数据集)¶
这个模块提供了大量预先准备好的、常用的公共数据集,你可以直接下载和使用,无需自己编写复杂的数据加载代码。
- 核心功能: 方便地加载标准数据集进行训练和测试。
- 常见数据集示例:
- Image Classification (图像分类):
MNIST,CIFAR10,ImageNet(需要手动下载和组织文件)。 - Object Detection (目标检测):
CocoDetection,Kitti. - Semantic Segmentation (语义分割):
VOCSegmentation.
- Image Classification (图像分类):
2. torchvision.models (预训练模型)¶
这个模块提供了大量预训练的、高性能的神经网络模型。这些模型通常在大型数据集(如 ImageNet)上训练过,可以直接用于特征提取或作为迁移学习 (Transfer Learning) 的起点。
- 核心功能: 快速获取和使用经过验证的经典或先进模型。
- 常见模型架构示例:
- Classification (分类):
ResNet(如resnet50),VGG,MobileNet,EfficientNet. - Object Detection (目标检测):
Faster R-CNN,SSD,RetinaNet. - Segmentation (分割):
FCN,DeepLabV3.
- Classification (分类):
3. torchvision.transforms (数据变换/预处理)¶
这个模块提供了各种图像预处理和数据增强 (Data Augmentation) 操作。在将原始图像输入到神经网络之前,通常需要进行这些处理。
- 核心功能: 对 PIL Image 或 Tensor 进行各种操作,如尺寸调整、裁剪、归一化等。
- 常见操作示例:
ToTensor(): 将 PIL Image 或 NumPy 数组转换成 PyTorch Tensor。Normalize(): 对 Tensor 进行标准化处理 (减去均值,除以标准差)。Resize(): 改变图像尺寸。RandomCrop(): 随机裁剪,用于数据增强。Compose(): 将多个变换操作串联起来。
总结¶
TorchVision 是 PyTorch 用户进行计算机视觉项目时的标配工具箱,它通过提供即插即用的数据集、模型和预处理工具,帮助开发者:
- 快速启动项目: 利用预训练模型和标准数据集快速构建原型。
- 保证数据处理一致性: 使用统一的
transformsAPI 进行数据预处理。 - 节省资源: 避免重复实现常见模型的代码。
transform¶
torchvision.transforms 介绍¶
torchvision.transforms 是 TorchVision 库中专门用于图像预处理 (Preprocessing) 和数据增强 (Data Augmentation) 的核心模块。
在将原始图像数据送入神经网络之前,我们几乎总是需要进行一系列的转换操作,而 transforms 模块就是为此设计的。
核心目的¶
- 统一输入格式: 将不同格式的输入(如 PIL Image, NumPy Array)转换为 PyTorch 模型所需的
Tensor格式。 - 标准化: 对像素值进行归一化,使其均值接近 0、标准差接近 1,有助于模型训练收敛。
- 数据增强: 通过随机变换(如裁剪、翻转、旋转),增加训练数据的多样性,从而提高模型的泛化能力,减少过拟合。
常见操作分类及示例¶
1. 基础转换 (Normalization & Format)¶
| 转换方法 | 作用 | 目的 |
|---|---|---|
ToTensor() |
将 PIL Image 或 NumPy ndarray 转换为 FloatTensor。 |
必须的操作,将数据转换为模型能处理的 PyTorch 张量。 |
Normalize(mean, std) |
根据给定的均值和标准差对 Tensor 进行标准化。 | 将像素值缩放到一个合适的范围,加速模型收敛。 |
Resize(size) |
将输入图像调整到指定的尺寸。 | 确保所有输入图像尺寸统一。 |
2. 数据增强 (Augmentation)¶
| 转换方法 | 作用 | 目的 |
|---|---|---|
RandomCrop(size) |
从图像中随机裁剪一块指定大小的区域。 | 增加训练样本的多样性和平移不变性。 |
RandomHorizontalFlip(p=0.5) |
以给定的概率水平随机翻转图像。 | 增加数据的对称性。 |
RandomRotation(degrees) |
随机旋转图像一个角度范围。 | 增加模型的旋转不变性。 |
ColorJitter(brightness=0.2, ...) |
随机改变图像的亮度、对比度、饱和度和色调。 | 使模型对光照变化更具鲁棒性。 |
3. 组合操作¶
| 转换方法 | 作用 | 目的 |
|---|---|---|
Compose([...]) |
将多个转换操作按顺序串联起来。 | 这是最常用的方法,用于定义一个完整的预处理流程。 |
Compose 示例 (标准的预处理流程)¶
在实际应用中,我们通常使用 transforms.Compose 将多个操作组合成一个流程:
from torchvision import transforms
# 定义一个标准的预处理流程
transform = transforms.Compose([
# 1. 调整图像尺寸到 256x256
transforms.Resize(256),
# 2. 从 256x256 的图像中随机裁剪 224x224 的区域 (用于数据增强)
transforms.RandomCrop(224),
# 3. 以 50% 的概率水平翻转图像 (用于数据增强)
transforms.RandomHorizontalFlip(),
# 4. 必须操作:将图像转换为 PyTorch Tensor
transforms.ToTensor(),
# 5. 必须操作:进行标准化 (使用 ImageNet 的均值和标准差作为示例)
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 假设 image 是一个 PIL 图像对象
# processed_tensor = transform(image)
总结¶
torchvision.transforms 模块通过提供一个清晰、灵活的 API,使得深度学习工程师可以方便地:
- 构建模型所需的输入张量。
- 应用各种数据增强技术来改进模型性能和鲁棒性。
# 转换图片
from PIL import Image
from torchvision import transforms
img_path = 'dataset/train/ants/0013035.jpg'
img = Image.open(img_path)
print(type(img))
<class 'PIL.JpegImagePlugin.JpegImageFile'>
tensor_trans = transforms.ToTensor() #实例化类
img_tensor = tensor_trans(img) #调用类(就是调用__call__函数)
print(type(img_tensor))
<class 'torch.Tensor'>
print(img_tensor)
tensor([[[0.3137, 0.3137, 0.3137, ..., 0.3176, 0.3098, 0.2980],
[0.3176, 0.3176, 0.3176, ..., 0.3176, 0.3098, 0.2980],
[0.3216, 0.3216, 0.3216, ..., 0.3137, 0.3098, 0.3020],
...,
[0.3412, 0.3412, 0.3373, ..., 0.1725, 0.3725, 0.3529],
[0.3412, 0.3412, 0.3373, ..., 0.3294, 0.3529, 0.3294],
[0.3412, 0.3412, 0.3373, ..., 0.3098, 0.3059, 0.3294]],
[[0.5922, 0.5922, 0.5922, ..., 0.5961, 0.5882, 0.5765],
[0.5961, 0.5961, 0.5961, ..., 0.5961, 0.5882, 0.5765],
[0.6000, 0.6000, 0.6000, ..., 0.5922, 0.5882, 0.5804],
...,
[0.6275, 0.6275, 0.6235, ..., 0.3608, 0.6196, 0.6157],
[0.6275, 0.6275, 0.6235, ..., 0.5765, 0.6275, 0.5961],
[0.6275, 0.6275, 0.6235, ..., 0.6275, 0.6235, 0.6314]],
[[0.9137, 0.9137, 0.9137, ..., 0.9176, 0.9098, 0.8980],
[0.9176, 0.9176, 0.9176, ..., 0.9176, 0.9098, 0.8980],
[0.9216, 0.9216, 0.9216, ..., 0.9137, 0.9098, 0.9020],
...,
[0.9294, 0.9294, 0.9255, ..., 0.5529, 0.9216, 0.8941],
[0.9294, 0.9294, 0.9255, ..., 0.8863, 1.0000, 0.9137],
[0.9294, 0.9294, 0.9255, ..., 0.9490, 0.9804, 0.9137]]])
print(img_tensor.shape)
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[1], line 1 ----> 1 print(img_tensor.shape) NameError: name 'img_tensor' is not defined
说明:维度0 深度(通道数);维度1 行数(高),维度2 列数(宽)
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
# 将特征转换为归一化的张量,并将标签转换为独热编码的张量
ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
# 把输入的y列表对应的位置都写 1,其他则为 0,并统一长度为 10
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1)) #y should be Python List
)
这段代码中的 target_transform 部分是 PyTorch 数据预处理中一个简洁且高效的技巧,它实现了将标量标签(类别 ID)转换为独热编码 (One-Hot Encoding)。
让我们把这段代码拆解开来,一步步详细解释。
代码解析:目标转换 (Target Transform)¶
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
1. Lambda(...) 的作用¶
- 回顾:
Lambda是torchvision.transforms模块中的一个类,它允许您将任何 自定义的、单行可执行的函数 封装成一个转换操作。 - 输入: 这里的
Lambda接收一个匿名函数lambda y: ...。 y: 在 FashionMNIST 数据集中,y就是当前的目标标签 (Target Label),它是一个表示类别索引的 整数标量(例如:T恤衫是0,裤子是1,鞋子是9)。
2. Lambda 函数体:核心的独热编码逻辑¶
Lambda 函数体执行了三个步骤来完成独热编码:
A. 创建一个全零向量 (容器)¶
torch.zeros(10, dtype=torch.float)
torch.zeros(10, ...): 创建一个包含 10 个元素的 一维零张量。- 为什么是 10?因为 FashionMNIST 数据集有 10 个类别。
- 这个张量就是我们最终的独热编码容器,形状是
(10,)。
dtype=torch.float: 确保张量的数据类型是浮点数,这通常是深度学习中损失函数(如交叉熵损失)所要求的输入格式。
B. 核心操作:.scatter_()¶
这个操作是实现独热编码的关键。它将一个数值 分散 (scatter) 写入目标张量的指定索引位置。
- 操作对象: 刚才创建的那个全零张量。
- 语法回顾:
target.scatter_(dim, index, value)
C. 理解 .scatter_() 的参数¶
| 参数 | 表达式 | 实际值 (假设 $y=3$) | 作用 |
|---|---|---|---|
dim |
0 |
0 |
分散维度: 因为目标张量是 1 维的 (10,),所以只有 dim=0 可选。操作将沿着这个唯一的维度进行。 |
index |
torch.tensor(y) |
torch.tensor(3) |
索引张量: y 是一个 Python 整数(如 3),但 .scatter_ 要求索引必须是 PyTorch 张量。因此,我们将其包装成一个单元素张量。 |
value |
value=1 |
1 |
写入值: 要写入到目标张量指定位置的值。独热编码要求将对应类别位置的值设置为 1。 |
3. 完整过程演示(假设 $y=3$)¶
假设当前样本的标签 $y=3$(即第 4 个类别):
初始零张量 (Target): $$T = [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]$$ (形状 $(10$, $)$)
索引张量 (Index): $$I = [3]$$ (形状 $(1$, $)$)
执行
.scatter_(0, I, value=1): 这个操作意味着:将值1写入到 $T$ 中,位置由 $I$ 指定,沿着 $dim=0$。- $T[I[0]] = 1$
- $T[3] = 1$
最终独热编码结果: $$T_{final} = [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]$$
总结:这段代码的意义¶
这段 target_transform 代码的目的是将原始的 整数类别标签(如 3)转换成 长度为类别总数 (10) 的向量,其中只有对应类别的索引位置为 1,其他位置为 0。
为什么需要这样做?
- 交叉熵损失 (Cross-Entropy Loss): 在 PyTorch 中,虽然
nn.CrossEntropyLoss优化器可以直接接收整数 ID 标签,但当涉及到其他类型的损失函数、自定义激活函数(如 Sigmoid 输出)或执行某些特殊的梯度计算时,独热编码 的格式(One-Hot Vector)是更通用和标准的标签格式。