feat: GKD with external teacher backends + on-policy + JSD mix#36
Open
marksverdhei wants to merge 2 commits into
Open
feat: GKD with external teacher backends + on-policy + JSD mix#36marksverdhei wants to merge 2 commits into
marksverdhei wants to merge 2 commits into
Conversation
Adds a TeacherBackend abstraction so the trainer can distill from any of: - HFTeacher — load a separate HuggingFace model in-process (fallback) - VLLMTeacher — vLLM as OpenAI-compat server (primary, TRL pattern) - OpenAIAPITeacher — any OpenAI-compat endpoint that exposes logprobs All backends return TopKLogprobs — a sparse top-k view of the teacher's distribution, renormalized over those K — so the KL math (`topk_forward_kl`) is identical whether K covers the full vocab (HF) or K=20 (API). The dense local-toggle path is unchanged. Trainer dispatches between dense and sparse KL via a unified `_align_and_slice` helper shared by `_kl_from_logits` and `_kl_from_topk`. `_generate_trajectory` also delegates to the external teacher when one is configured. New `TeacherConfig` exposes flat CLI flags (`teacher_backend`, `teacher_top_k`, `teacher_model_name_or_path`, `teacher_api_base`, `teacher_api_key`, etc.) — all prefixed with `teacher_` to avoid colliding with student DataConfig fields. Example `examples/gkd_gemma3.yaml` demonstrates the combined objective: Gemma 3 1B → 270M with a prefix context baked alongside the capability distillation, in one training sweep. LoRA stays on by default as a regularizer. vLLM/OpenAI score() endpoints are stubbed (raise NotImplementedError) pending the vLLM logprobs-via-echo plumbing; generate() works for both. HFTeacher covers the dev/test path and the first smoke runs. Tests: 31 new (topk_kl math, teacher backends, trainer GKD integration). Smoke: Gemma 3 270M student + 1B HF teacher on RTX 3090, end-to-end loss and adapter save verified.
Adds the two pieces left from the GKD branch's first commit:
1. OpenAIAPITeacher.score(): wired against /v1/completions with
prompt=token_ids, echo=true, max_tokens=0, logprobs=K. Returned token
strings are re-encoded against the student tokenizer (same-tokenizer
assumption from the design memo). Trainer registers the student
tokenizer on the backend in __init__ via set_student_tokenizer.
2. On-policy GKD trajectories + reverse/JSD-style KL mix:
- gkd_on_policy_fraction (TeacherConfig): probability of sampling the
trajectory from the student (mode-seeking on-policy) vs. the teacher
(mode-covering off-policy). 0.0 keeps current behavior.
- gkd_jsd_beta (TeacherConfig): mix between forward and reverse KL.
loss = (1-β)·KL(P_t||P_s) + β·KL(P_s||P_t). β=0 → forward KL
(default, identical math to before), β=1 → reverse KL, β=0.5 →
symmetric average. Same convention as TRL's GKDTrainer.
Trainer reads both from TeacherConfig and routes loss + trajectory
generation accordingly. New `_sample_from_student` mirrors the teacher
sampler but with adapters enabled and the student's trimmed view.
KL math:
- New `topk_jsd` in kl.py — operates on the teacher's top-k support,
renormalizes the student over the same K, computes forward and reverse
KL in log-space. Reduces exactly to topk_forward_kl at β=0.
Tests (+13): JSD identities (β=0 = forward KL, β=0.5 = avg, β=1 differs),
on-policy routing (β=1 → always student sampler, β=0 → never), trainer
guardrails for invalid β / fractions, OpenAI score() against a mocked
vLLM payload (shape, renormalization, null-first-position handling,
left-padding strip, empty-top_logprobs robustness).
Smoke (GPU, Gemma 3 270M ← 1B): on-policy fraction 1.0, JSD β=0.5 —
loss 4.099 end-to-end, adapter saved.
9fc2b1e to
ac85ce1
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds Generalized Knowledge Distillation (GKD) on top of the context-baking machinery from #35. The student can now distill from a separate teacher — a different HuggingFace model, a vLLM server, or any OpenAI-compatible API with logprobs — while continuing to bake an arbitrary prefix context into its weights in the same training sweep.
LoRA remains the default student adapter (it regularizes against capability degradation under the combined objective), but the trainer is no longer LoRA-gated — separate-teacher modes work with full FT too.
What's in
Teacher backend abstraction (
src/bakery/teachers/)TeacherBackendABC returningTopKLogprobs(sparse top-k view of the teacher's distribution, renormalized over those K)HFTeacher— in-process Transformers model (dev/test path)VLLMTeacher— vLLM as OpenAI-compat server (production, TRL-style)OpenAIAPITeacher— any OpenAI-compat endpoint that exposeslogprobs(Together, Fireworks, vLLM itself)The same
TopKLogprobsshape works for both dense (K=vocab) and sparse (K=20) teachers; the trainer's KL math is identical across backends.Sparse KL math (
src/bakery/kl.py)topk_forward_kl— forward KL over the teacher's top-k support. At K=V, identical to dense KL.topk_jsd— mixed forward/reverse KL:loss = (1-β)·KL(P_t||P_s) + β·KL(P_s||P_t). β=0 is forward (default), β=1 is reverse, β=0.5 is symmetric. Same convention as TRL's GKDTrainer.Trainer integration (
src/bakery/trainer.py)teacher_backend+teacher_top_kconstructor argsgkd_on_policy_fraction— probability of sampling the trajectory from the student (mode-seeking, on-policy GKD)gkd_jsd_beta— routes loss totopk_jsdinstead oftopk_forward_kl_align_and_slicehelper shared by dense and sparse KL paths_sample_from_studentfor the on-policy path (adapters enabled, student's trimmed view)Config (
src/bakery/config.py)New
TeacherConfigexposed as a 5th HfArgumentParser dataclass. All fields prefixed withteacher_to avoid CLI flag collisions with studentDataConfig:teacher_backend:local-toggle(default),hf,vllm,openaiteacher_model_name_or_path,teacher_api_base,teacher_api_key,teacher_api_modelteacher_top_k,teacher_torch_dtype,teacher_device,teacher_attn_implementationgkd_on_policy_fraction,gkd_jsd_betaExample
examples/gkd_gemma3.yaml— Gemma 3 1B teacher ↦ 270M student with prefix context baked in one sweep, LoRA default.Backward compatibility
teacher_backend=local-toggle(the default) keeps every prior bakery code path bit-identical: same dense KL, same adapter-toggle teacher, same trajectory generation.gkd_jsd_beta=0(default) reducestopk_jsdexactly totopk_forward_kl.gkd_on_policy_fraction=0(default) keeps trajectories sampled from the teacher.basic.yaml,multi_turn_prefix.yaml, etc.) are unchanged.Test plan
test_topk_kl.py(incl. JSD identities),test_teacher_backends.py,test_openai_teacher_score.py(mocked vLLM payloads — runs offline),test_trainer_gkd.py(compute_loss + prediction_step end-to-end with external teacher + on-policy routing + guardrails)pytest -m "not gpu and not benchmark")gkd_on_policy_fraction=1.0+gkd_jsd_beta=0.5: loss 4.099test_local_toggle_path_unchangedKnown limitations / follow-ups
OpenAIAPITeacher.score()assumes the server's tokenizer matches the student's (Gemma family → Gemma student, Qwen → Qwen, etc.). Cross-tokenizer distillation (e.g. Llama→Qwen) is out of scope.VLLMInProcessTeacheris stubbed (raisesNotImplementedError). UseVLLMTeacher(HTTP) orHFTeacherfor now.