Skip to content

Commit bd7065c

Browse files
committed
[fix] [torch_training] Resuming from checkpoints now first evaluates the entire data
The previous commit introduced sampling policies that can intentionally result in not sampling all of the data or sampling some data points multiple times. This can lead to checkpoint files that do not jhave reproducible error metrics. Hence, after resuming from a checkpoint, the error metrics need to reevaluated according to the used sampling policy.
1 parent 0c435ca commit bd7065c

File tree

1 file changed

+74
-64
lines changed

1 file changed

+74
-64
lines changed

src/aenet/torch_training/trainer.py

Lines changed: 74 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1945,6 +1945,79 @@ def train(
19451945
== "error_weighted"
19461946
)
19471947

1948+
def _replace_latest_metrics_with_full_pass() -> None:
1949+
history = self._metrics.get_history()
1950+
if len(history.get("train_energy_rmse", [])) == 0:
1951+
return
1952+
1953+
final_train_loader = DataLoader(
1954+
train_ds,
1955+
batch_size=batch_size,
1956+
shuffle=False,
1957+
collate_fn=_collate_fn,
1958+
**eval_dl_kwargs,
1959+
)
1960+
(
1961+
final_train_energy_rmse,
1962+
final_train_energy_mae,
1963+
final_train_force_rmse,
1964+
_,
1965+
_,
1966+
) = training_loop.run_epoch(
1967+
loader=final_train_loader,
1968+
optimizer=None,
1969+
alpha=alpha,
1970+
atomic_energies_by_index=atomic_energies_by_index,
1971+
train=False,
1972+
show_batch_progress=False,
1973+
force_scale_unbiased=bool(
1974+
getattr(config, "force_scale_unbiased", False)
1975+
),
1976+
collect_structure_scores=False,
1977+
)
1978+
1979+
final_test_energy_rmse = float("nan")
1980+
final_test_energy_mae = float("nan")
1981+
final_test_force_rmse = float("nan")
1982+
if test_ds is not None:
1983+
final_test_loader = DataLoader(
1984+
test_ds,
1985+
batch_size=batch_size,
1986+
shuffle=False,
1987+
collate_fn=_collate_fn,
1988+
**eval_dl_kwargs,
1989+
)
1990+
(
1991+
final_test_energy_rmse,
1992+
final_test_energy_mae,
1993+
final_test_force_rmse,
1994+
_,
1995+
_,
1996+
) = training_loop.run_epoch(
1997+
loader=final_test_loader,
1998+
optimizer=None,
1999+
alpha=alpha,
2000+
atomic_energies_by_index=atomic_energies_by_index,
2001+
train=False,
2002+
show_batch_progress=False,
2003+
force_scale_unbiased=bool(
2004+
getattr(config, "force_scale_unbiased", False)
2005+
),
2006+
collect_structure_scores=False,
2007+
)
2008+
2009+
self._metrics.replace_latest(
2010+
train_energy_rmse=float(final_train_energy_rmse),
2011+
train_energy_mae=float(final_train_energy_mae),
2012+
train_force_rmse=float(final_train_force_rmse),
2013+
test_energy_rmse=float(final_test_energy_rmse),
2014+
test_energy_mae=float(final_test_energy_mae),
2015+
test_force_rmse=float(final_test_force_rmse),
2016+
)
2017+
2018+
if start_epoch > 0:
2019+
_replace_latest_metrics_with_full_pass()
2020+
19482021
for epoch in range(start_epoch, end_epoch):
19492022
t0 = time.time()
19502023
save_best_checkpoint = False
@@ -2088,70 +2161,7 @@ def train(
20882161
pbar.close()
20892162

20902163
if n_epochs > 0:
2091-
final_train_loader = DataLoader(
2092-
train_ds,
2093-
batch_size=batch_size,
2094-
shuffle=False,
2095-
collate_fn=_collate_fn,
2096-
**eval_dl_kwargs,
2097-
)
2098-
(
2099-
final_train_energy_rmse,
2100-
final_train_energy_mae,
2101-
final_train_force_rmse,
2102-
_,
2103-
_,
2104-
) = training_loop.run_epoch(
2105-
loader=final_train_loader,
2106-
optimizer=None,
2107-
alpha=alpha,
2108-
atomic_energies_by_index=atomic_energies_by_index,
2109-
train=False,
2110-
show_batch_progress=False,
2111-
force_scale_unbiased=bool(
2112-
getattr(config, "force_scale_unbiased", False)
2113-
),
2114-
collect_structure_scores=False,
2115-
)
2116-
2117-
final_test_energy_rmse = float("nan")
2118-
final_test_energy_mae = float("nan")
2119-
final_test_force_rmse = float("nan")
2120-
if test_ds is not None:
2121-
final_test_loader = DataLoader(
2122-
test_ds,
2123-
batch_size=batch_size,
2124-
shuffle=False,
2125-
collate_fn=_collate_fn,
2126-
**eval_dl_kwargs,
2127-
)
2128-
(
2129-
final_test_energy_rmse,
2130-
final_test_energy_mae,
2131-
final_test_force_rmse,
2132-
_,
2133-
_,
2134-
) = training_loop.run_epoch(
2135-
loader=final_test_loader,
2136-
optimizer=None,
2137-
alpha=alpha,
2138-
atomic_energies_by_index=atomic_energies_by_index,
2139-
train=False,
2140-
show_batch_progress=False,
2141-
force_scale_unbiased=bool(
2142-
getattr(config, "force_scale_unbiased", False)
2143-
),
2144-
collect_structure_scores=False,
2145-
)
2146-
2147-
self._metrics.replace_latest(
2148-
train_energy_rmse=float(final_train_energy_rmse),
2149-
train_energy_mae=float(final_train_energy_mae),
2150-
train_force_rmse=float(final_train_force_rmse),
2151-
test_energy_rmse=float(final_test_energy_rmse),
2152-
test_energy_mae=float(final_test_energy_mae),
2153-
test_force_rmse=float(final_test_force_rmse),
2154-
)
2164+
_replace_latest_metrics_with_full_pass()
21552165

21562166
# Store optimizer and config for later saving
21572167
self._optimizer = optimizer

0 commit comments

Comments
 (0)