Skip to content

zhaoyingjun/Tiny-R2

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

47 Commits
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

Tiny-R2 A better combination: DSA/SWA/MLA, mHC, and DSMoE

Tiny-R2 ๆจกๅž‹ๆžถๆž„ไธŽ่ฎญ็ปƒๆต็จ‹ๆ–‡ๆกฃ


๐Ÿ“‹ ็›ฎๅฝ•

  1. ้กน็›ฎๆฆ‚่ฟฐ
  2. ๆจกๅž‹ๆžถๆž„ๆ€ป่งˆ
  3. ๆ ธๅฟƒ็ป„ไปถ่ฏฆ่งฃ
  4. ่ฎญ็ปƒๆต็จ‹
  5. ๅ…ณ้”ฎๆŠ€ๆœฏ็‰นๆ€ง
  6. ้™„ๅฝ•๏ผšๅ›พ่กจ็ดขๅผ•

้กน็›ฎๆฆ‚่ฟฐ

Tiny-R2 ๆ˜ฏไธ€ไธช็ดงๅ‡‘ๅž‹ไฝ†ๅŠŸ่ƒฝๅผบๅคง็š„่ฏญ่จ€ๆจกๅž‹๏ผŒ็ป“ๅˆไบ†ๅคš็งๅ…ˆ่ฟ›็š„ๆทฑๅบฆๅญฆไน ๆŠ€ๆœฏ๏ผš

  • ็จ€็–ๆณจๆ„ๅŠ›ๆœบๅˆถ (MLA-NSA Hybrid Attention)
  • ไธ“ๅฎถๆททๅˆๆจกๅž‹ (DeepSeek MoE)
  • ่ถ…่ฟžๆŽฅๆŠ€ๆœฏ (Hyper-Connections)
  • ๅŒไผ˜ๅŒ–ๅ™จ็ญ–็•ฅ (Muon + AdamW)

ๆจกๅž‹ๆžถๆž„ๆ€ป่งˆ

               Input Tokens
                      โ†“
Token Embedding + Positional Embedding
                      โ†“
       Hyper-Connection Expand Stream
                      โ”‚
                      โ–ผ
-------------------------------------------------------
                  RMSNorm โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”               |
                      โ”‚               โ”‚               |
                      โ–ผ               โ”‚ Residual      |
              โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”       โ”‚ Connection    |
              โ”‚  Attention    โ”‚       โ”‚               |
              โ”‚ (NSA/SWA/DSA) |       |               |
              |     ไธ‰้€‰ไธ€     โ”‚       โ”‚               |
              โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜       โ”‚               |
                      โ”‚               โ”‚               |
                      โ–ผ               โ”‚               |
           Hyper-connection (hc_attn) โ”‚               |
                      โ”‚               โ”‚               |
                      โ–ผ               โ”‚       N*Transformer Block
                  RMSNorm โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค               |
                      โ”‚               โ”‚               |
                      โ–ผ               โ”‚               |
              โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”       โ”‚               |
              โ”‚  Dense / MoE  โ”‚       โ”‚               |
              โ”‚  MLP/DSMoE    |       |               |
              |    ไบŒ้€‰ไธ€      โ”‚       โ”‚               |
              โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜       โ”‚               |
                      โ”‚               โ”‚               |
                      โ–ผ               โ”‚               |
           Hyper-connection (hc_mlp)  โ”‚               |
                      โ”‚               โ”‚               |
                      โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜               |
                      โ”‚                               |
                      โ–ผ                               |
            Output + router_weights                   |
                      โ†“                               |
         Hyper-Connection Reduce Stream               |
-----------------------------------------------------
                      โ†“
               RMSNorm + LM Head
                       โ†“
                 Output Logits

ๆ ธๅฟƒ็ป„ไปถ่ฏฆ่งฃ

3.1 ๆณจๆ„ๅŠ›ๆœบๅˆถ

Tiny-R2 ๆ”ฏๆŒไธค็งๆณจๆ„ๅŠ›็ฑปๅž‹๏ผŒ้€š่ฟ‡้…็ฝฎ attention_types ็ตๆดปๅˆ‡ๆข๏ผš

3.1.1 CausalSelfAttention (Full Attention)

