PyTorch 构建 Transformer 模型(千字长文)

更新时间:

💡一则或许对你有用的小广告

欢迎加入小哈的星球 ,你将获得:专属的项目实战 / 1v1 提问 / Java 学习路线 / 学习打卡 / 每月赠书 / 社群讨论

  • 新项目:《从零手撸:仿小红书(微服务架构)》 正在持续爆肝中,基于 Spring Cloud Alibaba + Spring Boot 3.x + JDK 17...点击查看项目介绍 ;
  • 《从零手撸:前后端分离博客项目(全栈开发)》 2 期已完结,演示链接: http://116.62.199.48/ ;

截止目前, 星球 内专栏累计输出 82w+ 字,讲解图 3441+ 张,还在持续爆肝中.. 后续还会上新更多项目,目标是将 Java 领域典型的项目都整一波,如秒杀系统, 在线商城, IM 即时通讯,权限管理,Spring Cloud Alibaba 微服务等等,已有 2900+ 小伙伴加入学习 ,欢迎点击围观

在自然语言处理(Natural Language Processing, NLP)领域,Transformer 模型凭借其高效的信息处理能力和并行计算优势,已经成为序列建模的主流选择。而 PyTorch 作为灵活易用的深度学习框架,为构建 Transformer 提供了丰富的工具和简洁的接口。本文将带领编程初学者和中级开发者,通过理论结合实践的方式,逐步掌握如何使用 PyTorch 实现一个基础的 Transformer 模型。

无论你是对 Transformer 的工作机制感到好奇,还是希望将这一技术应用到实际项目中,本文都将提供清晰的思路和可复现的代码案例。


1. Transformer 的核心概念与原理

1.1 什么是 Transformer?

Transformer 是一种基于自注意力机制(Self-Attention)的深度学习架构,由 Vaswani 等人在 2017 年提出。与传统循环神经网络(RNN)或卷积神经网络(CNN)不同,Transformer 通过全局依赖建模和并行计算,显著提升了处理长序列数据的效率。

形象比喻
想象一个快递分拣中心,每个包裹(输入序列中的单词)需要根据地址标签(注意力权重)被分发到正确的运输通道。自注意力机制就像这个分拣系统,让每个元素都能“关注”到序列中其他元素的重要性,从而决定自身在整体语义中的角色。

1.2 Transformer 的关键组件

Transformer 的核心组件包括:

  • 自注意力层(Self-Attention Layer):计算序列中元素之间的相关性。
  • 前馈神经网络(Feed-Forward Network):对每个位置的特征进行非线性变换。
  • 位置编码(Positional Encoding):为序列添加位置信息,弥补 Transformer 对顺序敏感的不足。
  • 掩码(Masking):用于处理序列长度不一致或遮蔽未来信息(如解码器中的因果掩码)。

2. 使用 PyTorch 实现 Transformer 的基础步骤

2.1 环境准备与数据加载

在开始编码前,确保已安装 PyTorch 和相关库:

pip install torch torchvision torchaudio

接下来,我们以一个简单的文本分类任务为例,加载并预处理数据:

import torch  
from torch.utils.data import DataLoader  
from torchvision.datasets import IMDB  
from transformers import AutoTokenizer  

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")  

def process_data(text, label):  
    encoding = tokenizer(text, truncation=True, padding=True)  
    return {  
        "input_ids": torch.tensor(encoding["input_ids"]),  
        "attention_mask": torch.tensor(encoding["attention_mask"]),  
        "labels": torch.tensor(label)  
    }  

dataset = IMDB(root="./data", split="train", transform=process_data)  
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)  

2.2 构建位置编码层

位置编码为序列中的每个位置添加唯一的向量,确保模型能区分顺序信息。常用的方法是正弦波编码:

import math  
import torch.nn as nn  

class PositionalEncoding(nn.Module):  
    def __init__(self, d_model, max_len=5000):  
        super().__init__()  
        pe = torch.zeros(max_len, d_model)  
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))  
        pe[:, 0::2] = torch.sin(position * div_term)  
        pe[:, 1::2] = torch.cos(position * div_term)  
        pe = pe.unsqueeze(0)  # 增加 batch 维度  
        self.register_buffer("pe", pe)  

    def forward(self, x):  
        return x + self.pe[:, :x.size(1)]  

3. 构建 Transformer 编码器与解码器

3.1 编码器层(Encoder Layer)

编码器层是 Transformer 的核心计算单元,包含多头注意力(Multi-Head Attention)和前馈网络:

