diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..4584de7 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +testpaths = tests +pythonpath = . diff --git a/tests/test_adaptive_cooldown.py b/tests/test_adaptive_cooldown.py new file mode 100644 index 0000000..81a08fc --- /dev/null +++ b/tests/test_adaptive_cooldown.py @@ -0,0 +1,48 @@ +import logging +from types import SimpleNamespace + +import torch +import pytest + +from config.hyperparameters import Config +from training.adaptive import AdaptiveController, CLIP_FRAC_COOLDOWN + + +class DummyAgent: + def save(self, path, step): + pass + + +def make_trainer(tmp_path): + return SimpleNamespace( + global_step=0, + config=SimpleNamespace(video_save_path=str(tmp_path)), + agent=DummyAgent(), + ) + + +def test_cooldown_prevents_repeated_scaling(tmp_path): + cfg = Config() + param = torch.nn.Parameter(torch.zeros(1)) + opt = torch.optim.Adam([param], lr=cfg.learning_rate) + logger = logging.getLogger("test") + controller = AdaptiveController(cfg, opt, logger, state_dir=str(tmp_path)) + trainer = make_trainer(tmp_path) + + metrics = {"clip_fraction": 0.6, "mean_reward": 0.0, "explained_variance": 0.0} + + # Fill history to trigger initial scaling + for _ in range(5): + controller.update(metrics, trainer) + + first_scale = controller.state["lr_scale"] + assert first_scale < 1.0 + + # Within cooldown period, repeated high clip fraction should not rescale + for _ in range(CLIP_FRAC_COOLDOWN - 1): + controller.update(metrics, trainer) + assert controller.state["lr_scale"] == pytest.approx(first_scale) + + # After cooldown expires, scaling can happen again + controller.update(metrics, trainer) + assert controller.state["lr_scale"] < first_scale diff --git a/training/adaptive.py b/training/adaptive.py index bc1bfbb..863f294 100644 --- a/training/adaptive.py +++ b/training/adaptive.py @@ -9,6 +9,7 @@ ENTROPY_FLOOR = 2e-2 COOLDOWN_ENABLED = True +CLIP_FRAC_COOLDOWN = 4 # updates to skip after scaling class AdaptiveController: @@ -82,13 +83,19 @@ def _save(self) -> None: # Public API – call once per PPO update # ──────────────────────────────────────────────── def update(self, metrics: Dict[str, float], trainer) -> None: + """Update internal state based on recent training metrics. + + Args: + metrics: Mapping containing ``clip_fraction``, ``mean_reward``, + and ``explained_variance`` keys describing the latest + training statistics. + trainer: Trainer object used for checkpointing when + hyper-parameters change. """ - metrics = { - "clip_fraction": float, - "mean_reward": float, - "explained_variance": float, - } - """ + # Decay cooldown timer each call so it persists across updates + if COOLDOWN_ENABLED and self.state.get("clip_cooldown", 0) > 0: + self.state["clip_cooldown"] -= 1 + h = self.state["history"] h["clip_fraction"].append(metrics["clip_fraction"]) h["mean_reward"].append(metrics["mean_reward"]) @@ -106,9 +113,7 @@ def update(self, metrics: Dict[str, float], trainer) -> None: f"[ADAPT] clip_fraction={avg_clip:.3f} ▶ dampening LR & clip_eps by 25 %" ) if COOLDOWN_ENABLED: - self.state["clip_cooldown"] = 1 - elif COOLDOWN_ENABLED: - self.state["clip_cooldown"] = 0 + self.state["clip_cooldown"] = CLIP_FRAC_COOLDOWN # ── Rule 2: plateau detector / entropy annealing ─────────────── if len(h["mean_reward"]) == h["mean_reward"].maxlen: