diff --git a/records/track_non_record_16mb/2026-03-20_DART_DiffAttn_Recurrent/README.md b/records/track_non_record_16mb/2026-03-20_DART_DiffAttn_Recurrent/README.md new file mode 100644 index 000000000..9755cb279 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_DART_DiffAttn_Recurrent/README.md @@ -0,0 +1,153 @@ +# DART: Differential Attention Recurrent Transformer + +**Author:** Anand K S (https://github.com/anandks2006) +**Institution:** Independent, Kerala, India +**Track:** Non-record (unlimited compute) +**Date:** 2026-03-20 +**Score:** `val_bpb = 1.85221128` + + +## How I Found The Competition + +I came across this competition through my Google feed on March 18, 2026 the same day the competition was launched. I am currently in my 2nd year of BCA (Bachelor of Computer Applications) and I have no formal research background, no GPU, and I have never tried to train models before. But the architecture idea seemed promising and I decided to push forward. + +--- + +## What DART Is + +DART is my attempt at a lightweight alternative to the standard approach of stacking many independent layers. Most models work by passing data through 9, 12 or more completely separate blocks, each one a different set of weights doing a different transformation. DART takes a different approach and ie, one block used multiple times in a loop. + +The idea is to get maximum information out of minimum parameters. Instead of paying the full parameter cost for each layer, DART reuses the same block over multiple passes but makes sure each pass brings something new, so the model is not just repeating the same transformation over and over. + +To make each loop meaningfully different, DART uses a couple of methods, they are: + +- **Differential Attention V2** — instead of standard attention which can get distracted by irrelevant tokens, DART uses two attention calculations and subtracts one from the other to cancel out the noise. This is from Microsoft Research (2025). + +- **Per-loop low-rank Q delta** — each loop gets a small unique modification to how it processes queries, so loop 1 might focus on basic word patterns while loop 8 focuses on higher-level thinking. This was my own idea for this architecture.The total cost is 65,536 parameters across 4 loops and is a small price for genuine per-loop specialisation. This was my idea for this architecture and I have not seen it applied this way in the literature. + +- **resid_mix** — each loop applies a learned balance between the current hidden state and the original input. This prevents the hidden state from drifting too far from the input across multiple passes, which can cause the model to lose track of what the original text said. + +- **Loop position embeddings** — a small vector added to the hidden state at each loop telling the block which pass number it is on. + +- **U-Net skip connections** — the first half of loops save their hidden states, and the second half receives them in reverse order. This lets later loops directly access the raw early representations without them being overwritten. + +- **QAT (Quantization-Aware Training)** — during training, weights are fake-quantized to simulate the int8 compression that happens at submission time. Specifically, per-row 99.99984th percentile clipping is applied and exactly matching the competition's evaluation quantization. This means the model trains knowing it will be compressed, rather than being surprised by quantization at the end. + +- **From the competition baseline (unchanged)**: relu² MLP, QK RMSNorm before RoPE, per-head q_gain scalar, logit softcap, CastedLinear (fp32 weights with bf16 matmul), Muon optimizer, tokenizer-aware BPB evaluation. + +- **Loop dropout(not there in final training)** — during cpu testing, each batch randomly uses a different number of loops. This prevents a training problem where earlier and later loops fight each other over the shared weights. Without this, more loops actually made the model worse and finding and fixing this problem was the most important research discovery I was met with during development also this issue was such a headache to figure out a solution since when I trained initially with various loop number and block number configuration, always the single loop performed better that the one with more loops which made me confused since having more loop should automatically make it perform well and that was my idea in head. And at end I figured it out and unfortunately I missed it out on training and its absence likely costs 0.02-0.05 bpb on the final training. + +- **Global Memory tokens** — small learned vectors that carry information across loops, acting like a notepad the model can write to and read from across passes. + +- **Deep supervision** — loss is computed after every loop and not just at the final one, so every pass through the block is forced to be useful. + +## The result is a model with about 3.9 million parameters — roughly the same as the baseline's parameter count when you account for what those parameters achieve — compressed to 3.5MB, using only 22.5% of the 16MB budget. + +## I will be working on making the architecture even better down the line by implementing other existing excellent techniques that will be the model surpass the current issues. + +--- + +## How I Built It + +I used Claude, ChatGPT and Gemini throughout the project. I want to be completely honest about the use of AI. + +The AI assistants helped me find relevant research papers, and Claude wrote the code, and suggested ideas that support the architecture. But the decisions were mine. I questioned every suggestion, ran every experiment on my laptop cpu and several times disagreed simultaneously with all three AI systems or as I like to call them The Council. + +A clear example for the above statement is when early in cpu experiments showed that adding more loops made the model perform worse, all three AI systems told me to drop the recurrent approach entirely and just use a single-pass model. I thought the idea was still sound and kept investigating. Eventually we found the actual cause, it was a gradient conflict problem in shared-weight training and fixed it. Unfortunately the final architecture didn't keep the loop dropout since I forgot to implement it during final training. + +The AI council handled the code and finding research papers. My job was to design the architecture by combining existing approaches that I believed will work well together and to push back when the results did not match the theory and suggests fixes. My initial inspiration came from Samsung SAIL Montreal's TinyRecursiveModels and seeing that a tiny model with repeated passes could outperform much larger models on hard reasoning tasks made me want to apply the same philosophy to language modeling for this competition. + +--- + +## Training + +The hardest part of this project was compute. I spent hours running architecture experiments on my laptop CPU (Ryzen 5500U, 8GB RAM) with small configurations to validate that the design actually learns. Once I was confident the architecture worked, the only GPU option I had was Google Colab's free T4. + +The T4 free tier has no guaranteed uptime, disconnects randomly, and gave me around 2-3 hours per session. Without torch.compile (which caused indefinite hangs on T4 with my architecture), each training step took about 2.6 seconds. That meant I could only run 2,000 steps — roughly 65 million tokens of training — before the session limits ran out. + +Another issue was that the competition baseline used 10.5 billion tokens. I used about 160 times less. The score gap between DART (1.852) and the baseline (1.224) is almost entirely explained by this, not by the architecture being worse. + +If I had run the same number of steps as the baseline on equivalent hardware, our 87-minute T4 run would have taken about 16 seconds on 8×H100s. The 10-minute competition window would have allowed around 73,000 training steps. + +**Training configuration:** + +| Parameter | Value | +|---|---| +| Hardware | Google Colab T4, free tier | +| Steps | 2,000 | +| Sequence length | 256 tokens | +| Batch | 32,768 tokens/step | +| Total tokens | ~65M | +| Training time | ~87 minutes | +| Model parameters | 3,918,888 | +| Compressed size | 3.55MB (22.5% of 16MB) | + +--- + +## Results + +| Step | val_bpb | +|---|---| +| 0 | 4.1040 | +| 500 | 2.0876 | +| 1,000 | 1.9957 | +| 1,500 | 1.9294 | +| 2,000 | 1.8502 | +| **Final (int8 roundtrip)** | **1.85221128** | + +The score is improving consistently across the run and had not plateaued at step 2000, suggesting the architecture has room to improve with more training and loop dropout. + +--- + +## What Did Not Work + +**16 loops was too slow**. The architecture was designed to run 16 loops. On T4 without torch.compile, that required 18 seconds per step — too slow to train meaningfully. I reduced to 4 loops. Whether 16 loops outperform 4 loops at full training scale is something I was not able to verify. + +**torch.compile hung indefinitely** on T4 with gradient checkpointing enabled. Disabling it slowed training by about 4×. This is a known compatibility issue on older GPU architectures. + +**Loops did not clearly beat 1-loop** in CPU ablation experiments. Loop dropout reduced the performance gap from 0.28 nats to 0.03 nats, but I was not able to run enough steps to definitively prove the recurrent approach is better than a single-pass model. The compute constraints made this impossible to resolve on CPU alone. + +--- + +## Reproducibility + +**To reproduce the submitted result (2000 steps, reduced config):** + +```bash +DATA_PATH=/path/to/fineweb10B_sp1024 \ +TOKENIZER_PATH=/path/to/fineweb_1024_bpe.model \ +RUN_ID=dart_repro \ +N_LOOPS=4 N_MEMORY=16 MODEL_DIM=512 \ +TRAIN_SEQ_LEN=256 TRAIN_BATCH_TOKENS=32768 \ +ITERATIONS=2000 python train_gpt.py +``` + +Requires a CUDA GPU with at least 6GB VRAM. torch.compile is disabled in this submission due to compatibility issues on T4. + +**To evaluate the architecture at full scale (recommended for 8×H100):** + +```bash +DATA_PATH=/path/to/fineweb10B_sp1024 \ +TOKENIZER_PATH=/path/to/fineweb_1024_bpe.model \ +RUN_ID=dart_fullscale \ +N_LOOPS=16 N_MEMORY=32 MODEL_DIM=512 \ +TRAIN_SEQ_LEN=1024 TRAIN_BATCH_TOKENS=524288 \ +ITERATIONS=20000 MAX_WALLCLOCK_SECONDS=600 \ +python train_gpt.py +``` + +This uses the intended 16 loops, full sequence length, and standard batch size. +The submitted result used 0.6% of this training budget due to free-tier compute +constraints. The architecture was designed for this config and has not been +evaluated at full scale. + +--- + +## Files + +| File | Description | +|---|---| +| `train_gpt.py` | Training script | +| `submission.json` | Scores and metadata | +| `train_log.txt` | Full training log | +| `README.md` | This document | diff --git a/records/track_non_record_16mb/2026-03-20_DART_DiffAttn_Recurrent/submission.json b/records/track_non_record_16mb/2026-03-20_DART_DiffAttn_Recurrent/submission.json new file mode 100644 index 000000000..490c2775a --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_DART_DiffAttn_Recurrent/submission.json @@ -0,0 +1,22 @@ +{ + "run_id": "rlm_v1_t4", + "val_bpb": 1.85221128, + "val_loss": 3.09141515, + "authors": ["Anand K S"], + "github_ids": ["@anandks2006"], + "institution": "Independent, Kerala, India", + "date": "2026-03-20", + "model_dim": 512, + "n_loops": 4, + "n_memory": 16, + "seq_len": 256, + "vocab_size": 1024, + "iterations": 2000, + "hardware": "Google Colab T4 (free tier)", + "model_bytes": 3554683, + "code_bytes": 42197, + "total_bytes": 3596880, + "budget_pct": 22.5, + "summary": "DART: Differential Attention Recurrent Transformer. Single shared block looped 4x with per-loop low-rank Q delta, resid_mix, loop position embeddings, memory tokens, U-Net skips, deep supervision. 3.92M params, 3.55MB int8+zlib (22.5% of 16MB budget). Trained 2000 steps on FineWeb sp1024, Google Colab T4 free tier. Student submission, BCA 2nd year, Kerala, India." +} +``` diff --git a/records/track_non_record_16mb/2026-03-20_DART_DiffAttn_Recurrent/train_gpt.py b/records/track_non_record_16mb/2026-03-20_DART_DiffAttn_Recurrent/train_gpt.py new file mode 100644 index 000000000..4290a1d1f --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_DART_DiffAttn_Recurrent/train_gpt.py @@ -0,0 +1,845 @@ +""" +train_gpt.py — Shared-Weight Recurrent Memory Transformer (RLM) +OpenAI Parameter Golf Competition — Non-Record Submission + +Architecture: One transformer block looped N_LOOPS=16 times (weight tying). + - Differential Attention V2 + per-loop low-rank Q delta (weight specialization) + - resid_mix: learned carry-forward vs input-reset per loop + - Loop position embeddings (activation-level depth signal) + - Memory tokens: learned persistent cross-loop state (RoPE excluded) + - Deep supervision: gamma-weighted loss at every loop iteration + - U-Net skip connections between encoder/decoder loop halves + - Gradient checkpointing: memory-efficient backprop through 16 loops + - QAT: per-row percentile fake-quantization during training + - relu² MLP, logit softcap, QK RMSNorm, q_gain (from baseline) + +At dim=512, 16 loops: ~4.24M params, ~4.24MB int8 — 26.5% of 16MB budget. +Effective compute depth of a ~63M parameter standard transformer. + +Thermal protection: checkpoints saved every SAVE_EVERY steps, auto-resumed. +""" + +from __future__ import annotations +import copy, glob, io, math, os, subprocess, sys, time, uuid, zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.checkpoint import checkpoint as grad_ckpt + +# ── Hyperparameters ─────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + save_every = int(os.environ.get("SAVE_EVERY", 500)) + + iterations = int(os.environ.get("ITERATIONS", 5000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 600)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 65_536)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0.0)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + n_loops = int(os.environ.get("N_LOOPS", 16)) + n_memory = int(os.environ.get("N_MEMORY", 32)) + mlp_mult = int(os.environ.get("MLP_MULT", 4)) + lora_rank = int(os.environ.get("LORA_RANK", 16)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + deep_sup_gamma= float(os.environ.get("DEEP_SUP_GAMMA", 0.9)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum= float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 300)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.1)) + +# ── Muon Optimizer (from modded-nanogpt / competition baseline) ─────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum, backend_steps, nesterov = ( + group["lr"], group["momentum"], group["backend_steps"], group["nesterov"] + ) + total = sum(int(p.numel()) for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-lr) + curr += p.numel() + return loss + +# ── Control tensor patterns (kept fp32, excluded from Muon) ────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,mlp_scale,resid_mix,q_gain,lambda_q1,lambda_k1,lambda_q2,lambda_k2," + "loop_embed,skip_weight,memory", +).split(",") if p) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# ── Tokenizer-aware BPB (from competition baseline) ─────────────────────────── + +def build_sentencepiece_luts(sp, vocab_size: int, device: torch.device): + sp_vocab = int(sp.vocab_size()) + table = max(sp_vocab, vocab_size) + base_np = np.zeros((table,), dtype=np.int16) + space_np = np.zeros((table,), dtype=np.bool_) + bnd_np = np.ones((table,), dtype=np.bool_) + for tid in range(sp_vocab): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue + bnd_np[tid] = False + if sp.is_byte(tid): base_np[tid] = 1; continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): space_np[tid] = True; piece = piece[1:] + base_np[tid] = len(piece.encode("utf-8")) + return (torch.tensor(base_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(bnd_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No val files: {pattern}") + tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[:usable + 1] + +@torch.no_grad() +def eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut): + local_tok = args.val_batch_size // (world_size * grad_accum_steps) + local_seqs = local_tok // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + val_seq_limit = min(seq_end, seq_start + 4096) # capped for T4 speed + with torch.inference_mode(): + for s in range(seq_start, val_seq_limit, local_seqs): + se = min(s + local_seqs, seq_end) + raw = val_tokens[s * args.train_seq_len:(se * args.train_seq_len + 1)] + raw = raw.to(device=device, dtype=torch.int64, non_blocking=True) + x = raw[:-1].reshape(-1, args.train_seq_len) + y = raw[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y).detach() + n = float(y.numel()) + loss_sum += loss.to(torch.float64) * n + tok_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_lut[tgt_ids].to(torch.int16) + tb += (space_lut[tgt_ids] & ~bnd_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, tok_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + model.train() + val_loss = float((loss_sum / tok_count).item()) + bpb = float((loss_sum / tok_count / math.log(2.0) * tok_count / byte_count).item()) + return val_loss, bpb + +# ── Quantization (from competition baseline) ────────────────────────────────── + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def _keep_float(name: str, t: Tensor, pt_dtypes: dict) -> Tensor: + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + pt_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + cl = torch.maximum(torch.minimum(t32, clip[:, None]), -clip[:, None]) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(cl / sc[:, None]), -127, 127).to(torch.int8).contiguous() + return q, sc.to(INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + sc = torch.tensor(clip / 127.0 if clip > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip, clip) / sc), -127, 127).to(torch.int8).contiguous() + return q, sc + +def quantize_state_dict_int8(sd: dict): + quant, scales, dtypes, passthrough, pt_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors", + "num_nonfloat_tensors","baseline_tensor_bytes","int8_payload_bytes"), 0) + for name, tensor in sd.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = _keep_float(name, t, pt_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = _quantize_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quant[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = {"__quant_format__": "int8_clean_per_row_v1", + "quantized": quant, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if pt_dtypes: obj["passthrough_orig_dtypes"] = pt_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict) -> dict: + out = {} + qm = obj.get("qmeta", {}) + ptod = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qm.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype).contiguous() + for name, t in obj["passthrough"].items(): + ot = t.detach().cpu().contiguous() + if isinstance(ptod.get(name), str): + ot = ot.to(getattr(torch, ptod[name])).contiguous() + out[name] = ot + return out + +# ── Data loading ────────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + hdr = np.fromfile(file, dtype=" Tensor: + chunks, rem = [], n + while rem > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._next(); continue + k = min(rem, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k; rem -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum: int): + local = global_tokens // (self.world_size * grad_accum) + span = local + 1 + chunk = self.stream.take(span * self.world_size) + s = self.rank * span + raw = chunk[s:s + span].to(torch.int64) + x = raw[:-1].reshape(-1, seq_len) + y = raw[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ── Model ───────────────────────────────────────────────────────────────────── + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), b) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__(); self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, p in module.named_parameters(): + if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and p.dtype != torch.float32: + p.data = p.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + self._sl, self._cos, self._sin = 0, None, None + def forward(self, sl: int, device, dtype): + if self._cos is None or self._sl != sl or self._cos.device != device: + t = torch.arange(sl, device=device, dtype=self.inv_freq.dtype) + f = torch.outer(t, self.inv_freq.to(device)) + self._cos = f.cos()[None, None, :, :]; self._sin = f.sin()[None, None, :, :] + self._sl = sl + return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) + +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def fake_quant_per_row(w: Tensor) -> Tensor: + """Per-row percentile int8 QAT with straight-through estimator.""" + if w.ndim < 2 or w.numel() == 0: return w + w32 = w.float() + clip = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + wq = (w32.clamp(-clip[:, None], clip[:, None]) / sc[:, None]).round().clamp(-127, 127) * sc[:, None] + return w + (wq.to(w.dtype) - w).detach() + +class SharedDiffAttention(nn.Module): + """Differential Attention V2 + per-loop low-rank Q delta + GQA.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, rope_base: float, rank: int): + super().__init__() + assert dim % nh == 0 and nh % nkv == 0 + self.nh, self.nkv, self.n_rep = nh, nkv, nh // nkv + self.hd = dim // nh + self.nmem = n_mem + + self.c_q = CastedLinear(dim, dim * 2, bias=False) + self.c_k = CastedLinear(dim, nkv * self.hd * 2, bias=False) + self.c_v = CastedLinear(dim, nkv * self.hd, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + li = 0.8 - 0.6 * math.exp(-0.3) + self.lambda_init = li + self.lambda_q1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_q2 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k2 = nn.Parameter(torch.randn(nh) * 0.1) + + self.q_gain = nn.Parameter(torch.ones(nh, dtype=torch.float32)) + self.delta_A = nn.Parameter(torch.zeros(n_loops, dim * 2, rank)) + self.delta_B = nn.Parameter(torch.zeros(n_loops, rank, dim)) + nn.init.normal_(self.delta_A, std=0.01) + + self.rotary = Rotary(self.hd, base=rope_base) + + def forward(self, x: Tensor, loop_idx: int) -> Tensor: + B, T, C = x.shape + H, KVH, D, NM = self.nh, self.nkv, self.hd, self.nmem + sl = T - NM + + # Q with per-loop low-rank delta + QAT on effective weight + wq = self.c_q.weight + (self.delta_A[loop_idx] @ self.delta_B[loop_idx]) + wq = fake_quant_per_row(wq) + qp = F.linear(x, wq.to(x.dtype)) + q1 = qp[..., :C].reshape(B, T, H, D).transpose(1, 2) + q2 = qp[..., C:].reshape(B, T, H, D).transpose(1, 2) + + # K, V: fresh every loop from updated hidden state + kv = self.c_k(x) + k1 = kv[..., :KVH*D].reshape(B, T, KVH, D).transpose(1, 2) + k2 = kv[..., KVH*D:].reshape(B, T, KVH, D).transpose(1, 2) + v = self.c_v(x).reshape(B, T, KVH, D).transpose(1, 2) + + # QK RMSNorm before RoPE (baseline technique) + q1 = F.rms_norm(q1, (D,)); q2 = F.rms_norm(q2, (D,)) + k1 = F.rms_norm(k1, (D,)); k2 = F.rms_norm(k2, (D,)) + + # RoPE: sequence positions ONLY — memory tokens excluded + cos, sin = self.rotary(sl, x.device, q1.dtype) + def rope_seq(q, k): + qm, qs = q[:, :, :NM], q[:, :, NM:] + km, ks = k[:, :, :NM], k[:, :, NM:] + return (torch.cat([qm, apply_rope(qs, cos, sin)], dim=2), + torch.cat([km, apply_rope(ks, cos, sin)], dim=2)) + q1, k1 = rope_seq(q1, k1) + q2, k2 = rope_seq(q2, k2) + + # Q gain + GQA + g = self.q_gain.to(q1.dtype)[None, :, None, None] + q1 = q1 * g; q2 = q2 * g + k1 = k1.repeat_interleave(self.n_rep, dim=1) + k2 = k2.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + # Differential attention: cancel noise via subtraction + lam = (torch.exp(self.lambda_q1) * torch.exp(self.lambda_k1) + - torch.exp(self.lambda_q2) * torch.exp(self.lambda_k2) + + self.lambda_init).to(q1.dtype)[None, :, None, None] + a1 = F.scaled_dot_product_attention(q1, k1, v, is_causal=True) + a2 = F.scaled_dot_product_attention(q2, k2, v, is_causal=True) + out = (a1 - lam * a2).transpose(1, 2).contiguous().view(B, T, C) + return self.proj(out) + +class ReluSquaredMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + h = dim * mlp_mult + self.fc = CastedLinear(dim, h, bias=False) + self.proj = CastedLinear(h, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class RecurrentBlock(nn.Module): + """Shared block with all per-loop parameters.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, mlp_mult: int, rope_base: float, rank: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SharedDiffAttention(dim, nh, nkv, n_loops, n_mem, rope_base, rank) + self.mlp = ReluSquaredMLP(dim, mlp_mult) + + # Per-loop scale factors (learned attention/MLP gate per depth) + self.attn_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + + # resid_mix: learned interpolation between current x and original x0 + # [0] = weight on current state, [1] = weight on original input + # Initialised to carry forward (mix[0]=1, mix[1]=0), learns to reset selectively + self.resid_mix = nn.Parameter( + torch.stack([torch.ones(n_loops, dim), torch.zeros(n_loops, dim)]).float() + ) + + # Loop position embedding: activation-level depth signal + # Complementary to weight-level low-rank delta + self.loop_embed = nn.Embedding(n_loops, dim) + nn.init.normal_(self.loop_embed.weight, std=0.02) + + def forward(self, x: Tensor, x0: Tensor, loop_idx: int) -> Tensor: + # resid_mix: learned balance of current state vs original input (generalises input injection) + mix = self.resid_mix[:, loop_idx, :].to(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # Loop position embedding (activation-level specialisation) + x = x + self.loop_embed.weight[loop_idx].to(x.dtype)[None, None, :] + + # Attention + MLP with per-loop scale + a = self.attn(self.attn_norm(x), loop_idx) + x = x + self.attn_scale[loop_idx].to(x.dtype)[None, None, :] * a + x = x + self.mlp_scale[loop_idx].to(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class RLM(nn.Module): + """ + Recurrent Language Model. + Single shared block × N_LOOPS with deep supervision and U-Net skips. + """ + def __init__(self, args: Hyperparameters): + super().__init__() + dim, nl, nm = args.model_dim, args.n_loops, args.n_memory + self.n_loops, self.n_memory = nl, nm + self.dim = dim + self.logit_softcap = args.logit_softcap + self.deep_sup_gamma = args.deep_sup_gamma + + self.n_enc = nl // 2 + self.n_skip = min(self.n_enc, nl - self.n_enc) + + self.tok_emb = nn.Embedding(args.vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + self.memory = nn.Parameter(torch.randn(1, nm, dim) * 0.02) + + self.block = RecurrentBlock( + dim, args.num_heads, args.num_kv_heads, nl, nm, + args.mlp_mult, args.rope_base, args.lora_rank + ) + + self.skip_weights = nn.Parameter( + torch.ones(self.n_skip, dim, dtype=torch.float32) + if self.n_skip > 0 else torch.zeros(0, dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + + # Zero-init output projections for stability at loop depth + nn.init.zeros_(self.block.attn.proj.weight) + nn.init.zeros_(self.block.mlp.proj.weight) + + def _logits(self, x: Tensor) -> Tensor: + h = self.final_norm(x[:, self.n_memory:, :]) + l = F.linear(h.reshape(-1, self.dim), self.tok_emb.weight) + return self.logit_softcap * torch.tanh(l / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + B, T = input_ids.shape + tok = F.rms_norm(self.tok_emb(input_ids), (self.dim,)) + x = torch.cat([self.memory.expand(B, -1, -1), tok], dim=1) + x0 = x + + gamma = self.deep_sup_gamma + total_loss = torch.zeros((), device=input_ids.device) + weight_sum = 0.0 + targets = target_ids.reshape(-1) + skips = [] + + for i in range(self.n_loops): + # Gradient checkpointing: recompute activations during backward + x = grad_ckpt(self.block, x, x0, i, use_reentrant=False) + + # U-Net: encoder half saves, decoder half consumes in reverse + if i < self.n_enc: + skips.append(x) + elif self.n_skip > 0 and skips: + di = i - self.n_enc + if di < self.n_skip: + sw = self.skip_weights[di].to(x.dtype)[None, None, :] + x = x + sw * skips[self.n_enc - 1 - di] + + # Deep supervision: weighted loss at every loop + # Later loops weighted higher (gamma^(N-1-i)), but all contribute + w = gamma ** (self.n_loops - 1 - i) + total_loss = total_loss + w * F.cross_entropy( + self._logits(x).float(), targets, reduction="mean" + ) + weight_sum += w + + return total_loss / weight_sum + +# ── Training ────────────────────────────────────────────────────────────────── + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + global zeropower_via_newtonschulz5 + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ── Distributed + CUDA setup ────────────────────────────────────────────── + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required. For CPU testing use train_gpt_local.py") + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(True); enable_math_sdp(True) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master else None + + def log0(msg: str, console: bool = True) -> None: + if not master: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + + # ── Seed + tokenizer ────────────────────────────────────────────────────── + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"Tokenizer vocab {sp.vocab_size()} != VOCAB_SIZE {args.vocab_size}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_lut, space_lut, bnd_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_tokens:{val_tokens.numel()-1} train_seq_len:{args.train_seq_len}") + + # ── Model ───────────────────────────────────────────────────────────────── + base_model = RLM(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(base_model) + compiled = torch.compile(base_model, dynamic=False, fullgraph=False) + model: nn.Module = (DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) + if distributed else compiled) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} (~{n_params/1e6:.2f}M)") + log0(f"n_loops:{args.n_loops} n_memory:{args.n_memory} dim:{args.model_dim} " + f"seq_len:{args.train_seq_len}") + + # ── Optimizers ──────────────────────────────────────────────────────────── + block_named = list(base_model.block.named_parameters()) + matrix_params = [p for n, p in block_named + if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = ([p for n, p in block_named + if p.ndim != 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + + [base_model.skip_weights, base_model.memory, base_model.final_norm.eps + if hasattr(base_model.final_norm, 'eps') and + isinstance(base_model.final_norm.eps, nn.Parameter) else None]) + scalar_params = [p for p in scalar_params if p is not None] + + opt_embed = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_muon = Muon(matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_embed, opt_muon, opt_scalar] + + log0(f"matrix_params:{sum(p.numel() for p in matrix_params):,} " + f"scalar_params:{sum(p.numel() for p in scalar_params):,}") + + # ── Checkpoint resume (thermal crash protection) ────────────────────────── + ckpt_path = f"checkpoint_{args.run_id}.pt" + start_step = 0 + training_time_ms = 0.0 + if os.path.exists(ckpt_path) and master: + log0(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + base_model.load_state_dict(ckpt["model"]) + for opt, sd in zip(optimizers, ckpt["optimizers"]): + opt.load_state_dict(sd) + start_step = ckpt["step"] + training_time_ms = ckpt.get("training_time_ms", 0.0) + log0(f"Resumed at step {start_step}") + if distributed: dist.barrier() + + # ── Data loader ─────────────────────────────────────────────────────────── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad(): + for o in optimizers: o.zero_grad(set_to_none=True) + + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_scale(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wc_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ + if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = args.warmdown_iters * step_ms + rem_ms = max(max_wc_ms - elapsed_ms, 0.0) + return rem_ms / max(wd_ms, 1e-9) if rem_ms <= wd_ms else 1.0 + + # ── Warmup-then-restore ─────────────────────────────────────────────────── + if args.warmup_steps > 0 and start_step == 0: + init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + (model(x, y) * grad_scale).backward() + for o in optimizers: o.step() + zero_grad() + log0(f"warmup_complete:{args.warmup_steps} steps") + base_model.load_state_dict(init_model, strict=True) + for o, sd in zip(optimizers, init_opts): o.load_state_dict(sd) + zero_grad() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ──────────────────────────────────────────────────── + stop_after: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + for step in range(start_step, args.iterations + 1): + last = step == args.iterations or (stop_after is not None and step >= stop_after) + + # Validation + if ((args.val_loss_every > 0 and step % args.val_loss_every == 0 and step >= 200) or last): + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + + if last: break + + # Checkpoint save (thermal crash protection) + if args.save_every > 0 and step > start_step and step % args.save_every == 0 and master: + torch.cuda.synchronize() + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + torch.save({"step": step, "model": base_model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "training_time_ms": elapsed_now}, ckpt_path) + log0(f"checkpoint_saved:step={step}") + + # LR update + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_scale(step, elapsed_ms) + for o in optimizers: + for g in o.param_groups: g["lr"] = g["base_lr"] * scale + + # Muon momentum warmup + frac = min(step / max(args.muon_momentum_warmup_steps, 1), 1.0) + mum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for g in opt_muon.param_groups: g["momentum"] = mum + + # Forward + backward with gradient accumulation + zero_grad() + train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for o in optimizers: o.step() + zero_grad() + + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step+1}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ms:.0f}ms step_avg:{approx_ms/(step+1):.2f}ms lr_scale:{scale:.4f}") + + reached = max_wc_ms is not None and approx_ms >= max_wc_ms + if distributed and max_wc_ms is not None: + rt = torch.tensor(int(reached), device=device) + dist.all_reduce(rt, op=dist.ReduceOp.MAX) + reached = bool(rt.item()) + if stop_after is None and reached: + stop_after = step + 1 + + log0(f"peak_memory:{torch.cuda.max_memory_allocated()//1024//1024}MiB " + f"reserved:{torch.cuda.max_memory_reserved()//1024//1024}MiB") + + # ── Serialisation + roundtrip validation ────────────────────────────────── + if master: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + model_bytes = os.path.getsize("final_model.pt") + log0(f"raw_model:{model_bytes} code:{code_bytes} total:{model_bytes+code_bytes}") + + quant_obj, qstats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + if master: + with open("final_model.int8.ptz", "wb") as f: f.write(blob) + qbytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = qstats["baseline_tensor_bytes"] / max(qstats["int8_payload_bytes"], 1) + log0(f"int8_zlib:{qbytes} code:{code_bytes} total:{qbytes+code_bytes} " + f"payload_ratio:{ratio:.2f}x budget_pct:{(qbytes+code_bytes)/16e6*100:.1f}%") + + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: blob_disk = f.read() + qs2 = torch.load(io.BytesIO(zlib.decompress(blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(qs2), strict=True) + torch.cuda.synchronize(); te = time.perf_counter() + qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + torch.cuda.synchronize() + log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-te):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvbpb:.8f}") + + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() diff --git a/records/track_non_record_16mb/2026-03-20_DART_DiffAttn_Recurrent/train_log.txt b/records/track_non_record_16mb/2026-03-20_DART_DiffAttn_Recurrent/train_log.txt new file mode 100644 index 000000000..f5734894c --- /dev/null +++ b/records/track_non_record_16mb/2026-03-20_DART_DiffAttn_Recurrent/train_log.txt @@ -0,0 +1,26303 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Test-time training (LoRA) hyperparameters. + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x) + (q_delta if q_delta is not None else 0) + k = self.c_k(x) + v = self.c_v(x) + (v_delta if v_delta is not None else 0) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = self.attn_norm(x) + qd = q_delta_fn(n) if q_delta_fn is not None else None + vd = v_delta_fn(n) if v_delta_fn is not None else None + attn_out = self.attn(n, qd, vd) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + qd = lora.q_loras[i] if lora else None + vd = lora.v_loras[i] if lora else None + x = self.blocks[i](x, x0, qd, vd) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + qd = lora.q_loras[bi] if lora else None + vd = lora.v_loras[bi] if lora else None + x = self.blocks[bi](x, x0, qd, vd) + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + (lora.lm_head_lora(x) if lora else 0) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + if lora: + bsz, sl, V = logits.shape + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none").reshape(bsz, sl) + return F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), target_ids.reshape(-1), reduction="mean") + + +# ----------------------------- +# TEST-TIME TRAINING (LoRA) +# ----------------------------- +# +# At evaluation time, we adapt per-document low-rank adapters on the validation data. +# Each document gets its own adapter, so there is no inter-document dependency. + +BOS_ID = 1 + +class BatchedLinearLoRA(nn.Module): + """LoRA for a linear layer, with independent weights per batch element. + Computes x @ Aᵀ @ Bᵀ = x @ (BA)ᵀ, i.e. the LoRA delta is ΔW = BA.""" + def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): + super().__init__() + self.in_features = in_features + self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) # down-projection + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) # up-projection + self.reset() + + def forward(self, x: Tensor) -> Tensor: + return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) # (bsz, T, out) + + def reset(self) -> None: + bound = 1.0 / math.sqrt(self.in_features) + with torch.no_grad(): + self.A.uniform_(-bound, bound) # kaiming-uniform + self.B.zero_() + +class BatchedTTTLoRA(nn.Module): + """All LoRA adapters for one batch: LM head and Q/V per block.""" + def __init__(self, bsz: int, model: GPT, rank: int): + super().__init__() + dim = model.tok_emb.embedding_dim + vocab = model.tok_emb.num_embeddings + self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) + self.q_loras = nn.ModuleList() + self.v_loras = nn.ModuleList() + for block in model.blocks: + self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) + self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) + + def reset(self) -> None: + for m in self.modules(): + if isinstance(m, BatchedLinearLoRA): + m.reset() + +def _reset_ttt_optimizer(opt): + for group in opt.param_groups: + for p in group['params']: + s = opt.state.get(p) + if not s: # Fresh state. + continue + s['exp_avg'].zero_() + s['exp_avg_sq'].zero_() + s['step'].fill_(0) + +def _build_ttt_optimizer(lora, args: Hyperparameters): + return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) + +def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[int, int]]: + """Return (start_offset, length) for each document, identified by BOS boundaries. + + If include_next_bos is True, include next document's BOS (to match continuous-stream + eval token count exactly). + """ + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() + if include_next_bos and i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + +def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): + """Return (win_start, win_len, chunk_offset, chunk_len) for chunk `ci` of a doc.""" + chunk_start = ci * chunk_size + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + +def _accumulate_bpb( + ptl: Tensor, x: Tensor, y: Tensor, + batch_i: int, chunk_offset: int, chunk_len: int, + base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + loss_sum: Tensor, byte_sum: Tensor, token_count: Tensor, +): + """Add one doc-chunk's contribution to the running BPB accumulators.""" + lbl = ptl[batch_i, chunk_offset:chunk_offset + chunk_len].to(torch.float64) + prev = x[batch_i, chunk_offset:chunk_offset + chunk_len] + tgt = y[batch_i, chunk_offset:chunk_offset + chunk_len] + tok_bytes = base_bytes_lut[tgt].to(torch.float64) + tok_bytes += has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev] + loss_sum += lbl.sum() + byte_sum += tok_bytes.sum() + token_count += chunk_len + +def eval_val_ttt_lora( + args: Hyperparameters, + base_model: GPT, + rank: int, + world_size: int, + device: torch.device, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Evaluate with batched LoRA test-time training. Returns (val_loss, val_bpb).""" + # Load validation tokens and find document boundaries + files = sorted(glob.glob(args.val_files)) + all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) + docs = _find_docs(all_tokens) + + # Each rank takes a contiguous slice of documents + rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] + chunk_size = args.ttt_chunk_size + eval_seq_len = args.ttt_eval_seq_len + batch_size = args.ttt_batch_size + lora_rank = args.ttt_lora_rank + + rank_docs.sort(key=lambda d: (d[1] - 2) // chunk_size) + + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + + lora = BatchedTTTLoRA(batch_size, base_model, lora_rank).to(device) + opt = _build_ttt_optimizer(lora, args) + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + + for bi in range(0, len(rank_docs), batch_size): + batch = rank_docs[bi:bi + batch_size] + bsz = len(batch) + + if bsz == batch_size: + cur_lora, cur_opt = lora, opt + cur_lora.reset() + _reset_ttt_optimizer(cur_opt) + else: + cur_lora = BatchedTTTLoRA(bsz, base_model, lora_rank).to(device) + cur_opt = _build_ttt_optimizer(cur_lora, args) + + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + + for ci in range(max_nc): + chunk_stats = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) + context_size, chunk_offset = chunk_stats[1], chunk_stats[2] + + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + + x = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + y = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + doc_info = [] # (chunk_offset, chunk_len) per doc + for b in range(bsz): + if not active[b]: + doc_info.append((0, 0)) + continue + ds, dl = batch[b] + ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) + chunk = all_tokens[ds + ws: ds + ws + wl + 1] + toks = chunk.to(dtype=torch.int64, device=device) + x[b, :wl] = toks[:-1] + y[b, :wl] = toks[1:] + doc_info.append((co, cl)) + + # Forward pass (keep grad graph alive only when we need to train) + if needs_train: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + else: + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + + # Score: accumulate loss and byte counts for BPB (before training on chunk) + with torch.no_grad(): + for b in range(bsz): + if not active[b]: + continue + co, cl = doc_info[b] + _accumulate_bpb( + ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, loss_sum, byte_sum, token_count) + + # Train: one Adam step on the LoRA params using this chunk's loss + if needs_train: + mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device) + per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1) + cur_opt.zero_grad() + (per_doc * mask).sum().backward() + cur_opt.step() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + + val_loss = float(loss_sum.item() / token_count.item()) + val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) + return val_loss, val_bpb + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + if isinstance(module, Rotary): + module.inv_freq.data = module.inv_freq.data.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # LoRA test-time training evaluation (the competition score) + torch._dynamo.reset() + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( + args, base_model, rank, world_size, device, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_ttt_lora val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +Running PyTorch 2.10.0+cu128 +Fri Mar 20 06:10:15 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 38C P0 25W / 70W | 105MiB / 15360MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 2706 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/content/drive/MyDrive/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:10 +val_loader:shards pattern=/content/drive/MyDrive/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26497096 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:65536 train_seq_len:1024 iterations:10000 warmup_steps:20 max_wallclock_seconds:0.000 +seed:1337 +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Test-time training (LoRA) hyperparameters. + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x) + (q_delta if q_delta is not None else 0) + k = self.c_k(x) + v = self.c_v(x) + (v_delta if v_delta is not None else 0) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = self.attn_norm(x) + qd = q_delta_fn(n) if q_delta_fn is not None else None + vd = v_delta_fn(n) if v_delta_fn is not None else None + attn_out = self.attn(n, qd, vd) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + qd = lora.q_loras[i] if lora else None + vd = lora.v_loras[i] if lora else None + x = self.blocks[i](x, x0, qd, vd) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + qd = lora.q_loras[bi] if lora else None + vd = lora.v_loras[bi] if lora else None + x = self.blocks[bi](x, x0, qd, vd) + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + (lora.lm_head_lora(x) if lora else 0) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + if lora: + bsz, sl, V = logits.shape + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none").reshape(bsz, sl) + return F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), target_ids.reshape(-1), reduction="mean") + + +# ----------------------------- +# TEST-TIME TRAINING (LoRA) +# ----------------------------- +# +# At evaluation time, we adapt per-document low-rank adapters on the validation data. +# Each document gets its own adapter, so there is no inter-document dependency. + +BOS_ID = 1 + +class BatchedLinearLoRA(nn.Module): + """LoRA for a linear layer, with independent weights per batch element. + Computes x @ Aᵀ @ Bᵀ = x @ (BA)ᵀ, i.e. the LoRA delta is ΔW = BA.""" + def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): + super().__init__() + self.in_features = in_features + self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) # down-projection + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) # up-projection + self.reset() + + def forward(self, x: Tensor) -> Tensor: + return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) # (bsz, T, out) + + def reset(self) -> None: + bound = 1.0 / math.sqrt(self.in_features) + with torch.no_grad(): + self.A.uniform_(-bound, bound) # kaiming-uniform + self.B.zero_() + +class BatchedTTTLoRA(nn.Module): + """All LoRA adapters for one batch: LM head and Q/V per block.""" + def __init__(self, bsz: int, model: GPT, rank: int): + super().__init__() + dim = model.tok_emb.embedding_dim + vocab = model.tok_emb.num_embeddings + self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) + self.q_loras = nn.ModuleList() + self.v_loras = nn.ModuleList() + for block in model.blocks: + self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) + self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) + + def reset(self) -> None: + for m in self.modules(): + if isinstance(m, BatchedLinearLoRA): + m.reset() + +def _reset_ttt_optimizer(opt): + for group in opt.param_groups: + for p in group['params']: + s = opt.state.get(p) + if not s: # Fresh state. + continue + s['exp_avg'].zero_() + s['exp_avg_sq'].zero_() + s['step'].fill_(0) + +def _build_ttt_optimizer(lora, args: Hyperparameters): + return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) + +def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[int, int]]: + """Return (start_offset, length) for each document, identified by BOS boundaries. + + If include_next_bos is True, include next document's BOS (to match continuous-stream + eval token count exactly). + """ + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() + if include_next_bos and i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + +def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): + """Return (win_start, win_len, chunk_offset, chunk_len) for chunk `ci` of a doc.""" + chunk_start = ci * chunk_size + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + +def _accumulate_bpb( + ptl: Tensor, x: Tensor, y: Tensor, + batch_i: int, chunk_offset: int, chunk_len: int, + base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + loss_sum: Tensor, byte_sum: Tensor, token_count: Tensor, +): + """Add one doc-chunk's contribution to the running BPB accumulators.""" + lbl = ptl[batch_i, chunk_offset:chunk_offset + chunk_len].to(torch.float64) + prev = x[batch_i, chunk_offset:chunk_offset + chunk_len] + tgt = y[batch_i, chunk_offset:chunk_offset + chunk_len] + tok_bytes = base_bytes_lut[tgt].to(torch.float64) + tok_bytes += has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev] + loss_sum += lbl.sum() + byte_sum += tok_bytes.sum() + token_count += chunk_len + +def eval_val_ttt_lora( + args: Hyperparameters, + base_model: GPT, + rank: int, + world_size: int, + device: torch.device, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Evaluate with batched LoRA test-time training. Returns (val_loss, val_bpb).""" + # Load validation tokens and find document boundaries + files = sorted(glob.glob(args.val_files)) + all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) + docs = _find_docs(all_tokens) + + # Each rank takes a contiguous slice of documents + rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] + chunk_size = args.ttt_chunk_size + eval_seq_len = args.ttt_eval_seq_len + batch_size = args.ttt_batch_size + lora_rank = args.ttt_lora_rank + + rank_docs.sort(key=lambda d: (d[1] - 2) // chunk_size) + + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + + lora = BatchedTTTLoRA(batch_size, base_model, lora_rank).to(device) + opt = _build_ttt_optimizer(lora, args) + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + + for bi in range(0, len(rank_docs), batch_size): + batch = rank_docs[bi:bi + batch_size] + bsz = len(batch) + + if bsz == batch_size: + cur_lora, cur_opt = lora, opt + cur_lora.reset() + _reset_ttt_optimizer(cur_opt) + else: + cur_lora = BatchedTTTLoRA(bsz, base_model, lora_rank).to(device) + cur_opt = _build_ttt_optimizer(cur_lora, args) + + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + + for ci in range(max_nc): + chunk_stats = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) + context_size, chunk_offset = chunk_stats[1], chunk_stats[2] + + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + + x = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + y = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + doc_info = [] # (chunk_offset, chunk_len) per doc + for b in range(bsz): + if not active[b]: + doc_info.append((0, 0)) + continue + ds, dl = batch[b] + ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) + chunk = all_tokens[ds + ws: ds + ws + wl + 1] + toks = chunk.to(dtype=torch.int64, device=device) + x[b, :wl] = toks[:-1] + y[b, :wl] = toks[1:] + doc_info.append((co, cl)) + + # Forward pass (keep grad graph alive only when we need to train) + if needs_train: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + else: + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + + # Score: accumulate loss and byte counts for BPB (before training on chunk) + with torch.no_grad(): + for b in range(bsz): + if not active[b]: + continue + co, cl = doc_info[b] + _accumulate_bpb( + ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, loss_sum, byte_sum, token_count) + + # Train: one Adam step on the LoRA params using this chunk's loss + if needs_train: + mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device) + per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1) + cur_opt.zero_grad() + (per_doc * mask).sum().backward() + cur_opt.step() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + + val_loss = float(loss_sum.item() / token_count.item()) + val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) + return val_loss, val_bpb + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + if isinstance(module, Rotary): + module.inv_freq.data = module.inv_freq.data.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # LoRA test-time training evaluation (the competition score) + torch._dynamo.reset() + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( + args, base_model, rank, world_size, device, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_ttt_lora val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +Running PyTorch 2.10.0+cu128 +Fri Mar 20 06:14:50 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 40C P0 26W / 70W | 105MiB / 15360MiB | 4% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 3918 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/content/drive/MyDrive/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:10 +val_loader:shards pattern=/content/drive/MyDrive/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26497096 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:65536 train_seq_len:1024 iterations:10000 warmup_steps:20 max_wallclock_seconds:0.000 +seed:1337 +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Test-time training (LoRA) hyperparameters. + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x) + (q_delta if q_delta is not None else 0) + k = self.c_k(x) + v = self.c_v(x) + (v_delta if v_delta is not None else 0) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = self.attn_norm(x) + qd = q_delta_fn(n) if q_delta_fn is not None else None + vd = v_delta_fn(n) if v_delta_fn is not None else None + attn_out = self.attn(n, qd, vd) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + qd = lora.q_loras[i] if lora else None + vd = lora.v_loras[i] if lora else None + x = self.blocks[i](x, x0, qd, vd) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + qd = lora.q_loras[bi] if lora else None + vd = lora.v_loras[bi] if lora else None + x = self.blocks[bi](x, x0, qd, vd) + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + (lora.lm_head_lora(x) if lora else 0) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + if lora: + bsz, sl, V = logits.shape + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none").reshape(bsz, sl) + return F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), target_ids.reshape(-1), reduction="mean") + + +# ----------------------------- +# TEST-TIME TRAINING (LoRA) +# ----------------------------- +# +# At evaluation time, we adapt per-document low-rank adapters on the validation data. +# Each document gets its own adapter, so there is no inter-document dependency. + +BOS_ID = 1 + +class BatchedLinearLoRA(nn.Module): + """LoRA for a linear layer, with independent weights per batch element. + Computes x @ Aᵀ @ Bᵀ = x @ (BA)ᵀ, i.e. the LoRA delta is ΔW = BA.""" + def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): + super().__init__() + self.in_features = in_features + self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) # down-projection + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) # up-projection + self.reset() + + def forward(self, x: Tensor) -> Tensor: + return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) # (bsz, T, out) + + def reset(self) -> None: + bound = 1.0 / math.sqrt(self.in_features) + with torch.no_grad(): + self.A.uniform_(-bound, bound) # kaiming-uniform + self.B.zero_() + +class BatchedTTTLoRA(nn.Module): + """All LoRA adapters for one batch: LM head and Q/V per block.""" + def __init__(self, bsz: int, model: GPT, rank: int): + super().__init__() + dim = model.tok_emb.embedding_dim + vocab = model.tok_emb.num_embeddings + self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) + self.q_loras = nn.ModuleList() + self.v_loras = nn.ModuleList() + for block in model.blocks: + self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) + self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) + + def reset(self) -> None: + for m in self.modules(): + if isinstance(m, BatchedLinearLoRA): + m.reset() + +def _reset_ttt_optimizer(opt): + for group in opt.param_groups: + for p in group['params']: + s = opt.state.get(p) + if not s: # Fresh state. + continue + s['exp_avg'].zero_() + s['exp_avg_sq'].zero_() + s['step'].fill_(0) + +def _build_ttt_optimizer(lora, args: Hyperparameters): + return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) + +def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[int, int]]: + """Return (start_offset, length) for each document, identified by BOS boundaries. + + If include_next_bos is True, include next document's BOS (to match continuous-stream + eval token count exactly). + """ + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() + if include_next_bos and i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + +def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): + """Return (win_start, win_len, chunk_offset, chunk_len) for chunk `ci` of a doc.""" + chunk_start = ci * chunk_size + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + +def _accumulate_bpb( + ptl: Tensor, x: Tensor, y: Tensor, + batch_i: int, chunk_offset: int, chunk_len: int, + base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + loss_sum: Tensor, byte_sum: Tensor, token_count: Tensor, +): + """Add one doc-chunk's contribution to the running BPB accumulators.""" + lbl = ptl[batch_i, chunk_offset:chunk_offset + chunk_len].to(torch.float64) + prev = x[batch_i, chunk_offset:chunk_offset + chunk_len] + tgt = y[batch_i, chunk_offset:chunk_offset + chunk_len] + tok_bytes = base_bytes_lut[tgt].to(torch.float64) + tok_bytes += has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev] + loss_sum += lbl.sum() + byte_sum += tok_bytes.sum() + token_count += chunk_len + +def eval_val_ttt_lora( + args: Hyperparameters, + base_model: GPT, + rank: int, + world_size: int, + device: torch.device, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Evaluate with batched LoRA test-time training. Returns (val_loss, val_bpb).""" + # Load validation tokens and find document boundaries + files = sorted(glob.glob(args.val_files)) + all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) + docs = _find_docs(all_tokens) + + # Each rank takes a contiguous slice of documents + rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] + chunk_size = args.ttt_chunk_size + eval_seq_len = args.ttt_eval_seq_len + batch_size = args.ttt_batch_size + lora_rank = args.ttt_lora_rank + + rank_docs.sort(key=lambda d: (d[1] - 2) // chunk_size) + + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + + lora = BatchedTTTLoRA(batch_size, base_model, lora_rank).to(device) + opt = _build_ttt_optimizer(lora, args) + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + + for bi in range(0, len(rank_docs), batch_size): + batch = rank_docs[bi:bi + batch_size] + bsz = len(batch) + + if bsz == batch_size: + cur_lora, cur_opt = lora, opt + cur_lora.reset() + _reset_ttt_optimizer(cur_opt) + else: + cur_lora = BatchedTTTLoRA(bsz, base_model, lora_rank).to(device) + cur_opt = _build_ttt_optimizer(cur_lora, args) + + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + + for ci in range(max_nc): + chunk_stats = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) + context_size, chunk_offset = chunk_stats[1], chunk_stats[2] + + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + + x = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + y = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + doc_info = [] # (chunk_offset, chunk_len) per doc + for b in range(bsz): + if not active[b]: + doc_info.append((0, 0)) + continue + ds, dl = batch[b] + ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) + chunk = all_tokens[ds + ws: ds + ws + wl + 1] + toks = chunk.to(dtype=torch.int64, device=device) + x[b, :wl] = toks[:-1] + y[b, :wl] = toks[1:] + doc_info.append((co, cl)) + + # Forward pass (keep grad graph alive only when we need to train) + if needs_train: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + else: + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + + # Score: accumulate loss and byte counts for BPB (before training on chunk) + with torch.no_grad(): + for b in range(bsz): + if not active[b]: + continue + co, cl = doc_info[b] + _accumulate_bpb( + ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, loss_sum, byte_sum, token_count) + + # Train: one Adam step on the LoRA params using this chunk's loss + if needs_train: + mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device) + per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1) + cur_opt.zero_grad() + (per_doc * mask).sum().backward() + cur_opt.step() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + + val_loss = float(loss_sum.item() / token_count.item()) + val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) + return val_loss, val_bpb + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + if isinstance(module, Rotary): + module.inv_freq.data = module.inv_freq.data.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # LoRA test-time training evaluation (the competition score) + torch._dynamo.reset() + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( + args, base_model, rank, world_size, device, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_ttt_lora val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +Running PyTorch 2.10.0+cu128 +Fri Mar 20 06:15:45 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 41C P0 25W / 70W | 105MiB / 15360MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 4208 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/content/drive/MyDrive/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:10 +val_loader:shards pattern=/content/drive/MyDrive/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26497096 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:65536 train_seq_len:1024 iterations:10000 warmup_steps:20 max_wallclock_seconds:0.000 +seed:1337 +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Test-time training (LoRA) hyperparameters. + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x) + (q_delta if q_delta is not None else 0) + k = self.c_k(x) + v = self.c_v(x) + (v_delta if v_delta is not None else 0) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = self.attn_norm(x) + qd = q_delta_fn(n) if q_delta_fn is not None else None + vd = v_delta_fn(n) if v_delta_fn is not None else None + attn_out = self.attn(n, qd, vd) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + qd = lora.q_loras[i] if lora else None + vd = lora.v_loras[i] if lora else None + x = self.blocks[i](x, x0, qd, vd) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + qd = lora.q_loras[bi] if lora else None + vd = lora.v_loras[bi] if lora else None + x = self.blocks[bi](x, x0, qd, vd) + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + (lora.lm_head_lora(x) if lora else 0) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + if lora: + bsz, sl, V = logits.shape + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none").reshape(bsz, sl) + return F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), target_ids.reshape(-1), reduction="mean") + + +# ----------------------------- +# TEST-TIME TRAINING (LoRA) +# ----------------------------- +# +# At evaluation time, we adapt per-document low-rank adapters on the validation data. +# Each document gets its own adapter, so there is no inter-document dependency. + +BOS_ID = 1 + +class BatchedLinearLoRA(nn.Module): + """LoRA for a linear layer, with independent weights per batch element. + Computes x @ Aᵀ @ Bᵀ = x @ (BA)ᵀ, i.e. the LoRA delta is ΔW = BA.""" + def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): + super().__init__() + self.in_features = in_features + self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) # down-projection + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) # up-projection + self.reset() + + def forward(self, x: Tensor) -> Tensor: + return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) # (bsz, T, out) + + def reset(self) -> None: + bound = 1.0 / math.sqrt(self.in_features) + with torch.no_grad(): + self.A.uniform_(-bound, bound) # kaiming-uniform + self.B.zero_() + +class BatchedTTTLoRA(nn.Module): + """All LoRA adapters for one batch: LM head and Q/V per block.""" + def __init__(self, bsz: int, model: GPT, rank: int): + super().__init__() + dim = model.tok_emb.embedding_dim + vocab = model.tok_emb.num_embeddings + self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) + self.q_loras = nn.ModuleList() + self.v_loras = nn.ModuleList() + for block in model.blocks: + self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) + self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) + + def reset(self) -> None: + for m in self.modules(): + if isinstance(m, BatchedLinearLoRA): + m.reset() + +def _reset_ttt_optimizer(opt): + for group in opt.param_groups: + for p in group['params']: + s = opt.state.get(p) + if not s: # Fresh state. + continue + s['exp_avg'].zero_() + s['exp_avg_sq'].zero_() + s['step'].fill_(0) + +def _build_ttt_optimizer(lora, args: Hyperparameters): + return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) + +def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[int, int]]: + """Return (start_offset, length) for each document, identified by BOS boundaries. + + If include_next_bos is True, include next document's BOS (to match continuous-stream + eval token count exactly). + """ + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() + if include_next_bos and i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + +def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): + """Return (win_start, win_len, chunk_offset, chunk_len) for chunk `ci` of a doc.""" + chunk_start = ci * chunk_size + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + +def _accumulate_bpb( + ptl: Tensor, x: Tensor, y: Tensor, + batch_i: int, chunk_offset: int, chunk_len: int, + base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + loss_sum: Tensor, byte_sum: Tensor, token_count: Tensor, +): + """Add one doc-chunk's contribution to the running BPB accumulators.""" + lbl = ptl[batch_i, chunk_offset:chunk_offset + chunk_len].to(torch.float64) + prev = x[batch_i, chunk_offset:chunk_offset + chunk_len] + tgt = y[batch_i, chunk_offset:chunk_offset + chunk_len] + tok_bytes = base_bytes_lut[tgt].to(torch.float64) + tok_bytes += has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev] + loss_sum += lbl.sum() + byte_sum += tok_bytes.sum() + token_count += chunk_len + +def eval_val_ttt_lora( + args: Hyperparameters, + base_model: GPT, + rank: int, + world_size: int, + device: torch.device, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Evaluate with batched LoRA test-time training. Returns (val_loss, val_bpb).""" + # Load validation tokens and find document boundaries + files = sorted(glob.glob(args.val_files)) + all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) + docs = _find_docs(all_tokens) + + # Each rank takes a contiguous slice of documents + rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] + chunk_size = args.ttt_chunk_size + eval_seq_len = args.ttt_eval_seq_len + batch_size = args.ttt_batch_size + lora_rank = args.ttt_lora_rank + + rank_docs.sort(key=lambda d: (d[1] - 2) // chunk_size) + + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + + lora = BatchedTTTLoRA(batch_size, base_model, lora_rank).to(device) + opt = _build_ttt_optimizer(lora, args) + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + + for bi in range(0, len(rank_docs), batch_size): + batch = rank_docs[bi:bi + batch_size] + bsz = len(batch) + + if bsz == batch_size: + cur_lora, cur_opt = lora, opt + cur_lora.reset() + _reset_ttt_optimizer(cur_opt) + else: + cur_lora = BatchedTTTLoRA(bsz, base_model, lora_rank).to(device) + cur_opt = _build_ttt_optimizer(cur_lora, args) + + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + + for ci in range(max_nc): + chunk_stats = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) + context_size, chunk_offset = chunk_stats[1], chunk_stats[2] + + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + + x = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + y = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + doc_info = [] # (chunk_offset, chunk_len) per doc + for b in range(bsz): + if not active[b]: + doc_info.append((0, 0)) + continue + ds, dl = batch[b] + ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) + chunk = all_tokens[ds + ws: ds + ws + wl + 1] + toks = chunk.to(dtype=torch.int64, device=device) + x[b, :wl] = toks[:-1] + y[b, :wl] = toks[1:] + doc_info.append((co, cl)) + + # Forward pass (keep grad graph alive only when we need to train) + if needs_train: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + else: + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + + # Score: accumulate loss and byte counts for BPB (before training on chunk) + with torch.no_grad(): + for b in range(bsz): + if not active[b]: + continue + co, cl = doc_info[b] + _accumulate_bpb( + ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, loss_sum, byte_sum, token_count) + + # Train: one Adam step on the LoRA params using this chunk's loss + if needs_train: + mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device) + per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1) + cur_opt.zero_grad() + (per_doc * mask).sum().backward() + cur_opt.step() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + + val_loss = float(loss_sum.item() / token_count.item()) + val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) + return val_loss, val_bpb + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + if isinstance(module, Rotary): + module.inv_freq.data = module.inv_freq.data.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # LoRA test-time training evaluation (the competition score) + torch._dynamo.reset() + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( + args, base_model, rank, world_size, device, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_ttt_lora val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +Running PyTorch 2.10.0+cu128 +Fri Mar 20 06:16:59 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 41C P0 25W / 70W | 105MiB / 15360MiB | 4% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 4523 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/content/drive/MyDrive/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:10 +val_loader:shards pattern=/content/drive/MyDrive/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26497096 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:65536 train_seq_len:1024 iterations:10000 warmup_steps:20 max_wallclock_seconds:0.000 +seed:1337 +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + + # Test-time training (LoRA) hyperparameters. + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x) + (q_delta if q_delta is not None else 0) + k = self.c_k(x) + v = self.c_v(x) + (v_delta if v_delta is not None else 0) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = self.attn_norm(x) + qd = q_delta_fn(n) if q_delta_fn is not None else None + vd = v_delta_fn(n) if v_delta_fn is not None else None + attn_out = self.attn(n, qd, vd) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + qd = lora.q_loras[i] if lora else None + vd = lora.v_loras[i] if lora else None + x = self.blocks[i](x, x0, qd, vd) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + qd = lora.q_loras[bi] if lora else None + vd = lora.v_loras[bi] if lora else None + x = self.blocks[bi](x, x0, qd, vd) + x = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(x, self.tok_emb.weight) + else: + logits = self.lm_head(x) + logits = logits + (lora.lm_head_lora(x) if lora else 0) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + if lora: + bsz, sl, V = logits.shape + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none").reshape(bsz, sl) + return F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), target_ids.reshape(-1), reduction="mean") + + +# ----------------------------- +# TEST-TIME TRAINING (LoRA) +# ----------------------------- +# +# At evaluation time, we adapt per-document low-rank adapters on the validation data. +# Each document gets its own adapter, so there is no inter-document dependency. + +BOS_ID = 1 + +class BatchedLinearLoRA(nn.Module): + """LoRA for a linear layer, with independent weights per batch element. + Computes x @ Aᵀ @ Bᵀ = x @ (BA)ᵀ, i.e. the LoRA delta is ΔW = BA.""" + def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): + super().__init__() + self.in_features = in_features + self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) # down-projection + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) # up-projection + self.reset() + + def forward(self, x: Tensor) -> Tensor: + return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) # (bsz, T, out) + + def reset(self) -> None: + bound = 1.0 / math.sqrt(self.in_features) + with torch.no_grad(): + self.A.uniform_(-bound, bound) # kaiming-uniform + self.B.zero_() + +class BatchedTTTLoRA(nn.Module): + """All LoRA adapters for one batch: LM head and Q/V per block.""" + def __init__(self, bsz: int, model: GPT, rank: int): + super().__init__() + dim = model.tok_emb.embedding_dim + vocab = model.tok_emb.num_embeddings + self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) + self.q_loras = nn.ModuleList() + self.v_loras = nn.ModuleList() + for block in model.blocks: + self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) + self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) + + def reset(self) -> None: + for m in self.modules(): + if isinstance(m, BatchedLinearLoRA): + m.reset() + +def _reset_ttt_optimizer(opt): + for group in opt.param_groups: + for p in group['params']: + s = opt.state.get(p) + if not s: # Fresh state. + continue + s['exp_avg'].zero_() + s['exp_avg_sq'].zero_() + s['step'].fill_(0) + +def _build_ttt_optimizer(lora, args: Hyperparameters): + return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) + +def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[int, int]]: + """Return (start_offset, length) for each document, identified by BOS boundaries. + + If include_next_bos is True, include next document's BOS (to match continuous-stream + eval token count exactly). + """ + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() + if include_next_bos and i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + +def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): + """Return (win_start, win_len, chunk_offset, chunk_len) for chunk `ci` of a doc.""" + chunk_start = ci * chunk_size + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + +def _accumulate_bpb( + ptl: Tensor, x: Tensor, y: Tensor, + batch_i: int, chunk_offset: int, chunk_len: int, + base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + loss_sum: Tensor, byte_sum: Tensor, token_count: Tensor, +): + """Add one doc-chunk's contribution to the running BPB accumulators.""" + lbl = ptl[batch_i, chunk_offset:chunk_offset + chunk_len].to(torch.float64) + prev = x[batch_i, chunk_offset:chunk_offset + chunk_len] + tgt = y[batch_i, chunk_offset:chunk_offset + chunk_len] + tok_bytes = base_bytes_lut[tgt].to(torch.float64) + tok_bytes += has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev] + loss_sum += lbl.sum() + byte_sum += tok_bytes.sum() + token_count += chunk_len + +def eval_val_ttt_lora( + args: Hyperparameters, + base_model: GPT, + rank: int, + world_size: int, + device: torch.device, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Evaluate with batched LoRA test-time training. Returns (val_loss, val_bpb).""" + # Load validation tokens and find document boundaries + files = sorted(glob.glob(args.val_files)) + all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) + docs = _find_docs(all_tokens) + + # Each rank takes a contiguous slice of documents + rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] + chunk_size = args.ttt_chunk_size + eval_seq_len = args.ttt_eval_seq_len + batch_size = args.ttt_batch_size + lora_rank = args.ttt_lora_rank + + rank_docs.sort(key=lambda d: (d[1] - 2) // chunk_size) + + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + + lora = BatchedTTTLoRA(batch_size, base_model, lora_rank).to(device) + opt = _build_ttt_optimizer(lora, args) + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + + for bi in range(0, len(rank_docs), batch_size): + batch = rank_docs[bi:bi + batch_size] + bsz = len(batch) + + if bsz == batch_size: + cur_lora, cur_opt = lora, opt + cur_lora.reset() + _reset_ttt_optimizer(cur_opt) + else: + cur_lora = BatchedTTTLoRA(bsz, base_model, lora_rank).to(device) + cur_opt = _build_ttt_optimizer(cur_lora, args) + + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + + for ci in range(max_nc): + chunk_stats = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) + context_size, chunk_offset = chunk_stats[1], chunk_stats[2] + + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + + x = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + y = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + doc_info = [] # (chunk_offset, chunk_len) per doc + for b in range(bsz): + if not active[b]: + doc_info.append((0, 0)) + continue + ds, dl = batch[b] + ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) + chunk = all_tokens[ds + ws: ds + ws + wl + 1] + toks = chunk.to(dtype=torch.int64, device=device) + x[b, :wl] = toks[:-1] + y[b, :wl] = toks[1:] + doc_info.append((co, cl)) + + # Forward pass (keep grad graph alive only when we need to train) + if needs_train: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + else: + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + + # Score: accumulate loss and byte counts for BPB (before training on chunk) + with torch.no_grad(): + for b in range(bsz): + if not active[b]: + continue + co, cl = doc_info[b] + _accumulate_bpb( + ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, loss_sum, byte_sum, token_count) + + # Train: one Adam step on the LoRA params using this chunk's loss + if needs_train: + mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device) + per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1) + cur_opt.zero_grad() + (per_doc * mask).sum().backward() + cur_opt.step() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + + val_loss = float(loss_sum.item() / token_count.item()) + val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) + return val_loss, val_bpb + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + if isinstance(module, Rotary): + module.inv_freq.data = module.inv_freq.data.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # LoRA test-time training evaluation (the competition score) + torch._dynamo.reset() + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( + args, base_model, rank, world_size, device, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_ttt_lora val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +Running PyTorch 2.10.0+cu128 +Fri Mar 20 06:20:55 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 41C P0 25W / 70W | 105MiB / 15360MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 5579 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/content/drive/MyDrive/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:10 +val_loader:shards pattern=/content/drive/MyDrive/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26497096 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:65536 train_seq_len:1024 iterations:10000 warmup_steps:20 max_wallclock_seconds:0.000 +seed:1337 +""" +train_gpt.py — Shared-Weight Recurrent Memory Transformer (RLM) +OpenAI Parameter Golf Competition — Non-Record Submission + +Architecture: One transformer block looped N_LOOPS=16 times (weight tying). + - Differential Attention V2 + per-loop low-rank Q delta (weight specialization) + - resid_mix: learned carry-forward vs input-reset per loop + - Loop position embeddings (activation-level depth signal) + - Memory tokens: learned persistent cross-loop state (RoPE excluded) + - Deep supervision: gamma-weighted loss at every loop iteration + - U-Net skip connections between encoder/decoder loop halves + - Gradient checkpointing: memory-efficient backprop through 16 loops + - QAT: per-row percentile fake-quantization during training + - relu² MLP, logit softcap, QK RMSNorm, q_gain (from baseline) + +At dim=512, 16 loops: ~4.24M params, ~4.24MB int8 — 26.5% of 16MB budget. +Effective compute depth of a ~63M parameter standard transformer. + +Thermal protection: checkpoints saved every SAVE_EVERY steps, auto-resumed. +""" + +from __future__ import annotations +import copy, glob, io, math, os, subprocess, sys, time, uuid, zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.checkpoint import checkpoint as grad_ckpt + +# ── Hyperparameters ─────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + save_every = int(os.environ.get("SAVE_EVERY", 500)) + + iterations = int(os.environ.get("ITERATIONS", 5000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 600)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 65_536)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0.0)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + n_loops = int(os.environ.get("N_LOOPS", 16)) + n_memory = int(os.environ.get("N_MEMORY", 32)) + mlp_mult = int(os.environ.get("MLP_MULT", 4)) + lora_rank = int(os.environ.get("LORA_RANK", 16)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + deep_sup_gamma= float(os.environ.get("DEEP_SUP_GAMMA", 0.9)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum= float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 300)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.1)) + +# ── Muon Optimizer (from modded-nanogpt / competition baseline) ─────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum, backend_steps, nesterov = ( + group["lr"], group["momentum"], group["backend_steps"], group["nesterov"] + ) + total = sum(int(p.numel()) for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-lr) + curr += p.numel() + return loss + +# ── Control tensor patterns (kept fp32, excluded from Muon) ────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,mlp_scale,resid_mix,q_gain,lambda_q1,lambda_k1,lambda_q2,lambda_k2," + "loop_embed,skip_weight,memory", +).split(",") if p) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# ── Tokenizer-aware BPB (from competition baseline) ─────────────────────────── + +def build_sentencepiece_luts(sp, vocab_size: int, device: torch.device): + sp_vocab = int(sp.vocab_size()) + table = max(sp_vocab, vocab_size) + base_np = np.zeros((table,), dtype=np.int16) + space_np = np.zeros((table,), dtype=np.bool_) + bnd_np = np.ones((table,), dtype=np.bool_) + for tid in range(sp_vocab): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue + bnd_np[tid] = False + if sp.is_byte(tid): base_np[tid] = 1; continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): space_np[tid] = True; piece = piece[1:] + base_np[tid] = len(piece.encode("utf-8")) + return (torch.tensor(base_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(bnd_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No val files: {pattern}") + tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[:usable + 1] + +@torch.no_grad() +def eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut): + local_tok = args.val_batch_size // (world_size * grad_accum_steps) + local_seqs = local_tok // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for s in range(seq_start, seq_end, local_seqs): + se = min(s + local_seqs, seq_end) + raw = val_tokens[s * args.train_seq_len:(se * args.train_seq_len + 1)] + raw = raw.to(device=device, dtype=torch.int64, non_blocking=True) + x = raw[:-1].reshape(-1, args.train_seq_len) + y = raw[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y).detach() + n = float(y.numel()) + loss_sum += loss.to(torch.float64) * n + tok_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_lut[tgt_ids].to(torch.int16) + tb += (space_lut[tgt_ids] & ~bnd_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, tok_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + model.train() + val_loss = float((loss_sum / tok_count).item()) + bpb = float((loss_sum / tok_count / math.log(2.0) * tok_count / byte_count).item()) + return val_loss, bpb + +# ── Quantization (from competition baseline) ────────────────────────────────── + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def _keep_float(name: str, t: Tensor, pt_dtypes: dict) -> Tensor: + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + pt_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + cl = torch.maximum(torch.minimum(t32, clip[:, None]), -clip[:, None]) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(cl / sc[:, None]), -127, 127).to(torch.int8).contiguous() + return q, sc.to(INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + sc = torch.tensor(clip / 127.0 if clip > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip, clip) / sc), -127, 127).to(torch.int8).contiguous() + return q, sc + +def quantize_state_dict_int8(sd: dict): + quant, scales, dtypes, passthrough, pt_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors", + "num_nonfloat_tensors","baseline_tensor_bytes","int8_payload_bytes"), 0) + for name, tensor in sd.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = _keep_float(name, t, pt_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = _quantize_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quant[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = {"__quant_format__": "int8_clean_per_row_v1", + "quantized": quant, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if pt_dtypes: obj["passthrough_orig_dtypes"] = pt_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict) -> dict: + out = {} + qm = obj.get("qmeta", {}) + ptod = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qm.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype).contiguous() + for name, t in obj["passthrough"].items(): + ot = t.detach().cpu().contiguous() + if isinstance(ptod.get(name), str): + ot = ot.to(getattr(torch, ptod[name])).contiguous() + out[name] = ot + return out + +# ── Data loading ────────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + hdr = np.fromfile(file, dtype=" Tensor: + chunks, rem = [], n + while rem > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._next(); continue + k = min(rem, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k; rem -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum: int): + local = global_tokens // (self.world_size * grad_accum) + span = local + 1 + chunk = self.stream.take(span * self.world_size) + s = self.rank * span + raw = chunk[s:s + span].to(torch.int64) + x = raw[:-1].reshape(-1, seq_len) + y = raw[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ── Model ───────────────────────────────────────────────────────────────────── + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), b) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__(); self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, p in module.named_parameters(): + if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and p.dtype != torch.float32: + p.data = p.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + self._sl, self._cos, self._sin = 0, None, None + def forward(self, sl: int, device, dtype): + if self._cos is None or self._sl != sl or self._cos.device != device: + t = torch.arange(sl, device=device, dtype=self.inv_freq.dtype) + f = torch.outer(t, self.inv_freq.to(device)) + self._cos = f.cos()[None, None, :, :]; self._sin = f.sin()[None, None, :, :] + self._sl = sl + return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) + +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def fake_quant_per_row(w: Tensor) -> Tensor: + """Per-row percentile int8 QAT with straight-through estimator.""" + if w.ndim < 2 or w.numel() == 0: return w + w32 = w.float() + clip = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + wq = (w32.clamp(-clip[:, None], clip[:, None]) / sc[:, None]).round().clamp(-127, 127) * sc[:, None] + return w + (wq.to(w.dtype) - w).detach() + +class SharedDiffAttention(nn.Module): + """Differential Attention V2 + per-loop low-rank Q delta + GQA.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, rope_base: float, rank: int): + super().__init__() + assert dim % nh == 0 and nh % nkv == 0 + self.nh, self.nkv, self.n_rep = nh, nkv, nh // nkv + self.hd = dim // nh + self.nmem = n_mem + + self.c_q = CastedLinear(dim, dim * 2, bias=False) + self.c_k = CastedLinear(dim, nkv * self.hd * 2, bias=False) + self.c_v = CastedLinear(dim, nkv * self.hd, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + li = 0.8 - 0.6 * math.exp(-0.3) + self.lambda_init = li + self.lambda_q1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_q2 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k2 = nn.Parameter(torch.randn(nh) * 0.1) + + self.q_gain = nn.Parameter(torch.ones(nh, dtype=torch.float32)) + self.delta_A = nn.Parameter(torch.zeros(n_loops, dim * 2, rank)) + self.delta_B = nn.Parameter(torch.zeros(n_loops, rank, dim)) + nn.init.normal_(self.delta_A, std=0.01) + + self.rotary = Rotary(self.hd, base=rope_base) + + def forward(self, x: Tensor, loop_idx: int) -> Tensor: + B, T, C = x.shape + H, KVH, D, NM = self.nh, self.nkv, self.hd, self.nmem + sl = T - NM + + # Q with per-loop low-rank delta + QAT on effective weight + wq = self.c_q.weight + (self.delta_A[loop_idx] @ self.delta_B[loop_idx]) + wq = fake_quant_per_row(wq) + qp = F.linear(x, wq.to(x.dtype)) + q1 = qp[..., :C].reshape(B, T, H, D).transpose(1, 2) + q2 = qp[..., C:].reshape(B, T, H, D).transpose(1, 2) + + # K, V: fresh every loop from updated hidden state + kv = self.c_k(x) + k1 = kv[..., :KVH*D].reshape(B, T, KVH, D).transpose(1, 2) + k2 = kv[..., KVH*D:].reshape(B, T, KVH, D).transpose(1, 2) + v = self.c_v(x).reshape(B, T, KVH, D).transpose(1, 2) + + # QK RMSNorm before RoPE (baseline technique) + q1 = F.rms_norm(q1, (D,)); q2 = F.rms_norm(q2, (D,)) + k1 = F.rms_norm(k1, (D,)); k2 = F.rms_norm(k2, (D,)) + + # RoPE: sequence positions ONLY — memory tokens excluded + cos, sin = self.rotary(sl, x.device, q1.dtype) + def rope_seq(q, k): + qm, qs = q[:, :, :NM], q[:, :, NM:] + km, ks = k[:, :, :NM], k[:, :, NM:] + return (torch.cat([qm, apply_rope(qs, cos, sin)], dim=2), + torch.cat([km, apply_rope(ks, cos, sin)], dim=2)) + q1, k1 = rope_seq(q1, k1) + q2, k2 = rope_seq(q2, k2) + + # Q gain + GQA + g = self.q_gain.to(q1.dtype)[None, :, None, None] + q1 = q1 * g; q2 = q2 * g + k1 = k1.repeat_interleave(self.n_rep, dim=1) + k2 = k2.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + # Differential attention: cancel noise via subtraction + lam = (torch.exp(self.lambda_q1) * torch.exp(self.lambda_k1) + - torch.exp(self.lambda_q2) * torch.exp(self.lambda_k2) + + self.lambda_init).to(q1.dtype)[None, :, None, None] + a1 = F.scaled_dot_product_attention(q1, k1, v, is_causal=True) + a2 = F.scaled_dot_product_attention(q2, k2, v, is_causal=True) + out = (a1 - lam * a2).transpose(1, 2).contiguous().view(B, T, C) + return self.proj(out) + +class ReluSquaredMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + h = dim * mlp_mult + self.fc = CastedLinear(dim, h, bias=False) + self.proj = CastedLinear(h, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class RecurrentBlock(nn.Module): + """Shared block with all per-loop parameters.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, mlp_mult: int, rope_base: float, rank: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SharedDiffAttention(dim, nh, nkv, n_loops, n_mem, rope_base, rank) + self.mlp = ReluSquaredMLP(dim, mlp_mult) + + # Per-loop scale factors (learned attention/MLP gate per depth) + self.attn_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + + # resid_mix: learned interpolation between current x and original x0 + # [0] = weight on current state, [1] = weight on original input + # Initialised to carry forward (mix[0]=1, mix[1]=0), learns to reset selectively + self.resid_mix = nn.Parameter( + torch.stack([torch.ones(n_loops, dim), torch.zeros(n_loops, dim)]).float() + ) + + # Loop position embedding: activation-level depth signal + # Complementary to weight-level low-rank delta + self.loop_embed = nn.Embedding(n_loops, dim) + nn.init.normal_(self.loop_embed.weight, std=0.02) + + def forward(self, x: Tensor, x0: Tensor, loop_idx: int) -> Tensor: + # resid_mix: learned balance of current state vs original input (generalises input injection) + mix = self.resid_mix[:, loop_idx, :].to(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # Loop position embedding (activation-level specialisation) + x = x + self.loop_embed.weight[loop_idx].to(x.dtype)[None, None, :] + + # Attention + MLP with per-loop scale + a = self.attn(self.attn_norm(x), loop_idx) + x = x + self.attn_scale[loop_idx].to(x.dtype)[None, None, :] * a + x = x + self.mlp_scale[loop_idx].to(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class RLM(nn.Module): + """ + Recurrent Language Model. + Single shared block × N_LOOPS with deep supervision and U-Net skips. + """ + def __init__(self, args: Hyperparameters): + super().__init__() + dim, nl, nm = args.model_dim, args.n_loops, args.n_memory + self.n_loops, self.n_memory = nl, nm + self.dim = dim + self.logit_softcap = args.logit_softcap + self.deep_sup_gamma = args.deep_sup_gamma + + self.n_enc = nl // 2 + self.n_skip = min(self.n_enc, nl - self.n_enc) + + self.tok_emb = nn.Embedding(args.vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + self.memory = nn.Parameter(torch.randn(1, nm, dim) * 0.02) + + self.block = RecurrentBlock( + dim, args.num_heads, args.num_kv_heads, nl, nm, + args.mlp_mult, args.rope_base, args.lora_rank + ) + + self.skip_weights = nn.Parameter( + torch.ones(self.n_skip, dim, dtype=torch.float32) + if self.n_skip > 0 else torch.zeros(0, dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + + # Zero-init output projections for stability at loop depth + nn.init.zeros_(self.block.attn.proj.weight) + nn.init.zeros_(self.block.mlp.proj.weight) + + def _logits(self, x: Tensor) -> Tensor: + h = self.final_norm(x[:, self.n_memory:, :]) + l = F.linear(h.reshape(-1, self.dim), self.tok_emb.weight) + return self.logit_softcap * torch.tanh(l / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + B, T = input_ids.shape + tok = F.rms_norm(self.tok_emb(input_ids), (self.dim,)) + x = torch.cat([self.memory.expand(B, -1, -1), tok], dim=1) + x0 = x + + gamma = self.deep_sup_gamma + total_loss = torch.zeros((), device=input_ids.device) + weight_sum = 0.0 + targets = target_ids.reshape(-1) + skips = [] + + for i in range(self.n_loops): + # Gradient checkpointing: recompute activations during backward + x = grad_ckpt(self.block, x, x0, i, use_reentrant=False) + + # U-Net: encoder half saves, decoder half consumes in reverse + if i < self.n_enc: + skips.append(x) + elif self.n_skip > 0 and skips: + di = i - self.n_enc + if di < self.n_skip: + sw = self.skip_weights[di].to(x.dtype)[None, None, :] + x = x + sw * skips[self.n_enc - 1 - di] + + # Deep supervision: weighted loss at every loop + # Later loops weighted higher (gamma^(N-1-i)), but all contribute + w = gamma ** (self.n_loops - 1 - i) + total_loss = total_loss + w * F.cross_entropy( + self._logits(x).float(), targets, reduction="mean" + ) + weight_sum += w + + return total_loss / weight_sum + +# ── Training ────────────────────────────────────────────────────────────────── + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + global zeropower_via_newtonschulz5 + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ── Distributed + CUDA setup ────────────────────────────────────────────── + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required. For CPU testing use train_gpt_local.py") + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(False); enable_math_sdp(False) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master else None + + def log0(msg: str, console: bool = True) -> None: + if not master: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + + # ── Seed + tokenizer ────────────────────────────────────────────────────── + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"Tokenizer vocab {sp.vocab_size()} != VOCAB_SIZE {args.vocab_size}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_lut, space_lut, bnd_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_tokens:{val_tokens.numel()-1} train_seq_len:{args.train_seq_len}") + + # ── Model ───────────────────────────────────────────────────────────────── + base_model = RLM(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(base_model) + compiled = torch.compile(base_model, dynamic=False) + model: nn.Module = (DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) + if distributed else compiled) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} (~{n_params/1e6:.2f}M)") + log0(f"n_loops:{args.n_loops} n_memory:{args.n_memory} dim:{args.model_dim} " + f"seq_len:{args.train_seq_len}") + + # ── Optimizers ──────────────────────────────────────────────────────────── + block_named = list(base_model.block.named_parameters()) + matrix_params = [p for n, p in block_named + if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = ([p for n, p in block_named + if p.ndim != 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + + [base_model.skip_weights, base_model.memory, base_model.final_norm.eps + if hasattr(base_model.final_norm, 'eps') and + isinstance(base_model.final_norm.eps, nn.Parameter) else None]) + scalar_params = [p for p in scalar_params if p is not None] + + opt_embed = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_muon = Muon(matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_embed, opt_muon, opt_scalar] + + log0(f"matrix_params:{sum(p.numel() for p in matrix_params):,} " + f"scalar_params:{sum(p.numel() for p in scalar_params):,}") + + # ── Checkpoint resume (thermal crash protection) ────────────────────────── + ckpt_path = f"checkpoint_{args.run_id}.pt" + start_step = 0 + training_time_ms = 0.0 + if os.path.exists(ckpt_path) and master: + log0(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + base_model.load_state_dict(ckpt["model"]) + for opt, sd in zip(optimizers, ckpt["optimizers"]): + opt.load_state_dict(sd) + start_step = ckpt["step"] + training_time_ms = ckpt.get("training_time_ms", 0.0) + log0(f"Resumed at step {start_step}") + if distributed: dist.barrier() + + # ── Data loader ─────────────────────────────────────────────────────────── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad(): + for o in optimizers: o.zero_grad(set_to_none=True) + + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_scale(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wc_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ + if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = args.warmdown_iters * step_ms + rem_ms = max(max_wc_ms - elapsed_ms, 0.0) + return rem_ms / max(wd_ms, 1e-9) if rem_ms <= wd_ms else 1.0 + + # ── Warmup-then-restore ─────────────────────────────────────────────────── + if args.warmup_steps > 0 and start_step == 0: + init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + (model(x, y) * grad_scale).backward() + for o in optimizers: o.step() + zero_grad() + log0(f"warmup_complete:{args.warmup_steps} steps") + base_model.load_state_dict(init_model, strict=True) + for o, sd in zip(optimizers, init_opts): o.load_state_dict(sd) + zero_grad() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ──────────────────────────────────────────────────── + stop_after: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + for step in range(start_step, args.iterations + 1): + last = step == args.iterations or (stop_after is not None and step >= stop_after) + + # Validation + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or last: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + + if last: break + + # Checkpoint save (thermal crash protection) + if args.save_every > 0 and step > start_step and step % args.save_every == 0 and master: + torch.cuda.synchronize() + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + torch.save({"step": step, "model": base_model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "training_time_ms": elapsed_now}, ckpt_path) + log0(f"checkpoint_saved:step={step}") + + # LR update + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_scale(step, elapsed_ms) + for o in optimizers: + for g in o.param_groups: g["lr"] = g["base_lr"] * scale + + # Muon momentum warmup + frac = min(step / max(args.muon_momentum_warmup_steps, 1), 1.0) + mum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for g in opt_muon.param_groups: g["momentum"] = mum + + # Forward + backward with gradient accumulation + zero_grad() + train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for o in optimizers: o.step() + zero_grad() + + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step+1}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ms:.0f}ms step_avg:{approx_ms/(step+1):.2f}ms lr_scale:{scale:.4f}") + + reached = max_wc_ms is not None and approx_ms >= max_wc_ms + if distributed and max_wc_ms is not None: + rt = torch.tensor(int(reached), device=device) + dist.all_reduce(rt, op=dist.ReduceOp.MAX) + reached = bool(rt.item()) + if stop_after is None and reached: + stop_after = step + 1 + + log0(f"peak_memory:{torch.cuda.max_memory_allocated()//1024//1024}MiB " + f"reserved:{torch.cuda.max_memory_reserved()//1024//1024}MiB") + + # ── Serialisation + roundtrip validation ────────────────────────────────── + if master: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + model_bytes = os.path.getsize("final_model.pt") + log0(f"raw_model:{model_bytes} code:{code_bytes} total:{model_bytes+code_bytes}") + + quant_obj, qstats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + if master: + with open("final_model.int8.ptz", "wb") as f: f.write(blob) + qbytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = qstats["baseline_tensor_bytes"] / max(qstats["int8_payload_bytes"], 1) + log0(f"int8_zlib:{qbytes} code:{code_bytes} total:{qbytes+code_bytes} " + f"payload_ratio:{ratio:.2f}x budget_pct:{(qbytes+code_bytes)/16e6*100:.1f}%") + + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: blob_disk = f.read() + qs2 = torch.load(io.BytesIO(zlib.decompress(blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(qs2), strict=True) + torch.cuda.synchronize(); te = time.perf_counter() + qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + torch.cuda.synchronize() + log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-te):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvbpb:.8f}") + + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +PyTorch 2.10.0+cu128 +Fri Mar 20 06:42:33 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 43C P0 25W / 70W | 105MiB / 15360MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 11151 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +val_tokens:62021632 train_seq_len:1024 +model_params:4255784 (~4.26M) +n_loops:16 n_memory:32 dim:512 seq_len:1024 +matrix_params:3,276,800 scalar_params:454,696 +""" +train_gpt.py — Shared-Weight Recurrent Memory Transformer (RLM) +OpenAI Parameter Golf Competition — Non-Record Submission + +Architecture: One transformer block looped N_LOOPS=16 times (weight tying). + - Differential Attention V2 + per-loop low-rank Q delta (weight specialization) + - resid_mix: learned carry-forward vs input-reset per loop + - Loop position embeddings (activation-level depth signal) + - Memory tokens: learned persistent cross-loop state (RoPE excluded) + - Deep supervision: gamma-weighted loss at every loop iteration + - U-Net skip connections between encoder/decoder loop halves + - Gradient checkpointing: memory-efficient backprop through 16 loops + - QAT: per-row percentile fake-quantization during training + - relu² MLP, logit softcap, QK RMSNorm, q_gain (from baseline) + +At dim=512, 16 loops: ~4.24M params, ~4.24MB int8 — 26.5% of 16MB budget. +Effective compute depth of a ~63M parameter standard transformer. + +Thermal protection: checkpoints saved every SAVE_EVERY steps, auto-resumed. +""" + +from __future__ import annotations +import copy, glob, io, math, os, subprocess, sys, time, uuid, zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.checkpoint import checkpoint as grad_ckpt + +# ── Hyperparameters ─────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + save_every = int(os.environ.get("SAVE_EVERY", 500)) + + iterations = int(os.environ.get("ITERATIONS", 5000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 600)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 65_536)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0.0)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + n_loops = int(os.environ.get("N_LOOPS", 16)) + n_memory = int(os.environ.get("N_MEMORY", 32)) + mlp_mult = int(os.environ.get("MLP_MULT", 4)) + lora_rank = int(os.environ.get("LORA_RANK", 16)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + deep_sup_gamma= float(os.environ.get("DEEP_SUP_GAMMA", 0.9)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum= float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 300)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.1)) + +# ── Muon Optimizer (from modded-nanogpt / competition baseline) ─────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum, backend_steps, nesterov = ( + group["lr"], group["momentum"], group["backend_steps"], group["nesterov"] + ) + total = sum(int(p.numel()) for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-lr) + curr += p.numel() + return loss + +# ── Control tensor patterns (kept fp32, excluded from Muon) ────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,mlp_scale,resid_mix,q_gain,lambda_q1,lambda_k1,lambda_q2,lambda_k2," + "loop_embed,skip_weight,memory", +).split(",") if p) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# ── Tokenizer-aware BPB (from competition baseline) ─────────────────────────── + +def build_sentencepiece_luts(sp, vocab_size: int, device: torch.device): + sp_vocab = int(sp.vocab_size()) + table = max(sp_vocab, vocab_size) + base_np = np.zeros((table,), dtype=np.int16) + space_np = np.zeros((table,), dtype=np.bool_) + bnd_np = np.ones((table,), dtype=np.bool_) + for tid in range(sp_vocab): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue + bnd_np[tid] = False + if sp.is_byte(tid): base_np[tid] = 1; continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): space_np[tid] = True; piece = piece[1:] + base_np[tid] = len(piece.encode("utf-8")) + return (torch.tensor(base_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(bnd_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No val files: {pattern}") + tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[:usable + 1] + +@torch.no_grad() +def eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut): + local_tok = args.val_batch_size // (world_size * grad_accum_steps) + local_seqs = local_tok // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for s in range(seq_start, seq_end, local_seqs): + se = min(s + local_seqs, seq_end) + raw = val_tokens[s * args.train_seq_len:(se * args.train_seq_len + 1)] + raw = raw.to(device=device, dtype=torch.int64, non_blocking=True) + x = raw[:-1].reshape(-1, args.train_seq_len) + y = raw[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y).detach() + n = float(y.numel()) + loss_sum += loss.to(torch.float64) * n + tok_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_lut[tgt_ids].to(torch.int16) + tb += (space_lut[tgt_ids] & ~bnd_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, tok_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + model.train() + val_loss = float((loss_sum / tok_count).item()) + bpb = float((loss_sum / tok_count / math.log(2.0) * tok_count / byte_count).item()) + return val_loss, bpb + +# ── Quantization (from competition baseline) ────────────────────────────────── + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def _keep_float(name: str, t: Tensor, pt_dtypes: dict) -> Tensor: + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + pt_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + cl = torch.maximum(torch.minimum(t32, clip[:, None]), -clip[:, None]) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(cl / sc[:, None]), -127, 127).to(torch.int8).contiguous() + return q, sc.to(INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + sc = torch.tensor(clip / 127.0 if clip > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip, clip) / sc), -127, 127).to(torch.int8).contiguous() + return q, sc + +def quantize_state_dict_int8(sd: dict): + quant, scales, dtypes, passthrough, pt_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors", + "num_nonfloat_tensors","baseline_tensor_bytes","int8_payload_bytes"), 0) + for name, tensor in sd.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = _keep_float(name, t, pt_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = _quantize_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quant[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = {"__quant_format__": "int8_clean_per_row_v1", + "quantized": quant, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if pt_dtypes: obj["passthrough_orig_dtypes"] = pt_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict) -> dict: + out = {} + qm = obj.get("qmeta", {}) + ptod = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qm.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype).contiguous() + for name, t in obj["passthrough"].items(): + ot = t.detach().cpu().contiguous() + if isinstance(ptod.get(name), str): + ot = ot.to(getattr(torch, ptod[name])).contiguous() + out[name] = ot + return out + +# ── Data loading ────────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + hdr = np.fromfile(file, dtype=" Tensor: + chunks, rem = [], n + while rem > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._next(); continue + k = min(rem, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k; rem -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum: int): + local = global_tokens // (self.world_size * grad_accum) + span = local + 1 + chunk = self.stream.take(span * self.world_size) + s = self.rank * span + raw = chunk[s:s + span].to(torch.int64) + x = raw[:-1].reshape(-1, seq_len) + y = raw[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ── Model ───────────────────────────────────────────────────────────────────── + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), b) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__(); self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, p in module.named_parameters(): + if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and p.dtype != torch.float32: + p.data = p.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + self._sl, self._cos, self._sin = 0, None, None + def forward(self, sl: int, device, dtype): + if self._cos is None or self._sl != sl or self._cos.device != device: + t = torch.arange(sl, device=device, dtype=self.inv_freq.dtype) + f = torch.outer(t, self.inv_freq.to(device)) + self._cos = f.cos()[None, None, :, :]; self._sin = f.sin()[None, None, :, :] + self._sl = sl + return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) + +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def fake_quant_per_row(w: Tensor) -> Tensor: + """Per-row percentile int8 QAT with straight-through estimator.""" + if w.ndim < 2 or w.numel() == 0: return w + w32 = w.float() + clip = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + wq = (w32.clamp(-clip[:, None], clip[:, None]) / sc[:, None]).round().clamp(-127, 127) * sc[:, None] + return w + (wq.to(w.dtype) - w).detach() + +class SharedDiffAttention(nn.Module): + """Differential Attention V2 + per-loop low-rank Q delta + GQA.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, rope_base: float, rank: int): + super().__init__() + assert dim % nh == 0 and nh % nkv == 0 + self.nh, self.nkv, self.n_rep = nh, nkv, nh // nkv + self.hd = dim // nh + self.nmem = n_mem + + self.c_q = CastedLinear(dim, dim * 2, bias=False) + self.c_k = CastedLinear(dim, nkv * self.hd * 2, bias=False) + self.c_v = CastedLinear(dim, nkv * self.hd, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + li = 0.8 - 0.6 * math.exp(-0.3) + self.lambda_init = li + self.lambda_q1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_q2 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k2 = nn.Parameter(torch.randn(nh) * 0.1) + + self.q_gain = nn.Parameter(torch.ones(nh, dtype=torch.float32)) + self.delta_A = nn.Parameter(torch.zeros(n_loops, dim * 2, rank)) + self.delta_B = nn.Parameter(torch.zeros(n_loops, rank, dim)) + nn.init.normal_(self.delta_A, std=0.01) + + self.rotary = Rotary(self.hd, base=rope_base) + + def forward(self, x: Tensor, loop_idx: int) -> Tensor: + B, T, C = x.shape + H, KVH, D, NM = self.nh, self.nkv, self.hd, self.nmem + sl = T - NM + + # Q with per-loop low-rank delta + QAT on effective weight + wq = self.c_q.weight + (self.delta_A[loop_idx] @ self.delta_B[loop_idx]) + wq = fake_quant_per_row(wq) + qp = F.linear(x, wq.to(x.dtype)) + q1 = qp[..., :C].reshape(B, T, H, D).transpose(1, 2) + q2 = qp[..., C:].reshape(B, T, H, D).transpose(1, 2) + + # K, V: fresh every loop from updated hidden state + kv = self.c_k(x) + k1 = kv[..., :KVH*D].reshape(B, T, KVH, D).transpose(1, 2) + k2 = kv[..., KVH*D:].reshape(B, T, KVH, D).transpose(1, 2) + v = self.c_v(x).reshape(B, T, KVH, D).transpose(1, 2) + + # QK RMSNorm before RoPE (baseline technique) + q1 = F.rms_norm(q1, (D,)); q2 = F.rms_norm(q2, (D,)) + k1 = F.rms_norm(k1, (D,)); k2 = F.rms_norm(k2, (D,)) + + # RoPE: sequence positions ONLY — memory tokens excluded + cos, sin = self.rotary(sl, x.device, q1.dtype) + def rope_seq(q, k): + qm, qs = q[:, :, :NM], q[:, :, NM:] + km, ks = k[:, :, :NM], k[:, :, NM:] + return (torch.cat([qm, apply_rope(qs, cos, sin)], dim=2), + torch.cat([km, apply_rope(ks, cos, sin)], dim=2)) + q1, k1 = rope_seq(q1, k1) + q2, k2 = rope_seq(q2, k2) + + # Q gain + GQA + g = self.q_gain.to(q1.dtype)[None, :, None, None] + q1 = q1 * g; q2 = q2 * g + k1 = k1.repeat_interleave(self.n_rep, dim=1) + k2 = k2.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + # Differential attention: cancel noise via subtraction + lam = (torch.exp(self.lambda_q1) * torch.exp(self.lambda_k1) + - torch.exp(self.lambda_q2) * torch.exp(self.lambda_k2) + + self.lambda_init).to(q1.dtype)[None, :, None, None] + a1 = F.scaled_dot_product_attention(q1, k1, v, is_causal=True) + a2 = F.scaled_dot_product_attention(q2, k2, v, is_causal=True) + out = (a1 - lam * a2).transpose(1, 2).contiguous().view(B, T, C) + return self.proj(out) + +class ReluSquaredMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + h = dim * mlp_mult + self.fc = CastedLinear(dim, h, bias=False) + self.proj = CastedLinear(h, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class RecurrentBlock(nn.Module): + """Shared block with all per-loop parameters.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, mlp_mult: int, rope_base: float, rank: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SharedDiffAttention(dim, nh, nkv, n_loops, n_mem, rope_base, rank) + self.mlp = ReluSquaredMLP(dim, mlp_mult) + + # Per-loop scale factors (learned attention/MLP gate per depth) + self.attn_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + + # resid_mix: learned interpolation between current x and original x0 + # [0] = weight on current state, [1] = weight on original input + # Initialised to carry forward (mix[0]=1, mix[1]=0), learns to reset selectively + self.resid_mix = nn.Parameter( + torch.stack([torch.ones(n_loops, dim), torch.zeros(n_loops, dim)]).float() + ) + + # Loop position embedding: activation-level depth signal + # Complementary to weight-level low-rank delta + self.loop_embed = nn.Embedding(n_loops, dim) + nn.init.normal_(self.loop_embed.weight, std=0.02) + + def forward(self, x: Tensor, x0: Tensor, loop_idx: int) -> Tensor: + # resid_mix: learned balance of current state vs original input (generalises input injection) + mix = self.resid_mix[:, loop_idx, :].to(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # Loop position embedding (activation-level specialisation) + x = x + self.loop_embed.weight[loop_idx].to(x.dtype)[None, None, :] + + # Attention + MLP with per-loop scale + a = self.attn(self.attn_norm(x), loop_idx) + x = x + self.attn_scale[loop_idx].to(x.dtype)[None, None, :] * a + x = x + self.mlp_scale[loop_idx].to(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class RLM(nn.Module): + """ + Recurrent Language Model. + Single shared block × N_LOOPS with deep supervision and U-Net skips. + """ + def __init__(self, args: Hyperparameters): + super().__init__() + dim, nl, nm = args.model_dim, args.n_loops, args.n_memory + self.n_loops, self.n_memory = nl, nm + self.dim = dim + self.logit_softcap = args.logit_softcap + self.deep_sup_gamma = args.deep_sup_gamma + + self.n_enc = nl // 2 + self.n_skip = min(self.n_enc, nl - self.n_enc) + + self.tok_emb = nn.Embedding(args.vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + self.memory = nn.Parameter(torch.randn(1, nm, dim) * 0.02) + + self.block = RecurrentBlock( + dim, args.num_heads, args.num_kv_heads, nl, nm, + args.mlp_mult, args.rope_base, args.lora_rank + ) + + self.skip_weights = nn.Parameter( + torch.ones(self.n_skip, dim, dtype=torch.float32) + if self.n_skip > 0 else torch.zeros(0, dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + + # Zero-init output projections for stability at loop depth + nn.init.zeros_(self.block.attn.proj.weight) + nn.init.zeros_(self.block.mlp.proj.weight) + + def _logits(self, x: Tensor) -> Tensor: + h = self.final_norm(x[:, self.n_memory:, :]) + l = F.linear(h.reshape(-1, self.dim), self.tok_emb.weight) + return self.logit_softcap * torch.tanh(l / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + B, T = input_ids.shape + tok = F.rms_norm(self.tok_emb(input_ids), (self.dim,)) + x = torch.cat([self.memory.expand(B, -1, -1), tok], dim=1) + x0 = x + + gamma = self.deep_sup_gamma + total_loss = torch.zeros((), device=input_ids.device) + weight_sum = 0.0 + targets = target_ids.reshape(-1) + skips = [] + + for i in range(self.n_loops): + # Gradient checkpointing: recompute activations during backward + x = grad_ckpt(self.block, x, x0, i, use_reentrant=False) + + # U-Net: encoder half saves, decoder half consumes in reverse + if i < self.n_enc: + skips.append(x) + elif self.n_skip > 0 and skips: + di = i - self.n_enc + if di < self.n_skip: + sw = self.skip_weights[di].to(x.dtype)[None, None, :] + x = x + sw * skips[self.n_enc - 1 - di] + + # Deep supervision: weighted loss at every loop + # Later loops weighted higher (gamma^(N-1-i)), but all contribute + w = gamma ** (self.n_loops - 1 - i) + total_loss = total_loss + w * F.cross_entropy( + self._logits(x).float(), targets, reduction="mean" + ) + weight_sum += w + + return total_loss / weight_sum + +# ── Training ────────────────────────────────────────────────────────────────── + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + global zeropower_via_newtonschulz5 + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ── Distributed + CUDA setup ────────────────────────────────────────────── + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required. For CPU testing use train_gpt_local.py") + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(False); enable_math_sdp(False) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master else None + + def log0(msg: str, console: bool = True) -> None: + if not master: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + + # ── Seed + tokenizer ────────────────────────────────────────────────────── + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"Tokenizer vocab {sp.vocab_size()} != VOCAB_SIZE {args.vocab_size}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_lut, space_lut, bnd_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_tokens:{val_tokens.numel()-1} train_seq_len:{args.train_seq_len}") + + # ── Model ───────────────────────────────────────────────────────────────── + base_model = RLM(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(base_model) + compiled = torch.compile(base_model, dynamic=False) + model: nn.Module = (DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) + if distributed else compiled) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} (~{n_params/1e6:.2f}M)") + log0(f"n_loops:{args.n_loops} n_memory:{args.n_memory} dim:{args.model_dim} " + f"seq_len:{args.train_seq_len}") + + # ── Optimizers ──────────────────────────────────────────────────────────── + block_named = list(base_model.block.named_parameters()) + matrix_params = [p for n, p in block_named + if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = ([p for n, p in block_named + if p.ndim != 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + + [base_model.skip_weights, base_model.memory, base_model.final_norm.eps + if hasattr(base_model.final_norm, 'eps') and + isinstance(base_model.final_norm.eps, nn.Parameter) else None]) + scalar_params = [p for p in scalar_params if p is not None] + + opt_embed = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_muon = Muon(matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_embed, opt_muon, opt_scalar] + + log0(f"matrix_params:{sum(p.numel() for p in matrix_params):,} " + f"scalar_params:{sum(p.numel() for p in scalar_params):,}") + + # ── Checkpoint resume (thermal crash protection) ────────────────────────── + ckpt_path = f"checkpoint_{args.run_id}.pt" + start_step = 0 + training_time_ms = 0.0 + if os.path.exists(ckpt_path) and master: + log0(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + base_model.load_state_dict(ckpt["model"]) + for opt, sd in zip(optimizers, ckpt["optimizers"]): + opt.load_state_dict(sd) + start_step = ckpt["step"] + training_time_ms = ckpt.get("training_time_ms", 0.0) + log0(f"Resumed at step {start_step}") + if distributed: dist.barrier() + + # ── Data loader ─────────────────────────────────────────────────────────── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad(): + for o in optimizers: o.zero_grad(set_to_none=True) + + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_scale(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wc_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ + if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = args.warmdown_iters * step_ms + rem_ms = max(max_wc_ms - elapsed_ms, 0.0) + return rem_ms / max(wd_ms, 1e-9) if rem_ms <= wd_ms else 1.0 + + # ── Warmup-then-restore ─────────────────────────────────────────────────── + if args.warmup_steps > 0 and start_step == 0: + init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + (model(x, y) * grad_scale).backward() + for o in optimizers: o.step() + zero_grad() + log0(f"warmup_complete:{args.warmup_steps} steps") + base_model.load_state_dict(init_model, strict=True) + for o, sd in zip(optimizers, init_opts): o.load_state_dict(sd) + zero_grad() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ──────────────────────────────────────────────────── + stop_after: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + for step in range(start_step, args.iterations + 1): + last = step == args.iterations or (stop_after is not None and step >= stop_after) + + # Validation + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or last: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + + if last: break + + # Checkpoint save (thermal crash protection) + if args.save_every > 0 and step > start_step and step % args.save_every == 0 and master: + torch.cuda.synchronize() + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + torch.save({"step": step, "model": base_model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "training_time_ms": elapsed_now}, ckpt_path) + log0(f"checkpoint_saved:step={step}") + + # LR update + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_scale(step, elapsed_ms) + for o in optimizers: + for g in o.param_groups: g["lr"] = g["base_lr"] * scale + + # Muon momentum warmup + frac = min(step / max(args.muon_momentum_warmup_steps, 1), 1.0) + mum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for g in opt_muon.param_groups: g["momentum"] = mum + + # Forward + backward with gradient accumulation + zero_grad() + train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for o in optimizers: o.step() + zero_grad() + + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step+1}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ms:.0f}ms step_avg:{approx_ms/(step+1):.2f}ms lr_scale:{scale:.4f}") + + reached = max_wc_ms is not None and approx_ms >= max_wc_ms + if distributed and max_wc_ms is not None: + rt = torch.tensor(int(reached), device=device) + dist.all_reduce(rt, op=dist.ReduceOp.MAX) + reached = bool(rt.item()) + if stop_after is None and reached: + stop_after = step + 1 + + log0(f"peak_memory:{torch.cuda.max_memory_allocated()//1024//1024}MiB " + f"reserved:{torch.cuda.max_memory_reserved()//1024//1024}MiB") + + # ── Serialisation + roundtrip validation ────────────────────────────────── + if master: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + model_bytes = os.path.getsize("final_model.pt") + log0(f"raw_model:{model_bytes} code:{code_bytes} total:{model_bytes+code_bytes}") + + quant_obj, qstats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + if master: + with open("final_model.int8.ptz", "wb") as f: f.write(blob) + qbytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = qstats["baseline_tensor_bytes"] / max(qstats["int8_payload_bytes"], 1) + log0(f"int8_zlib:{qbytes} code:{code_bytes} total:{qbytes+code_bytes} " + f"payload_ratio:{ratio:.2f}x budget_pct:{(qbytes+code_bytes)/16e6*100:.1f}%") + + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: blob_disk = f.read() + qs2 = torch.load(io.BytesIO(zlib.decompress(blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(qs2), strict=True) + torch.cuda.synchronize(); te = time.perf_counter() + qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + torch.cuda.synchronize() + log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-te):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvbpb:.8f}") + + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +PyTorch 2.10.0+cu128 +Fri Mar 20 06:44:15 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 42C P0 26W / 70W | 105MiB / 15360MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 11583 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +val_tokens:62021632 train_seq_len:1024 +model_params:4255784 (~4.26M) +n_loops:16 n_memory:32 dim:512 seq_len:1024 +matrix_params:3,276,800 scalar_params:454,696 +""" +train_gpt.py — Shared-Weight Recurrent Memory Transformer (RLM) +OpenAI Parameter Golf Competition — Non-Record Submission + +Architecture: One transformer block looped N_LOOPS=16 times (weight tying). + - Differential Attention V2 + per-loop low-rank Q delta (weight specialization) + - resid_mix: learned carry-forward vs input-reset per loop + - Loop position embeddings (activation-level depth signal) + - Memory tokens: learned persistent cross-loop state (RoPE excluded) + - Deep supervision: gamma-weighted loss at every loop iteration + - U-Net skip connections between encoder/decoder loop halves + - Gradient checkpointing: memory-efficient backprop through 16 loops + - QAT: per-row percentile fake-quantization during training + - relu² MLP, logit softcap, QK RMSNorm, q_gain (from baseline) + +At dim=512, 16 loops: ~4.24M params, ~4.24MB int8 — 26.5% of 16MB budget. +Effective compute depth of a ~63M parameter standard transformer. + +Thermal protection: checkpoints saved every SAVE_EVERY steps, auto-resumed. +""" + +from __future__ import annotations +import copy, glob, io, math, os, subprocess, sys, time, uuid, zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.checkpoint import checkpoint as grad_ckpt + +# ── Hyperparameters ─────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + save_every = int(os.environ.get("SAVE_EVERY", 500)) + + iterations = int(os.environ.get("ITERATIONS", 5000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 600)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 65_536)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0.0)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + n_loops = int(os.environ.get("N_LOOPS", 16)) + n_memory = int(os.environ.get("N_MEMORY", 32)) + mlp_mult = int(os.environ.get("MLP_MULT", 4)) + lora_rank = int(os.environ.get("LORA_RANK", 16)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + deep_sup_gamma= float(os.environ.get("DEEP_SUP_GAMMA", 0.9)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum= float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 300)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.1)) + +# ── Muon Optimizer (from modded-nanogpt / competition baseline) ─────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum, backend_steps, nesterov = ( + group["lr"], group["momentum"], group["backend_steps"], group["nesterov"] + ) + total = sum(int(p.numel()) for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-lr) + curr += p.numel() + return loss + +# ── Control tensor patterns (kept fp32, excluded from Muon) ────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,mlp_scale,resid_mix,q_gain,lambda_q1,lambda_k1,lambda_q2,lambda_k2," + "loop_embed,skip_weight,memory", +).split(",") if p) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# ── Tokenizer-aware BPB (from competition baseline) ─────────────────────────── + +def build_sentencepiece_luts(sp, vocab_size: int, device: torch.device): + sp_vocab = int(sp.vocab_size()) + table = max(sp_vocab, vocab_size) + base_np = np.zeros((table,), dtype=np.int16) + space_np = np.zeros((table,), dtype=np.bool_) + bnd_np = np.ones((table,), dtype=np.bool_) + for tid in range(sp_vocab): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue + bnd_np[tid] = False + if sp.is_byte(tid): base_np[tid] = 1; continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): space_np[tid] = True; piece = piece[1:] + base_np[tid] = len(piece.encode("utf-8")) + return (torch.tensor(base_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(bnd_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No val files: {pattern}") + tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[:usable + 1] + +@torch.no_grad() +def eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut): + local_tok = args.val_batch_size // (world_size * grad_accum_steps) + local_seqs = local_tok // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for s in range(seq_start, seq_end, local_seqs): + se = min(s + local_seqs, seq_end) + raw = val_tokens[s * args.train_seq_len:(se * args.train_seq_len + 1)] + raw = raw.to(device=device, dtype=torch.int64, non_blocking=True) + x = raw[:-1].reshape(-1, args.train_seq_len) + y = raw[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y).detach() + n = float(y.numel()) + loss_sum += loss.to(torch.float64) * n + tok_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_lut[tgt_ids].to(torch.int16) + tb += (space_lut[tgt_ids] & ~bnd_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, tok_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + model.train() + val_loss = float((loss_sum / tok_count).item()) + bpb = float((loss_sum / tok_count / math.log(2.0) * tok_count / byte_count).item()) + return val_loss, bpb + +# ── Quantization (from competition baseline) ────────────────────────────────── + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def _keep_float(name: str, t: Tensor, pt_dtypes: dict) -> Tensor: + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + pt_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + cl = torch.maximum(torch.minimum(t32, clip[:, None]), -clip[:, None]) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(cl / sc[:, None]), -127, 127).to(torch.int8).contiguous() + return q, sc.to(INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + sc = torch.tensor(clip / 127.0 if clip > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip, clip) / sc), -127, 127).to(torch.int8).contiguous() + return q, sc + +def quantize_state_dict_int8(sd: dict): + quant, scales, dtypes, passthrough, pt_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors", + "num_nonfloat_tensors","baseline_tensor_bytes","int8_payload_bytes"), 0) + for name, tensor in sd.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = _keep_float(name, t, pt_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = _quantize_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quant[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = {"__quant_format__": "int8_clean_per_row_v1", + "quantized": quant, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if pt_dtypes: obj["passthrough_orig_dtypes"] = pt_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict) -> dict: + out = {} + qm = obj.get("qmeta", {}) + ptod = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qm.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype).contiguous() + for name, t in obj["passthrough"].items(): + ot = t.detach().cpu().contiguous() + if isinstance(ptod.get(name), str): + ot = ot.to(getattr(torch, ptod[name])).contiguous() + out[name] = ot + return out + +# ── Data loading ────────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + hdr = np.fromfile(file, dtype=" Tensor: + chunks, rem = [], n + while rem > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._next(); continue + k = min(rem, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k; rem -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum: int): + local = global_tokens // (self.world_size * grad_accum) + span = local + 1 + chunk = self.stream.take(span * self.world_size) + s = self.rank * span + raw = chunk[s:s + span].to(torch.int64) + x = raw[:-1].reshape(-1, seq_len) + y = raw[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ── Model ───────────────────────────────────────────────────────────────────── + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), b) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__(); self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, p in module.named_parameters(): + if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and p.dtype != torch.float32: + p.data = p.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + self._sl, self._cos, self._sin = 0, None, None + def forward(self, sl: int, device, dtype): + if self._cos is None or self._sl != sl or self._cos.device != device: + t = torch.arange(sl, device=device, dtype=self.inv_freq.dtype) + f = torch.outer(t, self.inv_freq.to(device)) + self._cos = f.cos()[None, None, :, :]; self._sin = f.sin()[None, None, :, :] + self._sl = sl + return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) + +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def fake_quant_per_row(w: Tensor) -> Tensor: + """Per-row percentile int8 QAT with straight-through estimator.""" + if w.ndim < 2 or w.numel() == 0: return w + w32 = w.float() + clip = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + wq = (w32.clamp(-clip[:, None], clip[:, None]) / sc[:, None]).round().clamp(-127, 127) * sc[:, None] + return w + (wq.to(w.dtype) - w).detach() + +class SharedDiffAttention(nn.Module): + """Differential Attention V2 + per-loop low-rank Q delta + GQA.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, rope_base: float, rank: int): + super().__init__() + assert dim % nh == 0 and nh % nkv == 0 + self.nh, self.nkv, self.n_rep = nh, nkv, nh // nkv + self.hd = dim // nh + self.nmem = n_mem + + self.c_q = CastedLinear(dim, dim * 2, bias=False) + self.c_k = CastedLinear(dim, nkv * self.hd * 2, bias=False) + self.c_v = CastedLinear(dim, nkv * self.hd, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + li = 0.8 - 0.6 * math.exp(-0.3) + self.lambda_init = li + self.lambda_q1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_q2 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k2 = nn.Parameter(torch.randn(nh) * 0.1) + + self.q_gain = nn.Parameter(torch.ones(nh, dtype=torch.float32)) + self.delta_A = nn.Parameter(torch.zeros(n_loops, dim * 2, rank)) + self.delta_B = nn.Parameter(torch.zeros(n_loops, rank, dim)) + nn.init.normal_(self.delta_A, std=0.01) + + self.rotary = Rotary(self.hd, base=rope_base) + + def forward(self, x: Tensor, loop_idx: int) -> Tensor: + B, T, C = x.shape + H, KVH, D, NM = self.nh, self.nkv, self.hd, self.nmem + sl = T - NM + + # Q with per-loop low-rank delta + QAT on effective weight + wq = self.c_q.weight + (self.delta_A[loop_idx] @ self.delta_B[loop_idx]) + wq = fake_quant_per_row(wq) + qp = F.linear(x, wq.to(x.dtype)) + q1 = qp[..., :C].reshape(B, T, H, D).transpose(1, 2) + q2 = qp[..., C:].reshape(B, T, H, D).transpose(1, 2) + + # K, V: fresh every loop from updated hidden state + kv = self.c_k(x) + k1 = kv[..., :KVH*D].reshape(B, T, KVH, D).transpose(1, 2) + k2 = kv[..., KVH*D:].reshape(B, T, KVH, D).transpose(1, 2) + v = self.c_v(x).reshape(B, T, KVH, D).transpose(1, 2) + + # QK RMSNorm before RoPE (baseline technique) + q1 = F.rms_norm(q1, (D,)); q2 = F.rms_norm(q2, (D,)) + k1 = F.rms_norm(k1, (D,)); k2 = F.rms_norm(k2, (D,)) + + # RoPE: sequence positions ONLY — memory tokens excluded + cos, sin = self.rotary(sl, x.device, q1.dtype) + def rope_seq(q, k): + qm, qs = q[:, :, :NM], q[:, :, NM:] + km, ks = k[:, :, :NM], k[:, :, NM:] + return (torch.cat([qm, apply_rope(qs, cos, sin)], dim=2), + torch.cat([km, apply_rope(ks, cos, sin)], dim=2)) + q1, k1 = rope_seq(q1, k1) + q2, k2 = rope_seq(q2, k2) + + # Q gain + GQA + g = self.q_gain.to(q1.dtype)[None, :, None, None] + q1 = q1 * g; q2 = q2 * g + k1 = k1.repeat_interleave(self.n_rep, dim=1) + k2 = k2.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + # Differential attention: cancel noise via subtraction + lam = (torch.exp(self.lambda_q1) * torch.exp(self.lambda_k1) + - torch.exp(self.lambda_q2) * torch.exp(self.lambda_k2) + + self.lambda_init).to(q1.dtype)[None, :, None, None] + a1 = F.scaled_dot_product_attention(q1, k1, v, is_causal=True) + a2 = F.scaled_dot_product_attention(q2, k2, v, is_causal=True) + out = (a1 - lam * a2).transpose(1, 2).contiguous().view(B, T, C) + return self.proj(out) + +class ReluSquaredMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + h = dim * mlp_mult + self.fc = CastedLinear(dim, h, bias=False) + self.proj = CastedLinear(h, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class RecurrentBlock(nn.Module): + """Shared block with all per-loop parameters.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, mlp_mult: int, rope_base: float, rank: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SharedDiffAttention(dim, nh, nkv, n_loops, n_mem, rope_base, rank) + self.mlp = ReluSquaredMLP(dim, mlp_mult) + + # Per-loop scale factors (learned attention/MLP gate per depth) + self.attn_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + + # resid_mix: learned interpolation between current x and original x0 + # [0] = weight on current state, [1] = weight on original input + # Initialised to carry forward (mix[0]=1, mix[1]=0), learns to reset selectively + self.resid_mix = nn.Parameter( + torch.stack([torch.ones(n_loops, dim), torch.zeros(n_loops, dim)]).float() + ) + + # Loop position embedding: activation-level depth signal + # Complementary to weight-level low-rank delta + self.loop_embed = nn.Embedding(n_loops, dim) + nn.init.normal_(self.loop_embed.weight, std=0.02) + + def forward(self, x: Tensor, x0: Tensor, loop_idx: int) -> Tensor: + # resid_mix: learned balance of current state vs original input (generalises input injection) + mix = self.resid_mix[:, loop_idx, :].to(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # Loop position embedding (activation-level specialisation) + x = x + self.loop_embed.weight[loop_idx].to(x.dtype)[None, None, :] + + # Attention + MLP with per-loop scale + a = self.attn(self.attn_norm(x), loop_idx) + x = x + self.attn_scale[loop_idx].to(x.dtype)[None, None, :] * a + x = x + self.mlp_scale[loop_idx].to(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class RLM(nn.Module): + """ + Recurrent Language Model. + Single shared block × N_LOOPS with deep supervision and U-Net skips. + """ + def __init__(self, args: Hyperparameters): + super().__init__() + dim, nl, nm = args.model_dim, args.n_loops, args.n_memory + self.n_loops, self.n_memory = nl, nm + self.dim = dim + self.logit_softcap = args.logit_softcap + self.deep_sup_gamma = args.deep_sup_gamma + + self.n_enc = nl // 2 + self.n_skip = min(self.n_enc, nl - self.n_enc) + + self.tok_emb = nn.Embedding(args.vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + self.memory = nn.Parameter(torch.randn(1, nm, dim) * 0.02) + + self.block = RecurrentBlock( + dim, args.num_heads, args.num_kv_heads, nl, nm, + args.mlp_mult, args.rope_base, args.lora_rank + ) + + self.skip_weights = nn.Parameter( + torch.ones(self.n_skip, dim, dtype=torch.float32) + if self.n_skip > 0 else torch.zeros(0, dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + + # Zero-init output projections for stability at loop depth + nn.init.zeros_(self.block.attn.proj.weight) + nn.init.zeros_(self.block.mlp.proj.weight) + + def _logits(self, x: Tensor) -> Tensor: + h = self.final_norm(x[:, self.n_memory:, :]) + l = F.linear(h.reshape(-1, self.dim), self.tok_emb.weight) + return self.logit_softcap * torch.tanh(l / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + B, T = input_ids.shape + tok = F.rms_norm(self.tok_emb(input_ids), (self.dim,)) + x = torch.cat([self.memory.expand(B, -1, -1), tok], dim=1) + x0 = x + + gamma = self.deep_sup_gamma + total_loss = torch.zeros((), device=input_ids.device) + weight_sum = 0.0 + targets = target_ids.reshape(-1) + skips = [] + + for i in range(self.n_loops): + # Gradient checkpointing: recompute activations during backward + x = grad_ckpt(self.block, x, x0, i, use_reentrant=False) + + # U-Net: encoder half saves, decoder half consumes in reverse + if i < self.n_enc: + skips.append(x) + elif self.n_skip > 0 and skips: + di = i - self.n_enc + if di < self.n_skip: + sw = self.skip_weights[di].to(x.dtype)[None, None, :] + x = x + sw * skips[self.n_enc - 1 - di] + + # Deep supervision: weighted loss at every loop + # Later loops weighted higher (gamma^(N-1-i)), but all contribute + w = gamma ** (self.n_loops - 1 - i) + total_loss = total_loss + w * F.cross_entropy( + self._logits(x).float(), targets, reduction="mean" + ) + weight_sum += w + + return total_loss / weight_sum + +# ── Training ────────────────────────────────────────────────────────────────── + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + global zeropower_via_newtonschulz5 + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ── Distributed + CUDA setup ────────────────────────────────────────────── + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required. For CPU testing use train_gpt_local.py") + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(True); enable_math_sdp(True) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master else None + + def log0(msg: str, console: bool = True) -> None: + if not master: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + + # ── Seed + tokenizer ────────────────────────────────────────────────────── + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"Tokenizer vocab {sp.vocab_size()} != VOCAB_SIZE {args.vocab_size}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_lut, space_lut, bnd_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_tokens:{val_tokens.numel()-1} train_seq_len:{args.train_seq_len}") + + # ── Model ───────────────────────────────────────────────────────────────── + base_model = RLM(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(base_model) + compiled = torch.compile(base_model, dynamic=False) + model: nn.Module = (DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) + if distributed else compiled) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} (~{n_params/1e6:.2f}M)") + log0(f"n_loops:{args.n_loops} n_memory:{args.n_memory} dim:{args.model_dim} " + f"seq_len:{args.train_seq_len}") + + # ── Optimizers ──────────────────────────────────────────────────────────── + block_named = list(base_model.block.named_parameters()) + matrix_params = [p for n, p in block_named + if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = ([p for n, p in block_named + if p.ndim != 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + + [base_model.skip_weights, base_model.memory, base_model.final_norm.eps + if hasattr(base_model.final_norm, 'eps') and + isinstance(base_model.final_norm.eps, nn.Parameter) else None]) + scalar_params = [p for p in scalar_params if p is not None] + + opt_embed = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_muon = Muon(matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_embed, opt_muon, opt_scalar] + + log0(f"matrix_params:{sum(p.numel() for p in matrix_params):,} " + f"scalar_params:{sum(p.numel() for p in scalar_params):,}") + + # ── Checkpoint resume (thermal crash protection) ────────────────────────── + ckpt_path = f"checkpoint_{args.run_id}.pt" + start_step = 0 + training_time_ms = 0.0 + if os.path.exists(ckpt_path) and master: + log0(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + base_model.load_state_dict(ckpt["model"]) + for opt, sd in zip(optimizers, ckpt["optimizers"]): + opt.load_state_dict(sd) + start_step = ckpt["step"] + training_time_ms = ckpt.get("training_time_ms", 0.0) + log0(f"Resumed at step {start_step}") + if distributed: dist.barrier() + + # ── Data loader ─────────────────────────────────────────────────────────── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad(): + for o in optimizers: o.zero_grad(set_to_none=True) + + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_scale(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wc_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ + if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = args.warmdown_iters * step_ms + rem_ms = max(max_wc_ms - elapsed_ms, 0.0) + return rem_ms / max(wd_ms, 1e-9) if rem_ms <= wd_ms else 1.0 + + # ── Warmup-then-restore ─────────────────────────────────────────────────── + if args.warmup_steps > 0 and start_step == 0: + init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + (model(x, y) * grad_scale).backward() + for o in optimizers: o.step() + zero_grad() + log0(f"warmup_complete:{args.warmup_steps} steps") + base_model.load_state_dict(init_model, strict=True) + for o, sd in zip(optimizers, init_opts): o.load_state_dict(sd) + zero_grad() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ──────────────────────────────────────────────────── + stop_after: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + for step in range(start_step, args.iterations + 1): + last = step == args.iterations or (stop_after is not None and step >= stop_after) + + # Validation + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or last: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + + if last: break + + # Checkpoint save (thermal crash protection) + if args.save_every > 0 and step > start_step and step % args.save_every == 0 and master: + torch.cuda.synchronize() + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + torch.save({"step": step, "model": base_model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "training_time_ms": elapsed_now}, ckpt_path) + log0(f"checkpoint_saved:step={step}") + + # LR update + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_scale(step, elapsed_ms) + for o in optimizers: + for g in o.param_groups: g["lr"] = g["base_lr"] * scale + + # Muon momentum warmup + frac = min(step / max(args.muon_momentum_warmup_steps, 1), 1.0) + mum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for g in opt_muon.param_groups: g["momentum"] = mum + + # Forward + backward with gradient accumulation + zero_grad() + train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for o in optimizers: o.step() + zero_grad() + + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step+1}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ms:.0f}ms step_avg:{approx_ms/(step+1):.2f}ms lr_scale:{scale:.4f}") + + reached = max_wc_ms is not None and approx_ms >= max_wc_ms + if distributed and max_wc_ms is not None: + rt = torch.tensor(int(reached), device=device) + dist.all_reduce(rt, op=dist.ReduceOp.MAX) + reached = bool(rt.item()) + if stop_after is None and reached: + stop_after = step + 1 + + log0(f"peak_memory:{torch.cuda.max_memory_allocated()//1024//1024}MiB " + f"reserved:{torch.cuda.max_memory_reserved()//1024//1024}MiB") + + # ── Serialisation + roundtrip validation ────────────────────────────────── + if master: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + model_bytes = os.path.getsize("final_model.pt") + log0(f"raw_model:{model_bytes} code:{code_bytes} total:{model_bytes+code_bytes}") + + quant_obj, qstats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + if master: + with open("final_model.int8.ptz", "wb") as f: f.write(blob) + qbytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = qstats["baseline_tensor_bytes"] / max(qstats["int8_payload_bytes"], 1) + log0(f"int8_zlib:{qbytes} code:{code_bytes} total:{qbytes+code_bytes} " + f"payload_ratio:{ratio:.2f}x budget_pct:{(qbytes+code_bytes)/16e6*100:.1f}%") + + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: blob_disk = f.read() + qs2 = torch.load(io.BytesIO(zlib.decompress(blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(qs2), strict=True) + torch.cuda.synchronize(); te = time.perf_counter() + qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + torch.cuda.synchronize() + log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-te):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvbpb:.8f}") + + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +PyTorch 2.10.0+cu128 +Fri Mar 20 06:45:50 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 40C P0 26W / 70W | 105MiB / 15360MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 11994 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +val_tokens:62021632 train_seq_len:1024 +model_params:4255784 (~4.26M) +n_loops:16 n_memory:32 dim:512 seq_len:1024 +matrix_params:3,276,800 scalar_params:454,696 +""" +train_gpt.py — Shared-Weight Recurrent Memory Transformer (RLM) +OpenAI Parameter Golf Competition — Non-Record Submission + +Architecture: One transformer block looped N_LOOPS=16 times (weight tying). + - Differential Attention V2 + per-loop low-rank Q delta (weight specialization) + - resid_mix: learned carry-forward vs input-reset per loop + - Loop position embeddings (activation-level depth signal) + - Memory tokens: learned persistent cross-loop state (RoPE excluded) + - Deep supervision: gamma-weighted loss at every loop iteration + - U-Net skip connections between encoder/decoder loop halves + - Gradient checkpointing: memory-efficient backprop through 16 loops + - QAT: per-row percentile fake-quantization during training + - relu² MLP, logit softcap, QK RMSNorm, q_gain (from baseline) + +At dim=512, 16 loops: ~4.24M params, ~4.24MB int8 — 26.5% of 16MB budget. +Effective compute depth of a ~63M parameter standard transformer. + +Thermal protection: checkpoints saved every SAVE_EVERY steps, auto-resumed. +""" + +from __future__ import annotations +import copy, glob, io, math, os, subprocess, sys, time, uuid, zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.checkpoint import checkpoint as grad_ckpt + +# ── Hyperparameters ─────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + save_every = int(os.environ.get("SAVE_EVERY", 500)) + + iterations = int(os.environ.get("ITERATIONS", 5000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 600)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 65_536)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0.0)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + n_loops = int(os.environ.get("N_LOOPS", 16)) + n_memory = int(os.environ.get("N_MEMORY", 32)) + mlp_mult = int(os.environ.get("MLP_MULT", 4)) + lora_rank = int(os.environ.get("LORA_RANK", 16)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + deep_sup_gamma= float(os.environ.get("DEEP_SUP_GAMMA", 0.9)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum= float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 300)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.1)) + +# ── Muon Optimizer (from modded-nanogpt / competition baseline) ─────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum, backend_steps, nesterov = ( + group["lr"], group["momentum"], group["backend_steps"], group["nesterov"] + ) + total = sum(int(p.numel()) for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-lr) + curr += p.numel() + return loss + +# ── Control tensor patterns (kept fp32, excluded from Muon) ────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,mlp_scale,resid_mix,q_gain,lambda_q1,lambda_k1,lambda_q2,lambda_k2," + "loop_embed,skip_weight,memory", +).split(",") if p) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# ── Tokenizer-aware BPB (from competition baseline) ─────────────────────────── + +def build_sentencepiece_luts(sp, vocab_size: int, device: torch.device): + sp_vocab = int(sp.vocab_size()) + table = max(sp_vocab, vocab_size) + base_np = np.zeros((table,), dtype=np.int16) + space_np = np.zeros((table,), dtype=np.bool_) + bnd_np = np.ones((table,), dtype=np.bool_) + for tid in range(sp_vocab): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue + bnd_np[tid] = False + if sp.is_byte(tid): base_np[tid] = 1; continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): space_np[tid] = True; piece = piece[1:] + base_np[tid] = len(piece.encode("utf-8")) + return (torch.tensor(base_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(bnd_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No val files: {pattern}") + tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[:usable + 1] + +@torch.no_grad() +def eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut): + local_tok = args.val_batch_size // (world_size * grad_accum_steps) + local_seqs = local_tok // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for s in range(seq_start, seq_end, local_seqs): + se = min(s + local_seqs, seq_end) + raw = val_tokens[s * args.train_seq_len:(se * args.train_seq_len + 1)] + raw = raw.to(device=device, dtype=torch.int64, non_blocking=True) + x = raw[:-1].reshape(-1, args.train_seq_len) + y = raw[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y).detach() + n = float(y.numel()) + loss_sum += loss.to(torch.float64) * n + tok_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_lut[tgt_ids].to(torch.int16) + tb += (space_lut[tgt_ids] & ~bnd_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, tok_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + model.train() + val_loss = float((loss_sum / tok_count).item()) + bpb = float((loss_sum / tok_count / math.log(2.0) * tok_count / byte_count).item()) + return val_loss, bpb + +# ── Quantization (from competition baseline) ────────────────────────────────── + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def _keep_float(name: str, t: Tensor, pt_dtypes: dict) -> Tensor: + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + pt_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + cl = torch.maximum(torch.minimum(t32, clip[:, None]), -clip[:, None]) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(cl / sc[:, None]), -127, 127).to(torch.int8).contiguous() + return q, sc.to(INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + sc = torch.tensor(clip / 127.0 if clip > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip, clip) / sc), -127, 127).to(torch.int8).contiguous() + return q, sc + +def quantize_state_dict_int8(sd: dict): + quant, scales, dtypes, passthrough, pt_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors", + "num_nonfloat_tensors","baseline_tensor_bytes","int8_payload_bytes"), 0) + for name, tensor in sd.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = _keep_float(name, t, pt_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = _quantize_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quant[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = {"__quant_format__": "int8_clean_per_row_v1", + "quantized": quant, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if pt_dtypes: obj["passthrough_orig_dtypes"] = pt_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict) -> dict: + out = {} + qm = obj.get("qmeta", {}) + ptod = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qm.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype).contiguous() + for name, t in obj["passthrough"].items(): + ot = t.detach().cpu().contiguous() + if isinstance(ptod.get(name), str): + ot = ot.to(getattr(torch, ptod[name])).contiguous() + out[name] = ot + return out + +# ── Data loading ────────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + hdr = np.fromfile(file, dtype=" Tensor: + chunks, rem = [], n + while rem > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._next(); continue + k = min(rem, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k; rem -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum: int): + local = global_tokens // (self.world_size * grad_accum) + span = local + 1 + chunk = self.stream.take(span * self.world_size) + s = self.rank * span + raw = chunk[s:s + span].to(torch.int64) + x = raw[:-1].reshape(-1, seq_len) + y = raw[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ── Model ───────────────────────────────────────────────────────────────────── + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), b) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__(); self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, p in module.named_parameters(): + if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and p.dtype != torch.float32: + p.data = p.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + self._sl, self._cos, self._sin = 0, None, None + def forward(self, sl: int, device, dtype): + if self._cos is None or self._sl != sl or self._cos.device != device: + t = torch.arange(sl, device=device, dtype=self.inv_freq.dtype) + f = torch.outer(t, self.inv_freq.to(device)) + self._cos = f.cos()[None, None, :, :]; self._sin = f.sin()[None, None, :, :] + self._sl = sl + return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) + +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def fake_quant_per_row(w: Tensor) -> Tensor: + """Per-row percentile int8 QAT with straight-through estimator.""" + if w.ndim < 2 or w.numel() == 0: return w + w32 = w.float() + clip = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + wq = (w32.clamp(-clip[:, None], clip[:, None]) / sc[:, None]).round().clamp(-127, 127) * sc[:, None] + return w + (wq.to(w.dtype) - w).detach() + +class SharedDiffAttention(nn.Module): + """Differential Attention V2 + per-loop low-rank Q delta + GQA.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, rope_base: float, rank: int): + super().__init__() + assert dim % nh == 0 and nh % nkv == 0 + self.nh, self.nkv, self.n_rep = nh, nkv, nh // nkv + self.hd = dim // nh + self.nmem = n_mem + + self.c_q = CastedLinear(dim, dim * 2, bias=False) + self.c_k = CastedLinear(dim, nkv * self.hd * 2, bias=False) + self.c_v = CastedLinear(dim, nkv * self.hd, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + li = 0.8 - 0.6 * math.exp(-0.3) + self.lambda_init = li + self.lambda_q1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_q2 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k2 = nn.Parameter(torch.randn(nh) * 0.1) + + self.q_gain = nn.Parameter(torch.ones(nh, dtype=torch.float32)) + self.delta_A = nn.Parameter(torch.zeros(n_loops, dim * 2, rank)) + self.delta_B = nn.Parameter(torch.zeros(n_loops, rank, dim)) + nn.init.normal_(self.delta_A, std=0.01) + + self.rotary = Rotary(self.hd, base=rope_base) + + def forward(self, x: Tensor, loop_idx: int) -> Tensor: + B, T, C = x.shape + H, KVH, D, NM = self.nh, self.nkv, self.hd, self.nmem + sl = T - NM + + # Q with per-loop low-rank delta + QAT on effective weight + wq = self.c_q.weight + (self.delta_A[loop_idx] @ self.delta_B[loop_idx]) + wq = fake_quant_per_row(wq) + qp = F.linear(x, wq.to(x.dtype)) + q1 = qp[..., :C].reshape(B, T, H, D).transpose(1, 2) + q2 = qp[..., C:].reshape(B, T, H, D).transpose(1, 2) + + # K, V: fresh every loop from updated hidden state + kv = self.c_k(x) + k1 = kv[..., :KVH*D].reshape(B, T, KVH, D).transpose(1, 2) + k2 = kv[..., KVH*D:].reshape(B, T, KVH, D).transpose(1, 2) + v = self.c_v(x).reshape(B, T, KVH, D).transpose(1, 2) + + # QK RMSNorm before RoPE (baseline technique) + q1 = F.rms_norm(q1, (D,)); q2 = F.rms_norm(q2, (D,)) + k1 = F.rms_norm(k1, (D,)); k2 = F.rms_norm(k2, (D,)) + + # RoPE: sequence positions ONLY — memory tokens excluded + cos, sin = self.rotary(sl, x.device, q1.dtype) + def rope_seq(q, k): + qm, qs = q[:, :, :NM], q[:, :, NM:] + km, ks = k[:, :, :NM], k[:, :, NM:] + return (torch.cat([qm, apply_rope(qs, cos, sin)], dim=2), + torch.cat([km, apply_rope(ks, cos, sin)], dim=2)) + q1, k1 = rope_seq(q1, k1) + q2, k2 = rope_seq(q2, k2) + + # Q gain + GQA + g = self.q_gain.to(q1.dtype)[None, :, None, None] + q1 = q1 * g; q2 = q2 * g + k1 = k1.repeat_interleave(self.n_rep, dim=1) + k2 = k2.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + # Differential attention: cancel noise via subtraction + lam = (torch.exp(self.lambda_q1) * torch.exp(self.lambda_k1) + - torch.exp(self.lambda_q2) * torch.exp(self.lambda_k2) + + self.lambda_init).to(q1.dtype)[None, :, None, None] + a1 = F.scaled_dot_product_attention(q1, k1, v, is_causal=True) + a2 = F.scaled_dot_product_attention(q2, k2, v, is_causal=True) + out = (a1 - lam * a2).transpose(1, 2).contiguous().view(B, T, C) + return self.proj(out) + +class ReluSquaredMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + h = dim * mlp_mult + self.fc = CastedLinear(dim, h, bias=False) + self.proj = CastedLinear(h, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class RecurrentBlock(nn.Module): + """Shared block with all per-loop parameters.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, mlp_mult: int, rope_base: float, rank: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SharedDiffAttention(dim, nh, nkv, n_loops, n_mem, rope_base, rank) + self.mlp = ReluSquaredMLP(dim, mlp_mult) + + # Per-loop scale factors (learned attention/MLP gate per depth) + self.attn_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + + # resid_mix: learned interpolation between current x and original x0 + # [0] = weight on current state, [1] = weight on original input + # Initialised to carry forward (mix[0]=1, mix[1]=0), learns to reset selectively + self.resid_mix = nn.Parameter( + torch.stack([torch.ones(n_loops, dim), torch.zeros(n_loops, dim)]).float() + ) + + # Loop position embedding: activation-level depth signal + # Complementary to weight-level low-rank delta + self.loop_embed = nn.Embedding(n_loops, dim) + nn.init.normal_(self.loop_embed.weight, std=0.02) + + def forward(self, x: Tensor, x0: Tensor, loop_idx: int) -> Tensor: + # resid_mix: learned balance of current state vs original input (generalises input injection) + mix = self.resid_mix[:, loop_idx, :].to(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # Loop position embedding (activation-level specialisation) + x = x + self.loop_embed.weight[loop_idx].to(x.dtype)[None, None, :] + + # Attention + MLP with per-loop scale + a = self.attn(self.attn_norm(x), loop_idx) + x = x + self.attn_scale[loop_idx].to(x.dtype)[None, None, :] * a + x = x + self.mlp_scale[loop_idx].to(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class RLM(nn.Module): + """ + Recurrent Language Model. + Single shared block × N_LOOPS with deep supervision and U-Net skips. + """ + def __init__(self, args: Hyperparameters): + super().__init__() + dim, nl, nm = args.model_dim, args.n_loops, args.n_memory + self.n_loops, self.n_memory = nl, nm + self.dim = dim + self.logit_softcap = args.logit_softcap + self.deep_sup_gamma = args.deep_sup_gamma + + self.n_enc = nl // 2 + self.n_skip = min(self.n_enc, nl - self.n_enc) + + self.tok_emb = nn.Embedding(args.vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + self.memory = nn.Parameter(torch.randn(1, nm, dim) * 0.02) + + self.block = RecurrentBlock( + dim, args.num_heads, args.num_kv_heads, nl, nm, + args.mlp_mult, args.rope_base, args.lora_rank + ) + + self.skip_weights = nn.Parameter( + torch.ones(self.n_skip, dim, dtype=torch.float32) + if self.n_skip > 0 else torch.zeros(0, dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + + # Zero-init output projections for stability at loop depth + nn.init.zeros_(self.block.attn.proj.weight) + nn.init.zeros_(self.block.mlp.proj.weight) + + def _logits(self, x: Tensor) -> Tensor: + h = self.final_norm(x[:, self.n_memory:, :]) + l = F.linear(h.reshape(-1, self.dim), self.tok_emb.weight) + return self.logit_softcap * torch.tanh(l / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + B, T = input_ids.shape + tok = F.rms_norm(self.tok_emb(input_ids), (self.dim,)) + x = torch.cat([self.memory.expand(B, -1, -1), tok], dim=1) + x0 = x + + gamma = self.deep_sup_gamma + total_loss = torch.zeros((), device=input_ids.device) + weight_sum = 0.0 + targets = target_ids.reshape(-1) + skips = [] + + for i in range(self.n_loops): + # Gradient checkpointing: recompute activations during backward + x = grad_ckpt(self.block, x, x0, i, use_reentrant=False) + + # U-Net: encoder half saves, decoder half consumes in reverse + if i < self.n_enc: + skips.append(x) + elif self.n_skip > 0 and skips: + di = i - self.n_enc + if di < self.n_skip: + sw = self.skip_weights[di].to(x.dtype)[None, None, :] + x = x + sw * skips[self.n_enc - 1 - di] + + # Deep supervision: weighted loss at every loop + # Later loops weighted higher (gamma^(N-1-i)), but all contribute + w = gamma ** (self.n_loops - 1 - i) + total_loss = total_loss + w * F.cross_entropy( + self._logits(x).float(), targets, reduction="mean" + ) + weight_sum += w + + return total_loss / weight_sum + +# ── Training ────────────────────────────────────────────────────────────────── + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + global zeropower_via_newtonschulz5 + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ── Distributed + CUDA setup ────────────────────────────────────────────── + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required. For CPU testing use train_gpt_local.py") + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(True); enable_math_sdp(True) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master else None + + def log0(msg: str, console: bool = True) -> None: + if not master: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + + # ── Seed + tokenizer ────────────────────────────────────────────────────── + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"Tokenizer vocab {sp.vocab_size()} != VOCAB_SIZE {args.vocab_size}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_lut, space_lut, bnd_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_tokens:{val_tokens.numel()-1} train_seq_len:{args.train_seq_len}") + + # ── Model ───────────────────────────────────────────────────────────────── + base_model = RLM(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(base_model) + compiled = torch.compile(base_model, dynamic=False) + model: nn.Module = (DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) + if distributed else compiled) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} (~{n_params/1e6:.2f}M)") + log0(f"n_loops:{args.n_loops} n_memory:{args.n_memory} dim:{args.model_dim} " + f"seq_len:{args.train_seq_len}") + + # ── Optimizers ──────────────────────────────────────────────────────────── + block_named = list(base_model.block.named_parameters()) + matrix_params = [p for n, p in block_named + if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = ([p for n, p in block_named + if p.ndim != 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + + [base_model.skip_weights, base_model.memory, base_model.final_norm.eps + if hasattr(base_model.final_norm, 'eps') and + isinstance(base_model.final_norm.eps, nn.Parameter) else None]) + scalar_params = [p for p in scalar_params if p is not None] + + opt_embed = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_muon = Muon(matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_embed, opt_muon, opt_scalar] + + log0(f"matrix_params:{sum(p.numel() for p in matrix_params):,} " + f"scalar_params:{sum(p.numel() for p in scalar_params):,}") + + # ── Checkpoint resume (thermal crash protection) ────────────────────────── + ckpt_path = f"checkpoint_{args.run_id}.pt" + start_step = 0 + training_time_ms = 0.0 + if os.path.exists(ckpt_path) and master: + log0(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + base_model.load_state_dict(ckpt["model"]) + for opt, sd in zip(optimizers, ckpt["optimizers"]): + opt.load_state_dict(sd) + start_step = ckpt["step"] + training_time_ms = ckpt.get("training_time_ms", 0.0) + log0(f"Resumed at step {start_step}") + if distributed: dist.barrier() + + # ── Data loader ─────────────────────────────────────────────────────────── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad(): + for o in optimizers: o.zero_grad(set_to_none=True) + + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_scale(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wc_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ + if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = args.warmdown_iters * step_ms + rem_ms = max(max_wc_ms - elapsed_ms, 0.0) + return rem_ms / max(wd_ms, 1e-9) if rem_ms <= wd_ms else 1.0 + + # ── Warmup-then-restore ─────────────────────────────────────────────────── + if args.warmup_steps > 0 and start_step == 0: + init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + (model(x, y) * grad_scale).backward() + for o in optimizers: o.step() + zero_grad() + log0(f"warmup_complete:{args.warmup_steps} steps") + base_model.load_state_dict(init_model, strict=True) + for o, sd in zip(optimizers, init_opts): o.load_state_dict(sd) + zero_grad() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ──────────────────────────────────────────────────── + stop_after: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + for step in range(start_step, args.iterations + 1): + last = step == args.iterations or (stop_after is not None and step >= stop_after) + + # Validation + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or last: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + + if last: break + + # Checkpoint save (thermal crash protection) + if args.save_every > 0 and step > start_step and step % args.save_every == 0 and master: + torch.cuda.synchronize() + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + torch.save({"step": step, "model": base_model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "training_time_ms": elapsed_now}, ckpt_path) + log0(f"checkpoint_saved:step={step}") + + # LR update + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_scale(step, elapsed_ms) + for o in optimizers: + for g in o.param_groups: g["lr"] = g["base_lr"] * scale + + # Muon momentum warmup + frac = min(step / max(args.muon_momentum_warmup_steps, 1), 1.0) + mum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for g in opt_muon.param_groups: g["momentum"] = mum + + # Forward + backward with gradient accumulation + zero_grad() + train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for o in optimizers: o.step() + zero_grad() + + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step+1}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ms:.0f}ms step_avg:{approx_ms/(step+1):.2f}ms lr_scale:{scale:.4f}") + + reached = max_wc_ms is not None and approx_ms >= max_wc_ms + if distributed and max_wc_ms is not None: + rt = torch.tensor(int(reached), device=device) + dist.all_reduce(rt, op=dist.ReduceOp.MAX) + reached = bool(rt.item()) + if stop_after is None and reached: + stop_after = step + 1 + + log0(f"peak_memory:{torch.cuda.max_memory_allocated()//1024//1024}MiB " + f"reserved:{torch.cuda.max_memory_reserved()//1024//1024}MiB") + + # ── Serialisation + roundtrip validation ────────────────────────────────── + if master: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + model_bytes = os.path.getsize("final_model.pt") + log0(f"raw_model:{model_bytes} code:{code_bytes} total:{model_bytes+code_bytes}") + + quant_obj, qstats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + if master: + with open("final_model.int8.ptz", "wb") as f: f.write(blob) + qbytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = qstats["baseline_tensor_bytes"] / max(qstats["int8_payload_bytes"], 1) + log0(f"int8_zlib:{qbytes} code:{code_bytes} total:{qbytes+code_bytes} " + f"payload_ratio:{ratio:.2f}x budget_pct:{(qbytes+code_bytes)/16e6*100:.1f}%") + + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: blob_disk = f.read() + qs2 = torch.load(io.BytesIO(zlib.decompress(blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(qs2), strict=True) + torch.cuda.synchronize(); te = time.perf_counter() + qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + torch.cuda.synchronize() + log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-te):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvbpb:.8f}") + + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +PyTorch 2.10.0+cu128 +Fri Mar 20 06:49:17 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 60C P0 28W / 70W | 105MiB / 15360MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 12904 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +val_tokens:62021632 train_seq_len:1024 +model_params:4255784 (~4.26M) +n_loops:16 n_memory:32 dim:512 seq_len:1024 +matrix_params:3,276,800 scalar_params:454,696 +""" +train_gpt.py — Shared-Weight Recurrent Memory Transformer (RLM) +OpenAI Parameter Golf Competition — Non-Record Submission + +Architecture: One transformer block looped N_LOOPS=16 times (weight tying). + - Differential Attention V2 + per-loop low-rank Q delta (weight specialization) + - resid_mix: learned carry-forward vs input-reset per loop + - Loop position embeddings (activation-level depth signal) + - Memory tokens: learned persistent cross-loop state (RoPE excluded) + - Deep supervision: gamma-weighted loss at every loop iteration + - U-Net skip connections between encoder/decoder loop halves + - Gradient checkpointing: memory-efficient backprop through 16 loops + - QAT: per-row percentile fake-quantization during training + - relu² MLP, logit softcap, QK RMSNorm, q_gain (from baseline) + +At dim=512, 16 loops: ~4.24M params, ~4.24MB int8 — 26.5% of 16MB budget. +Effective compute depth of a ~63M parameter standard transformer. + +Thermal protection: checkpoints saved every SAVE_EVERY steps, auto-resumed. +""" + +from __future__ import annotations +import copy, glob, io, math, os, subprocess, sys, time, uuid, zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.checkpoint import checkpoint as grad_ckpt + +# ── Hyperparameters ─────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + save_every = int(os.environ.get("SAVE_EVERY", 500)) + + iterations = int(os.environ.get("ITERATIONS", 5000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 600)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 65_536)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0.0)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + n_loops = int(os.environ.get("N_LOOPS", 16)) + n_memory = int(os.environ.get("N_MEMORY", 32)) + mlp_mult = int(os.environ.get("MLP_MULT", 4)) + lora_rank = int(os.environ.get("LORA_RANK", 16)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + deep_sup_gamma= float(os.environ.get("DEEP_SUP_GAMMA", 0.9)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum= float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 300)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.1)) + +# ── Muon Optimizer (from modded-nanogpt / competition baseline) ─────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum, backend_steps, nesterov = ( + group["lr"], group["momentum"], group["backend_steps"], group["nesterov"] + ) + total = sum(int(p.numel()) for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-lr) + curr += p.numel() + return loss + +# ── Control tensor patterns (kept fp32, excluded from Muon) ────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,mlp_scale,resid_mix,q_gain,lambda_q1,lambda_k1,lambda_q2,lambda_k2," + "loop_embed,skip_weight,memory", +).split(",") if p) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# ── Tokenizer-aware BPB (from competition baseline) ─────────────────────────── + +def build_sentencepiece_luts(sp, vocab_size: int, device: torch.device): + sp_vocab = int(sp.vocab_size()) + table = max(sp_vocab, vocab_size) + base_np = np.zeros((table,), dtype=np.int16) + space_np = np.zeros((table,), dtype=np.bool_) + bnd_np = np.ones((table,), dtype=np.bool_) + for tid in range(sp_vocab): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue + bnd_np[tid] = False + if sp.is_byte(tid): base_np[tid] = 1; continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): space_np[tid] = True; piece = piece[1:] + base_np[tid] = len(piece.encode("utf-8")) + return (torch.tensor(base_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(bnd_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No val files: {pattern}") + tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[:usable + 1] + +@torch.no_grad() +def eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut): + local_tok = args.val_batch_size // (world_size * grad_accum_steps) + local_seqs = local_tok // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for s in range(seq_start, seq_end, local_seqs): + se = min(s + local_seqs, seq_end) + raw = val_tokens[s * args.train_seq_len:(se * args.train_seq_len + 1)] + raw = raw.to(device=device, dtype=torch.int64, non_blocking=True) + x = raw[:-1].reshape(-1, args.train_seq_len) + y = raw[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y).detach() + n = float(y.numel()) + loss_sum += loss.to(torch.float64) * n + tok_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_lut[tgt_ids].to(torch.int16) + tb += (space_lut[tgt_ids] & ~bnd_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, tok_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + model.train() + val_loss = float((loss_sum / tok_count).item()) + bpb = float((loss_sum / tok_count / math.log(2.0) * tok_count / byte_count).item()) + return val_loss, bpb + +# ── Quantization (from competition baseline) ────────────────────────────────── + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def _keep_float(name: str, t: Tensor, pt_dtypes: dict) -> Tensor: + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + pt_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + cl = torch.maximum(torch.minimum(t32, clip[:, None]), -clip[:, None]) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(cl / sc[:, None]), -127, 127).to(torch.int8).contiguous() + return q, sc.to(INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + sc = torch.tensor(clip / 127.0 if clip > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip, clip) / sc), -127, 127).to(torch.int8).contiguous() + return q, sc + +def quantize_state_dict_int8(sd: dict): + quant, scales, dtypes, passthrough, pt_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors", + "num_nonfloat_tensors","baseline_tensor_bytes","int8_payload_bytes"), 0) + for name, tensor in sd.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = _keep_float(name, t, pt_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = _quantize_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quant[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = {"__quant_format__": "int8_clean_per_row_v1", + "quantized": quant, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if pt_dtypes: obj["passthrough_orig_dtypes"] = pt_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict) -> dict: + out = {} + qm = obj.get("qmeta", {}) + ptod = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qm.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype).contiguous() + for name, t in obj["passthrough"].items(): + ot = t.detach().cpu().contiguous() + if isinstance(ptod.get(name), str): + ot = ot.to(getattr(torch, ptod[name])).contiguous() + out[name] = ot + return out + +# ── Data loading ────────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + hdr = np.fromfile(file, dtype=" Tensor: + chunks, rem = [], n + while rem > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._next(); continue + k = min(rem, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k; rem -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum: int): + local = global_tokens // (self.world_size * grad_accum) + span = local + 1 + chunk = self.stream.take(span * self.world_size) + s = self.rank * span + raw = chunk[s:s + span].to(torch.int64) + x = raw[:-1].reshape(-1, seq_len) + y = raw[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ── Model ───────────────────────────────────────────────────────────────────── + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), b) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__(); self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, p in module.named_parameters(): + if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and p.dtype != torch.float32: + p.data = p.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + self._sl, self._cos, self._sin = 0, None, None + def forward(self, sl: int, device, dtype): + if self._cos is None or self._sl != sl or self._cos.device != device: + t = torch.arange(sl, device=device, dtype=self.inv_freq.dtype) + f = torch.outer(t, self.inv_freq.to(device)) + self._cos = f.cos()[None, None, :, :]; self._sin = f.sin()[None, None, :, :] + self._sl = sl + return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) + +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def fake_quant_per_row(w: Tensor) -> Tensor: + """Per-row percentile int8 QAT with straight-through estimator.""" + if w.ndim < 2 or w.numel() == 0: return w + w32 = w.float() + clip = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + wq = (w32.clamp(-clip[:, None], clip[:, None]) / sc[:, None]).round().clamp(-127, 127) * sc[:, None] + return w + (wq.to(w.dtype) - w).detach() + +class SharedDiffAttention(nn.Module): + """Differential Attention V2 + per-loop low-rank Q delta + GQA.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, rope_base: float, rank: int): + super().__init__() + assert dim % nh == 0 and nh % nkv == 0 + self.nh, self.nkv, self.n_rep = nh, nkv, nh // nkv + self.hd = dim // nh + self.nmem = n_mem + + self.c_q = CastedLinear(dim, dim * 2, bias=False) + self.c_k = CastedLinear(dim, nkv * self.hd * 2, bias=False) + self.c_v = CastedLinear(dim, nkv * self.hd, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + li = 0.8 - 0.6 * math.exp(-0.3) + self.lambda_init = li + self.lambda_q1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_q2 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k2 = nn.Parameter(torch.randn(nh) * 0.1) + + self.q_gain = nn.Parameter(torch.ones(nh, dtype=torch.float32)) + self.delta_A = nn.Parameter(torch.zeros(n_loops, dim * 2, rank)) + self.delta_B = nn.Parameter(torch.zeros(n_loops, rank, dim)) + nn.init.normal_(self.delta_A, std=0.01) + + self.rotary = Rotary(self.hd, base=rope_base) + + def forward(self, x: Tensor, loop_idx: int) -> Tensor: + B, T, C = x.shape + H, KVH, D, NM = self.nh, self.nkv, self.hd, self.nmem + sl = T - NM + + # Q with per-loop low-rank delta + QAT on effective weight + wq = self.c_q.weight + (self.delta_A[loop_idx] @ self.delta_B[loop_idx]) + wq = fake_quant_per_row(wq) + qp = F.linear(x, wq.to(x.dtype)) + q1 = qp[..., :C].reshape(B, T, H, D).transpose(1, 2) + q2 = qp[..., C:].reshape(B, T, H, D).transpose(1, 2) + + # K, V: fresh every loop from updated hidden state + kv = self.c_k(x) + k1 = kv[..., :KVH*D].reshape(B, T, KVH, D).transpose(1, 2) + k2 = kv[..., KVH*D:].reshape(B, T, KVH, D).transpose(1, 2) + v = self.c_v(x).reshape(B, T, KVH, D).transpose(1, 2) + + # QK RMSNorm before RoPE (baseline technique) + q1 = F.rms_norm(q1, (D,)); q2 = F.rms_norm(q2, (D,)) + k1 = F.rms_norm(k1, (D,)); k2 = F.rms_norm(k2, (D,)) + + # RoPE: sequence positions ONLY — memory tokens excluded + cos, sin = self.rotary(sl, x.device, q1.dtype) + def rope_seq(q, k): + qm, qs = q[:, :, :NM], q[:, :, NM:] + km, ks = k[:, :, :NM], k[:, :, NM:] + return (torch.cat([qm, apply_rope(qs, cos, sin)], dim=2), + torch.cat([km, apply_rope(ks, cos, sin)], dim=2)) + q1, k1 = rope_seq(q1, k1) + q2, k2 = rope_seq(q2, k2) + + # Q gain + GQA + g = self.q_gain.to(q1.dtype)[None, :, None, None] + q1 = q1 * g; q2 = q2 * g + k1 = k1.repeat_interleave(self.n_rep, dim=1) + k2 = k2.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + # Differential attention: cancel noise via subtraction + lam = (torch.exp(self.lambda_q1) * torch.exp(self.lambda_k1) + - torch.exp(self.lambda_q2) * torch.exp(self.lambda_k2) + + self.lambda_init).to(q1.dtype)[None, :, None, None] + a1 = F.scaled_dot_product_attention(q1, k1, v, is_causal=True) + a2 = F.scaled_dot_product_attention(q2, k2, v, is_causal=True) + out = (a1 - lam * a2).transpose(1, 2).contiguous().view(B, T, C) + return self.proj(out) + +class ReluSquaredMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + h = dim * mlp_mult + self.fc = CastedLinear(dim, h, bias=False) + self.proj = CastedLinear(h, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class RecurrentBlock(nn.Module): + """Shared block with all per-loop parameters.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, mlp_mult: int, rope_base: float, rank: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SharedDiffAttention(dim, nh, nkv, n_loops, n_mem, rope_base, rank) + self.mlp = ReluSquaredMLP(dim, mlp_mult) + + # Per-loop scale factors (learned attention/MLP gate per depth) + self.attn_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + + # resid_mix: learned interpolation between current x and original x0 + # [0] = weight on current state, [1] = weight on original input + # Initialised to carry forward (mix[0]=1, mix[1]=0), learns to reset selectively + self.resid_mix = nn.Parameter( + torch.stack([torch.ones(n_loops, dim), torch.zeros(n_loops, dim)]).float() + ) + + # Loop position embedding: activation-level depth signal + # Complementary to weight-level low-rank delta + self.loop_embed = nn.Embedding(n_loops, dim) + nn.init.normal_(self.loop_embed.weight, std=0.02) + + def forward(self, x: Tensor, x0: Tensor, loop_idx: int) -> Tensor: + # resid_mix: learned balance of current state vs original input (generalises input injection) + mix = self.resid_mix[:, loop_idx, :].to(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # Loop position embedding (activation-level specialisation) + x = x + self.loop_embed.weight[loop_idx].to(x.dtype)[None, None, :] + + # Attention + MLP with per-loop scale + a = self.attn(self.attn_norm(x), loop_idx) + x = x + self.attn_scale[loop_idx].to(x.dtype)[None, None, :] * a + x = x + self.mlp_scale[loop_idx].to(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class RLM(nn.Module): + """ + Recurrent Language Model. + Single shared block × N_LOOPS with deep supervision and U-Net skips. + """ + def __init__(self, args: Hyperparameters): + super().__init__() + dim, nl, nm = args.model_dim, args.n_loops, args.n_memory + self.n_loops, self.n_memory = nl, nm + self.dim = dim + self.logit_softcap = args.logit_softcap + self.deep_sup_gamma = args.deep_sup_gamma + + self.n_enc = nl // 2 + self.n_skip = min(self.n_enc, nl - self.n_enc) + + self.tok_emb = nn.Embedding(args.vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + self.memory = nn.Parameter(torch.randn(1, nm, dim) * 0.02) + + self.block = RecurrentBlock( + dim, args.num_heads, args.num_kv_heads, nl, nm, + args.mlp_mult, args.rope_base, args.lora_rank + ) + + self.skip_weights = nn.Parameter( + torch.ones(self.n_skip, dim, dtype=torch.float32) + if self.n_skip > 0 else torch.zeros(0, dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + + # Zero-init output projections for stability at loop depth + nn.init.zeros_(self.block.attn.proj.weight) + nn.init.zeros_(self.block.mlp.proj.weight) + + def _logits(self, x: Tensor) -> Tensor: + h = self.final_norm(x[:, self.n_memory:, :]) + l = F.linear(h.reshape(-1, self.dim), self.tok_emb.weight) + return self.logit_softcap * torch.tanh(l / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + B, T = input_ids.shape + tok = F.rms_norm(self.tok_emb(input_ids), (self.dim,)) + x = torch.cat([self.memory.expand(B, -1, -1), tok], dim=1) + x0 = x + + gamma = self.deep_sup_gamma + total_loss = torch.zeros((), device=input_ids.device) + weight_sum = 0.0 + targets = target_ids.reshape(-1) + skips = [] + + for i in range(self.n_loops): + # Gradient checkpointing: recompute activations during backward + x = grad_ckpt(self.block, x, x0, i, use_reentrant=False) + + # U-Net: encoder half saves, decoder half consumes in reverse + if i < self.n_enc: + skips.append(x) + elif self.n_skip > 0 and skips: + di = i - self.n_enc + if di < self.n_skip: + sw = self.skip_weights[di].to(x.dtype)[None, None, :] + x = x + sw * skips[self.n_enc - 1 - di] + + # Deep supervision: weighted loss at every loop + # Later loops weighted higher (gamma^(N-1-i)), but all contribute + w = gamma ** (self.n_loops - 1 - i) + total_loss = total_loss + w * F.cross_entropy( + self._logits(x).float(), targets, reduction="mean" + ) + weight_sum += w + + return total_loss / weight_sum + +# ── Training ────────────────────────────────────────────────────────────────── + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + global zeropower_via_newtonschulz5 + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ── Distributed + CUDA setup ────────────────────────────────────────────── + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required. For CPU testing use train_gpt_local.py") + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(True); enable_math_sdp(True) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master else None + + def log0(msg: str, console: bool = True) -> None: + if not master: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + + # ── Seed + tokenizer ────────────────────────────────────────────────────── + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"Tokenizer vocab {sp.vocab_size()} != VOCAB_SIZE {args.vocab_size}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_lut, space_lut, bnd_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_tokens:{val_tokens.numel()-1} train_seq_len:{args.train_seq_len}") + + # ── Model ───────────────────────────────────────────────────────────────── + base_model = RLM(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(base_model) + compiled = torch.compile(base_model, dynamic=False) + model: nn.Module = (DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) + if distributed else compiled) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} (~{n_params/1e6:.2f}M)") + log0(f"n_loops:{args.n_loops} n_memory:{args.n_memory} dim:{args.model_dim} " + f"seq_len:{args.train_seq_len}") + + # ── Optimizers ──────────────────────────────────────────────────────────── + block_named = list(base_model.block.named_parameters()) + matrix_params = [p for n, p in block_named + if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = ([p for n, p in block_named + if p.ndim != 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + + [base_model.skip_weights, base_model.memory, base_model.final_norm.eps + if hasattr(base_model.final_norm, 'eps') and + isinstance(base_model.final_norm.eps, nn.Parameter) else None]) + scalar_params = [p for p in scalar_params if p is not None] + + opt_embed = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_muon = Muon(matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_embed, opt_muon, opt_scalar] + + log0(f"matrix_params:{sum(p.numel() for p in matrix_params):,} " + f"scalar_params:{sum(p.numel() for p in scalar_params):,}") + + # ── Checkpoint resume (thermal crash protection) ────────────────────────── + ckpt_path = f"checkpoint_{args.run_id}.pt" + start_step = 0 + training_time_ms = 0.0 + if os.path.exists(ckpt_path) and master: + log0(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + base_model.load_state_dict(ckpt["model"]) + for opt, sd in zip(optimizers, ckpt["optimizers"]): + opt.load_state_dict(sd) + start_step = ckpt["step"] + training_time_ms = ckpt.get("training_time_ms", 0.0) + log0(f"Resumed at step {start_step}") + if distributed: dist.barrier() + + # ── Data loader ─────────────────────────────────────────────────────────── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad(): + for o in optimizers: o.zero_grad(set_to_none=True) + + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_scale(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wc_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ + if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = args.warmdown_iters * step_ms + rem_ms = max(max_wc_ms - elapsed_ms, 0.0) + return rem_ms / max(wd_ms, 1e-9) if rem_ms <= wd_ms else 1.0 + + # ── Warmup-then-restore ─────────────────────────────────────────────────── + if args.warmup_steps > 0 and start_step == 0: + init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + (model(x, y) * grad_scale).backward() + for o in optimizers: o.step() + zero_grad() + log0(f"warmup_complete:{args.warmup_steps} steps") + base_model.load_state_dict(init_model, strict=True) + for o, sd in zip(optimizers, init_opts): o.load_state_dict(sd) + zero_grad() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ──────────────────────────────────────────────────── + stop_after: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + for step in range(start_step, args.iterations + 1): + last = step == args.iterations or (stop_after is not None and step >= stop_after) + + # Validation + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or last: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + + if last: break + + # Checkpoint save (thermal crash protection) + if args.save_every > 0 and step > start_step and step % args.save_every == 0 and master: + torch.cuda.synchronize() + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + torch.save({"step": step, "model": base_model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "training_time_ms": elapsed_now}, ckpt_path) + log0(f"checkpoint_saved:step={step}") + + # LR update + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_scale(step, elapsed_ms) + for o in optimizers: + for g in o.param_groups: g["lr"] = g["base_lr"] * scale + + # Muon momentum warmup + frac = min(step / max(args.muon_momentum_warmup_steps, 1), 1.0) + mum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for g in opt_muon.param_groups: g["momentum"] = mum + + # Forward + backward with gradient accumulation + zero_grad() + train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for o in optimizers: o.step() + zero_grad() + + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step+1}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ms:.0f}ms step_avg:{approx_ms/(step+1):.2f}ms lr_scale:{scale:.4f}") + + reached = max_wc_ms is not None and approx_ms >= max_wc_ms + if distributed and max_wc_ms is not None: + rt = torch.tensor(int(reached), device=device) + dist.all_reduce(rt, op=dist.ReduceOp.MAX) + reached = bool(rt.item()) + if stop_after is None and reached: + stop_after = step + 1 + + log0(f"peak_memory:{torch.cuda.max_memory_allocated()//1024//1024}MiB " + f"reserved:{torch.cuda.max_memory_reserved()//1024//1024}MiB") + + # ── Serialisation + roundtrip validation ────────────────────────────────── + if master: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + model_bytes = os.path.getsize("final_model.pt") + log0(f"raw_model:{model_bytes} code:{code_bytes} total:{model_bytes+code_bytes}") + + quant_obj, qstats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + if master: + with open("final_model.int8.ptz", "wb") as f: f.write(blob) + qbytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = qstats["baseline_tensor_bytes"] / max(qstats["int8_payload_bytes"], 1) + log0(f"int8_zlib:{qbytes} code:{code_bytes} total:{qbytes+code_bytes} " + f"payload_ratio:{ratio:.2f}x budget_pct:{(qbytes+code_bytes)/16e6*100:.1f}%") + + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: blob_disk = f.read() + qs2 = torch.load(io.BytesIO(zlib.decompress(blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(qs2), strict=True) + torch.cuda.synchronize(); te = time.perf_counter() + qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + torch.cuda.synchronize() + log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-te):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvbpb:.8f}") + + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +PyTorch 2.10.0+cu128 +Fri Mar 20 06:52:10 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 62C P0 28W / 70W | 105MiB / 15360MiB | 4% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 13665 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +val_tokens:62021632 train_seq_len:1024 +model_params:4255784 (~4.26M) +n_loops:16 n_memory:32 dim:512 seq_len:1024 +matrix_params:3,276,800 scalar_params:454,696 +""" +train_gpt.py — Shared-Weight Recurrent Memory Transformer (RLM) +OpenAI Parameter Golf Competition — Non-Record Submission + +Architecture: One transformer block looped N_LOOPS=16 times (weight tying). + - Differential Attention V2 + per-loop low-rank Q delta (weight specialization) + - resid_mix: learned carry-forward vs input-reset per loop + - Loop position embeddings (activation-level depth signal) + - Memory tokens: learned persistent cross-loop state (RoPE excluded) + - Deep supervision: gamma-weighted loss at every loop iteration + - U-Net skip connections between encoder/decoder loop halves + - Gradient checkpointing: memory-efficient backprop through 16 loops + - QAT: per-row percentile fake-quantization during training + - relu² MLP, logit softcap, QK RMSNorm, q_gain (from baseline) + +At dim=512, 16 loops: ~4.24M params, ~4.24MB int8 — 26.5% of 16MB budget. +Effective compute depth of a ~63M parameter standard transformer. + +Thermal protection: checkpoints saved every SAVE_EVERY steps, auto-resumed. +""" + +from __future__ import annotations +import copy, glob, io, math, os, subprocess, sys, time, uuid, zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.checkpoint import checkpoint as grad_ckpt + +# ── Hyperparameters ─────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + save_every = int(os.environ.get("SAVE_EVERY", 500)) + + iterations = int(os.environ.get("ITERATIONS", 5000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 600)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 65_536)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0.0)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + n_loops = int(os.environ.get("N_LOOPS", 16)) + n_memory = int(os.environ.get("N_MEMORY", 32)) + mlp_mult = int(os.environ.get("MLP_MULT", 4)) + lora_rank = int(os.environ.get("LORA_RANK", 16)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + deep_sup_gamma= float(os.environ.get("DEEP_SUP_GAMMA", 0.9)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum= float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 300)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.1)) + +# ── Muon Optimizer (from modded-nanogpt / competition baseline) ─────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum, backend_steps, nesterov = ( + group["lr"], group["momentum"], group["backend_steps"], group["nesterov"] + ) + total = sum(int(p.numel()) for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-lr) + curr += p.numel() + return loss + +# ── Control tensor patterns (kept fp32, excluded from Muon) ────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,mlp_scale,resid_mix,q_gain,lambda_q1,lambda_k1,lambda_q2,lambda_k2," + "loop_embed,skip_weight,memory", +).split(",") if p) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# ── Tokenizer-aware BPB (from competition baseline) ─────────────────────────── + +def build_sentencepiece_luts(sp, vocab_size: int, device: torch.device): + sp_vocab = int(sp.vocab_size()) + table = max(sp_vocab, vocab_size) + base_np = np.zeros((table,), dtype=np.int16) + space_np = np.zeros((table,), dtype=np.bool_) + bnd_np = np.ones((table,), dtype=np.bool_) + for tid in range(sp_vocab): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue + bnd_np[tid] = False + if sp.is_byte(tid): base_np[tid] = 1; continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): space_np[tid] = True; piece = piece[1:] + base_np[tid] = len(piece.encode("utf-8")) + return (torch.tensor(base_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(bnd_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No val files: {pattern}") + tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[:usable + 1] + +@torch.no_grad() +def eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut): + local_tok = args.val_batch_size // (world_size * grad_accum_steps) + local_seqs = local_tok // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for s in range(seq_start, seq_end, local_seqs): + se = min(s + local_seqs, seq_end) + raw = val_tokens[s * args.train_seq_len:(se * args.train_seq_len + 1)] + raw = raw.to(device=device, dtype=torch.int64, non_blocking=True) + x = raw[:-1].reshape(-1, args.train_seq_len) + y = raw[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y).detach() + n = float(y.numel()) + loss_sum += loss.to(torch.float64) * n + tok_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_lut[tgt_ids].to(torch.int16) + tb += (space_lut[tgt_ids] & ~bnd_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, tok_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + model.train() + val_loss = float((loss_sum / tok_count).item()) + bpb = float((loss_sum / tok_count / math.log(2.0) * tok_count / byte_count).item()) + return val_loss, bpb + +# ── Quantization (from competition baseline) ────────────────────────────────── + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def _keep_float(name: str, t: Tensor, pt_dtypes: dict) -> Tensor: + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + pt_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + cl = torch.maximum(torch.minimum(t32, clip[:, None]), -clip[:, None]) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(cl / sc[:, None]), -127, 127).to(torch.int8).contiguous() + return q, sc.to(INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + sc = torch.tensor(clip / 127.0 if clip > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip, clip) / sc), -127, 127).to(torch.int8).contiguous() + return q, sc + +def quantize_state_dict_int8(sd: dict): + quant, scales, dtypes, passthrough, pt_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors", + "num_nonfloat_tensors","baseline_tensor_bytes","int8_payload_bytes"), 0) + for name, tensor in sd.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = _keep_float(name, t, pt_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = _quantize_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quant[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = {"__quant_format__": "int8_clean_per_row_v1", + "quantized": quant, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if pt_dtypes: obj["passthrough_orig_dtypes"] = pt_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict) -> dict: + out = {} + qm = obj.get("qmeta", {}) + ptod = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qm.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype).contiguous() + for name, t in obj["passthrough"].items(): + ot = t.detach().cpu().contiguous() + if isinstance(ptod.get(name), str): + ot = ot.to(getattr(torch, ptod[name])).contiguous() + out[name] = ot + return out + +# ── Data loading ────────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + hdr = np.fromfile(file, dtype=" Tensor: + chunks, rem = [], n + while rem > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._next(); continue + k = min(rem, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k; rem -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum: int): + local = global_tokens // (self.world_size * grad_accum) + span = local + 1 + chunk = self.stream.take(span * self.world_size) + s = self.rank * span + raw = chunk[s:s + span].to(torch.int64) + x = raw[:-1].reshape(-1, seq_len) + y = raw[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ── Model ───────────────────────────────────────────────────────────────────── + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), b) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__(); self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, p in module.named_parameters(): + if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and p.dtype != torch.float32: + p.data = p.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + self._sl, self._cos, self._sin = 0, None, None + def forward(self, sl: int, device, dtype): + if self._cos is None or self._sl != sl or self._cos.device != device: + t = torch.arange(sl, device=device, dtype=self.inv_freq.dtype) + f = torch.outer(t, self.inv_freq.to(device)) + self._cos = f.cos()[None, None, :, :]; self._sin = f.sin()[None, None, :, :] + self._sl = sl + return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) + +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def fake_quant_per_row(w: Tensor) -> Tensor: + """Per-row percentile int8 QAT with straight-through estimator.""" + if w.ndim < 2 or w.numel() == 0: return w + w32 = w.float() + clip = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + wq = (w32.clamp(-clip[:, None], clip[:, None]) / sc[:, None]).round().clamp(-127, 127) * sc[:, None] + return w + (wq.to(w.dtype) - w).detach() + +class SharedDiffAttention(nn.Module): + """Differential Attention V2 + per-loop low-rank Q delta + GQA.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, rope_base: float, rank: int): + super().__init__() + assert dim % nh == 0 and nh % nkv == 0 + self.nh, self.nkv, self.n_rep = nh, nkv, nh // nkv + self.hd = dim // nh + self.nmem = n_mem + + self.c_q = CastedLinear(dim, dim * 2, bias=False) + self.c_k = CastedLinear(dim, nkv * self.hd * 2, bias=False) + self.c_v = CastedLinear(dim, nkv * self.hd, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + li = 0.8 - 0.6 * math.exp(-0.3) + self.lambda_init = li + self.lambda_q1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_q2 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k2 = nn.Parameter(torch.randn(nh) * 0.1) + + self.q_gain = nn.Parameter(torch.ones(nh, dtype=torch.float32)) + self.delta_A = nn.Parameter(torch.zeros(n_loops, dim * 2, rank)) + self.delta_B = nn.Parameter(torch.zeros(n_loops, rank, dim)) + nn.init.normal_(self.delta_A, std=0.01) + + self.rotary = Rotary(self.hd, base=rope_base) + + def forward(self, x: Tensor, loop_idx: int) -> Tensor: + B, T, C = x.shape + H, KVH, D, NM = self.nh, self.nkv, self.hd, self.nmem + sl = T - NM + + # Q with per-loop low-rank delta + QAT on effective weight + wq = self.c_q.weight + (self.delta_A[loop_idx] @ self.delta_B[loop_idx]) + wq = fake_quant_per_row(wq) + qp = F.linear(x, wq.to(x.dtype)) + q1 = qp[..., :C].reshape(B, T, H, D).transpose(1, 2) + q2 = qp[..., C:].reshape(B, T, H, D).transpose(1, 2) + + # K, V: fresh every loop from updated hidden state + kv = self.c_k(x) + k1 = kv[..., :KVH*D].reshape(B, T, KVH, D).transpose(1, 2) + k2 = kv[..., KVH*D:].reshape(B, T, KVH, D).transpose(1, 2) + v = self.c_v(x).reshape(B, T, KVH, D).transpose(1, 2) + + # QK RMSNorm before RoPE (baseline technique) + q1 = F.rms_norm(q1, (D,)); q2 = F.rms_norm(q2, (D,)) + k1 = F.rms_norm(k1, (D,)); k2 = F.rms_norm(k2, (D,)) + + # RoPE: sequence positions ONLY — memory tokens excluded + cos, sin = self.rotary(sl, x.device, q1.dtype) + def rope_seq(q, k): + qm, qs = q[:, :, :NM], q[:, :, NM:] + km, ks = k[:, :, :NM], k[:, :, NM:] + return (torch.cat([qm, apply_rope(qs, cos, sin)], dim=2), + torch.cat([km, apply_rope(ks, cos, sin)], dim=2)) + q1, k1 = rope_seq(q1, k1) + q2, k2 = rope_seq(q2, k2) + + # Q gain + GQA + g = self.q_gain.to(q1.dtype)[None, :, None, None] + q1 = q1 * g; q2 = q2 * g + k1 = k1.repeat_interleave(self.n_rep, dim=1) + k2 = k2.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + # Differential attention: cancel noise via subtraction + lam = (torch.exp(self.lambda_q1) * torch.exp(self.lambda_k1) + - torch.exp(self.lambda_q2) * torch.exp(self.lambda_k2) + + self.lambda_init).to(q1.dtype)[None, :, None, None] + a1 = F.scaled_dot_product_attention(q1, k1, v, is_causal=True) + a2 = F.scaled_dot_product_attention(q2, k2, v, is_causal=True) + out = (a1 - lam * a2).transpose(1, 2).contiguous().view(B, T, C) + return self.proj(out) + +class ReluSquaredMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + h = dim * mlp_mult + self.fc = CastedLinear(dim, h, bias=False) + self.proj = CastedLinear(h, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class RecurrentBlock(nn.Module): + """Shared block with all per-loop parameters.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, mlp_mult: int, rope_base: float, rank: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SharedDiffAttention(dim, nh, nkv, n_loops, n_mem, rope_base, rank) + self.mlp = ReluSquaredMLP(dim, mlp_mult) + + # Per-loop scale factors (learned attention/MLP gate per depth) + self.attn_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + + # resid_mix: learned interpolation between current x and original x0 + # [0] = weight on current state, [1] = weight on original input + # Initialised to carry forward (mix[0]=1, mix[1]=0), learns to reset selectively + self.resid_mix = nn.Parameter( + torch.stack([torch.ones(n_loops, dim), torch.zeros(n_loops, dim)]).float() + ) + + # Loop position embedding: activation-level depth signal + # Complementary to weight-level low-rank delta + self.loop_embed = nn.Embedding(n_loops, dim) + nn.init.normal_(self.loop_embed.weight, std=0.02) + + def forward(self, x: Tensor, x0: Tensor, loop_idx: int) -> Tensor: + # resid_mix: learned balance of current state vs original input (generalises input injection) + mix = self.resid_mix[:, loop_idx, :].to(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # Loop position embedding (activation-level specialisation) + x = x + self.loop_embed.weight[loop_idx].to(x.dtype)[None, None, :] + + # Attention + MLP with per-loop scale + a = self.attn(self.attn_norm(x), loop_idx) + x = x + self.attn_scale[loop_idx].to(x.dtype)[None, None, :] * a + x = x + self.mlp_scale[loop_idx].to(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class RLM(nn.Module): + """ + Recurrent Language Model. + Single shared block × N_LOOPS with deep supervision and U-Net skips. + """ + def __init__(self, args: Hyperparameters): + super().__init__() + dim, nl, nm = args.model_dim, args.n_loops, args.n_memory + self.n_loops, self.n_memory = nl, nm + self.dim = dim + self.logit_softcap = args.logit_softcap + self.deep_sup_gamma = args.deep_sup_gamma + + self.n_enc = nl // 2 + self.n_skip = min(self.n_enc, nl - self.n_enc) + + self.tok_emb = nn.Embedding(args.vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + self.memory = nn.Parameter(torch.randn(1, nm, dim) * 0.02) + + self.block = RecurrentBlock( + dim, args.num_heads, args.num_kv_heads, nl, nm, + args.mlp_mult, args.rope_base, args.lora_rank + ) + + self.skip_weights = nn.Parameter( + torch.ones(self.n_skip, dim, dtype=torch.float32) + if self.n_skip > 0 else torch.zeros(0, dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + + # Zero-init output projections for stability at loop depth + nn.init.zeros_(self.block.attn.proj.weight) + nn.init.zeros_(self.block.mlp.proj.weight) + + def _logits(self, x: Tensor) -> Tensor: + h = self.final_norm(x[:, self.n_memory:, :]) + l = F.linear(h.reshape(-1, self.dim), self.tok_emb.weight) + return self.logit_softcap * torch.tanh(l / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + B, T = input_ids.shape + tok = F.rms_norm(self.tok_emb(input_ids), (self.dim,)) + x = torch.cat([self.memory.expand(B, -1, -1), tok], dim=1) + x0 = x + + gamma = self.deep_sup_gamma + total_loss = torch.zeros((), device=input_ids.device) + weight_sum = 0.0 + targets = target_ids.reshape(-1) + skips = [] + + for i in range(self.n_loops): + # Gradient checkpointing: recompute activations during backward + x = grad_ckpt(self.block, x, x0, i, use_reentrant=False) + + # U-Net: encoder half saves, decoder half consumes in reverse + if i < self.n_enc: + skips.append(x) + elif self.n_skip > 0 and skips: + di = i - self.n_enc + if di < self.n_skip: + sw = self.skip_weights[di].to(x.dtype)[None, None, :] + x = x + sw * skips[self.n_enc - 1 - di] + + # Deep supervision: weighted loss at every loop + # Later loops weighted higher (gamma^(N-1-i)), but all contribute + w = gamma ** (self.n_loops - 1 - i) + total_loss = total_loss + w * F.cross_entropy( + self._logits(x).float(), targets, reduction="mean" + ) + weight_sum += w + + return total_loss / weight_sum + +# ── Training ────────────────────────────────────────────────────────────────── + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + global zeropower_via_newtonschulz5 + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ── Distributed + CUDA setup ────────────────────────────────────────────── + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required. For CPU testing use train_gpt_local.py") + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(True); enable_math_sdp(True) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master else None + + def log0(msg: str, console: bool = True) -> None: + if not master: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + + # ── Seed + tokenizer ────────────────────────────────────────────────────── + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"Tokenizer vocab {sp.vocab_size()} != VOCAB_SIZE {args.vocab_size}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_lut, space_lut, bnd_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_tokens:{val_tokens.numel()-1} train_seq_len:{args.train_seq_len}") + + # ── Model ───────────────────────────────────────────────────────────────── + base_model = RLM(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(base_model) + compiled = torch.compile(base_model, dynamic=False) + model: nn.Module = (DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) + if distributed else compiled) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} (~{n_params/1e6:.2f}M)") + log0(f"n_loops:{args.n_loops} n_memory:{args.n_memory} dim:{args.model_dim} " + f"seq_len:{args.train_seq_len}") + + # ── Optimizers ──────────────────────────────────────────────────────────── + block_named = list(base_model.block.named_parameters()) + matrix_params = [p for n, p in block_named + if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = ([p for n, p in block_named + if p.ndim != 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + + [base_model.skip_weights, base_model.memory, base_model.final_norm.eps + if hasattr(base_model.final_norm, 'eps') and + isinstance(base_model.final_norm.eps, nn.Parameter) else None]) + scalar_params = [p for p in scalar_params if p is not None] + + opt_embed = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_muon = Muon(matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_embed, opt_muon, opt_scalar] + + log0(f"matrix_params:{sum(p.numel() for p in matrix_params):,} " + f"scalar_params:{sum(p.numel() for p in scalar_params):,}") + + # ── Checkpoint resume (thermal crash protection) ────────────────────────── + ckpt_path = f"checkpoint_{args.run_id}.pt" + start_step = 0 + training_time_ms = 0.0 + if os.path.exists(ckpt_path) and master: + log0(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + base_model.load_state_dict(ckpt["model"]) + for opt, sd in zip(optimizers, ckpt["optimizers"]): + opt.load_state_dict(sd) + start_step = ckpt["step"] + training_time_ms = ckpt.get("training_time_ms", 0.0) + log0(f"Resumed at step {start_step}") + if distributed: dist.barrier() + + # ── Data loader ─────────────────────────────────────────────────────────── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad(): + for o in optimizers: o.zero_grad(set_to_none=True) + + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_scale(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wc_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ + if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = args.warmdown_iters * step_ms + rem_ms = max(max_wc_ms - elapsed_ms, 0.0) + return rem_ms / max(wd_ms, 1e-9) if rem_ms <= wd_ms else 1.0 + + # ── Warmup-then-restore ─────────────────────────────────────────────────── + if args.warmup_steps > 0 and start_step == 0: + init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + (model(x, y) * grad_scale).backward() + for o in optimizers: o.step() + zero_grad() + log0(f"warmup_complete:{args.warmup_steps} steps") + base_model.load_state_dict(init_model, strict=True) + for o, sd in zip(optimizers, init_opts): o.load_state_dict(sd) + zero_grad() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ──────────────────────────────────────────────────── + stop_after: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + for step in range(start_step, args.iterations + 1): + last = step == args.iterations or (stop_after is not None and step >= stop_after) + + # Validation + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or last: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + + if last: break + + # Checkpoint save (thermal crash protection) + if args.save_every > 0 and step > start_step and step % args.save_every == 0 and master: + torch.cuda.synchronize() + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + torch.save({"step": step, "model": base_model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "training_time_ms": elapsed_now}, ckpt_path) + log0(f"checkpoint_saved:step={step}") + + # LR update + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_scale(step, elapsed_ms) + for o in optimizers: + for g in o.param_groups: g["lr"] = g["base_lr"] * scale + + # Muon momentum warmup + frac = min(step / max(args.muon_momentum_warmup_steps, 1), 1.0) + mum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for g in opt_muon.param_groups: g["momentum"] = mum + + # Forward + backward with gradient accumulation + zero_grad() + train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for o in optimizers: o.step() + zero_grad() + + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step+1}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ms:.0f}ms step_avg:{approx_ms/(step+1):.2f}ms lr_scale:{scale:.4f}") + + reached = max_wc_ms is not None and approx_ms >= max_wc_ms + if distributed and max_wc_ms is not None: + rt = torch.tensor(int(reached), device=device) + dist.all_reduce(rt, op=dist.ReduceOp.MAX) + reached = bool(rt.item()) + if stop_after is None and reached: + stop_after = step + 1 + + log0(f"peak_memory:{torch.cuda.max_memory_allocated()//1024//1024}MiB " + f"reserved:{torch.cuda.max_memory_reserved()//1024//1024}MiB") + + # ── Serialisation + roundtrip validation ────────────────────────────────── + if master: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + model_bytes = os.path.getsize("final_model.pt") + log0(f"raw_model:{model_bytes} code:{code_bytes} total:{model_bytes+code_bytes}") + + quant_obj, qstats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + if master: + with open("final_model.int8.ptz", "wb") as f: f.write(blob) + qbytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = qstats["baseline_tensor_bytes"] / max(qstats["int8_payload_bytes"], 1) + log0(f"int8_zlib:{qbytes} code:{code_bytes} total:{qbytes+code_bytes} " + f"payload_ratio:{ratio:.2f}x budget_pct:{(qbytes+code_bytes)/16e6*100:.1f}%") + + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: blob_disk = f.read() + qs2 = torch.load(io.BytesIO(zlib.decompress(blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(qs2), strict=True) + torch.cuda.synchronize(); te = time.perf_counter() + qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + torch.cuda.synchronize() + log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-te):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvbpb:.8f}") + + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +PyTorch 2.10.0+cu128 +Fri Mar 20 07:03:16 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 61C P0 28W / 70W | 105MiB / 15360MiB | 4% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 16530 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +val_tokens:62021632 train_seq_len:1024 +model_params:4255784 (~4.26M) +n_loops:16 n_memory:32 dim:512 seq_len:1024 +matrix_params:3,276,800 scalar_params:454,696 +warmup_complete:5 steps +""" +train_gpt.py — Shared-Weight Recurrent Memory Transformer (RLM) +OpenAI Parameter Golf Competition — Non-Record Submission + +Architecture: One transformer block looped N_LOOPS=16 times (weight tying). + - Differential Attention V2 + per-loop low-rank Q delta (weight specialization) + - resid_mix: learned carry-forward vs input-reset per loop + - Loop position embeddings (activation-level depth signal) + - Memory tokens: learned persistent cross-loop state (RoPE excluded) + - Deep supervision: gamma-weighted loss at every loop iteration + - U-Net skip connections between encoder/decoder loop halves + - Gradient checkpointing: memory-efficient backprop through 16 loops + - QAT: per-row percentile fake-quantization during training + - relu² MLP, logit softcap, QK RMSNorm, q_gain (from baseline) + +At dim=512, 16 loops: ~4.24M params, ~4.24MB int8 — 26.5% of 16MB budget. +Effective compute depth of a ~63M parameter standard transformer. + +Thermal protection: checkpoints saved every SAVE_EVERY steps, auto-resumed. +""" + +from __future__ import annotations +import copy, glob, io, math, os, subprocess, sys, time, uuid, zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.checkpoint import checkpoint as grad_ckpt + +# ── Hyperparameters ─────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + save_every = int(os.environ.get("SAVE_EVERY", 500)) + + iterations = int(os.environ.get("ITERATIONS", 5000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 600)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 65_536)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0.0)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + n_loops = int(os.environ.get("N_LOOPS", 16)) + n_memory = int(os.environ.get("N_MEMORY", 32)) + mlp_mult = int(os.environ.get("MLP_MULT", 4)) + lora_rank = int(os.environ.get("LORA_RANK", 16)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + deep_sup_gamma= float(os.environ.get("DEEP_SUP_GAMMA", 0.9)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum= float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 300)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.1)) + +# ── Muon Optimizer (from modded-nanogpt / competition baseline) ─────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum, backend_steps, nesterov = ( + group["lr"], group["momentum"], group["backend_steps"], group["nesterov"] + ) + total = sum(int(p.numel()) for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-lr) + curr += p.numel() + return loss + +# ── Control tensor patterns (kept fp32, excluded from Muon) ────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,mlp_scale,resid_mix,q_gain,lambda_q1,lambda_k1,lambda_q2,lambda_k2," + "loop_embed,skip_weight,memory", +).split(",") if p) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# ── Tokenizer-aware BPB (from competition baseline) ─────────────────────────── + +def build_sentencepiece_luts(sp, vocab_size: int, device: torch.device): + sp_vocab = int(sp.vocab_size()) + table = max(sp_vocab, vocab_size) + base_np = np.zeros((table,), dtype=np.int16) + space_np = np.zeros((table,), dtype=np.bool_) + bnd_np = np.ones((table,), dtype=np.bool_) + for tid in range(sp_vocab): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue + bnd_np[tid] = False + if sp.is_byte(tid): base_np[tid] = 1; continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): space_np[tid] = True; piece = piece[1:] + base_np[tid] = len(piece.encode("utf-8")) + return (torch.tensor(base_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(bnd_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No val files: {pattern}") + tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[:usable + 1] + +@torch.no_grad() +def eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut): + local_tok = args.val_batch_size // (world_size * grad_accum_steps) + local_seqs = local_tok // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for s in range(seq_start, seq_end, local_seqs): + se = min(s + local_seqs, seq_end) + raw = val_tokens[s * args.train_seq_len:(se * args.train_seq_len + 1)] + raw = raw.to(device=device, dtype=torch.int64, non_blocking=True) + x = raw[:-1].reshape(-1, args.train_seq_len) + y = raw[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y).detach() + n = float(y.numel()) + loss_sum += loss.to(torch.float64) * n + tok_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_lut[tgt_ids].to(torch.int16) + tb += (space_lut[tgt_ids] & ~bnd_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, tok_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + model.train() + val_loss = float((loss_sum / tok_count).item()) + bpb = float((loss_sum / tok_count / math.log(2.0) * tok_count / byte_count).item()) + return val_loss, bpb + +# ── Quantization (from competition baseline) ────────────────────────────────── + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def _keep_float(name: str, t: Tensor, pt_dtypes: dict) -> Tensor: + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + pt_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + cl = torch.maximum(torch.minimum(t32, clip[:, None]), -clip[:, None]) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(cl / sc[:, None]), -127, 127).to(torch.int8).contiguous() + return q, sc.to(INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + sc = torch.tensor(clip / 127.0 if clip > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip, clip) / sc), -127, 127).to(torch.int8).contiguous() + return q, sc + +def quantize_state_dict_int8(sd: dict): + quant, scales, dtypes, passthrough, pt_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors", + "num_nonfloat_tensors","baseline_tensor_bytes","int8_payload_bytes"), 0) + for name, tensor in sd.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = _keep_float(name, t, pt_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = _quantize_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quant[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = {"__quant_format__": "int8_clean_per_row_v1", + "quantized": quant, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if pt_dtypes: obj["passthrough_orig_dtypes"] = pt_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict) -> dict: + out = {} + qm = obj.get("qmeta", {}) + ptod = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qm.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype).contiguous() + for name, t in obj["passthrough"].items(): + ot = t.detach().cpu().contiguous() + if isinstance(ptod.get(name), str): + ot = ot.to(getattr(torch, ptod[name])).contiguous() + out[name] = ot + return out + +# ── Data loading ────────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + hdr = np.fromfile(file, dtype=" Tensor: + chunks, rem = [], n + while rem > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._next(); continue + k = min(rem, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k; rem -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum: int): + local = global_tokens // (self.world_size * grad_accum) + span = local + 1 + chunk = self.stream.take(span * self.world_size) + s = self.rank * span + raw = chunk[s:s + span].to(torch.int64) + x = raw[:-1].reshape(-1, seq_len) + y = raw[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ── Model ───────────────────────────────────────────────────────────────────── + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), b) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__(); self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, p in module.named_parameters(): + if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and p.dtype != torch.float32: + p.data = p.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + self._sl, self._cos, self._sin = 0, None, None + def forward(self, sl: int, device, dtype): + if self._cos is None or self._sl != sl or self._cos.device != device: + t = torch.arange(sl, device=device, dtype=self.inv_freq.dtype) + f = torch.outer(t, self.inv_freq.to(device)) + self._cos = f.cos()[None, None, :, :]; self._sin = f.sin()[None, None, :, :] + self._sl = sl + return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) + +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def fake_quant_per_row(w: Tensor) -> Tensor: + """Per-row percentile int8 QAT with straight-through estimator.""" + if w.ndim < 2 or w.numel() == 0: return w + w32 = w.float() + clip = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + wq = (w32.clamp(-clip[:, None], clip[:, None]) / sc[:, None]).round().clamp(-127, 127) * sc[:, None] + return w + (wq.to(w.dtype) - w).detach() + +class SharedDiffAttention(nn.Module): + """Differential Attention V2 + per-loop low-rank Q delta + GQA.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, rope_base: float, rank: int): + super().__init__() + assert dim % nh == 0 and nh % nkv == 0 + self.nh, self.nkv, self.n_rep = nh, nkv, nh // nkv + self.hd = dim // nh + self.nmem = n_mem + + self.c_q = CastedLinear(dim, dim * 2, bias=False) + self.c_k = CastedLinear(dim, nkv * self.hd * 2, bias=False) + self.c_v = CastedLinear(dim, nkv * self.hd, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + li = 0.8 - 0.6 * math.exp(-0.3) + self.lambda_init = li + self.lambda_q1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_q2 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k2 = nn.Parameter(torch.randn(nh) * 0.1) + + self.q_gain = nn.Parameter(torch.ones(nh, dtype=torch.float32)) + self.delta_A = nn.Parameter(torch.zeros(n_loops, dim * 2, rank)) + self.delta_B = nn.Parameter(torch.zeros(n_loops, rank, dim)) + nn.init.normal_(self.delta_A, std=0.01) + + self.rotary = Rotary(self.hd, base=rope_base) + + def forward(self, x: Tensor, loop_idx: int) -> Tensor: + B, T, C = x.shape + H, KVH, D, NM = self.nh, self.nkv, self.hd, self.nmem + sl = T - NM + + # Q with per-loop low-rank delta + QAT on effective weight + wq = self.c_q.weight + (self.delta_A[loop_idx] @ self.delta_B[loop_idx]) + wq = fake_quant_per_row(wq) + qp = F.linear(x, wq.to(x.dtype)) + q1 = qp[..., :C].reshape(B, T, H, D).transpose(1, 2) + q2 = qp[..., C:].reshape(B, T, H, D).transpose(1, 2) + + # K, V: fresh every loop from updated hidden state + kv = self.c_k(x) + k1 = kv[..., :KVH*D].reshape(B, T, KVH, D).transpose(1, 2) + k2 = kv[..., KVH*D:].reshape(B, T, KVH, D).transpose(1, 2) + v = self.c_v(x).reshape(B, T, KVH, D).transpose(1, 2) + + # QK RMSNorm before RoPE (baseline technique) + q1 = F.rms_norm(q1, (D,)); q2 = F.rms_norm(q2, (D,)) + k1 = F.rms_norm(k1, (D,)); k2 = F.rms_norm(k2, (D,)) + + # RoPE: sequence positions ONLY — memory tokens excluded + cos, sin = self.rotary(sl, x.device, q1.dtype) + def rope_seq(q, k): + qm, qs = q[:, :, :NM], q[:, :, NM:] + km, ks = k[:, :, :NM], k[:, :, NM:] + return (torch.cat([qm, apply_rope(qs, cos, sin)], dim=2), + torch.cat([km, apply_rope(ks, cos, sin)], dim=2)) + q1, k1 = rope_seq(q1, k1) + q2, k2 = rope_seq(q2, k2) + + # Q gain + GQA + g = self.q_gain.to(q1.dtype)[None, :, None, None] + q1 = q1 * g; q2 = q2 * g + k1 = k1.repeat_interleave(self.n_rep, dim=1) + k2 = k2.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + # Differential attention: cancel noise via subtraction + lam = (torch.exp(self.lambda_q1) * torch.exp(self.lambda_k1) + - torch.exp(self.lambda_q2) * torch.exp(self.lambda_k2) + + self.lambda_init).to(q1.dtype)[None, :, None, None] + a1 = F.scaled_dot_product_attention(q1, k1, v, is_causal=True) + a2 = F.scaled_dot_product_attention(q2, k2, v, is_causal=True) + out = (a1 - lam * a2).transpose(1, 2).contiguous().view(B, T, C) + return self.proj(out) + +class ReluSquaredMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + h = dim * mlp_mult + self.fc = CastedLinear(dim, h, bias=False) + self.proj = CastedLinear(h, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class RecurrentBlock(nn.Module): + """Shared block with all per-loop parameters.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, mlp_mult: int, rope_base: float, rank: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SharedDiffAttention(dim, nh, nkv, n_loops, n_mem, rope_base, rank) + self.mlp = ReluSquaredMLP(dim, mlp_mult) + + # Per-loop scale factors (learned attention/MLP gate per depth) + self.attn_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + + # resid_mix: learned interpolation between current x and original x0 + # [0] = weight on current state, [1] = weight on original input + # Initialised to carry forward (mix[0]=1, mix[1]=0), learns to reset selectively + self.resid_mix = nn.Parameter( + torch.stack([torch.ones(n_loops, dim), torch.zeros(n_loops, dim)]).float() + ) + + # Loop position embedding: activation-level depth signal + # Complementary to weight-level low-rank delta + self.loop_embed = nn.Embedding(n_loops, dim) + nn.init.normal_(self.loop_embed.weight, std=0.02) + + def forward(self, x: Tensor, x0: Tensor, loop_idx: int) -> Tensor: + # resid_mix: learned balance of current state vs original input (generalises input injection) + mix = self.resid_mix[:, loop_idx, :].to(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # Loop position embedding (activation-level specialisation) + x = x + self.loop_embed.weight[loop_idx].to(x.dtype)[None, None, :] + + # Attention + MLP with per-loop scale + a = self.attn(self.attn_norm(x), loop_idx) + x = x + self.attn_scale[loop_idx].to(x.dtype)[None, None, :] * a + x = x + self.mlp_scale[loop_idx].to(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class RLM(nn.Module): + """ + Recurrent Language Model. + Single shared block × N_LOOPS with deep supervision and U-Net skips. + """ + def __init__(self, args: Hyperparameters): + super().__init__() + dim, nl, nm = args.model_dim, args.n_loops, args.n_memory + self.n_loops, self.n_memory = nl, nm + self.dim = dim + self.logit_softcap = args.logit_softcap + self.deep_sup_gamma = args.deep_sup_gamma + + self.n_enc = nl // 2 + self.n_skip = min(self.n_enc, nl - self.n_enc) + + self.tok_emb = nn.Embedding(args.vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + self.memory = nn.Parameter(torch.randn(1, nm, dim) * 0.02) + + self.block = RecurrentBlock( + dim, args.num_heads, args.num_kv_heads, nl, nm, + args.mlp_mult, args.rope_base, args.lora_rank + ) + + self.skip_weights = nn.Parameter( + torch.ones(self.n_skip, dim, dtype=torch.float32) + if self.n_skip > 0 else torch.zeros(0, dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + + # Zero-init output projections for stability at loop depth + nn.init.zeros_(self.block.attn.proj.weight) + nn.init.zeros_(self.block.mlp.proj.weight) + + def _logits(self, x: Tensor) -> Tensor: + h = self.final_norm(x[:, self.n_memory:, :]) + l = F.linear(h.reshape(-1, self.dim), self.tok_emb.weight) + return self.logit_softcap * torch.tanh(l / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + B, T = input_ids.shape + tok = F.rms_norm(self.tok_emb(input_ids), (self.dim,)) + x = torch.cat([self.memory.expand(B, -1, -1), tok], dim=1) + x0 = x + + gamma = self.deep_sup_gamma + total_loss = torch.zeros((), device=input_ids.device) + weight_sum = 0.0 + targets = target_ids.reshape(-1) + skips = [] + + for i in range(self.n_loops): + # Gradient checkpointing: recompute activations during backward + x = grad_ckpt(self.block, x, x0, i, use_reentrant=False) + + # U-Net: encoder half saves, decoder half consumes in reverse + if i < self.n_enc: + skips.append(x) + elif self.n_skip > 0 and skips: + di = i - self.n_enc + if di < self.n_skip: + sw = self.skip_weights[di].to(x.dtype)[None, None, :] + x = x + sw * skips[self.n_enc - 1 - di] + + # Deep supervision: weighted loss at every loop + # Later loops weighted higher (gamma^(N-1-i)), but all contribute + w = gamma ** (self.n_loops - 1 - i) + total_loss = total_loss + w * F.cross_entropy( + self._logits(x).float(), targets, reduction="mean" + ) + weight_sum += w + + return total_loss / weight_sum + +# ── Training ────────────────────────────────────────────────────────────────── + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + global zeropower_via_newtonschulz5 + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ── Distributed + CUDA setup ────────────────────────────────────────────── + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required. For CPU testing use train_gpt_local.py") + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(True); enable_math_sdp(True) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master else None + + def log0(msg: str, console: bool = True) -> None: + if not master: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + + # ── Seed + tokenizer ────────────────────────────────────────────────────── + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"Tokenizer vocab {sp.vocab_size()} != VOCAB_SIZE {args.vocab_size}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_lut, space_lut, bnd_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_tokens:{val_tokens.numel()-1} train_seq_len:{args.train_seq_len}") + + # ── Model ───────────────────────────────────────────────────────────────── + base_model = RLM(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(base_model) + compiled = base_model # torch.compile disabled: T4 compile too slow + model: nn.Module = (DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) + if distributed else compiled) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} (~{n_params/1e6:.2f}M)") + log0(f"n_loops:{args.n_loops} n_memory:{args.n_memory} dim:{args.model_dim} " + f"seq_len:{args.train_seq_len}") + + # ── Optimizers ──────────────────────────────────────────────────────────── + block_named = list(base_model.block.named_parameters()) + matrix_params = [p for n, p in block_named + if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = ([p for n, p in block_named + if p.ndim != 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + + [base_model.skip_weights, base_model.memory, base_model.final_norm.eps + if hasattr(base_model.final_norm, 'eps') and + isinstance(base_model.final_norm.eps, nn.Parameter) else None]) + scalar_params = [p for p in scalar_params if p is not None] + + opt_embed = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_muon = Muon(matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_embed, opt_muon, opt_scalar] + + log0(f"matrix_params:{sum(p.numel() for p in matrix_params):,} " + f"scalar_params:{sum(p.numel() for p in scalar_params):,}") + + # ── Checkpoint resume (thermal crash protection) ────────────────────────── + ckpt_path = f"checkpoint_{args.run_id}.pt" + start_step = 0 + training_time_ms = 0.0 + if os.path.exists(ckpt_path) and master: + log0(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + base_model.load_state_dict(ckpt["model"]) + for opt, sd in zip(optimizers, ckpt["optimizers"]): + opt.load_state_dict(sd) + start_step = ckpt["step"] + training_time_ms = ckpt.get("training_time_ms", 0.0) + log0(f"Resumed at step {start_step}") + if distributed: dist.barrier() + + # ── Data loader ─────────────────────────────────────────────────────────── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad(): + for o in optimizers: o.zero_grad(set_to_none=True) + + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_scale(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wc_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ + if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = args.warmdown_iters * step_ms + rem_ms = max(max_wc_ms - elapsed_ms, 0.0) + return rem_ms / max(wd_ms, 1e-9) if rem_ms <= wd_ms else 1.0 + + # ── Warmup-then-restore ─────────────────────────────────────────────────── + if args.warmup_steps > 0 and start_step == 0: + init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + (model(x, y) * grad_scale).backward() + for o in optimizers: o.step() + zero_grad() + log0(f"warmup_complete:{args.warmup_steps} steps") + base_model.load_state_dict(init_model, strict=True) + for o, sd in zip(optimizers, init_opts): o.load_state_dict(sd) + zero_grad() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ──────────────────────────────────────────────────── + stop_after: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + for step in range(start_step, args.iterations + 1): + last = step == args.iterations or (stop_after is not None and step >= stop_after) + + # Validation + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or last: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + + if last: break + + # Checkpoint save (thermal crash protection) + if args.save_every > 0 and step > start_step and step % args.save_every == 0 and master: + torch.cuda.synchronize() + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + torch.save({"step": step, "model": base_model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "training_time_ms": elapsed_now}, ckpt_path) + log0(f"checkpoint_saved:step={step}") + + # LR update + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_scale(step, elapsed_ms) + for o in optimizers: + for g in o.param_groups: g["lr"] = g["base_lr"] * scale + + # Muon momentum warmup + frac = min(step / max(args.muon_momentum_warmup_steps, 1), 1.0) + mum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for g in opt_muon.param_groups: g["momentum"] = mum + + # Forward + backward with gradient accumulation + zero_grad() + train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for o in optimizers: o.step() + zero_grad() + + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step+1}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ms:.0f}ms step_avg:{approx_ms/(step+1):.2f}ms lr_scale:{scale:.4f}") + + reached = max_wc_ms is not None and approx_ms >= max_wc_ms + if distributed and max_wc_ms is not None: + rt = torch.tensor(int(reached), device=device) + dist.all_reduce(rt, op=dist.ReduceOp.MAX) + reached = bool(rt.item()) + if stop_after is None and reached: + stop_after = step + 1 + + log0(f"peak_memory:{torch.cuda.max_memory_allocated()//1024//1024}MiB " + f"reserved:{torch.cuda.max_memory_reserved()//1024//1024}MiB") + + # ── Serialisation + roundtrip validation ────────────────────────────────── + if master: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + model_bytes = os.path.getsize("final_model.pt") + log0(f"raw_model:{model_bytes} code:{code_bytes} total:{model_bytes+code_bytes}") + + quant_obj, qstats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + if master: + with open("final_model.int8.ptz", "wb") as f: f.write(blob) + qbytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = qstats["baseline_tensor_bytes"] / max(qstats["int8_payload_bytes"], 1) + log0(f"int8_zlib:{qbytes} code:{code_bytes} total:{qbytes+code_bytes} " + f"payload_ratio:{ratio:.2f}x budget_pct:{(qbytes+code_bytes)/16e6*100:.1f}%") + + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: blob_disk = f.read() + qs2 = torch.load(io.BytesIO(zlib.decompress(blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(qs2), strict=True) + torch.cuda.synchronize(); te = time.perf_counter() + qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + torch.cuda.synchronize() + log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-te):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvbpb:.8f}") + + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +PyTorch 2.10.0+cu128 +Fri Mar 20 07:24:42 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 68C P0 29W / 70W | 105MiB / 15360MiB | 4% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 21982 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +val_tokens:62021632 train_seq_len:1024 +model_params:4255784 (~4.26M) +n_loops:16 n_memory:32 dim:512 seq_len:1024 +matrix_params:3,276,800 scalar_params:454,696 +warmup_complete:5 steps +""" +train_gpt.py — Shared-Weight Recurrent Memory Transformer (RLM) +OpenAI Parameter Golf Competition — Non-Record Submission + +Architecture: One transformer block looped N_LOOPS=16 times (weight tying). + - Differential Attention V2 + per-loop low-rank Q delta (weight specialization) + - resid_mix: learned carry-forward vs input-reset per loop + - Loop position embeddings (activation-level depth signal) + - Memory tokens: learned persistent cross-loop state (RoPE excluded) + - Deep supervision: gamma-weighted loss at every loop iteration + - U-Net skip connections between encoder/decoder loop halves + - Gradient checkpointing: memory-efficient backprop through 16 loops + - QAT: per-row percentile fake-quantization during training + - relu² MLP, logit softcap, QK RMSNorm, q_gain (from baseline) + +At dim=512, 16 loops: ~4.24M params, ~4.24MB int8 — 26.5% of 16MB budget. +Effective compute depth of a ~63M parameter standard transformer. + +Thermal protection: checkpoints saved every SAVE_EVERY steps, auto-resumed. +""" + +from __future__ import annotations +import copy, glob, io, math, os, subprocess, sys, time, uuid, zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.checkpoint import checkpoint as grad_ckpt + +# ── Hyperparameters ─────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + save_every = int(os.environ.get("SAVE_EVERY", 500)) + + iterations = int(os.environ.get("ITERATIONS", 5000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 600)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 65_536)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0.0)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + n_loops = int(os.environ.get("N_LOOPS", 16)) + n_memory = int(os.environ.get("N_MEMORY", 32)) + mlp_mult = int(os.environ.get("MLP_MULT", 4)) + lora_rank = int(os.environ.get("LORA_RANK", 16)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + deep_sup_gamma= float(os.environ.get("DEEP_SUP_GAMMA", 0.9)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum= float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 300)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.1)) + +# ── Muon Optimizer (from modded-nanogpt / competition baseline) ─────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum, backend_steps, nesterov = ( + group["lr"], group["momentum"], group["backend_steps"], group["nesterov"] + ) + total = sum(int(p.numel()) for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-lr) + curr += p.numel() + return loss + +# ── Control tensor patterns (kept fp32, excluded from Muon) ────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,mlp_scale,resid_mix,q_gain,lambda_q1,lambda_k1,lambda_q2,lambda_k2," + "loop_embed,skip_weight,memory", +).split(",") if p) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# ── Tokenizer-aware BPB (from competition baseline) ─────────────────────────── + +def build_sentencepiece_luts(sp, vocab_size: int, device: torch.device): + sp_vocab = int(sp.vocab_size()) + table = max(sp_vocab, vocab_size) + base_np = np.zeros((table,), dtype=np.int16) + space_np = np.zeros((table,), dtype=np.bool_) + bnd_np = np.ones((table,), dtype=np.bool_) + for tid in range(sp_vocab): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue + bnd_np[tid] = False + if sp.is_byte(tid): base_np[tid] = 1; continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): space_np[tid] = True; piece = piece[1:] + base_np[tid] = len(piece.encode("utf-8")) + return (torch.tensor(base_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(bnd_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No val files: {pattern}") + tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[:usable + 1] + +@torch.no_grad() +def eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut): + local_tok = args.val_batch_size // (world_size * grad_accum_steps) + local_seqs = local_tok // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for s in range(seq_start, seq_end, local_seqs): + se = min(s + local_seqs, seq_end) + raw = val_tokens[s * args.train_seq_len:(se * args.train_seq_len + 1)] + raw = raw.to(device=device, dtype=torch.int64, non_blocking=True) + x = raw[:-1].reshape(-1, args.train_seq_len) + y = raw[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y).detach() + n = float(y.numel()) + loss_sum += loss.to(torch.float64) * n + tok_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_lut[tgt_ids].to(torch.int16) + tb += (space_lut[tgt_ids] & ~bnd_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, tok_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + model.train() + val_loss = float((loss_sum / tok_count).item()) + bpb = float((loss_sum / tok_count / math.log(2.0) * tok_count / byte_count).item()) + return val_loss, bpb + +# ── Quantization (from competition baseline) ────────────────────────────────── + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def _keep_float(name: str, t: Tensor, pt_dtypes: dict) -> Tensor: + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + pt_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + cl = torch.maximum(torch.minimum(t32, clip[:, None]), -clip[:, None]) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(cl / sc[:, None]), -127, 127).to(torch.int8).contiguous() + return q, sc.to(INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + sc = torch.tensor(clip / 127.0 if clip > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip, clip) / sc), -127, 127).to(torch.int8).contiguous() + return q, sc + +def quantize_state_dict_int8(sd: dict): + quant, scales, dtypes, passthrough, pt_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors", + "num_nonfloat_tensors","baseline_tensor_bytes","int8_payload_bytes"), 0) + for name, tensor in sd.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = _keep_float(name, t, pt_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = _quantize_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quant[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = {"__quant_format__": "int8_clean_per_row_v1", + "quantized": quant, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if pt_dtypes: obj["passthrough_orig_dtypes"] = pt_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict) -> dict: + out = {} + qm = obj.get("qmeta", {}) + ptod = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qm.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype).contiguous() + for name, t in obj["passthrough"].items(): + ot = t.detach().cpu().contiguous() + if isinstance(ptod.get(name), str): + ot = ot.to(getattr(torch, ptod[name])).contiguous() + out[name] = ot + return out + +# ── Data loading ────────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + hdr = np.fromfile(file, dtype=" Tensor: + chunks, rem = [], n + while rem > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._next(); continue + k = min(rem, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k; rem -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum: int): + local = global_tokens // (self.world_size * grad_accum) + span = local + 1 + chunk = self.stream.take(span * self.world_size) + s = self.rank * span + raw = chunk[s:s + span].to(torch.int64) + x = raw[:-1].reshape(-1, seq_len) + y = raw[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ── Model ───────────────────────────────────────────────────────────────────── + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), b) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__(); self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, p in module.named_parameters(): + if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and p.dtype != torch.float32: + p.data = p.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + self._sl, self._cos, self._sin = 0, None, None + def forward(self, sl: int, device, dtype): + if self._cos is None or self._sl != sl or self._cos.device != device: + t = torch.arange(sl, device=device, dtype=self.inv_freq.dtype) + f = torch.outer(t, self.inv_freq.to(device)) + self._cos = f.cos()[None, None, :, :]; self._sin = f.sin()[None, None, :, :] + self._sl = sl + return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) + +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def fake_quant_per_row(w: Tensor) -> Tensor: + """Per-row percentile int8 QAT with straight-through estimator.""" + if w.ndim < 2 or w.numel() == 0: return w + w32 = w.float() + clip = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + wq = (w32.clamp(-clip[:, None], clip[:, None]) / sc[:, None]).round().clamp(-127, 127) * sc[:, None] + return w + (wq.to(w.dtype) - w).detach() + +class SharedDiffAttention(nn.Module): + """Differential Attention V2 + per-loop low-rank Q delta + GQA.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, rope_base: float, rank: int): + super().__init__() + assert dim % nh == 0 and nh % nkv == 0 + self.nh, self.nkv, self.n_rep = nh, nkv, nh // nkv + self.hd = dim // nh + self.nmem = n_mem + + self.c_q = CastedLinear(dim, dim * 2, bias=False) + self.c_k = CastedLinear(dim, nkv * self.hd * 2, bias=False) + self.c_v = CastedLinear(dim, nkv * self.hd, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + li = 0.8 - 0.6 * math.exp(-0.3) + self.lambda_init = li + self.lambda_q1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_q2 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k2 = nn.Parameter(torch.randn(nh) * 0.1) + + self.q_gain = nn.Parameter(torch.ones(nh, dtype=torch.float32)) + self.delta_A = nn.Parameter(torch.zeros(n_loops, dim * 2, rank)) + self.delta_B = nn.Parameter(torch.zeros(n_loops, rank, dim)) + nn.init.normal_(self.delta_A, std=0.01) + + self.rotary = Rotary(self.hd, base=rope_base) + + def forward(self, x: Tensor, loop_idx: int) -> Tensor: + B, T, C = x.shape + H, KVH, D, NM = self.nh, self.nkv, self.hd, self.nmem + sl = T - NM + + # Q with per-loop low-rank delta + QAT on effective weight + wq = self.c_q.weight + (self.delta_A[loop_idx] @ self.delta_B[loop_idx]) + wq = fake_quant_per_row(wq) + qp = F.linear(x, wq.to(x.dtype)) + q1 = qp[..., :C].reshape(B, T, H, D).transpose(1, 2) + q2 = qp[..., C:].reshape(B, T, H, D).transpose(1, 2) + + # K, V: fresh every loop from updated hidden state + kv = self.c_k(x) + k1 = kv[..., :KVH*D].reshape(B, T, KVH, D).transpose(1, 2) + k2 = kv[..., KVH*D:].reshape(B, T, KVH, D).transpose(1, 2) + v = self.c_v(x).reshape(B, T, KVH, D).transpose(1, 2) + + # QK RMSNorm before RoPE (baseline technique) + q1 = F.rms_norm(q1, (D,)); q2 = F.rms_norm(q2, (D,)) + k1 = F.rms_norm(k1, (D,)); k2 = F.rms_norm(k2, (D,)) + + # RoPE: sequence positions ONLY — memory tokens excluded + cos, sin = self.rotary(sl, x.device, q1.dtype) + def rope_seq(q, k): + qm, qs = q[:, :, :NM], q[:, :, NM:] + km, ks = k[:, :, :NM], k[:, :, NM:] + return (torch.cat([qm, apply_rope(qs, cos, sin)], dim=2), + torch.cat([km, apply_rope(ks, cos, sin)], dim=2)) + q1, k1 = rope_seq(q1, k1) + q2, k2 = rope_seq(q2, k2) + + # Q gain + GQA + g = self.q_gain.to(q1.dtype)[None, :, None, None] + q1 = q1 * g; q2 = q2 * g + k1 = k1.repeat_interleave(self.n_rep, dim=1) + k2 = k2.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + # Differential attention: cancel noise via subtraction + lam = (torch.exp(self.lambda_q1) * torch.exp(self.lambda_k1) + - torch.exp(self.lambda_q2) * torch.exp(self.lambda_k2) + + self.lambda_init).to(q1.dtype)[None, :, None, None] + a1 = F.scaled_dot_product_attention(q1, k1, v, is_causal=True) + a2 = F.scaled_dot_product_attention(q2, k2, v, is_causal=True) + out = (a1 - lam * a2).transpose(1, 2).contiguous().view(B, T, C) + return self.proj(out) + +class ReluSquaredMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + h = dim * mlp_mult + self.fc = CastedLinear(dim, h, bias=False) + self.proj = CastedLinear(h, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class RecurrentBlock(nn.Module): + """Shared block with all per-loop parameters.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, mlp_mult: int, rope_base: float, rank: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SharedDiffAttention(dim, nh, nkv, n_loops, n_mem, rope_base, rank) + self.mlp = ReluSquaredMLP(dim, mlp_mult) + + # Per-loop scale factors (learned attention/MLP gate per depth) + self.attn_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + + # resid_mix: learned interpolation between current x and original x0 + # [0] = weight on current state, [1] = weight on original input + # Initialised to carry forward (mix[0]=1, mix[1]=0), learns to reset selectively + self.resid_mix = nn.Parameter( + torch.stack([torch.ones(n_loops, dim), torch.zeros(n_loops, dim)]).float() + ) + + # Loop position embedding: activation-level depth signal + # Complementary to weight-level low-rank delta + self.loop_embed = nn.Embedding(n_loops, dim) + nn.init.normal_(self.loop_embed.weight, std=0.02) + + def forward(self, x: Tensor, x0: Tensor, loop_idx: int) -> Tensor: + # resid_mix: learned balance of current state vs original input (generalises input injection) + mix = self.resid_mix[:, loop_idx, :].to(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # Loop position embedding (activation-level specialisation) + x = x + self.loop_embed.weight[loop_idx].to(x.dtype)[None, None, :] + + # Attention + MLP with per-loop scale + a = self.attn(self.attn_norm(x), loop_idx) + x = x + self.attn_scale[loop_idx].to(x.dtype)[None, None, :] * a + x = x + self.mlp_scale[loop_idx].to(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class RLM(nn.Module): + """ + Recurrent Language Model. + Single shared block × N_LOOPS with deep supervision and U-Net skips. + """ + def __init__(self, args: Hyperparameters): + super().__init__() + dim, nl, nm = args.model_dim, args.n_loops, args.n_memory + self.n_loops, self.n_memory = nl, nm + self.dim = dim + self.logit_softcap = args.logit_softcap + self.deep_sup_gamma = args.deep_sup_gamma + + self.n_enc = nl // 2 + self.n_skip = min(self.n_enc, nl - self.n_enc) + + self.tok_emb = nn.Embedding(args.vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + self.memory = nn.Parameter(torch.randn(1, nm, dim) * 0.02) + + self.block = RecurrentBlock( + dim, args.num_heads, args.num_kv_heads, nl, nm, + args.mlp_mult, args.rope_base, args.lora_rank + ) + + self.skip_weights = nn.Parameter( + torch.ones(self.n_skip, dim, dtype=torch.float32) + if self.n_skip > 0 else torch.zeros(0, dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + + # Zero-init output projections for stability at loop depth + nn.init.zeros_(self.block.attn.proj.weight) + nn.init.zeros_(self.block.mlp.proj.weight) + + def _logits(self, x: Tensor) -> Tensor: + h = self.final_norm(x[:, self.n_memory:, :]) + l = F.linear(h.reshape(-1, self.dim), self.tok_emb.weight) + return self.logit_softcap * torch.tanh(l / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + B, T = input_ids.shape + tok = F.rms_norm(self.tok_emb(input_ids), (self.dim,)) + x = torch.cat([self.memory.expand(B, -1, -1), tok], dim=1) + x0 = x + + gamma = self.deep_sup_gamma + total_loss = torch.zeros((), device=input_ids.device) + weight_sum = 0.0 + targets = target_ids.reshape(-1) + skips = [] + + for i in range(self.n_loops): + # Gradient checkpointing: recompute activations during backward + x = grad_ckpt(self.block, x, x0, i, use_reentrant=False) + + # U-Net: encoder half saves, decoder half consumes in reverse + if i < self.n_enc: + skips.append(x) + elif self.n_skip > 0 and skips: + di = i - self.n_enc + if di < self.n_skip: + sw = self.skip_weights[di].to(x.dtype)[None, None, :] + x = x + sw * skips[self.n_enc - 1 - di] + + # Deep supervision: weighted loss at every loop + # Later loops weighted higher (gamma^(N-1-i)), but all contribute + w = gamma ** (self.n_loops - 1 - i) + total_loss = total_loss + w * F.cross_entropy( + self._logits(x).float(), targets, reduction="mean" + ) + weight_sum += w + + return total_loss / weight_sum + +# ── Training ────────────────────────────────────────────────────────────────── + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + pass # torch.compile disabled for T4 + + # ── Distributed + CUDA setup ────────────────────────────────────────────── + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required. For CPU testing use train_gpt_local.py") + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(True); enable_math_sdp(True) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master else None + + def log0(msg: str, console: bool = True) -> None: + if not master: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + + # ── Seed + tokenizer ────────────────────────────────────────────────────── + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"Tokenizer vocab {sp.vocab_size()} != VOCAB_SIZE {args.vocab_size}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_lut, space_lut, bnd_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_tokens:{val_tokens.numel()-1} train_seq_len:{args.train_seq_len}") + + # ── Model ───────────────────────────────────────────────────────────────── + base_model = RLM(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(base_model) + compiled = base_model # torch.compile disabled: T4 compile too slow + model: nn.Module = (DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) + if distributed else compiled) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} (~{n_params/1e6:.2f}M)") + log0(f"n_loops:{args.n_loops} n_memory:{args.n_memory} dim:{args.model_dim} " + f"seq_len:{args.train_seq_len}") + + # ── Optimizers ──────────────────────────────────────────────────────────── + block_named = list(base_model.block.named_parameters()) + matrix_params = [p for n, p in block_named + if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = ([p for n, p in block_named + if p.ndim != 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + + [base_model.skip_weights, base_model.memory, base_model.final_norm.eps + if hasattr(base_model.final_norm, 'eps') and + isinstance(base_model.final_norm.eps, nn.Parameter) else None]) + scalar_params = [p for p in scalar_params if p is not None] + + opt_embed = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_muon = Muon(matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_embed, opt_muon, opt_scalar] + + log0(f"matrix_params:{sum(p.numel() for p in matrix_params):,} " + f"scalar_params:{sum(p.numel() for p in scalar_params):,}") + + # ── Checkpoint resume (thermal crash protection) ────────────────────────── + ckpt_path = f"checkpoint_{args.run_id}.pt" + start_step = 0 + training_time_ms = 0.0 + if os.path.exists(ckpt_path) and master: + log0(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + base_model.load_state_dict(ckpt["model"]) + for opt, sd in zip(optimizers, ckpt["optimizers"]): + opt.load_state_dict(sd) + start_step = ckpt["step"] + training_time_ms = ckpt.get("training_time_ms", 0.0) + log0(f"Resumed at step {start_step}") + if distributed: dist.barrier() + + # ── Data loader ─────────────────────────────────────────────────────────── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad(): + for o in optimizers: o.zero_grad(set_to_none=True) + + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_scale(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wc_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ + if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = args.warmdown_iters * step_ms + rem_ms = max(max_wc_ms - elapsed_ms, 0.0) + return rem_ms / max(wd_ms, 1e-9) if rem_ms <= wd_ms else 1.0 + + # ── Warmup-then-restore ─────────────────────────────────────────────────── + if args.warmup_steps > 0 and start_step == 0: + init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + (model(x, y) * grad_scale).backward() + for o in optimizers: o.step() + zero_grad() + log0(f"warmup_complete:{args.warmup_steps} steps") + base_model.load_state_dict(init_model, strict=True) + for o, sd in zip(optimizers, init_opts): o.load_state_dict(sd) + zero_grad() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ──────────────────────────────────────────────────── + stop_after: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + for step in range(start_step, args.iterations + 1): + last = step == args.iterations or (stop_after is not None and step >= stop_after) + + # Validation + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or last: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + + if last: break + + # Checkpoint save (thermal crash protection) + if args.save_every > 0 and step > start_step and step % args.save_every == 0 and master: + torch.cuda.synchronize() + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + torch.save({"step": step, "model": base_model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "training_time_ms": elapsed_now}, ckpt_path) + log0(f"checkpoint_saved:step={step}") + + # LR update + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_scale(step, elapsed_ms) + for o in optimizers: + for g in o.param_groups: g["lr"] = g["base_lr"] * scale + + # Muon momentum warmup + frac = min(step / max(args.muon_momentum_warmup_steps, 1), 1.0) + mum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for g in opt_muon.param_groups: g["momentum"] = mum + + # Forward + backward with gradient accumulation + zero_grad() + train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for o in optimizers: o.step() + zero_grad() + + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step+1}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ms:.0f}ms step_avg:{approx_ms/(step+1):.2f}ms lr_scale:{scale:.4f}") + + reached = max_wc_ms is not None and approx_ms >= max_wc_ms + if distributed and max_wc_ms is not None: + rt = torch.tensor(int(reached), device=device) + dist.all_reduce(rt, op=dist.ReduceOp.MAX) + reached = bool(rt.item()) + if stop_after is None and reached: + stop_after = step + 1 + + log0(f"peak_memory:{torch.cuda.max_memory_allocated()//1024//1024}MiB " + f"reserved:{torch.cuda.max_memory_reserved()//1024//1024}MiB") + + # ── Serialisation + roundtrip validation ────────────────────────────────── + if master: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + model_bytes = os.path.getsize("final_model.pt") + log0(f"raw_model:{model_bytes} code:{code_bytes} total:{model_bytes+code_bytes}") + + quant_obj, qstats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + if master: + with open("final_model.int8.ptz", "wb") as f: f.write(blob) + qbytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = qstats["baseline_tensor_bytes"] / max(qstats["int8_payload_bytes"], 1) + log0(f"int8_zlib:{qbytes} code:{code_bytes} total:{qbytes+code_bytes} " + f"payload_ratio:{ratio:.2f}x budget_pct:{(qbytes+code_bytes)/16e6*100:.1f}%") + + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: blob_disk = f.read() + qs2 = torch.load(io.BytesIO(zlib.decompress(blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(qs2), strict=True) + torch.cuda.synchronize(); te = time.perf_counter() + qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + torch.cuda.synchronize() + log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-te):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvbpb:.8f}") + + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +PyTorch 2.10.0+cu128 +Fri Mar 20 07:32:55 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 62C P0 28W / 70W | 105MiB / 15360MiB | 4% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 24084 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +val_tokens:62021632 train_seq_len:1024 +model_params:4255784 (~4.26M) +n_loops:16 n_memory:32 dim:512 seq_len:1024 +matrix_params:3,276,800 scalar_params:454,696 +warmup_complete:5 steps +""" +train_gpt.py — Shared-Weight Recurrent Memory Transformer (RLM) +OpenAI Parameter Golf Competition — Non-Record Submission + +Architecture: One transformer block looped N_LOOPS=16 times (weight tying). + - Differential Attention V2 + per-loop low-rank Q delta (weight specialization) + - resid_mix: learned carry-forward vs input-reset per loop + - Loop position embeddings (activation-level depth signal) + - Memory tokens: learned persistent cross-loop state (RoPE excluded) + - Deep supervision: gamma-weighted loss at every loop iteration + - U-Net skip connections between encoder/decoder loop halves + - Gradient checkpointing: memory-efficient backprop through 16 loops + - QAT: per-row percentile fake-quantization during training + - relu² MLP, logit softcap, QK RMSNorm, q_gain (from baseline) + +At dim=512, 16 loops: ~4.24M params, ~4.24MB int8 — 26.5% of 16MB budget. +Effective compute depth of a ~63M parameter standard transformer. + +Thermal protection: checkpoints saved every SAVE_EVERY steps, auto-resumed. +""" + +from __future__ import annotations +import copy, glob, io, math, os, subprocess, sys, time, uuid, zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ── Hyperparameters ─────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + save_every = int(os.environ.get("SAVE_EVERY", 500)) + + iterations = int(os.environ.get("ITERATIONS", 5000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 600)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 65_536)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0.0)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + n_loops = int(os.environ.get("N_LOOPS", 16)) + n_memory = int(os.environ.get("N_MEMORY", 32)) + mlp_mult = int(os.environ.get("MLP_MULT", 4)) + lora_rank = int(os.environ.get("LORA_RANK", 16)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + deep_sup_gamma= float(os.environ.get("DEEP_SUP_GAMMA", 0.9)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum= float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 300)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.1)) + +# ── Muon Optimizer (from modded-nanogpt / competition baseline) ─────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum, backend_steps, nesterov = ( + group["lr"], group["momentum"], group["backend_steps"], group["nesterov"] + ) + total = sum(int(p.numel()) for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-lr) + curr += p.numel() + return loss + +# ── Control tensor patterns (kept fp32, excluded from Muon) ────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,mlp_scale,resid_mix,q_gain,lambda_q1,lambda_k1,lambda_q2,lambda_k2," + "loop_embed,skip_weight,memory", +).split(",") if p) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# ── Tokenizer-aware BPB (from competition baseline) ─────────────────────────── + +def build_sentencepiece_luts(sp, vocab_size: int, device: torch.device): + sp_vocab = int(sp.vocab_size()) + table = max(sp_vocab, vocab_size) + base_np = np.zeros((table,), dtype=np.int16) + space_np = np.zeros((table,), dtype=np.bool_) + bnd_np = np.ones((table,), dtype=np.bool_) + for tid in range(sp_vocab): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue + bnd_np[tid] = False + if sp.is_byte(tid): base_np[tid] = 1; continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): space_np[tid] = True; piece = piece[1:] + base_np[tid] = len(piece.encode("utf-8")) + return (torch.tensor(base_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(bnd_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No val files: {pattern}") + tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[:usable + 1] + +@torch.no_grad() +def eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut): + local_tok = args.val_batch_size // (world_size * grad_accum_steps) + local_seqs = local_tok // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for s in range(seq_start, seq_end, local_seqs): + se = min(s + local_seqs, seq_end) + raw = val_tokens[s * args.train_seq_len:(se * args.train_seq_len + 1)] + raw = raw.to(device=device, dtype=torch.int64, non_blocking=True) + x = raw[:-1].reshape(-1, args.train_seq_len) + y = raw[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y).detach() + n = float(y.numel()) + loss_sum += loss.to(torch.float64) * n + tok_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_lut[tgt_ids].to(torch.int16) + tb += (space_lut[tgt_ids] & ~bnd_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, tok_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + model.train() + val_loss = float((loss_sum / tok_count).item()) + bpb = float((loss_sum / tok_count / math.log(2.0) * tok_count / byte_count).item()) + return val_loss, bpb + +# ── Quantization (from competition baseline) ────────────────────────────────── + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def _keep_float(name: str, t: Tensor, pt_dtypes: dict) -> Tensor: + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + pt_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + cl = torch.maximum(torch.minimum(t32, clip[:, None]), -clip[:, None]) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(cl / sc[:, None]), -127, 127).to(torch.int8).contiguous() + return q, sc.to(INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + sc = torch.tensor(clip / 127.0 if clip > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip, clip) / sc), -127, 127).to(torch.int8).contiguous() + return q, sc + +def quantize_state_dict_int8(sd: dict): + quant, scales, dtypes, passthrough, pt_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors", + "num_nonfloat_tensors","baseline_tensor_bytes","int8_payload_bytes"), 0) + for name, tensor in sd.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = _keep_float(name, t, pt_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = _quantize_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quant[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = {"__quant_format__": "int8_clean_per_row_v1", + "quantized": quant, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if pt_dtypes: obj["passthrough_orig_dtypes"] = pt_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict) -> dict: + out = {} + qm = obj.get("qmeta", {}) + ptod = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qm.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype).contiguous() + for name, t in obj["passthrough"].items(): + ot = t.detach().cpu().contiguous() + if isinstance(ptod.get(name), str): + ot = ot.to(getattr(torch, ptod[name])).contiguous() + out[name] = ot + return out + +# ── Data loading ────────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + hdr = np.fromfile(file, dtype=" Tensor: + chunks, rem = [], n + while rem > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._next(); continue + k = min(rem, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k; rem -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum: int): + local = global_tokens // (self.world_size * grad_accum) + span = local + 1 + chunk = self.stream.take(span * self.world_size) + s = self.rank * span + raw = chunk[s:s + span].to(torch.int64) + x = raw[:-1].reshape(-1, seq_len) + y = raw[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ── Model ───────────────────────────────────────────────────────────────────── + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), b) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__(); self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, p in module.named_parameters(): + if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and p.dtype != torch.float32: + p.data = p.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + self._sl, self._cos, self._sin = 0, None, None + def forward(self, sl: int, device, dtype): + if self._cos is None or self._sl != sl or self._cos.device != device: + t = torch.arange(sl, device=device, dtype=self.inv_freq.dtype) + f = torch.outer(t, self.inv_freq.to(device)) + self._cos = f.cos()[None, None, :, :]; self._sin = f.sin()[None, None, :, :] + self._sl = sl + return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) + +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def fake_quant_per_row(w: Tensor) -> Tensor: + """Per-row percentile int8 QAT with straight-through estimator.""" + if w.ndim < 2 or w.numel() == 0: return w + w32 = w.float() + clip = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + wq = (w32.clamp(-clip[:, None], clip[:, None]) / sc[:, None]).round().clamp(-127, 127) * sc[:, None] + return w + (wq.to(w.dtype) - w).detach() + +class SharedDiffAttention(nn.Module): + """Differential Attention V2 + per-loop low-rank Q delta + GQA.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, rope_base: float, rank: int): + super().__init__() + assert dim % nh == 0 and nh % nkv == 0 + self.nh, self.nkv, self.n_rep = nh, nkv, nh // nkv + self.hd = dim // nh + self.nmem = n_mem + + self.c_q = CastedLinear(dim, dim * 2, bias=False) + self.c_k = CastedLinear(dim, nkv * self.hd * 2, bias=False) + self.c_v = CastedLinear(dim, nkv * self.hd, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + li = 0.8 - 0.6 * math.exp(-0.3) + self.lambda_init = li + self.lambda_q1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_q2 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k2 = nn.Parameter(torch.randn(nh) * 0.1) + + self.q_gain = nn.Parameter(torch.ones(nh, dtype=torch.float32)) + self.delta_A = nn.Parameter(torch.zeros(n_loops, dim * 2, rank)) + self.delta_B = nn.Parameter(torch.zeros(n_loops, rank, dim)) + nn.init.normal_(self.delta_A, std=0.01) + + self.rotary = Rotary(self.hd, base=rope_base) + + def forward(self, x: Tensor, loop_idx: int) -> Tensor: + B, T, C = x.shape + H, KVH, D, NM = self.nh, self.nkv, self.hd, self.nmem + sl = T - NM + + # Q with per-loop low-rank delta + QAT on effective weight + wq = self.c_q.weight + (self.delta_A[loop_idx] @ self.delta_B[loop_idx]) + wq = fake_quant_per_row(wq) + qp = F.linear(x, wq.to(x.dtype)) + q1 = qp[..., :C].reshape(B, T, H, D).transpose(1, 2) + q2 = qp[..., C:].reshape(B, T, H, D).transpose(1, 2) + + # K, V: fresh every loop from updated hidden state + kv = self.c_k(x) + k1 = kv[..., :KVH*D].reshape(B, T, KVH, D).transpose(1, 2) + k2 = kv[..., KVH*D:].reshape(B, T, KVH, D).transpose(1, 2) + v = self.c_v(x).reshape(B, T, KVH, D).transpose(1, 2) + + # QK RMSNorm before RoPE (baseline technique) + q1 = F.rms_norm(q1, (D,)); q2 = F.rms_norm(q2, (D,)) + k1 = F.rms_norm(k1, (D,)); k2 = F.rms_norm(k2, (D,)) + + # RoPE: sequence positions ONLY — memory tokens excluded + cos, sin = self.rotary(sl, x.device, q1.dtype) + def rope_seq(q, k): + qm, qs = q[:, :, :NM], q[:, :, NM:] + km, ks = k[:, :, :NM], k[:, :, NM:] + return (torch.cat([qm, apply_rope(qs, cos, sin)], dim=2), + torch.cat([km, apply_rope(ks, cos, sin)], dim=2)) + q1, k1 = rope_seq(q1, k1) + q2, k2 = rope_seq(q2, k2) + + # Q gain + GQA + g = self.q_gain.to(q1.dtype)[None, :, None, None] + q1 = q1 * g; q2 = q2 * g + k1 = k1.repeat_interleave(self.n_rep, dim=1) + k2 = k2.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + # Differential attention: cancel noise via subtraction + lam = (torch.exp(self.lambda_q1) * torch.exp(self.lambda_k1) + - torch.exp(self.lambda_q2) * torch.exp(self.lambda_k2) + + self.lambda_init).to(q1.dtype)[None, :, None, None] + a1 = F.scaled_dot_product_attention(q1, k1, v, is_causal=True) + a2 = F.scaled_dot_product_attention(q2, k2, v, is_causal=True) + out = (a1 - lam * a2).transpose(1, 2).contiguous().view(B, T, C) + return self.proj(out) + +class ReluSquaredMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + h = dim * mlp_mult + self.fc = CastedLinear(dim, h, bias=False) + self.proj = CastedLinear(h, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class RecurrentBlock(nn.Module): + """Shared block with all per-loop parameters.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, mlp_mult: int, rope_base: float, rank: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SharedDiffAttention(dim, nh, nkv, n_loops, n_mem, rope_base, rank) + self.mlp = ReluSquaredMLP(dim, mlp_mult) + + # Per-loop scale factors (learned attention/MLP gate per depth) + self.attn_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + + # resid_mix: learned interpolation between current x and original x0 + # [0] = weight on current state, [1] = weight on original input + # Initialised to carry forward (mix[0]=1, mix[1]=0), learns to reset selectively + self.resid_mix = nn.Parameter( + torch.stack([torch.ones(n_loops, dim), torch.zeros(n_loops, dim)]).float() + ) + + # Loop position embedding: activation-level depth signal + # Complementary to weight-level low-rank delta + self.loop_embed = nn.Embedding(n_loops, dim) + nn.init.normal_(self.loop_embed.weight, std=0.02) + + def forward(self, x: Tensor, x0: Tensor, loop_idx: int) -> Tensor: + # resid_mix: learned balance of current state vs original input (generalises input injection) + mix = self.resid_mix[:, loop_idx, :].to(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # Loop position embedding (activation-level specialisation) + x = x + self.loop_embed.weight[loop_idx].to(x.dtype)[None, None, :] + + # Attention + MLP with per-loop scale + a = self.attn(self.attn_norm(x), loop_idx) + x = x + self.attn_scale[loop_idx].to(x.dtype)[None, None, :] * a + x = x + self.mlp_scale[loop_idx].to(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class RLM(nn.Module): + """ + Recurrent Language Model. + Single shared block × N_LOOPS with deep supervision and U-Net skips. + """ + def __init__(self, args: Hyperparameters): + super().__init__() + dim, nl, nm = args.model_dim, args.n_loops, args.n_memory + self.n_loops, self.n_memory = nl, nm + self.dim = dim + self.logit_softcap = args.logit_softcap + self.deep_sup_gamma = args.deep_sup_gamma + + self.n_enc = nl // 2 + self.n_skip = min(self.n_enc, nl - self.n_enc) + + self.tok_emb = nn.Embedding(args.vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + self.memory = nn.Parameter(torch.randn(1, nm, dim) * 0.02) + + self.block = RecurrentBlock( + dim, args.num_heads, args.num_kv_heads, nl, nm, + args.mlp_mult, args.rope_base, args.lora_rank + ) + + self.skip_weights = nn.Parameter( + torch.ones(self.n_skip, dim, dtype=torch.float32) + if self.n_skip > 0 else torch.zeros(0, dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + + # Zero-init output projections for stability at loop depth + nn.init.zeros_(self.block.attn.proj.weight) + nn.init.zeros_(self.block.mlp.proj.weight) + + def _logits(self, x: Tensor) -> Tensor: + h = self.final_norm(x[:, self.n_memory:, :]) + l = F.linear(h.reshape(-1, self.dim), self.tok_emb.weight) + return self.logit_softcap * torch.tanh(l / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + B, T = input_ids.shape + tok = F.rms_norm(self.tok_emb(input_ids), (self.dim,)) + x = torch.cat([self.memory.expand(B, -1, -1), tok], dim=1) + x0 = x + + gamma = self.deep_sup_gamma + total_loss = torch.zeros((), device=input_ids.device) + weight_sum = 0.0 + targets = target_ids.reshape(-1) + skips = [] + + for i in range(self.n_loops): + # Gradient checkpointing: recompute activations during backward + x = self.block(x, x0, i) + + # U-Net: encoder half saves, decoder half consumes in reverse + if i < self.n_enc: + skips.append(x) + elif self.n_skip > 0 and skips: + di = i - self.n_enc + if di < self.n_skip: + sw = self.skip_weights[di].to(x.dtype)[None, None, :] + x = x + sw * skips[self.n_enc - 1 - di] + + # Deep supervision: weighted loss at every loop + # Later loops weighted higher (gamma^(N-1-i)), but all contribute + w = gamma ** (self.n_loops - 1 - i) + total_loss = total_loss + w * F.cross_entropy( + self._logits(x).float(), targets, reduction="mean" + ) + weight_sum += w + + return total_loss / weight_sum + +# ── Training ────────────────────────────────────────────────────────────────── + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + pass # torch.compile disabled for T4 + + # ── Distributed + CUDA setup ────────────────────────────────────────────── + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required. For CPU testing use train_gpt_local.py") + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(True); enable_math_sdp(True) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master else None + + def log0(msg: str, console: bool = True) -> None: + if not master: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + + # ── Seed + tokenizer ────────────────────────────────────────────────────── + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"Tokenizer vocab {sp.vocab_size()} != VOCAB_SIZE {args.vocab_size}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_lut, space_lut, bnd_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_tokens:{val_tokens.numel()-1} train_seq_len:{args.train_seq_len}") + + # ── Model ───────────────────────────────────────────────────────────────── + base_model = RLM(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(base_model) + compiled = base_model # torch.compile disabled: T4 compile too slow + model: nn.Module = (DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) + if distributed else compiled) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} (~{n_params/1e6:.2f}M)") + log0(f"n_loops:{args.n_loops} n_memory:{args.n_memory} dim:{args.model_dim} " + f"seq_len:{args.train_seq_len}") + + # ── Optimizers ──────────────────────────────────────────────────────────── + block_named = list(base_model.block.named_parameters()) + matrix_params = [p for n, p in block_named + if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = ([p for n, p in block_named + if p.ndim != 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + + [base_model.skip_weights, base_model.memory, base_model.final_norm.eps + if hasattr(base_model.final_norm, 'eps') and + isinstance(base_model.final_norm.eps, nn.Parameter) else None]) + scalar_params = [p for p in scalar_params if p is not None] + + opt_embed = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_muon = Muon(matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_embed, opt_muon, opt_scalar] + + log0(f"matrix_params:{sum(p.numel() for p in matrix_params):,} " + f"scalar_params:{sum(p.numel() for p in scalar_params):,}") + + # ── Checkpoint resume (thermal crash protection) ────────────────────────── + ckpt_path = f"checkpoint_{args.run_id}.pt" + start_step = 0 + training_time_ms = 0.0 + if os.path.exists(ckpt_path) and master: + log0(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + base_model.load_state_dict(ckpt["model"]) + for opt, sd in zip(optimizers, ckpt["optimizers"]): + opt.load_state_dict(sd) + start_step = ckpt["step"] + training_time_ms = ckpt.get("training_time_ms", 0.0) + log0(f"Resumed at step {start_step}") + if distributed: dist.barrier() + + # ── Data loader ─────────────────────────────────────────────────────────── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad(): + for o in optimizers: o.zero_grad(set_to_none=True) + + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_scale(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wc_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ + if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = args.warmdown_iters * step_ms + rem_ms = max(max_wc_ms - elapsed_ms, 0.0) + return rem_ms / max(wd_ms, 1e-9) if rem_ms <= wd_ms else 1.0 + + # ── Warmup-then-restore ─────────────────────────────────────────────────── + if args.warmup_steps > 0 and start_step == 0: + init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + (model(x, y) * grad_scale).backward() + for o in optimizers: o.step() + zero_grad() + log0(f"warmup_complete:{args.warmup_steps} steps") + base_model.load_state_dict(init_model, strict=True) + for o, sd in zip(optimizers, init_opts): o.load_state_dict(sd) + zero_grad() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ──────────────────────────────────────────────────── + stop_after: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + for step in range(start_step, args.iterations + 1): + last = step == args.iterations or (stop_after is not None and step >= stop_after) + + # Validation + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or last: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + + if last: break + + # Checkpoint save (thermal crash protection) + if args.save_every > 0 and step > start_step and step % args.save_every == 0 and master: + torch.cuda.synchronize() + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + torch.save({"step": step, "model": base_model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "training_time_ms": elapsed_now}, ckpt_path) + log0(f"checkpoint_saved:step={step}") + + # LR update + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_scale(step, elapsed_ms) + for o in optimizers: + for g in o.param_groups: g["lr"] = g["base_lr"] * scale + + # Muon momentum warmup + frac = min(step / max(args.muon_momentum_warmup_steps, 1), 1.0) + mum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for g in opt_muon.param_groups: g["momentum"] = mum + + # Forward + backward with gradient accumulation + zero_grad() + train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for o in optimizers: o.step() + zero_grad() + + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step+1}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ms:.0f}ms step_avg:{approx_ms/(step+1):.2f}ms lr_scale:{scale:.4f}") + + reached = max_wc_ms is not None and approx_ms >= max_wc_ms + if distributed and max_wc_ms is not None: + rt = torch.tensor(int(reached), device=device) + dist.all_reduce(rt, op=dist.ReduceOp.MAX) + reached = bool(rt.item()) + if stop_after is None and reached: + stop_after = step + 1 + + log0(f"peak_memory:{torch.cuda.max_memory_allocated()//1024//1024}MiB " + f"reserved:{torch.cuda.max_memory_reserved()//1024//1024}MiB") + + # ── Serialisation + roundtrip validation ────────────────────────────────── + if master: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + model_bytes = os.path.getsize("final_model.pt") + log0(f"raw_model:{model_bytes} code:{code_bytes} total:{model_bytes+code_bytes}") + + quant_obj, qstats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + if master: + with open("final_model.int8.ptz", "wb") as f: f.write(blob) + qbytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = qstats["baseline_tensor_bytes"] / max(qstats["int8_payload_bytes"], 1) + log0(f"int8_zlib:{qbytes} code:{code_bytes} total:{qbytes+code_bytes} " + f"payload_ratio:{ratio:.2f}x budget_pct:{(qbytes+code_bytes)/16e6*100:.1f}%") + + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: blob_disk = f.read() + qs2 = torch.load(io.BytesIO(zlib.decompress(blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(qs2), strict=True) + torch.cuda.synchronize(); te = time.perf_counter() + qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + torch.cuda.synchronize() + log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-te):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvbpb:.8f}") + + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +PyTorch 2.10.0+cu128 +Fri Mar 20 07:37:41 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 67C P0 29W / 70W | 105MiB / 15360MiB | 4% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 25339 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +val_tokens:62021632 train_seq_len:1024 +model_params:4255784 (~4.26M) +n_loops:16 n_memory:32 dim:512 seq_len:1024 +matrix_params:3,276,800 scalar_params:454,696 +""" +train_gpt.py — Shared-Weight Recurrent Memory Transformer (RLM) +OpenAI Parameter Golf Competition — Non-Record Submission + +Architecture: One transformer block looped N_LOOPS=16 times (weight tying). + - Differential Attention V2 + per-loop low-rank Q delta (weight specialization) + - resid_mix: learned carry-forward vs input-reset per loop + - Loop position embeddings (activation-level depth signal) + - Memory tokens: learned persistent cross-loop state (RoPE excluded) + - Deep supervision: gamma-weighted loss at every loop iteration + - U-Net skip connections between encoder/decoder loop halves + - Gradient checkpointing: memory-efficient backprop through 16 loops + - QAT: per-row percentile fake-quantization during training + - relu² MLP, logit softcap, QK RMSNorm, q_gain (from baseline) + +At dim=512, 16 loops: ~4.24M params, ~4.24MB int8 — 26.5% of 16MB budget. +Effective compute depth of a ~63M parameter standard transformer. + +Thermal protection: checkpoints saved every SAVE_EVERY steps, auto-resumed. +""" + +from __future__ import annotations +import copy, glob, io, math, os, subprocess, sys, time, uuid, zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ── Hyperparameters ─────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + save_every = int(os.environ.get("SAVE_EVERY", 500)) + + iterations = int(os.environ.get("ITERATIONS", 5000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 600)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 65_536)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0.0)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + n_loops = int(os.environ.get("N_LOOPS", 16)) + n_memory = int(os.environ.get("N_MEMORY", 32)) + mlp_mult = int(os.environ.get("MLP_MULT", 4)) + lora_rank = int(os.environ.get("LORA_RANK", 16)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + deep_sup_gamma= float(os.environ.get("DEEP_SUP_GAMMA", 0.9)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum= float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 300)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.1)) + +# ── Muon Optimizer (from modded-nanogpt / competition baseline) ─────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum, backend_steps, nesterov = ( + group["lr"], group["momentum"], group["backend_steps"], group["nesterov"] + ) + total = sum(int(p.numel()) for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-lr) + curr += p.numel() + return loss + +# ── Control tensor patterns (kept fp32, excluded from Muon) ────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,mlp_scale,resid_mix,q_gain,lambda_q1,lambda_k1,lambda_q2,lambda_k2," + "loop_embed,skip_weight,memory", +).split(",") if p) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# ── Tokenizer-aware BPB (from competition baseline) ─────────────────────────── + +def build_sentencepiece_luts(sp, vocab_size: int, device: torch.device): + sp_vocab = int(sp.vocab_size()) + table = max(sp_vocab, vocab_size) + base_np = np.zeros((table,), dtype=np.int16) + space_np = np.zeros((table,), dtype=np.bool_) + bnd_np = np.ones((table,), dtype=np.bool_) + for tid in range(sp_vocab): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue + bnd_np[tid] = False + if sp.is_byte(tid): base_np[tid] = 1; continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): space_np[tid] = True; piece = piece[1:] + base_np[tid] = len(piece.encode("utf-8")) + return (torch.tensor(base_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(bnd_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No val files: {pattern}") + tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[:usable + 1] + +@torch.no_grad() +def eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut): + local_tok = args.val_batch_size // (world_size * grad_accum_steps) + local_seqs = local_tok // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for s in range(seq_start, seq_end, local_seqs): + se = min(s + local_seqs, seq_end) + raw = val_tokens[s * args.train_seq_len:(se * args.train_seq_len + 1)] + raw = raw.to(device=device, dtype=torch.int64, non_blocking=True) + x = raw[:-1].reshape(-1, args.train_seq_len) + y = raw[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y).detach() + n = float(y.numel()) + loss_sum += loss.to(torch.float64) * n + tok_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_lut[tgt_ids].to(torch.int16) + tb += (space_lut[tgt_ids] & ~bnd_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, tok_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + model.train() + val_loss = float((loss_sum / tok_count).item()) + bpb = float((loss_sum / tok_count / math.log(2.0) * tok_count / byte_count).item()) + return val_loss, bpb + +# ── Quantization (from competition baseline) ────────────────────────────────── + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def _keep_float(name: str, t: Tensor, pt_dtypes: dict) -> Tensor: + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + pt_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + cl = torch.maximum(torch.minimum(t32, clip[:, None]), -clip[:, None]) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(cl / sc[:, None]), -127, 127).to(torch.int8).contiguous() + return q, sc.to(INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + sc = torch.tensor(clip / 127.0 if clip > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip, clip) / sc), -127, 127).to(torch.int8).contiguous() + return q, sc + +def quantize_state_dict_int8(sd: dict): + quant, scales, dtypes, passthrough, pt_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors", + "num_nonfloat_tensors","baseline_tensor_bytes","int8_payload_bytes"), 0) + for name, tensor in sd.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = _keep_float(name, t, pt_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = _quantize_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quant[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = {"__quant_format__": "int8_clean_per_row_v1", + "quantized": quant, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if pt_dtypes: obj["passthrough_orig_dtypes"] = pt_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict) -> dict: + out = {} + qm = obj.get("qmeta", {}) + ptod = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qm.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype).contiguous() + for name, t in obj["passthrough"].items(): + ot = t.detach().cpu().contiguous() + if isinstance(ptod.get(name), str): + ot = ot.to(getattr(torch, ptod[name])).contiguous() + out[name] = ot + return out + +# ── Data loading ────────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + hdr = np.fromfile(file, dtype=" Tensor: + chunks, rem = [], n + while rem > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._next(); continue + k = min(rem, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k; rem -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum: int): + local = global_tokens // (self.world_size * grad_accum) + span = local + 1 + chunk = self.stream.take(span * self.world_size) + s = self.rank * span + raw = chunk[s:s + span].to(torch.int64) + x = raw[:-1].reshape(-1, seq_len) + y = raw[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ── Model ───────────────────────────────────────────────────────────────────── + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), b) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__(); self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, p in module.named_parameters(): + if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and p.dtype != torch.float32: + p.data = p.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + self._sl, self._cos, self._sin = 0, None, None + def forward(self, sl: int, device, dtype): + if self._cos is None or self._sl != sl or self._cos.device != device: + t = torch.arange(sl, device=device, dtype=self.inv_freq.dtype) + f = torch.outer(t, self.inv_freq.to(device)) + self._cos = f.cos()[None, None, :, :]; self._sin = f.sin()[None, None, :, :] + self._sl = sl + return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) + +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def fake_quant_per_row(w: Tensor) -> Tensor: + """Per-row percentile int8 QAT with straight-through estimator.""" + if w.ndim < 2 or w.numel() == 0: return w + w32 = w.float() + clip = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + wq = (w32.clamp(-clip[:, None], clip[:, None]) / sc[:, None]).round().clamp(-127, 127) * sc[:, None] + return w + (wq.to(w.dtype) - w).detach() + +class SharedDiffAttention(nn.Module): + """Differential Attention V2 + per-loop low-rank Q delta + GQA.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, rope_base: float, rank: int): + super().__init__() + assert dim % nh == 0 and nh % nkv == 0 + self.nh, self.nkv, self.n_rep = nh, nkv, nh // nkv + self.hd = dim // nh + self.nmem = n_mem + + self.c_q = CastedLinear(dim, dim * 2, bias=False) + self.c_k = CastedLinear(dim, nkv * self.hd * 2, bias=False) + self.c_v = CastedLinear(dim, nkv * self.hd, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + li = 0.8 - 0.6 * math.exp(-0.3) + self.lambda_init = li + self.lambda_q1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_q2 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k2 = nn.Parameter(torch.randn(nh) * 0.1) + + self.q_gain = nn.Parameter(torch.ones(nh, dtype=torch.float32)) + self.delta_A = nn.Parameter(torch.zeros(n_loops, dim * 2, rank)) + self.delta_B = nn.Parameter(torch.zeros(n_loops, rank, dim)) + nn.init.normal_(self.delta_A, std=0.01) + + self.rotary = Rotary(self.hd, base=rope_base) + + def forward(self, x: Tensor, loop_idx: int) -> Tensor: + B, T, C = x.shape + H, KVH, D, NM = self.nh, self.nkv, self.hd, self.nmem + sl = T - NM + + # Q with per-loop low-rank delta + QAT on effective weight + wq = self.c_q.weight + (self.delta_A[loop_idx] @ self.delta_B[loop_idx]) + wq = fake_quant_per_row(wq) + qp = F.linear(x, wq.to(x.dtype)) + q1 = qp[..., :C].reshape(B, T, H, D).transpose(1, 2) + q2 = qp[..., C:].reshape(B, T, H, D).transpose(1, 2) + + # K, V: fresh every loop from updated hidden state + kv = self.c_k(x) + k1 = kv[..., :KVH*D].reshape(B, T, KVH, D).transpose(1, 2) + k2 = kv[..., KVH*D:].reshape(B, T, KVH, D).transpose(1, 2) + v = self.c_v(x).reshape(B, T, KVH, D).transpose(1, 2) + + # QK RMSNorm before RoPE (baseline technique) + q1 = F.rms_norm(q1, (D,)); q2 = F.rms_norm(q2, (D,)) + k1 = F.rms_norm(k1, (D,)); k2 = F.rms_norm(k2, (D,)) + + # RoPE: sequence positions ONLY — memory tokens excluded + cos, sin = self.rotary(sl, x.device, q1.dtype) + def rope_seq(q, k): + qm, qs = q[:, :, :NM], q[:, :, NM:] + km, ks = k[:, :, :NM], k[:, :, NM:] + return (torch.cat([qm, apply_rope(qs, cos, sin)], dim=2), + torch.cat([km, apply_rope(ks, cos, sin)], dim=2)) + q1, k1 = rope_seq(q1, k1) + q2, k2 = rope_seq(q2, k2) + + # Q gain + GQA + g = self.q_gain.to(q1.dtype)[None, :, None, None] + q1 = q1 * g; q2 = q2 * g + k1 = k1.repeat_interleave(self.n_rep, dim=1) + k2 = k2.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + # Differential attention: cancel noise via subtraction + lam = (torch.exp(self.lambda_q1) * torch.exp(self.lambda_k1) + - torch.exp(self.lambda_q2) * torch.exp(self.lambda_k2) + + self.lambda_init).to(q1.dtype)[None, :, None, None] + a1 = F.scaled_dot_product_attention(q1, k1, v, is_causal=True) + a2 = F.scaled_dot_product_attention(q2, k2, v, is_causal=True) + out = (a1 - lam * a2).transpose(1, 2).contiguous().view(B, T, C) + return self.proj(out) + +class ReluSquaredMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + h = dim * mlp_mult + self.fc = CastedLinear(dim, h, bias=False) + self.proj = CastedLinear(h, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class RecurrentBlock(nn.Module): + """Shared block with all per-loop parameters.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, mlp_mult: int, rope_base: float, rank: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SharedDiffAttention(dim, nh, nkv, n_loops, n_mem, rope_base, rank) + self.mlp = ReluSquaredMLP(dim, mlp_mult) + + # Per-loop scale factors (learned attention/MLP gate per depth) + self.attn_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + + # resid_mix: learned interpolation between current x and original x0 + # [0] = weight on current state, [1] = weight on original input + # Initialised to carry forward (mix[0]=1, mix[1]=0), learns to reset selectively + self.resid_mix = nn.Parameter( + torch.stack([torch.ones(n_loops, dim), torch.zeros(n_loops, dim)]).float() + ) + + # Loop position embedding: activation-level depth signal + # Complementary to weight-level low-rank delta + self.loop_embed = nn.Embedding(n_loops, dim) + nn.init.normal_(self.loop_embed.weight, std=0.02) + + def forward(self, x: Tensor, x0: Tensor, loop_idx: int) -> Tensor: + # resid_mix: learned balance of current state vs original input (generalises input injection) + mix = self.resid_mix[:, loop_idx, :].to(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # Loop position embedding (activation-level specialisation) + x = x + self.loop_embed.weight[loop_idx].to(x.dtype)[None, None, :] + + # Attention + MLP with per-loop scale + a = self.attn(self.attn_norm(x), loop_idx) + x = x + self.attn_scale[loop_idx].to(x.dtype)[None, None, :] * a + x = x + self.mlp_scale[loop_idx].to(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class RLM(nn.Module): + """ + Recurrent Language Model. + Single shared block × N_LOOPS with deep supervision and U-Net skips. + """ + def __init__(self, args: Hyperparameters): + super().__init__() + dim, nl, nm = args.model_dim, args.n_loops, args.n_memory + self.n_loops, self.n_memory = nl, nm + self.dim = dim + self.logit_softcap = args.logit_softcap + self.deep_sup_gamma = args.deep_sup_gamma + + self.n_enc = nl // 2 + self.n_skip = min(self.n_enc, nl - self.n_enc) + + self.tok_emb = nn.Embedding(args.vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + self.memory = nn.Parameter(torch.randn(1, nm, dim) * 0.02) + + self.block = RecurrentBlock( + dim, args.num_heads, args.num_kv_heads, nl, nm, + args.mlp_mult, args.rope_base, args.lora_rank + ) + + self.skip_weights = nn.Parameter( + torch.ones(self.n_skip, dim, dtype=torch.float32) + if self.n_skip > 0 else torch.zeros(0, dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + + # Zero-init output projections for stability at loop depth + nn.init.zeros_(self.block.attn.proj.weight) + nn.init.zeros_(self.block.mlp.proj.weight) + + def _logits(self, x: Tensor) -> Tensor: + h = self.final_norm(x[:, self.n_memory:, :]) + l = F.linear(h.reshape(-1, self.dim), self.tok_emb.weight) + return self.logit_softcap * torch.tanh(l / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + B, T = input_ids.shape + tok = F.rms_norm(self.tok_emb(input_ids), (self.dim,)) + x = torch.cat([self.memory.expand(B, -1, -1), tok], dim=1) + x0 = x + + gamma = self.deep_sup_gamma + total_loss = torch.zeros((), device=input_ids.device) + weight_sum = 0.0 + targets = target_ids.reshape(-1) + skips = [] + + for i in range(self.n_loops): + # Gradient checkpointing: recompute activations during backward + x = torch.utils.checkpoint.checkpoint(self.block, x, x0, i, use_reentrant=False) + + # U-Net: encoder half saves, decoder half consumes in reverse + if i < self.n_enc: + skips.append(x) + elif self.n_skip > 0 and skips: + di = i - self.n_enc + if di < self.n_skip: + sw = self.skip_weights[di].to(x.dtype)[None, None, :] + x = x + sw * skips[self.n_enc - 1 - di] + + # Deep supervision: weighted loss at every loop + # Later loops weighted higher (gamma^(N-1-i)), but all contribute + w = gamma ** (self.n_loops - 1 - i) + total_loss = total_loss + w * F.cross_entropy( + self._logits(x).float(), targets, reduction="mean" + ) + weight_sum += w + + return total_loss / weight_sum + +# ── Training ────────────────────────────────────────────────────────────────── + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + pass # torch.compile disabled for T4 + + # ── Distributed + CUDA setup ────────────────────────────────────────────── + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required. For CPU testing use train_gpt_local.py") + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(True); enable_math_sdp(True) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master else None + + def log0(msg: str, console: bool = True) -> None: + if not master: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + + # ── Seed + tokenizer ────────────────────────────────────────────────────── + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"Tokenizer vocab {sp.vocab_size()} != VOCAB_SIZE {args.vocab_size}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_lut, space_lut, bnd_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_tokens:{val_tokens.numel()-1} train_seq_len:{args.train_seq_len}") + + # ── Model ───────────────────────────────────────────────────────────────── + base_model = RLM(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(base_model) + compiled = base_model # torch.compile disabled: T4 compile too slow + model: nn.Module = (DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) + if distributed else compiled) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} (~{n_params/1e6:.2f}M)") + log0(f"n_loops:{args.n_loops} n_memory:{args.n_memory} dim:{args.model_dim} " + f"seq_len:{args.train_seq_len}") + + # ── Optimizers ──────────────────────────────────────────────────────────── + block_named = list(base_model.block.named_parameters()) + matrix_params = [p for n, p in block_named + if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = ([p for n, p in block_named + if p.ndim != 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + + [base_model.skip_weights, base_model.memory, base_model.final_norm.eps + if hasattr(base_model.final_norm, 'eps') and + isinstance(base_model.final_norm.eps, nn.Parameter) else None]) + scalar_params = [p for p in scalar_params if p is not None] + + opt_embed = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_muon = Muon(matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_embed, opt_muon, opt_scalar] + + log0(f"matrix_params:{sum(p.numel() for p in matrix_params):,} " + f"scalar_params:{sum(p.numel() for p in scalar_params):,}") + + # ── Checkpoint resume (thermal crash protection) ────────────────────────── + ckpt_path = f"checkpoint_{args.run_id}.pt" + start_step = 0 + training_time_ms = 0.0 + if os.path.exists(ckpt_path) and master: + log0(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + base_model.load_state_dict(ckpt["model"]) + for opt, sd in zip(optimizers, ckpt["optimizers"]): + opt.load_state_dict(sd) + start_step = ckpt["step"] + training_time_ms = ckpt.get("training_time_ms", 0.0) + log0(f"Resumed at step {start_step}") + if distributed: dist.barrier() + + # ── Data loader ─────────────────────────────────────────────────────────── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad(): + for o in optimizers: o.zero_grad(set_to_none=True) + + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_scale(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wc_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ + if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = args.warmdown_iters * step_ms + rem_ms = max(max_wc_ms - elapsed_ms, 0.0) + return rem_ms / max(wd_ms, 1e-9) if rem_ms <= wd_ms else 1.0 + + # ── Warmup-then-restore ─────────────────────────────────────────────────── + if args.warmup_steps > 0 and start_step == 0: + init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + (model(x, y) * grad_scale).backward() + for o in optimizers: o.step() + zero_grad() + log0(f"warmup_complete:{args.warmup_steps} steps") + base_model.load_state_dict(init_model, strict=True) + for o, sd in zip(optimizers, init_opts): o.load_state_dict(sd) + zero_grad() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ──────────────────────────────────────────────────── + stop_after: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + for step in range(start_step, args.iterations + 1): + last = step == args.iterations or (stop_after is not None and step >= stop_after) + + # Validation + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or last: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + + if last: break + + # Checkpoint save (thermal crash protection) + if args.save_every > 0 and step > start_step and step % args.save_every == 0 and master: + torch.cuda.synchronize() + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + torch.save({"step": step, "model": base_model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "training_time_ms": elapsed_now}, ckpt_path) + log0(f"checkpoint_saved:step={step}") + + # LR update + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_scale(step, elapsed_ms) + for o in optimizers: + for g in o.param_groups: g["lr"] = g["base_lr"] * scale + + # Muon momentum warmup + frac = min(step / max(args.muon_momentum_warmup_steps, 1), 1.0) + mum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for g in opt_muon.param_groups: g["momentum"] = mum + + # Forward + backward with gradient accumulation + zero_grad() + train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for o in optimizers: o.step() + zero_grad() + + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step+1}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ms:.0f}ms step_avg:{approx_ms/(step+1):.2f}ms lr_scale:{scale:.4f}") + + reached = max_wc_ms is not None and approx_ms >= max_wc_ms + if distributed and max_wc_ms is not None: + rt = torch.tensor(int(reached), device=device) + dist.all_reduce(rt, op=dist.ReduceOp.MAX) + reached = bool(rt.item()) + if stop_after is None and reached: + stop_after = step + 1 + + log0(f"peak_memory:{torch.cuda.max_memory_allocated()//1024//1024}MiB " + f"reserved:{torch.cuda.max_memory_reserved()//1024//1024}MiB") + + # ── Serialisation + roundtrip validation ────────────────────────────────── + if master: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + model_bytes = os.path.getsize("final_model.pt") + log0(f"raw_model:{model_bytes} code:{code_bytes} total:{model_bytes+code_bytes}") + + quant_obj, qstats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + if master: + with open("final_model.int8.ptz", "wb") as f: f.write(blob) + qbytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = qstats["baseline_tensor_bytes"] / max(qstats["int8_payload_bytes"], 1) + log0(f"int8_zlib:{qbytes} code:{code_bytes} total:{qbytes+code_bytes} " + f"payload_ratio:{ratio:.2f}x budget_pct:{(qbytes+code_bytes)/16e6*100:.1f}%") + + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: blob_disk = f.read() + qs2 = torch.load(io.BytesIO(zlib.decompress(blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(qs2), strict=True) + torch.cuda.synchronize(); te = time.perf_counter() + qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + torch.cuda.synchronize() + log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-te):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvbpb:.8f}") + + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +PyTorch 2.10.0+cu128 +Fri Mar 20 07:40:04 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 55C P0 27W / 70W | 105MiB / 15360MiB | 4% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 25958 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +val_tokens:62021632 train_seq_len:1024 +model_params:4255784 (~4.26M) +n_loops:16 n_memory:32 dim:512 seq_len:1024 +matrix_params:3,276,800 scalar_params:454,696 +""" +train_gpt.py — Shared-Weight Recurrent Memory Transformer (RLM) +OpenAI Parameter Golf Competition — Non-Record Submission + +Architecture: One transformer block looped N_LOOPS=16 times (weight tying). + - Differential Attention V2 + per-loop low-rank Q delta (weight specialization) + - resid_mix: learned carry-forward vs input-reset per loop + - Loop position embeddings (activation-level depth signal) + - Memory tokens: learned persistent cross-loop state (RoPE excluded) + - Deep supervision: gamma-weighted loss at every loop iteration + - U-Net skip connections between encoder/decoder loop halves + - Gradient checkpointing: memory-efficient backprop through 16 loops + - QAT: per-row percentile fake-quantization during training + - relu² MLP, logit softcap, QK RMSNorm, q_gain (from baseline) + +At dim=512, 16 loops: ~4.24M params, ~4.24MB int8 — 26.5% of 16MB budget. +Effective compute depth of a ~63M parameter standard transformer. + +Thermal protection: checkpoints saved every SAVE_EVERY steps, auto-resumed. +""" + +from __future__ import annotations +import copy, glob, io, math, os, subprocess, sys, time, uuid, zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ── Hyperparameters ─────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + save_every = int(os.environ.get("SAVE_EVERY", 500)) + + iterations = int(os.environ.get("ITERATIONS", 5000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 600)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 65_536)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0.0)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + n_loops = int(os.environ.get("N_LOOPS", 16)) + n_memory = int(os.environ.get("N_MEMORY", 32)) + mlp_mult = int(os.environ.get("MLP_MULT", 4)) + lora_rank = int(os.environ.get("LORA_RANK", 16)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + deep_sup_gamma= float(os.environ.get("DEEP_SUP_GAMMA", 0.9)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum= float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 300)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.1)) + +# ── Muon Optimizer (from modded-nanogpt / competition baseline) ─────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum, backend_steps, nesterov = ( + group["lr"], group["momentum"], group["backend_steps"], group["nesterov"] + ) + total = sum(int(p.numel()) for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-lr) + curr += p.numel() + return loss + +# ── Control tensor patterns (kept fp32, excluded from Muon) ────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,mlp_scale,resid_mix,q_gain,lambda_q1,lambda_k1,lambda_q2,lambda_k2," + "loop_embed,skip_weight,memory", +).split(",") if p) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# ── Tokenizer-aware BPB (from competition baseline) ─────────────────────────── + +def build_sentencepiece_luts(sp, vocab_size: int, device: torch.device): + sp_vocab = int(sp.vocab_size()) + table = max(sp_vocab, vocab_size) + base_np = np.zeros((table,), dtype=np.int16) + space_np = np.zeros((table,), dtype=np.bool_) + bnd_np = np.ones((table,), dtype=np.bool_) + for tid in range(sp_vocab): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue + bnd_np[tid] = False + if sp.is_byte(tid): base_np[tid] = 1; continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): space_np[tid] = True; piece = piece[1:] + base_np[tid] = len(piece.encode("utf-8")) + return (torch.tensor(base_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(bnd_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No val files: {pattern}") + tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[:usable + 1] + +@torch.no_grad() +def eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut): + local_tok = args.val_batch_size // (world_size * grad_accum_steps) + local_seqs = local_tok // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for s in range(seq_start, seq_end, local_seqs): + se = min(s + local_seqs, seq_end) + raw = val_tokens[s * args.train_seq_len:(se * args.train_seq_len + 1)] + raw = raw.to(device=device, dtype=torch.int64, non_blocking=True) + x = raw[:-1].reshape(-1, args.train_seq_len) + y = raw[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y).detach() + n = float(y.numel()) + loss_sum += loss.to(torch.float64) * n + tok_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_lut[tgt_ids].to(torch.int16) + tb += (space_lut[tgt_ids] & ~bnd_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, tok_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + model.train() + val_loss = float((loss_sum / tok_count).item()) + bpb = float((loss_sum / tok_count / math.log(2.0) * tok_count / byte_count).item()) + return val_loss, bpb + +# ── Quantization (from competition baseline) ────────────────────────────────── + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def _keep_float(name: str, t: Tensor, pt_dtypes: dict) -> Tensor: + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + pt_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + cl = torch.maximum(torch.minimum(t32, clip[:, None]), -clip[:, None]) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(cl / sc[:, None]), -127, 127).to(torch.int8).contiguous() + return q, sc.to(INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + sc = torch.tensor(clip / 127.0 if clip > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip, clip) / sc), -127, 127).to(torch.int8).contiguous() + return q, sc + +def quantize_state_dict_int8(sd: dict): + quant, scales, dtypes, passthrough, pt_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors", + "num_nonfloat_tensors","baseline_tensor_bytes","int8_payload_bytes"), 0) + for name, tensor in sd.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = _keep_float(name, t, pt_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = _quantize_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quant[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = {"__quant_format__": "int8_clean_per_row_v1", + "quantized": quant, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if pt_dtypes: obj["passthrough_orig_dtypes"] = pt_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict) -> dict: + out = {} + qm = obj.get("qmeta", {}) + ptod = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qm.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype).contiguous() + for name, t in obj["passthrough"].items(): + ot = t.detach().cpu().contiguous() + if isinstance(ptod.get(name), str): + ot = ot.to(getattr(torch, ptod[name])).contiguous() + out[name] = ot + return out + +# ── Data loading ────────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + hdr = np.fromfile(file, dtype=" Tensor: + chunks, rem = [], n + while rem > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._next(); continue + k = min(rem, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k; rem -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum: int): + local = global_tokens // (self.world_size * grad_accum) + span = local + 1 + chunk = self.stream.take(span * self.world_size) + s = self.rank * span + raw = chunk[s:s + span].to(torch.int64) + x = raw[:-1].reshape(-1, seq_len) + y = raw[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ── Model ───────────────────────────────────────────────────────────────────── + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), b) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__(); self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, p in module.named_parameters(): + if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and p.dtype != torch.float32: + p.data = p.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + self._sl, self._cos, self._sin = 0, None, None + def forward(self, sl: int, device, dtype): + if self._cos is None or self._sl != sl or self._cos.device != device: + t = torch.arange(sl, device=device, dtype=self.inv_freq.dtype) + f = torch.outer(t, self.inv_freq.to(device)) + self._cos = f.cos()[None, None, :, :]; self._sin = f.sin()[None, None, :, :] + self._sl = sl + return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) + +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def fake_quant_per_row(w: Tensor) -> Tensor: + """Per-row percentile int8 QAT with straight-through estimator.""" + if w.ndim < 2 or w.numel() == 0: return w + w32 = w.float() + clip = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + wq = (w32.clamp(-clip[:, None], clip[:, None]) / sc[:, None]).round().clamp(-127, 127) * sc[:, None] + return w + (wq.to(w.dtype) - w).detach() + +class SharedDiffAttention(nn.Module): + """Differential Attention V2 + per-loop low-rank Q delta + GQA.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, rope_base: float, rank: int): + super().__init__() + assert dim % nh == 0 and nh % nkv == 0 + self.nh, self.nkv, self.n_rep = nh, nkv, nh // nkv + self.hd = dim // nh + self.nmem = n_mem + + self.c_q = CastedLinear(dim, dim * 2, bias=False) + self.c_k = CastedLinear(dim, nkv * self.hd * 2, bias=False) + self.c_v = CastedLinear(dim, nkv * self.hd, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + li = 0.8 - 0.6 * math.exp(-0.3) + self.lambda_init = li + self.lambda_q1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_q2 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k2 = nn.Parameter(torch.randn(nh) * 0.1) + + self.q_gain = nn.Parameter(torch.ones(nh, dtype=torch.float32)) + self.delta_A = nn.Parameter(torch.zeros(n_loops, dim * 2, rank)) + self.delta_B = nn.Parameter(torch.zeros(n_loops, rank, dim)) + nn.init.normal_(self.delta_A, std=0.01) + + self.rotary = Rotary(self.hd, base=rope_base) + + def forward(self, x: Tensor, loop_idx: int) -> Tensor: + B, T, C = x.shape + H, KVH, D, NM = self.nh, self.nkv, self.hd, self.nmem + sl = T - NM + + # Q with per-loop low-rank delta + QAT on effective weight + wq = self.c_q.weight + (self.delta_A[loop_idx] @ self.delta_B[loop_idx]) + wq = fake_quant_per_row(wq) + qp = F.linear(x, wq.to(x.dtype)) + q1 = qp[..., :C].reshape(B, T, H, D).transpose(1, 2) + q2 = qp[..., C:].reshape(B, T, H, D).transpose(1, 2) + + # K, V: fresh every loop from updated hidden state + kv = self.c_k(x) + k1 = kv[..., :KVH*D].reshape(B, T, KVH, D).transpose(1, 2) + k2 = kv[..., KVH*D:].reshape(B, T, KVH, D).transpose(1, 2) + v = self.c_v(x).reshape(B, T, KVH, D).transpose(1, 2) + + # QK RMSNorm before RoPE (baseline technique) + q1 = F.rms_norm(q1, (D,)); q2 = F.rms_norm(q2, (D,)) + k1 = F.rms_norm(k1, (D,)); k2 = F.rms_norm(k2, (D,)) + + # RoPE: sequence positions ONLY — memory tokens excluded + cos, sin = self.rotary(sl, x.device, q1.dtype) + def rope_seq(q, k): + qm, qs = q[:, :, :NM], q[:, :, NM:] + km, ks = k[:, :, :NM], k[:, :, NM:] + return (torch.cat([qm, apply_rope(qs, cos, sin)], dim=2), + torch.cat([km, apply_rope(ks, cos, sin)], dim=2)) + q1, k1 = rope_seq(q1, k1) + q2, k2 = rope_seq(q2, k2) + + # Q gain + GQA + g = self.q_gain.to(q1.dtype)[None, :, None, None] + q1 = q1 * g; q2 = q2 * g + k1 = k1.repeat_interleave(self.n_rep, dim=1) + k2 = k2.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + # Differential attention: cancel noise via subtraction + lam = (torch.exp(self.lambda_q1) * torch.exp(self.lambda_k1) + - torch.exp(self.lambda_q2) * torch.exp(self.lambda_k2) + + self.lambda_init).to(q1.dtype)[None, :, None, None] + a1 = F.scaled_dot_product_attention(q1, k1, v, is_causal=True) + a2 = F.scaled_dot_product_attention(q2, k2, v, is_causal=True) + out = (a1 - lam * a2).transpose(1, 2).contiguous().view(B, T, C) + return self.proj(out) + +class ReluSquaredMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + h = dim * mlp_mult + self.fc = CastedLinear(dim, h, bias=False) + self.proj = CastedLinear(h, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class RecurrentBlock(nn.Module): + """Shared block with all per-loop parameters.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, mlp_mult: int, rope_base: float, rank: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SharedDiffAttention(dim, nh, nkv, n_loops, n_mem, rope_base, rank) + self.mlp = ReluSquaredMLP(dim, mlp_mult) + + # Per-loop scale factors (learned attention/MLP gate per depth) + self.attn_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + + # resid_mix: learned interpolation between current x and original x0 + # [0] = weight on current state, [1] = weight on original input + # Initialised to carry forward (mix[0]=1, mix[1]=0), learns to reset selectively + self.resid_mix = nn.Parameter( + torch.stack([torch.ones(n_loops, dim), torch.zeros(n_loops, dim)]).float() + ) + + # Loop position embedding: activation-level depth signal + # Complementary to weight-level low-rank delta + self.loop_embed = nn.Embedding(n_loops, dim) + nn.init.normal_(self.loop_embed.weight, std=0.02) + + def forward(self, x: Tensor, x0: Tensor, loop_idx: int) -> Tensor: + # resid_mix: learned balance of current state vs original input (generalises input injection) + mix = self.resid_mix[:, loop_idx, :].to(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # Loop position embedding (activation-level specialisation) + x = x + self.loop_embed.weight[loop_idx].to(x.dtype)[None, None, :] + + # Attention + MLP with per-loop scale + a = self.attn(self.attn_norm(x), loop_idx) + x = x + self.attn_scale[loop_idx].to(x.dtype)[None, None, :] * a + x = x + self.mlp_scale[loop_idx].to(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class RLM(nn.Module): + """ + Recurrent Language Model. + Single shared block × N_LOOPS with deep supervision and U-Net skips. + """ + def __init__(self, args: Hyperparameters): + super().__init__() + dim, nl, nm = args.model_dim, args.n_loops, args.n_memory + self.n_loops, self.n_memory = nl, nm + self.dim = dim + self.logit_softcap = args.logit_softcap + self.deep_sup_gamma = args.deep_sup_gamma + + self.n_enc = nl // 2 + self.n_skip = min(self.n_enc, nl - self.n_enc) + + self.tok_emb = nn.Embedding(args.vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + self.memory = nn.Parameter(torch.randn(1, nm, dim) * 0.02) + + self.block = RecurrentBlock( + dim, args.num_heads, args.num_kv_heads, nl, nm, + args.mlp_mult, args.rope_base, args.lora_rank + ) + + self.skip_weights = nn.Parameter( + torch.ones(self.n_skip, dim, dtype=torch.float32) + if self.n_skip > 0 else torch.zeros(0, dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + + # Zero-init output projections for stability at loop depth + nn.init.zeros_(self.block.attn.proj.weight) + nn.init.zeros_(self.block.mlp.proj.weight) + + def _logits(self, x: Tensor) -> Tensor: + h = self.final_norm(x[:, self.n_memory:, :]) + l = F.linear(h.reshape(-1, self.dim), self.tok_emb.weight) + return self.logit_softcap * torch.tanh(l / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + B, T = input_ids.shape + tok = F.rms_norm(self.tok_emb(input_ids), (self.dim,)) + x = torch.cat([self.memory.expand(B, -1, -1), tok], dim=1) + x0 = x + + gamma = self.deep_sup_gamma + total_loss = torch.zeros((), device=input_ids.device) + weight_sum = 0.0 + targets = target_ids.reshape(-1) + skips = [] + + for i in range(self.n_loops): + # Gradient checkpointing: recompute activations during backward + x = torch.utils.checkpoint.checkpoint(self.block, x, x0, i, use_reentrant=False) + + # U-Net: encoder half saves, decoder half consumes in reverse + if i < self.n_enc: + skips.append(x) + elif self.n_skip > 0 and skips: + di = i - self.n_enc + if di < self.n_skip: + sw = self.skip_weights[di].to(x.dtype)[None, None, :] + x = x + sw * skips[self.n_enc - 1 - di] + + # Deep supervision: weighted loss at every loop + # Later loops weighted higher (gamma^(N-1-i)), but all contribute + w = gamma ** (self.n_loops - 1 - i) + total_loss = total_loss + w * F.cross_entropy( + self._logits(x).float(), targets, reduction="mean" + ) + weight_sum += w + + return total_loss / weight_sum + +# ── Training ────────────────────────────────────────────────────────────────── + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + pass # torch.compile disabled for T4 + + # ── Distributed + CUDA setup ────────────────────────────────────────────── + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required. For CPU testing use train_gpt_local.py") + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(True); enable_math_sdp(True) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master else None + + def log0(msg: str, console: bool = True) -> None: + if not master: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + + # ── Seed + tokenizer ────────────────────────────────────────────────────── + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"Tokenizer vocab {sp.vocab_size()} != VOCAB_SIZE {args.vocab_size}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_lut, space_lut, bnd_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_tokens:{val_tokens.numel()-1} train_seq_len:{args.train_seq_len}") + + # ── Model ───────────────────────────────────────────────────────────────── + base_model = RLM(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(base_model) + compiled = base_model # torch.compile disabled: T4 compile too slow + model: nn.Module = (DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) + if distributed else compiled) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} (~{n_params/1e6:.2f}M)") + log0(f"n_loops:{args.n_loops} n_memory:{args.n_memory} dim:{args.model_dim} " + f"seq_len:{args.train_seq_len}") + + # ── Optimizers ──────────────────────────────────────────────────────────── + block_named = list(base_model.block.named_parameters()) + matrix_params = [p for n, p in block_named + if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = ([p for n, p in block_named + if p.ndim != 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + + [base_model.skip_weights, base_model.memory, base_model.final_norm.eps + if hasattr(base_model.final_norm, 'eps') and + isinstance(base_model.final_norm.eps, nn.Parameter) else None]) + scalar_params = [p for p in scalar_params if p is not None] + + opt_embed = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_muon = Muon(matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_embed, opt_muon, opt_scalar] + + log0(f"matrix_params:{sum(p.numel() for p in matrix_params):,} " + f"scalar_params:{sum(p.numel() for p in scalar_params):,}") + + # ── Checkpoint resume (thermal crash protection) ────────────────────────── + ckpt_path = f"checkpoint_{args.run_id}.pt" + start_step = 0 + training_time_ms = 0.0 + if os.path.exists(ckpt_path) and master: + log0(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + base_model.load_state_dict(ckpt["model"]) + for opt, sd in zip(optimizers, ckpt["optimizers"]): + opt.load_state_dict(sd) + start_step = ckpt["step"] + training_time_ms = ckpt.get("training_time_ms", 0.0) + log0(f"Resumed at step {start_step}") + if distributed: dist.barrier() + + # ── Data loader ─────────────────────────────────────────────────────────── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad(): + for o in optimizers: o.zero_grad(set_to_none=True) + + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_scale(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wc_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ + if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = args.warmdown_iters * step_ms + rem_ms = max(max_wc_ms - elapsed_ms, 0.0) + return rem_ms / max(wd_ms, 1e-9) if rem_ms <= wd_ms else 1.0 + + # ── Warmup-then-restore ─────────────────────────────────────────────────── + if args.warmup_steps > 0 and start_step == 0: + init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + (model(x, y) * grad_scale).backward() + for o in optimizers: o.step() + zero_grad() + log0(f"warmup_complete:{args.warmup_steps} steps") + base_model.load_state_dict(init_model, strict=True) + for o, sd in zip(optimizers, init_opts): o.load_state_dict(sd) + zero_grad() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ──────────────────────────────────────────────────── + stop_after: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + for step in range(start_step, args.iterations + 1): + last = step == args.iterations or (stop_after is not None and step >= stop_after) + + # Validation + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or last: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + + if last: break + + # Checkpoint save (thermal crash protection) + if args.save_every > 0 and step > start_step and step % args.save_every == 0 and master: + torch.cuda.synchronize() + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + torch.save({"step": step, "model": base_model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "training_time_ms": elapsed_now}, ckpt_path) + log0(f"checkpoint_saved:step={step}") + + # LR update + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_scale(step, elapsed_ms) + for o in optimizers: + for g in o.param_groups: g["lr"] = g["base_lr"] * scale + + # Muon momentum warmup + frac = min(step / max(args.muon_momentum_warmup_steps, 1), 1.0) + mum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for g in opt_muon.param_groups: g["momentum"] = mum + + # Forward + backward with gradient accumulation + zero_grad() + train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for o in optimizers: o.step() + zero_grad() + + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step+1}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ms:.0f}ms step_avg:{approx_ms/(step+1):.2f}ms lr_scale:{scale:.4f}") + + reached = max_wc_ms is not None and approx_ms >= max_wc_ms + if distributed and max_wc_ms is not None: + rt = torch.tensor(int(reached), device=device) + dist.all_reduce(rt, op=dist.ReduceOp.MAX) + reached = bool(rt.item()) + if stop_after is None and reached: + stop_after = step + 1 + + log0(f"peak_memory:{torch.cuda.max_memory_allocated()//1024//1024}MiB " + f"reserved:{torch.cuda.max_memory_reserved()//1024//1024}MiB") + + # ── Serialisation + roundtrip validation ────────────────────────────────── + if master: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + model_bytes = os.path.getsize("final_model.pt") + log0(f"raw_model:{model_bytes} code:{code_bytes} total:{model_bytes+code_bytes}") + + quant_obj, qstats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + if master: + with open("final_model.int8.ptz", "wb") as f: f.write(blob) + qbytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = qstats["baseline_tensor_bytes"] / max(qstats["int8_payload_bytes"], 1) + log0(f"int8_zlib:{qbytes} code:{code_bytes} total:{qbytes+code_bytes} " + f"payload_ratio:{ratio:.2f}x budget_pct:{(qbytes+code_bytes)/16e6*100:.1f}%") + + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: blob_disk = f.read() + qs2 = torch.load(io.BytesIO(zlib.decompress(blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(qs2), strict=True) + torch.cuda.synchronize(); te = time.perf_counter() + qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + torch.cuda.synchronize() + log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-te):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvbpb:.8f}") + + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +PyTorch 2.10.0+cu128 +Fri Mar 20 07:42:50 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 69C P0 30W / 70W | 105MiB / 15360MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 26691 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +val_tokens:62021632 train_seq_len:512 +model_params:4255784 (~4.26M) +n_loops:16 n_memory:32 dim:512 seq_len:512 +matrix_params:3,276,800 scalar_params:454,696 +warmup_complete:1 steps +""" +train_gpt.py — Shared-Weight Recurrent Memory Transformer (RLM) +OpenAI Parameter Golf Competition — Non-Record Submission + +Architecture: One transformer block looped N_LOOPS=16 times (weight tying). + - Differential Attention V2 + per-loop low-rank Q delta (weight specialization) + - resid_mix: learned carry-forward vs input-reset per loop + - Loop position embeddings (activation-level depth signal) + - Memory tokens: learned persistent cross-loop state (RoPE excluded) + - Deep supervision: gamma-weighted loss at every loop iteration + - U-Net skip connections between encoder/decoder loop halves + - Gradient checkpointing: memory-efficient backprop through 16 loops + - QAT: per-row percentile fake-quantization during training + - relu² MLP, logit softcap, QK RMSNorm, q_gain (from baseline) + +At dim=512, 16 loops: ~4.24M params, ~4.24MB int8 — 26.5% of 16MB budget. +Effective compute depth of a ~63M parameter standard transformer. + +Thermal protection: checkpoints saved every SAVE_EVERY steps, auto-resumed. +""" + +from __future__ import annotations +import copy, glob, io, math, os, subprocess, sys, time, uuid, zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ── Hyperparameters ─────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + save_every = int(os.environ.get("SAVE_EVERY", 500)) + + iterations = int(os.environ.get("ITERATIONS", 5000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 600)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 65_536)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0.0)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + n_loops = int(os.environ.get("N_LOOPS", 16)) + n_memory = int(os.environ.get("N_MEMORY", 32)) + mlp_mult = int(os.environ.get("MLP_MULT", 4)) + lora_rank = int(os.environ.get("LORA_RANK", 16)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + deep_sup_gamma= float(os.environ.get("DEEP_SUP_GAMMA", 0.9)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum= float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 300)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.1)) + +# ── Muon Optimizer (from modded-nanogpt / competition baseline) ─────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum, backend_steps, nesterov = ( + group["lr"], group["momentum"], group["backend_steps"], group["nesterov"] + ) + total = sum(int(p.numel()) for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-lr) + curr += p.numel() + return loss + +# ── Control tensor patterns (kept fp32, excluded from Muon) ────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,mlp_scale,resid_mix,q_gain,lambda_q1,lambda_k1,lambda_q2,lambda_k2," + "loop_embed,skip_weight,memory", +).split(",") if p) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# ── Tokenizer-aware BPB (from competition baseline) ─────────────────────────── + +def build_sentencepiece_luts(sp, vocab_size: int, device: torch.device): + sp_vocab = int(sp.vocab_size()) + table = max(sp_vocab, vocab_size) + base_np = np.zeros((table,), dtype=np.int16) + space_np = np.zeros((table,), dtype=np.bool_) + bnd_np = np.ones((table,), dtype=np.bool_) + for tid in range(sp_vocab): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue + bnd_np[tid] = False + if sp.is_byte(tid): base_np[tid] = 1; continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): space_np[tid] = True; piece = piece[1:] + base_np[tid] = len(piece.encode("utf-8")) + return (torch.tensor(base_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(bnd_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No val files: {pattern}") + tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[:usable + 1] + +@torch.no_grad() +def eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut): + local_tok = args.val_batch_size // (world_size * grad_accum_steps) + local_seqs = local_tok // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for s in range(seq_start, seq_end, local_seqs): + se = min(s + local_seqs, seq_end) + raw = val_tokens[s * args.train_seq_len:(se * args.train_seq_len + 1)] + raw = raw.to(device=device, dtype=torch.int64, non_blocking=True) + x = raw[:-1].reshape(-1, args.train_seq_len) + y = raw[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y).detach() + n = float(y.numel()) + loss_sum += loss.to(torch.float64) * n + tok_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_lut[tgt_ids].to(torch.int16) + tb += (space_lut[tgt_ids] & ~bnd_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, tok_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + model.train() + val_loss = float((loss_sum / tok_count).item()) + bpb = float((loss_sum / tok_count / math.log(2.0) * tok_count / byte_count).item()) + return val_loss, bpb + +# ── Quantization (from competition baseline) ────────────────────────────────── + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def _keep_float(name: str, t: Tensor, pt_dtypes: dict) -> Tensor: + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + pt_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + cl = torch.maximum(torch.minimum(t32, clip[:, None]), -clip[:, None]) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(cl / sc[:, None]), -127, 127).to(torch.int8).contiguous() + return q, sc.to(INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + sc = torch.tensor(clip / 127.0 if clip > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip, clip) / sc), -127, 127).to(torch.int8).contiguous() + return q, sc + +def quantize_state_dict_int8(sd: dict): + quant, scales, dtypes, passthrough, pt_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors", + "num_nonfloat_tensors","baseline_tensor_bytes","int8_payload_bytes"), 0) + for name, tensor in sd.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = _keep_float(name, t, pt_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = _quantize_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quant[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = {"__quant_format__": "int8_clean_per_row_v1", + "quantized": quant, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if pt_dtypes: obj["passthrough_orig_dtypes"] = pt_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict) -> dict: + out = {} + qm = obj.get("qmeta", {}) + ptod = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qm.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype).contiguous() + for name, t in obj["passthrough"].items(): + ot = t.detach().cpu().contiguous() + if isinstance(ptod.get(name), str): + ot = ot.to(getattr(torch, ptod[name])).contiguous() + out[name] = ot + return out + +# ── Data loading ────────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + hdr = np.fromfile(file, dtype=" Tensor: + chunks, rem = [], n + while rem > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._next(); continue + k = min(rem, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k; rem -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum: int): + local = global_tokens // (self.world_size * grad_accum) + span = local + 1 + chunk = self.stream.take(span * self.world_size) + s = self.rank * span + raw = chunk[s:s + span].to(torch.int64) + x = raw[:-1].reshape(-1, seq_len) + y = raw[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ── Model ───────────────────────────────────────────────────────────────────── + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), b) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__(); self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, p in module.named_parameters(): + if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and p.dtype != torch.float32: + p.data = p.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + self._sl, self._cos, self._sin = 0, None, None + def forward(self, sl: int, device, dtype): + if self._cos is None or self._sl != sl or self._cos.device != device: + t = torch.arange(sl, device=device, dtype=self.inv_freq.dtype) + f = torch.outer(t, self.inv_freq.to(device)) + self._cos = f.cos()[None, None, :, :]; self._sin = f.sin()[None, None, :, :] + self._sl = sl + return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) + +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def fake_quant_per_row(w: Tensor) -> Tensor: + """Per-row percentile int8 QAT with straight-through estimator.""" + if w.ndim < 2 or w.numel() == 0: return w + w32 = w.float() + clip = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + wq = (w32.clamp(-clip[:, None], clip[:, None]) / sc[:, None]).round().clamp(-127, 127) * sc[:, None] + return w + (wq.to(w.dtype) - w).detach() + +class SharedDiffAttention(nn.Module): + """Differential Attention V2 + per-loop low-rank Q delta + GQA.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, rope_base: float, rank: int): + super().__init__() + assert dim % nh == 0 and nh % nkv == 0 + self.nh, self.nkv, self.n_rep = nh, nkv, nh // nkv + self.hd = dim // nh + self.nmem = n_mem + + self.c_q = CastedLinear(dim, dim * 2, bias=False) + self.c_k = CastedLinear(dim, nkv * self.hd * 2, bias=False) + self.c_v = CastedLinear(dim, nkv * self.hd, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + li = 0.8 - 0.6 * math.exp(-0.3) + self.lambda_init = li + self.lambda_q1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_q2 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k2 = nn.Parameter(torch.randn(nh) * 0.1) + + self.q_gain = nn.Parameter(torch.ones(nh, dtype=torch.float32)) + self.delta_A = nn.Parameter(torch.zeros(n_loops, dim * 2, rank)) + self.delta_B = nn.Parameter(torch.zeros(n_loops, rank, dim)) + nn.init.normal_(self.delta_A, std=0.01) + + self.rotary = Rotary(self.hd, base=rope_base) + + def forward(self, x: Tensor, loop_idx: int) -> Tensor: + B, T, C = x.shape + H, KVH, D, NM = self.nh, self.nkv, self.hd, self.nmem + sl = T - NM + + # Q with per-loop low-rank delta + QAT on effective weight + wq = self.c_q.weight + (self.delta_A[loop_idx] @ self.delta_B[loop_idx]) + wq = fake_quant_per_row(wq) + qp = F.linear(x, wq.to(x.dtype)) + q1 = qp[..., :C].reshape(B, T, H, D).transpose(1, 2) + q2 = qp[..., C:].reshape(B, T, H, D).transpose(1, 2) + + # K, V: fresh every loop from updated hidden state + kv = self.c_k(x) + k1 = kv[..., :KVH*D].reshape(B, T, KVH, D).transpose(1, 2) + k2 = kv[..., KVH*D:].reshape(B, T, KVH, D).transpose(1, 2) + v = self.c_v(x).reshape(B, T, KVH, D).transpose(1, 2) + + # QK RMSNorm before RoPE (baseline technique) + q1 = F.rms_norm(q1, (D,)); q2 = F.rms_norm(q2, (D,)) + k1 = F.rms_norm(k1, (D,)); k2 = F.rms_norm(k2, (D,)) + + # RoPE: sequence positions ONLY — memory tokens excluded + cos, sin = self.rotary(sl, x.device, q1.dtype) + def rope_seq(q, k): + qm, qs = q[:, :, :NM], q[:, :, NM:] + km, ks = k[:, :, :NM], k[:, :, NM:] + return (torch.cat([qm, apply_rope(qs, cos, sin)], dim=2), + torch.cat([km, apply_rope(ks, cos, sin)], dim=2)) + q1, k1 = rope_seq(q1, k1) + q2, k2 = rope_seq(q2, k2) + + # Q gain + GQA + g = self.q_gain.to(q1.dtype)[None, :, None, None] + q1 = q1 * g; q2 = q2 * g + k1 = k1.repeat_interleave(self.n_rep, dim=1) + k2 = k2.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + # Differential attention: cancel noise via subtraction + lam = (torch.exp(self.lambda_q1) * torch.exp(self.lambda_k1) + - torch.exp(self.lambda_q2) * torch.exp(self.lambda_k2) + + self.lambda_init).to(q1.dtype)[None, :, None, None] + a1 = F.scaled_dot_product_attention(q1, k1, v, is_causal=True) + a2 = F.scaled_dot_product_attention(q2, k2, v, is_causal=True) + out = (a1 - lam * a2).transpose(1, 2).contiguous().view(B, T, C) + return self.proj(out) + +class ReluSquaredMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + h = dim * mlp_mult + self.fc = CastedLinear(dim, h, bias=False) + self.proj = CastedLinear(h, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class RecurrentBlock(nn.Module): + """Shared block with all per-loop parameters.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, mlp_mult: int, rope_base: float, rank: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SharedDiffAttention(dim, nh, nkv, n_loops, n_mem, rope_base, rank) + self.mlp = ReluSquaredMLP(dim, mlp_mult) + + # Per-loop scale factors (learned attention/MLP gate per depth) + self.attn_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + + # resid_mix: learned interpolation between current x and original x0 + # [0] = weight on current state, [1] = weight on original input + # Initialised to carry forward (mix[0]=1, mix[1]=0), learns to reset selectively + self.resid_mix = nn.Parameter( + torch.stack([torch.ones(n_loops, dim), torch.zeros(n_loops, dim)]).float() + ) + + # Loop position embedding: activation-level depth signal + # Complementary to weight-level low-rank delta + self.loop_embed = nn.Embedding(n_loops, dim) + nn.init.normal_(self.loop_embed.weight, std=0.02) + + def forward(self, x: Tensor, x0: Tensor, loop_idx: int) -> Tensor: + # resid_mix: learned balance of current state vs original input (generalises input injection) + mix = self.resid_mix[:, loop_idx, :].to(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # Loop position embedding (activation-level specialisation) + x = x + self.loop_embed.weight[loop_idx].to(x.dtype)[None, None, :] + + # Attention + MLP with per-loop scale + a = self.attn(self.attn_norm(x), loop_idx) + x = x + self.attn_scale[loop_idx].to(x.dtype)[None, None, :] * a + x = x + self.mlp_scale[loop_idx].to(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class RLM(nn.Module): + """ + Recurrent Language Model. + Single shared block × N_LOOPS with deep supervision and U-Net skips. + """ + def __init__(self, args: Hyperparameters): + super().__init__() + dim, nl, nm = args.model_dim, args.n_loops, args.n_memory + self.n_loops, self.n_memory = nl, nm + self.dim = dim + self.logit_softcap = args.logit_softcap + self.deep_sup_gamma = args.deep_sup_gamma + + self.n_enc = nl // 2 + self.n_skip = min(self.n_enc, nl - self.n_enc) + + self.tok_emb = nn.Embedding(args.vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + self.memory = nn.Parameter(torch.randn(1, nm, dim) * 0.02) + + self.block = RecurrentBlock( + dim, args.num_heads, args.num_kv_heads, nl, nm, + args.mlp_mult, args.rope_base, args.lora_rank + ) + + self.skip_weights = nn.Parameter( + torch.ones(self.n_skip, dim, dtype=torch.float32) + if self.n_skip > 0 else torch.zeros(0, dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + + # Zero-init output projections for stability at loop depth + nn.init.zeros_(self.block.attn.proj.weight) + nn.init.zeros_(self.block.mlp.proj.weight) + + def _logits(self, x: Tensor) -> Tensor: + h = self.final_norm(x[:, self.n_memory:, :]) + l = F.linear(h.reshape(-1, self.dim), self.tok_emb.weight) + return self.logit_softcap * torch.tanh(l / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + B, T = input_ids.shape + tok = F.rms_norm(self.tok_emb(input_ids), (self.dim,)) + x = torch.cat([self.memory.expand(B, -1, -1), tok], dim=1) + x0 = x + + gamma = self.deep_sup_gamma + total_loss = torch.zeros((), device=input_ids.device) + weight_sum = 0.0 + targets = target_ids.reshape(-1) + skips = [] + + for i in range(self.n_loops): + # Gradient checkpointing: recompute activations during backward + x = torch.utils.checkpoint.checkpoint(self.block, x, x0, i, use_reentrant=False) + + # U-Net: encoder half saves, decoder half consumes in reverse + if i < self.n_enc: + skips.append(x) + elif self.n_skip > 0 and skips: + di = i - self.n_enc + if di < self.n_skip: + sw = self.skip_weights[di].to(x.dtype)[None, None, :] + x = x + sw * skips[self.n_enc - 1 - di] + + # Deep supervision: weighted loss at every loop + # Later loops weighted higher (gamma^(N-1-i)), but all contribute + w = gamma ** (self.n_loops - 1 - i) + total_loss = total_loss + w * F.cross_entropy( + self._logits(x).float(), targets, reduction="mean" + ) + weight_sum += w + + return total_loss / weight_sum + +# ── Training ────────────────────────────────────────────────────────────────── + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + pass # torch.compile disabled for T4 + + # ── Distributed + CUDA setup ────────────────────────────────────────────── + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required. For CPU testing use train_gpt_local.py") + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(True); enable_math_sdp(True) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master else None + + def log0(msg: str, console: bool = True) -> None: + if not master: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + + # ── Seed + tokenizer ────────────────────────────────────────────────────── + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"Tokenizer vocab {sp.vocab_size()} != VOCAB_SIZE {args.vocab_size}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_lut, space_lut, bnd_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_tokens:{val_tokens.numel()-1} train_seq_len:{args.train_seq_len}") + + # ── Model ───────────────────────────────────────────────────────────────── + base_model = RLM(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(base_model) + compiled = base_model # torch.compile disabled: T4 compile too slow + model: nn.Module = (DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) + if distributed else compiled) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} (~{n_params/1e6:.2f}M)") + log0(f"n_loops:{args.n_loops} n_memory:{args.n_memory} dim:{args.model_dim} " + f"seq_len:{args.train_seq_len}") + + # ── Optimizers ──────────────────────────────────────────────────────────── + block_named = list(base_model.block.named_parameters()) + matrix_params = [p for n, p in block_named + if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = ([p for n, p in block_named + if p.ndim != 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + + [base_model.skip_weights, base_model.memory, base_model.final_norm.eps + if hasattr(base_model.final_norm, 'eps') and + isinstance(base_model.final_norm.eps, nn.Parameter) else None]) + scalar_params = [p for p in scalar_params if p is not None] + + opt_embed = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_muon = Muon(matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_embed, opt_muon, opt_scalar] + + log0(f"matrix_params:{sum(p.numel() for p in matrix_params):,} " + f"scalar_params:{sum(p.numel() for p in scalar_params):,}") + + # ── Checkpoint resume (thermal crash protection) ────────────────────────── + ckpt_path = f"checkpoint_{args.run_id}.pt" + start_step = 0 + training_time_ms = 0.0 + if os.path.exists(ckpt_path) and master: + log0(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + base_model.load_state_dict(ckpt["model"]) + for opt, sd in zip(optimizers, ckpt["optimizers"]): + opt.load_state_dict(sd) + start_step = ckpt["step"] + training_time_ms = ckpt.get("training_time_ms", 0.0) + log0(f"Resumed at step {start_step}") + if distributed: dist.barrier() + + # ── Data loader ─────────────────────────────────────────────────────────── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad(): + for o in optimizers: o.zero_grad(set_to_none=True) + + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_scale(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wc_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ + if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = args.warmdown_iters * step_ms + rem_ms = max(max_wc_ms - elapsed_ms, 0.0) + return rem_ms / max(wd_ms, 1e-9) if rem_ms <= wd_ms else 1.0 + + # ── Warmup-then-restore ─────────────────────────────────────────────────── + if args.warmup_steps > 0 and start_step == 0: + init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + (model(x, y) * grad_scale).backward() + for o in optimizers: o.step() + zero_grad() + log0(f"warmup_complete:{args.warmup_steps} steps") + base_model.load_state_dict(init_model, strict=True) + for o, sd in zip(optimizers, init_opts): o.load_state_dict(sd) + zero_grad() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ──────────────────────────────────────────────────── + stop_after: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + for step in range(start_step, args.iterations + 1): + last = step == args.iterations or (stop_after is not None and step >= stop_after) + + # Validation + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or last: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + + if last: break + + # Checkpoint save (thermal crash protection) + if args.save_every > 0 and step > start_step and step % args.save_every == 0 and master: + torch.cuda.synchronize() + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + torch.save({"step": step, "model": base_model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "training_time_ms": elapsed_now}, ckpt_path) + log0(f"checkpoint_saved:step={step}") + + # LR update + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_scale(step, elapsed_ms) + for o in optimizers: + for g in o.param_groups: g["lr"] = g["base_lr"] * scale + + # Muon momentum warmup + frac = min(step / max(args.muon_momentum_warmup_steps, 1), 1.0) + mum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for g in opt_muon.param_groups: g["momentum"] = mum + + # Forward + backward with gradient accumulation + zero_grad() + train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for o in optimizers: o.step() + zero_grad() + + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step+1}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ms:.0f}ms step_avg:{approx_ms/(step+1):.2f}ms lr_scale:{scale:.4f}") + + reached = max_wc_ms is not None and approx_ms >= max_wc_ms + if distributed and max_wc_ms is not None: + rt = torch.tensor(int(reached), device=device) + dist.all_reduce(rt, op=dist.ReduceOp.MAX) + reached = bool(rt.item()) + if stop_after is None and reached: + stop_after = step + 1 + + log0(f"peak_memory:{torch.cuda.max_memory_allocated()//1024//1024}MiB " + f"reserved:{torch.cuda.max_memory_reserved()//1024//1024}MiB") + + # ── Serialisation + roundtrip validation ────────────────────────────────── + if master: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + model_bytes = os.path.getsize("final_model.pt") + log0(f"raw_model:{model_bytes} code:{code_bytes} total:{model_bytes+code_bytes}") + + quant_obj, qstats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + if master: + with open("final_model.int8.ptz", "wb") as f: f.write(blob) + qbytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = qstats["baseline_tensor_bytes"] / max(qstats["int8_payload_bytes"], 1) + log0(f"int8_zlib:{qbytes} code:{code_bytes} total:{qbytes+code_bytes} " + f"payload_ratio:{ratio:.2f}x budget_pct:{(qbytes+code_bytes)/16e6*100:.1f}%") + + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: blob_disk = f.read() + qs2 = torch.load(io.BytesIO(zlib.decompress(blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(qs2), strict=True) + torch.cuda.synchronize(); te = time.perf_counter() + qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + torch.cuda.synchronize() + log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-te):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvbpb:.8f}") + + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +PyTorch 2.10.0+cu128 +Fri Mar 20 07:47:20 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 63C P0 28W / 70W | 105MiB / 15360MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 27835 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +val_tokens:62021632 train_seq_len:256 +model_params:4255784 (~4.26M) +n_loops:16 n_memory:32 dim:512 seq_len:256 +matrix_params:3,276,800 scalar_params:454,696 +""" +train_gpt.py — Shared-Weight Recurrent Memory Transformer (RLM) +OpenAI Parameter Golf Competition — Non-Record Submission + +Architecture: One transformer block looped N_LOOPS=16 times (weight tying). + - Differential Attention V2 + per-loop low-rank Q delta (weight specialization) + - resid_mix: learned carry-forward vs input-reset per loop + - Loop position embeddings (activation-level depth signal) + - Memory tokens: learned persistent cross-loop state (RoPE excluded) + - Deep supervision: gamma-weighted loss at every loop iteration + - U-Net skip connections between encoder/decoder loop halves + - Gradient checkpointing: memory-efficient backprop through 16 loops + - QAT: per-row percentile fake-quantization during training + - relu² MLP, logit softcap, QK RMSNorm, q_gain (from baseline) + +At dim=512, 16 loops: ~4.24M params, ~4.24MB int8 — 26.5% of 16MB budget. +Effective compute depth of a ~63M parameter standard transformer. + +Thermal protection: checkpoints saved every SAVE_EVERY steps, auto-resumed. +""" + +from __future__ import annotations +import copy, glob, io, math, os, subprocess, sys, time, uuid, zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ── Hyperparameters ─────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + save_every = int(os.environ.get("SAVE_EVERY", 500)) + + iterations = int(os.environ.get("ITERATIONS", 5000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 600)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 65_536)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0.0)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + n_loops = int(os.environ.get("N_LOOPS", 16)) + n_memory = int(os.environ.get("N_MEMORY", 32)) + mlp_mult = int(os.environ.get("MLP_MULT", 4)) + lora_rank = int(os.environ.get("LORA_RANK", 16)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + deep_sup_gamma= float(os.environ.get("DEEP_SUP_GAMMA", 0.9)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum= float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 300)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.1)) + +# ── Muon Optimizer (from modded-nanogpt / competition baseline) ─────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum, backend_steps, nesterov = ( + group["lr"], group["momentum"], group["backend_steps"], group["nesterov"] + ) + total = sum(int(p.numel()) for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-lr) + curr += p.numel() + return loss + +# ── Control tensor patterns (kept fp32, excluded from Muon) ────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,mlp_scale,resid_mix,q_gain,lambda_q1,lambda_k1,lambda_q2,lambda_k2," + "loop_embed,skip_weight,memory", +).split(",") if p) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# ── Tokenizer-aware BPB (from competition baseline) ─────────────────────────── + +def build_sentencepiece_luts(sp, vocab_size: int, device: torch.device): + sp_vocab = int(sp.vocab_size()) + table = max(sp_vocab, vocab_size) + base_np = np.zeros((table,), dtype=np.int16) + space_np = np.zeros((table,), dtype=np.bool_) + bnd_np = np.ones((table,), dtype=np.bool_) + for tid in range(sp_vocab): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue + bnd_np[tid] = False + if sp.is_byte(tid): base_np[tid] = 1; continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): space_np[tid] = True; piece = piece[1:] + base_np[tid] = len(piece.encode("utf-8")) + return (torch.tensor(base_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(bnd_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No val files: {pattern}") + tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[:usable + 1] + +@torch.no_grad() +def eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut): + local_tok = args.val_batch_size // (world_size * grad_accum_steps) + local_seqs = local_tok // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for s in range(seq_start, seq_end, local_seqs): + se = min(s + local_seqs, seq_end) + raw = val_tokens[s * args.train_seq_len:(se * args.train_seq_len + 1)] + raw = raw.to(device=device, dtype=torch.int64, non_blocking=True) + x = raw[:-1].reshape(-1, args.train_seq_len) + y = raw[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y).detach() + n = float(y.numel()) + loss_sum += loss.to(torch.float64) * n + tok_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_lut[tgt_ids].to(torch.int16) + tb += (space_lut[tgt_ids] & ~bnd_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, tok_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + model.train() + val_loss = float((loss_sum / tok_count).item()) + bpb = float((loss_sum / tok_count / math.log(2.0) * tok_count / byte_count).item()) + return val_loss, bpb + +# ── Quantization (from competition baseline) ────────────────────────────────── + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def _keep_float(name: str, t: Tensor, pt_dtypes: dict) -> Tensor: + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + pt_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + cl = torch.maximum(torch.minimum(t32, clip[:, None]), -clip[:, None]) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(cl / sc[:, None]), -127, 127).to(torch.int8).contiguous() + return q, sc.to(INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + sc = torch.tensor(clip / 127.0 if clip > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip, clip) / sc), -127, 127).to(torch.int8).contiguous() + return q, sc + +def quantize_state_dict_int8(sd: dict): + quant, scales, dtypes, passthrough, pt_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors", + "num_nonfloat_tensors","baseline_tensor_bytes","int8_payload_bytes"), 0) + for name, tensor in sd.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = _keep_float(name, t, pt_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = _quantize_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quant[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = {"__quant_format__": "int8_clean_per_row_v1", + "quantized": quant, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if pt_dtypes: obj["passthrough_orig_dtypes"] = pt_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict) -> dict: + out = {} + qm = obj.get("qmeta", {}) + ptod = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qm.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype).contiguous() + for name, t in obj["passthrough"].items(): + ot = t.detach().cpu().contiguous() + if isinstance(ptod.get(name), str): + ot = ot.to(getattr(torch, ptod[name])).contiguous() + out[name] = ot + return out + +# ── Data loading ────────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + hdr = np.fromfile(file, dtype=" Tensor: + chunks, rem = [], n + while rem > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._next(); continue + k = min(rem, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k; rem -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum: int): + local = global_tokens // (self.world_size * grad_accum) + span = local + 1 + chunk = self.stream.take(span * self.world_size) + s = self.rank * span + raw = chunk[s:s + span].to(torch.int64) + x = raw[:-1].reshape(-1, seq_len) + y = raw[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ── Model ───────────────────────────────────────────────────────────────────── + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), b) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__(); self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, p in module.named_parameters(): + if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and p.dtype != torch.float32: + p.data = p.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + self._sl, self._cos, self._sin = 0, None, None + def forward(self, sl: int, device, dtype): + if self._cos is None or self._sl != sl or self._cos.device != device: + t = torch.arange(sl, device=device, dtype=self.inv_freq.dtype) + f = torch.outer(t, self.inv_freq.to(device)) + self._cos = f.cos()[None, None, :, :]; self._sin = f.sin()[None, None, :, :] + self._sl = sl + return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) + +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def fake_quant_per_row(w: Tensor) -> Tensor: + """Per-row percentile int8 QAT with straight-through estimator.""" + if w.ndim < 2 or w.numel() == 0: return w + w32 = w.float() + clip = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + wq = (w32.clamp(-clip[:, None], clip[:, None]) / sc[:, None]).round().clamp(-127, 127) * sc[:, None] + return w + (wq.to(w.dtype) - w).detach() + +class SharedDiffAttention(nn.Module): + """Differential Attention V2 + per-loop low-rank Q delta + GQA.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, rope_base: float, rank: int): + super().__init__() + assert dim % nh == 0 and nh % nkv == 0 + self.nh, self.nkv, self.n_rep = nh, nkv, nh // nkv + self.hd = dim // nh + self.nmem = n_mem + + self.c_q = CastedLinear(dim, dim * 2, bias=False) + self.c_k = CastedLinear(dim, nkv * self.hd * 2, bias=False) + self.c_v = CastedLinear(dim, nkv * self.hd, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + li = 0.8 - 0.6 * math.exp(-0.3) + self.lambda_init = li + self.lambda_q1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_q2 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k2 = nn.Parameter(torch.randn(nh) * 0.1) + + self.q_gain = nn.Parameter(torch.ones(nh, dtype=torch.float32)) + self.delta_A = nn.Parameter(torch.zeros(n_loops, dim * 2, rank)) + self.delta_B = nn.Parameter(torch.zeros(n_loops, rank, dim)) + nn.init.normal_(self.delta_A, std=0.01) + + self.rotary = Rotary(self.hd, base=rope_base) + + def forward(self, x: Tensor, loop_idx: int) -> Tensor: + B, T, C = x.shape + H, KVH, D, NM = self.nh, self.nkv, self.hd, self.nmem + sl = T - NM + + # Q with per-loop low-rank delta + QAT on effective weight + wq = self.c_q.weight + (self.delta_A[loop_idx] @ self.delta_B[loop_idx]) + wq = fake_quant_per_row(wq) + qp = F.linear(x, wq.to(x.dtype)) + q1 = qp[..., :C].reshape(B, T, H, D).transpose(1, 2) + q2 = qp[..., C:].reshape(B, T, H, D).transpose(1, 2) + + # K, V: fresh every loop from updated hidden state + kv = self.c_k(x) + k1 = kv[..., :KVH*D].reshape(B, T, KVH, D).transpose(1, 2) + k2 = kv[..., KVH*D:].reshape(B, T, KVH, D).transpose(1, 2) + v = self.c_v(x).reshape(B, T, KVH, D).transpose(1, 2) + + # QK RMSNorm before RoPE (baseline technique) + q1 = F.rms_norm(q1, (D,)); q2 = F.rms_norm(q2, (D,)) + k1 = F.rms_norm(k1, (D,)); k2 = F.rms_norm(k2, (D,)) + + # RoPE: sequence positions ONLY — memory tokens excluded + cos, sin = self.rotary(sl, x.device, q1.dtype) + def rope_seq(q, k): + qm, qs = q[:, :, :NM], q[:, :, NM:] + km, ks = k[:, :, :NM], k[:, :, NM:] + return (torch.cat([qm, apply_rope(qs, cos, sin)], dim=2), + torch.cat([km, apply_rope(ks, cos, sin)], dim=2)) + q1, k1 = rope_seq(q1, k1) + q2, k2 = rope_seq(q2, k2) + + # Q gain + GQA + g = self.q_gain.to(q1.dtype)[None, :, None, None] + q1 = q1 * g; q2 = q2 * g + k1 = k1.repeat_interleave(self.n_rep, dim=1) + k2 = k2.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + # Differential attention: cancel noise via subtraction + lam = (torch.exp(self.lambda_q1) * torch.exp(self.lambda_k1) + - torch.exp(self.lambda_q2) * torch.exp(self.lambda_k2) + + self.lambda_init).to(q1.dtype)[None, :, None, None] + a1 = F.scaled_dot_product_attention(q1, k1, v, is_causal=True) + a2 = F.scaled_dot_product_attention(q2, k2, v, is_causal=True) + out = (a1 - lam * a2).transpose(1, 2).contiguous().view(B, T, C) + return self.proj(out) + +class ReluSquaredMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + h = dim * mlp_mult + self.fc = CastedLinear(dim, h, bias=False) + self.proj = CastedLinear(h, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class RecurrentBlock(nn.Module): + """Shared block with all per-loop parameters.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, mlp_mult: int, rope_base: float, rank: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SharedDiffAttention(dim, nh, nkv, n_loops, n_mem, rope_base, rank) + self.mlp = ReluSquaredMLP(dim, mlp_mult) + + # Per-loop scale factors (learned attention/MLP gate per depth) + self.attn_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + + # resid_mix: learned interpolation between current x and original x0 + # [0] = weight on current state, [1] = weight on original input + # Initialised to carry forward (mix[0]=1, mix[1]=0), learns to reset selectively + self.resid_mix = nn.Parameter( + torch.stack([torch.ones(n_loops, dim), torch.zeros(n_loops, dim)]).float() + ) + + # Loop position embedding: activation-level depth signal + # Complementary to weight-level low-rank delta + self.loop_embed = nn.Embedding(n_loops, dim) + nn.init.normal_(self.loop_embed.weight, std=0.02) + + def forward(self, x: Tensor, x0: Tensor, loop_idx: int) -> Tensor: + # resid_mix: learned balance of current state vs original input (generalises input injection) + mix = self.resid_mix[:, loop_idx, :].to(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # Loop position embedding (activation-level specialisation) + x = x + self.loop_embed.weight[loop_idx].to(x.dtype)[None, None, :] + + # Attention + MLP with per-loop scale + a = self.attn(self.attn_norm(x), loop_idx) + x = x + self.attn_scale[loop_idx].to(x.dtype)[None, None, :] * a + x = x + self.mlp_scale[loop_idx].to(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class RLM(nn.Module): + """ + Recurrent Language Model. + Single shared block × N_LOOPS with deep supervision and U-Net skips. + """ + def __init__(self, args: Hyperparameters): + super().__init__() + dim, nl, nm = args.model_dim, args.n_loops, args.n_memory + self.n_loops, self.n_memory = nl, nm + self.dim = dim + self.logit_softcap = args.logit_softcap + self.deep_sup_gamma = args.deep_sup_gamma + + self.n_enc = nl // 2 + self.n_skip = min(self.n_enc, nl - self.n_enc) + + self.tok_emb = nn.Embedding(args.vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + self.memory = nn.Parameter(torch.randn(1, nm, dim) * 0.02) + + self.block = RecurrentBlock( + dim, args.num_heads, args.num_kv_heads, nl, nm, + args.mlp_mult, args.rope_base, args.lora_rank + ) + + self.skip_weights = nn.Parameter( + torch.ones(self.n_skip, dim, dtype=torch.float32) + if self.n_skip > 0 else torch.zeros(0, dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + + # Zero-init output projections for stability at loop depth + nn.init.zeros_(self.block.attn.proj.weight) + nn.init.zeros_(self.block.mlp.proj.weight) + + def _logits(self, x: Tensor) -> Tensor: + h = self.final_norm(x[:, self.n_memory:, :]) + l = F.linear(h.reshape(-1, self.dim), self.tok_emb.weight) + return self.logit_softcap * torch.tanh(l / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + B, T = input_ids.shape + tok = F.rms_norm(self.tok_emb(input_ids), (self.dim,)) + x = torch.cat([self.memory.expand(B, -1, -1), tok], dim=1) + x0 = x + + gamma = self.deep_sup_gamma + total_loss = torch.zeros((), device=input_ids.device) + weight_sum = 0.0 + targets = target_ids.reshape(-1) + skips = [] + + for i in range(self.n_loops): + # Gradient checkpointing: recompute activations during backward + x = torch.utils.checkpoint.checkpoint(self.block, x, x0, i, use_reentrant=False) + + # U-Net: encoder half saves, decoder half consumes in reverse + if i < self.n_enc: + skips.append(x) + elif self.n_skip > 0 and skips: + di = i - self.n_enc + if di < self.n_skip: + sw = self.skip_weights[di].to(x.dtype)[None, None, :] + x = x + sw * skips[self.n_enc - 1 - di] + + # Deep supervision: weighted loss at every loop + # Later loops weighted higher (gamma^(N-1-i)), but all contribute + w = gamma ** (self.n_loops - 1 - i) + total_loss = total_loss + w * F.cross_entropy( + self._logits(x).float(), targets, reduction="mean" + ) + weight_sum += w + + return total_loss / weight_sum + +# ── Training ────────────────────────────────────────────────────────────────── + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + pass # torch.compile disabled for T4 + + # ── Distributed + CUDA setup ────────────────────────────────────────────── + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required. For CPU testing use train_gpt_local.py") + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(True); enable_math_sdp(True) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master else None + + def log0(msg: str, console: bool = True) -> None: + if not master: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + + # ── Seed + tokenizer ────────────────────────────────────────────────────── + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"Tokenizer vocab {sp.vocab_size()} != VOCAB_SIZE {args.vocab_size}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_lut, space_lut, bnd_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_tokens:{val_tokens.numel()-1} train_seq_len:{args.train_seq_len}") + + # ── Model ───────────────────────────────────────────────────────────────── + base_model = RLM(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(base_model) + compiled = base_model # torch.compile disabled: T4 compile too slow + model: nn.Module = (DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) + if distributed else compiled) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} (~{n_params/1e6:.2f}M)") + log0(f"n_loops:{args.n_loops} n_memory:{args.n_memory} dim:{args.model_dim} " + f"seq_len:{args.train_seq_len}") + + # ── Optimizers ──────────────────────────────────────────────────────────── + block_named = list(base_model.block.named_parameters()) + matrix_params = [p for n, p in block_named + if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = ([p for n, p in block_named + if p.ndim != 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + + [base_model.skip_weights, base_model.memory, base_model.final_norm.eps + if hasattr(base_model.final_norm, 'eps') and + isinstance(base_model.final_norm.eps, nn.Parameter) else None]) + scalar_params = [p for p in scalar_params if p is not None] + + opt_embed = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_muon = Muon(matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_embed, opt_muon, opt_scalar] + + log0(f"matrix_params:{sum(p.numel() for p in matrix_params):,} " + f"scalar_params:{sum(p.numel() for p in scalar_params):,}") + + # ── Checkpoint resume (thermal crash protection) ────────────────────────── + ckpt_path = f"checkpoint_{args.run_id}.pt" + start_step = 0 + training_time_ms = 0.0 + if os.path.exists(ckpt_path) and master: + log0(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + base_model.load_state_dict(ckpt["model"]) + for opt, sd in zip(optimizers, ckpt["optimizers"]): + opt.load_state_dict(sd) + start_step = ckpt["step"] + training_time_ms = ckpt.get("training_time_ms", 0.0) + log0(f"Resumed at step {start_step}") + if distributed: dist.barrier() + + # ── Data loader ─────────────────────────────────────────────────────────── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad(): + for o in optimizers: o.zero_grad(set_to_none=True) + + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_scale(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wc_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ + if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = args.warmdown_iters * step_ms + rem_ms = max(max_wc_ms - elapsed_ms, 0.0) + return rem_ms / max(wd_ms, 1e-9) if rem_ms <= wd_ms else 1.0 + + # ── Warmup-then-restore ─────────────────────────────────────────────────── + if args.warmup_steps > 0 and start_step == 0: + init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + (model(x, y) * grad_scale).backward() + for o in optimizers: o.step() + zero_grad() + log0(f"warmup_complete:{args.warmup_steps} steps") + base_model.load_state_dict(init_model, strict=True) + for o, sd in zip(optimizers, init_opts): o.load_state_dict(sd) + zero_grad() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ──────────────────────────────────────────────────── + stop_after: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + for step in range(start_step, args.iterations + 1): + last = step == args.iterations or (stop_after is not None and step >= stop_after) + + # Validation + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or last: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + + if last: break + + # Checkpoint save (thermal crash protection) + if args.save_every > 0 and step > start_step and step % args.save_every == 0 and master: + torch.cuda.synchronize() + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + torch.save({"step": step, "model": base_model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "training_time_ms": elapsed_now}, ckpt_path) + log0(f"checkpoint_saved:step={step}") + + # LR update + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_scale(step, elapsed_ms) + for o in optimizers: + for g in o.param_groups: g["lr"] = g["base_lr"] * scale + + # Muon momentum warmup + frac = min(step / max(args.muon_momentum_warmup_steps, 1), 1.0) + mum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for g in opt_muon.param_groups: g["momentum"] = mum + + # Forward + backward with gradient accumulation + zero_grad() + train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for o in optimizers: o.step() + zero_grad() + + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step+1}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ms:.0f}ms step_avg:{approx_ms/(step+1):.2f}ms lr_scale:{scale:.4f}") + + reached = max_wc_ms is not None and approx_ms >= max_wc_ms + if distributed and max_wc_ms is not None: + rt = torch.tensor(int(reached), device=device) + dist.all_reduce(rt, op=dist.ReduceOp.MAX) + reached = bool(rt.item()) + if stop_after is None and reached: + stop_after = step + 1 + + log0(f"peak_memory:{torch.cuda.max_memory_allocated()//1024//1024}MiB " + f"reserved:{torch.cuda.max_memory_reserved()//1024//1024}MiB") + + # ── Serialisation + roundtrip validation ────────────────────────────────── + if master: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + model_bytes = os.path.getsize("final_model.pt") + log0(f"raw_model:{model_bytes} code:{code_bytes} total:{model_bytes+code_bytes}") + + quant_obj, qstats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + if master: + with open("final_model.int8.ptz", "wb") as f: f.write(blob) + qbytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = qstats["baseline_tensor_bytes"] / max(qstats["int8_payload_bytes"], 1) + log0(f"int8_zlib:{qbytes} code:{code_bytes} total:{qbytes+code_bytes} " + f"payload_ratio:{ratio:.2f}x budget_pct:{(qbytes+code_bytes)/16e6*100:.1f}%") + + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: blob_disk = f.read() + qs2 = torch.load(io.BytesIO(zlib.decompress(blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(qs2), strict=True) + torch.cuda.synchronize(); te = time.perf_counter() + qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + torch.cuda.synchronize() + log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-te):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvbpb:.8f}") + + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +PyTorch 2.10.0+cu128 +Fri Mar 20 07:48:49 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 69C P0 30W / 70W | 105MiB / 15360MiB | 4% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 28227 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +val_tokens:62021760 train_seq_len:128 +model_params:342676 (~0.34M) +n_loops:1 n_memory:0 dim:128 seq_len:128 +matrix_params:204,800 scalar_params:6,804 +step:0/3 val_loss:6.9295 val_bpb:4.1040 train_time:0ms step_avg:0.03ms +""" +train_gpt.py — Shared-Weight Recurrent Memory Transformer (RLM) +OpenAI Parameter Golf Competition — Non-Record Submission + +Architecture: One transformer block looped N_LOOPS=16 times (weight tying). + - Differential Attention V2 + per-loop low-rank Q delta (weight specialization) + - resid_mix: learned carry-forward vs input-reset per loop + - Loop position embeddings (activation-level depth signal) + - Memory tokens: learned persistent cross-loop state (RoPE excluded) + - Deep supervision: gamma-weighted loss at every loop iteration + - U-Net skip connections between encoder/decoder loop halves + - Gradient checkpointing: memory-efficient backprop through 16 loops + - QAT: per-row percentile fake-quantization during training + - relu² MLP, logit softcap, QK RMSNorm, q_gain (from baseline) + +At dim=512, 16 loops: ~4.24M params, ~4.24MB int8 — 26.5% of 16MB budget. +Effective compute depth of a ~63M parameter standard transformer. + +Thermal protection: checkpoints saved every SAVE_EVERY steps, auto-resumed. +""" + +from __future__ import annotations +import copy, glob, io, math, os, subprocess, sys, time, uuid, zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ── Hyperparameters ─────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + save_every = int(os.environ.get("SAVE_EVERY", 500)) + + iterations = int(os.environ.get("ITERATIONS", 5000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 600)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 65_536)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0.0)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + n_loops = int(os.environ.get("N_LOOPS", 16)) + n_memory = int(os.environ.get("N_MEMORY", 32)) + mlp_mult = int(os.environ.get("MLP_MULT", 4)) + lora_rank = int(os.environ.get("LORA_RANK", 16)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + deep_sup_gamma= float(os.environ.get("DEEP_SUP_GAMMA", 0.9)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum= float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 300)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.1)) + +# ── Muon Optimizer (from modded-nanogpt / competition baseline) ─────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum, backend_steps, nesterov = ( + group["lr"], group["momentum"], group["backend_steps"], group["nesterov"] + ) + total = sum(int(p.numel()) for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-lr) + curr += p.numel() + return loss + +# ── Control tensor patterns (kept fp32, excluded from Muon) ────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,mlp_scale,resid_mix,q_gain,lambda_q1,lambda_k1,lambda_q2,lambda_k2," + "loop_embed,skip_weight,memory", +).split(",") if p) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# ── Tokenizer-aware BPB (from competition baseline) ─────────────────────────── + +def build_sentencepiece_luts(sp, vocab_size: int, device: torch.device): + sp_vocab = int(sp.vocab_size()) + table = max(sp_vocab, vocab_size) + base_np = np.zeros((table,), dtype=np.int16) + space_np = np.zeros((table,), dtype=np.bool_) + bnd_np = np.ones((table,), dtype=np.bool_) + for tid in range(sp_vocab): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue + bnd_np[tid] = False + if sp.is_byte(tid): base_np[tid] = 1; continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): space_np[tid] = True; piece = piece[1:] + base_np[tid] = len(piece.encode("utf-8")) + return (torch.tensor(base_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(bnd_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No val files: {pattern}") + tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[:usable + 1] + +@torch.no_grad() +def eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut): + local_tok = args.val_batch_size // (world_size * grad_accum_steps) + local_seqs = local_tok // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for s in range(seq_start, seq_end, local_seqs): + se = min(s + local_seqs, seq_end) + raw = val_tokens[s * args.train_seq_len:(se * args.train_seq_len + 1)] + raw = raw.to(device=device, dtype=torch.int64, non_blocking=True) + x = raw[:-1].reshape(-1, args.train_seq_len) + y = raw[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y).detach() + n = float(y.numel()) + loss_sum += loss.to(torch.float64) * n + tok_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_lut[tgt_ids].to(torch.int16) + tb += (space_lut[tgt_ids] & ~bnd_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, tok_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + model.train() + val_loss = float((loss_sum / tok_count).item()) + bpb = float((loss_sum / tok_count / math.log(2.0) * tok_count / byte_count).item()) + return val_loss, bpb + +# ── Quantization (from competition baseline) ────────────────────────────────── + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def _keep_float(name: str, t: Tensor, pt_dtypes: dict) -> Tensor: + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + pt_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + cl = torch.maximum(torch.minimum(t32, clip[:, None]), -clip[:, None]) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(cl / sc[:, None]), -127, 127).to(torch.int8).contiguous() + return q, sc.to(INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + sc = torch.tensor(clip / 127.0 if clip > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip, clip) / sc), -127, 127).to(torch.int8).contiguous() + return q, sc + +def quantize_state_dict_int8(sd: dict): + quant, scales, dtypes, passthrough, pt_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors", + "num_nonfloat_tensors","baseline_tensor_bytes","int8_payload_bytes"), 0) + for name, tensor in sd.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = _keep_float(name, t, pt_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = _quantize_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quant[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = {"__quant_format__": "int8_clean_per_row_v1", + "quantized": quant, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if pt_dtypes: obj["passthrough_orig_dtypes"] = pt_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict) -> dict: + out = {} + qm = obj.get("qmeta", {}) + ptod = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qm.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype).contiguous() + for name, t in obj["passthrough"].items(): + ot = t.detach().cpu().contiguous() + if isinstance(ptod.get(name), str): + ot = ot.to(getattr(torch, ptod[name])).contiguous() + out[name] = ot + return out + +# ── Data loading ────────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + hdr = np.fromfile(file, dtype=" Tensor: + chunks, rem = [], n + while rem > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._next(); continue + k = min(rem, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k; rem -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum: int): + local = global_tokens // (self.world_size * grad_accum) + span = local + 1 + chunk = self.stream.take(span * self.world_size) + s = self.rank * span + raw = chunk[s:s + span].to(torch.int64) + x = raw[:-1].reshape(-1, seq_len) + y = raw[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ── Model ───────────────────────────────────────────────────────────────────── + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), b) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__(); self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, p in module.named_parameters(): + if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and p.dtype != torch.float32: + p.data = p.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + self._sl, self._cos, self._sin = 0, None, None + def forward(self, sl: int, device, dtype): + if self._cos is None or self._sl != sl or self._cos.device != device: + t = torch.arange(sl, device=device, dtype=self.inv_freq.dtype) + f = torch.outer(t, self.inv_freq.to(device)) + self._cos = f.cos()[None, None, :, :]; self._sin = f.sin()[None, None, :, :] + self._sl = sl + return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) + +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def fake_quant_per_row(w: Tensor) -> Tensor: + """Per-row percentile int8 QAT with straight-through estimator.""" + if w.ndim < 2 or w.numel() == 0: return w + w32 = w.float() + clip = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + wq = (w32.clamp(-clip[:, None], clip[:, None]) / sc[:, None]).round().clamp(-127, 127) * sc[:, None] + return w + (wq.to(w.dtype) - w).detach() + +class SharedDiffAttention(nn.Module): + """Differential Attention V2 + per-loop low-rank Q delta + GQA.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, rope_base: float, rank: int): + super().__init__() + assert dim % nh == 0 and nh % nkv == 0 + self.nh, self.nkv, self.n_rep = nh, nkv, nh // nkv + self.hd = dim // nh + self.nmem = n_mem + + self.c_q = CastedLinear(dim, dim * 2, bias=False) + self.c_k = CastedLinear(dim, nkv * self.hd * 2, bias=False) + self.c_v = CastedLinear(dim, nkv * self.hd, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + li = 0.8 - 0.6 * math.exp(-0.3) + self.lambda_init = li + self.lambda_q1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_q2 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k2 = nn.Parameter(torch.randn(nh) * 0.1) + + self.q_gain = nn.Parameter(torch.ones(nh, dtype=torch.float32)) + self.delta_A = nn.Parameter(torch.zeros(n_loops, dim * 2, rank)) + self.delta_B = nn.Parameter(torch.zeros(n_loops, rank, dim)) + nn.init.normal_(self.delta_A, std=0.01) + + self.rotary = Rotary(self.hd, base=rope_base) + + def forward(self, x: Tensor, loop_idx: int) -> Tensor: + B, T, C = x.shape + H, KVH, D, NM = self.nh, self.nkv, self.hd, self.nmem + sl = T - NM + + # Q with per-loop low-rank delta + QAT on effective weight + wq = self.c_q.weight + (self.delta_A[loop_idx] @ self.delta_B[loop_idx]) + wq = fake_quant_per_row(wq) + qp = F.linear(x, wq.to(x.dtype)) + q1 = qp[..., :C].reshape(B, T, H, D).transpose(1, 2) + q2 = qp[..., C:].reshape(B, T, H, D).transpose(1, 2) + + # K, V: fresh every loop from updated hidden state + kv = self.c_k(x) + k1 = kv[..., :KVH*D].reshape(B, T, KVH, D).transpose(1, 2) + k2 = kv[..., KVH*D:].reshape(B, T, KVH, D).transpose(1, 2) + v = self.c_v(x).reshape(B, T, KVH, D).transpose(1, 2) + + # QK RMSNorm before RoPE (baseline technique) + q1 = F.rms_norm(q1, (D,)); q2 = F.rms_norm(q2, (D,)) + k1 = F.rms_norm(k1, (D,)); k2 = F.rms_norm(k2, (D,)) + + # RoPE: sequence positions ONLY — memory tokens excluded + cos, sin = self.rotary(sl, x.device, q1.dtype) + def rope_seq(q, k): + qm, qs = q[:, :, :NM], q[:, :, NM:] + km, ks = k[:, :, :NM], k[:, :, NM:] + return (torch.cat([qm, apply_rope(qs, cos, sin)], dim=2), + torch.cat([km, apply_rope(ks, cos, sin)], dim=2)) + q1, k1 = rope_seq(q1, k1) + q2, k2 = rope_seq(q2, k2) + + # Q gain + GQA + g = self.q_gain.to(q1.dtype)[None, :, None, None] + q1 = q1 * g; q2 = q2 * g + k1 = k1.repeat_interleave(self.n_rep, dim=1) + k2 = k2.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + # Differential attention: cancel noise via subtraction + lam = (torch.exp(self.lambda_q1) * torch.exp(self.lambda_k1) + - torch.exp(self.lambda_q2) * torch.exp(self.lambda_k2) + + self.lambda_init).to(q1.dtype)[None, :, None, None] + a1 = F.scaled_dot_product_attention(q1, k1, v, is_causal=True) + a2 = F.scaled_dot_product_attention(q2, k2, v, is_causal=True) + out = (a1 - lam * a2).transpose(1, 2).contiguous().view(B, T, C) + return self.proj(out) + +class ReluSquaredMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + h = dim * mlp_mult + self.fc = CastedLinear(dim, h, bias=False) + self.proj = CastedLinear(h, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class RecurrentBlock(nn.Module): + """Shared block with all per-loop parameters.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, mlp_mult: int, rope_base: float, rank: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SharedDiffAttention(dim, nh, nkv, n_loops, n_mem, rope_base, rank) + self.mlp = ReluSquaredMLP(dim, mlp_mult) + + # Per-loop scale factors (learned attention/MLP gate per depth) + self.attn_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + + # resid_mix: learned interpolation between current x and original x0 + # [0] = weight on current state, [1] = weight on original input + # Initialised to carry forward (mix[0]=1, mix[1]=0), learns to reset selectively + self.resid_mix = nn.Parameter( + torch.stack([torch.ones(n_loops, dim), torch.zeros(n_loops, dim)]).float() + ) + + # Loop position embedding: activation-level depth signal + # Complementary to weight-level low-rank delta + self.loop_embed = nn.Embedding(n_loops, dim) + nn.init.normal_(self.loop_embed.weight, std=0.02) + + def forward(self, x: Tensor, x0: Tensor, loop_idx: int) -> Tensor: + # resid_mix: learned balance of current state vs original input (generalises input injection) + mix = self.resid_mix[:, loop_idx, :].to(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # Loop position embedding (activation-level specialisation) + x = x + self.loop_embed.weight[loop_idx].to(x.dtype)[None, None, :] + + # Attention + MLP with per-loop scale + a = self.attn(self.attn_norm(x), loop_idx) + x = x + self.attn_scale[loop_idx].to(x.dtype)[None, None, :] * a + x = x + self.mlp_scale[loop_idx].to(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class RLM(nn.Module): + """ + Recurrent Language Model. + Single shared block × N_LOOPS with deep supervision and U-Net skips. + """ + def __init__(self, args: Hyperparameters): + super().__init__() + dim, nl, nm = args.model_dim, args.n_loops, args.n_memory + self.n_loops, self.n_memory = nl, nm + self.dim = dim + self.logit_softcap = args.logit_softcap + self.deep_sup_gamma = args.deep_sup_gamma + + self.n_enc = nl // 2 + self.n_skip = min(self.n_enc, nl - self.n_enc) + + self.tok_emb = nn.Embedding(args.vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + self.memory = nn.Parameter(torch.randn(1, nm, dim) * 0.02) + + self.block = RecurrentBlock( + dim, args.num_heads, args.num_kv_heads, nl, nm, + args.mlp_mult, args.rope_base, args.lora_rank + ) + + self.skip_weights = nn.Parameter( + torch.ones(self.n_skip, dim, dtype=torch.float32) + if self.n_skip > 0 else torch.zeros(0, dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + + # Zero-init output projections for stability at loop depth + nn.init.zeros_(self.block.attn.proj.weight) + nn.init.zeros_(self.block.mlp.proj.weight) + + def _logits(self, x: Tensor) -> Tensor: + h = self.final_norm(x[:, self.n_memory:, :]) + l = F.linear(h.reshape(-1, self.dim), self.tok_emb.weight) + return self.logit_softcap * torch.tanh(l / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + B, T = input_ids.shape + tok = F.rms_norm(self.tok_emb(input_ids), (self.dim,)) + x = torch.cat([self.memory.expand(B, -1, -1), tok], dim=1) + x0 = x + + gamma = self.deep_sup_gamma + total_loss = torch.zeros((), device=input_ids.device) + weight_sum = 0.0 + targets = target_ids.reshape(-1) + skips = [] + + for i in range(self.n_loops): + # Gradient checkpointing: recompute activations during backward + x = torch.utils.checkpoint.checkpoint(self.block, x, x0, i, use_reentrant=False) + + # U-Net: encoder half saves, decoder half consumes in reverse + if i < self.n_enc: + skips.append(x) + elif self.n_skip > 0 and skips: + di = i - self.n_enc + if di < self.n_skip: + sw = self.skip_weights[di].to(x.dtype)[None, None, :] + x = x + sw * skips[self.n_enc - 1 - di] + + # Deep supervision: weighted loss at every loop + # Later loops weighted higher (gamma^(N-1-i)), but all contribute + w = gamma ** (self.n_loops - 1 - i) + total_loss = total_loss + w * F.cross_entropy( + self._logits(x).float(), targets, reduction="mean" + ) + weight_sum += w + + return total_loss / weight_sum + +# ── Training ────────────────────────────────────────────────────────────────── + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + pass # torch.compile disabled for T4 + + # ── Distributed + CUDA setup ────────────────────────────────────────────── + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required. For CPU testing use train_gpt_local.py") + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(True); enable_math_sdp(True) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master else None + + def log0(msg: str, console: bool = True) -> None: + if not master: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + + # ── Seed + tokenizer ────────────────────────────────────────────────────── + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"Tokenizer vocab {sp.vocab_size()} != VOCAB_SIZE {args.vocab_size}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_lut, space_lut, bnd_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_tokens:{val_tokens.numel()-1} train_seq_len:{args.train_seq_len}") + + # ── Model ───────────────────────────────────────────────────────────────── + base_model = RLM(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(base_model) + compiled = base_model # torch.compile disabled: T4 compile too slow + model: nn.Module = (DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) + if distributed else compiled) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} (~{n_params/1e6:.2f}M)") + log0(f"n_loops:{args.n_loops} n_memory:{args.n_memory} dim:{args.model_dim} " + f"seq_len:{args.train_seq_len}") + + # ── Optimizers ──────────────────────────────────────────────────────────── + block_named = list(base_model.block.named_parameters()) + matrix_params = [p for n, p in block_named + if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = ([p for n, p in block_named + if p.ndim != 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + + [base_model.skip_weights, base_model.memory, base_model.final_norm.eps + if hasattr(base_model.final_norm, 'eps') and + isinstance(base_model.final_norm.eps, nn.Parameter) else None]) + scalar_params = [p for p in scalar_params if p is not None] + + opt_embed = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_muon = Muon(matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_embed, opt_muon, opt_scalar] + + log0(f"matrix_params:{sum(p.numel() for p in matrix_params):,} " + f"scalar_params:{sum(p.numel() for p in scalar_params):,}") + + # ── Checkpoint resume (thermal crash protection) ────────────────────────── + ckpt_path = f"checkpoint_{args.run_id}.pt" + start_step = 0 + training_time_ms = 0.0 + if os.path.exists(ckpt_path) and master: + log0(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + base_model.load_state_dict(ckpt["model"]) + for opt, sd in zip(optimizers, ckpt["optimizers"]): + opt.load_state_dict(sd) + start_step = ckpt["step"] + training_time_ms = ckpt.get("training_time_ms", 0.0) + log0(f"Resumed at step {start_step}") + if distributed: dist.barrier() + + # ── Data loader ─────────────────────────────────────────────────────────── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad(): + for o in optimizers: o.zero_grad(set_to_none=True) + + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_scale(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wc_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ + if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = args.warmdown_iters * step_ms + rem_ms = max(max_wc_ms - elapsed_ms, 0.0) + return rem_ms / max(wd_ms, 1e-9) if rem_ms <= wd_ms else 1.0 + + # ── Warmup-then-restore ─────────────────────────────────────────────────── + if args.warmup_steps > 0 and start_step == 0: + init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + (model(x, y) * grad_scale).backward() + for o in optimizers: o.step() + zero_grad() + log0(f"warmup_complete:{args.warmup_steps} steps") + base_model.load_state_dict(init_model, strict=True) + for o, sd in zip(optimizers, init_opts): o.load_state_dict(sd) + zero_grad() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ──────────────────────────────────────────────────── + stop_after: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + for step in range(start_step, args.iterations + 1): + last = step == args.iterations or (stop_after is not None and step >= stop_after) + + # Validation + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or last: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + + if last: break + + # Checkpoint save (thermal crash protection) + if args.save_every > 0 and step > start_step and step % args.save_every == 0 and master: + torch.cuda.synchronize() + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + torch.save({"step": step, "model": base_model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "training_time_ms": elapsed_now}, ckpt_path) + log0(f"checkpoint_saved:step={step}") + + # LR update + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_scale(step, elapsed_ms) + for o in optimizers: + for g in o.param_groups: g["lr"] = g["base_lr"] * scale + + # Muon momentum warmup + frac = min(step / max(args.muon_momentum_warmup_steps, 1), 1.0) + mum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for g in opt_muon.param_groups: g["momentum"] = mum + + # Forward + backward with gradient accumulation + zero_grad() + train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for o in optimizers: o.step() + zero_grad() + + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step+1}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ms:.0f}ms step_avg:{approx_ms/(step+1):.2f}ms lr_scale:{scale:.4f}") + + reached = max_wc_ms is not None and approx_ms >= max_wc_ms + if distributed and max_wc_ms is not None: + rt = torch.tensor(int(reached), device=device) + dist.all_reduce(rt, op=dist.ReduceOp.MAX) + reached = bool(rt.item()) + if stop_after is None and reached: + stop_after = step + 1 + + log0(f"peak_memory:{torch.cuda.max_memory_allocated()//1024//1024}MiB " + f"reserved:{torch.cuda.max_memory_reserved()//1024//1024}MiB") + + # ── Serialisation + roundtrip validation ────────────────────────────────── + if master: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + model_bytes = os.path.getsize("final_model.pt") + log0(f"raw_model:{model_bytes} code:{code_bytes} total:{model_bytes+code_bytes}") + + quant_obj, qstats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + if master: + with open("final_model.int8.ptz", "wb") as f: f.write(blob) + qbytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = qstats["baseline_tensor_bytes"] / max(qstats["int8_payload_bytes"], 1) + log0(f"int8_zlib:{qbytes} code:{code_bytes} total:{qbytes+code_bytes} " + f"payload_ratio:{ratio:.2f}x budget_pct:{(qbytes+code_bytes)/16e6*100:.1f}%") + + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: blob_disk = f.read() + qs2 = torch.load(io.BytesIO(zlib.decompress(blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(qs2), strict=True) + torch.cuda.synchronize(); te = time.perf_counter() + qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + torch.cuda.synchronize() + log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-te):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvbpb:.8f}") + + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +PyTorch 2.10.0+cu128 +Fri Mar 20 07:52:04 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 54C P0 27W / 70W | 105MiB / 15360MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 29064 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +val_tokens:62021632 train_seq_len:512 +model_params:4239400 (~4.24M) +n_loops:16 n_memory:0 dim:512 seq_len:512 +matrix_params:3,276,800 scalar_params:438,312 +""" +train_gpt.py — Shared-Weight Recurrent Memory Transformer (RLM) +OpenAI Parameter Golf Competition — Non-Record Submission + +Architecture: One transformer block looped N_LOOPS=16 times (weight tying). + - Differential Attention V2 + per-loop low-rank Q delta (weight specialization) + - resid_mix: learned carry-forward vs input-reset per loop + - Loop position embeddings (activation-level depth signal) + - Memory tokens: learned persistent cross-loop state (RoPE excluded) + - Deep supervision: gamma-weighted loss at every loop iteration + - U-Net skip connections between encoder/decoder loop halves + - Gradient checkpointing: memory-efficient backprop through 16 loops + - QAT: per-row percentile fake-quantization during training + - relu² MLP, logit softcap, QK RMSNorm, q_gain (from baseline) + +At dim=512, 16 loops: ~4.24M params, ~4.24MB int8 — 26.5% of 16MB budget. +Effective compute depth of a ~63M parameter standard transformer. + +Thermal protection: checkpoints saved every SAVE_EVERY steps, auto-resumed. +""" + +from __future__ import annotations +import copy, glob, io, math, os, subprocess, sys, time, uuid, zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ── Hyperparameters ─────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + save_every = int(os.environ.get("SAVE_EVERY", 500)) + + iterations = int(os.environ.get("ITERATIONS", 5000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 600)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 65_536)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0.0)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + n_loops = int(os.environ.get("N_LOOPS", 16)) + n_memory = int(os.environ.get("N_MEMORY", 32)) + mlp_mult = int(os.environ.get("MLP_MULT", 4)) + lora_rank = int(os.environ.get("LORA_RANK", 16)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + deep_sup_gamma= float(os.environ.get("DEEP_SUP_GAMMA", 0.9)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum= float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 300)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.1)) + +# ── Muon Optimizer (from modded-nanogpt / competition baseline) ─────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum, backend_steps, nesterov = ( + group["lr"], group["momentum"], group["backend_steps"], group["nesterov"] + ) + total = sum(int(p.numel()) for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-lr) + curr += p.numel() + return loss + +# ── Control tensor patterns (kept fp32, excluded from Muon) ────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,mlp_scale,resid_mix,q_gain,lambda_q1,lambda_k1,lambda_q2,lambda_k2," + "loop_embed,skip_weight,memory", +).split(",") if p) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# ── Tokenizer-aware BPB (from competition baseline) ─────────────────────────── + +def build_sentencepiece_luts(sp, vocab_size: int, device: torch.device): + sp_vocab = int(sp.vocab_size()) + table = max(sp_vocab, vocab_size) + base_np = np.zeros((table,), dtype=np.int16) + space_np = np.zeros((table,), dtype=np.bool_) + bnd_np = np.ones((table,), dtype=np.bool_) + for tid in range(sp_vocab): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue + bnd_np[tid] = False + if sp.is_byte(tid): base_np[tid] = 1; continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): space_np[tid] = True; piece = piece[1:] + base_np[tid] = len(piece.encode("utf-8")) + return (torch.tensor(base_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(bnd_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No val files: {pattern}") + tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[:usable + 1] + +@torch.no_grad() +def eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut): + local_tok = args.val_batch_size // (world_size * grad_accum_steps) + local_seqs = local_tok // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for s in range(seq_start, seq_end, local_seqs): + se = min(s + local_seqs, seq_end) + raw = val_tokens[s * args.train_seq_len:(se * args.train_seq_len + 1)] + raw = raw.to(device=device, dtype=torch.int64, non_blocking=True) + x = raw[:-1].reshape(-1, args.train_seq_len) + y = raw[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y).detach() + n = float(y.numel()) + loss_sum += loss.to(torch.float64) * n + tok_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_lut[tgt_ids].to(torch.int16) + tb += (space_lut[tgt_ids] & ~bnd_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, tok_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + model.train() + val_loss = float((loss_sum / tok_count).item()) + bpb = float((loss_sum / tok_count / math.log(2.0) * tok_count / byte_count).item()) + return val_loss, bpb + +# ── Quantization (from competition baseline) ────────────────────────────────── + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def _keep_float(name: str, t: Tensor, pt_dtypes: dict) -> Tensor: + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + pt_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + cl = torch.maximum(torch.minimum(t32, clip[:, None]), -clip[:, None]) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(cl / sc[:, None]), -127, 127).to(torch.int8).contiguous() + return q, sc.to(INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + sc = torch.tensor(clip / 127.0 if clip > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip, clip) / sc), -127, 127).to(torch.int8).contiguous() + return q, sc + +def quantize_state_dict_int8(sd: dict): + quant, scales, dtypes, passthrough, pt_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors", + "num_nonfloat_tensors","baseline_tensor_bytes","int8_payload_bytes"), 0) + for name, tensor in sd.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = _keep_float(name, t, pt_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = _quantize_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quant[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = {"__quant_format__": "int8_clean_per_row_v1", + "quantized": quant, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if pt_dtypes: obj["passthrough_orig_dtypes"] = pt_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict) -> dict: + out = {} + qm = obj.get("qmeta", {}) + ptod = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qm.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype).contiguous() + for name, t in obj["passthrough"].items(): + ot = t.detach().cpu().contiguous() + if isinstance(ptod.get(name), str): + ot = ot.to(getattr(torch, ptod[name])).contiguous() + out[name] = ot + return out + +# ── Data loading ────────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + hdr = np.fromfile(file, dtype=" Tensor: + chunks, rem = [], n + while rem > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._next(); continue + k = min(rem, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k; rem -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum: int): + local = global_tokens // (self.world_size * grad_accum) + span = local + 1 + chunk = self.stream.take(span * self.world_size) + s = self.rank * span + raw = chunk[s:s + span].to(torch.int64) + x = raw[:-1].reshape(-1, seq_len) + y = raw[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ── Model ───────────────────────────────────────────────────────────────────── + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), b) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__(); self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, p in module.named_parameters(): + if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and p.dtype != torch.float32: + p.data = p.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + self._sl, self._cos, self._sin = 0, None, None + def forward(self, sl: int, device, dtype): + if self._cos is None or self._sl != sl or self._cos.device != device: + t = torch.arange(sl, device=device, dtype=self.inv_freq.dtype) + f = torch.outer(t, self.inv_freq.to(device)) + self._cos = f.cos()[None, None, :, :]; self._sin = f.sin()[None, None, :, :] + self._sl = sl + return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) + +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def fake_quant_per_row(w: Tensor) -> Tensor: + """Per-row percentile int8 QAT with straight-through estimator.""" + if w.ndim < 2 or w.numel() == 0: return w + w32 = w.float() + clip = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + wq = (w32.clamp(-clip[:, None], clip[:, None]) / sc[:, None]).round().clamp(-127, 127) * sc[:, None] + return w + (wq.to(w.dtype) - w).detach() + +class SharedDiffAttention(nn.Module): + """Differential Attention V2 + per-loop low-rank Q delta + GQA.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, rope_base: float, rank: int): + super().__init__() + assert dim % nh == 0 and nh % nkv == 0 + self.nh, self.nkv, self.n_rep = nh, nkv, nh // nkv + self.hd = dim // nh + self.nmem = n_mem + + self.c_q = CastedLinear(dim, dim * 2, bias=False) + self.c_k = CastedLinear(dim, nkv * self.hd * 2, bias=False) + self.c_v = CastedLinear(dim, nkv * self.hd, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + li = 0.8 - 0.6 * math.exp(-0.3) + self.lambda_init = li + self.lambda_q1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_q2 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k2 = nn.Parameter(torch.randn(nh) * 0.1) + + self.q_gain = nn.Parameter(torch.ones(nh, dtype=torch.float32)) + self.delta_A = nn.Parameter(torch.zeros(n_loops, dim * 2, rank)) + self.delta_B = nn.Parameter(torch.zeros(n_loops, rank, dim)) + nn.init.normal_(self.delta_A, std=0.01) + + self.rotary = Rotary(self.hd, base=rope_base) + + def forward(self, x: Tensor, loop_idx: int) -> Tensor: + B, T, C = x.shape + H, KVH, D, NM = self.nh, self.nkv, self.hd, self.nmem + sl = T - NM + + # Q with per-loop low-rank delta + QAT on effective weight + wq = self.c_q.weight + (self.delta_A[loop_idx] @ self.delta_B[loop_idx]) + wq = fake_quant_per_row(wq) + qp = F.linear(x, wq.to(x.dtype)) + q1 = qp[..., :C].reshape(B, T, H, D).transpose(1, 2) + q2 = qp[..., C:].reshape(B, T, H, D).transpose(1, 2) + + # K, V: fresh every loop from updated hidden state + kv = self.c_k(x) + k1 = kv[..., :KVH*D].reshape(B, T, KVH, D).transpose(1, 2) + k2 = kv[..., KVH*D:].reshape(B, T, KVH, D).transpose(1, 2) + v = self.c_v(x).reshape(B, T, KVH, D).transpose(1, 2) + + # QK RMSNorm before RoPE (baseline technique) + q1 = F.rms_norm(q1, (D,)); q2 = F.rms_norm(q2, (D,)) + k1 = F.rms_norm(k1, (D,)); k2 = F.rms_norm(k2, (D,)) + + # RoPE: sequence positions ONLY — memory tokens excluded + cos, sin = self.rotary(sl, x.device, q1.dtype) + def rope_seq(q, k): + qm, qs = q[:, :, :NM], q[:, :, NM:] + km, ks = k[:, :, :NM], k[:, :, NM:] + return (torch.cat([qm, apply_rope(qs, cos, sin)], dim=2), + torch.cat([km, apply_rope(ks, cos, sin)], dim=2)) + q1, k1 = rope_seq(q1, k1) + q2, k2 = rope_seq(q2, k2) + + # Q gain + GQA + g = self.q_gain.to(q1.dtype)[None, :, None, None] + q1 = q1 * g; q2 = q2 * g + k1 = k1.repeat_interleave(self.n_rep, dim=1) + k2 = k2.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + # Differential attention: cancel noise via subtraction + lam = (torch.exp(self.lambda_q1) * torch.exp(self.lambda_k1) + - torch.exp(self.lambda_q2) * torch.exp(self.lambda_k2) + + self.lambda_init).to(q1.dtype)[None, :, None, None] + a1 = F.scaled_dot_product_attention(q1, k1, v, is_causal=True) + a2 = F.scaled_dot_product_attention(q2, k2, v, is_causal=True) + out = (a1 - lam * a2).transpose(1, 2).contiguous().view(B, T, C) + return self.proj(out) + +class ReluSquaredMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + h = dim * mlp_mult + self.fc = CastedLinear(dim, h, bias=False) + self.proj = CastedLinear(h, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class RecurrentBlock(nn.Module): + """Shared block with all per-loop parameters.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, mlp_mult: int, rope_base: float, rank: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SharedDiffAttention(dim, nh, nkv, n_loops, n_mem, rope_base, rank) + self.mlp = ReluSquaredMLP(dim, mlp_mult) + + # Per-loop scale factors (learned attention/MLP gate per depth) + self.attn_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + + # resid_mix: learned interpolation between current x and original x0 + # [0] = weight on current state, [1] = weight on original input + # Initialised to carry forward (mix[0]=1, mix[1]=0), learns to reset selectively + self.resid_mix = nn.Parameter( + torch.stack([torch.ones(n_loops, dim), torch.zeros(n_loops, dim)]).float() + ) + + # Loop position embedding: activation-level depth signal + # Complementary to weight-level low-rank delta + self.loop_embed = nn.Embedding(n_loops, dim) + nn.init.normal_(self.loop_embed.weight, std=0.02) + + def forward(self, x: Tensor, x0: Tensor, loop_idx: int) -> Tensor: + # resid_mix: learned balance of current state vs original input (generalises input injection) + mix = self.resid_mix[:, loop_idx, :].to(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # Loop position embedding (activation-level specialisation) + x = x + self.loop_embed.weight[loop_idx].to(x.dtype)[None, None, :] + + # Attention + MLP with per-loop scale + a = self.attn(self.attn_norm(x), loop_idx) + x = x + self.attn_scale[loop_idx].to(x.dtype)[None, None, :] * a + x = x + self.mlp_scale[loop_idx].to(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class RLM(nn.Module): + """ + Recurrent Language Model. + Single shared block × N_LOOPS with deep supervision and U-Net skips. + """ + def __init__(self, args: Hyperparameters): + super().__init__() + dim, nl, nm = args.model_dim, args.n_loops, args.n_memory + self.n_loops, self.n_memory = nl, nm + self.dim = dim + self.logit_softcap = args.logit_softcap + self.deep_sup_gamma = args.deep_sup_gamma + + self.n_enc = nl // 2 + self.n_skip = min(self.n_enc, nl - self.n_enc) + + self.tok_emb = nn.Embedding(args.vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + self.memory = nn.Parameter(torch.randn(1, nm, dim) * 0.02) + + self.block = RecurrentBlock( + dim, args.num_heads, args.num_kv_heads, nl, nm, + args.mlp_mult, args.rope_base, args.lora_rank + ) + + self.skip_weights = nn.Parameter( + torch.ones(self.n_skip, dim, dtype=torch.float32) + if self.n_skip > 0 else torch.zeros(0, dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + + # Zero-init output projections for stability at loop depth + nn.init.zeros_(self.block.attn.proj.weight) + nn.init.zeros_(self.block.mlp.proj.weight) + + def _logits(self, x: Tensor) -> Tensor: + h = self.final_norm(x[:, self.n_memory:, :]) + l = F.linear(h.reshape(-1, self.dim), self.tok_emb.weight) + return self.logit_softcap * torch.tanh(l / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + B, T = input_ids.shape + tok = F.rms_norm(self.tok_emb(input_ids), (self.dim,)) + x = torch.cat([self.memory.expand(B, -1, -1), tok], dim=1) + x0 = x + + gamma = self.deep_sup_gamma + total_loss = torch.zeros((), device=input_ids.device) + weight_sum = 0.0 + targets = target_ids.reshape(-1) + skips = [] + + for i in range(self.n_loops): + # Gradient checkpointing: recompute activations during backward + x = torch.utils.checkpoint.checkpoint(self.block, x, x0, i, use_reentrant=False) + + # U-Net: encoder half saves, decoder half consumes in reverse + if i < self.n_enc: + skips.append(x) + elif self.n_skip > 0 and skips: + di = i - self.n_enc + if di < self.n_skip: + sw = self.skip_weights[di].to(x.dtype)[None, None, :] + x = x + sw * skips[self.n_enc - 1 - di] + + # Deep supervision: weighted loss at every loop + # Later loops weighted higher (gamma^(N-1-i)), but all contribute + w = gamma ** (self.n_loops - 1 - i) + total_loss = total_loss + w * F.cross_entropy( + self._logits(x).float(), targets, reduction="mean" + ) + weight_sum += w + + return total_loss / weight_sum + +# ── Training ────────────────────────────────────────────────────────────────── + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + pass # torch.compile disabled for T4 + + # ── Distributed + CUDA setup ────────────────────────────────────────────── + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required. For CPU testing use train_gpt_local.py") + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(True); enable_math_sdp(True) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master else None + + def log0(msg: str, console: bool = True) -> None: + if not master: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + + # ── Seed + tokenizer ────────────────────────────────────────────────────── + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"Tokenizer vocab {sp.vocab_size()} != VOCAB_SIZE {args.vocab_size}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_lut, space_lut, bnd_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_tokens:{val_tokens.numel()-1} train_seq_len:{args.train_seq_len}") + + # ── Model ───────────────────────────────────────────────────────────────── + base_model = RLM(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(base_model) + compiled = base_model # torch.compile disabled: T4 compile too slow + model: nn.Module = (DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) + if distributed else compiled) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} (~{n_params/1e6:.2f}M)") + log0(f"n_loops:{args.n_loops} n_memory:{args.n_memory} dim:{args.model_dim} " + f"seq_len:{args.train_seq_len}") + + # ── Optimizers ──────────────────────────────────────────────────────────── + block_named = list(base_model.block.named_parameters()) + matrix_params = [p for n, p in block_named + if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = ([p for n, p in block_named + if p.ndim != 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + + [base_model.skip_weights, base_model.memory, base_model.final_norm.eps + if hasattr(base_model.final_norm, 'eps') and + isinstance(base_model.final_norm.eps, nn.Parameter) else None]) + scalar_params = [p for p in scalar_params if p is not None] + + opt_embed = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_muon = Muon(matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_embed, opt_muon, opt_scalar] + + log0(f"matrix_params:{sum(p.numel() for p in matrix_params):,} " + f"scalar_params:{sum(p.numel() for p in scalar_params):,}") + + # ── Checkpoint resume (thermal crash protection) ────────────────────────── + ckpt_path = f"checkpoint_{args.run_id}.pt" + start_step = 0 + training_time_ms = 0.0 + if os.path.exists(ckpt_path) and master: + log0(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + base_model.load_state_dict(ckpt["model"]) + for opt, sd in zip(optimizers, ckpt["optimizers"]): + opt.load_state_dict(sd) + start_step = ckpt["step"] + training_time_ms = ckpt.get("training_time_ms", 0.0) + log0(f"Resumed at step {start_step}") + if distributed: dist.barrier() + + # ── Data loader ─────────────────────────────────────────────────────────── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad(): + for o in optimizers: o.zero_grad(set_to_none=True) + + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_scale(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wc_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ + if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = args.warmdown_iters * step_ms + rem_ms = max(max_wc_ms - elapsed_ms, 0.0) + return rem_ms / max(wd_ms, 1e-9) if rem_ms <= wd_ms else 1.0 + + # ── Warmup-then-restore ─────────────────────────────────────────────────── + if args.warmup_steps > 0 and start_step == 0: + init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + (model(x, y) * grad_scale).backward() + for o in optimizers: o.step() + zero_grad() + log0(f"warmup_complete:{args.warmup_steps} steps") + base_model.load_state_dict(init_model, strict=True) + for o, sd in zip(optimizers, init_opts): o.load_state_dict(sd) + zero_grad() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ──────────────────────────────────────────────────── + stop_after: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + for step in range(start_step, args.iterations + 1): + last = step == args.iterations or (stop_after is not None and step >= stop_after) + + # Validation + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or last: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + + if last: break + + # Checkpoint save (thermal crash protection) + if args.save_every > 0 and step > start_step and step % args.save_every == 0 and master: + torch.cuda.synchronize() + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + torch.save({"step": step, "model": base_model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "training_time_ms": elapsed_now}, ckpt_path) + log0(f"checkpoint_saved:step={step}") + + # LR update + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_scale(step, elapsed_ms) + for o in optimizers: + for g in o.param_groups: g["lr"] = g["base_lr"] * scale + + # Muon momentum warmup + frac = min(step / max(args.muon_momentum_warmup_steps, 1), 1.0) + mum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for g in opt_muon.param_groups: g["momentum"] = mum + + # Forward + backward with gradient accumulation + zero_grad() + train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for o in optimizers: o.step() + zero_grad() + + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step+1}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ms:.0f}ms step_avg:{approx_ms/(step+1):.2f}ms lr_scale:{scale:.4f}") + + reached = max_wc_ms is not None and approx_ms >= max_wc_ms + if distributed and max_wc_ms is not None: + rt = torch.tensor(int(reached), device=device) + dist.all_reduce(rt, op=dist.ReduceOp.MAX) + reached = bool(rt.item()) + if stop_after is None and reached: + stop_after = step + 1 + + log0(f"peak_memory:{torch.cuda.max_memory_allocated()//1024//1024}MiB " + f"reserved:{torch.cuda.max_memory_reserved()//1024//1024}MiB") + + # ── Serialisation + roundtrip validation ────────────────────────────────── + if master: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + model_bytes = os.path.getsize("final_model.pt") + log0(f"raw_model:{model_bytes} code:{code_bytes} total:{model_bytes+code_bytes}") + + quant_obj, qstats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + if master: + with open("final_model.int8.ptz", "wb") as f: f.write(blob) + qbytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = qstats["baseline_tensor_bytes"] / max(qstats["int8_payload_bytes"], 1) + log0(f"int8_zlib:{qbytes} code:{code_bytes} total:{qbytes+code_bytes} " + f"payload_ratio:{ratio:.2f}x budget_pct:{(qbytes+code_bytes)/16e6*100:.1f}%") + + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: blob_disk = f.read() + qs2 = torch.load(io.BytesIO(zlib.decompress(blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(qs2), strict=True) + torch.cuda.synchronize(); te = time.perf_counter() + qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + torch.cuda.synchronize() + log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-te):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvbpb:.8f}") + + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +PyTorch 2.10.0+cu128 +Fri Mar 20 07:54:36 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 72C P0 30W / 70W | 105MiB / 15360MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 29719 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +val_tokens:62021632 train_seq_len:512 +model_params:4020264 (~4.02M) +n_loops:8 n_memory:0 dim:512 seq_len:512 +matrix_params:3,276,800 scalar_params:219,176 +""" +train_gpt.py — Shared-Weight Recurrent Memory Transformer (RLM) +OpenAI Parameter Golf Competition — Non-Record Submission + +Architecture: One transformer block looped N_LOOPS=16 times (weight tying). + - Differential Attention V2 + per-loop low-rank Q delta (weight specialization) + - resid_mix: learned carry-forward vs input-reset per loop + - Loop position embeddings (activation-level depth signal) + - Memory tokens: learned persistent cross-loop state (RoPE excluded) + - Deep supervision: gamma-weighted loss at every loop iteration + - U-Net skip connections between encoder/decoder loop halves + - Gradient checkpointing: memory-efficient backprop through 16 loops + - QAT: per-row percentile fake-quantization during training + - relu² MLP, logit softcap, QK RMSNorm, q_gain (from baseline) + +At dim=512, 16 loops: ~4.24M params, ~4.24MB int8 — 26.5% of 16MB budget. +Effective compute depth of a ~63M parameter standard transformer. + +Thermal protection: checkpoints saved every SAVE_EVERY steps, auto-resumed. +""" + +from __future__ import annotations +import copy, glob, io, math, os, subprocess, sys, time, uuid, zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ── Hyperparameters ─────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + save_every = int(os.environ.get("SAVE_EVERY", 500)) + + iterations = int(os.environ.get("ITERATIONS", 5000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 600)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 65_536)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0.0)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + n_loops = int(os.environ.get("N_LOOPS", 16)) + n_memory = int(os.environ.get("N_MEMORY", 32)) + mlp_mult = int(os.environ.get("MLP_MULT", 4)) + lora_rank = int(os.environ.get("LORA_RANK", 16)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + deep_sup_gamma= float(os.environ.get("DEEP_SUP_GAMMA", 0.9)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum= float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 300)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.1)) + +# ── Muon Optimizer (from modded-nanogpt / competition baseline) ─────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum, backend_steps, nesterov = ( + group["lr"], group["momentum"], group["backend_steps"], group["nesterov"] + ) + total = sum(int(p.numel()) for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-lr) + curr += p.numel() + return loss + +# ── Control tensor patterns (kept fp32, excluded from Muon) ────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,mlp_scale,resid_mix,q_gain,lambda_q1,lambda_k1,lambda_q2,lambda_k2," + "loop_embed,skip_weight,memory", +).split(",") if p) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# ── Tokenizer-aware BPB (from competition baseline) ─────────────────────────── + +def build_sentencepiece_luts(sp, vocab_size: int, device: torch.device): + sp_vocab = int(sp.vocab_size()) + table = max(sp_vocab, vocab_size) + base_np = np.zeros((table,), dtype=np.int16) + space_np = np.zeros((table,), dtype=np.bool_) + bnd_np = np.ones((table,), dtype=np.bool_) + for tid in range(sp_vocab): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue + bnd_np[tid] = False + if sp.is_byte(tid): base_np[tid] = 1; continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): space_np[tid] = True; piece = piece[1:] + base_np[tid] = len(piece.encode("utf-8")) + return (torch.tensor(base_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(bnd_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No val files: {pattern}") + tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[:usable + 1] + +@torch.no_grad() +def eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut): + local_tok = args.val_batch_size // (world_size * grad_accum_steps) + local_seqs = local_tok // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for s in range(seq_start, seq_end, local_seqs): + se = min(s + local_seqs, seq_end) + raw = val_tokens[s * args.train_seq_len:(se * args.train_seq_len + 1)] + raw = raw.to(device=device, dtype=torch.int64, non_blocking=True) + x = raw[:-1].reshape(-1, args.train_seq_len) + y = raw[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y).detach() + n = float(y.numel()) + loss_sum += loss.to(torch.float64) * n + tok_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_lut[tgt_ids].to(torch.int16) + tb += (space_lut[tgt_ids] & ~bnd_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, tok_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + model.train() + val_loss = float((loss_sum / tok_count).item()) + bpb = float((loss_sum / tok_count / math.log(2.0) * tok_count / byte_count).item()) + return val_loss, bpb + +# ── Quantization (from competition baseline) ────────────────────────────────── + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def _keep_float(name: str, t: Tensor, pt_dtypes: dict) -> Tensor: + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + pt_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + cl = torch.maximum(torch.minimum(t32, clip[:, None]), -clip[:, None]) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(cl / sc[:, None]), -127, 127).to(torch.int8).contiguous() + return q, sc.to(INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + sc = torch.tensor(clip / 127.0 if clip > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip, clip) / sc), -127, 127).to(torch.int8).contiguous() + return q, sc + +def quantize_state_dict_int8(sd: dict): + quant, scales, dtypes, passthrough, pt_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors", + "num_nonfloat_tensors","baseline_tensor_bytes","int8_payload_bytes"), 0) + for name, tensor in sd.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = _keep_float(name, t, pt_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = _quantize_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quant[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = {"__quant_format__": "int8_clean_per_row_v1", + "quantized": quant, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if pt_dtypes: obj["passthrough_orig_dtypes"] = pt_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict) -> dict: + out = {} + qm = obj.get("qmeta", {}) + ptod = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qm.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype).contiguous() + for name, t in obj["passthrough"].items(): + ot = t.detach().cpu().contiguous() + if isinstance(ptod.get(name), str): + ot = ot.to(getattr(torch, ptod[name])).contiguous() + out[name] = ot + return out + +# ── Data loading ────────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + hdr = np.fromfile(file, dtype=" Tensor: + chunks, rem = [], n + while rem > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._next(); continue + k = min(rem, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k; rem -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum: int): + local = global_tokens // (self.world_size * grad_accum) + span = local + 1 + chunk = self.stream.take(span * self.world_size) + s = self.rank * span + raw = chunk[s:s + span].to(torch.int64) + x = raw[:-1].reshape(-1, seq_len) + y = raw[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ── Model ───────────────────────────────────────────────────────────────────── + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), b) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__(); self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, p in module.named_parameters(): + if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and p.dtype != torch.float32: + p.data = p.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + self._sl, self._cos, self._sin = 0, None, None + def forward(self, sl: int, device, dtype): + if self._cos is None or self._sl != sl or self._cos.device != device: + t = torch.arange(sl, device=device, dtype=self.inv_freq.dtype) + f = torch.outer(t, self.inv_freq.to(device)) + self._cos = f.cos()[None, None, :, :]; self._sin = f.sin()[None, None, :, :] + self._sl = sl + return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) + +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def fake_quant_per_row(w: Tensor) -> Tensor: + """Per-row percentile int8 QAT with straight-through estimator.""" + if w.ndim < 2 or w.numel() == 0: return w + w32 = w.float() + clip = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + wq = (w32.clamp(-clip[:, None], clip[:, None]) / sc[:, None]).round().clamp(-127, 127) * sc[:, None] + return w + (wq.to(w.dtype) - w).detach() + +class SharedDiffAttention(nn.Module): + """Differential Attention V2 + per-loop low-rank Q delta + GQA.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, rope_base: float, rank: int): + super().__init__() + assert dim % nh == 0 and nh % nkv == 0 + self.nh, self.nkv, self.n_rep = nh, nkv, nh // nkv + self.hd = dim // nh + self.nmem = n_mem + + self.c_q = CastedLinear(dim, dim * 2, bias=False) + self.c_k = CastedLinear(dim, nkv * self.hd * 2, bias=False) + self.c_v = CastedLinear(dim, nkv * self.hd, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + li = 0.8 - 0.6 * math.exp(-0.3) + self.lambda_init = li + self.lambda_q1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_q2 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k2 = nn.Parameter(torch.randn(nh) * 0.1) + + self.q_gain = nn.Parameter(torch.ones(nh, dtype=torch.float32)) + self.delta_A = nn.Parameter(torch.zeros(n_loops, dim * 2, rank)) + self.delta_B = nn.Parameter(torch.zeros(n_loops, rank, dim)) + nn.init.normal_(self.delta_A, std=0.01) + + self.rotary = Rotary(self.hd, base=rope_base) + + def forward(self, x: Tensor, loop_idx: int) -> Tensor: + B, T, C = x.shape + H, KVH, D, NM = self.nh, self.nkv, self.hd, self.nmem + sl = T - NM + + # Q with per-loop low-rank delta + QAT on effective weight + wq = self.c_q.weight + (self.delta_A[loop_idx] @ self.delta_B[loop_idx]) + wq = fake_quant_per_row(wq) + qp = F.linear(x, wq.to(x.dtype)) + q1 = qp[..., :C].reshape(B, T, H, D).transpose(1, 2) + q2 = qp[..., C:].reshape(B, T, H, D).transpose(1, 2) + + # K, V: fresh every loop from updated hidden state + kv = self.c_k(x) + k1 = kv[..., :KVH*D].reshape(B, T, KVH, D).transpose(1, 2) + k2 = kv[..., KVH*D:].reshape(B, T, KVH, D).transpose(1, 2) + v = self.c_v(x).reshape(B, T, KVH, D).transpose(1, 2) + + # QK RMSNorm before RoPE (baseline technique) + q1 = F.rms_norm(q1, (D,)); q2 = F.rms_norm(q2, (D,)) + k1 = F.rms_norm(k1, (D,)); k2 = F.rms_norm(k2, (D,)) + + # RoPE: sequence positions ONLY — memory tokens excluded + cos, sin = self.rotary(sl, x.device, q1.dtype) + def rope_seq(q, k): + qm, qs = q[:, :, :NM], q[:, :, NM:] + km, ks = k[:, :, :NM], k[:, :, NM:] + return (torch.cat([qm, apply_rope(qs, cos, sin)], dim=2), + torch.cat([km, apply_rope(ks, cos, sin)], dim=2)) + q1, k1 = rope_seq(q1, k1) + q2, k2 = rope_seq(q2, k2) + + # Q gain + GQA + g = self.q_gain.to(q1.dtype)[None, :, None, None] + q1 = q1 * g; q2 = q2 * g + k1 = k1.repeat_interleave(self.n_rep, dim=1) + k2 = k2.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + # Differential attention: cancel noise via subtraction + lam = (torch.exp(self.lambda_q1) * torch.exp(self.lambda_k1) + - torch.exp(self.lambda_q2) * torch.exp(self.lambda_k2) + + self.lambda_init).to(q1.dtype)[None, :, None, None] + a1 = F.scaled_dot_product_attention(q1, k1, v, is_causal=True) + a2 = F.scaled_dot_product_attention(q2, k2, v, is_causal=True) + out = (a1 - lam * a2).transpose(1, 2).contiguous().view(B, T, C) + return self.proj(out) + +class ReluSquaredMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + h = dim * mlp_mult + self.fc = CastedLinear(dim, h, bias=False) + self.proj = CastedLinear(h, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class RecurrentBlock(nn.Module): + """Shared block with all per-loop parameters.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, mlp_mult: int, rope_base: float, rank: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SharedDiffAttention(dim, nh, nkv, n_loops, n_mem, rope_base, rank) + self.mlp = ReluSquaredMLP(dim, mlp_mult) + + # Per-loop scale factors (learned attention/MLP gate per depth) + self.attn_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + + # resid_mix: learned interpolation between current x and original x0 + # [0] = weight on current state, [1] = weight on original input + # Initialised to carry forward (mix[0]=1, mix[1]=0), learns to reset selectively + self.resid_mix = nn.Parameter( + torch.stack([torch.ones(n_loops, dim), torch.zeros(n_loops, dim)]).float() + ) + + # Loop position embedding: activation-level depth signal + # Complementary to weight-level low-rank delta + self.loop_embed = nn.Embedding(n_loops, dim) + nn.init.normal_(self.loop_embed.weight, std=0.02) + + def forward(self, x: Tensor, x0: Tensor, loop_idx: int) -> Tensor: + # resid_mix: learned balance of current state vs original input (generalises input injection) + mix = self.resid_mix[:, loop_idx, :].to(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # Loop position embedding (activation-level specialisation) + x = x + self.loop_embed.weight[loop_idx].to(x.dtype)[None, None, :] + + # Attention + MLP with per-loop scale + a = self.attn(self.attn_norm(x), loop_idx) + x = x + self.attn_scale[loop_idx].to(x.dtype)[None, None, :] * a + x = x + self.mlp_scale[loop_idx].to(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class RLM(nn.Module): + """ + Recurrent Language Model. + Single shared block × N_LOOPS with deep supervision and U-Net skips. + """ + def __init__(self, args: Hyperparameters): + super().__init__() + dim, nl, nm = args.model_dim, args.n_loops, args.n_memory + self.n_loops, self.n_memory = nl, nm + self.dim = dim + self.logit_softcap = args.logit_softcap + self.deep_sup_gamma = args.deep_sup_gamma + + self.n_enc = nl // 2 + self.n_skip = min(self.n_enc, nl - self.n_enc) + + self.tok_emb = nn.Embedding(args.vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + self.memory = nn.Parameter(torch.randn(1, nm, dim) * 0.02) + + self.block = RecurrentBlock( + dim, args.num_heads, args.num_kv_heads, nl, nm, + args.mlp_mult, args.rope_base, args.lora_rank + ) + + self.skip_weights = nn.Parameter( + torch.ones(self.n_skip, dim, dtype=torch.float32) + if self.n_skip > 0 else torch.zeros(0, dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + + # Zero-init output projections for stability at loop depth + nn.init.zeros_(self.block.attn.proj.weight) + nn.init.zeros_(self.block.mlp.proj.weight) + + def _logits(self, x: Tensor) -> Tensor: + h = self.final_norm(x[:, self.n_memory:, :]) + l = F.linear(h.reshape(-1, self.dim), self.tok_emb.weight) + return self.logit_softcap * torch.tanh(l / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + B, T = input_ids.shape + tok = F.rms_norm(self.tok_emb(input_ids), (self.dim,)) + x = torch.cat([self.memory.expand(B, -1, -1), tok], dim=1) + x0 = x + + gamma = self.deep_sup_gamma + total_loss = torch.zeros((), device=input_ids.device) + weight_sum = 0.0 + targets = target_ids.reshape(-1) + skips = [] + + for i in range(self.n_loops): + # Gradient checkpointing: recompute activations during backward + x = torch.utils.checkpoint.checkpoint(self.block, x, x0, i, use_reentrant=False) + + # U-Net: encoder half saves, decoder half consumes in reverse + if i < self.n_enc: + skips.append(x) + elif self.n_skip > 0 and skips: + di = i - self.n_enc + if di < self.n_skip: + sw = self.skip_weights[di].to(x.dtype)[None, None, :] + x = x + sw * skips[self.n_enc - 1 - di] + + # Deep supervision: weighted loss at every loop + # Later loops weighted higher (gamma^(N-1-i)), but all contribute + w = gamma ** (self.n_loops - 1 - i) + total_loss = total_loss + w * F.cross_entropy( + self._logits(x).float(), targets, reduction="mean" + ) + weight_sum += w + + return total_loss / weight_sum + +# ── Training ────────────────────────────────────────────────────────────────── + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + pass # torch.compile disabled for T4 + + # ── Distributed + CUDA setup ────────────────────────────────────────────── + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required. For CPU testing use train_gpt_local.py") + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(True); enable_math_sdp(True) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master else None + + def log0(msg: str, console: bool = True) -> None: + if not master: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + + # ── Seed + tokenizer ────────────────────────────────────────────────────── + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"Tokenizer vocab {sp.vocab_size()} != VOCAB_SIZE {args.vocab_size}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_lut, space_lut, bnd_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_tokens:{val_tokens.numel()-1} train_seq_len:{args.train_seq_len}") + + # ── Model ───────────────────────────────────────────────────────────────── + base_model = RLM(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(base_model) + compiled = base_model # torch.compile disabled: T4 compile too slow + model: nn.Module = (DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) + if distributed else compiled) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} (~{n_params/1e6:.2f}M)") + log0(f"n_loops:{args.n_loops} n_memory:{args.n_memory} dim:{args.model_dim} " + f"seq_len:{args.train_seq_len}") + + # ── Optimizers ──────────────────────────────────────────────────────────── + block_named = list(base_model.block.named_parameters()) + matrix_params = [p for n, p in block_named + if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = ([p for n, p in block_named + if p.ndim != 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + + [base_model.skip_weights, base_model.memory, base_model.final_norm.eps + if hasattr(base_model.final_norm, 'eps') and + isinstance(base_model.final_norm.eps, nn.Parameter) else None]) + scalar_params = [p for p in scalar_params if p is not None] + + opt_embed = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_muon = Muon(matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_embed, opt_muon, opt_scalar] + + log0(f"matrix_params:{sum(p.numel() for p in matrix_params):,} " + f"scalar_params:{sum(p.numel() for p in scalar_params):,}") + + # ── Checkpoint resume (thermal crash protection) ────────────────────────── + ckpt_path = f"checkpoint_{args.run_id}.pt" + start_step = 0 + training_time_ms = 0.0 + if os.path.exists(ckpt_path) and master: + log0(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + base_model.load_state_dict(ckpt["model"]) + for opt, sd in zip(optimizers, ckpt["optimizers"]): + opt.load_state_dict(sd) + start_step = ckpt["step"] + training_time_ms = ckpt.get("training_time_ms", 0.0) + log0(f"Resumed at step {start_step}") + if distributed: dist.barrier() + + # ── Data loader ─────────────────────────────────────────────────────────── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad(): + for o in optimizers: o.zero_grad(set_to_none=True) + + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_scale(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wc_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ + if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = args.warmdown_iters * step_ms + rem_ms = max(max_wc_ms - elapsed_ms, 0.0) + return rem_ms / max(wd_ms, 1e-9) if rem_ms <= wd_ms else 1.0 + + # ── Warmup-then-restore ─────────────────────────────────────────────────── + if args.warmup_steps > 0 and start_step == 0: + init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + (model(x, y) * grad_scale).backward() + for o in optimizers: o.step() + zero_grad() + log0(f"warmup_complete:{args.warmup_steps} steps") + base_model.load_state_dict(init_model, strict=True) + for o, sd in zip(optimizers, init_opts): o.load_state_dict(sd) + zero_grad() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ──────────────────────────────────────────────────── + stop_after: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + for step in range(start_step, args.iterations + 1): + last = step == args.iterations or (stop_after is not None and step >= stop_after) + + # Validation + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or last: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + + if last: break + + # Checkpoint save (thermal crash protection) + if args.save_every > 0 and step > start_step and step % args.save_every == 0 and master: + torch.cuda.synchronize() + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + torch.save({"step": step, "model": base_model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "training_time_ms": elapsed_now}, ckpt_path) + log0(f"checkpoint_saved:step={step}") + + # LR update + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_scale(step, elapsed_ms) + for o in optimizers: + for g in o.param_groups: g["lr"] = g["base_lr"] * scale + + # Muon momentum warmup + frac = min(step / max(args.muon_momentum_warmup_steps, 1), 1.0) + mum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for g in opt_muon.param_groups: g["momentum"] = mum + + # Forward + backward with gradient accumulation + zero_grad() + train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for o in optimizers: o.step() + zero_grad() + + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step+1}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ms:.0f}ms step_avg:{approx_ms/(step+1):.2f}ms lr_scale:{scale:.4f}") + + reached = max_wc_ms is not None and approx_ms >= max_wc_ms + if distributed and max_wc_ms is not None: + rt = torch.tensor(int(reached), device=device) + dist.all_reduce(rt, op=dist.ReduceOp.MAX) + reached = bool(rt.item()) + if stop_after is None and reached: + stop_after = step + 1 + + log0(f"peak_memory:{torch.cuda.max_memory_allocated()//1024//1024}MiB " + f"reserved:{torch.cuda.max_memory_reserved()//1024//1024}MiB") + + # ── Serialisation + roundtrip validation ────────────────────────────────── + if master: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + model_bytes = os.path.getsize("final_model.pt") + log0(f"raw_model:{model_bytes} code:{code_bytes} total:{model_bytes+code_bytes}") + + quant_obj, qstats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + if master: + with open("final_model.int8.ptz", "wb") as f: f.write(blob) + qbytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = qstats["baseline_tensor_bytes"] / max(qstats["int8_payload_bytes"], 1) + log0(f"int8_zlib:{qbytes} code:{code_bytes} total:{qbytes+code_bytes} " + f"payload_ratio:{ratio:.2f}x budget_pct:{(qbytes+code_bytes)/16e6*100:.1f}%") + + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: blob_disk = f.read() + qs2 = torch.load(io.BytesIO(zlib.decompress(blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(qs2), strict=True) + torch.cuda.synchronize(); te = time.perf_counter() + qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + torch.cuda.synchronize() + log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-te):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvbpb:.8f}") + + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +PyTorch 2.10.0+cu128 +Fri Mar 20 07:59:25 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 72C P0 30W / 70W | 105MiB / 15360MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 30923 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +val_tokens:62021760 train_seq_len:128 +model_params:1108756 (~1.11M) +n_loops:2 n_memory:0 dim:256 seq_len:128 +matrix_params:819,200 scalar_params:27,412 +step:1/2 train_loss:6.9318 train_time:1254ms step_avg:1254.35ms lr_scale:0.0033 +step:2/2 train_loss:6.9254 train_time:2020ms step_avg:1009.90ms lr_scale:0.0017 +""" +train_gpt.py — Shared-Weight Recurrent Memory Transformer (RLM) +OpenAI Parameter Golf Competition — Non-Record Submission + +Architecture: One transformer block looped N_LOOPS=16 times (weight tying). + - Differential Attention V2 + per-loop low-rank Q delta (weight specialization) + - resid_mix: learned carry-forward vs input-reset per loop + - Loop position embeddings (activation-level depth signal) + - Memory tokens: learned persistent cross-loop state (RoPE excluded) + - Deep supervision: gamma-weighted loss at every loop iteration + - U-Net skip connections between encoder/decoder loop halves + - Gradient checkpointing: memory-efficient backprop through 16 loops + - QAT: per-row percentile fake-quantization during training + - relu² MLP, logit softcap, QK RMSNorm, q_gain (from baseline) + +At dim=512, 16 loops: ~4.24M params, ~4.24MB int8 — 26.5% of 16MB budget. +Effective compute depth of a ~63M parameter standard transformer. + +Thermal protection: checkpoints saved every SAVE_EVERY steps, auto-resumed. +""" + +from __future__ import annotations +import copy, glob, io, math, os, subprocess, sys, time, uuid, zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ── Hyperparameters ─────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + save_every = int(os.environ.get("SAVE_EVERY", 500)) + + iterations = int(os.environ.get("ITERATIONS", 5000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 600)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 65_536)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0.0)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + n_loops = int(os.environ.get("N_LOOPS", 16)) + n_memory = int(os.environ.get("N_MEMORY", 32)) + mlp_mult = int(os.environ.get("MLP_MULT", 4)) + lora_rank = int(os.environ.get("LORA_RANK", 16)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + deep_sup_gamma= float(os.environ.get("DEEP_SUP_GAMMA", 0.9)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum= float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 300)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.1)) + +# ── Muon Optimizer (from modded-nanogpt / competition baseline) ─────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum, backend_steps, nesterov = ( + group["lr"], group["momentum"], group["backend_steps"], group["nesterov"] + ) + total = sum(int(p.numel()) for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-lr) + curr += p.numel() + return loss + +# ── Control tensor patterns (kept fp32, excluded from Muon) ────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,mlp_scale,resid_mix,q_gain,lambda_q1,lambda_k1,lambda_q2,lambda_k2," + "loop_embed,skip_weight,memory", +).split(",") if p) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# ── Tokenizer-aware BPB (from competition baseline) ─────────────────────────── + +def build_sentencepiece_luts(sp, vocab_size: int, device: torch.device): + sp_vocab = int(sp.vocab_size()) + table = max(sp_vocab, vocab_size) + base_np = np.zeros((table,), dtype=np.int16) + space_np = np.zeros((table,), dtype=np.bool_) + bnd_np = np.ones((table,), dtype=np.bool_) + for tid in range(sp_vocab): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue + bnd_np[tid] = False + if sp.is_byte(tid): base_np[tid] = 1; continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): space_np[tid] = True; piece = piece[1:] + base_np[tid] = len(piece.encode("utf-8")) + return (torch.tensor(base_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(bnd_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No val files: {pattern}") + tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[:usable + 1] + +@torch.no_grad() +def eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut): + local_tok = args.val_batch_size // (world_size * grad_accum_steps) + local_seqs = local_tok // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for s in range(seq_start, seq_end, local_seqs): + se = min(s + local_seqs, seq_end) + raw = val_tokens[s * args.train_seq_len:(se * args.train_seq_len + 1)] + raw = raw.to(device=device, dtype=torch.int64, non_blocking=True) + x = raw[:-1].reshape(-1, args.train_seq_len) + y = raw[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y).detach() + n = float(y.numel()) + loss_sum += loss.to(torch.float64) * n + tok_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_lut[tgt_ids].to(torch.int16) + tb += (space_lut[tgt_ids] & ~bnd_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, tok_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + model.train() + val_loss = float((loss_sum / tok_count).item()) + bpb = float((loss_sum / tok_count / math.log(2.0) * tok_count / byte_count).item()) + return val_loss, bpb + +# ── Quantization (from competition baseline) ────────────────────────────────── + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def _keep_float(name: str, t: Tensor, pt_dtypes: dict) -> Tensor: + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + pt_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + cl = torch.maximum(torch.minimum(t32, clip[:, None]), -clip[:, None]) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(cl / sc[:, None]), -127, 127).to(torch.int8).contiguous() + return q, sc.to(INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + sc = torch.tensor(clip / 127.0 if clip > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip, clip) / sc), -127, 127).to(torch.int8).contiguous() + return q, sc + +def quantize_state_dict_int8(sd: dict): + quant, scales, dtypes, passthrough, pt_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors", + "num_nonfloat_tensors","baseline_tensor_bytes","int8_payload_bytes"), 0) + for name, tensor in sd.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = _keep_float(name, t, pt_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = _quantize_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quant[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = {"__quant_format__": "int8_clean_per_row_v1", + "quantized": quant, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if pt_dtypes: obj["passthrough_orig_dtypes"] = pt_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict) -> dict: + out = {} + qm = obj.get("qmeta", {}) + ptod = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qm.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype).contiguous() + for name, t in obj["passthrough"].items(): + ot = t.detach().cpu().contiguous() + if isinstance(ptod.get(name), str): + ot = ot.to(getattr(torch, ptod[name])).contiguous() + out[name] = ot + return out + +# ── Data loading ────────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + hdr = np.fromfile(file, dtype=" Tensor: + chunks, rem = [], n + while rem > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._next(); continue + k = min(rem, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k; rem -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum: int): + local = global_tokens // (self.world_size * grad_accum) + span = local + 1 + chunk = self.stream.take(span * self.world_size) + s = self.rank * span + raw = chunk[s:s + span].to(torch.int64) + x = raw[:-1].reshape(-1, seq_len) + y = raw[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ── Model ───────────────────────────────────────────────────────────────────── + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), b) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__(); self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, p in module.named_parameters(): + if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and p.dtype != torch.float32: + p.data = p.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + self._sl, self._cos, self._sin = 0, None, None + def forward(self, sl: int, device, dtype): + if self._cos is None or self._sl != sl or self._cos.device != device: + t = torch.arange(sl, device=device, dtype=self.inv_freq.dtype) + f = torch.outer(t, self.inv_freq.to(device)) + self._cos = f.cos()[None, None, :, :]; self._sin = f.sin()[None, None, :, :] + self._sl = sl + return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) + +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def fake_quant_per_row(w: Tensor) -> Tensor: + """Per-row percentile int8 QAT with straight-through estimator.""" + if w.ndim < 2 or w.numel() == 0: return w + w32 = w.float() + clip = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + wq = (w32.clamp(-clip[:, None], clip[:, None]) / sc[:, None]).round().clamp(-127, 127) * sc[:, None] + return w + (wq.to(w.dtype) - w).detach() + +class SharedDiffAttention(nn.Module): + """Differential Attention V2 + per-loop low-rank Q delta + GQA.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, rope_base: float, rank: int): + super().__init__() + assert dim % nh == 0 and nh % nkv == 0 + self.nh, self.nkv, self.n_rep = nh, nkv, nh // nkv + self.hd = dim // nh + self.nmem = n_mem + + self.c_q = CastedLinear(dim, dim * 2, bias=False) + self.c_k = CastedLinear(dim, nkv * self.hd * 2, bias=False) + self.c_v = CastedLinear(dim, nkv * self.hd, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + li = 0.8 - 0.6 * math.exp(-0.3) + self.lambda_init = li + self.lambda_q1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_q2 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k2 = nn.Parameter(torch.randn(nh) * 0.1) + + self.q_gain = nn.Parameter(torch.ones(nh, dtype=torch.float32)) + self.delta_A = nn.Parameter(torch.zeros(n_loops, dim * 2, rank)) + self.delta_B = nn.Parameter(torch.zeros(n_loops, rank, dim)) + nn.init.normal_(self.delta_A, std=0.01) + + self.rotary = Rotary(self.hd, base=rope_base) + + def forward(self, x: Tensor, loop_idx: int) -> Tensor: + B, T, C = x.shape + H, KVH, D, NM = self.nh, self.nkv, self.hd, self.nmem + sl = T - NM + + # Q with per-loop low-rank delta + QAT on effective weight + wq = self.c_q.weight + (self.delta_A[loop_idx] @ self.delta_B[loop_idx]) + wq = fake_quant_per_row(wq) + qp = F.linear(x, wq.to(x.dtype)) + q1 = qp[..., :C].reshape(B, T, H, D).transpose(1, 2) + q2 = qp[..., C:].reshape(B, T, H, D).transpose(1, 2) + + # K, V: fresh every loop from updated hidden state + kv = self.c_k(x) + k1 = kv[..., :KVH*D].reshape(B, T, KVH, D).transpose(1, 2) + k2 = kv[..., KVH*D:].reshape(B, T, KVH, D).transpose(1, 2) + v = self.c_v(x).reshape(B, T, KVH, D).transpose(1, 2) + + # QK RMSNorm before RoPE (baseline technique) + q1 = F.rms_norm(q1, (D,)); q2 = F.rms_norm(q2, (D,)) + k1 = F.rms_norm(k1, (D,)); k2 = F.rms_norm(k2, (D,)) + + # RoPE: sequence positions ONLY — memory tokens excluded + cos, sin = self.rotary(sl, x.device, q1.dtype) + def rope_seq(q, k): + qm, qs = q[:, :, :NM], q[:, :, NM:] + km, ks = k[:, :, :NM], k[:, :, NM:] + return (torch.cat([qm, apply_rope(qs, cos, sin)], dim=2), + torch.cat([km, apply_rope(ks, cos, sin)], dim=2)) + q1, k1 = rope_seq(q1, k1) + q2, k2 = rope_seq(q2, k2) + + # Q gain + GQA + g = self.q_gain.to(q1.dtype)[None, :, None, None] + q1 = q1 * g; q2 = q2 * g + k1 = k1.repeat_interleave(self.n_rep, dim=1) + k2 = k2.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + # Differential attention: cancel noise via subtraction + lam = (torch.exp(self.lambda_q1) * torch.exp(self.lambda_k1) + - torch.exp(self.lambda_q2) * torch.exp(self.lambda_k2) + + self.lambda_init).to(q1.dtype)[None, :, None, None] + a1 = F.scaled_dot_product_attention(q1, k1, v, is_causal=True) + a2 = F.scaled_dot_product_attention(q2, k2, v, is_causal=True) + out = (a1 - lam * a2).transpose(1, 2).contiguous().view(B, T, C) + return self.proj(out) + +class ReluSquaredMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + h = dim * mlp_mult + self.fc = CastedLinear(dim, h, bias=False) + self.proj = CastedLinear(h, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class RecurrentBlock(nn.Module): + """Shared block with all per-loop parameters.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, mlp_mult: int, rope_base: float, rank: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SharedDiffAttention(dim, nh, nkv, n_loops, n_mem, rope_base, rank) + self.mlp = ReluSquaredMLP(dim, mlp_mult) + + # Per-loop scale factors (learned attention/MLP gate per depth) + self.attn_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + + # resid_mix: learned interpolation between current x and original x0 + # [0] = weight on current state, [1] = weight on original input + # Initialised to carry forward (mix[0]=1, mix[1]=0), learns to reset selectively + self.resid_mix = nn.Parameter( + torch.stack([torch.ones(n_loops, dim), torch.zeros(n_loops, dim)]).float() + ) + + # Loop position embedding: activation-level depth signal + # Complementary to weight-level low-rank delta + self.loop_embed = nn.Embedding(n_loops, dim) + nn.init.normal_(self.loop_embed.weight, std=0.02) + + def forward(self, x: Tensor, x0: Tensor, loop_idx: int) -> Tensor: + # resid_mix: learned balance of current state vs original input (generalises input injection) + mix = self.resid_mix[:, loop_idx, :].to(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # Loop position embedding (activation-level specialisation) + x = x + self.loop_embed.weight[loop_idx].to(x.dtype)[None, None, :] + + # Attention + MLP with per-loop scale + a = self.attn(self.attn_norm(x), loop_idx) + x = x + self.attn_scale[loop_idx].to(x.dtype)[None, None, :] * a + x = x + self.mlp_scale[loop_idx].to(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class RLM(nn.Module): + """ + Recurrent Language Model. + Single shared block × N_LOOPS with deep supervision and U-Net skips. + """ + def __init__(self, args: Hyperparameters): + super().__init__() + dim, nl, nm = args.model_dim, args.n_loops, args.n_memory + self.n_loops, self.n_memory = nl, nm + self.dim = dim + self.logit_softcap = args.logit_softcap + self.deep_sup_gamma = args.deep_sup_gamma + + self.n_enc = nl // 2 + self.n_skip = min(self.n_enc, nl - self.n_enc) + + self.tok_emb = nn.Embedding(args.vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + self.memory = nn.Parameter(torch.randn(1, nm, dim) * 0.02) + + self.block = RecurrentBlock( + dim, args.num_heads, args.num_kv_heads, nl, nm, + args.mlp_mult, args.rope_base, args.lora_rank + ) + + self.skip_weights = nn.Parameter( + torch.ones(self.n_skip, dim, dtype=torch.float32) + if self.n_skip > 0 else torch.zeros(0, dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + + # Zero-init output projections for stability at loop depth + nn.init.zeros_(self.block.attn.proj.weight) + nn.init.zeros_(self.block.mlp.proj.weight) + + def _logits(self, x: Tensor) -> Tensor: + h = self.final_norm(x[:, self.n_memory:, :]) + l = F.linear(h.reshape(-1, self.dim), self.tok_emb.weight) + return self.logit_softcap * torch.tanh(l / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + B, T = input_ids.shape + tok = F.rms_norm(self.tok_emb(input_ids), (self.dim,)) + x = torch.cat([self.memory.expand(B, -1, -1), tok], dim=1) + x0 = x + + gamma = self.deep_sup_gamma + total_loss = torch.zeros((), device=input_ids.device) + weight_sum = 0.0 + targets = target_ids.reshape(-1) + skips = [] + + for i in range(self.n_loops): + # Gradient checkpointing: recompute activations during backward + x = torch.utils.checkpoint.checkpoint(self.block, x, x0, i, use_reentrant=False) + + # U-Net: encoder half saves, decoder half consumes in reverse + if i < self.n_enc: + skips.append(x) + elif self.n_skip > 0 and skips: + di = i - self.n_enc + if di < self.n_skip: + sw = self.skip_weights[di].to(x.dtype)[None, None, :] + x = x + sw * skips[self.n_enc - 1 - di] + + # Deep supervision: weighted loss at every loop + # Later loops weighted higher (gamma^(N-1-i)), but all contribute + w = gamma ** (self.n_loops - 1 - i) + total_loss = total_loss + w * F.cross_entropy( + self._logits(x).float(), targets, reduction="mean" + ) + weight_sum += w + + return total_loss / weight_sum + +# ── Training ────────────────────────────────────────────────────────────────── + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + pass # torch.compile disabled for T4 + + # ── Distributed + CUDA setup ────────────────────────────────────────────── + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required. For CPU testing use train_gpt_local.py") + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(True); enable_math_sdp(True) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master else None + + def log0(msg: str, console: bool = True) -> None: + if not master: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + + # ── Seed + tokenizer ────────────────────────────────────────────────────── + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"Tokenizer vocab {sp.vocab_size()} != VOCAB_SIZE {args.vocab_size}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_lut, space_lut, bnd_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_tokens:{val_tokens.numel()-1} train_seq_len:{args.train_seq_len}") + + # ── Model ───────────────────────────────────────────────────────────────── + base_model = RLM(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(base_model) + compiled = base_model # torch.compile disabled: T4 compile too slow + model: nn.Module = (DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) + if distributed else compiled) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} (~{n_params/1e6:.2f}M)") + log0(f"n_loops:{args.n_loops} n_memory:{args.n_memory} dim:{args.model_dim} " + f"seq_len:{args.train_seq_len}") + + # ── Optimizers ──────────────────────────────────────────────────────────── + block_named = list(base_model.block.named_parameters()) + matrix_params = [p for n, p in block_named + if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = ([p for n, p in block_named + if p.ndim != 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + + [base_model.skip_weights, base_model.memory, base_model.final_norm.eps + if hasattr(base_model.final_norm, 'eps') and + isinstance(base_model.final_norm.eps, nn.Parameter) else None]) + scalar_params = [p for p in scalar_params if p is not None] + + opt_embed = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_muon = Muon(matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_embed, opt_muon, opt_scalar] + + log0(f"matrix_params:{sum(p.numel() for p in matrix_params):,} " + f"scalar_params:{sum(p.numel() for p in scalar_params):,}") + + # ── Checkpoint resume (thermal crash protection) ────────────────────────── + ckpt_path = f"checkpoint_{args.run_id}.pt" + start_step = 0 + training_time_ms = 0.0 + if os.path.exists(ckpt_path) and master: + log0(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + base_model.load_state_dict(ckpt["model"]) + for opt, sd in zip(optimizers, ckpt["optimizers"]): + opt.load_state_dict(sd) + start_step = ckpt["step"] + training_time_ms = ckpt.get("training_time_ms", 0.0) + log0(f"Resumed at step {start_step}") + if distributed: dist.barrier() + + # ── Data loader ─────────────────────────────────────────────────────────── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad(): + for o in optimizers: o.zero_grad(set_to_none=True) + + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_scale(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wc_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ + if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = args.warmdown_iters * step_ms + rem_ms = max(max_wc_ms - elapsed_ms, 0.0) + return rem_ms / max(wd_ms, 1e-9) if rem_ms <= wd_ms else 1.0 + + # ── Warmup-then-restore ─────────────────────────────────────────────────── + if args.warmup_steps > 0 and start_step == 0: + init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + (model(x, y) * grad_scale).backward() + for o in optimizers: o.step() + zero_grad() + log0(f"warmup_complete:{args.warmup_steps} steps") + base_model.load_state_dict(init_model, strict=True) + for o, sd in zip(optimizers, init_opts): o.load_state_dict(sd) + zero_grad() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ──────────────────────────────────────────────────── + stop_after: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + for step in range(start_step, args.iterations + 1): + last = step == args.iterations or (stop_after is not None and step >= stop_after) + + # Validation + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or last: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + + if last: break + + # Checkpoint save (thermal crash protection) + if args.save_every > 0 and step > start_step and step % args.save_every == 0 and master: + torch.cuda.synchronize() + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + torch.save({"step": step, "model": base_model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "training_time_ms": elapsed_now}, ckpt_path) + log0(f"checkpoint_saved:step={step}") + + # LR update + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_scale(step, elapsed_ms) + for o in optimizers: + for g in o.param_groups: g["lr"] = g["base_lr"] * scale + + # Muon momentum warmup + frac = min(step / max(args.muon_momentum_warmup_steps, 1), 1.0) + mum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for g in opt_muon.param_groups: g["momentum"] = mum + + # Forward + backward with gradient accumulation + zero_grad() + train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for o in optimizers: o.step() + zero_grad() + + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step+1}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ms:.0f}ms step_avg:{approx_ms/(step+1):.2f}ms lr_scale:{scale:.4f}") + + reached = max_wc_ms is not None and approx_ms >= max_wc_ms + if distributed and max_wc_ms is not None: + rt = torch.tensor(int(reached), device=device) + dist.all_reduce(rt, op=dist.ReduceOp.MAX) + reached = bool(rt.item()) + if stop_after is None and reached: + stop_after = step + 1 + + log0(f"peak_memory:{torch.cuda.max_memory_allocated()//1024//1024}MiB " + f"reserved:{torch.cuda.max_memory_reserved()//1024//1024}MiB") + + # ── Serialisation + roundtrip validation ────────────────────────────────── + if master: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + model_bytes = os.path.getsize("final_model.pt") + log0(f"raw_model:{model_bytes} code:{code_bytes} total:{model_bytes+code_bytes}") + + quant_obj, qstats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + if master: + with open("final_model.int8.ptz", "wb") as f: f.write(blob) + qbytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = qstats["baseline_tensor_bytes"] / max(qstats["int8_payload_bytes"], 1) + log0(f"int8_zlib:{qbytes} code:{code_bytes} total:{qbytes+code_bytes} " + f"payload_ratio:{ratio:.2f}x budget_pct:{(qbytes+code_bytes)/16e6*100:.1f}%") + + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: blob_disk = f.read() + qs2 = torch.load(io.BytesIO(zlib.decompress(blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(qs2), strict=True) + torch.cuda.synchronize(); te = time.perf_counter() + qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + torch.cuda.synchronize() + log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-te):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvbpb:.8f}") + + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +PyTorch 2.10.0+cu128 +Fri Mar 20 08:01:49 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 72C P0 29W / 70W | 105MiB / 15360MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 31545 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +val_tokens:62021632 train_seq_len:512 +model_params:4239400 (~4.24M) +n_loops:16 n_memory:0 dim:512 seq_len:512 +matrix_params:3,276,800 scalar_params:438,312 +step:1/3 train_loss:6.9345 train_time:21101ms step_avg:21101.38ms lr_scale:0.0050 +step:2/3 train_loss:6.9107 train_time:42747ms step_avg:21373.47ms lr_scale:0.0033 +step:3/3 train_loss:6.8906 train_time:64324ms step_avg:21441.24ms lr_scale:0.0017 +""" +train_gpt.py — Shared-Weight Recurrent Memory Transformer (RLM) +OpenAI Parameter Golf Competition — Non-Record Submission + +Architecture: One transformer block looped N_LOOPS=16 times (weight tying). + - Differential Attention V2 + per-loop low-rank Q delta (weight specialization) + - resid_mix: learned carry-forward vs input-reset per loop + - Loop position embeddings (activation-level depth signal) + - Memory tokens: learned persistent cross-loop state (RoPE excluded) + - Deep supervision: gamma-weighted loss at every loop iteration + - U-Net skip connections between encoder/decoder loop halves + - Gradient checkpointing: memory-efficient backprop through 16 loops + - QAT: per-row percentile fake-quantization during training + - relu² MLP, logit softcap, QK RMSNorm, q_gain (from baseline) + +At dim=512, 16 loops: ~4.24M params, ~4.24MB int8 — 26.5% of 16MB budget. +Effective compute depth of a ~63M parameter standard transformer. + +Thermal protection: checkpoints saved every SAVE_EVERY steps, auto-resumed. +""" + +from __future__ import annotations +import copy, glob, io, math, os, subprocess, sys, time, uuid, zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ── Hyperparameters ─────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + save_every = int(os.environ.get("SAVE_EVERY", 500)) + + iterations = int(os.environ.get("ITERATIONS", 5000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 600)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 65_536)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0.0)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + n_loops = int(os.environ.get("N_LOOPS", 16)) + n_memory = int(os.environ.get("N_MEMORY", 32)) + mlp_mult = int(os.environ.get("MLP_MULT", 4)) + lora_rank = int(os.environ.get("LORA_RANK", 16)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + deep_sup_gamma= float(os.environ.get("DEEP_SUP_GAMMA", 0.9)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum= float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 300)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.1)) + +# ── Muon Optimizer (from modded-nanogpt / competition baseline) ─────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum, backend_steps, nesterov = ( + group["lr"], group["momentum"], group["backend_steps"], group["nesterov"] + ) + total = sum(int(p.numel()) for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-lr) + curr += p.numel() + return loss + +# ── Control tensor patterns (kept fp32, excluded from Muon) ────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,mlp_scale,resid_mix,q_gain,lambda_q1,lambda_k1,lambda_q2,lambda_k2," + "loop_embed,skip_weight,memory", +).split(",") if p) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# ── Tokenizer-aware BPB (from competition baseline) ─────────────────────────── + +def build_sentencepiece_luts(sp, vocab_size: int, device: torch.device): + sp_vocab = int(sp.vocab_size()) + table = max(sp_vocab, vocab_size) + base_np = np.zeros((table,), dtype=np.int16) + space_np = np.zeros((table,), dtype=np.bool_) + bnd_np = np.ones((table,), dtype=np.bool_) + for tid in range(sp_vocab): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue + bnd_np[tid] = False + if sp.is_byte(tid): base_np[tid] = 1; continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): space_np[tid] = True; piece = piece[1:] + base_np[tid] = len(piece.encode("utf-8")) + return (torch.tensor(base_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(bnd_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No val files: {pattern}") + tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[:usable + 1] + +@torch.no_grad() +def eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut): + local_tok = args.val_batch_size // (world_size * grad_accum_steps) + local_seqs = local_tok // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for s in range(seq_start, seq_end, local_seqs): + se = min(s + local_seqs, seq_end) + raw = val_tokens[s * args.train_seq_len:(se * args.train_seq_len + 1)] + raw = raw.to(device=device, dtype=torch.int64, non_blocking=True) + x = raw[:-1].reshape(-1, args.train_seq_len) + y = raw[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y).detach() + n = float(y.numel()) + loss_sum += loss.to(torch.float64) * n + tok_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_lut[tgt_ids].to(torch.int16) + tb += (space_lut[tgt_ids] & ~bnd_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, tok_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + model.train() + val_loss = float((loss_sum / tok_count).item()) + bpb = float((loss_sum / tok_count / math.log(2.0) * tok_count / byte_count).item()) + return val_loss, bpb + +# ── Quantization (from competition baseline) ────────────────────────────────── + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def _keep_float(name: str, t: Tensor, pt_dtypes: dict) -> Tensor: + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + pt_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + cl = torch.maximum(torch.minimum(t32, clip[:, None]), -clip[:, None]) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(cl / sc[:, None]), -127, 127).to(torch.int8).contiguous() + return q, sc.to(INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + sc = torch.tensor(clip / 127.0 if clip > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip, clip) / sc), -127, 127).to(torch.int8).contiguous() + return q, sc + +def quantize_state_dict_int8(sd: dict): + quant, scales, dtypes, passthrough, pt_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors", + "num_nonfloat_tensors","baseline_tensor_bytes","int8_payload_bytes"), 0) + for name, tensor in sd.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = _keep_float(name, t, pt_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = _quantize_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quant[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = {"__quant_format__": "int8_clean_per_row_v1", + "quantized": quant, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if pt_dtypes: obj["passthrough_orig_dtypes"] = pt_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict) -> dict: + out = {} + qm = obj.get("qmeta", {}) + ptod = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qm.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype).contiguous() + for name, t in obj["passthrough"].items(): + ot = t.detach().cpu().contiguous() + if isinstance(ptod.get(name), str): + ot = ot.to(getattr(torch, ptod[name])).contiguous() + out[name] = ot + return out + +# ── Data loading ────────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + hdr = np.fromfile(file, dtype=" Tensor: + chunks, rem = [], n + while rem > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._next(); continue + k = min(rem, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k; rem -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum: int): + local = global_tokens // (self.world_size * grad_accum) + span = local + 1 + chunk = self.stream.take(span * self.world_size) + s = self.rank * span + raw = chunk[s:s + span].to(torch.int64) + x = raw[:-1].reshape(-1, seq_len) + y = raw[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ── Model ───────────────────────────────────────────────────────────────────── + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), b) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__(); self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, p in module.named_parameters(): + if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and p.dtype != torch.float32: + p.data = p.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + self._sl, self._cos, self._sin = 0, None, None + def forward(self, sl: int, device, dtype): + if self._cos is None or self._sl != sl or self._cos.device != device: + t = torch.arange(sl, device=device, dtype=self.inv_freq.dtype) + f = torch.outer(t, self.inv_freq.to(device)) + self._cos = f.cos()[None, None, :, :]; self._sin = f.sin()[None, None, :, :] + self._sl = sl + return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) + +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def fake_quant_per_row(w: Tensor) -> Tensor: + """Per-row percentile int8 QAT with straight-through estimator.""" + if w.ndim < 2 or w.numel() == 0: return w + w32 = w.float() + clip = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + wq = (w32.clamp(-clip[:, None], clip[:, None]) / sc[:, None]).round().clamp(-127, 127) * sc[:, None] + return w + (wq.to(w.dtype) - w).detach() + +class SharedDiffAttention(nn.Module): + """Differential Attention V2 + per-loop low-rank Q delta + GQA.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, rope_base: float, rank: int): + super().__init__() + assert dim % nh == 0 and nh % nkv == 0 + self.nh, self.nkv, self.n_rep = nh, nkv, nh // nkv + self.hd = dim // nh + self.nmem = n_mem + + self.c_q = CastedLinear(dim, dim * 2, bias=False) + self.c_k = CastedLinear(dim, nkv * self.hd * 2, bias=False) + self.c_v = CastedLinear(dim, nkv * self.hd, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + li = 0.8 - 0.6 * math.exp(-0.3) + self.lambda_init = li + self.lambda_q1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_q2 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k2 = nn.Parameter(torch.randn(nh) * 0.1) + + self.q_gain = nn.Parameter(torch.ones(nh, dtype=torch.float32)) + self.delta_A = nn.Parameter(torch.zeros(n_loops, dim * 2, rank)) + self.delta_B = nn.Parameter(torch.zeros(n_loops, rank, dim)) + nn.init.normal_(self.delta_A, std=0.01) + + self.rotary = Rotary(self.hd, base=rope_base) + + def forward(self, x: Tensor, loop_idx: int) -> Tensor: + B, T, C = x.shape + H, KVH, D, NM = self.nh, self.nkv, self.hd, self.nmem + sl = T - NM + + # Q with per-loop low-rank delta + QAT on effective weight + wq = self.c_q.weight + (self.delta_A[loop_idx] @ self.delta_B[loop_idx]) + wq = fake_quant_per_row(wq) + qp = F.linear(x, wq.to(x.dtype)) + q1 = qp[..., :C].reshape(B, T, H, D).transpose(1, 2) + q2 = qp[..., C:].reshape(B, T, H, D).transpose(1, 2) + + # K, V: fresh every loop from updated hidden state + kv = self.c_k(x) + k1 = kv[..., :KVH*D].reshape(B, T, KVH, D).transpose(1, 2) + k2 = kv[..., KVH*D:].reshape(B, T, KVH, D).transpose(1, 2) + v = self.c_v(x).reshape(B, T, KVH, D).transpose(1, 2) + + # QK RMSNorm before RoPE (baseline technique) + q1 = F.rms_norm(q1, (D,)); q2 = F.rms_norm(q2, (D,)) + k1 = F.rms_norm(k1, (D,)); k2 = F.rms_norm(k2, (D,)) + + # RoPE: sequence positions ONLY — memory tokens excluded + cos, sin = self.rotary(sl, x.device, q1.dtype) + def rope_seq(q, k): + qm, qs = q[:, :, :NM], q[:, :, NM:] + km, ks = k[:, :, :NM], k[:, :, NM:] + return (torch.cat([qm, apply_rope(qs, cos, sin)], dim=2), + torch.cat([km, apply_rope(ks, cos, sin)], dim=2)) + q1, k1 = rope_seq(q1, k1) + q2, k2 = rope_seq(q2, k2) + + # Q gain + GQA + g = self.q_gain.to(q1.dtype)[None, :, None, None] + q1 = q1 * g; q2 = q2 * g + k1 = k1.repeat_interleave(self.n_rep, dim=1) + k2 = k2.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + # Differential attention: cancel noise via subtraction + lam = (torch.exp(self.lambda_q1) * torch.exp(self.lambda_k1) + - torch.exp(self.lambda_q2) * torch.exp(self.lambda_k2) + + self.lambda_init).to(q1.dtype)[None, :, None, None] + a1 = F.scaled_dot_product_attention(q1, k1, v, is_causal=True) + a2 = F.scaled_dot_product_attention(q2, k2, v, is_causal=True) + out = (a1 - lam * a2).transpose(1, 2).contiguous().view(B, T, C) + return self.proj(out) + +class ReluSquaredMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + h = dim * mlp_mult + self.fc = CastedLinear(dim, h, bias=False) + self.proj = CastedLinear(h, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class RecurrentBlock(nn.Module): + """Shared block with all per-loop parameters.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, mlp_mult: int, rope_base: float, rank: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SharedDiffAttention(dim, nh, nkv, n_loops, n_mem, rope_base, rank) + self.mlp = ReluSquaredMLP(dim, mlp_mult) + + # Per-loop scale factors (learned attention/MLP gate per depth) + self.attn_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + + # resid_mix: learned interpolation between current x and original x0 + # [0] = weight on current state, [1] = weight on original input + # Initialised to carry forward (mix[0]=1, mix[1]=0), learns to reset selectively + self.resid_mix = nn.Parameter( + torch.stack([torch.ones(n_loops, dim), torch.zeros(n_loops, dim)]).float() + ) + + # Loop position embedding: activation-level depth signal + # Complementary to weight-level low-rank delta + self.loop_embed = nn.Embedding(n_loops, dim) + nn.init.normal_(self.loop_embed.weight, std=0.02) + + def forward(self, x: Tensor, x0: Tensor, loop_idx: int) -> Tensor: + # resid_mix: learned balance of current state vs original input (generalises input injection) + mix = self.resid_mix[:, loop_idx, :].to(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # Loop position embedding (activation-level specialisation) + x = x + self.loop_embed.weight[loop_idx].to(x.dtype)[None, None, :] + + # Attention + MLP with per-loop scale + a = self.attn(self.attn_norm(x), loop_idx) + x = x + self.attn_scale[loop_idx].to(x.dtype)[None, None, :] * a + x = x + self.mlp_scale[loop_idx].to(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class RLM(nn.Module): + """ + Recurrent Language Model. + Single shared block × N_LOOPS with deep supervision and U-Net skips. + """ + def __init__(self, args: Hyperparameters): + super().__init__() + dim, nl, nm = args.model_dim, args.n_loops, args.n_memory + self.n_loops, self.n_memory = nl, nm + self.dim = dim + self.logit_softcap = args.logit_softcap + self.deep_sup_gamma = args.deep_sup_gamma + + self.n_enc = nl // 2 + self.n_skip = min(self.n_enc, nl - self.n_enc) + + self.tok_emb = nn.Embedding(args.vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + self.memory = nn.Parameter(torch.randn(1, nm, dim) * 0.02) + + self.block = RecurrentBlock( + dim, args.num_heads, args.num_kv_heads, nl, nm, + args.mlp_mult, args.rope_base, args.lora_rank + ) + + self.skip_weights = nn.Parameter( + torch.ones(self.n_skip, dim, dtype=torch.float32) + if self.n_skip > 0 else torch.zeros(0, dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + + # Zero-init output projections for stability at loop depth + nn.init.zeros_(self.block.attn.proj.weight) + nn.init.zeros_(self.block.mlp.proj.weight) + + def _logits(self, x: Tensor) -> Tensor: + h = self.final_norm(x[:, self.n_memory:, :]) + l = F.linear(h.reshape(-1, self.dim), self.tok_emb.weight) + return self.logit_softcap * torch.tanh(l / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + B, T = input_ids.shape + tok = F.rms_norm(self.tok_emb(input_ids), (self.dim,)) + x = torch.cat([self.memory.expand(B, -1, -1), tok], dim=1) + x0 = x + + gamma = self.deep_sup_gamma + total_loss = torch.zeros((), device=input_ids.device) + weight_sum = 0.0 + targets = target_ids.reshape(-1) + skips = [] + + for i in range(self.n_loops): + # Gradient checkpointing: recompute activations during backward + x = torch.utils.checkpoint.checkpoint(self.block, x, x0, i, use_reentrant=False) + + # U-Net: encoder half saves, decoder half consumes in reverse + if i < self.n_enc: + skips.append(x) + elif self.n_skip > 0 and skips: + di = i - self.n_enc + if di < self.n_skip: + sw = self.skip_weights[di].to(x.dtype)[None, None, :] + x = x + sw * skips[self.n_enc - 1 - di] + + # Deep supervision: weighted loss at every loop + # Later loops weighted higher (gamma^(N-1-i)), but all contribute + w = gamma ** (self.n_loops - 1 - i) + total_loss = total_loss + w * F.cross_entropy( + self._logits(x).float(), targets, reduction="mean" + ) + weight_sum += w + + return total_loss / weight_sum + +# ── Training ────────────────────────────────────────────────────────────────── + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + pass # torch.compile disabled for T4 + + # ── Distributed + CUDA setup ────────────────────────────────────────────── + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required. For CPU testing use train_gpt_local.py") + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(True); enable_math_sdp(True) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master else None + + def log0(msg: str, console: bool = True) -> None: + if not master: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + + # ── Seed + tokenizer ────────────────────────────────────────────────────── + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"Tokenizer vocab {sp.vocab_size()} != VOCAB_SIZE {args.vocab_size}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_lut, space_lut, bnd_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_tokens:{val_tokens.numel()-1} train_seq_len:{args.train_seq_len}") + + # ── Model ───────────────────────────────────────────────────────────────── + base_model = RLM(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(base_model) + compiled = base_model # torch.compile disabled: T4 compile too slow + model: nn.Module = (DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) + if distributed else compiled) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} (~{n_params/1e6:.2f}M)") + log0(f"n_loops:{args.n_loops} n_memory:{args.n_memory} dim:{args.model_dim} " + f"seq_len:{args.train_seq_len}") + + # ── Optimizers ──────────────────────────────────────────────────────────── + block_named = list(base_model.block.named_parameters()) + matrix_params = [p for n, p in block_named + if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = ([p for n, p in block_named + if p.ndim != 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + + [base_model.skip_weights, base_model.memory, base_model.final_norm.eps + if hasattr(base_model.final_norm, 'eps') and + isinstance(base_model.final_norm.eps, nn.Parameter) else None]) + scalar_params = [p for p in scalar_params if p is not None] + + opt_embed = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_muon = Muon(matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_embed, opt_muon, opt_scalar] + + log0(f"matrix_params:{sum(p.numel() for p in matrix_params):,} " + f"scalar_params:{sum(p.numel() for p in scalar_params):,}") + + # ── Checkpoint resume (thermal crash protection) ────────────────────────── + ckpt_path = f"checkpoint_{args.run_id}.pt" + start_step = 0 + training_time_ms = 0.0 + if os.path.exists(ckpt_path) and master: + log0(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + base_model.load_state_dict(ckpt["model"]) + for opt, sd in zip(optimizers, ckpt["optimizers"]): + opt.load_state_dict(sd) + start_step = ckpt["step"] + training_time_ms = ckpt.get("training_time_ms", 0.0) + log0(f"Resumed at step {start_step}") + if distributed: dist.barrier() + + # ── Data loader ─────────────────────────────────────────────────────────── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad(): + for o in optimizers: o.zero_grad(set_to_none=True) + + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_scale(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wc_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ + if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = args.warmdown_iters * step_ms + rem_ms = max(max_wc_ms - elapsed_ms, 0.0) + return rem_ms / max(wd_ms, 1e-9) if rem_ms <= wd_ms else 1.0 + + # ── Warmup-then-restore ─────────────────────────────────────────────────── + if args.warmup_steps > 0 and start_step == 0: + init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + (model(x, y) * grad_scale).backward() + for o in optimizers: o.step() + zero_grad() + log0(f"warmup_complete:{args.warmup_steps} steps") + base_model.load_state_dict(init_model, strict=True) + for o, sd in zip(optimizers, init_opts): o.load_state_dict(sd) + zero_grad() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ──────────────────────────────────────────────────── + stop_after: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + for step in range(start_step, args.iterations + 1): + last = step == args.iterations or (stop_after is not None and step >= stop_after) + + # Validation + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or last: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + + if last: break + + # Checkpoint save (thermal crash protection) + if args.save_every > 0 and step > start_step and step % args.save_every == 0 and master: + torch.cuda.synchronize() + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + torch.save({"step": step, "model": base_model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "training_time_ms": elapsed_now}, ckpt_path) + log0(f"checkpoint_saved:step={step}") + + # LR update + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_scale(step, elapsed_ms) + for o in optimizers: + for g in o.param_groups: g["lr"] = g["base_lr"] * scale + + # Muon momentum warmup + frac = min(step / max(args.muon_momentum_warmup_steps, 1), 1.0) + mum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for g in opt_muon.param_groups: g["momentum"] = mum + + # Forward + backward with gradient accumulation + zero_grad() + train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for o in optimizers: o.step() + zero_grad() + + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step+1}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ms:.0f}ms step_avg:{approx_ms/(step+1):.2f}ms lr_scale:{scale:.4f}") + + reached = max_wc_ms is not None and approx_ms >= max_wc_ms + if distributed and max_wc_ms is not None: + rt = torch.tensor(int(reached), device=device) + dist.all_reduce(rt, op=dist.ReduceOp.MAX) + reached = bool(rt.item()) + if stop_after is None and reached: + stop_after = step + 1 + + log0(f"peak_memory:{torch.cuda.max_memory_allocated()//1024//1024}MiB " + f"reserved:{torch.cuda.max_memory_reserved()//1024//1024}MiB") + + # ── Serialisation + roundtrip validation ────────────────────────────────── + if master: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + model_bytes = os.path.getsize("final_model.pt") + log0(f"raw_model:{model_bytes} code:{code_bytes} total:{model_bytes+code_bytes}") + + quant_obj, qstats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + if master: + with open("final_model.int8.ptz", "wb") as f: f.write(blob) + qbytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = qstats["baseline_tensor_bytes"] / max(qstats["int8_payload_bytes"], 1) + log0(f"int8_zlib:{qbytes} code:{code_bytes} total:{qbytes+code_bytes} " + f"payload_ratio:{ratio:.2f}x budget_pct:{(qbytes+code_bytes)/16e6*100:.1f}%") + + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: blob_disk = f.read() + qs2 = torch.load(io.BytesIO(zlib.decompress(blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(qs2), strict=True) + torch.cuda.synchronize(); te = time.perf_counter() + qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + torch.cuda.synchronize() + log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-te):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvbpb:.8f}") + + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +PyTorch 2.10.0+cu128 +Fri Mar 20 08:06:05 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 72C P0 30W / 70W | 105MiB / 15360MiB | 4% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 32680 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +val_tokens:62021632 train_seq_len:256 +model_params:4239400 (~4.24M) +n_loops:16 n_memory:0 dim:512 seq_len:256 +matrix_params:3,276,800 scalar_params:438,312 +step:1/3 train_loss:6.9345 train_time:18213ms step_avg:18212.90ms lr_scale:0.0050 +step:2/3 train_loss:6.9107 train_time:36956ms step_avg:18477.98ms lr_scale:0.0033 +""" +train_gpt.py — Shared-Weight Recurrent Memory Transformer (RLM) +OpenAI Parameter Golf Competition — Non-Record Submission + +Architecture: One transformer block looped N_LOOPS=16 times (weight tying). + - Differential Attention V2 + per-loop low-rank Q delta (weight specialization) + - resid_mix: learned carry-forward vs input-reset per loop + - Loop position embeddings (activation-level depth signal) + - Memory tokens: learned persistent cross-loop state (RoPE excluded) + - Deep supervision: gamma-weighted loss at every loop iteration + - U-Net skip connections between encoder/decoder loop halves + - Gradient checkpointing: memory-efficient backprop through 16 loops + - QAT: per-row percentile fake-quantization during training + - relu² MLP, logit softcap, QK RMSNorm, q_gain (from baseline) + +At dim=512, 16 loops: ~4.24M params, ~4.24MB int8 — 26.5% of 16MB budget. +Effective compute depth of a ~63M parameter standard transformer. + +Thermal protection: checkpoints saved every SAVE_EVERY steps, auto-resumed. +""" + +from __future__ import annotations +import copy, glob, io, math, os, subprocess, sys, time, uuid, zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ── Hyperparameters ─────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + save_every = int(os.environ.get("SAVE_EVERY", 500)) + + iterations = int(os.environ.get("ITERATIONS", 5000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 600)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 65_536)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0.0)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + n_loops = int(os.environ.get("N_LOOPS", 16)) + n_memory = int(os.environ.get("N_MEMORY", 32)) + mlp_mult = int(os.environ.get("MLP_MULT", 4)) + lora_rank = int(os.environ.get("LORA_RANK", 16)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + deep_sup_gamma= float(os.environ.get("DEEP_SUP_GAMMA", 0.9)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum= float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 300)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.1)) + +# ── Muon Optimizer (from modded-nanogpt / competition baseline) ─────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum, backend_steps, nesterov = ( + group["lr"], group["momentum"], group["backend_steps"], group["nesterov"] + ) + total = sum(int(p.numel()) for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-lr) + curr += p.numel() + return loss + +# ── Control tensor patterns (kept fp32, excluded from Muon) ────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,mlp_scale,resid_mix,q_gain,lambda_q1,lambda_k1,lambda_q2,lambda_k2," + "loop_embed,skip_weight,memory", +).split(",") if p) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# ── Tokenizer-aware BPB (from competition baseline) ─────────────────────────── + +def build_sentencepiece_luts(sp, vocab_size: int, device: torch.device): + sp_vocab = int(sp.vocab_size()) + table = max(sp_vocab, vocab_size) + base_np = np.zeros((table,), dtype=np.int16) + space_np = np.zeros((table,), dtype=np.bool_) + bnd_np = np.ones((table,), dtype=np.bool_) + for tid in range(sp_vocab): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue + bnd_np[tid] = False + if sp.is_byte(tid): base_np[tid] = 1; continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): space_np[tid] = True; piece = piece[1:] + base_np[tid] = len(piece.encode("utf-8")) + return (torch.tensor(base_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(bnd_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No val files: {pattern}") + tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[:usable + 1] + +@torch.no_grad() +def eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut): + local_tok = args.val_batch_size // (world_size * grad_accum_steps) + local_seqs = local_tok // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for s in range(seq_start, seq_end, local_seqs): + se = min(s + local_seqs, seq_end) + raw = val_tokens[s * args.train_seq_len:(se * args.train_seq_len + 1)] + raw = raw.to(device=device, dtype=torch.int64, non_blocking=True) + x = raw[:-1].reshape(-1, args.train_seq_len) + y = raw[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y).detach() + n = float(y.numel()) + loss_sum += loss.to(torch.float64) * n + tok_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_lut[tgt_ids].to(torch.int16) + tb += (space_lut[tgt_ids] & ~bnd_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, tok_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + model.train() + val_loss = float((loss_sum / tok_count).item()) + bpb = float((loss_sum / tok_count / math.log(2.0) * tok_count / byte_count).item()) + return val_loss, bpb + +# ── Quantization (from competition baseline) ────────────────────────────────── + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def _keep_float(name: str, t: Tensor, pt_dtypes: dict) -> Tensor: + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + pt_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + cl = torch.maximum(torch.minimum(t32, clip[:, None]), -clip[:, None]) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(cl / sc[:, None]), -127, 127).to(torch.int8).contiguous() + return q, sc.to(INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + sc = torch.tensor(clip / 127.0 if clip > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip, clip) / sc), -127, 127).to(torch.int8).contiguous() + return q, sc + +def quantize_state_dict_int8(sd: dict): + quant, scales, dtypes, passthrough, pt_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors", + "num_nonfloat_tensors","baseline_tensor_bytes","int8_payload_bytes"), 0) + for name, tensor in sd.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = _keep_float(name, t, pt_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = _quantize_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quant[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = {"__quant_format__": "int8_clean_per_row_v1", + "quantized": quant, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if pt_dtypes: obj["passthrough_orig_dtypes"] = pt_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict) -> dict: + out = {} + qm = obj.get("qmeta", {}) + ptod = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qm.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype).contiguous() + for name, t in obj["passthrough"].items(): + ot = t.detach().cpu().contiguous() + if isinstance(ptod.get(name), str): + ot = ot.to(getattr(torch, ptod[name])).contiguous() + out[name] = ot + return out + +# ── Data loading ────────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + hdr = np.fromfile(file, dtype=" Tensor: + chunks, rem = [], n + while rem > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._next(); continue + k = min(rem, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k; rem -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum: int): + local = global_tokens // (self.world_size * grad_accum) + span = local + 1 + chunk = self.stream.take(span * self.world_size) + s = self.rank * span + raw = chunk[s:s + span].to(torch.int64) + x = raw[:-1].reshape(-1, seq_len) + y = raw[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ── Model ───────────────────────────────────────────────────────────────────── + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), b) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__(); self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, p in module.named_parameters(): + if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and p.dtype != torch.float32: + p.data = p.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + self._sl, self._cos, self._sin = 0, None, None + def forward(self, sl: int, device, dtype): + if self._cos is None or self._sl != sl or self._cos.device != device: + t = torch.arange(sl, device=device, dtype=self.inv_freq.dtype) + f = torch.outer(t, self.inv_freq.to(device)) + self._cos = f.cos()[None, None, :, :]; self._sin = f.sin()[None, None, :, :] + self._sl = sl + return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) + +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def fake_quant_per_row(w: Tensor) -> Tensor: + """Per-row percentile int8 QAT with straight-through estimator.""" + if w.ndim < 2 or w.numel() == 0: return w + w32 = w.float() + clip = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + wq = (w32.clamp(-clip[:, None], clip[:, None]) / sc[:, None]).round().clamp(-127, 127) * sc[:, None] + return w + (wq.to(w.dtype) - w).detach() + +class SharedDiffAttention(nn.Module): + """Differential Attention V2 + per-loop low-rank Q delta + GQA.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, rope_base: float, rank: int): + super().__init__() + assert dim % nh == 0 and nh % nkv == 0 + self.nh, self.nkv, self.n_rep = nh, nkv, nh // nkv + self.hd = dim // nh + self.nmem = n_mem + + self.c_q = CastedLinear(dim, dim * 2, bias=False) + self.c_k = CastedLinear(dim, nkv * self.hd * 2, bias=False) + self.c_v = CastedLinear(dim, nkv * self.hd, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + li = 0.8 - 0.6 * math.exp(-0.3) + self.lambda_init = li + self.lambda_q1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_q2 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k2 = nn.Parameter(torch.randn(nh) * 0.1) + + self.q_gain = nn.Parameter(torch.ones(nh, dtype=torch.float32)) + self.delta_A = nn.Parameter(torch.zeros(n_loops, dim * 2, rank)) + self.delta_B = nn.Parameter(torch.zeros(n_loops, rank, dim)) + nn.init.normal_(self.delta_A, std=0.01) + + self.rotary = Rotary(self.hd, base=rope_base) + + def forward(self, x: Tensor, loop_idx: int) -> Tensor: + B, T, C = x.shape + H, KVH, D, NM = self.nh, self.nkv, self.hd, self.nmem + sl = T - NM + + # Q with per-loop low-rank delta + QAT on effective weight + wq = self.c_q.weight + (self.delta_A[loop_idx] @ self.delta_B[loop_idx]) + wq = fake_quant_per_row(wq) + qp = F.linear(x, wq.to(x.dtype)) + q1 = qp[..., :C].reshape(B, T, H, D).transpose(1, 2) + q2 = qp[..., C:].reshape(B, T, H, D).transpose(1, 2) + + # K, V: fresh every loop from updated hidden state + kv = self.c_k(x) + k1 = kv[..., :KVH*D].reshape(B, T, KVH, D).transpose(1, 2) + k2 = kv[..., KVH*D:].reshape(B, T, KVH, D).transpose(1, 2) + v = self.c_v(x).reshape(B, T, KVH, D).transpose(1, 2) + + # QK RMSNorm before RoPE (baseline technique) + q1 = F.rms_norm(q1, (D,)); q2 = F.rms_norm(q2, (D,)) + k1 = F.rms_norm(k1, (D,)); k2 = F.rms_norm(k2, (D,)) + + # RoPE: sequence positions ONLY — memory tokens excluded + cos, sin = self.rotary(sl, x.device, q1.dtype) + def rope_seq(q, k): + qm, qs = q[:, :, :NM], q[:, :, NM:] + km, ks = k[:, :, :NM], k[:, :, NM:] + return (torch.cat([qm, apply_rope(qs, cos, sin)], dim=2), + torch.cat([km, apply_rope(ks, cos, sin)], dim=2)) + q1, k1 = rope_seq(q1, k1) + q2, k2 = rope_seq(q2, k2) + + # Q gain + GQA + g = self.q_gain.to(q1.dtype)[None, :, None, None] + q1 = q1 * g; q2 = q2 * g + k1 = k1.repeat_interleave(self.n_rep, dim=1) + k2 = k2.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + # Differential attention: cancel noise via subtraction + lam = (torch.exp(self.lambda_q1) * torch.exp(self.lambda_k1) + - torch.exp(self.lambda_q2) * torch.exp(self.lambda_k2) + + self.lambda_init).to(q1.dtype)[None, :, None, None] + a1 = F.scaled_dot_product_attention(q1, k1, v, is_causal=True) + a2 = F.scaled_dot_product_attention(q2, k2, v, is_causal=True) + out = (a1 - lam * a2).transpose(1, 2).contiguous().view(B, T, C) + return self.proj(out) + +class ReluSquaredMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + h = dim * mlp_mult + self.fc = CastedLinear(dim, h, bias=False) + self.proj = CastedLinear(h, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class RecurrentBlock(nn.Module): + """Shared block with all per-loop parameters.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, mlp_mult: int, rope_base: float, rank: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SharedDiffAttention(dim, nh, nkv, n_loops, n_mem, rope_base, rank) + self.mlp = ReluSquaredMLP(dim, mlp_mult) + + # Per-loop scale factors (learned attention/MLP gate per depth) + self.attn_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + + # resid_mix: learned interpolation between current x and original x0 + # [0] = weight on current state, [1] = weight on original input + # Initialised to carry forward (mix[0]=1, mix[1]=0), learns to reset selectively + self.resid_mix = nn.Parameter( + torch.stack([torch.ones(n_loops, dim), torch.zeros(n_loops, dim)]).float() + ) + + # Loop position embedding: activation-level depth signal + # Complementary to weight-level low-rank delta + self.loop_embed = nn.Embedding(n_loops, dim) + nn.init.normal_(self.loop_embed.weight, std=0.02) + + def forward(self, x: Tensor, x0: Tensor, loop_idx: int) -> Tensor: + # resid_mix: learned balance of current state vs original input (generalises input injection) + mix = self.resid_mix[:, loop_idx, :].to(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # Loop position embedding (activation-level specialisation) + x = x + self.loop_embed.weight[loop_idx].to(x.dtype)[None, None, :] + + # Attention + MLP with per-loop scale + a = self.attn(self.attn_norm(x), loop_idx) + x = x + self.attn_scale[loop_idx].to(x.dtype)[None, None, :] * a + x = x + self.mlp_scale[loop_idx].to(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class RLM(nn.Module): + """ + Recurrent Language Model. + Single shared block × N_LOOPS with deep supervision and U-Net skips. + """ + def __init__(self, args: Hyperparameters): + super().__init__() + dim, nl, nm = args.model_dim, args.n_loops, args.n_memory + self.n_loops, self.n_memory = nl, nm + self.dim = dim + self.logit_softcap = args.logit_softcap + self.deep_sup_gamma = args.deep_sup_gamma + + self.n_enc = nl // 2 + self.n_skip = min(self.n_enc, nl - self.n_enc) + + self.tok_emb = nn.Embedding(args.vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + self.memory = nn.Parameter(torch.randn(1, nm, dim) * 0.02) + + self.block = RecurrentBlock( + dim, args.num_heads, args.num_kv_heads, nl, nm, + args.mlp_mult, args.rope_base, args.lora_rank + ) + + self.skip_weights = nn.Parameter( + torch.ones(self.n_skip, dim, dtype=torch.float32) + if self.n_skip > 0 else torch.zeros(0, dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + + # Zero-init output projections for stability at loop depth + nn.init.zeros_(self.block.attn.proj.weight) + nn.init.zeros_(self.block.mlp.proj.weight) + + def _logits(self, x: Tensor) -> Tensor: + h = self.final_norm(x[:, self.n_memory:, :]) + l = F.linear(h.reshape(-1, self.dim), self.tok_emb.weight) + return self.logit_softcap * torch.tanh(l / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + B, T = input_ids.shape + tok = F.rms_norm(self.tok_emb(input_ids), (self.dim,)) + x = torch.cat([self.memory.expand(B, -1, -1), tok], dim=1) + x0 = x + + gamma = self.deep_sup_gamma + total_loss = torch.zeros((), device=input_ids.device) + weight_sum = 0.0 + targets = target_ids.reshape(-1) + skips = [] + + for i in range(self.n_loops): + # Gradient checkpointing: recompute activations during backward + x = torch.utils.checkpoint.checkpoint(self.block, x, x0, i, use_reentrant=False) + + # U-Net: encoder half saves, decoder half consumes in reverse + if i < self.n_enc: + skips.append(x) + elif self.n_skip > 0 and skips: + di = i - self.n_enc + if di < self.n_skip: + sw = self.skip_weights[di].to(x.dtype)[None, None, :] + x = x + sw * skips[self.n_enc - 1 - di] + + # Deep supervision: weighted loss at every loop + # Later loops weighted higher (gamma^(N-1-i)), but all contribute + w = gamma ** (self.n_loops - 1 - i) + total_loss = total_loss + w * F.cross_entropy( + self._logits(x).float(), targets, reduction="mean" + ) + weight_sum += w + + return total_loss / weight_sum + +# ── Training ────────────────────────────────────────────────────────────────── + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + pass # torch.compile disabled for T4 + + # ── Distributed + CUDA setup ────────────────────────────────────────────── + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required. For CPU testing use train_gpt_local.py") + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(True); enable_math_sdp(True) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master else None + + def log0(msg: str, console: bool = True) -> None: + if not master: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + + # ── Seed + tokenizer ────────────────────────────────────────────────────── + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"Tokenizer vocab {sp.vocab_size()} != VOCAB_SIZE {args.vocab_size}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_lut, space_lut, bnd_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_tokens:{val_tokens.numel()-1} train_seq_len:{args.train_seq_len}") + + # ── Model ───────────────────────────────────────────────────────────────── + base_model = RLM(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(base_model) + compiled = base_model # torch.compile disabled: T4 compile too slow + model: nn.Module = (DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) + if distributed else compiled) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} (~{n_params/1e6:.2f}M)") + log0(f"n_loops:{args.n_loops} n_memory:{args.n_memory} dim:{args.model_dim} " + f"seq_len:{args.train_seq_len}") + + # ── Optimizers ──────────────────────────────────────────────────────────── + block_named = list(base_model.block.named_parameters()) + matrix_params = [p for n, p in block_named + if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = ([p for n, p in block_named + if p.ndim != 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + + [base_model.skip_weights, base_model.memory, base_model.final_norm.eps + if hasattr(base_model.final_norm, 'eps') and + isinstance(base_model.final_norm.eps, nn.Parameter) else None]) + scalar_params = [p for p in scalar_params if p is not None] + + opt_embed = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_muon = Muon(matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_embed, opt_muon, opt_scalar] + + log0(f"matrix_params:{sum(p.numel() for p in matrix_params):,} " + f"scalar_params:{sum(p.numel() for p in scalar_params):,}") + + # ── Checkpoint resume (thermal crash protection) ────────────────────────── + ckpt_path = f"checkpoint_{args.run_id}.pt" + start_step = 0 + training_time_ms = 0.0 + if os.path.exists(ckpt_path) and master: + log0(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + base_model.load_state_dict(ckpt["model"]) + for opt, sd in zip(optimizers, ckpt["optimizers"]): + opt.load_state_dict(sd) + start_step = ckpt["step"] + training_time_ms = ckpt.get("training_time_ms", 0.0) + log0(f"Resumed at step {start_step}") + if distributed: dist.barrier() + + # ── Data loader ─────────────────────────────────────────────────────────── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad(): + for o in optimizers: o.zero_grad(set_to_none=True) + + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_scale(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wc_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ + if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = args.warmdown_iters * step_ms + rem_ms = max(max_wc_ms - elapsed_ms, 0.0) + return rem_ms / max(wd_ms, 1e-9) if rem_ms <= wd_ms else 1.0 + + # ── Warmup-then-restore ─────────────────────────────────────────────────── + if args.warmup_steps > 0 and start_step == 0: + init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + (model(x, y) * grad_scale).backward() + for o in optimizers: o.step() + zero_grad() + log0(f"warmup_complete:{args.warmup_steps} steps") + base_model.load_state_dict(init_model, strict=True) + for o, sd in zip(optimizers, init_opts): o.load_state_dict(sd) + zero_grad() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ──────────────────────────────────────────────────── + stop_after: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + for step in range(start_step, args.iterations + 1): + last = step == args.iterations or (stop_after is not None and step >= stop_after) + + # Validation + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or last: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + + if last: break + + # Checkpoint save (thermal crash protection) + if args.save_every > 0 and step > start_step and step % args.save_every == 0 and master: + torch.cuda.synchronize() + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + torch.save({"step": step, "model": base_model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "training_time_ms": elapsed_now}, ckpt_path) + log0(f"checkpoint_saved:step={step}") + + # LR update + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_scale(step, elapsed_ms) + for o in optimizers: + for g in o.param_groups: g["lr"] = g["base_lr"] * scale + + # Muon momentum warmup + frac = min(step / max(args.muon_momentum_warmup_steps, 1), 1.0) + mum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for g in opt_muon.param_groups: g["momentum"] = mum + + # Forward + backward with gradient accumulation + zero_grad() + train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for o in optimizers: o.step() + zero_grad() + + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step+1}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ms:.0f}ms step_avg:{approx_ms/(step+1):.2f}ms lr_scale:{scale:.4f}") + + reached = max_wc_ms is not None and approx_ms >= max_wc_ms + if distributed and max_wc_ms is not None: + rt = torch.tensor(int(reached), device=device) + dist.all_reduce(rt, op=dist.ReduceOp.MAX) + reached = bool(rt.item()) + if stop_after is None and reached: + stop_after = step + 1 + + log0(f"peak_memory:{torch.cuda.max_memory_allocated()//1024//1024}MiB " + f"reserved:{torch.cuda.max_memory_reserved()//1024//1024}MiB") + + # ── Serialisation + roundtrip validation ────────────────────────────────── + if master: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + model_bytes = os.path.getsize("final_model.pt") + log0(f"raw_model:{model_bytes} code:{code_bytes} total:{model_bytes+code_bytes}") + + quant_obj, qstats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + if master: + with open("final_model.int8.ptz", "wb") as f: f.write(blob) + qbytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = qstats["baseline_tensor_bytes"] / max(qstats["int8_payload_bytes"], 1) + log0(f"int8_zlib:{qbytes} code:{code_bytes} total:{qbytes+code_bytes} " + f"payload_ratio:{ratio:.2f}x budget_pct:{(qbytes+code_bytes)/16e6*100:.1f}%") + + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: blob_disk = f.read() + qs2 = torch.load(io.BytesIO(zlib.decompress(blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(qs2), strict=True) + torch.cuda.synchronize(); te = time.perf_counter() + qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + torch.cuda.synchronize() + log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-te):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvbpb:.8f}") + + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +PyTorch 2.10.0+cu128 +Fri Mar 20 08:09:59 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 51C P0 27W / 70W | 105MiB / 15360MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 33668 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +val_tokens:62021632 train_seq_len:512 +model_params:3910696 (~3.91M) +n_loops:4 n_memory:0 dim:512 seq_len:512 +matrix_params:3,276,800 scalar_params:109,608 +step:1/3 train_loss:6.9349 train_time:5237ms step_avg:5236.85ms lr_scale:0.0050 +step:2/3 train_loss:6.9121 train_time:10075ms step_avg:5037.37ms lr_scale:0.0033 +step:3/3 train_loss:6.8929 train_time:15151ms step_avg:5050.25ms lr_scale:0.0017 +""" +train_gpt.py — Shared-Weight Recurrent Memory Transformer (RLM) +OpenAI Parameter Golf Competition — Non-Record Submission + +Architecture: One transformer block looped N_LOOPS=16 times (weight tying). + - Differential Attention V2 + per-loop low-rank Q delta (weight specialization) + - resid_mix: learned carry-forward vs input-reset per loop + - Loop position embeddings (activation-level depth signal) + - Memory tokens: learned persistent cross-loop state (RoPE excluded) + - Deep supervision: gamma-weighted loss at every loop iteration + - U-Net skip connections between encoder/decoder loop halves + - Gradient checkpointing: memory-efficient backprop through 16 loops + - QAT: per-row percentile fake-quantization during training + - relu² MLP, logit softcap, QK RMSNorm, q_gain (from baseline) + +At dim=512, 16 loops: ~4.24M params, ~4.24MB int8 — 26.5% of 16MB budget. +Effective compute depth of a ~63M parameter standard transformer. + +Thermal protection: checkpoints saved every SAVE_EVERY steps, auto-resumed. +""" + +from __future__ import annotations +import copy, glob, io, math, os, subprocess, sys, time, uuid, zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ── Hyperparameters ─────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + save_every = int(os.environ.get("SAVE_EVERY", 500)) + + iterations = int(os.environ.get("ITERATIONS", 5000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 600)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 65_536)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0.0)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + n_loops = int(os.environ.get("N_LOOPS", 16)) + n_memory = int(os.environ.get("N_MEMORY", 32)) + mlp_mult = int(os.environ.get("MLP_MULT", 4)) + lora_rank = int(os.environ.get("LORA_RANK", 16)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + deep_sup_gamma= float(os.environ.get("DEEP_SUP_GAMMA", 0.9)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum= float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 300)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.1)) + +# ── Muon Optimizer (from modded-nanogpt / competition baseline) ─────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum, backend_steps, nesterov = ( + group["lr"], group["momentum"], group["backend_steps"], group["nesterov"] + ) + total = sum(int(p.numel()) for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-lr) + curr += p.numel() + return loss + +# ── Control tensor patterns (kept fp32, excluded from Muon) ────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,mlp_scale,resid_mix,q_gain,lambda_q1,lambda_k1,lambda_q2,lambda_k2," + "loop_embed,skip_weight,memory", +).split(",") if p) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# ── Tokenizer-aware BPB (from competition baseline) ─────────────────────────── + +def build_sentencepiece_luts(sp, vocab_size: int, device: torch.device): + sp_vocab = int(sp.vocab_size()) + table = max(sp_vocab, vocab_size) + base_np = np.zeros((table,), dtype=np.int16) + space_np = np.zeros((table,), dtype=np.bool_) + bnd_np = np.ones((table,), dtype=np.bool_) + for tid in range(sp_vocab): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue + bnd_np[tid] = False + if sp.is_byte(tid): base_np[tid] = 1; continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): space_np[tid] = True; piece = piece[1:] + base_np[tid] = len(piece.encode("utf-8")) + return (torch.tensor(base_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(bnd_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No val files: {pattern}") + tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[:usable + 1] + +@torch.no_grad() +def eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut): + local_tok = args.val_batch_size // (world_size * grad_accum_steps) + local_seqs = local_tok // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for s in range(seq_start, seq_end, local_seqs): + se = min(s + local_seqs, seq_end) + raw = val_tokens[s * args.train_seq_len:(se * args.train_seq_len + 1)] + raw = raw.to(device=device, dtype=torch.int64, non_blocking=True) + x = raw[:-1].reshape(-1, args.train_seq_len) + y = raw[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y).detach() + n = float(y.numel()) + loss_sum += loss.to(torch.float64) * n + tok_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_lut[tgt_ids].to(torch.int16) + tb += (space_lut[tgt_ids] & ~bnd_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, tok_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + model.train() + val_loss = float((loss_sum / tok_count).item()) + bpb = float((loss_sum / tok_count / math.log(2.0) * tok_count / byte_count).item()) + return val_loss, bpb + +# ── Quantization (from competition baseline) ────────────────────────────────── + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def _keep_float(name: str, t: Tensor, pt_dtypes: dict) -> Tensor: + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + pt_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + cl = torch.maximum(torch.minimum(t32, clip[:, None]), -clip[:, None]) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(cl / sc[:, None]), -127, 127).to(torch.int8).contiguous() + return q, sc.to(INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + sc = torch.tensor(clip / 127.0 if clip > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip, clip) / sc), -127, 127).to(torch.int8).contiguous() + return q, sc + +def quantize_state_dict_int8(sd: dict): + quant, scales, dtypes, passthrough, pt_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors", + "num_nonfloat_tensors","baseline_tensor_bytes","int8_payload_bytes"), 0) + for name, tensor in sd.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = _keep_float(name, t, pt_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = _quantize_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quant[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = {"__quant_format__": "int8_clean_per_row_v1", + "quantized": quant, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if pt_dtypes: obj["passthrough_orig_dtypes"] = pt_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict) -> dict: + out = {} + qm = obj.get("qmeta", {}) + ptod = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qm.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype).contiguous() + for name, t in obj["passthrough"].items(): + ot = t.detach().cpu().contiguous() + if isinstance(ptod.get(name), str): + ot = ot.to(getattr(torch, ptod[name])).contiguous() + out[name] = ot + return out + +# ── Data loading ────────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + hdr = np.fromfile(file, dtype=" Tensor: + chunks, rem = [], n + while rem > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._next(); continue + k = min(rem, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k; rem -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum: int): + local = global_tokens // (self.world_size * grad_accum) + span = local + 1 + chunk = self.stream.take(span * self.world_size) + s = self.rank * span + raw = chunk[s:s + span].to(torch.int64) + x = raw[:-1].reshape(-1, seq_len) + y = raw[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ── Model ───────────────────────────────────────────────────────────────────── + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), b) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__(); self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, p in module.named_parameters(): + if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and p.dtype != torch.float32: + p.data = p.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + self._sl, self._cos, self._sin = 0, None, None + def forward(self, sl: int, device, dtype): + if self._cos is None or self._sl != sl or self._cos.device != device: + t = torch.arange(sl, device=device, dtype=self.inv_freq.dtype) + f = torch.outer(t, self.inv_freq.to(device)) + self._cos = f.cos()[None, None, :, :]; self._sin = f.sin()[None, None, :, :] + self._sl = sl + return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) + +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def fake_quant_per_row(w: Tensor) -> Tensor: + """Per-row percentile int8 QAT with straight-through estimator.""" + if w.ndim < 2 or w.numel() == 0: return w + w32 = w.float() + clip = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + wq = (w32.clamp(-clip[:, None], clip[:, None]) / sc[:, None]).round().clamp(-127, 127) * sc[:, None] + return w + (wq.to(w.dtype) - w).detach() + +class SharedDiffAttention(nn.Module): + """Differential Attention V2 + per-loop low-rank Q delta + GQA.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, rope_base: float, rank: int): + super().__init__() + assert dim % nh == 0 and nh % nkv == 0 + self.nh, self.nkv, self.n_rep = nh, nkv, nh // nkv + self.hd = dim // nh + self.nmem = n_mem + + self.c_q = CastedLinear(dim, dim * 2, bias=False) + self.c_k = CastedLinear(dim, nkv * self.hd * 2, bias=False) + self.c_v = CastedLinear(dim, nkv * self.hd, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + li = 0.8 - 0.6 * math.exp(-0.3) + self.lambda_init = li + self.lambda_q1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_q2 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k2 = nn.Parameter(torch.randn(nh) * 0.1) + + self.q_gain = nn.Parameter(torch.ones(nh, dtype=torch.float32)) + self.delta_A = nn.Parameter(torch.zeros(n_loops, dim * 2, rank)) + self.delta_B = nn.Parameter(torch.zeros(n_loops, rank, dim)) + nn.init.normal_(self.delta_A, std=0.01) + + self.rotary = Rotary(self.hd, base=rope_base) + + def forward(self, x: Tensor, loop_idx: int) -> Tensor: + B, T, C = x.shape + H, KVH, D, NM = self.nh, self.nkv, self.hd, self.nmem + sl = T - NM + + # Q with per-loop low-rank delta + QAT on effective weight + wq = self.c_q.weight + (self.delta_A[loop_idx] @ self.delta_B[loop_idx]) + wq = fake_quant_per_row(wq) + qp = F.linear(x, wq.to(x.dtype)) + q1 = qp[..., :C].reshape(B, T, H, D).transpose(1, 2) + q2 = qp[..., C:].reshape(B, T, H, D).transpose(1, 2) + + # K, V: fresh every loop from updated hidden state + kv = self.c_k(x) + k1 = kv[..., :KVH*D].reshape(B, T, KVH, D).transpose(1, 2) + k2 = kv[..., KVH*D:].reshape(B, T, KVH, D).transpose(1, 2) + v = self.c_v(x).reshape(B, T, KVH, D).transpose(1, 2) + + # QK RMSNorm before RoPE (baseline technique) + q1 = F.rms_norm(q1, (D,)); q2 = F.rms_norm(q2, (D,)) + k1 = F.rms_norm(k1, (D,)); k2 = F.rms_norm(k2, (D,)) + + # RoPE: sequence positions ONLY — memory tokens excluded + cos, sin = self.rotary(sl, x.device, q1.dtype) + def rope_seq(q, k): + qm, qs = q[:, :, :NM], q[:, :, NM:] + km, ks = k[:, :, :NM], k[:, :, NM:] + return (torch.cat([qm, apply_rope(qs, cos, sin)], dim=2), + torch.cat([km, apply_rope(ks, cos, sin)], dim=2)) + q1, k1 = rope_seq(q1, k1) + q2, k2 = rope_seq(q2, k2) + + # Q gain + GQA + g = self.q_gain.to(q1.dtype)[None, :, None, None] + q1 = q1 * g; q2 = q2 * g + k1 = k1.repeat_interleave(self.n_rep, dim=1) + k2 = k2.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + # Differential attention: cancel noise via subtraction + lam = (torch.exp(self.lambda_q1) * torch.exp(self.lambda_k1) + - torch.exp(self.lambda_q2) * torch.exp(self.lambda_k2) + + self.lambda_init).to(q1.dtype)[None, :, None, None] + a1 = F.scaled_dot_product_attention(q1, k1, v, is_causal=True) + a2 = F.scaled_dot_product_attention(q2, k2, v, is_causal=True) + out = (a1 - lam * a2).transpose(1, 2).contiguous().view(B, T, C) + return self.proj(out) + +class ReluSquaredMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + h = dim * mlp_mult + self.fc = CastedLinear(dim, h, bias=False) + self.proj = CastedLinear(h, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class RecurrentBlock(nn.Module): + """Shared block with all per-loop parameters.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, mlp_mult: int, rope_base: float, rank: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SharedDiffAttention(dim, nh, nkv, n_loops, n_mem, rope_base, rank) + self.mlp = ReluSquaredMLP(dim, mlp_mult) + + # Per-loop scale factors (learned attention/MLP gate per depth) + self.attn_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + + # resid_mix: learned interpolation between current x and original x0 + # [0] = weight on current state, [1] = weight on original input + # Initialised to carry forward (mix[0]=1, mix[1]=0), learns to reset selectively + self.resid_mix = nn.Parameter( + torch.stack([torch.ones(n_loops, dim), torch.zeros(n_loops, dim)]).float() + ) + + # Loop position embedding: activation-level depth signal + # Complementary to weight-level low-rank delta + self.loop_embed = nn.Embedding(n_loops, dim) + nn.init.normal_(self.loop_embed.weight, std=0.02) + + def forward(self, x: Tensor, x0: Tensor, loop_idx: int) -> Tensor: + # resid_mix: learned balance of current state vs original input (generalises input injection) + mix = self.resid_mix[:, loop_idx, :].to(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # Loop position embedding (activation-level specialisation) + x = x + self.loop_embed.weight[loop_idx].to(x.dtype)[None, None, :] + + # Attention + MLP with per-loop scale + a = self.attn(self.attn_norm(x), loop_idx) + x = x + self.attn_scale[loop_idx].to(x.dtype)[None, None, :] * a + x = x + self.mlp_scale[loop_idx].to(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class RLM(nn.Module): + """ + Recurrent Language Model. + Single shared block × N_LOOPS with deep supervision and U-Net skips. + """ + def __init__(self, args: Hyperparameters): + super().__init__() + dim, nl, nm = args.model_dim, args.n_loops, args.n_memory + self.n_loops, self.n_memory = nl, nm + self.dim = dim + self.logit_softcap = args.logit_softcap + self.deep_sup_gamma = args.deep_sup_gamma + + self.n_enc = nl // 2 + self.n_skip = min(self.n_enc, nl - self.n_enc) + + self.tok_emb = nn.Embedding(args.vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + self.memory = nn.Parameter(torch.randn(1, nm, dim) * 0.02) + + self.block = RecurrentBlock( + dim, args.num_heads, args.num_kv_heads, nl, nm, + args.mlp_mult, args.rope_base, args.lora_rank + ) + + self.skip_weights = nn.Parameter( + torch.ones(self.n_skip, dim, dtype=torch.float32) + if self.n_skip > 0 else torch.zeros(0, dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + + # Zero-init output projections for stability at loop depth + nn.init.zeros_(self.block.attn.proj.weight) + nn.init.zeros_(self.block.mlp.proj.weight) + + def _logits(self, x: Tensor) -> Tensor: + h = self.final_norm(x[:, self.n_memory:, :]) + l = F.linear(h.reshape(-1, self.dim), self.tok_emb.weight) + return self.logit_softcap * torch.tanh(l / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + B, T = input_ids.shape + tok = F.rms_norm(self.tok_emb(input_ids), (self.dim,)) + x = torch.cat([self.memory.expand(B, -1, -1), tok], dim=1) + x0 = x + + gamma = self.deep_sup_gamma + total_loss = torch.zeros((), device=input_ids.device) + weight_sum = 0.0 + targets = target_ids.reshape(-1) + skips = [] + + for i in range(self.n_loops): + # Gradient checkpointing: recompute activations during backward + x = torch.utils.checkpoint.checkpoint(self.block, x, x0, i, use_reentrant=False) + + # U-Net: encoder half saves, decoder half consumes in reverse + if i < self.n_enc: + skips.append(x) + elif self.n_skip > 0 and skips: + di = i - self.n_enc + if di < self.n_skip: + sw = self.skip_weights[di].to(x.dtype)[None, None, :] + x = x + sw * skips[self.n_enc - 1 - di] + + # Deep supervision: weighted loss at every loop + # Later loops weighted higher (gamma^(N-1-i)), but all contribute + w = gamma ** (self.n_loops - 1 - i) + total_loss = total_loss + w * F.cross_entropy( + self._logits(x).float(), targets, reduction="mean" + ) + weight_sum += w + + return total_loss / weight_sum + +# ── Training ────────────────────────────────────────────────────────────────── + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + pass # torch.compile disabled for T4 + + # ── Distributed + CUDA setup ────────────────────────────────────────────── + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required. For CPU testing use train_gpt_local.py") + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(True); enable_math_sdp(True) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master else None + + def log0(msg: str, console: bool = True) -> None: + if not master: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + + # ── Seed + tokenizer ────────────────────────────────────────────────────── + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"Tokenizer vocab {sp.vocab_size()} != VOCAB_SIZE {args.vocab_size}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_lut, space_lut, bnd_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_tokens:{val_tokens.numel()-1} train_seq_len:{args.train_seq_len}") + + # ── Model ───────────────────────────────────────────────────────────────── + base_model = RLM(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(base_model) + compiled = base_model # torch.compile disabled: T4 compile too slow + model: nn.Module = (DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) + if distributed else compiled) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} (~{n_params/1e6:.2f}M)") + log0(f"n_loops:{args.n_loops} n_memory:{args.n_memory} dim:{args.model_dim} " + f"seq_len:{args.train_seq_len}") + + # ── Optimizers ──────────────────────────────────────────────────────────── + block_named = list(base_model.block.named_parameters()) + matrix_params = [p for n, p in block_named + if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = ([p for n, p in block_named + if p.ndim != 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + + [base_model.skip_weights, base_model.memory, base_model.final_norm.eps + if hasattr(base_model.final_norm, 'eps') and + isinstance(base_model.final_norm.eps, nn.Parameter) else None]) + scalar_params = [p for p in scalar_params if p is not None] + + opt_embed = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_muon = Muon(matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_embed, opt_muon, opt_scalar] + + log0(f"matrix_params:{sum(p.numel() for p in matrix_params):,} " + f"scalar_params:{sum(p.numel() for p in scalar_params):,}") + + # ── Checkpoint resume (thermal crash protection) ────────────────────────── + ckpt_path = f"checkpoint_{args.run_id}.pt" + start_step = 0 + training_time_ms = 0.0 + if os.path.exists(ckpt_path) and master: + log0(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + base_model.load_state_dict(ckpt["model"]) + for opt, sd in zip(optimizers, ckpt["optimizers"]): + opt.load_state_dict(sd) + start_step = ckpt["step"] + training_time_ms = ckpt.get("training_time_ms", 0.0) + log0(f"Resumed at step {start_step}") + if distributed: dist.barrier() + + # ── Data loader ─────────────────────────────────────────────────────────── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad(): + for o in optimizers: o.zero_grad(set_to_none=True) + + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_scale(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wc_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ + if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = args.warmdown_iters * step_ms + rem_ms = max(max_wc_ms - elapsed_ms, 0.0) + return rem_ms / max(wd_ms, 1e-9) if rem_ms <= wd_ms else 1.0 + + # ── Warmup-then-restore ─────────────────────────────────────────────────── + if args.warmup_steps > 0 and start_step == 0: + init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + (model(x, y) * grad_scale).backward() + for o in optimizers: o.step() + zero_grad() + log0(f"warmup_complete:{args.warmup_steps} steps") + base_model.load_state_dict(init_model, strict=True) + for o, sd in zip(optimizers, init_opts): o.load_state_dict(sd) + zero_grad() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ──────────────────────────────────────────────────── + stop_after: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + for step in range(start_step, args.iterations + 1): + last = step == args.iterations or (stop_after is not None and step >= stop_after) + + # Validation + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or last: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + + if last: break + + # Checkpoint save (thermal crash protection) + if args.save_every > 0 and step > start_step and step % args.save_every == 0 and master: + torch.cuda.synchronize() + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + torch.save({"step": step, "model": base_model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "training_time_ms": elapsed_now}, ckpt_path) + log0(f"checkpoint_saved:step={step}") + + # LR update + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_scale(step, elapsed_ms) + for o in optimizers: + for g in o.param_groups: g["lr"] = g["base_lr"] * scale + + # Muon momentum warmup + frac = min(step / max(args.muon_momentum_warmup_steps, 1), 1.0) + mum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for g in opt_muon.param_groups: g["momentum"] = mum + + # Forward + backward with gradient accumulation + zero_grad() + train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for o in optimizers: o.step() + zero_grad() + + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step+1}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ms:.0f}ms step_avg:{approx_ms/(step+1):.2f}ms lr_scale:{scale:.4f}") + + reached = max_wc_ms is not None and approx_ms >= max_wc_ms + if distributed and max_wc_ms is not None: + rt = torch.tensor(int(reached), device=device) + dist.all_reduce(rt, op=dist.ReduceOp.MAX) + reached = bool(rt.item()) + if stop_after is None and reached: + stop_after = step + 1 + + log0(f"peak_memory:{torch.cuda.max_memory_allocated()//1024//1024}MiB " + f"reserved:{torch.cuda.max_memory_reserved()//1024//1024}MiB") + + # ── Serialisation + roundtrip validation ────────────────────────────────── + if master: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + model_bytes = os.path.getsize("final_model.pt") + log0(f"raw_model:{model_bytes} code:{code_bytes} total:{model_bytes+code_bytes}") + + quant_obj, qstats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + if master: + with open("final_model.int8.ptz", "wb") as f: f.write(blob) + qbytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = qstats["baseline_tensor_bytes"] / max(qstats["int8_payload_bytes"], 1) + log0(f"int8_zlib:{qbytes} code:{code_bytes} total:{qbytes+code_bytes} " + f"payload_ratio:{ratio:.2f}x budget_pct:{(qbytes+code_bytes)/16e6*100:.1f}%") + + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: blob_disk = f.read() + qs2 = torch.load(io.BytesIO(zlib.decompress(blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(qs2), strict=True) + torch.cuda.synchronize(); te = time.perf_counter() + qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + torch.cuda.synchronize() + log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-te):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvbpb:.8f}") + + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +PyTorch 2.10.0+cu128 +Fri Mar 20 08:10:34 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 57C P0 27W / 70W | 105MiB / 15360MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 33831 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +val_tokens:62021632 train_seq_len:512 +model_params:3855912 (~3.86M) +n_loops:2 n_memory:0 dim:512 seq_len:512 +matrix_params:3,276,800 scalar_params:54,824 +step:1/3 train_loss:6.9348 train_time:2958ms step_avg:2958.03ms lr_scale:0.0050 +step:2/3 train_loss:6.9122 train_time:5283ms step_avg:2641.67ms lr_scale:0.0033 +step:3/3 train_loss:6.8931 train_time:7851ms step_avg:2616.99ms lr_scale:0.0017 +""" +train_gpt.py — Shared-Weight Recurrent Memory Transformer (RLM) +OpenAI Parameter Golf Competition — Non-Record Submission + +Architecture: One transformer block looped N_LOOPS=16 times (weight tying). + - Differential Attention V2 + per-loop low-rank Q delta (weight specialization) + - resid_mix: learned carry-forward vs input-reset per loop + - Loop position embeddings (activation-level depth signal) + - Memory tokens: learned persistent cross-loop state (RoPE excluded) + - Deep supervision: gamma-weighted loss at every loop iteration + - U-Net skip connections between encoder/decoder loop halves + - Gradient checkpointing: memory-efficient backprop through 16 loops + - QAT: per-row percentile fake-quantization during training + - relu² MLP, logit softcap, QK RMSNorm, q_gain (from baseline) + +At dim=512, 16 loops: ~4.24M params, ~4.24MB int8 — 26.5% of 16MB budget. +Effective compute depth of a ~63M parameter standard transformer. + +Thermal protection: checkpoints saved every SAVE_EVERY steps, auto-resumed. +""" + +from __future__ import annotations +import copy, glob, io, math, os, subprocess, sys, time, uuid, zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ── Hyperparameters ─────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + save_every = int(os.environ.get("SAVE_EVERY", 500)) + + iterations = int(os.environ.get("ITERATIONS", 5000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 600)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 65_536)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0.0)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + n_loops = int(os.environ.get("N_LOOPS", 16)) + n_memory = int(os.environ.get("N_MEMORY", 32)) + mlp_mult = int(os.environ.get("MLP_MULT", 4)) + lora_rank = int(os.environ.get("LORA_RANK", 16)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + deep_sup_gamma= float(os.environ.get("DEEP_SUP_GAMMA", 0.9)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum= float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 300)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.1)) + +# ── Muon Optimizer (from modded-nanogpt / competition baseline) ─────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum, backend_steps, nesterov = ( + group["lr"], group["momentum"], group["backend_steps"], group["nesterov"] + ) + total = sum(int(p.numel()) for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-lr) + curr += p.numel() + return loss + +# ── Control tensor patterns (kept fp32, excluded from Muon) ────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,mlp_scale,resid_mix,q_gain,lambda_q1,lambda_k1,lambda_q2,lambda_k2," + "loop_embed,skip_weight,memory", +).split(",") if p) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# ── Tokenizer-aware BPB (from competition baseline) ─────────────────────────── + +def build_sentencepiece_luts(sp, vocab_size: int, device: torch.device): + sp_vocab = int(sp.vocab_size()) + table = max(sp_vocab, vocab_size) + base_np = np.zeros((table,), dtype=np.int16) + space_np = np.zeros((table,), dtype=np.bool_) + bnd_np = np.ones((table,), dtype=np.bool_) + for tid in range(sp_vocab): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue + bnd_np[tid] = False + if sp.is_byte(tid): base_np[tid] = 1; continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): space_np[tid] = True; piece = piece[1:] + base_np[tid] = len(piece.encode("utf-8")) + return (torch.tensor(base_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(bnd_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No val files: {pattern}") + tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[:usable + 1] + +@torch.no_grad() +def eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut): + local_tok = args.val_batch_size // (world_size * grad_accum_steps) + local_seqs = local_tok // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for s in range(seq_start, seq_end, local_seqs): + se = min(s + local_seqs, seq_end) + raw = val_tokens[s * args.train_seq_len:(se * args.train_seq_len + 1)] + raw = raw.to(device=device, dtype=torch.int64, non_blocking=True) + x = raw[:-1].reshape(-1, args.train_seq_len) + y = raw[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y).detach() + n = float(y.numel()) + loss_sum += loss.to(torch.float64) * n + tok_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_lut[tgt_ids].to(torch.int16) + tb += (space_lut[tgt_ids] & ~bnd_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, tok_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + model.train() + val_loss = float((loss_sum / tok_count).item()) + bpb = float((loss_sum / tok_count / math.log(2.0) * tok_count / byte_count).item()) + return val_loss, bpb + +# ── Quantization (from competition baseline) ────────────────────────────────── + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def _keep_float(name: str, t: Tensor, pt_dtypes: dict) -> Tensor: + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + pt_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + cl = torch.maximum(torch.minimum(t32, clip[:, None]), -clip[:, None]) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(cl / sc[:, None]), -127, 127).to(torch.int8).contiguous() + return q, sc.to(INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + sc = torch.tensor(clip / 127.0 if clip > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip, clip) / sc), -127, 127).to(torch.int8).contiguous() + return q, sc + +def quantize_state_dict_int8(sd: dict): + quant, scales, dtypes, passthrough, pt_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors", + "num_nonfloat_tensors","baseline_tensor_bytes","int8_payload_bytes"), 0) + for name, tensor in sd.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = _keep_float(name, t, pt_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = _quantize_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quant[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = {"__quant_format__": "int8_clean_per_row_v1", + "quantized": quant, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if pt_dtypes: obj["passthrough_orig_dtypes"] = pt_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict) -> dict: + out = {} + qm = obj.get("qmeta", {}) + ptod = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qm.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype).contiguous() + for name, t in obj["passthrough"].items(): + ot = t.detach().cpu().contiguous() + if isinstance(ptod.get(name), str): + ot = ot.to(getattr(torch, ptod[name])).contiguous() + out[name] = ot + return out + +# ── Data loading ────────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + hdr = np.fromfile(file, dtype=" Tensor: + chunks, rem = [], n + while rem > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._next(); continue + k = min(rem, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k; rem -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum: int): + local = global_tokens // (self.world_size * grad_accum) + span = local + 1 + chunk = self.stream.take(span * self.world_size) + s = self.rank * span + raw = chunk[s:s + span].to(torch.int64) + x = raw[:-1].reshape(-1, seq_len) + y = raw[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ── Model ───────────────────────────────────────────────────────────────────── + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), b) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__(); self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, p in module.named_parameters(): + if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and p.dtype != torch.float32: + p.data = p.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + self._sl, self._cos, self._sin = 0, None, None + def forward(self, sl: int, device, dtype): + if self._cos is None or self._sl != sl or self._cos.device != device: + t = torch.arange(sl, device=device, dtype=self.inv_freq.dtype) + f = torch.outer(t, self.inv_freq.to(device)) + self._cos = f.cos()[None, None, :, :]; self._sin = f.sin()[None, None, :, :] + self._sl = sl + return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) + +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def fake_quant_per_row(w: Tensor) -> Tensor: + """Per-row percentile int8 QAT with straight-through estimator.""" + if w.ndim < 2 or w.numel() == 0: return w + w32 = w.float() + clip = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + wq = (w32.clamp(-clip[:, None], clip[:, None]) / sc[:, None]).round().clamp(-127, 127) * sc[:, None] + return w + (wq.to(w.dtype) - w).detach() + +class SharedDiffAttention(nn.Module): + """Differential Attention V2 + per-loop low-rank Q delta + GQA.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, rope_base: float, rank: int): + super().__init__() + assert dim % nh == 0 and nh % nkv == 0 + self.nh, self.nkv, self.n_rep = nh, nkv, nh // nkv + self.hd = dim // nh + self.nmem = n_mem + + self.c_q = CastedLinear(dim, dim * 2, bias=False) + self.c_k = CastedLinear(dim, nkv * self.hd * 2, bias=False) + self.c_v = CastedLinear(dim, nkv * self.hd, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + li = 0.8 - 0.6 * math.exp(-0.3) + self.lambda_init = li + self.lambda_q1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_q2 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k2 = nn.Parameter(torch.randn(nh) * 0.1) + + self.q_gain = nn.Parameter(torch.ones(nh, dtype=torch.float32)) + self.delta_A = nn.Parameter(torch.zeros(n_loops, dim * 2, rank)) + self.delta_B = nn.Parameter(torch.zeros(n_loops, rank, dim)) + nn.init.normal_(self.delta_A, std=0.01) + + self.rotary = Rotary(self.hd, base=rope_base) + + def forward(self, x: Tensor, loop_idx: int) -> Tensor: + B, T, C = x.shape + H, KVH, D, NM = self.nh, self.nkv, self.hd, self.nmem + sl = T - NM + + # Q with per-loop low-rank delta + QAT on effective weight + wq = self.c_q.weight + (self.delta_A[loop_idx] @ self.delta_B[loop_idx]) + wq = fake_quant_per_row(wq) + qp = F.linear(x, wq.to(x.dtype)) + q1 = qp[..., :C].reshape(B, T, H, D).transpose(1, 2) + q2 = qp[..., C:].reshape(B, T, H, D).transpose(1, 2) + + # K, V: fresh every loop from updated hidden state + kv = self.c_k(x) + k1 = kv[..., :KVH*D].reshape(B, T, KVH, D).transpose(1, 2) + k2 = kv[..., KVH*D:].reshape(B, T, KVH, D).transpose(1, 2) + v = self.c_v(x).reshape(B, T, KVH, D).transpose(1, 2) + + # QK RMSNorm before RoPE (baseline technique) + q1 = F.rms_norm(q1, (D,)); q2 = F.rms_norm(q2, (D,)) + k1 = F.rms_norm(k1, (D,)); k2 = F.rms_norm(k2, (D,)) + + # RoPE: sequence positions ONLY — memory tokens excluded + cos, sin = self.rotary(sl, x.device, q1.dtype) + def rope_seq(q, k): + qm, qs = q[:, :, :NM], q[:, :, NM:] + km, ks = k[:, :, :NM], k[:, :, NM:] + return (torch.cat([qm, apply_rope(qs, cos, sin)], dim=2), + torch.cat([km, apply_rope(ks, cos, sin)], dim=2)) + q1, k1 = rope_seq(q1, k1) + q2, k2 = rope_seq(q2, k2) + + # Q gain + GQA + g = self.q_gain.to(q1.dtype)[None, :, None, None] + q1 = q1 * g; q2 = q2 * g + k1 = k1.repeat_interleave(self.n_rep, dim=1) + k2 = k2.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + # Differential attention: cancel noise via subtraction + lam = (torch.exp(self.lambda_q1) * torch.exp(self.lambda_k1) + - torch.exp(self.lambda_q2) * torch.exp(self.lambda_k2) + + self.lambda_init).to(q1.dtype)[None, :, None, None] + a1 = F.scaled_dot_product_attention(q1, k1, v, is_causal=True) + a2 = F.scaled_dot_product_attention(q2, k2, v, is_causal=True) + out = (a1 - lam * a2).transpose(1, 2).contiguous().view(B, T, C) + return self.proj(out) + +class ReluSquaredMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + h = dim * mlp_mult + self.fc = CastedLinear(dim, h, bias=False) + self.proj = CastedLinear(h, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class RecurrentBlock(nn.Module): + """Shared block with all per-loop parameters.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, mlp_mult: int, rope_base: float, rank: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SharedDiffAttention(dim, nh, nkv, n_loops, n_mem, rope_base, rank) + self.mlp = ReluSquaredMLP(dim, mlp_mult) + + # Per-loop scale factors (learned attention/MLP gate per depth) + self.attn_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + + # resid_mix: learned interpolation between current x and original x0 + # [0] = weight on current state, [1] = weight on original input + # Initialised to carry forward (mix[0]=1, mix[1]=0), learns to reset selectively + self.resid_mix = nn.Parameter( + torch.stack([torch.ones(n_loops, dim), torch.zeros(n_loops, dim)]).float() + ) + + # Loop position embedding: activation-level depth signal + # Complementary to weight-level low-rank delta + self.loop_embed = nn.Embedding(n_loops, dim) + nn.init.normal_(self.loop_embed.weight, std=0.02) + + def forward(self, x: Tensor, x0: Tensor, loop_idx: int) -> Tensor: + # resid_mix: learned balance of current state vs original input (generalises input injection) + mix = self.resid_mix[:, loop_idx, :].to(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # Loop position embedding (activation-level specialisation) + x = x + self.loop_embed.weight[loop_idx].to(x.dtype)[None, None, :] + + # Attention + MLP with per-loop scale + a = self.attn(self.attn_norm(x), loop_idx) + x = x + self.attn_scale[loop_idx].to(x.dtype)[None, None, :] * a + x = x + self.mlp_scale[loop_idx].to(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class RLM(nn.Module): + """ + Recurrent Language Model. + Single shared block × N_LOOPS with deep supervision and U-Net skips. + """ + def __init__(self, args: Hyperparameters): + super().__init__() + dim, nl, nm = args.model_dim, args.n_loops, args.n_memory + self.n_loops, self.n_memory = nl, nm + self.dim = dim + self.logit_softcap = args.logit_softcap + self.deep_sup_gamma = args.deep_sup_gamma + + self.n_enc = nl // 2 + self.n_skip = min(self.n_enc, nl - self.n_enc) + + self.tok_emb = nn.Embedding(args.vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + self.memory = nn.Parameter(torch.randn(1, nm, dim) * 0.02) + + self.block = RecurrentBlock( + dim, args.num_heads, args.num_kv_heads, nl, nm, + args.mlp_mult, args.rope_base, args.lora_rank + ) + + self.skip_weights = nn.Parameter( + torch.ones(self.n_skip, dim, dtype=torch.float32) + if self.n_skip > 0 else torch.zeros(0, dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + + # Zero-init output projections for stability at loop depth + nn.init.zeros_(self.block.attn.proj.weight) + nn.init.zeros_(self.block.mlp.proj.weight) + + def _logits(self, x: Tensor) -> Tensor: + h = self.final_norm(x[:, self.n_memory:, :]) + l = F.linear(h.reshape(-1, self.dim), self.tok_emb.weight) + return self.logit_softcap * torch.tanh(l / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + B, T = input_ids.shape + tok = F.rms_norm(self.tok_emb(input_ids), (self.dim,)) + x = torch.cat([self.memory.expand(B, -1, -1), tok], dim=1) + x0 = x + + gamma = self.deep_sup_gamma + total_loss = torch.zeros((), device=input_ids.device) + weight_sum = 0.0 + targets = target_ids.reshape(-1) + skips = [] + + for i in range(self.n_loops): + # Gradient checkpointing: recompute activations during backward + x = torch.utils.checkpoint.checkpoint(self.block, x, x0, i, use_reentrant=False) + + # U-Net: encoder half saves, decoder half consumes in reverse + if i < self.n_enc: + skips.append(x) + elif self.n_skip > 0 and skips: + di = i - self.n_enc + if di < self.n_skip: + sw = self.skip_weights[di].to(x.dtype)[None, None, :] + x = x + sw * skips[self.n_enc - 1 - di] + + # Deep supervision: weighted loss at every loop + # Later loops weighted higher (gamma^(N-1-i)), but all contribute + w = gamma ** (self.n_loops - 1 - i) + total_loss = total_loss + w * F.cross_entropy( + self._logits(x).float(), targets, reduction="mean" + ) + weight_sum += w + + return total_loss / weight_sum + +# ── Training ────────────────────────────────────────────────────────────────── + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + pass # torch.compile disabled for T4 + + # ── Distributed + CUDA setup ────────────────────────────────────────────── + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required. For CPU testing use train_gpt_local.py") + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(True); enable_math_sdp(True) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master else None + + def log0(msg: str, console: bool = True) -> None: + if not master: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + + # ── Seed + tokenizer ────────────────────────────────────────────────────── + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"Tokenizer vocab {sp.vocab_size()} != VOCAB_SIZE {args.vocab_size}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_lut, space_lut, bnd_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_tokens:{val_tokens.numel()-1} train_seq_len:{args.train_seq_len}") + + # ── Model ───────────────────────────────────────────────────────────────── + base_model = RLM(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(base_model) + compiled = base_model # torch.compile disabled: T4 compile too slow + model: nn.Module = (DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) + if distributed else compiled) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} (~{n_params/1e6:.2f}M)") + log0(f"n_loops:{args.n_loops} n_memory:{args.n_memory} dim:{args.model_dim} " + f"seq_len:{args.train_seq_len}") + + # ── Optimizers ──────────────────────────────────────────────────────────── + block_named = list(base_model.block.named_parameters()) + matrix_params = [p for n, p in block_named + if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = ([p for n, p in block_named + if p.ndim != 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + + [base_model.skip_weights, base_model.memory, base_model.final_norm.eps + if hasattr(base_model.final_norm, 'eps') and + isinstance(base_model.final_norm.eps, nn.Parameter) else None]) + scalar_params = [p for p in scalar_params if p is not None] + + opt_embed = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_muon = Muon(matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_embed, opt_muon, opt_scalar] + + log0(f"matrix_params:{sum(p.numel() for p in matrix_params):,} " + f"scalar_params:{sum(p.numel() for p in scalar_params):,}") + + # ── Checkpoint resume (thermal crash protection) ────────────────────────── + ckpt_path = f"checkpoint_{args.run_id}.pt" + start_step = 0 + training_time_ms = 0.0 + if os.path.exists(ckpt_path) and master: + log0(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + base_model.load_state_dict(ckpt["model"]) + for opt, sd in zip(optimizers, ckpt["optimizers"]): + opt.load_state_dict(sd) + start_step = ckpt["step"] + training_time_ms = ckpt.get("training_time_ms", 0.0) + log0(f"Resumed at step {start_step}") + if distributed: dist.barrier() + + # ── Data loader ─────────────────────────────────────────────────────────── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad(): + for o in optimizers: o.zero_grad(set_to_none=True) + + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_scale(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wc_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ + if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = args.warmdown_iters * step_ms + rem_ms = max(max_wc_ms - elapsed_ms, 0.0) + return rem_ms / max(wd_ms, 1e-9) if rem_ms <= wd_ms else 1.0 + + # ── Warmup-then-restore ─────────────────────────────────────────────────── + if args.warmup_steps > 0 and start_step == 0: + init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + (model(x, y) * grad_scale).backward() + for o in optimizers: o.step() + zero_grad() + log0(f"warmup_complete:{args.warmup_steps} steps") + base_model.load_state_dict(init_model, strict=True) + for o, sd in zip(optimizers, init_opts): o.load_state_dict(sd) + zero_grad() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ──────────────────────────────────────────────────── + stop_after: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + for step in range(start_step, args.iterations + 1): + last = step == args.iterations or (stop_after is not None and step >= stop_after) + + # Validation + if (args.val_loss_every > 0 and step % args.val_loss_every == 0) or last: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + + if last: break + + # Checkpoint save (thermal crash protection) + if args.save_every > 0 and step > start_step and step % args.save_every == 0 and master: + torch.cuda.synchronize() + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + torch.save({"step": step, "model": base_model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "training_time_ms": elapsed_now}, ckpt_path) + log0(f"checkpoint_saved:step={step}") + + # LR update + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_scale(step, elapsed_ms) + for o in optimizers: + for g in o.param_groups: g["lr"] = g["base_lr"] * scale + + # Muon momentum warmup + frac = min(step / max(args.muon_momentum_warmup_steps, 1), 1.0) + mum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for g in opt_muon.param_groups: g["momentum"] = mum + + # Forward + backward with gradient accumulation + zero_grad() + train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for o in optimizers: o.step() + zero_grad() + + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step+1}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ms:.0f}ms step_avg:{approx_ms/(step+1):.2f}ms lr_scale:{scale:.4f}") + + reached = max_wc_ms is not None and approx_ms >= max_wc_ms + if distributed and max_wc_ms is not None: + rt = torch.tensor(int(reached), device=device) + dist.all_reduce(rt, op=dist.ReduceOp.MAX) + reached = bool(rt.item()) + if stop_after is None and reached: + stop_after = step + 1 + + log0(f"peak_memory:{torch.cuda.max_memory_allocated()//1024//1024}MiB " + f"reserved:{torch.cuda.max_memory_reserved()//1024//1024}MiB") + + # ── Serialisation + roundtrip validation ────────────────────────────────── + if master: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + model_bytes = os.path.getsize("final_model.pt") + log0(f"raw_model:{model_bytes} code:{code_bytes} total:{model_bytes+code_bytes}") + + quant_obj, qstats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + if master: + with open("final_model.int8.ptz", "wb") as f: f.write(blob) + qbytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = qstats["baseline_tensor_bytes"] / max(qstats["int8_payload_bytes"], 1) + log0(f"int8_zlib:{qbytes} code:{code_bytes} total:{qbytes+code_bytes} " + f"payload_ratio:{ratio:.2f}x budget_pct:{(qbytes+code_bytes)/16e6*100:.1f}%") + + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: blob_disk = f.read() + qs2 = torch.load(io.BytesIO(zlib.decompress(blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(qs2), strict=True) + torch.cuda.synchronize(); te = time.perf_counter() + qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + torch.cuda.synchronize() + log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-te):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvbpb:.8f}") + + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +PyTorch 2.10.0+cu128 +Fri Mar 20 08:11:11 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 57C P0 27W / 70W | 105MiB / 15360MiB | 4% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 33997 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +val_tokens:62021632 train_seq_len:512 +model_params:3828264 (~3.83M) +n_loops:1 n_memory:0 dim:512 seq_len:512 +matrix_params:3,276,800 scalar_params:27,176 +step:1/3 train_loss:6.9350 train_time:1626ms step_avg:1625.71ms lr_scale:0.0050 +step:2/3 train_loss:6.9124 train_time:2728ms step_avg:1364.17ms lr_scale:0.0033 +step:3/3 train_loss:6.8934 train_time:4016ms step_avg:1338.80ms lr_scale:0.0017 +""" +train_gpt.py — Shared-Weight Recurrent Memory Transformer (RLM) +OpenAI Parameter Golf Competition — Non-Record Submission + +Architecture: One transformer block looped N_LOOPS=16 times (weight tying). + - Differential Attention V2 + per-loop low-rank Q delta (weight specialization) + - resid_mix: learned carry-forward vs input-reset per loop + - Loop position embeddings (activation-level depth signal) + - Memory tokens: learned persistent cross-loop state (RoPE excluded) + - Deep supervision: gamma-weighted loss at every loop iteration + - U-Net skip connections between encoder/decoder loop halves + - Gradient checkpointing: memory-efficient backprop through 16 loops + - QAT: per-row percentile fake-quantization during training + - relu² MLP, logit softcap, QK RMSNorm, q_gain (from baseline) + +At dim=512, 16 loops: ~4.24M params, ~4.24MB int8 — 26.5% of 16MB budget. +Effective compute depth of a ~63M parameter standard transformer. + +Thermal protection: checkpoints saved every SAVE_EVERY steps, auto-resumed. +""" + +from __future__ import annotations +import copy, glob, io, math, os, subprocess, sys, time, uuid, zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.checkpoint import checkpoint as grad_ckpt + +# ── Hyperparameters ─────────────────────────────────────────────────────────── + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + save_every = int(os.environ.get("SAVE_EVERY", 500)) + + iterations = int(os.environ.get("ITERATIONS", 5000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 600)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 65_536)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 0.0)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + n_loops = int(os.environ.get("N_LOOPS", 16)) + n_memory = int(os.environ.get("N_MEMORY", 32)) + mlp_mult = int(os.environ.get("MLP_MULT", 4)) + lora_rank = int(os.environ.get("LORA_RANK", 16)) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + deep_sup_gamma= float(os.environ.get("DEEP_SUP_GAMMA", 0.9)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.05)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum= float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 300)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.1)) + +# ── Muon Optimizer (from modded-nanogpt / competition baseline) ─────────────── + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: continue + lr, momentum, backend_steps, nesterov = ( + group["lr"], group["momentum"], group["backend_steps"], group["nesterov"] + ) + total = sum(int(p.numel()) for p in params) + flat = torch.zeros(total, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + flat[curr:curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: dist.all_reduce(flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + p.add_(flat[curr:curr + p.numel()].view_as(p).to(p.dtype), alpha=-lr) + curr += p.numel() + return loss + +# ── Control tensor patterns (kept fp32, excluded from Muon) ────────────────── + +CONTROL_TENSOR_NAME_PATTERNS = tuple(p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,mlp_scale,resid_mix,q_gain,lambda_q1,lambda_k1,lambda_q2,lambda_k2," + "loop_embed,skip_weight,memory", +).split(",") if p) + +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +# ── Tokenizer-aware BPB (from competition baseline) ─────────────────────────── + +def build_sentencepiece_luts(sp, vocab_size: int, device: torch.device): + sp_vocab = int(sp.vocab_size()) + table = max(sp_vocab, vocab_size) + base_np = np.zeros((table,), dtype=np.int16) + space_np = np.zeros((table,), dtype=np.bool_) + bnd_np = np.ones((table,), dtype=np.bool_) + for tid in range(sp_vocab): + if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid): continue + bnd_np[tid] = False + if sp.is_byte(tid): base_np[tid] = 1; continue + piece = sp.id_to_piece(tid) + if piece.startswith("▁"): space_np[tid] = True; piece = piece[1:] + base_np[tid] = len(piece.encode("utf-8")) + return (torch.tensor(base_np, dtype=torch.int16, device=device), + torch.tensor(space_np, dtype=torch.bool, device=device), + torch.tensor(bnd_np, dtype=torch.bool, device=device)) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: raise FileNotFoundError(f"No val files: {pattern}") + tokens = torch.cat([load_data_shard(f) for f in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + return tokens[:usable + 1] + +@torch.no_grad() +def eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut): + local_tok = args.val_batch_size // (world_size * grad_accum_steps) + local_seqs = local_tok // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + val_seq_limit = min(seq_end, seq_start + 4096) # capped for T4 speed + with torch.inference_mode(): + for s in range(seq_start, val_seq_limit, local_seqs): + se = min(s + local_seqs, seq_end) + raw = val_tokens[s * args.train_seq_len:(se * args.train_seq_len + 1)] + raw = raw.to(device=device, dtype=torch.int64, non_blocking=True) + x = raw[:-1].reshape(-1, args.train_seq_len) + y = raw[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = model(x, y).detach() + n = float(y.numel()) + loss_sum += loss.to(torch.float64) * n + tok_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_lut[tgt_ids].to(torch.int16) + tb += (space_lut[tgt_ids] & ~bnd_lut[prev_ids]).to(torch.int16) + byte_count += tb.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + for t in (loss_sum, tok_count, byte_count): + dist.all_reduce(t, op=dist.ReduceOp.SUM) + model.train() + val_loss = float((loss_sum / tok_count).item()) + bpb = float((loss_sum / tok_count / math.log(2.0) * tok_count / byte_count).item()) + return val_loss, bpb + +# ── Quantization (from competition baseline) ────────────────────────────────── + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def _keep_float(name: str, t: Tensor, pt_dtypes: dict) -> Tensor: + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + pt_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def _quantize_tensor(t: Tensor): + t32 = t.float() + if t32.ndim == 2: + clip = torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty(t32.shape[0]) + cl = torch.maximum(torch.minimum(t32, clip[:, None]), -clip[:, None]) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(cl / sc[:, None]), -127, 127).to(torch.int8).contiguous() + return q, sc.to(INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + sc = torch.tensor(clip / 127.0 if clip > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip, clip) / sc), -127, 127).to(torch.int8).contiguous() + return q, sc + +def quantize_state_dict_int8(sd: dict): + quant, scales, dtypes, passthrough, pt_dtypes, qmeta = {}, {}, {}, {}, {}, {} + stats = dict.fromkeys(("param_count","num_tensors","num_float_tensors", + "num_nonfloat_tensors","baseline_tensor_bytes","int8_payload_bytes"), 0) + for name, tensor in sd.items(): + t = tensor.detach().cpu().contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = _keep_float(name, t, pt_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = _quantize_tensor(t) + if s.ndim > 0: qmeta[name] = {"scheme": "per_row", "axis": 0} + quant[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict = {"__quant_format__": "int8_clean_per_row_v1", + "quantized": quant, "scales": scales, "dtypes": dtypes, "passthrough": passthrough} + if qmeta: obj["qmeta"] = qmeta + if pt_dtypes: obj["passthrough_orig_dtypes"] = pt_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict) -> dict: + out = {} + qm = obj.get("qmeta", {}) + ptod = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qm.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype).contiguous() + else: + out[name] = (q.float() * float(s.item())).to(dtype).contiguous() + for name, t in obj["passthrough"].items(): + ot = t.detach().cpu().contiguous() + if isinstance(ptod.get(name), str): + ot = ot.to(getattr(torch, ptod[name])).contiguous() + out[name] = ot + return out + +# ── Data loading ────────────────────────────────────────────────────────────── + +def load_data_shard(file: Path) -> Tensor: + hdr = np.fromfile(file, dtype=" Tensor: + chunks, rem = [], n + while rem > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: self._next(); continue + k = min(rem, avail) + chunks.append(self.tokens[self.pos:self.pos + k]) + self.pos += k; rem -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum: int): + local = global_tokens // (self.world_size * grad_accum) + span = local + 1 + chunk = self.stream.take(span * self.world_size) + s = self.rank * span + raw = chunk[s:s + span].to(torch.int64) + x = raw[:-1].reshape(-1, seq_len) + y = raw[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ── Model ───────────────────────────────────────────────────────────────────── + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + b = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), b) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__(); self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, p in module.named_parameters(): + if (p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS)) \ + and p.dtype != torch.float32: + p.data = p.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv, persistent=False) + self._sl, self._cos, self._sin = 0, None, None + def forward(self, sl: int, device, dtype): + if self._cos is None or self._sl != sl or self._cos.device != device: + t = torch.arange(sl, device=device, dtype=self.inv_freq.dtype) + f = torch.outer(t, self.inv_freq.to(device)) + self._cos = f.cos()[None, None, :, :]; self._sin = f.sin()[None, None, :, :] + self._sl = sl + return self._cos.to(dtype=dtype), self._sin.to(dtype=dtype) + +def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + h = x.size(-1) // 2 + x1, x2 = x[..., :h], x[..., h:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +def fake_quant_per_row(w: Tensor) -> Tensor: + """Per-row percentile int8 QAT with straight-through estimator.""" + if w.ndim < 2 or w.numel() == 0: return w + w32 = w.float() + clip = torch.quantile(w32.abs(), INT8_CLIP_Q, dim=1) + sc = (clip / 127.0).clamp_min(1.0 / 127.0) + wq = (w32.clamp(-clip[:, None], clip[:, None]) / sc[:, None]).round().clamp(-127, 127) * sc[:, None] + return w + (wq.to(w.dtype) - w).detach() + +class SharedDiffAttention(nn.Module): + """Differential Attention V2 + per-loop low-rank Q delta + GQA.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, rope_base: float, rank: int): + super().__init__() + assert dim % nh == 0 and nh % nkv == 0 + self.nh, self.nkv, self.n_rep = nh, nkv, nh // nkv + self.hd = dim // nh + self.nmem = n_mem + + self.c_q = CastedLinear(dim, dim * 2, bias=False) + self.c_k = CastedLinear(dim, nkv * self.hd * 2, bias=False) + self.c_v = CastedLinear(dim, nkv * self.hd, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + + li = 0.8 - 0.6 * math.exp(-0.3) + self.lambda_init = li + self.lambda_q1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k1 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_q2 = nn.Parameter(torch.randn(nh) * 0.1) + self.lambda_k2 = nn.Parameter(torch.randn(nh) * 0.1) + + self.q_gain = nn.Parameter(torch.ones(nh, dtype=torch.float32)) + self.delta_A = nn.Parameter(torch.zeros(n_loops, dim * 2, rank)) + self.delta_B = nn.Parameter(torch.zeros(n_loops, rank, dim)) + nn.init.normal_(self.delta_A, std=0.01) + + self.rotary = Rotary(self.hd, base=rope_base) + + def forward(self, x: Tensor, loop_idx: int) -> Tensor: + B, T, C = x.shape + H, KVH, D, NM = self.nh, self.nkv, self.hd, self.nmem + sl = T - NM + + # Q with per-loop low-rank delta + QAT on effective weight + wq = self.c_q.weight + (self.delta_A[loop_idx] @ self.delta_B[loop_idx]) + wq = fake_quant_per_row(wq) + qp = F.linear(x, wq.to(x.dtype)) + q1 = qp[..., :C].reshape(B, T, H, D).transpose(1, 2) + q2 = qp[..., C:].reshape(B, T, H, D).transpose(1, 2) + + # K, V: fresh every loop from updated hidden state + kv = self.c_k(x) + k1 = kv[..., :KVH*D].reshape(B, T, KVH, D).transpose(1, 2) + k2 = kv[..., KVH*D:].reshape(B, T, KVH, D).transpose(1, 2) + v = self.c_v(x).reshape(B, T, KVH, D).transpose(1, 2) + + # QK RMSNorm before RoPE (baseline technique) + q1 = F.rms_norm(q1, (D,)); q2 = F.rms_norm(q2, (D,)) + k1 = F.rms_norm(k1, (D,)); k2 = F.rms_norm(k2, (D,)) + + # RoPE: sequence positions ONLY — memory tokens excluded + cos, sin = self.rotary(sl, x.device, q1.dtype) + def rope_seq(q, k): + qm, qs = q[:, :, :NM], q[:, :, NM:] + km, ks = k[:, :, :NM], k[:, :, NM:] + return (torch.cat([qm, apply_rope(qs, cos, sin)], dim=2), + torch.cat([km, apply_rope(ks, cos, sin)], dim=2)) + q1, k1 = rope_seq(q1, k1) + q2, k2 = rope_seq(q2, k2) + + # Q gain + GQA + g = self.q_gain.to(q1.dtype)[None, :, None, None] + q1 = q1 * g; q2 = q2 * g + k1 = k1.repeat_interleave(self.n_rep, dim=1) + k2 = k2.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + # Differential attention: cancel noise via subtraction + lam = (torch.exp(self.lambda_q1) * torch.exp(self.lambda_k1) + - torch.exp(self.lambda_q2) * torch.exp(self.lambda_k2) + + self.lambda_init).to(q1.dtype)[None, :, None, None] + a1 = F.scaled_dot_product_attention(q1, k1, v, is_causal=True) + a2 = F.scaled_dot_product_attention(q2, k2, v, is_causal=True) + out = (a1 - lam * a2).transpose(1, 2).contiguous().view(B, T, C) + return self.proj(out) + +class ReluSquaredMLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + h = dim * mlp_mult + self.fc = CastedLinear(dim, h, bias=False) + self.proj = CastedLinear(h, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + return self.proj(torch.relu(self.fc(x)).square()) + +class RecurrentBlock(nn.Module): + """Shared block with all per-loop parameters.""" + def __init__(self, dim: int, nh: int, nkv: int, n_loops: int, + n_mem: int, mlp_mult: int, rope_base: float, rank: int): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = SharedDiffAttention(dim, nh, nkv, n_loops, n_mem, rope_base, rank) + self.mlp = ReluSquaredMLP(dim, mlp_mult) + + # Per-loop scale factors (learned attention/MLP gate per depth) + self.attn_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(n_loops, dim, dtype=torch.float32)) + + # resid_mix: learned interpolation between current x and original x0 + # [0] = weight on current state, [1] = weight on original input + # Initialised to carry forward (mix[0]=1, mix[1]=0), learns to reset selectively + self.resid_mix = nn.Parameter( + torch.stack([torch.ones(n_loops, dim), torch.zeros(n_loops, dim)]).float() + ) + + # Loop position embedding: activation-level depth signal + # Complementary to weight-level low-rank delta + self.loop_embed = nn.Embedding(n_loops, dim) + nn.init.normal_(self.loop_embed.weight, std=0.02) + + def forward(self, x: Tensor, x0: Tensor, loop_idx: int) -> Tensor: + # resid_mix: learned balance of current state vs original input (generalises input injection) + mix = self.resid_mix[:, loop_idx, :].to(x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + + # Loop position embedding (activation-level specialisation) + x = x + self.loop_embed.weight[loop_idx].to(x.dtype)[None, None, :] + + # Attention + MLP with per-loop scale + a = self.attn(self.attn_norm(x), loop_idx) + x = x + self.attn_scale[loop_idx].to(x.dtype)[None, None, :] * a + x = x + self.mlp_scale[loop_idx].to(x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class RLM(nn.Module): + """ + Recurrent Language Model. + Single shared block × N_LOOPS with deep supervision and U-Net skips. + """ + def __init__(self, args: Hyperparameters): + super().__init__() + dim, nl, nm = args.model_dim, args.n_loops, args.n_memory + self.n_loops, self.n_memory = nl, nm + self.dim = dim + self.logit_softcap = args.logit_softcap + self.deep_sup_gamma = args.deep_sup_gamma + + self.n_enc = nl // 2 + self.n_skip = min(self.n_enc, nl - self.n_enc) + + self.tok_emb = nn.Embedding(args.vocab_size, dim) + nn.init.normal_(self.tok_emb.weight, std=0.005) + + self.memory = nn.Parameter(torch.randn(1, nm, dim) * 0.02) + + self.block = RecurrentBlock( + dim, args.num_heads, args.num_kv_heads, nl, nm, + args.mlp_mult, args.rope_base, args.lora_rank + ) + + self.skip_weights = nn.Parameter( + torch.ones(self.n_skip, dim, dtype=torch.float32) + if self.n_skip > 0 else torch.zeros(0, dim, dtype=torch.float32) + ) + self.final_norm = RMSNorm() + + # Zero-init output projections for stability at loop depth + nn.init.zeros_(self.block.attn.proj.weight) + nn.init.zeros_(self.block.mlp.proj.weight) + + def _logits(self, x: Tensor) -> Tensor: + h = self.final_norm(x[:, self.n_memory:, :]) + l = F.linear(h.reshape(-1, self.dim), self.tok_emb.weight) + return self.logit_softcap * torch.tanh(l / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + B, T = input_ids.shape + tok = F.rms_norm(self.tok_emb(input_ids), (self.dim,)) + x = torch.cat([self.memory.expand(B, -1, -1), tok], dim=1) + x0 = x + + gamma = self.deep_sup_gamma + total_loss = torch.zeros((), device=input_ids.device) + weight_sum = 0.0 + targets = target_ids.reshape(-1) + skips = [] + + for i in range(self.n_loops): + # Gradient checkpointing: recompute activations during backward + x = grad_ckpt(self.block, x, x0, i, use_reentrant=False) + + # U-Net: encoder half saves, decoder half consumes in reverse + if i < self.n_enc: + skips.append(x) + elif self.n_skip > 0 and skips: + di = i - self.n_enc + if di < self.n_skip: + sw = self.skip_weights[di].to(x.dtype)[None, None, :] + x = x + sw * skips[self.n_enc - 1 - di] + + # Deep supervision: weighted loss at every loop + # Later loops weighted higher (gamma^(N-1-i)), but all contribute + w = gamma ** (self.n_loops - 1 - i) + total_loss = total_loss + w * F.cross_entropy( + self._logits(x).float(), targets, reduction="mean" + ) + weight_sum += w + + return total_loss / weight_sum + +# ── Training ────────────────────────────────────────────────────────────────── + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + global zeropower_via_newtonschulz5 + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ── Distributed + CUDA setup ────────────────────────────────────────────── + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required. For CPU testing use train_gpt_local.py") + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import (enable_cudnn_sdp, enable_flash_sdp, + enable_math_sdp, enable_mem_efficient_sdp) + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(True); enable_math_sdp(True) + + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" if master else None + + def log0(msg: str, console: bool = True) -> None: + if not master: return + if console: print(msg) + if logfile: + with open(logfile, "a", encoding="utf-8") as f: print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Python {sys.version}", console=False) + log0(f"PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + + # ── Seed + tokenizer ────────────────────────────────────────────────────── + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"Tokenizer vocab {sp.vocab_size()} != VOCAB_SIZE {args.vocab_size}") + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_lut, space_lut, bnd_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_tokens:{val_tokens.numel()-1} train_seq_len:{args.train_seq_len}") + + # ── Model ───────────────────────────────────────────────────────────────── + base_model = RLM(args).to(device).bfloat16() + for m in base_model.modules(): + if isinstance(m, CastedLinear): m.float() + restore_low_dim_params_to_fp32(base_model) + compiled = torch.compile(base_model, dynamic=False, fullgraph=False) + model: nn.Module = (DDP(compiled, device_ids=[local_rank], broadcast_buffers=False) + if distributed else compiled) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} (~{n_params/1e6:.2f}M)") + log0(f"n_loops:{args.n_loops} n_memory:{args.n_memory} dim:{args.model_dim} " + f"seq_len:{args.train_seq_len}") + + # ── Optimizers ──────────────────────────────────────────────────────────── + block_named = list(base_model.block.named_parameters()) + matrix_params = [p for n, p in block_named + if p.ndim == 2 and not any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + scalar_params = ([p for n, p in block_named + if p.ndim != 2 or any(pat in n for pat in CONTROL_TENSOR_NAME_PATTERNS)] + + [base_model.skip_weights, base_model.memory, base_model.final_norm.eps + if hasattr(base_model.final_norm, 'eps') and + isinstance(base_model.final_norm.eps, nn.Parameter) else None]) + scalar_params = [p for p in scalar_params if p is not None] + + opt_embed = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": args.embed_lr, "base_lr": args.embed_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + opt_muon = Muon(matrix_params, lr=args.matrix_lr, + momentum=args.muon_momentum, backend_steps=args.muon_backend_steps) + for g in opt_muon.param_groups: g["base_lr"] = args.matrix_lr + opt_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers = [opt_embed, opt_muon, opt_scalar] + + log0(f"matrix_params:{sum(p.numel() for p in matrix_params):,} " + f"scalar_params:{sum(p.numel() for p in scalar_params):,}") + + # ── Checkpoint resume (thermal crash protection) ────────────────────────── + ckpt_path = f"checkpoint_{args.run_id}.pt" + start_step = 0 + training_time_ms = 0.0 + if os.path.exists(ckpt_path) and master: + log0(f"Resuming from {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location=device) + base_model.load_state_dict(ckpt["model"]) + for opt, sd in zip(optimizers, ckpt["optimizers"]): + opt.load_state_dict(sd) + start_step = ckpt["step"] + training_time_ms = ckpt.get("training_time_ms", 0.0) + log0(f"Resumed at step {start_step}") + if distributed: dist.barrier() + + # ── Data loader ─────────────────────────────────────────────────────────── + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad(): + for o in optimizers: o.zero_grad(set_to_none=True) + + max_wc_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_scale(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wc_ms is None: + ws = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) \ + if ws <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + wd_ms = args.warmdown_iters * step_ms + rem_ms = max(max_wc_ms - elapsed_ms, 0.0) + return rem_ms / max(wd_ms, 1e-9) if rem_ms <= wd_ms else 1.0 + + # ── Warmup-then-restore ─────────────────────────────────────────────────── + if args.warmup_steps > 0 and start_step == 0: + init_model = {n: t.detach().cpu().clone() for n, t in base_model.state_dict().items()} + init_opts = [copy.deepcopy(o.state_dict()) for o in optimizers] + model.train() + for ws in range(args.warmup_steps): + zero_grad() + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + (model(x, y) * grad_scale).backward() + for o in optimizers: o.step() + zero_grad() + log0(f"warmup_complete:{args.warmup_steps} steps") + base_model.load_state_dict(init_model, strict=True) + for o, sd in zip(optimizers, init_opts): o.load_state_dict(sd) + zero_grad() + if distributed: model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ── Main training loop ──────────────────────────────────────────────────── + stop_after: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + for step in range(start_step, args.iterations + 1): + last = step == args.iterations or (stop_after is not None and step >= stop_after) + + # Validation + if ((args.val_loss_every > 0 and step % args.val_loss_every == 0 and step >= 200) or last): + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + vl, vbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + log0(f"step:{step}/{args.iterations} val_loss:{vl:.4f} val_bpb:{vbpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/max(step,1):.2f}ms") + torch.cuda.synchronize(); t0 = time.perf_counter() + + if last: break + + # Checkpoint save (thermal crash protection) + if args.save_every > 0 and step > start_step and step % args.save_every == 0 and master: + torch.cuda.synchronize() + elapsed_now = training_time_ms + 1000.0 * (time.perf_counter() - t0) + torch.save({"step": step, "model": base_model.state_dict(), + "optimizers": [o.state_dict() for o in optimizers], + "training_time_ms": elapsed_now}, ckpt_path) + log0(f"checkpoint_saved:step={step}") + + # LR update + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_scale(step, elapsed_ms) + for o in optimizers: + for g in o.param_groups: g["lr"] = g["base_lr"] * scale + + # Muon momentum warmup + frac = min(step / max(args.muon_momentum_warmup_steps, 1), 1.0) + mum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for g in opt_muon.param_groups: g["momentum"] = mum + + # Forward + backward with gradient accumulation + zero_grad() + train_loss = torch.zeros((), device=device) + for ms in range(grad_accum_steps): + if distributed: model.require_backward_grad_sync = (ms == grad_accum_steps - 1) + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast("cuda", torch.bfloat16): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + for o in optimizers: o.step() + zero_grad() + + approx_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0): + log0(f"step:{step+1}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ms:.0f}ms step_avg:{approx_ms/(step+1):.2f}ms lr_scale:{scale:.4f}") + + reached = max_wc_ms is not None and approx_ms >= max_wc_ms + if distributed and max_wc_ms is not None: + rt = torch.tensor(int(reached), device=device) + dist.all_reduce(rt, op=dist.ReduceOp.MAX) + reached = bool(rt.item()) + if stop_after is None and reached: + stop_after = step + 1 + + log0(f"peak_memory:{torch.cuda.max_memory_allocated()//1024//1024}MiB " + f"reserved:{torch.cuda.max_memory_reserved()//1024//1024}MiB") + + # ── Serialisation + roundtrip validation ────────────────────────────────── + if master: + torch.save(base_model.state_dict(), "final_model.pt") + code_bytes = len(code.encode("utf-8")) + model_bytes = os.path.getsize("final_model.pt") + log0(f"raw_model:{model_bytes} code:{code_bytes} total:{model_bytes+code_bytes}") + + quant_obj, qstats = quantize_state_dict_int8(base_model.state_dict()) + buf = io.BytesIO() + torch.save(quant_obj, buf) + blob = zlib.compress(buf.getvalue(), level=9) + if master: + with open("final_model.int8.ptz", "wb") as f: f.write(blob) + qbytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = qstats["baseline_tensor_bytes"] / max(qstats["int8_payload_bytes"], 1) + log0(f"int8_zlib:{qbytes} code:{code_bytes} total:{qbytes+code_bytes} " + f"payload_ratio:{ratio:.2f}x budget_pct:{(qbytes+code_bytes)/16e6*100:.1f}%") + + if distributed: dist.barrier() + with open("final_model.int8.ptz", "rb") as f: blob_disk = f.read() + qs2 = torch.load(io.BytesIO(zlib.decompress(blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(qs2), strict=True) + torch.cuda.synchronize(); te = time.perf_counter() + qvl, qvbpb = eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_lut, space_lut, bnd_lut) + torch.cuda.synchronize() + log0(f"final_int8_zlib_roundtrip val_loss:{qvl:.4f} val_bpb:{qvbpb:.4f} " + f"eval_time:{1000.0*(time.perf_counter()-te):.0f}ms") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{qvl:.8f} val_bpb:{qvbpb:.8f}") + + if distributed: dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Python 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0] +PyTorch 2.10.0+cu128 +Fri Mar 20 09:24:02 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.07 Driver Version: 580.82.07 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | +| N/A 67C P0 29W / 70W | 105MiB / 15360MiB | 4% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 52685 C /usr/bin/python3 102MiB | ++-----------------------------------------------------------------------------------------+ + +val_tokens:62021632 train_seq_len:256 +model_params:3918888 (~3.92M) +n_loops:4 n_memory:16 dim:512 seq_len:256 +matrix_params:3,276,800 scalar_params:117,800 +step:1/2000 train_loss:6.9365 train_time:8943ms step_avg:8942.57ms lr_scale:1.0000 +step:2/2000 train_loss:18.1766 train_time:11464ms step_avg:5732.25ms lr_scale:1.0000 +step:3/2000 train_loss:10.7042 train_time:14012ms step_avg:4670.75ms lr_scale:1.0000 +step:4/2000 train_loss:6.9687 train_time:16566ms step_avg:4141.47ms lr_scale:1.0000 +step:5/2000 train_loss:6.4091 train_time:19132ms step_avg:3826.33ms lr_scale:1.0000 +step:6/2000 train_loss:6.3178 train_time:21708ms step_avg:3617.94ms lr_scale:1.0000 +step:7/2000 train_loss:6.2625 train_time:24293ms step_avg:3470.43ms lr_scale:1.0000 +step:8/2000 train_loss:6.1385 train_time:26894ms step_avg:3361.81ms lr_scale:1.0000 +step:9/2000 train_loss:6.0663 train_time:29504ms step_avg:3278.24ms lr_scale:1.0000 +step:10/2000 train_loss:6.0499 train_time:32121ms step_avg:3212.14ms lr_scale:1.0000 +step:11/2000 train_loss:6.0487 train_time:34749ms step_avg:3159.04ms lr_scale:1.0000 +step:51/2000 train_loss:4.4880 train_time:139053ms step_avg:2726.54ms lr_scale:1.0000 +step:101/2000 train_loss:3.9092 train_time:269296ms step_avg:2666.30ms lr_scale:1.0000 +step:151/2000 train_loss:3.7392 train_time:399676ms step_avg:2646.86ms lr_scale:1.0000 +step:201/2000 train_loss:3.5584 train_time:530024ms step_avg:2636.93ms lr_scale:1.0000 +step:251/2000 train_loss:3.5899 train_time:660149ms step_avg:2630.07ms lr_scale:1.0000 +step:301/2000 train_loss:3.6138 train_time:790468ms step_avg:2626.14ms lr_scale:1.0000 +step:351/2000 train_loss:3.6629 train_time:920683ms step_avg:2623.03ms lr_scale:1.0000 +step:401/2000 train_loss:3.6071 train_time:1051010ms step_avg:2620.97ms lr_scale:1.0000 +step:451/2000 train_loss:3.5002 train_time:1181500ms step_avg:2619.73ms lr_scale:1.0000 +step:500/2000 val_loss:3.4842 val_bpb:2.0876 train_time:1309306ms step_avg:2618.61ms +checkpoint_saved:step=500 +step:501/2000 train_loss:3.4314 train_time:1311839ms step_avg:2618.44ms lr_scale:1.0000 +step:551/2000 train_loss:3.5510 train_time:1442107ms step_avg:2617.25ms lr_scale:1.0000 +step:601/2000 train_loss:3.3796 train_time:1572574ms step_avg:2616.60ms lr_scale:1.0000 +step:651/2000 train_loss:3.3024 train_time:1703092ms step_avg:2616.12ms lr_scale:1.0000 +step:701/2000 train_loss:3.4731 train_time:1833373ms step_avg:2615.37ms lr_scale:1.0000 +step:751/2000 train_loss:3.4287 train_time:1963481ms step_avg:2614.49ms lr_scale:1.0000 +step:801/2000 train_loss:3.3459 train_time:2093660ms step_avg:2613.81ms lr_scale:1.0000 +step:851/2000 train_loss:3.3466 train_time:2223894ms step_avg:2613.27ms lr_scale:1.0000 +step:901/2000 train_loss:3.2686 train_time:2354185ms step_avg:2612.86ms lr_scale:1.0000 +step:951/2000 train_loss:3.2709 train_time:2484575ms step_avg:2612.59ms lr_scale:1.0000 +step:1000/2000 val_loss:3.3309 val_bpb:1.9957 train_time:2612320ms step_avg:2612.32ms +checkpoint_saved:step=1000 +step:1001/2000 train_loss:3.2584 train_time:2614884ms step_avg:2612.27ms lr_scale:1.0000 +step:1051/2000 train_loss:3.3233 train_time:2745277ms step_avg:2612.06ms lr_scale:1.0000 +step:1101/2000 train_loss:3.2620 train_time:2875668ms step_avg:2611.87ms lr_scale:1.0000 +step:1151/2000 train_loss:3.2854 train_time:3005985ms step_avg:2611.63ms lr_scale:1.0000 +step:1201/2000 train_loss:3.2910 train_time:3136292ms step_avg:2611.40ms lr_scale:1.0000 +step:1251/2000 train_loss:3.3130 train_time:3266403ms step_avg:2611.03ms lr_scale:1.0000 +step:1301/2000 train_loss:3.3112 train_time:3396377ms step_avg:2610.59ms lr_scale:1.0000 +step:1351/2000 train_loss:3.2463 train_time:3526411ms step_avg:2610.22ms lr_scale:1.0000 +step:1401/2000 train_loss:3.2729 train_time:3656791ms step_avg:2610.13ms lr_scale:1.0000 +step:1451/2000 train_loss:3.2517 train_time:3787183ms step_avg:2610.05ms lr_scale:0.9167 +step:1500/2000 val_loss:3.2203 val_bpb:1.9294 train_time:3915119ms step_avg:2610.08ms +checkpoint_saved:step=1500 +step:1501/2000 train_loss:3.1675 train_time:3917685ms step_avg:2610.05ms lr_scale:0.8333 +step:1551/2000 train_loss:3.2723 train_time:4047841ms step_avg:2609.83ms lr_scale:0.7500 +step:1601/2000 train_loss:3.0833 train_time:4177851ms step_avg:2609.53ms lr_scale:0.6667 +step:1651/2000 train_loss:3.2030 train_time:4308461ms step_avg:2609.61ms lr_scale:0.5833 +step:1701/2000 train_loss:3.2907 train_time:4438605ms step_avg:2609.41ms lr_scale:0.5000 +step:1751/2000 train_loss:2.9965 train_time:4568694ms step_avg:2609.19ms lr_scale:0.4167 +step:1801/2000 train_loss:3.1435 train_time:4698846ms step_avg:2609.02ms lr_scale:0.3333 +step:1851/2000 train_loss:3.0147 train_time:4828975ms step_avg:2608.85ms lr_scale:0.2500 +step:1901/2000 train_loss:3.1446 train_time:4959082ms step_avg:2608.67ms lr_scale:0.1667 +step:1951/2000 train_loss:3.1320 train_time:5089186ms step_avg:2608.50ms lr_scale:0.0833 +step:2000/2000 val_loss:3.0880 val_bpb:1.8502 train_time:5217020ms step_avg:2608.51ms +peak_memory:3390MiB reserved:3784MiB +raw_model:14437775 code:42197 total:14479972 +int8_zlib:3554683 code:42197 total:3596880 payload_ratio:3.53x budget_pct:22.5% +final_int8_zlib_roundtrip val_loss:3.0914 val_bpb:1.8522 eval_time:21983ms +final_int8_zlib_roundtrip_exact val_loss:3.09141515 val_bpb:1.85221128