Skip to content
This repository was archived by the owner on Mar 12, 2026. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
testpaths = tests
pythonpath = .
48 changes: 48 additions & 0 deletions tests/test_adaptive_cooldown.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import logging
from types import SimpleNamespace

import torch
import pytest

from config.hyperparameters import Config
from training.adaptive import AdaptiveController, CLIP_FRAC_COOLDOWN


class DummyAgent:
def save(self, path, step):
pass


def make_trainer(tmp_path):
return SimpleNamespace(
global_step=0,
config=SimpleNamespace(video_save_path=str(tmp_path)),
agent=DummyAgent(),
)


def test_cooldown_prevents_repeated_scaling(tmp_path):
cfg = Config()
param = torch.nn.Parameter(torch.zeros(1))
opt = torch.optim.Adam([param], lr=cfg.learning_rate)
logger = logging.getLogger("test")
controller = AdaptiveController(cfg, opt, logger, state_dir=str(tmp_path))
trainer = make_trainer(tmp_path)

metrics = {"clip_fraction": 0.6, "mean_reward": 0.0, "explained_variance": 0.0}

# Fill history to trigger initial scaling
for _ in range(5):
controller.update(metrics, trainer)

first_scale = controller.state["lr_scale"]
assert first_scale < 1.0

# Within cooldown period, repeated high clip fraction should not rescale
for _ in range(CLIP_FRAC_COOLDOWN - 1):
controller.update(metrics, trainer)
assert controller.state["lr_scale"] == pytest.approx(first_scale)

# After cooldown expires, scaling can happen again
controller.update(metrics, trainer)
assert controller.state["lr_scale"] < first_scale
23 changes: 14 additions & 9 deletions training/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

ENTROPY_FLOOR = 2e-2
COOLDOWN_ENABLED = True
CLIP_FRAC_COOLDOWN = 4 # updates to skip after scaling


class AdaptiveController:
Expand Down Expand Up @@ -82,13 +83,19 @@ def _save(self) -> None:
# Public API – call once per PPO update
# ────────────────────────────────────────────────
def update(self, metrics: Dict[str, float], trainer) -> None:
"""Update internal state based on recent training metrics.

Args:
metrics: Mapping containing ``clip_fraction``, ``mean_reward``,
and ``explained_variance`` keys describing the latest
training statistics.
trainer: Trainer object used for checkpointing when
hyper-parameters change.
"""
metrics = {
"clip_fraction": float,
"mean_reward": float,
"explained_variance": float,
}
"""
# Decay cooldown timer each call so it persists across updates
if COOLDOWN_ENABLED and self.state.get("clip_cooldown", 0) > 0:
self.state["clip_cooldown"] -= 1

h = self.state["history"]
h["clip_fraction"].append(metrics["clip_fraction"])
h["mean_reward"].append(metrics["mean_reward"])
Expand All @@ -106,9 +113,7 @@ def update(self, metrics: Dict[str, float], trainer) -> None:
f"[ADAPT] clip_fraction={avg_clip:.3f} ▶ dampening LR & clip_eps by 25 %"
)
if COOLDOWN_ENABLED:
self.state["clip_cooldown"] = 1
elif COOLDOWN_ENABLED:
self.state["clip_cooldown"] = 0
self.state["clip_cooldown"] = CLIP_FRAC_COOLDOWN

# ── Rule 2: plateau detector / entropy annealing ───────────────
if len(h["mean_reward"]) == h["mean_reward"].maxlen:
Expand Down