Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 75 additions & 4 deletions comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,76 @@ def load_safetensors(ckpt):
return sd, header.get("__metadata__", {}),


def load_safetensors_no_mmap(ckpt, device=None, return_metadata=False):
# Load a .safetensors / .sft file without ever mmap'ing it.
#
# safetensors.safe_open() (and therefore safetensors.torch.load_file) always
# mmaps the underlying file in Rust. On systems with unified CPU/GPU memory
# like NVIDIA Grace Blackwell / DGX Spark, Apple Silicon, AMD APUs, etc.
# this is fatal for large models: the OS page-cache pages backing the mmap
# and any subsequent device copy both reside in the same physical memory
# pool, doubling peak memory and causing OOM well before the hardware
# limit is reached.
# See: https://github.com/Comfy-Org/ComfyUI/issues/10896
# https://github.com/safetensors/safetensors/issues/758
# https://github.com/safetensors/safetensors/pull/759
#
# This is a temporary workaround until upstream safetensors exposes a
# public ``mmap=False`` option. Here we parse the safetensors header
# ourselves and read each tensor straight from disk into a per-tensor
# ``bytearray`` via ``readinto``, then zero-copy-wrap it as a torch tensor
# with ``torch.frombuffer``. Peak memory is one model copy (plus, if a
# non-CPU device is requested, the bytes of a single tensor in flight
# while it is being moved).
if device is None:
device = torch.device("cpu")

sd = {}
metadata = None
with open(ckpt, "rb") as f:
header_bytes = f.read(8)
if len(header_bytes) != 8:
raise ValueError("HeaderTooLarge: file is too small to be a valid safetensors file: {}".format(ckpt))
header_size = struct.unpack("<Q", header_bytes)[0]
header_data = f.read(header_size)
if len(header_data) != header_size:
raise ValueError("MetadataIncompleteBuffer: truncated header in {}".format(ckpt))
header = json.loads(header_data.decode("utf-8"))
Comment thread
johnnynunez marked this conversation as resolved.
data_base_offset = 8 + header_size

if return_metadata:
metadata = header.get("__metadata__", {})

for name, info in header.items():
if name == "__metadata__":
continue

dtype = _TYPES[info["dtype"]]
shape = info["shape"]
start, end = info["data_offsets"]
num_bytes = end - start

if num_bytes == 0:
tensor = torch.empty(shape, dtype=dtype)
else:
buf = bytearray(num_bytes)
f.seek(data_base_offset + start)
view = memoryview(buf)
offset = 0
while offset < num_bytes:
n = f.readinto(view[offset:])
if not n:
raise ValueError("MetadataIncompleteBuffer: unexpected EOF reading tensor {!r} from {}".format(name, ckpt))
offset += n
tensor = torch.frombuffer(buf, dtype=dtype).reshape(shape)

if device.type != "cpu":
tensor = tensor.to(device=device)
sd[name] = tensor

return sd, metadata


def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if device is None:
device = torch.device("cpu")
Expand All @@ -129,14 +199,15 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
sd, metadata = load_safetensors(ckpt)
if not return_metadata:
metadata = None
elif DISABLE_MMAP:
sd, metadata = load_safetensors_no_mmap(ckpt, device=device, return_metadata=return_metadata)
if not return_metadata:
metadata = None
else:
with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
sd = {}
for k in f.keys():
tensor = f.get_tensor(k)
if DISABLE_MMAP: # TODO: Not sure if this is the best way to bypass the mmap issues
tensor = tensor.to(device=device, copy=True)
sd[k] = tensor
sd[k] = f.get_tensor(k)
if return_metadata:
metadata = f.metadata()
except Exception as e:
Expand Down