Skip to content

Commit 51f3a51

Browse files
committed
qwen2:从权重目录名,反推的模型结构
1 parent c8303bb commit 51f3a51

1 file changed

Lines changed: 119 additions & 0 deletions

File tree

todo/qwen.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# cursor 根据deepseek-qwen2 1.5b到处的deepx模型目录名,反推导出的pytorch模型
2+
# DeepSeek-R1-Distill-Qwen-1.5B
3+
import torch
4+
import torch.nn as nn
5+
from transformers import PreTrainedModel
6+
from transformers.utils import ModelOutput
7+
8+
class RMSNorm(nn.Module):
9+
def __init__(self, dim, eps=1e-6):
10+
super().__init__()
11+
self.eps = eps
12+
self.weight = nn.Parameter(torch.ones(dim))
13+
14+
def _norm(self, x):
15+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
16+
17+
def forward(self, x):
18+
return self.weight * self._norm(x.float()).type_as(x)
19+
20+
class Qwen2Attention(nn.Module):
21+
def __init__(self, config):
22+
super().__init__()
23+
self.hidden_size = config.hidden_size # 1536
24+
self.num_heads = config.num_attention_heads # 12
25+
self.head_dim = self.hidden_size // self.num_heads # 128
26+
self.num_key_value_heads = config.num_key_value_heads # 2
27+
28+
# 根据shape文件中的维度定义
29+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
30+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
31+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
32+
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
33+
34+
def forward(self, hidden_states, attention_mask=None):
35+
# 实现GQA分组查询注意力逻辑
36+
# 包含RoPE位置编码实现(根据config.use_mrope决定)
37+
# 返回注意力计算结果
38+
return hidden_states
39+
40+
class Qwen2MLP(nn.Module):
41+
def __init__(self, config):
42+
super().__init__()
43+
self.hidden_size = config.hidden_size # 1536
44+
self.intermediate_size = config.intermediate_size # 8960
45+
46+
# 根据目录结构中的mlp层定义
47+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
48+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
49+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
50+
self.act_fn = nn.SiLU()
51+
52+
def forward(self, x):
53+
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
54+
55+
class Qwen2DecoderLayer(nn.Module):
56+
def __init__(self, config):
57+
super().__init__()
58+
self.hidden_size = config.hidden_size
59+
self.self_attn = Qwen2Attention(config)
60+
61+
# 根据目录结构中的layernorm定义
62+
self.input_layernorm = RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
63+
self.post_attention_layernorm = RMSNorm(self.hidden_size, eps=config.rms_norm_eps)
64+
65+
self.mlp = Qwen2MLP(config)
66+
67+
def forward(self, hidden_states, attention_mask=None):
68+
# 实现残差连接
69+
residual = hidden_states
70+
hidden_states = self.input_layernorm(hidden_states)
71+
hidden_states = self.self_attn(hidden_states, attention_mask)
72+
hidden_states = residual + hidden_states
73+
74+
residual = hidden_states
75+
hidden_states = self.post_attention_layernorm(hidden_states)
76+
hidden_states = self.mlp(hidden_states)
77+
hidden_states = residual + hidden_states
78+
return hidden_states
79+
80+
class Qwen2Model(nn.Module):
81+
def __init__(self, config):
82+
super().__init__()
83+
self.config = config
84+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
85+
self.layers = nn.ModuleList(
86+
[Qwen2DecoderLayer(config) for _ in range(config.num_hidden_layers)] # 28层
87+
)
88+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
89+
90+
def forward(self, input_ids, attention_mask=None):
91+
hidden_states = self.embed_tokens(input_ids)
92+
for layer in self.layers:
93+
hidden_states = layer(hidden_states, attention_mask)
94+
hidden_states = self.norm(hidden_states)
95+
return hidden_states
96+
97+
class Qwen2ForCausalLM(PreTrainedModel):
98+
def __init__(self, config):
99+
super().__init__(config)
100+
self.model = Qwen2Model(config)
101+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
102+
103+
# 根据config.yaml中的参数设置
104+
self.config.tie_word_embeddings = False # 不共享embedding权重
105+
106+
def forward(self, input_ids, attention_mask=None, labels=None):
107+
hidden_states = self.model(input_ids, attention_mask)
108+
logits = self.lm_head(hidden_states)
109+
110+
loss = None
111+
if labels is not None:
112+
loss_fct = nn.CrossEntropyLoss()
113+
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
114+
115+
return ModelOutput(
116+
loss=loss,
117+
logits=logits,
118+
hidden_states=hidden_states,
119+
)

0 commit comments

Comments
 (0)