PyTorch 卷积神经网络(长文讲解)

更新时间:

💡一则或许对你有用的小广告

欢迎加入小哈的星球 ,你将获得:专属的项目实战 / 1v1 提问 / Java 学习路线 / 学习打卡 / 每月赠书 / 社群讨论

截止目前, 星球 内专栏累计输出 90w+ 字,讲解图 3441+ 张,还在持续爆肝中.. 后续还会上新更多项目,目标是将 Java 领域典型的项目都整一波,如秒杀系统, 在线商城, IM 即时通讯,权限管理,Spring Cloud Alibaba 微服务等等,已有 3100+ 小伙伴加入学习 ,欢迎点击围观

在深度学习领域,PyTorch 卷积神经网络(CNN)是图像识别、目标检测等任务的核心工具。作为编程初学者和中级开发者,理解 CNN 的原理和 PyTorch 实现方法,是迈向计算机视觉领域的关键一步。本文将从基础概念讲起,通过代码示例和形象比喻,帮助读者逐步构建属于自己的卷积神经网络。


卷积神经网络基础:从概念到直觉

什么是卷积神经网络?

卷积神经网络是一种专门处理网格状数据(如图像)的深度学习模型。它的核心思想是通过局部感知参数共享,从图像中提取层次化的特征。例如,第一层可能检测边缘,第二层组合边缘形成形状,最终层识别整体模式。

卷积层:图像的“扫描仪”

卷积层通过卷积核(Kernel)在图像上滑动,逐块计算局部区域的特征。想象你用一个放大镜扫描地图,每次只观察一小块区域,记录下这块区域的关键信息。卷积核的大小(如 3×3 或 5×5)决定了“放大镜”的视野范围。

激活函数:为特征注入非线性

激活函数(如 ReLU)为卷积层的输出引入非线性,使模型能够学习复杂模式。例如,ReLU 将负值归零,保留正值,类似于“过滤掉不重要的信息,保留关键特征”。

池化层:信息的“压缩机”

池化层(如 MaxPool)通过缩小特征图的尺寸,减少计算量并增强模型的抗干扰能力。例如,将 2×2 的区域压缩为最大值,类似于在地图上合并相邻区域,保留核心信息。


PyTorch 中的卷积层实现:从代码到理解

在 PyTorch 中,卷积层由 torch.nn.Conv2d 实现。其核心参数包括:

  • in_channels:输入图像的通道数(如 RGB 图像为 3,灰度图 1)。
  • out_channels:输出特征图的数量(即卷积核的数量)。
  • kernel_size:卷积核的尺寸(如 3 表示 3×3)。

示例代码:定义一个简单的卷积层

import torch.nn as nn  

class SimpleCNN(nn.Module):  
    def __init__(self):  
        super(SimpleCNN, self).__init__()  
        # 输入通道 1(灰度图),输出通道 16,卷积核 3×3  
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3)  
        self.relu = nn.ReLU()  # 激活函数  

    def forward(self, x):  
        x = self.conv1(x)  
        x = self.relu(x)  
        return x  

卷积层输出尺寸计算

假设输入图像尺寸为 H×W,卷积核尺寸为 K×K,且 padding=0,stride=1,则输出尺寸为:
[ \text{Output Size} = (H - K + 1) \times (W - K + 1) ]
例如,输入 28×28 的图像经过 3×3 卷积核后,输出为 26×26。


池化层与激活函数:PyTorch 实现细节

MaxPool2d:缩小特征图

self.pool = nn.MaxPool2d(kernel_size=2, stride=2)  

此代码将特征图尺寸缩小为原来的一半(如 26×26 → 13×13)。

激活函数的选择

除了 ReLU,PyTorch 还支持 SigmoidTanh 等。例如:

self.sigmoid = nn.Sigmoid()  

但 ReLU 因为计算简单、不易梯度消失,成为 CNN 的主流选择。


构建第一个 CNN 模型:手写数字识别

目标:使用 MNIST 数据集训练一个简单 CNN

MNIST 是经典的图像数据集,包含 60,000 张 28×28 的手写数字图像。

模型结构设计

class MNISTCNN(nn.Module):  
    def __init__(self):  
        super(MNISTCNN, self).__init__()  
        # 卷积层 1:输入通道 1,输出 16,核 3×3  
        self.conv1 = nn.Conv2d(1, 16, 3)  
        self.pool = nn.MaxPool2d(2, 2)  
        self.relu = nn.ReLU()  
        # 卷积层 2:输入通道 16,输出 32,核 3×3  
        self.conv2 = nn.Conv2d(16, 32, 3)  
        # 全连接层:输入特征数 32×12×12 → 4608,输出 10 类  
        self.fc = nn.Linear(32 * 12 * 12, 10)  

    def forward(self, x):  
        x = self.conv1(x)  
        x = self.relu(x)  
        x = self.pool(x)  
        x = self.conv2(x)  
        x = self.relu(x)  
        x = self.pool(x)  
        # 展平张量为向量  
        x = x.view(-1, 32 * 12 * 12)  
        x = self.fc(x)  
        return x  

训练流程代码

import torch.optim as optim  
from torch.utils.data import DataLoader  
from torchvision import datasets, transforms  

transform = transforms.Compose([  
    transforms.ToTensor(),  
    transforms.Normalize((0.1307,), (0.3081,))  
])  

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)  
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)  

model = MNISTCNN()  
criterion = nn.CrossEntropyLoss()  
optimizer = optim.Adam(model.parameters(), lr=0.001)  

for epoch in range(5):  
    for images, labels in train_loader:  
        optimizer.zero_grad()  
        outputs = model(images)  
        loss = criterion(outputs, labels)  
        loss.backward()  
        optimizer.step()  
    print(f"Epoch {epoch+1} Loss: {loss.item():.4f}")  

实际案例:模型调优与进阶技巧

调整学习率与正则化

optimizer = optim.Adam(model.parameters(), lr=0.001)  

self.dropout = nn.Dropout(p=0.2)  
x = self.dropout(x)  # 在全连接层前添加  

模型保存与加载

torch.save(model.state_dict(), "mnist_cnn.pth")  

model = MNISTCNN()  
model.load_state_dict(torch.load("mnist_cnn.pth"))  
model.eval()  

结论:掌握 PyTorch 卷积神经网络的关键步骤

通过本文,读者已了解:

  1. 卷积层、池化层和激活函数的原理与 PyTorch 实现;
  2. 如何构建并训练一个完整的 CNN 模型;
  3. 模型调优和部署的基本方法。

PyTorch 卷积神经网络的强大之处在于其灵活性和高效性。建议读者从简单任务入手(如 MNIST),逐步尝试复杂数据集(如 CIFAR-10),并探索预训练模型(如 ResNet)的应用。通过实践,你将掌握这一技术的核心逻辑,并为更复杂的计算机视觉任务打下坚实基础。


本文通过循序渐进的讲解和代码示例,帮助开发者快速上手 PyTorch 卷积神经网络。如需进一步探讨模型优化或具体应用场景,欢迎在评论区留言。

最新发布