Skip to content

Commit 1044207

Browse files
amcadmusHan Wang
andauthored
Pass scheduler as an artifact (#29)
* change interface of scheduler from parameter to artifact * adjust test for artifact scheduler Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent a1fd77b commit 1044207

2 files changed

Lines changed: 100 additions & 47 deletions

File tree

dpgen2/flow/dpgen_loop.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,15 @@ class SchedulerWrapper(OP):
3939
@classmethod
4040
def get_input_sign(cls):
4141
return OPIOSign({
42-
"exploration_scheduler" : ExplorationScheduler,
42+
"exploration_scheduler" : Artifact(Path),
4343
"exploration_report": ExplorationReport,
4444
"trajs": Artifact(List[Path]),
4545
})
4646

4747
@classmethod
4848
def get_output_sign(cls):
4949
return OPIOSign({
50-
"exploration_scheduler" : ExplorationScheduler,
50+
"exploration_scheduler" : Artifact(Path),
5151
"converged" : bool,
5252
"lmp_task_grp" : Artifact(Path),
5353
"conf_selector" : ConfSelector,
@@ -58,28 +58,36 @@ def execute(
5858
self,
5959
ip : OPIO,
6060
) -> OPIO:
61-
scheduler = ip['exploration_scheduler']
61+
scheduler_in = ip['exploration_scheduler']
6262
report = ip['exploration_report']
6363
trajs = ip['trajs']
64+
lmp_task_grp_file = Path('lmp_task_grp.dat')
65+
scheduler_file = Path('scheduler.dat')
66+
67+
with open(scheduler_in, 'rb') as fp:
68+
scheduler = pickle.load(fp)
6469

6570
conv, lmp_task_grp, selector = scheduler.plan_next_iteration(report, trajs)
6671

67-
with open('lmp_task_grp.dat', 'wb') as fp:
72+
with open(lmp_task_grp_file, 'wb') as fp:
6873
pickle.dump(lmp_task_grp, fp)
6974

75+
with open(scheduler_file, 'wb') as fp:
76+
pickle.dump(scheduler, fp)
77+
7078
return OPIO({
71-
"exploration_scheduler" : scheduler,
79+
"exploration_scheduler" : scheduler_file,
7280
"converged" : conv,
7381
"conf_selector" : selector,
74-
"lmp_task_grp" : Path('lmp_task_grp.dat'),
82+
"lmp_task_grp" : lmp_task_grp_file,
7583
})
7684

7785

7886
class MakeBlockId(OP):
7987
@classmethod
8088
def get_input_sign(cls):
8189
return OPIOSign({
82-
"exploration_scheduler" : ExplorationScheduler,
90+
"exploration_scheduler" : Artifact(Path),
8391
})
8492

8593
@classmethod
@@ -93,8 +101,11 @@ def execute(
93101
self,
94102
ip : OPIO,
95103
) -> OPIO:
96-
scheduler = ip['exploration_scheduler']
104+
scheduler_in = ip['exploration_scheduler']
97105

106+
with open(scheduler_in, 'rb') as fp:
107+
scheduler = pickle.load(fp)
108+
98109
stage = scheduler.get_stage()
99110
iteration = scheduler.get_iteration()
100111

@@ -121,18 +132,18 @@ def __init__(
121132
"conf_selector" : InputParameter(),
122133
"fp_inputs" : InputParameter(),
123134
"fp_config" : InputParameter(),
124-
"exploration_scheduler" : InputParameter(),
125135
}
126136
self._input_artifacts={
137+
"exploration_scheduler" : InputArtifact(),
127138
"init_models" : InputArtifact(),
128139
"init_data" : InputArtifact(),
129140
"iter_data" : InputArtifact(),
130141
"lmp_task_grp" : InputArtifact(),
131142
}
132143
self._output_parameters={
133-
"exploration_scheduler": OutputParameter(),
134144
}
135145
self._output_artifacts={
146+
"exploration_scheduler": OutputArtifact(),
136147
"models": OutputArtifact(),
137148
"iter_data" : OutputArtifact(),
138149
}
@@ -213,17 +224,17 @@ def __init__(
213224
"lmp_config" : InputParameter(),
214225
"fp_inputs" : InputParameter(),
215226
"fp_config" : InputParameter(),
216-
"exploration_scheduler" : InputParameter(),
217227
}
218228
self._input_artifacts={
229+
"exploration_scheduler" : InputArtifact(),
219230
"init_models" : InputArtifact(),
220231
"init_data" : InputArtifact(),
221232
"iter_data" : InputArtifact(),
222233
}
223234
self._output_parameters={
224-
"exploration_scheduler": OutputParameter(),
225235
}
226236
self._output_artifacts={
237+
"exploration_scheduler": OutputArtifact(),
227238
"models": OutputArtifact(),
228239
"iter_data" : OutputArtifact(),
229240
}
@@ -321,10 +332,10 @@ def _loop (
321332
python_packages = upload_python_package,
322333
),
323334
parameters={
324-
"exploration_scheduler": steps.inputs.parameters['exploration_scheduler'],
325335
"exploration_report": block_step.outputs.parameters['exploration_report'],
326336
},
327337
artifacts={
338+
"exploration_scheduler": steps.inputs.artifacts['exploration_scheduler'],
328339
"trajs" : block_step.outputs.artifacts['trajs'],
329340
},
330341
key = step_keys['scheduler'],
@@ -339,9 +350,9 @@ def _loop (
339350
python_packages = upload_python_package,
340351
),
341352
parameters={
342-
"exploration_scheduler": scheduler_step.outputs.parameters['exploration_scheduler'],
343353
},
344354
artifacts={
355+
"exploration_scheduler": scheduler_step.outputs.artifacts['exploration_scheduler'],
345356
},
346357
key = step_keys['id'],
347358
)
@@ -360,9 +371,9 @@ def _loop (
360371
"conf_selector" : scheduler_step.outputs.parameters["conf_selector"],
361372
"fp_inputs" : steps.inputs.parameters["fp_inputs"],
362373
"fp_config" : steps.inputs.parameters["fp_config"],
363-
"exploration_scheduler" : scheduler_step.outputs.parameters["exploration_scheduler"],
364374
},
365375
artifacts={
376+
"exploration_scheduler" : scheduler_step.outputs.artifacts["exploration_scheduler"],
366377
"lmp_task_grp" : scheduler_step.outputs.artifacts["lmp_task_grp"],
367378
"init_models" : block_step.outputs.artifacts['models'],
368379
"init_data" : steps.inputs.artifacts['init_data'],
@@ -372,11 +383,11 @@ def _loop (
372383
)
373384
steps.add(next_step)
374385

375-
steps.outputs.parameters['exploration_scheduler'].value_from_expression = \
386+
steps.outputs.artifacts['exploration_scheduler'].from_expression = \
376387
if_expression(
377388
_if = (scheduler_step.outputs.parameters['converged'] == True),
378-
_then = scheduler_step.outputs.parameters['exploration_scheduler'],
379-
_else = next_step.outputs.parameters['exploration_scheduler'],
389+
_then = scheduler_step.outputs.artifacts['exploration_scheduler'],
390+
_else = next_step.outputs.artifacts['exploration_scheduler'],
380391
)
381392
steps.outputs.artifacts['models'].from_expression = \
382393
if_expression(
@@ -411,10 +422,10 @@ def _dpgen(
411422
python_packages = upload_python_package,
412423
),
413424
parameters={
414-
"exploration_scheduler": steps.inputs.parameters['exploration_scheduler'],
415425
"exploration_report": None,
416426
},
417427
artifacts={
428+
"exploration_scheduler": steps.inputs.artifacts['exploration_scheduler'],
418429
"trajs" : None,
419430
},
420431
key = step_keys['scheduler'],
@@ -429,9 +440,9 @@ def _dpgen(
429440
python_packages = upload_python_package,
430441
),
431442
parameters={
432-
"exploration_scheduler": scheduler_step.outputs.parameters['exploration_scheduler'],
433443
},
434444
artifacts={
445+
"exploration_scheduler": scheduler_step.outputs.artifacts['exploration_scheduler'],
435446
},
436447
key = step_keys['id'],
437448
)
@@ -450,9 +461,9 @@ def _dpgen(
450461
"lmp_config" : steps.inputs.parameters['lmp_config'],
451462
"fp_inputs" : steps.inputs.parameters['fp_inputs'],
452463
"fp_config" : steps.inputs.parameters['fp_config'],
453-
"exploration_scheduler" : scheduler_step.outputs.parameters['exploration_scheduler'],
454464
},
455465
artifacts={
466+
"exploration_scheduler" : scheduler_step.outputs.artifacts['exploration_scheduler'],
456467
"lmp_task_grp" : scheduler_step.outputs.artifacts['lmp_task_grp'],
457468
"init_models": steps.inputs.artifacts["init_models"],
458469
"init_data": steps.inputs.artifacts["init_data"],
@@ -462,8 +473,8 @@ def _dpgen(
462473
)
463474
steps.add(loop_step)
464475

465-
steps.outputs.parameters["exploration_scheduler"].value_from_parameter = \
466-
loop_step.outputs.parameters["exploration_scheduler"]
476+
steps.outputs.artifacts["exploration_scheduler"]._from = \
477+
loop_step.outputs.artifacts["exploration_scheduler"]
467478
steps.outputs.artifacts["models"]._from = \
468479
loop_step.outputs.artifacts["models"]
469480
steps.outputs.artifacts["iter_data"]._from = \

0 commit comments

Comments
 (0)