ๆ ‡ๅ‡†็š„ๅ› ๆžœ่‡ชๆณจๆ„ๅŠ›ๆœบๅˆถ๏ผš

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        # Projections: Q, K, V from single linear
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        
        # Value residual connections
        self.v_residual = config.v_residual
        self.lamb1 = nn.Parameter(torch.tensor(0.5))
        self.lamb2 = nn.Parameter(torch.tensor(0.5))
        
        # Flash Attention support
        self.flash = hasattr(F, "scaled_dot_product_attention")

ๅ…ณ้”ฎ็‰นๆ€ง๏ผš

  • ไฝฟ็”จ Flash Attention ๅŠ ้€Ÿ๏ผˆๅฆ‚ๆžœๅฏ็”จ๏ผ‰
  • ๆ”ฏๆŒ Value Residual Connections
  • ๆ ‡ๅ‡†็š„ๅ› ๆžœๆŽฉ็ 

3.1.2 MLA-NSA Hybrid Attention

็ป“ๅˆ Multi-head Latent Attention (MLA) ๅ’Œ Native Sparse Attention (NSA) ็š„ๆททๅˆๆณจๆ„ๅŠ›ๆœบๅˆถใ€‚

ไธ‰็ง่ฟ่กŒๆจกๅผ๏ผš

ๆจกๅผ ๅˆ†ๆ”ฏ้…็ฝฎ ่ฏดๆ˜Ž
NSA [1, 1, 1] ๅฏ็”จๆ‰€ๆœ‰ไธ‰ไธชๅˆ†ๆ”ฏ
SWA [1, 0, 1] ๅŽ‹็ผฉๅˆ†ๆ”ฏ + ๆป‘ๅŠจ็ช—ๅฃๅˆ†ๆ”ฏ
DSA [1, 1, 0] ๅŽ‹็ผฉๅˆ†ๆ”ฏ + ้€‰ๆ‹ฉๅˆ†ๆ”ฏ

3.2 MLA-NSA ๆททๅˆๆณจๆ„ๅŠ›

MLA-NSA ๆ˜ฏ Tiny-R2 ็š„ๆ ธๅฟƒๅˆ›ๆ–ฐไน‹ไธ€๏ผŒ้€š่ฟ‡ไธ‰ไธชๅนถ่กŒๅˆ†ๆ”ฏๅฎž็Žฐ้ซ˜ๆ•ˆ็š„็จ€็–ๆณจๆ„ๅŠ›่ฎก็ฎ—ใ€‚

ๆžถๆž„ๆต็จ‹

                       Input x
                          โ†“
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚ Query Preparation (MLA style)                                โ”‚
โ”‚   compress_q โ†’ q_norm โ†’ decompress_q โ†’ RoPE โ†’ Query         โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
                          โ†“
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚   Branch 1       โ”‚   Branch 2       โ”‚   Branch 3           โ”‚
โ”‚   Compression    โ”‚   Selection      โ”‚   Sliding Window     โ”‚
โ”‚   (MLA)          โ”‚   (DSA)          โ”‚   (SWA)              โ”‚
โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค
โ”‚ compress_kv      โ”‚ importance_score โ”‚ window_k/v           โ”‚
โ”‚ kv_norm          โ”‚ topk selection   โ”‚ sliding_window       โ”‚
โ”‚ decompress_k/v   โ”‚ selection_k/v    โ”‚ RoPE                 โ”‚
โ”‚ k_rope           โ”‚ RoPE             โ”‚                      โ”‚
โ”‚ K/V Recombine    โ”‚ K/V Selected     โ”‚ K/V Window           โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
    โ†“                    โ†“                    โ†“
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚ Attention Computation                                        โ”‚
โ”‚   Attention 1: (Q @ K1.T) @ V1                               โ”‚
โ”‚   Attention 2: (Q @ K2.T) @ V2                               โ”‚
โ”‚   Attention 3: (Q @ K3.T) @ V3                               โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
    โ†“
branch_gate (Linear + Softmax) โ†’ Weighted Sum
    โ†“
proj (Linear) โ†’ res_dropout โ†’ Output

ๅ…ณ้”ฎๅ‚ๆ•ฐ

