Skip to content

Commit f3ef193

Browse files
authored
Add files via upload
1 parent d716b7f commit f3ef193

2 files changed

Lines changed: 44 additions & 50 deletions

File tree

README.md

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -137,13 +137,13 @@ flowchart LR
137137
Before M2, the lookahead head was just a head with no real supervision. M2 adds a proper soft-target objective:
138138

139139
$$
140-
\mathcal{L}_{\text{lookahead}}
141-
= \frac{1}{|\mathcal{K}_{\text{valid}}|}
142-
\sum_{k \in \mathcal{K}_{\text{valid}}}
140+
\mathcal{L}_{\mathrm{lookahead}}
141+
= \frac{1}{|\mathcal{K}_{\mathrm{valid}}|}
142+
\sum_{k \in \mathcal{K}_{\mathrm{valid}}}
143143
\mathbb{E}_{b,t}
144144
\left[
145145
- \sum_{e=1}^{E}
146-
\operatorname{stopgrad}\!\left(P_{b,t+k,e}\right)
146+
\mathrm{sg}\!\left(P_{b,t+k,e}\right)
147147
\log Q_{b,t,e}^{(k)}
148148
\right].
149149
$$
@@ -262,45 +262,42 @@ Honest note: upstream PyTorch does not ship a real OpenCL backend, and Vulkan su
262262
## Objective
263263

264264
$$
265-
\begin{aligned}
266-
\mathcal{L}_{\text{total}}
267-
&= \mathcal{L}_{\text{base}}
268-
+ \lambda_{\text{bal}} \mathcal{L}_{\text{aux-raw}}
269-
+ \lambda_{\text{tmp}} \mathcal{L}_{\text{temporal}} \\
270-
&\quad
271-
+ \lambda_{\text{LA}} \mathcal{L}_{\text{lookahead}}
272-
+ \lambda_{\text{anc}} \mathcal{L}_{\text{router-KL-anchor}} .
273-
\end{aligned}
265+
\mathcal{L}_{\mathrm{total}}
266+
= \mathcal{L}_{\mathrm{base}}
267+
+ \lambda_{\mathrm{bal}} \mathcal{L}_{\mathrm{aux}}
268+
+ \lambda_{\mathrm{tmp}} \mathcal{L}_{\mathrm{temporal}}
269+
+ \lambda_{\mathrm{LA}} \mathcal{L}_{\mathrm{lookahead}}
270+
+ \lambda_{\mathrm{anc}} \mathcal{L}_{\mathrm{routerKL}}
274271
$$
275272

276273
$$
277-
\mathcal{L}_{\text{aux-raw}}
278-
= E \sum_{e=1}^{E} \operatorname{load}_e \cdot \overline{p}_e
274+
\mathcal{L}_{\mathrm{aux}}
275+
= E \sum_{e=1}^{E} \mathit{load}_e \cdot \overline{p}_e
279276
$$
280277

281278
$$
282-
\mathcal{L}_{\text{temporal}}
279+
\mathcal{L}_{\mathrm{temporal}}
283280
= \mathbb{E}_{b,t}
284281
\left[
285282
\left\| P_{b,t,:} - P_{b,t-1,:} \right\|_2^2
286283
\right]
287284
$$
288285

289286
$$
290-
\mathcal{L}_{\text{router-KL-anchor}}
287+
\mathcal{L}_{\mathrm{routerKL}}
291288
= D_{\mathrm{KL}}
292289
\left(
293-
\pi_{\theta}^{\text{router}}
294-
\,\middle\|\,
295-
\pi_{\text{ref}}^{\text{router}}
290+
\pi_{\theta}^{\mathrm{router}}
291+
\|
292+
\pi_{\mathrm{ref}}^{\mathrm{router}}
296293
\right)
297294
$$
298295

299-
- $\mathcal{L}_{\text{base}}$: stage-specific objective (`CE`, `DPO`, `ORPO`, `GRPO`, or distillation).
300-
- $\mathcal{L}_{\text{aux-raw}}$: the unscaled MoE load-balance auxiliary term; Chronos applies $\lambda_{\text{bal}}$ once in `chronos_loss_term`.
301-
- $\mathcal{L}_{\text{temporal}}$: encourages adjacent tokens to reuse similar expert distributions.
302-
- $\mathcal{L}_{\text{lookahead}}$: soft-target cross entropy from the real future router distribution to the lookahead prediction.
303-
- $\mathcal{L}_{\text{router-KL-anchor}}$: keeps alignment-stage updates from destroying the routing layout captured at stage start.
296+
- $\mathcal{L}_{\mathrm{base}}$: stage-specific objective (`CE`, `DPO`, `ORPO`, `GRPO`, or distillation).
297+
- $\mathcal{L}_{\mathrm{aux}}$: the unscaled MoE load-balance auxiliary term; Chronos applies $\lambda_{\mathrm{bal}}$ once in `chronos_loss_term`.
298+
- $\mathcal{L}_{\mathrm{temporal}}$: encourages adjacent tokens to reuse similar expert distributions.
299+
- $\mathcal{L}_{\mathrm{lookahead}}$: soft-target cross entropy from the real future router distribution to the lookahead prediction. Here $\mathrm{sg}(\cdot)$ means stop-gradient.
300+
- $\mathcal{L}_{\mathrm{routerKL}}$: keeps alignment-stage updates from destroying the routing layout captured at stage start.
304301

305302
All lambda terms are searchable with Optuna TPE, together with structural hyperparameters such as `hidden_size`, `num_experts`, and `kv_latent_dim`.
306303

