- ้กน็ฎๆฆ่ฟฐ
- ๆจกๅๆถๆๆป่ง
- ๆ ธๅฟ็ปไปถ่ฏฆ่งฃ
- 3.1 ๆณจๆๅๆบๅถ
- 3.2 MLA-NSA ๆททๅๆณจๆๅ
- 3.3 ๅ้ฆ็ฝ็ปไธ MoE
- 3.4 Hyper-Connections
- ่ฎญ็ปๆต็จ
- ๅ ณ้ฎๆๆฏ็นๆง
- ้ๅฝ๏ผๅพ่กจ็ดขๅผ
Tiny-R2 ๆฏไธไธช็ดงๅๅไฝๅ่ฝๅผบๅคง็่ฏญ่จๆจกๅ๏ผ็ปๅไบๅค็งๅ ่ฟ็ๆทฑๅบฆๅญฆไน ๆๆฏ๏ผ
- ็จ็ๆณจๆๅๆบๅถ (MLA-NSA Hybrid Attention)
- ไธๅฎถๆททๅๆจกๅ (DeepSeek MoE)
- ่ถ ่ฟๆฅๆๆฏ (Hyper-Connections)
- ๅไผๅๅจ็ญ็ฅ (Muon + AdamW)
Input Tokens
โ
Token Embedding + Positional Embedding
โ
Hyper-Connection Expand Stream
โ
โผ
-------------------------------------------------------
RMSNorm โโโโโโโโโโโโโ |
โ โ |
โผ โ Residual |
โโโโโโโโโโโโโโโโโ โ Connection |
โ Attention โ โ |
โ (NSA/SWA/DSA) | | |
| ไธ้ไธ โ โ |
โโโโโโโโโโโโโโโโโ โ |
โ โ |
โผ โ |
Hyper-connection (hc_attn) โ |
โ โ |
โผ โ N*Transformer Block
RMSNorm โโโโโโโโโโโโโค |
โ โ |
โผ โ |
โโโโโโโโโโโโโโโโโ โ |
โ Dense / MoE โ โ |
โ MLP/DSMoE | | |
| ไบ้ไธ โ โ |
โโโโโโโโโโโโโโโโโ โ |
โ โ |
โผ โ |
Hyper-connection (hc_mlp) โ |
โ โ |
โโโโโโโโโโโโโโโโโ |
โ |
โผ |
Output + router_weights |
โ |
Hyper-Connection Reduce Stream |
-----------------------------------------------------
โ
RMSNorm + LM Head
โ
Output Logits
Tiny-R2 ๆฏๆไธค็งๆณจๆๅ็ฑปๅ๏ผ้่ฟ้
็ฝฎ attention_types ็ตๆดปๅๆข๏ผ
ๆ ๅ็ๅ ๆ่ชๆณจๆๅๆบๅถ๏ผ
class CausalSelfAttention(nn.Module):
def __init__(self, config):
# Projections: Q, K, V from single linear
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
# Value residual connections
self.v_residual = config.v_residual
self.lamb1 = nn.Parameter(torch.tensor(0.5))
self.lamb2 = nn.Parameter(torch.tensor(0.5))
# Flash Attention support
self.flash = hasattr(F, "scaled_dot_product_attention")ๅ ณ้ฎ็นๆง๏ผ
- ไฝฟ็จ Flash Attention ๅ ้๏ผๅฆๆๅฏ็จ๏ผ
- ๆฏๆ Value Residual Connections
- ๆ ๅ็ๅ ๆๆฉ็
็ปๅ Multi-head Latent Attention (MLA) ๅ Native Sparse Attention (NSA) ็ๆททๅๆณจๆๅๆบๅถใ
ไธ็ง่ฟ่กๆจกๅผ๏ผ
| ๆจกๅผ | ๅๆฏ้ ็ฝฎ | ่ฏดๆ |
|---|---|---|
NSA |
[1, 1, 1] | ๅฏ็จๆๆไธไธชๅๆฏ |
SWA |
[1, 0, 1] | ๅ็ผฉๅๆฏ + ๆปๅจ็ชๅฃๅๆฏ |
DSA |
[1, 1, 0] | ๅ็ผฉๅๆฏ + ้ๆฉๅๆฏ |
MLA-NSA ๆฏ Tiny-R2 ็ๆ ธๅฟๅๆฐไนไธ๏ผ้่ฟไธไธชๅนถ่กๅๆฏๅฎ็ฐ้ซๆ็็จ็ๆณจๆๅ่ฎก็ฎใ
Input x
โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Query Preparation (MLA style) โ
โ compress_q โ q_norm โ decompress_q โ RoPE โ Query โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ
โโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโฌโโโโโโโโโโโโโโโโโโโโโโโ
โ Branch 1 โ Branch 2 โ Branch 3 โ
โ Compression โ Selection โ Sliding Window โ
โ (MLA) โ (DSA) โ (SWA) โ
โโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโผโโโโโโโโโโโโโโโโโโโโโโโค
โ compress_kv โ importance_score โ window_k/v โ
โ kv_norm โ topk selection โ sliding_window โ
โ decompress_k/v โ selection_k/v โ RoPE โ
โ k_rope โ RoPE โ โ
โ K/V Recombine โ K/V Selected โ K/V Window โ
โโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโดโโโโโโโโโโโโโโโโโโโโโโโ
โ โ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Attention Computation โ
โ Attention 1: (Q @ K1.T) @ V1 โ
โ Attention 2: (Q @ K2.T) @ V2 โ
โ Attention 3: (Q @ K3.T) @ V3 โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ
branch_gate (Linear + Softmax) โ Weighted Sum
โ
proj (Linear) โ res_dropout โ Output
# MLA ๅๆฐ
self.v_head_dim = 32
self.kv_lora_rank = 32
self.q_lora_rank = 3 * self.kv_lora_rank
self.rope_head_dim = 64
self.nope_head_dim = 32
# NSA ๅๆฐ
self.block_size = config.block_size # Tokenๅ็ผฉๅๅคงๅฐ
self.window_size = config.window_size # ๆปๅจ็ชๅฃๅคงๅฐ
self.num_tokens_to_keep = config.num_tokens_to_keep # ้ๆฉไฟ็็tokenๆฐๆ ๅ็ๅ้ฆ็ฝ็ป๏ผไฝฟ็จ ReLUยฒ ๆฟๆดปๅฝๆฐ๏ผ
class MLP(nn.Module):
def __init__(self):
self.c_fc = nn.Linear(n_embd, 4 * n_embd)
self.c_proj = nn.Linear(4 * n_embd, n_embd)
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square() # ReLU squared
x = self.c_proj(x)
return xDeepSeek ้ฃๆ ผ็ไธๅฎถๆททๅๆจกๅ๏ผ
Input x [B, T, C]
โ
Gate Network (Linear + UnitCenteredNoise)
โ
Softmax โ Top-k Selection
โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ Expert Networks โ
โ โโโโโโโโโโโโโโโโ โโโโโโโโโโโโ โโโโโโโโโโโโ โโโโโโโโโโ โ
โ โ Shared Exp 0 โ โ Expert 1 โ โ Expert 2 โ โ ... โ โ
โ โ (Always On) โ โ (Top-k) โ โ (Top-k) โ โ (Top-k)โ โ
โ โโโโโโโโโโโโโโโโ โโโโโโโโโโโโ โโโโโโโโโโโโ โโโโโโโโโโ โ
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
โ
Weighted Sum of Expert Outputs
โ
Output [B, T, C]
ๅ ณ้ฎ็นๆง๏ผ
| ็นๆง | ่ฏดๆ |
|---|---|
| Shared Expert | ๅง็ปๆฟๆดป็ๅ ฑไบซไธๅฎถ๏ผๆไพ็จณๅฎๆง |
| Routed Experts | Top-k ้ๆฉ็่ทฏ็ฑไธๅฎถ |
| Load Balance Loss | ้ฒๆญขไธๅฎถๅดฉๆบ็่ด่ฝฝๅ่กกๆๅคฑ |
| Expert Bias | ๅฏๅญฆไน ็ไธๅฎถๅ็ฝฎ๏ผ็จไบ่ทฏ็ฑไผๅ |
| UnitCenteredNoise | ่ฎญ็ปๆถๆทปๅ ๅชๅฃฐไปฅๅขๅ ๆข็ดข |
Load Balance Loss ่ฎก็ฎ๏ผ
def moe_load_balance_loss(router_weights, num_experts):
load = router_weights.sum(dim=0)
load = load / load.sum()
ideal = torch.full_like(load, 1.0 / num_experts)
loss = num_experts * torch.sum((load - ideal) ** 2)
return lossHyper-Connections ๆฏ Tiny-R2 ็ๅฆไธๅคงๅๆฐ๏ผ้่ฟๅคๆต่ทฏ็ฑๆบๅถๅขๅผบไฟกๆฏๆตๅจใ
ๆ ธๅฟๆฆๅฟต๏ผ
# ๅๅงๅ Hyper-Connections
self.init_hc, self.expand_stream, self.reduce_stream = \
get_init_and_expand_reduce_stream_functions(
config.hc_num_streams,
num_fracs=config.hc_num_fracs,
disable=config.hc_disable,
)
# ๅจๆฏไธช Block ไธญไฝฟ็จ
self.hc_attn = init_hc(
dim=config.n_embd,
branch=self.attn_branch,
layer_index=index * 2,
mhc=config.mhc,
sinkhorn_iters=config.sinkhorn_iters,
sinkhorn_tau=config.sinkhorn_tau,
)ๅ ณ้ฎๅๆฐ๏ผ
| ๅๆฐ | ่ฏดๆ |
|---|---|
hc_num_streams |
่ถ ่ฟๆฅๆตๆฐ้ |
hc_num_fracs |
ๅๆฎตๆฐ้ |
mhc |
ๅค่ถ ่ฟๆฅ้ ็ฝฎ |
sinkhorn_iters |
Sinkhorn ็ฎๆณ่ฟญไปฃๆฌกๆฐ |
sinkhorn_tau |
Sinkhorn ๆธฉๅบฆๅๆฐ |
Parse Arguments โ Update Config โ Init WandB โ Setup Distributed โ Setup AMP
Load HF Dataset (flytech/python-codes-25k)
โ
Init GPT2 Tokenizer
โ
Create TokenBuffer
TokenBuffer ๅ่ฝ๏ผ
- ๆตๅผ่ฏปๅ HuggingFace ๆฐๆฎ้
- ๅจๆๅกซๅ token buffer
- ็ๆ่ฟ็ปญ็ token batch
Create Transformer
โ
Configure Optimizers (Muon + AdamW)
โ
Create LR Scheduler (Warmup + Cosine)
For iter in range(max_iters):
โ
โโโ For step in grad_accum_steps:
โ โโโ Get Batch (TokenBuffer)
โ โโโ Forward Pass (model)
โ โโโ Backward Pass (scaler.scale)
โ โโโ Collect Router Weights
โ
โโโ Gradient Clipping (clip_grad_norm_)
โโโ Optimizer Steps (Muon + AdamW)
โโโ Update Scaler (scaler.update)
โโโ LR Scheduler Step
โโโ Update Expert Biases (load balancing)
โโโ Log Metrics (WandB)
If iter % eval_interval == 0:
โโโ Estimate Loss (eval mode)
โโโ Save Checkpoint (if val_loss < 5.27)
โโโ Log to WandB
Tiny-R2 ไฝฟ็จๅไผๅๅจ็ญ็ฅ๏ผ
def configure_optimizers(self, weight_decay, learning_rate, device):
muon_params = [] # โฅ2D parameters in blocks
adamw_params = [] # Other parameters
for name, param in self.named_parameters():
if 'blocks' in name and param.ndim >= 2:
muon_params.append(param)
else:
adamw_params.append(param)
return [
Muon(muon_params, lr=0.02, momentum=0.95),
torch.optim.AdamW(adamw_params, lr=learning_rate,
betas=(0.90, 0.95), weight_decay=weight_decay)
]| ็นๆง | CausalSelfAttention | MLA-NSA Hybrid |
|---|---|---|
| ่ฎก็ฎๅคๆๅบฆ | O(nยฒ) | O(n) ~ O(n log n) |
| ๅ ๅญไฝฟ็จ | ้ซ | ไฝ |
| ้็จๅบๆฏ | ็ญๅบๅ | ้ฟๅบๅ |
| ๅๆฏๆฐ้ | 1 | 3 (ๅฏ้ ็ฝฎ) |
| ็นๆง | MLP | DSMoE |
|---|---|---|
| ๅๆฐ้ | ๅบๅฎ | ๅ ฑไบซ + ่ทฏ็ฑ |
| ่ฎก็ฎ้ | ๅบๅฎ | ็จ็ๆฟๆดป |
| ่กจ่พพ่ฝๅ | ๆ ๅ | ๆดๅผบ |
| ่ฎญ็ป็จณๅฎๆง | ้ซ | ้่ฆ่ด่ฝฝๅ่กก |
# ๆจกๅๆถๆ
n_embd = 512 # ๅตๅ
ฅ็ปดๅบฆ
n_head = 8 # ๆณจๆๅๅคดๆฐ
n_layer = 8 # ๅฑๆฐ
n_experts = 8 # ไธๅฎถๆฐ้
num_exp = 2 # ๆฏtokenๆฟๆดป็ไธๅฎถๆฐ
# ๆณจๆๅ้
็ฝฎ
attention_types = ["FULL", "Spares", ...] # ๆฏๅฑๆณจๆๅ็ฑปๅ
attention_mode = ["FULL", "SWA", "NSA"] # ็จ็ๆณจๆๅๆจกๅผ
# Hyper-Connections
hc = True # ๅฏ็จ่ถ
่ฟๆฅ
hc_num_streams = 4 # ๆตๆฐ้
# ่ฎญ็ป
batch_size = 32
ctx_len = 512 # ไธไธๆ้ฟๅบฆ
lr = 1e-3
warmup_iters = 1000
max_iters = 100000ๆฌๆๆกฃ้
ๅฅๅพ่กจไฟๅญๅจ /mnt/okcomputer/output/ ็ฎๅฝ๏ผ
| ๆไปถๅ | ่ฏดๆ |
|---|---|
model_architecture.png |
ๆจกๅๆดไฝๆถๆๅพ |
mla_nsa_attention.png |
MLA-NSA ๆททๅๆณจๆๅ่ฏฆ็ป็ปๆๅพ |
dsmoe_architecture.png |
DSMoE ไธๅฎถๆททๅ็ปๆๅพ |
training_pipeline.png |
ๅฎๆด่ฎญ็ปๆต็จๅพ |
tinyr2_overview.png |
Tiny-R2 ็ปผๅๆฆ่งๅพ |
- Tiny-R2 GitHub Repository
- DeepSeek-V2 Technical Report
- Native Sparse Attention (NSA)
- Hyper-Connections Paper
ๆๆกฃ็ๆๆถ้ด: 2026-02-16