Test Transfrom
In [1]:
from torchvision import transforms
🖼️ TorchVision 介绍¶
TorchVision 是 PyTorch 官方支持的、专门用于计算机视觉 (Computer Vision) 任务的库。它是 PyTorch 生态系统中不可或缺的一部分,极大地简化了计算机视觉研究和应用的开发过程。
TorchVision 主要由三个核心模块构成,共同为视觉任务提供高效的支持:
1. torchvision.datasets (数据集)¶
这个模块提供了大量预先准备好的、常用的公共数据集,你可以直接下载和使用,无需自己编写复杂的数据加载代码。
- 核心功能: 方便地加载标准数据集进行训练和测试。
- 常见数据集示例:
- Image Classification (图像分类):
MNIST,CIFAR10,ImageNet(需要手动下载和组织文件)。 - Object Detection (目标检测):
CocoDetection,Kitti. - Semantic Segmentation (语义分割):
VOCSegmentation.
- Image Classification (图像分类):
2. torchvision.models (预训练模型)¶
这个模块提供了大量预训练的、高性能的神经网络模型。这些模型通常在大型数据集(如 ImageNet)上训练过,可以直接用于特征提取或作为迁移学习 (Transfer Learning) 的起点。
- 核心功能: 快速获取和使用经过验证的经典或先进模型。
- 常见模型架构示例:
- Classification (分类):
ResNet(如resnet50),VGG,MobileNet,EfficientNet. - Object Detection (目标检测):
Faster R-CNN,SSD,RetinaNet. - Segmentation (分割):
FCN,DeepLabV3.
- Classification (分类):
3. torchvision.transforms (数据变换/预处理)¶
这个模块提供了各种图像预处理和数据增强 (Data Augmentation) 操作。在将原始图像输入到神经网络之前,通常需要进行这些处理。
- 核心功能: 对 PIL Image 或 Tensor 进行各种操作,如尺寸调整、裁剪、归一化等。
- 常见操作示例:
ToTensor(): 将 PIL Image 或 NumPy 数组转换成 PyTorch Tensor。Normalize(): 对 Tensor 进行标准化处理 (减去均值,除以标准差)。Resize(): 改变图像尺寸。RandomCrop(): 随机裁剪,用于数据增强。Compose(): 将多个变换操作串联起来。
总结¶
TorchVision 是 PyTorch 用户进行计算机视觉项目时的标配工具箱,它通过提供即插即用的数据集、模型和预处理工具,帮助开发者:
- 快速启动项目: 利用预训练模型和标准数据集快速构建原型。
- 保证数据处理一致性: 使用统一的
transformsAPI 进行数据预处理。 - 节省资源: 避免重复实现常见模型的代码。
transform¶
torchvision.transforms 介绍¶
torchvision.transforms 是 TorchVision 库中专门用于图像预处理 (Preprocessing) 和数据增强 (Data Augmentation) 的核心模块。
在将原始图像数据送入神经网络之前,我们几乎总是需要进行一系列的转换操作,而 transforms 模块就是为此设计的。
核心目的¶
- 统一输入格式: 将不同格式的输入(如 PIL Image, NumPy Array)转换为 PyTorch 模型所需的
Tensor格式。 - 标准化: 对像素值进行归一化,使其均值接近 0、标准差接近 1,有助于模型训练收敛。
- 数据增强: 通过随机变换(如裁剪、翻转、旋转),增加训练数据的多样性,从而提高模型的泛化能力,减少过拟合。
常见操作分类及示例¶
1. 基础转换 (Normalization & Format)¶
| 转换方法 | 作用 | 目的 |
|---|---|---|
ToTensor() |
将 PIL Image 或 NumPy ndarray 转换为 FloatTensor。 |
必须的操作,将数据转换为模型能处理的 PyTorch 张量。 |
Normalize(mean, std) |
根据给定的均值和标准差对 Tensor 进行标准化。 | 将像素值缩放到一个合适的范围,加速模型收敛。 |
Resize(size) |
将输入图像调整到指定的尺寸。 | 确保所有输入图像尺寸统一。 |
2. 数据增强 (Augmentation)¶
| 转换方法 | 作用 | 目的 |
|---|---|---|
RandomCrop(size) |
从图像中随机裁剪一块指定大小的区域。 | 增加训练样本的多样性和平移不变性。 |
RandomHorizontalFlip(p=0.5) |
以给定的概率水平随机翻转图像。 | 增加数据的对称性。 |
RandomRotation(degrees) |
随机旋转图像一个角度范围。 | 增加模型的旋转不变性。 |
ColorJitter(brightness=0.2, ...) |
随机改变图像的亮度、对比度、饱和度和色调。 | 使模型对光照变化更具鲁棒性。 |
3. 组合操作¶
| 转换方法 | 作用 | 目的 |
|---|---|---|
Compose([...]) |
将多个转换操作按顺序串联起来。 | 这是最常用的方法,用于定义一个完整的预处理流程。 |
Compose 示例 (标准的预处理流程)¶
在实际应用中,我们通常使用 transforms.Compose 将多个操作组合成一个流程:
from torchvision import transforms
# 定义一个标准的预处理流程
transform = transforms.Compose([
# 1. 调整图像尺寸到 256x256
transforms.Resize(256),
# 2. 从 256x256 的图像中随机裁剪 224x224 的区域 (用于数据增强)
transforms.RandomCrop(224),
# 3. 以 50% 的概率水平翻转图像 (用于数据增强)
transforms.RandomHorizontalFlip(),
# 4. 必须操作:将图像转换为 PyTorch Tensor
transforms.ToTensor(),
# 5. 必须操作:进行标准化 (使用 ImageNet 的均值和标准差作为示例)
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 假设 image 是一个 PIL 图像对象
# processed_tensor = transform(image)
总结¶
torchvision.transforms 模块通过提供一个清晰、灵活的 API,使得深度学习工程师可以方便地:
- 构建模型所需的输入张量。
- 应用各种数据增强技术来改进模型性能和鲁棒性。
In [2]:
# 转换图片
from PIL import Image
from torchvision import transforms
In [3]:
img_path = 'dataset/train/ants/0013035.jpg'
img = Image.open(img_path)
print(type(img))
<class 'PIL.JpegImagePlugin.JpegImageFile'>
In [4]:
tensor_trans = transforms.ToTensor() #实例化类
img_tensor = tensor_trans(img) #调用类(就是调用__call__函数)
print(type(img_tensor))
<class 'torch.Tensor'>
In [5]:
print(img_tensor)
tensor([[[0.3137, 0.3137, 0.3137, ..., 0.3176, 0.3098, 0.2980],
[0.3176, 0.3176, 0.3176, ..., 0.3176, 0.3098, 0.2980],
[0.3216, 0.3216, 0.3216, ..., 0.3137, 0.3098, 0.3020],
...,
[0.3412, 0.3412, 0.3373, ..., 0.1725, 0.3725, 0.3529],
[0.3412, 0.3412, 0.3373, ..., 0.3294, 0.3529, 0.3294],
[0.3412, 0.3412, 0.3373, ..., 0.3098, 0.3059, 0.3294]],
[[0.5922, 0.5922, 0.5922, ..., 0.5961, 0.5882, 0.5765],
[0.5961, 0.5961, 0.5961, ..., 0.5961, 0.5882, 0.5765],
[0.6000, 0.6000, 0.6000, ..., 0.5922, 0.5882, 0.5804],
...,
[0.6275, 0.6275, 0.6235, ..., 0.3608, 0.6196, 0.6157],
[0.6275, 0.6275, 0.6235, ..., 0.5765, 0.6275, 0.5961],
[0.6275, 0.6275, 0.6235, ..., 0.6275, 0.6235, 0.6314]],
[[0.9137, 0.9137, 0.9137, ..., 0.9176, 0.9098, 0.8980],
[0.9176, 0.9176, 0.9176, ..., 0.9176, 0.9098, 0.8980],
[0.9216, 0.9216, 0.9216, ..., 0.9137, 0.9098, 0.9020],
...,
[0.9294, 0.9294, 0.9255, ..., 0.5529, 0.9216, 0.8941],
[0.9294, 0.9294, 0.9255, ..., 0.8863, 1.0000, 0.9137],
[0.9294, 0.9294, 0.9255, ..., 0.9490, 0.9804, 0.9137]]])
In [7]:
print(img_tensor.shape)
torch.Size([3, 512, 768])
说明:维度0 深度(通道数);维度1 行数(高),维度2 列数(宽)