# MLA ๅ‚ๆ•ฐ
self.v_head_dim = 32
self.kv_lora_rank = 32
self.q_lora_rank = 3 * self.kv_lora_rank
self.rope_head_dim = 64
self.nope_head_dim = 32

# NSA ๅ‚ๆ•ฐ
self.block_size = config.block_size      # TokenๅŽ‹็ผฉๅ—ๅคงๅฐ
self.window_size = config.window_size    # ๆป‘ๅŠจ็ช—ๅฃๅคงๅฐ
self.num_tokens_to_keep = config.num_tokens_to_keep  # ้€‰ๆ‹ฉไฟ็•™็š„tokenๆ•ฐ

3.3 ๅ‰้ฆˆ็ฝ‘็ปœไธŽ MoE

3.3.1 MLP

ๆ ‡ๅ‡†็š„ๅ‰้ฆˆ็ฝ‘็ปœ๏ผŒไฝฟ็”จ ReLUยฒ ๆฟ€ๆดปๅ‡ฝๆ•ฐ๏ผš

class MLP(nn.Module):
    def __init__(self):
        self.c_fc = nn.Linear(n_embd, 4 * n_embd)
        self.c_proj = nn.Linear(4 * n_embd, n_embd)
    
    def forward(self, x):
        x = self.c_fc(x)
        x = F.relu(x).square()  # ReLU squared
        x = self.c_proj(x)
        return x

3.3.2 DSMoE (DeepSeek Mixture of Experts)

DeepSeek ้ฃŽๆ ผ็š„ไธ“ๅฎถๆททๅˆๆจกๅž‹๏ผš

Input x [B, T, C]
    โ†“
Gate Network (Linear + UnitCenteredNoise)
    โ†“
Softmax โ†’ Top-k Selection
    โ†“
โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚                    Expert Networks                          โ”‚
โ”‚  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”  โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” โ”‚
โ”‚  โ”‚ Shared Exp 0 โ”‚  โ”‚ Expert 1 โ”‚  โ”‚ Expert 2 โ”‚  โ”‚  ...   โ”‚ โ”‚
โ”‚  โ”‚ (Always On)  โ”‚  โ”‚ (Top-k)  โ”‚  โ”‚ (Top-k)  โ”‚  โ”‚ (Top-k)โ”‚ โ”‚
โ”‚  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜  โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
    โ†“
Weighted Sum of Expert Outputs
    โ†“
Output [B, T, C]

ๅ…ณ้”ฎ็‰นๆ€ง๏ผš

็‰นๆ€ง ่ฏดๆ˜Ž
Shared Expert ๅง‹็ปˆๆฟ€ๆดป็š„ๅ…ฑไบซไธ“ๅฎถ๏ผŒๆไพ›็จณๅฎšๆ€ง
Routed Experts Top-k ้€‰ๆ‹ฉ็š„่ทฏ็”ฑไธ“ๅฎถ
Load Balance Loss ้˜ฒๆญขไธ“ๅฎถๅดฉๆบƒ็š„่ดŸ่ฝฝๅ‡่กกๆŸๅคฑ
Expert Bias ๅฏๅญฆไน ็š„ไธ“ๅฎถๅ็ฝฎ๏ผŒ็”จไบŽ่ทฏ็”ฑไผ˜ๅŒ–
UnitCenteredNoise ่ฎญ็ปƒๆ—ถๆทปๅŠ ๅ™ชๅฃฐไปฅๅขžๅŠ ๆŽข็ดข

Load Balance Loss ่ฎก็ฎ—๏ผš

def moe_load_balance_loss(router_weights, num_experts):
    load = router_weights.sum(dim=0)
    load = load / load.sum()
    ideal = torch.full_like(load, 1.0 / num_experts)
    loss = num_experts * torch.sum((load - ideal) ** 2)
    return loss

3.4 Hyper-Connections

Hyper-Connections ๆ˜ฏ Tiny-R2 ็š„ๅฆไธ€ๅคงๅˆ›ๆ–ฐ๏ผŒ้€š่ฟ‡ๅคšๆต่ทฏ็”ฑๆœบๅˆถๅขžๅผบไฟกๆฏๆตๅŠจใ€‚

ๆ ธๅฟƒๆฆ‚ๅฟต๏ผš

