Skip to content
/ zh2en Public

中文至英文序列转导模型。从零严格复现《Attention Is All You Need》在中文→英文机器翻译(Zh→En)上的完整流程。

Notifications You must be signed in to change notification settings

caochk/zh2en

Repository files navigation

Transformer 中英机器翻译项目 (PyTorch 实现)

项目概述

  • 目标:从零严格复现《Attention Is All You Need》在中文→英文机器翻译(Zh→En)上的完整流程:数据处理、模型实现、训练、评估与解码、可视化与设备兼容。
  • 场景:企业可落地的机器翻译管线,支持 Mac M1/M2(MPS)与 Google Colab GPU。
  • 论文设定d_model=512n_heads=8encoder/decoder=6 层,d_ff=2048dropout=0.1Adam(β1=0.9, β2=0.98, ε=1e-9)Noam 学习率调度,label smoothing=0.1beam search + 长度惩罚 α=0.6正弦位置编码

🚀 一键运行 (One-Click Run)

如果你希望快速体验整个流程(数据生成 -> 冒烟测试 -> 训练 -> 评估 -> 推理),可以直接运行我们提供的全流程脚本:

# 赋予执行权限
chmod +x run_pipeline.sh

# 运行全流程
./run_pipeline.sh

该脚本会自动检测环境、生成演示数据(如果缺失)、运行冒烟测试、启动训练、执行评估并进行简单的解码演示。


📚 详细分步运行指南

如果你想深入了解每一步的操作,请按照以下流程进行:

1. 环境准备

首先确保你的 Python 环境已安装必要依赖。建议使用虚拟环境:

# 创建并激活虚拟环境
python3 -m venv .venv
source .venv/bin/activate

# 安装依赖
pip install torch sentencepiece sacrebleu wandb

2. 数据准备

本项目需要**中文(源)英文(目标)**的平行语料。

  • 自动生成演示数据: 如果你没有现成的数据,可以运行以下脚本生成简单的合成数据用于跑通流程:

    python3 generate_dummy_data.py

    这将生成 data/train.zh, data/train.en, data/val.zh, data/val.en

  • 使用真实数据(推荐): 你可以从 Kaggle (如 WMT, UM-Corpus) 下载高质量平行语料。 请将数据清洗后整理为以下格式(每行一句,中英严格对齐):

    • data/train.zh (训练集中文)
    • data/train.en (训练集英文)
    • data/val.zh (验证集中文)
    • data/val.en (验证集英文)
    • data/test.zh (测试集中文,可选)
    • data/test.en (测试集英文,可选)

    注意:文件路径可在 config.py 中修改。

3. 冒烟测试 (Smoke Test)

在开始昂贵的训练之前,先运行冒烟测试确保模型结构定义正确,能够跑通一次前向传播。

python3 smoke_test.py
# 期望输出: "logits shape: torch.Size([2, 15, 32000])" 等成功信息

4. 模型训练 (Training)

启动训练脚本。该脚本会自动处理以下事务:

  • 检查并训练 SentencePiece 子词模型(BPE 32k)。
  • 加载数据并构建 DataLoader。
  • 使用 Adam 优化器和 Noam 学习率调度策略进行训练。
  • 每个 epoch 结束后进行验证,并根据配置保存最佳模型。
  • 支持 WandB 可视化(默认开启,可在 config.py 关闭)。
python3 train.py
  • 输出
    • 子词模型:data/spm_zh_bpe.model, data/spm_en_bpe.model
    • 模型权重:checkpoints/transformer_zh_en_best.pt (验证集loss最低), checkpoints/transformer_zh_en.pt (最终)

5. 模型评估 (Evaluation)

使用训练好的模型在验证集上计算 BLEU 分数(机器翻译的标准评价指标)。我们使用 sacrebleu 库以确保评估的标准化。

python3 eval.py
# 输出示例: "val BLEU: 25.4"

6. 解码与推理 (Inference)

使用训练好的模型翻译新的中文句子。脚本使用 Beam Search 算法(束搜索)来寻找最优译文。

  • 交互式模式(手动输入中文):

    python3 decode.py

    输入句子后回车即可看到翻译结果。

  • 文件模式(翻译文件):

    python3 decode.py --src data/test.zh

    结果将保存到 outputs/pred.en


核心文件与结构说明

  • config.py控制中心。所有超参数(层数、维度、学习率、Batch Size、文件路径)都在这里修改。默认值严格遵循论文。
  • train.py训练主程序。包含训练循环、验证循环、梯度累积、早停机制。
  • models/transformer.py模型定义。包含 Encoder, Decoder, Multi-Head Attention, Positional Encoding 的完整实现。
  • data/dataset.py数据管道。负责读取文本、训练/加载分词器、构建 Batch 和 Mask。
  • decode.py解码脚本。实现 Beam Search 算法。
  • eval.py评估脚本。计算 BLEU 分数。
  • utils/device.py设备管理。自动适配 CUDA, MPS (Mac), CPU。