README_zh.md

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,13 @@ flowchart LR
131131
M2 之前 LookaheadRouter 没有任何监督——只是个未训练的 head。M2 引入:
132132

133133
$$
134-
\mathcal{L}_{\text{lookahead}}
135-
= \frac{1}{|\mathcal{K}_{\text{valid}}|}
136-
\sum_{k \in \mathcal{K}_{\text{valid}}}
134+
\mathcal{L}_{\mathrm{lookahead}}
135+
= \frac{1}{|\mathcal{K}_{\mathrm{valid}}|}
136+
\sum_{k \in \mathcal{K}_{\mathrm{valid}}}
137137
\mathbb{E}_{b,t}
138138
\left[
139139
- \sum_{e=1}^{E}
140-
\operatorname{stopgrad}\!\left(P_{b,t+k,e}\right)
140+
\mathrm{sg}\!\left(P_{b,t+k,e}\right)
141141
\log Q_{b,t,e}^{(k)}
142142
\right].
143143
$$
@@ -254,45 +254,42 @@ d.describe() # 人类可读的能力总览
254254
## 损失函数(完整形式)
255255

256256
$$
257-
\begin{aligned}
258-
\mathcal{L}_{\text{total}}
259-
&= \mathcal{L}_{\text{base}}
260-
+ \lambda_{\text{bal}} \mathcal{L}_{\text{aux-raw}}
261-
+ \lambda_{\text{tmp}} \mathcal{L}_{\text{temporal}} \\
262-
&\quad
263-
+ \lambda_{\text{LA}} \mathcal{L}_{\text{lookahead}}
264-
+ \lambda_{\text{anc}} \mathcal{L}_{\text{router-KL-anchor}} .
265-
\end{aligned}
257+
\mathcal{L}_{\mathrm{total}}
258+
= \mathcal{L}_{\mathrm{base}}
259+
+ \lambda_{\mathrm{bal}} \mathcal{L}_{\mathrm{aux}}
260+
+ \lambda_{\mathrm{tmp}} \mathcal{L}_{\mathrm{temporal}}
261+
+ \lambda_{\mathrm{LA}} \mathcal{L}_{\mathrm{lookahead}}
262+
+ \lambda_{\mathrm{anc}} \mathcal{L}_{\mathrm{routerKL}}
266263
$$
267264

268265
$$
269-
\mathcal{L}_{\text{aux-raw}}
270-
= E \sum_{e=1}^{E} \operatorname{load}_e \cdot \overline{p}_e
266+
\mathcal{L}_{\mathrm{aux}}
267+
= E \sum_{e=1}^{E} \mathit{load}_e \cdot \overline{p}_e
271268
$$
272269

273270
$$
274-
\mathcal{L}_{\text{temporal}}
271+
\mathcal{L}_{\mathrm{temporal}}
275272
= \mathbb{E}_{b,t}
276273
\left[
277274
\left\| P_{b,t,:} - P_{b,t-1,:} \right\|_2^2
278275
\right]
279276
$$
280277

281278
$$
282-
\mathcal{L}_{\text{router-KL-anchor}}
279+
\mathcal{L}_{\mathrm{routerKL}}
283280
= D_{\mathrm{KL}}
284281
\left(
285-
\pi_{\theta}^{\text{router}}
286-
\,\middle\|\,
287-
\pi_{\text{ref}}^{\text{router}}
282+
\pi_{\theta}^{\mathrm{router}}
283+
\|
284+
\pi_{\mathrm{ref}}^{\mathrm{router}}
288285
\right)
289286
$$
290287

291-
- $\mathcal{L}_{\text{base}}$:阶段相关目标(CE / DPO / ORPO / GRPO / KD)。
292-
- $\mathcal{L}_{\text{aux-raw}}$:未缩放的 MoE load-balance 辅助项;Chronos 在 `chronos_loss_term` 中只乘一次 $\lambda_{\text{bal}}$。
293-
- $\mathcal{L}_{\text{temporal}}$:约束相邻 token 的路由分布不要剧烈跳变,提高专家复用和缓存局部性。
294-
- $\mathcal{L}_{\text{lookahead}}$:未来真实路由分布到前瞻预测的 soft-target cross entropy。
295-
- $\mathcal{L}_{\text{router-KL-anchor}}$:对齐阶段锚定 stage 开始时捕获的参考路由分布,防止 RL/DPO/ORPO/GRPO 梯度破坏聚簇布局。
288+
- $\mathcal{L}_{\mathrm{base}}$:阶段相关目标(CE / DPO / ORPO / GRPO / KD)。
289+
- $\mathcal{L}_{\mathrm{aux}}$:未缩放的 MoE load-balance 辅助项;Chronos 在 `chronos_loss_term` 中只乘一次 $\lambda_{\mathrm{bal}}$。
290+
- $\mathcal{L}_{\mathrm{temporal}}$:约束相邻 token 的路由分布不要剧烈跳变,提高专家复用和缓存局部性。
291+
- $\mathcal{L}_{\mathrm{lookahead}}$:未来真实路由分布到前瞻预测的 soft-target cross entropy。这里 $\mathrm{sg}(\cdot)$ 表示 stop-gradient
292+
- $\mathcal{L}_{\mathrm{routerKL}}$:对齐阶段锚定 stage 开始时捕获的参考路由分布,防止 RL/DPO/ORPO/GRPO 梯度破坏聚簇布局。
296293

297294
`λ` 全部支持 Optuna TPE 自动搜索(包括 `hidden_size` / `num_experts` / `kv_latent_dim` 等结构超参)。
298295

0 commit comments

Comments
 (0)