# ๅˆๅง‹ๅŒ– Hyper-Connections
self.init_hc, self.expand_stream, self.reduce_stream = \
    get_init_and_expand_reduce_stream_functions(
        config.hc_num_streams,
        num_fracs=config.hc_num_fracs,
        disable=config.hc_disable,
    )

# ๅœจๆฏไธช Block ไธญไฝฟ็”จ
self.hc_attn = init_hc(
    dim=config.n_embd,
    branch=self.attn_branch,
    layer_index=index * 2,
    mhc=config.mhc,
    sinkhorn_iters=config.sinkhorn_iters,
    sinkhorn_tau=config.sinkhorn_tau,
)

ๅ…ณ้”ฎๅ‚ๆ•ฐ๏ผš

ๅ‚ๆ•ฐ ่ฏดๆ˜Ž
hc_num_streams ่ถ…่ฟžๆŽฅๆตๆ•ฐ้‡
hc_num_fracs ๅˆ†ๆฎตๆ•ฐ้‡
mhc ๅคš่ถ…่ฟžๆŽฅ้…็ฝฎ
sinkhorn_iters Sinkhorn ็ฎ—ๆณ•่ฟญไปฃๆฌกๆ•ฐ
sinkhorn_tau Sinkhorn ๆธฉๅบฆๅ‚ๆ•ฐ

่ฎญ็ปƒๆต็จ‹

4.1 ๅˆๅง‹ๅŒ–้˜ถๆฎต

Parse Arguments โ†’ Update Config โ†’ Init WandB โ†’ Setup Distributed โ†’ Setup AMP

4.2 ๆ•ฐๆฎๅ‡†ๅค‡

Load HF Dataset (flytech/python-codes-25k)
    โ†“
Init GPT2 Tokenizer
    โ†“
Create TokenBuffer

TokenBuffer ๅŠŸ่ƒฝ๏ผš

  • ๆตๅผ่ฏปๅ– HuggingFace ๆ•ฐๆฎ้›†
  • ๅŠจๆ€ๅกซๅ…… token buffer
  • ็”Ÿๆˆ่ฟž็ปญ็š„ token batch

4.3 ๆจกๅž‹ๅˆๅง‹ๅŒ–

Create Transformer
    โ†“
Configure Optimizers (Muon + AdamW)
    โ†“
Create LR Scheduler (Warmup + Cosine)

4.4 ่ฎญ็ปƒๅพช็Žฏ

For iter in range(max_iters):
    โ”‚
    โ”œโ”€โ”€ For step in grad_accum_steps:
    โ”‚       โ”œโ”€โ”€ Get Batch (TokenBuffer)
    โ”‚       โ”œโ”€โ”€ Forward Pass (model)
    โ”‚       โ”œโ”€โ”€ Backward Pass (scaler.scale)
    โ”‚       โ””โ”€โ”€ Collect Router Weights
    โ”‚
    โ”œโ”€โ”€ Gradient Clipping (clip_grad_norm_)
    โ”œโ”€โ”€ Optimizer Steps (Muon + AdamW)
    โ”œโ”€โ”€ Update Scaler (scaler.update)
    โ”œโ”€โ”€ LR Scheduler Step
    โ”œโ”€โ”€ Update Expert Biases (load balancing)
    โ””โ”€โ”€ Log Metrics (WandB)

4.5 ่ฏ„ไผฐไธŽไฟๅญ˜

If iter % eval_interval == 0:
    โ”œโ”€โ”€ Estimate Loss (eval mode)
    โ”œโ”€โ”€ Save Checkpoint (if val_loss < 5.27)
    โ””โ”€โ”€ Log to WandB

4.6 ไผ˜ๅŒ–ๅ™จ้…็ฝฎ

Tiny-R2 ไฝฟ็”จๅŒไผ˜ๅŒ–ๅ™จ็ญ–็•ฅ๏ผš

def configure_optimizers(self, weight_decay, learning_rate, device):
    muon_params = []    # โ‰ฅ2D parameters in blocks
    adamw_params = []   # Other parameters
    
    for name, param in self.named_parameters():
        if 'blocks' in name and param.ndim >= 2:
            muon_params.append(param)
        else:
            adamw_params.append(param)
    
    return [
        Muon(muon_params, lr=0.02, momentum=0.95),
        torch.optim.AdamW(adamw_params, lr=learning_rate, 
                          betas=(0.90, 0.95), weight_decay=weight_decay)
    ]