🧠 教学说明与原理 (核心概念速览)

本项目不仅是代码复现,更是《Attention Is All You Need》的学习笔记。代码中包含大量教学级注释。

1. 注意力机制 (Attention)

  • 直觉:类似“对话时的注意力集中”。
    • Query (Q): “我在问什么”
    • Key (K): “你有哪些信息标签”
    • Value (V): “具体信息内容”
  • 计算:计算 Q 和 K 的相似度(点积),经过 Softmax 归一化后,作为权重对 V 进行加权求和。
  • 缩放点积 (Scaled Dot-Product)softmax(QK^T / sqrt(d_k)) V。除以 sqrt(d_k) 是为了防止点积过大导致 Softmax 梯度消失,稳定训练。
  • 多头注意力 (Multi-Head):并行多个注意力子空间(如 8 个头),类似“用多种视角同时观察句子结构”,捕捉不同的语法和语义特征。

2. 位置编码 (Positional Encoding)

  • 问题:Self-Attention 是并行计算的,天然不具备序列顺序信息(“我爱你”和“你爱我”在它看来是一样的)。
  • 解决:使用不同频率的正弦/余弦函数生成位置向量,直接加到词嵌入上。
  • 优势:可以让模型泛化到比训练时更长的序列,且具备相对位置感知能力。

3. 训练策略 (Training Dynamics)

  • Label Smoothing (标签平滑):不仅仅将正确标签设为 1,错误标签设为 0,而是给错误标签也分配一点点概率(如 0.1)。这能降低过拟合,防止模型对自己过于自信,通常能提升 BLEU 分数。
  • Noam Learning Rate Scheduler
    • 预热 (Warmup):训练初期学习率线性增加,防止模型参数剧烈波动。
    • 衰减 (Decay):后期按步数的负平方根衰减,精细调整权重。
    • 公式lr = d_model^{-0.5} * min(step^{-0.5}, step * warmup^{-1.5})

4. 解码策略 (Decoding)

  • Beam Search (束搜索)
    • 贪婪搜索 (Greedy) 每次只选概率最大的词,容易陷入局部最优。
    • Beam Search 在每一步维护前 K 个(如 Beam Size=4)最优候选路径,最后选总分最高的。
  • 长度惩罚 (Length Penalty)
    • 因为概率是连乘(log 概率相加),长句子总分天然更低。需要除以长度的 $\alpha$ 次方(论文取 $\alpha=0.6$)来平衡,避免模型倾向于生成极短的句子。

📊 超参数与论文对照表

本项目严格遵循 Base 模型配置:

组件 参数 论文值 本项目配置 (config.py)
维度 d_model 512 512
多头 n_heads 8 8
层数 N 6 6 (Encoder & Decoder)
前馈网络 d_ff 2048 2048
Dropout P_drop 0.1 0.1
优化器 Adam β1=0.9, β2=0.98, ε=10^-9 一致
学习率 Warmup Steps 4000 4000
标签平滑 ε_ls 0.1 0.1
Batch Size Tokens 25000 / Batch 可配置 (默认按句子数 Batch)

❓ 常见问题与排查

  1. 显存不足 (OOM)

    • config.py 中减小 batch_size (例如从 64 减到 32 或 16)。
    • 减小 max_seq_len
    • 增大 grad_accum_steps 以保持有效 Batch Size 不变。
  2. Mac M1/M2 训练慢

    • 确保 config.pynum_workers=0 (MPS 目前对多进程支持不稳定)。
    • 确保已安装最新版 PyTorch。
  3. WandB 登录问题

    • 如果不想使用 WandB,在 config.py 设置 use_wandb = False
    • 如果在离线环境,设置 wandb_mode_offline = True
  4. 安装 sentencepiece 失败

    • 尝试升级 pip: pip install --upgrade pip
    • 确保已安装 C++ 编译工具 (如 Xcode Command Line Tools)。

后续扩展建议

  • 共享权重:实现源/目标词嵌入与输出层权重的共享(论文变体常用)。
  • 动态批处理:按 Token 数而非句子数构建 Batch,更接近论文训练策略并提升吞吐。
  • 数据增强:在数据加载时引入随机替换、删除等增强策略。

希望这个项目能帮助你深入理解 Transformer!

About

中文至英文序列转导模型。从零严格复现《Attention Is All You Need》在中文→英文机器翻译(Zh→En)上的完整流程。

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published