diff --git a/flame/utils/convert_hf_to_dcp.py b/flame/utils/convert_hf_to_dcp.py index bab94eb..5ceb001 100644 --- a/flame/utils/convert_hf_to_dcp.py +++ b/flame/utils/convert_hf_to_dcp.py @@ -21,7 +21,7 @@ def convert_hf_weights(model: str, checkpoint: str): logger.info(f"Writing to DCP at '{checkpoint}'") checkpoint.mkdir(parents=True, exist_ok=True) storage_writer = DCP.filesystem.FileSystemWriter(checkpoint, thread_count=8) - DCP.save({"model": state_dict}, storage_writer=storage_writer) + DCP.save(state_dict, storage_writer=storage_writer) if __name__ == "__main__":