@@ -522,4 +522,86 @@ def main() -> None:
522522
523523 # Parent mode
524524 if not args .config_name :
525- raise System
525+ raise SystemExit ("--config_name is required" )
526+
527+ train_cfg = TrainingPropConfig (** json .load (open (args .config_name )))
528+ hp_cfg = OptunaSearchConfig (** json .load (open (train_cfg .hp_cfg_path )))
529+
530+ obj = hp_cfg .objective_metric or "final_eval_loss"
531+ objective_metrics = [obj ] if isinstance (obj , str ) else list (obj )
532+
533+ dirs = hp_cfg .study_direction
534+ if dirs is None :
535+ directions = [_auto_direction (k ) for k in objective_metrics ]
536+ else :
537+ directions = [dirs ] if isinstance (dirs , str ) else list (dirs )
538+
539+ if len (directions ) == 1 and len (objective_metrics ) > 1 :
540+ directions = directions * len (objective_metrics )
541+
542+ if _DEBUG :
543+ log .debug ("Objectives: %s | Directions: %s" , objective_metrics , directions )
544+
545+ # Build dataset JSONs once (shared across trials)
546+ data = _load_id_prop_data (train_cfg .id_prop_path , train_cfg )
547+
548+ train_ids , val_ids , test_ids = train_val_test_split_ids (
549+ data ,
550+ train_cfg .id_tag ,
551+ train_cfg .seed_val ,
552+ train_cfg .val_ratio ,
553+ train_cfg .test_ratio ,
554+ )
555+
556+ tmp = Path (tempfile .mkdtemp (prefix = "optuna_data_" ))
557+ train_j = tmp / "train.json"
558+ val_j = tmp / "val.json"
559+ test_j = tmp / "test.json"
560+ dumpjson (make_alpaca_json (data , train_ids , config = train_cfg ), train_j )
561+ dumpjson (make_alpaca_json (data , val_ids , config = train_cfg ), val_j )
562+ dumpjson (make_alpaca_json (data , test_ids , config = train_cfg ), test_j )
563+
564+ sampler = SearchSpaceSampler (hp_cfg .parameters )
565+ pruner = optuna .pruners .MedianPruner (n_warmup_steps = 1 )
566+ opt_sampler = TPESampler (
567+ multivariate = True ,
568+ constraints_func = lambda t : (t .user_attrs .get ("oom_violation" , 0.0 ),),
569+ )
570+ study = optuna .create_study (directions = directions , pruner = pruner , sampler = opt_sampler )
571+
572+ wall = time .time ()
573+ try :
574+ study .optimize (
575+ partial (
576+ objective ,
577+ train_cfg = train_cfg ,
578+ hp_cfg = hp_cfg ,
579+ sampler = sampler ,
580+ train_json = train_j ,
581+ val_json = val_j ,
582+ objective_metrics = objective_metrics ,
583+ # GUARDS — adjust to taste
584+ max_micro_bs = 256 ,
585+ max_eff_bs = 4096 ,
586+ trial_timeout_s = None ,
587+ ),
588+ n_trials = hp_cfg .n_trials ,
589+ # CRITICAL: this keeps the *study* alive while marking the trial FAIL (visible)
590+ catch = (TrialCrashed ,),
591+ )
592+ finally :
593+ shutil .rmtree (tmp , ignore_errors = True )
594+
595+ runtime = time .time () - wall
596+ print (f"\n Study finished in { runtime :.1f} s" )
597+ if len (objective_metrics ) == 1 :
598+ print ("Best value :" , study .best_value )
599+ print ("Best params:" , study .best_params )
600+ else :
601+ print ("Pareto front (top 5 shown):" )
602+ for i , t in enumerate (study .best_trials [:5 ]):
603+ print (f" Trial { t .number } : values={ t .values } , params={ t .params } " )
604+
605+
606+ if __name__ == "__main__" :
607+ main ()
0 commit comments