PyTorch 构建 Transformer 模型(千字长文)
💡一则或许对你有用的小广告
欢迎加入小哈的星球 ,你将获得:专属的项目实战 / 1v1 提问 / Java 学习路线 / 学习打卡 / 每月赠书 / 社群讨论
- 新项目:《从零手撸:仿小红书(微服务架构)》 正在持续爆肝中,基于
Spring Cloud Alibaba + Spring Boot 3.x + JDK 17...
,点击查看项目介绍 ;- 《从零手撸:前后端分离博客项目(全栈开发)》 2 期已完结,演示链接: http://116.62.199.48/ ;
截止目前, 星球 内专栏累计输出 82w+ 字,讲解图 3441+ 张,还在持续爆肝中.. 后续还会上新更多项目,目标是将 Java 领域典型的项目都整一波,如秒杀系统, 在线商城, IM 即时通讯,权限管理,Spring Cloud Alibaba 微服务等等,已有 2900+ 小伙伴加入学习 ,欢迎点击围观
在自然语言处理(Natural Language Processing, NLP)领域,Transformer 模型凭借其高效的信息处理能力和并行计算优势,已经成为序列建模的主流选择。而 PyTorch 作为灵活易用的深度学习框架,为构建 Transformer 提供了丰富的工具和简洁的接口。本文将带领编程初学者和中级开发者,通过理论结合实践的方式,逐步掌握如何使用 PyTorch 实现一个基础的 Transformer 模型。
无论你是对 Transformer 的工作机制感到好奇,还是希望将这一技术应用到实际项目中,本文都将提供清晰的思路和可复现的代码案例。
1. Transformer 的核心概念与原理
1.1 什么是 Transformer?
Transformer 是一种基于自注意力机制(Self-Attention)的深度学习架构,由 Vaswani 等人在 2017 年提出。与传统循环神经网络(RNN)或卷积神经网络(CNN)不同,Transformer 通过全局依赖建模和并行计算,显著提升了处理长序列数据的效率。
形象比喻:
想象一个快递分拣中心,每个包裹(输入序列中的单词)需要根据地址标签(注意力权重)被分发到正确的运输通道。自注意力机制就像这个分拣系统,让每个元素都能“关注”到序列中其他元素的重要性,从而决定自身在整体语义中的角色。
1.2 Transformer 的关键组件
Transformer 的核心组件包括:
- 自注意力层(Self-Attention Layer):计算序列中元素之间的相关性。
- 前馈神经网络(Feed-Forward Network):对每个位置的特征进行非线性变换。
- 位置编码(Positional Encoding):为序列添加位置信息,弥补 Transformer 对顺序敏感的不足。
- 掩码(Masking):用于处理序列长度不一致或遮蔽未来信息(如解码器中的因果掩码)。
2. 使用 PyTorch 实现 Transformer 的基础步骤
2.1 环境准备与数据加载
在开始编码前,确保已安装 PyTorch 和相关库:
pip install torch torchvision torchaudio
接下来,我们以一个简单的文本分类任务为例,加载并预处理数据:
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import IMDB
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def process_data(text, label):
encoding = tokenizer(text, truncation=True, padding=True)
return {
"input_ids": torch.tensor(encoding["input_ids"]),
"attention_mask": torch.tensor(encoding["attention_mask"]),
"labels": torch.tensor(label)
}
dataset = IMDB(root="./data", split="train", transform=process_data)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
2.2 构建位置编码层
位置编码为序列中的每个位置添加唯一的向量,确保模型能区分顺序信息。常用的方法是正弦波编码:
import math
import torch.nn as nn
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # 增加 batch 维度
self.register_buffer("pe", pe)
def forward(self, x):
return x + self.pe[:, :x.size(1)]
3. 构建 Transformer 编码器与解码器
3.1 编码器层(Encoder Layer)
编码器层是 Transformer 的核心计算单元,包含多头注意力(Multi-Head Attention)和前馈网络:
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.activation = nn.ReLU()
def forward(self, src, src_mask=None, src_key_padding_mask=None):
# 自注意力机制
attn_out, _ = self.self_attn(
src, src, src,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask
)
src = src + self.dropout(attn_out) # 残差连接
src = self.norm1(src)
# 前馈网络
ff_out = self.linear2(self.activation(self.linear1(src)))
src = src + self.dropout(ff_out)
src = self.norm2(src)
return src
3.2 解码器层(Decoder Layer)
解码器层包含自注意力、编码器-解码器注意力和前馈网络,常用于序列生成任务(如机器翻译):
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# ... 其他层与编码器类似 ...
4. 完整 Transformer 模型实现
4.1 整合编码器与解码器
通过堆叠多个编码器层和解码器层,可以构建完整的 Transformer 模型:
class TransformerModel(nn.Module):
def __init__(self, ntoken, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward):
super().__init__()
self.embedding = nn.Embedding(ntoken, d_model)
self.pos_encoder = PositionalEncoding(d_model)
encoder_layers = TransformerEncoderLayer(d_model, nhead, dim_feedforward)
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_encoder_layers)
# 解码器部分(根据任务需求添加)
def forward(self, src, src_mask=None):
src = self.embedding(src) * math.sqrt(d_model)
src = self.pos_encoder(src)
output = self.transformer_encoder(src, mask=src_mask)
return output
4.2 训练与评估
定义损失函数和优化器,启动训练循环:
model = TransformerModel(ntoken=10000, d_model=256, nhead=8, num_encoder_layers=3, ...)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(10):
model.train()
for batch in dataloader:
optimizer.zero_grad()
outputs = model(batch["input_ids"], batch["attention_mask"])
loss = criterion(outputs, batch["labels"])
loss.backward()
optimizer.step()
5. 实际案例:文本分类任务
5.1 任务场景与数据准备
假设我们希望使用 Transformer 对电影评论进行情感分类(正面/负面)。数据集可选用 IMDb 数据集,包含 50,000 条标注文本。
5.2 完整代码示例
import torch.nn.functional as F
class SentimentTransformer(TransformerModel):
def __init__(self, ntoken, d_model, nhead, num_encoder_layers, num_classes):
super().__init__(ntoken, d_model, nhead, num_encoder_layers, ...)
self.classifier = nn.Linear(d_model, num_classes)
def forward(self, src, src_mask=None):
features = super().forward(src, src_mask)
# 取平均或最后一个 token 的输出
pooled = features.mean(dim=1)
return self.classifier(pooled)
model = SentimentTransformer(ntoken=10000, d_model=256, nhead=8, num_encoder_layers=3, num_classes=2)
6. 常见问题与调试技巧
6.1 注意力权重的可视化
通过保存自注意力层的输出权重,可以直观分析模型对输入序列的关注模式:
attn_weights = self.self_attn(...)
print("Attention weights shape:", attn_weights.shape) # (batch_size, nhead, seq_len, seq_len)
6.2 解决过拟合问题
- 正则化:增加 dropout 层或使用权重衰减(Weight Decay)。
- 早停法(Early Stopping):监控验证集损失,提前终止训练。
- 数据增强:对文本进行随机掩码或回译(Back Translation)。
结论
通过本文的讲解,读者应已掌握 PyTorch 构建 Transformer 模型 的核心步骤:从理论概念到代码实现,再到实际案例的训练与调试。无论是文本分类、机器翻译还是序列生成任务,Transformer 的灵活性和高效性使其成为 NLP 开发者的必备工具。
建议读者尝试以下扩展实验:
- 将模型迁移到图像或时间序列任务;
- 使用 PyTorch 的
torch.compile
加速推理; - 探索不同的位置编码方法(如学习型位置编码)。
掌握 Transformer 的构建方法,你将能够应对更复杂的深度学习挑战,并为未来探索大模型(如 BERT、GPT)打下坚实基础。