Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flame/utils/convert_hf_to_dcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def convert_hf_weights(model: str, checkpoint: str):
logger.info(f"Writing to DCP at '{checkpoint}'")
checkpoint.mkdir(parents=True, exist_ok=True)
storage_writer = DCP.filesystem.FileSystemWriter(checkpoint, thread_count=8)
DCP.save({"model": state_dict}, storage_writer=storage_writer)
DCP.save(state_dict, storage_writer=storage_writer)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This change correctly flattens the checkpoint structure to align with the training loop's expectations. However, this is likely to break the flame/utils/convert_dcp_to_hf.py script.

That script expects a checkpoint with a top-level 'model' key, as seen on line 51 of that file:

model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')['model'])

After this change, checkpoints created by convert_hf_to_dcp.py will no longer have this 'model' key, and torch.load(...)['model'] will likely raise a KeyError.

If convert_dcp_to_hf.py is intended to work with checkpoints converted from Hugging Face, it may need to be updated to handle the new flat structure. Could you please clarify if this is an intended side-effect or if the other script should be updated as well?



if __name__ == "__main__":
Expand Down