Skip to content

feat: GKD with external teacher backends + on-policy + JSD mix#36

Open
marksverdhei wants to merge 2 commits into
mainfrom
feat/gkd-context-distillation
Open

feat: GKD with external teacher backends + on-policy + JSD mix#36
marksverdhei wants to merge 2 commits into
mainfrom
feat/gkd-context-distillation

Conversation

@marksverdhei
Copy link
Copy Markdown
Owner

@marksverdhei marksverdhei commented May 11, 2026

Summary

Adds Generalized Knowledge Distillation (GKD) on top of the context-baking machinery from #35. The student can now distill from a separate teacher — a different HuggingFace model, a vLLM server, or any OpenAI-compatible API with logprobs — while continuing to bake an arbitrary prefix context into its weights in the same training sweep.

LoRA remains the default student adapter (it regularizes against capability degradation under the combined objective), but the trainer is no longer LoRA-gated — separate-teacher modes work with full FT too.

What's in

Teacher backend abstraction (src/bakery/teachers/)

  • TeacherBackend ABC returning TopKLogprobs (sparse top-k view of the teacher's distribution, renormalized over those K)
  • HFTeacher — in-process Transformers model (dev/test path)
  • VLLMTeacher — vLLM as OpenAI-compat server (production, TRL-style)
  • OpenAIAPITeacher — any OpenAI-compat endpoint that exposes logprobs (Together, Fireworks, vLLM itself)

The same TopKLogprobs shape works for both dense (K=vocab) and sparse (K=20) teachers; the trainer's KL math is identical across backends.

Sparse KL math (src/bakery/kl.py)

  • topk_forward_kl — forward KL over the teacher's top-k support. At K=V, identical to dense KL.
  • topk_jsd — mixed forward/reverse KL: loss = (1-β)·KL(P_t||P_s) + β·KL(P_s||P_t). β=0 is forward (default), β=1 is reverse, β=0.5 is symmetric. Same convention as TRL's GKDTrainer.

Trainer integration (src/bakery/trainer.py)

  • teacher_backend + teacher_top_k constructor args
  • gkd_on_policy_fraction — probability of sampling the trajectory from the student (mode-seeking, on-policy GKD)
  • gkd_jsd_beta — routes loss to topk_jsd instead of topk_forward_kl
  • Unified _align_and_slice helper shared by dense and sparse KL paths
  • _sample_from_student for the on-policy path (adapters enabled, student's trimmed view)

Config (src/bakery/config.py)

New TeacherConfig exposed as a 5th HfArgumentParser dataclass. All fields prefixed with teacher_ to avoid CLI flag collisions with student DataConfig:

  • teacher_backend: local-toggle (default), hf, vllm, openai
  • teacher_model_name_or_path, teacher_api_base, teacher_api_key, teacher_api_model
  • teacher_top_k, teacher_torch_dtype, teacher_device, teacher_attn_implementation
  • gkd_on_policy_fraction, gkd_jsd_beta

Example

examples/gkd_gemma3.yaml — Gemma 3 1B teacher ↦ 270M student with prefix context baked in one sweep, LoRA default.

Backward compatibility

  • teacher_backend=local-toggle (the default) keeps every prior bakery code path bit-identical: same dense KL, same adapter-toggle teacher, same trajectory generation.
  • gkd_jsd_beta=0 (default) reduces topk_jsd exactly to topk_forward_kl.
  • gkd_on_policy_fraction=0 (default) keeps trajectories sampled from the teacher.
  • Older examples (basic.yaml, multi_turn_prefix.yaml, etc.) are unchanged.

Test plan

  • 50 new tests under test_topk_kl.py (incl. JSD identities), test_teacher_backends.py, test_openai_teacher_score.py (mocked vLLM payloads — runs offline), test_trainer_gkd.py (compute_loss + prediction_step end-to-end with external teacher + on-policy routing + guardrails)
  • Full CPU suite: 315 passing (pytest -m "not gpu and not benchmark")
  • GPU smoke (RTX 3090, Gemma 3 270M ← Gemma 3 1B HF teacher):
    • off-policy forward KL (default): loss 6.108
    • on-policy gkd_on_policy_fraction=1.0 + gkd_jsd_beta=0.5: loss 4.099
  • Backward compat: local-toggle path verified unchanged via test_local_toggle_path_unchanged

Known limitations / follow-ups

  • OpenAIAPITeacher.score() assumes the server's tokenizer matches the student's (Gemma family → Gemma student, Qwen → Qwen, etc.). Cross-tokenizer distillation (e.g. Llama→Qwen) is out of scope.
  • VLLMInProcessTeacher is stubbed (raises NotImplementedError). Use VLLMTeacher (HTTP) or HFTeacher for now.
  • True GKD with on-policy + sequence-level reward is partial: the student sampler is wired, but trajectory selection isn't yet tied to teacher-score quality.

Adds a TeacherBackend abstraction so the trainer can distill from any of:

- HFTeacher          — load a separate HuggingFace model in-process (fallback)
- VLLMTeacher        — vLLM as OpenAI-compat server (primary, TRL pattern)
- OpenAIAPITeacher   — any OpenAI-compat endpoint that exposes logprobs

All backends return TopKLogprobs — a sparse top-k view of the teacher's
distribution, renormalized over those K — so the KL math (`topk_forward_kl`)
is identical whether K covers the full vocab (HF) or K=20 (API). The dense
local-toggle path is unchanged.

Trainer dispatches between dense and sparse KL via a unified `_align_and_slice`
helper shared by `_kl_from_logits` and `_kl_from_topk`. `_generate_trajectory`
also delegates to the external teacher when one is configured.

New `TeacherConfig` exposes flat CLI flags (`teacher_backend`, `teacher_top_k`,
`teacher_model_name_or_path`, `teacher_api_base`, `teacher_api_key`, etc.) — all
prefixed with `teacher_` to avoid colliding with student DataConfig fields.

Example `examples/gkd_gemma3.yaml` demonstrates the combined objective:
Gemma 3 1B → 270M with a prefix context baked alongside the capability
distillation, in one training sweep. LoRA stays on by default as a regularizer.

vLLM/OpenAI score() endpoints are stubbed (raise NotImplementedError) pending
the vLLM logprobs-via-echo plumbing; generate() works for both. HFTeacher
covers the dev/test path and the first smoke runs.

Tests: 31 new (topk_kl math, teacher backends, trainer GKD integration).
Smoke: Gemma 3 270M student + 1B HF teacher on RTX 3090, end-to-end loss
and adapter save verified.
Adds the two pieces left from the GKD branch's first commit:

1. OpenAIAPITeacher.score(): wired against /v1/completions with
   prompt=token_ids, echo=true, max_tokens=0, logprobs=K. Returned token
   strings are re-encoded against the student tokenizer (same-tokenizer
   assumption from the design memo). Trainer registers the student
   tokenizer on the backend in __init__ via set_student_tokenizer.

2. On-policy GKD trajectories + reverse/JSD-style KL mix:
   - gkd_on_policy_fraction (TeacherConfig): probability of sampling the
     trajectory from the student (mode-seeking on-policy) vs. the teacher
     (mode-covering off-policy). 0.0 keeps current behavior.
   - gkd_jsd_beta (TeacherConfig): mix between forward and reverse KL.
     loss = (1-β)·KL(P_t||P_s) + β·KL(P_s||P_t). β=0 → forward KL
     (default, identical math to before), β=1 → reverse KL, β=0.5 →
     symmetric average. Same convention as TRL's GKDTrainer.

   Trainer reads both from TeacherConfig and routes loss + trajectory
   generation accordingly. New `_sample_from_student` mirrors the teacher
   sampler but with adapters enabled and the student's trimmed view.

KL math:
- New `topk_jsd` in kl.py — operates on the teacher's top-k support,
  renormalizes the student over the same K, computes forward and reverse
  KL in log-space. Reduces exactly to topk_forward_kl at β=0.

Tests (+13): JSD identities (β=0 = forward KL, β=0.5 = avg, β=1 differs),
on-policy routing (β=1 → always student sampler, β=0 → never), trainer
guardrails for invalid β / fractions, OpenAI score() against a mocked
vLLM payload (shape, renormalization, null-first-position handling,
left-padding strip, empty-top_logprobs robustness).

Smoke (GPU, Gemma 3 270M ← 1B): on-policy fraction 1.0, JSD β=0.5 —
loss 4.099 end-to-end, adapter saved.
@marksverdhei marksverdhei force-pushed the feat/gkd-context-distillation branch from 9fc2b1e to ac85ce1 Compare May 11, 2026 15:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant