Skip to content

Commit c0386ed

Browse files
authored
fix file truncation
1 parent 155692a commit c0386ed

1 file changed

Lines changed: 83 additions & 1 deletion

File tree

atomgpt/inverse_models/hyperparameter_search.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,4 +522,86 @@ def main() -> None:
522522

523523
# Parent mode
524524
if not args.config_name:
525-
raise System
525+
raise SystemExit("--config_name is required")
526+
527+
train_cfg = TrainingPropConfig(**json.load(open(args.config_name)))
528+
hp_cfg = OptunaSearchConfig(**json.load(open(train_cfg.hp_cfg_path)))
529+
530+
obj = hp_cfg.objective_metric or "final_eval_loss"
531+
objective_metrics = [obj] if isinstance(obj, str) else list(obj)
532+
533+
dirs = hp_cfg.study_direction
534+
if dirs is None:
535+
directions = [_auto_direction(k) for k in objective_metrics]
536+
else:
537+
directions = [dirs] if isinstance(dirs, str) else list(dirs)
538+
539+
if len(directions) == 1 and len(objective_metrics) > 1:
540+
directions = directions * len(objective_metrics)
541+
542+
if _DEBUG:
543+
log.debug("Objectives: %s | Directions: %s", objective_metrics, directions)
544+
545+
# Build dataset JSONs once (shared across trials)
546+
data = _load_id_prop_data(train_cfg.id_prop_path, train_cfg)
547+
548+
train_ids, val_ids, test_ids = train_val_test_split_ids(
549+
data,
550+
train_cfg.id_tag,
551+
train_cfg.seed_val,
552+
train_cfg.val_ratio,
553+
train_cfg.test_ratio,
554+
)
555+
556+
tmp = Path(tempfile.mkdtemp(prefix="optuna_data_"))
557+
train_j = tmp / "train.json"
558+
val_j = tmp / "val.json"
559+
test_j = tmp / "test.json"
560+
dumpjson(make_alpaca_json(data, train_ids, config=train_cfg), train_j)
561+
dumpjson(make_alpaca_json(data, val_ids, config=train_cfg), val_j)
562+
dumpjson(make_alpaca_json(data, test_ids, config=train_cfg), test_j)
563+
564+
sampler = SearchSpaceSampler(hp_cfg.parameters)
565+
pruner = optuna.pruners.MedianPruner(n_warmup_steps=1)
566+
opt_sampler = TPESampler(
567+
multivariate=True,
568+
constraints_func=lambda t: (t.user_attrs.get("oom_violation", 0.0),),
569+
)
570+
study = optuna.create_study(directions=directions, pruner=pruner, sampler=opt_sampler)
571+
572+
wall = time.time()
573+
try:
574+
study.optimize(
575+
partial(
576+
objective,
577+
train_cfg=train_cfg,
578+
hp_cfg=hp_cfg,
579+
sampler=sampler,
580+
train_json=train_j,
581+
val_json=val_j,
582+
objective_metrics=objective_metrics,
583+
# GUARDS — adjust to taste
584+
max_micro_bs=256,
585+
max_eff_bs=4096,
586+
trial_timeout_s=None,
587+
),
588+
n_trials=hp_cfg.n_trials,
589+
# CRITICAL: this keeps the *study* alive while marking the trial FAIL (visible)
590+
catch=(TrialCrashed,),
591+
)
592+
finally:
593+
shutil.rmtree(tmp, ignore_errors=True)
594+
595+
runtime = time.time() - wall
596+
print(f"\nStudy finished in {runtime:.1f}s")
597+
if len(objective_metrics) == 1:
598+
print("Best value :", study.best_value)
599+
print("Best params:", study.best_params)
600+
else:
601+
print("Pareto front (top 5 shown):")
602+
for i, t in enumerate(study.best_trials[:5]):
603+
print(f" Trial {t.number}: values={t.values}, params={t.params}")
604+
605+
606+
if __name__ == "__main__":
607+
main()

0 commit comments

Comments
 (0)