diff --git a/DeepSDFStruct/deep_sdf/workspace.py b/DeepSDFStruct/deep_sdf/workspace.py index 77b5418..bb034b3 100644 --- a/DeepSDFStruct/deep_sdf/workspace.py +++ b/DeepSDFStruct/deep_sdf/workspace.py @@ -148,7 +148,7 @@ def load_model_parameters( data = torch.load(filename, map_location=device, weights_only=True) - decoder.load_state_dict(data["model_state_dict"]) + decoder.load_state_dict(data["model_state_dict"], strict=False) return data["epoch"] @@ -317,13 +317,13 @@ def load_trained_model( data = torch.load(filename, map_location=device) decoder = init_decoder(experiment_specs, device, data_parallel) try: - decoder.load_state_dict(data["model_state_dict"]) + decoder.load_state_dict(data["model_state_dict"], strict=False) except RuntimeError: state_dict = {} for k, v in data["model_state_dict"].items(): new_key = k.replace("module.", "", 1) if k.startswith("module.") else k state_dict[new_key] = v - decoder.load_state_dict(state_dict) + decoder.load_state_dict(state_dict, strict=False) decoder = decoder.to(device) return decoder