diff --git a/langfun/core/eval/v2/runners/ckpt_monitor.py b/langfun/core/eval/v2/runners/ckpt_monitor.py index 174df2f3..fee1275d 100644 --- a/langfun/core/eval/v2/runners/ckpt_monitor.py +++ b/langfun/core/eval/v2/runners/ckpt_monitor.py @@ -110,6 +110,7 @@ def start(self): # This is not precise, but we at least notify example start. if not self.current_run.filter or self.current_run.filter(evaluation): self.on_experiment_start(evaluation) + self._set_prior_elapse_from_checkpoints(evaluation) # Signal the start of the examples if we are not monitoring in-progress # files. @@ -353,6 +354,31 @@ def _mark_example_started( # HTML could show remotely in-progress examples. evaluation.state.update(example, in_progress=True) + def _set_prior_elapse_from_checkpoints( + self, + evaluation: evaluation_lib.Evaluation, + ) -> None: + output_dir = self.current_run.output_dir(evaluation) + ckpt_file_pattern = os.path.join(output_dir, self.checkpoint_pattern) + total_elapse = 0.0 + for filepath in pg.io.glob(ckpt_file_pattern): + last_modified_time = pg.io.getmtime(filepath) + if last_modified_time >= self.ckpt_start_time: + continue + try: + loaded_examples = evaluation.state.load( + filepath, + example_input_by_id=evaluation.example_input_by_id, + load_example_metadata=False, + ) + for example in loaded_examples: + if example.start_time is not None and example.end_time is not None: + total_elapse += example.end_time - example.start_time + except Exception: # pylint: disable=broad-except + pass + if total_elapse > 0: + evaluation.progress.add_prior_elapse(total_elapse) + def _run(self, evaluations: list[evaluation_lib.Evaluation]): raise NotImplementedError('Not needed in checkpoint monitor.') diff --git a/langfun/core/eval/v2/runners/ckpt_monitor_test.py b/langfun/core/eval/v2/runners/ckpt_monitor_test.py index 5e1e29e3..cb906795 100644 --- a/langfun/core/eval/v2/runners/ckpt_monitor_test.py +++ b/langfun/core/eval/v2/runners/ckpt_monitor_test.py @@ -208,6 +208,33 @@ def on_experiment_complete( ckpt_start_time=ckpt_start_time, ).run() + def test_prior_elapse_accumulated_from_preexisting_checkpoints(self): + exp = eval_test_helper.test_evaluation() + root_dir = os.path.join(self.test_dir, 'test_prior_elapse') + run = exp.run( + root_dir, + runner='sequential', + progress_tracker=None, + plugins=[ + checkpointing.PerExampleCheckpointer( + checkpoint_filename='checkpoint.jsonl' + ) + ], + use_cache='no', + ) + + ckpt_start_time = time.time() + monitor = ckpt_monitor.CheckpointMonitor( + run, + plugins=[], + checkpoint_pattern='checkpoint_*.jsonl', + ckpt_start_time=ckpt_start_time, + bypass_old_ckpt_files_with_non_oop_errors=False, + ) + monitor.run() + for leaf in run.experiment.leaf_nodes: + self.assertGreater(leaf.progress.prior_elapse, 0.0) + if __name__ == '__main__': unittest.main()