class TransformerEncoderLayer(nn.Module):  
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):  
        super().__init__()  
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)  
        self.norm1 = nn.LayerNorm(d_model)  
        self.norm2 = nn.LayerNorm(d_model)  
        self.dropout = nn.Dropout(dropout)  
        self.linear1 = nn.Linear(d_model, dim_feedforward)  
        self.linear2 = nn.Linear(dim_feedforward, d_model)  
        self.activation = nn.ReLU()  

    def forward(self, src, src_mask=None, src_key_padding_mask=None):  
        # 自注意力机制  
        attn_out, _ = self.self_attn(  
            src, src, src,  
            attn_mask=src_mask,  
            key_padding_mask=src_key_padding_mask  
        )  
        src = src + self.dropout(attn_out)  # 残差连接  
        src = self.norm1(src)  

        # 前馈网络  
        ff_out = self.linear2(self.activation(self.linear1(src)))  
        src = src + self.dropout(ff_out)  
        src = self.norm2(src)  
        return src  

3.2 解码器层(Decoder Layer)

解码器层包含自注意力、编码器-解码器注意力和前馈网络,常用于序列生成任务(如机器翻译):

class TransformerDecoderLayer(nn.Module):  
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):  
        super().__init__()  
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)  
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)  
        # ... 其他层与编码器类似 ...  

4. 完整 Transformer 模型实现

4.1 整合编码器与解码器

通过堆叠多个编码器层和解码器层,可以构建完整的 Transformer 模型:

class TransformerModel(nn.Module):  
    def __init__(self, ntoken, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward):  
        super().__init__()  
        self.embedding = nn.Embedding(ntoken, d_model)  
        self.pos_encoder = PositionalEncoding(d_model)  
        encoder_layers = TransformerEncoderLayer(d_model, nhead, dim_feedforward)  
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_encoder_layers)  
        # 解码器部分(根据任务需求添加)  

    def forward(self, src, src_mask=None):  
        src = self.embedding(src) * math.sqrt(d_model)  
        src = self.pos_encoder(src)  
        output = self.transformer_encoder(src, mask=src_mask)  
        return output  

4.2 训练与评估

定义损失函数和优化器,启动训练循环:

model = TransformerModel(ntoken=10000, d_model=256, nhead=8, num_encoder_layers=3, ...)  
criterion = nn.CrossEntropyLoss()  
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)  

for epoch in range(10):  
    model.train()  
    for batch in dataloader:  
        optimizer.zero_grad()  
        outputs = model(batch["input_ids"], batch["attention_mask"])  
        loss = criterion(outputs, batch["labels"])  
        loss.backward()  
        optimizer.step()  

5. 实际案例:文本分类任务

5.1 任务场景与数据准备

假设我们希望使用 Transformer 对电影评论进行情感分类(正面/负面)。数据集可选用 IMDb 数据集,包含 50,000 条标注文本。

5.2 完整代码示例

import torch.nn.functional as F  

class SentimentTransformer(TransformerModel):  
    def __init__(self, ntoken, d_model, nhead, num_encoder_layers, num_classes):  
        super().__init__(ntoken, d_model, nhead, num_encoder_layers, ...)  
        self.classifier = nn.Linear(d_model, num_classes)  

    def forward(self, src, src_mask=None):  
        features = super().forward(src, src_mask)  
        # 取平均或最后一个 token 的输出  
        pooled = features.mean(dim=1)  
        return self.classifier(pooled)  

model = SentimentTransformer(ntoken=10000, d_model=256, nhead=8, num_encoder_layers=3, num_classes=2)  

6. 常见问题与调试技巧

6.1 注意力权重的可视化

通过保存自注意力层的输出权重,可以直观分析模型对输入序列的关注模式:

attn_weights = self.self_attn(...)  
print("Attention weights shape:", attn_weights.shape)  # (batch_size, nhead, seq_len, seq_len)  

6.2 解决过拟合问题

  • 正则化:增加 dropout 层或使用权重衰减(Weight Decay)。
  • 早停法(Early Stopping):监控验证集损失,提前终止训练。
  • 数据增强:对文本进行随机掩码或回译(Back Translation)。

结论

通过本文的讲解,读者应已掌握 PyTorch 构建 Transformer 模型 的核心步骤:从理论概念到代码实现,再到实际案例的训练与调试。无论是文本分类、机器翻译还是序列生成任务,Transformer 的灵活性和高效性使其成为 NLP 开发者的必备工具。

建议读者尝试以下扩展实验:

  1. 将模型迁移到图像或时间序列任务;
  2. 使用 PyTorch 的 torch.compile 加速推理;
  3. 探索不同的位置编码方法(如学习型位置编码)。

掌握 Transformer 的构建方法,你将能够应对更复杂的深度学习挑战,并为未来探索大模型(如 BERT、GPT)打下坚实基础。

最新发布