Bug description
The transformers4rec.torch.trainer.load_model_trainer_states_from_checkpoint() function can only load t4rec_model_class.pkl, all the other loads are broken. The utils.serialization.load() function is called 4 times in the function, but only one of them successfully loads, the other throw errors. I believe this is because load() is expecting an _io.BufferedReader object (resulting from open()) instead of a path.
Correctly implemented:
|
model = load( |
|
open(os.path.join(checkpoint_path, "t4rec_model_class.pkl"), "rb") |
|
) |
Incorrectly implemented:
|
load(os.path.join(checkpoint_path, "pytorch_model.bin"), torch_load=True) |
|
checkpoint_rng_state = load(rng_file, torch_load=True) |
|
load(os.path.join(checkpoint_path, "scaler.pt"), torch_load=True) |
Steps/Code to reproduce bug
Run the following code with a Model checkpoint at /checkpoint:
from transformers4rec import torch as tr
trainer = tr.trainer.Trainer(model=tr.model.base.Model(), args=None)
trainer.load_model_trainer_states_from_checkpoint('/checkpoint')
Expected behavior
The function call shouldn't be erroring.
Environment details
- Transformers4Rec version: 23.12.0
- Platform: Linux
- Python version: 3.10.12
- Huggingface Transformers version: 4.27.1
- PyTorch version (GPU?): 2.1.2 (no)
Bug description
The
transformers4rec.torch.trainer.load_model_trainer_states_from_checkpoint()function can only loadt4rec_model_class.pkl, all the other loads are broken. Theutils.serialization.load()function is called 4 times in the function, but only one of them successfully loads, the other throw errors. I believe this is becauseload()is expecting an_io.BufferedReaderobject (resulting fromopen()) instead of a path.Correctly implemented:
Transformers4Rec/transformers4rec/torch/trainer.py
Lines 744 to 746 in 5d59d14
Incorrectly implemented:
Transformers4Rec/transformers4rec/torch/trainer.py
Line 753 in 5d59d14
Transformers4Rec/transformers4rec/torch/trainer.py
Line 757 in 5d59d14
Transformers4Rec/transformers4rec/torch/trainer.py
Line 766 in 5d59d14
Steps/Code to reproduce bug
Run the following code with a Model checkpoint at
/checkpoint:Expected behavior
The function call shouldn't be erroring.
Environment details