- 目标:从零严格复现《Attention Is All You Need》在中文→英文机器翻译(Zh→En)上的完整流程:数据处理、模型实现、训练、评估与解码、可视化与设备兼容。
- 场景:企业可落地的机器翻译管线,支持 Mac M1/M2(MPS)与 Google Colab GPU。
- 论文设定:
d_model=512,n_heads=8,encoder/decoder=6层,d_ff=2048,dropout=0.1,Adam(β1=0.9, β2=0.98, ε=1e-9),Noam学习率调度,label smoothing=0.1,beam search + 长度惩罚 α=0.6,正弦位置编码。
如果你希望快速体验整个流程(数据生成 -> 冒烟测试 -> 训练 -> 评估 -> 推理),可以直接运行我们提供的全流程脚本:
# 赋予执行权限
chmod +x run_pipeline.sh
# 运行全流程
./run_pipeline.sh该脚本会自动检测环境、生成演示数据(如果缺失)、运行冒烟测试、启动训练、执行评估并进行简单的解码演示。
如果你想深入了解每一步的操作,请按照以下流程进行:
首先确保你的 Python 环境已安装必要依赖。建议使用虚拟环境:
# 创建并激活虚拟环境
python3 -m venv .venv
source .venv/bin/activate
# 安装依赖
pip install torch sentencepiece sacrebleu wandb本项目需要**中文(源)到英文(目标)**的平行语料。
-
自动生成演示数据: 如果你没有现成的数据,可以运行以下脚本生成简单的合成数据用于跑通流程:
python3 generate_dummy_data.py
这将生成
data/train.zh,data/train.en,data/val.zh,data/val.en。 -
使用真实数据(推荐): 你可以从 Kaggle (如 WMT, UM-Corpus) 下载高质量平行语料。 请将数据清洗后整理为以下格式(每行一句,中英严格对齐):
data/train.zh(训练集中文)data/train.en(训练集英文)data/val.zh(验证集中文)data/val.en(验证集英文)data/test.zh(测试集中文,可选)data/test.en(测试集英文,可选)
注意:文件路径可在
config.py中修改。
在开始昂贵的训练之前,先运行冒烟测试确保模型结构定义正确,能够跑通一次前向传播。
python3 smoke_test.py
# 期望输出: "logits shape: torch.Size([2, 15, 32000])" 等成功信息启动训练脚本。该脚本会自动处理以下事务:
- 检查并训练 SentencePiece 子词模型(BPE 32k)。
- 加载数据并构建 DataLoader。
- 使用 Adam 优化器和 Noam 学习率调度策略进行训练。
- 每个 epoch 结束后进行验证,并根据配置保存最佳模型。
- 支持 WandB 可视化(默认开启,可在
config.py关闭)。
python3 train.py- 输出:
- 子词模型:
data/spm_zh_bpe.model,data/spm_en_bpe.model - 模型权重:
checkpoints/transformer_zh_en_best.pt(验证集loss最低),checkpoints/transformer_zh_en.pt(最终)
- 子词模型:
使用训练好的模型在验证集上计算 BLEU 分数(机器翻译的标准评价指标)。我们使用 sacrebleu 库以确保评估的标准化。
python3 eval.py
# 输出示例: "val BLEU: 25.4"使用训练好的模型翻译新的中文句子。脚本使用 Beam Search 算法(束搜索)来寻找最优译文。
-
交互式模式(手动输入中文):
python3 decode.py
输入句子后回车即可看到翻译结果。
-
文件模式(翻译文件):
python3 decode.py --src data/test.zh
结果将保存到
outputs/pred.en。
config.py:控制中心。所有超参数(层数、维度、学习率、Batch Size、文件路径)都在这里修改。默认值严格遵循论文。train.py:训练主程序。包含训练循环、验证循环、梯度累积、早停机制。models/transformer.py:模型定义。包含 Encoder, Decoder, Multi-Head Attention, Positional Encoding 的完整实现。data/dataset.py:数据管道。负责读取文本、训练/加载分词器、构建 Batch 和 Mask。decode.py:解码脚本。实现 Beam Search 算法。eval.py:评估脚本。计算 BLEU 分数。utils/device.py:设备管理。自动适配 CUDA, MPS (Mac), CPU。
本项目不仅是代码复现,更是《Attention Is All You Need》的学习笔记。代码中包含大量教学级注释。
- 直觉:类似“对话时的注意力集中”。
- Query (Q): “我在问什么”
- Key (K): “你有哪些信息标签”
- Value (V): “具体信息内容”
- 计算:计算 Q 和 K 的相似度(点积),经过 Softmax 归一化后,作为权重对 V 进行加权求和。
- 缩放点积 (Scaled Dot-Product):
softmax(QK^T / sqrt(d_k)) V。除以sqrt(d_k)是为了防止点积过大导致 Softmax 梯度消失,稳定训练。 - 多头注意力 (Multi-Head):并行多个注意力子空间(如 8 个头),类似“用多种视角同时观察句子结构”,捕捉不同的语法和语义特征。
- 问题:Self-Attention 是并行计算的,天然不具备序列顺序信息(“我爱你”和“你爱我”在它看来是一样的)。
- 解决:使用不同频率的正弦/余弦函数生成位置向量,直接加到词嵌入上。
- 优势:可以让模型泛化到比训练时更长的序列,且具备相对位置感知能力。
- Label Smoothing (标签平滑):不仅仅将正确标签设为 1,错误标签设为 0,而是给错误标签也分配一点点概率(如 0.1)。这能降低过拟合,防止模型对自己过于自信,通常能提升 BLEU 分数。
- Noam Learning Rate Scheduler:
- 预热 (Warmup):训练初期学习率线性增加,防止模型参数剧烈波动。
- 衰减 (Decay):后期按步数的负平方根衰减,精细调整权重。
- 公式:
lr = d_model^{-0.5} * min(step^{-0.5}, step * warmup^{-1.5})
-
Beam Search (束搜索):
- 贪婪搜索 (Greedy) 每次只选概率最大的词,容易陷入局部最优。
- Beam Search 在每一步维护前 K 个(如 Beam Size=4)最优候选路径,最后选总分最高的。
-
长度惩罚 (Length Penalty):
- 因为概率是连乘(log 概率相加),长句子总分天然更低。需要除以长度的
$\alpha$ 次方(论文取$\alpha=0.6$ )来平衡,避免模型倾向于生成极短的句子。
- 因为概率是连乘(log 概率相加),长句子总分天然更低。需要除以长度的
本项目严格遵循 Base 模型配置:
| 组件 | 参数 | 论文值 | 本项目配置 (config.py) |
|---|---|---|---|
| 维度 | d_model |
512 | 512 |
| 多头 | n_heads |
8 | 8 |
| 层数 | N |
6 | 6 (Encoder & Decoder) |
| 前馈网络 | d_ff |
2048 | 2048 |
| Dropout | P_drop |
0.1 | 0.1 |
| 优化器 | Adam | β1=0.9, β2=0.98, ε=10^-9 | 一致 |
| 学习率 | Warmup Steps | 4000 | 4000 |
| 标签平滑 | ε_ls | 0.1 | 0.1 |
| Batch Size | Tokens | 25000 / Batch | 可配置 (默认按句子数 Batch) |
-
显存不足 (OOM):
- 在
config.py中减小batch_size(例如从 64 减到 32 或 16)。 - 减小
max_seq_len。 - 增大
grad_accum_steps以保持有效 Batch Size 不变。
- 在
-
Mac M1/M2 训练慢:
- 确保
config.py中num_workers=0(MPS 目前对多进程支持不稳定)。 - 确保已安装最新版 PyTorch。
- 确保
-
WandB 登录问题:
- 如果不想使用 WandB,在
config.py设置use_wandb = False。 - 如果在离线环境,设置
wandb_mode_offline = True。
- 如果不想使用 WandB,在
-
安装
sentencepiece失败:- 尝试升级 pip:
pip install --upgrade pip - 确保已安装 C++ 编译工具 (如 Xcode Command Line Tools)。
- 尝试升级 pip:
- 共享权重:实现源/目标词嵌入与输出层权重的共享(论文变体常用)。
- 动态批处理:按 Token 数而非句子数构建 Batch,更接近论文训练策略并提升吞吐。
- 数据增强:在数据加载时引入随机替换、删除等增强策略。
希望这个项目能帮助你深入理解 Transformer!