ๅ…ณ้”ฎๆŠ€ๆœฏ็‰นๆ€ง

5.1 ๆณจๆ„ๅŠ›ๆœบๅˆถๅฏนๆฏ”

็‰นๆ€ง CausalSelfAttention MLA-NSA Hybrid
่ฎก็ฎ—ๅคๆ‚ๅบฆ O(nยฒ) O(n) ~ O(n log n)
ๅ†…ๅญ˜ไฝฟ็”จ ้ซ˜ ไฝŽ
้€‚็”จๅœบๆ™ฏ ็Ÿญๅบๅˆ— ้•ฟๅบๅˆ—
ๅˆ†ๆ”ฏๆ•ฐ้‡ 1 3 (ๅฏ้…็ฝฎ)

5.2 FFN ็ฑปๅž‹ๅฏนๆฏ”

็‰นๆ€ง MLP DSMoE
ๅ‚ๆ•ฐ้‡ ๅ›บๅฎš ๅ…ฑไบซ + ่ทฏ็”ฑ
่ฎก็ฎ—้‡ ๅ›บๅฎš ็จ€็–ๆฟ€ๆดป
่กจ่พพ่ƒฝๅŠ› ๆ ‡ๅ‡† ๆ›ดๅผบ
่ฎญ็ปƒ็จณๅฎšๆ€ง ้ซ˜ ้œ€่ฆ่ดŸ่ฝฝๅ‡่กก

5.3 ๆ ธๅฟƒ้…็ฝฎๅ‚ๆ•ฐ

# ๆจกๅž‹ๆžถๆž„
n_embd = 512        # ๅตŒๅ…ฅ็ปดๅบฆ
n_head = 8          # ๆณจๆ„ๅŠ›ๅคดๆ•ฐ
n_layer = 8         # ๅฑ‚ๆ•ฐ
n_experts = 8       # ไธ“ๅฎถๆ•ฐ้‡
num_exp = 2         # ๆฏtokenๆฟ€ๆดป็š„ไธ“ๅฎถๆ•ฐ

# ๆณจๆ„ๅŠ›้…็ฝฎ
attention_types = ["FULL", "Spares", ...]  # ๆฏๅฑ‚ๆณจๆ„ๅŠ›็ฑปๅž‹
attention_mode = ["FULL", "SWA", "NSA"]    # ็จ€็–ๆณจๆ„ๅŠ›ๆจกๅผ

# Hyper-Connections
hc = True           # ๅฏ็”จ่ถ…่ฟžๆŽฅ
hc_num_streams = 4  # ๆตๆ•ฐ้‡

# ่ฎญ็ปƒ
batch_size = 32
ctx_len = 512       # ไธŠไธ‹ๆ–‡้•ฟๅบฆ
lr = 1e-3
warmup_iters = 1000
max_iters = 100000

้™„ๅฝ•๏ผšๅ›พ่กจ็ดขๅผ•

ๆœฌๆ–‡ๆกฃ้…ๅฅ—ๅ›พ่กจไฟๅญ˜ๅœจ /mnt/okcomputer/output/ ็›ฎๅฝ•๏ผš

ๆ–‡ไปถๅ ่ฏดๆ˜Ž
model_architecture.png ๆจกๅž‹ๆ•ดไฝ“ๆžถๆž„ๅ›พ
mla_nsa_attention.png MLA-NSA ๆททๅˆๆณจๆ„ๅŠ›่ฏฆ็ป†็ป“ๆž„ๅ›พ
dsmoe_architecture.png DSMoE ไธ“ๅฎถๆททๅˆ็ป“ๆž„ๅ›พ
training_pipeline.png ๅฎŒๆ•ด่ฎญ็ปƒๆต็จ‹ๅ›พ
tinyr2_overview.png Tiny-R2 ็ปผๅˆๆฆ‚่งˆๅ›พ

ๅ‚่€ƒ่ต„ๆ–™


ๆ–‡ๆกฃ็”Ÿๆˆๆ—ถ้—ด: 2026-02-16

About

A better combination: DSA/SWA/MLA, mHC, and DSMoE

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published