@@ -1945,6 +1945,79 @@ def train(
19451945 == "error_weighted"
19461946 )
19471947
1948+ def _replace_latest_metrics_with_full_pass () -> None :
1949+ history = self ._metrics .get_history ()
1950+ if len (history .get ("train_energy_rmse" , [])) == 0 :
1951+ return
1952+
1953+ final_train_loader = DataLoader (
1954+ train_ds ,
1955+ batch_size = batch_size ,
1956+ shuffle = False ,
1957+ collate_fn = _collate_fn ,
1958+ ** eval_dl_kwargs ,
1959+ )
1960+ (
1961+ final_train_energy_rmse ,
1962+ final_train_energy_mae ,
1963+ final_train_force_rmse ,
1964+ _ ,
1965+ _ ,
1966+ ) = training_loop .run_epoch (
1967+ loader = final_train_loader ,
1968+ optimizer = None ,
1969+ alpha = alpha ,
1970+ atomic_energies_by_index = atomic_energies_by_index ,
1971+ train = False ,
1972+ show_batch_progress = False ,
1973+ force_scale_unbiased = bool (
1974+ getattr (config , "force_scale_unbiased" , False )
1975+ ),
1976+ collect_structure_scores = False ,
1977+ )
1978+
1979+ final_test_energy_rmse = float ("nan" )
1980+ final_test_energy_mae = float ("nan" )
1981+ final_test_force_rmse = float ("nan" )
1982+ if test_ds is not None :
1983+ final_test_loader = DataLoader (
1984+ test_ds ,
1985+ batch_size = batch_size ,
1986+ shuffle = False ,
1987+ collate_fn = _collate_fn ,
1988+ ** eval_dl_kwargs ,
1989+ )
1990+ (
1991+ final_test_energy_rmse ,
1992+ final_test_energy_mae ,
1993+ final_test_force_rmse ,
1994+ _ ,
1995+ _ ,
1996+ ) = training_loop .run_epoch (
1997+ loader = final_test_loader ,
1998+ optimizer = None ,
1999+ alpha = alpha ,
2000+ atomic_energies_by_index = atomic_energies_by_index ,
2001+ train = False ,
2002+ show_batch_progress = False ,
2003+ force_scale_unbiased = bool (
2004+ getattr (config , "force_scale_unbiased" , False )
2005+ ),
2006+ collect_structure_scores = False ,
2007+ )
2008+
2009+ self ._metrics .replace_latest (
2010+ train_energy_rmse = float (final_train_energy_rmse ),
2011+ train_energy_mae = float (final_train_energy_mae ),
2012+ train_force_rmse = float (final_train_force_rmse ),
2013+ test_energy_rmse = float (final_test_energy_rmse ),
2014+ test_energy_mae = float (final_test_energy_mae ),
2015+ test_force_rmse = float (final_test_force_rmse ),
2016+ )
2017+
2018+ if start_epoch > 0 :
2019+ _replace_latest_metrics_with_full_pass ()
2020+
19482021 for epoch in range (start_epoch , end_epoch ):
19492022 t0 = time .time ()
19502023 save_best_checkpoint = False
@@ -2088,70 +2161,7 @@ def train(
20882161 pbar .close ()
20892162
20902163 if n_epochs > 0 :
2091- final_train_loader = DataLoader (
2092- train_ds ,
2093- batch_size = batch_size ,
2094- shuffle = False ,
2095- collate_fn = _collate_fn ,
2096- ** eval_dl_kwargs ,
2097- )
2098- (
2099- final_train_energy_rmse ,
2100- final_train_energy_mae ,
2101- final_train_force_rmse ,
2102- _ ,
2103- _ ,
2104- ) = training_loop .run_epoch (
2105- loader = final_train_loader ,
2106- optimizer = None ,
2107- alpha = alpha ,
2108- atomic_energies_by_index = atomic_energies_by_index ,
2109- train = False ,
2110- show_batch_progress = False ,
2111- force_scale_unbiased = bool (
2112- getattr (config , "force_scale_unbiased" , False )
2113- ),
2114- collect_structure_scores = False ,
2115- )
2116-
2117- final_test_energy_rmse = float ("nan" )
2118- final_test_energy_mae = float ("nan" )
2119- final_test_force_rmse = float ("nan" )
2120- if test_ds is not None :
2121- final_test_loader = DataLoader (
2122- test_ds ,
2123- batch_size = batch_size ,
2124- shuffle = False ,
2125- collate_fn = _collate_fn ,
2126- ** eval_dl_kwargs ,
2127- )
2128- (
2129- final_test_energy_rmse ,
2130- final_test_energy_mae ,
2131- final_test_force_rmse ,
2132- _ ,
2133- _ ,
2134- ) = training_loop .run_epoch (
2135- loader = final_test_loader ,
2136- optimizer = None ,
2137- alpha = alpha ,
2138- atomic_energies_by_index = atomic_energies_by_index ,
2139- train = False ,
2140- show_batch_progress = False ,
2141- force_scale_unbiased = bool (
2142- getattr (config , "force_scale_unbiased" , False )
2143- ),
2144- collect_structure_scores = False ,
2145- )
2146-
2147- self ._metrics .replace_latest (
2148- train_energy_rmse = float (final_train_energy_rmse ),
2149- train_energy_mae = float (final_train_energy_mae ),
2150- train_force_rmse = float (final_train_force_rmse ),
2151- test_energy_rmse = float (final_test_energy_rmse ),
2152- test_energy_mae = float (final_test_energy_mae ),
2153- test_force_rmse = float (final_test_force_rmse ),
2154- )
2164+ _replace_latest_metrics_with_full_pass ()
21552165
21562166 # Store optimizer and config for later saving
21572167 self ._optimizer = optimizer
0 commit comments