Skip to content

Commit 9ce3f31

Browse files
committed
fix: fix PR issues
1 parent 35ff15e commit 9ce3f31

1 file changed

Lines changed: 149 additions & 160 deletions

File tree

checkpoint_engine/ps.py

Lines changed: 149 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class ParameterMeta(BaseModel):
9393
name: str
9494
dtype: _TorchDtype
9595
shape: _TorchSize
96-
manually_aligned: bool = True
96+
aligned_size: int
9797

9898

9999
class BucketRange(NamedTuple):
@@ -142,11 +142,7 @@ def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
142142
def _to_named_tensor(metas: list[ParameterMeta], offset: int = 0) -> list[dict]:
143143
ret = []
144144
for meta in metas:
145-
size = (
146-
_align_size(meta.dtype, meta.shape)
147-
if meta.manually_aligned
148-
else meta.dtype.itemsize * meta.shape.numel()
149-
)
145+
size = meta.aligned_size
150146
ret.append(
151147
{
152148
"name": meta.name,
@@ -428,6 +424,7 @@ class TPMeta(BaseModel):
428424
name=parameter_name,
429425
shape=meta["shape"],
430426
dtype=meta["dtype"],
427+
aligned_size=_align_size(meta["dtype"], meta["shape"]),
431428
)
432429
tp_meta = tp_metas[parameter_name]
433430
if tp_meta.concat_dim != -1:
@@ -437,7 +434,10 @@ class TPMeta(BaseModel):
437434
shape = list(parameter_metas[name].shape)
438435
shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size
439436
parameter_metas[name] = ParameterMeta(
440-
name=name, shape=torch.Size(shape), dtype=parameter_metas[name].dtype
437+
name=name,
438+
shape=torch.Size(shape),
439+
dtype=parameter_metas[name].dtype,
440+
aligned_size=_align_size(parameter_metas[name].dtype, torch.Size(shape)),
441441
)
442442
weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])]
443443
# TODO: here concat is serial, which may be slow
@@ -455,20 +455,15 @@ class TPMeta(BaseModel):
455455
return parameters
456456

457457

458-
def _register_checkpoint(
459-
*,
460-
files: list[str],
461-
named_tensors: dict[str, torch.Tensor],
462-
rank: int | None = None,
463-
) -> list[MemoryBuffer]:
464-
logger.info(
465-
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
466-
)
467-
if not files and not named_tensors:
468-
return []
469-
memory_buffers: list[MemoryBuffer] = []
458+
def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]:
459+
def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer:
460+
"""
461+
safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
462+
We load the safetensors file as bytes, then parse the header manually to get parameter metas.
463+
The actual tensor data is in the remaining bytes and is naturally aligned.
464+
We pin the remaining bytes as the buffer, making pinning faster.
465+
"""
470466

471-
def inplace_pin_memory(files: list[str]) -> list[MemoryBuffer]:
472467
def _pin(t: torch.Tensor):
473468
"""
474469
Pin the memory of tensor in-place.
@@ -478,138 +473,142 @@ def _pin(t: torch.Tensor):
478473
r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
479474
assert r == 0, f"pin memory error, error code: {r}"
480475

481-
def _inplace_pin_memory(file_path: str) -> MemoryBuffer:
482-
"""
483-
safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
484-
We load the safetensors file as bytes, then parse the header manually to get parameter metas.
485-
The actual tensor data is in the remaining bytes and is naturally aligned.
486-
We pin the remaining bytes as the buffer, making pinning faster.
487-
"""
488-
# TODO: should only support /dev/shm? but we found files in disk also work?
489-
size = os.stat(file_path).st_size
490-
flag_size = 8
491-
t = torch.from_file(file_path, True, size, dtype=torch.uint8)
492-
assert t.nbytes > flag_size, (
493-
f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}"
494-
)
495-
os.remove(file_path)
496-
start_pos = (
497-
int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False)
498-
+ flag_size
499-
)
500-
header_tensor = t[flag_size:start_pos]
501-
header = json.loads(header_tensor.numpy().tobytes())
502-
if "__metadata__" in header:
503-
header.pop("__metadata__")
476+
# TODO: should only support /dev/shm? but we found files in disk also work?
477+
size = os.stat(file_path).st_size
478+
flag_size = 8
479+
t = torch.from_file(file_path, True, size, dtype=torch.uint8)
480+
assert t.nbytes > flag_size, (
481+
f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}"
482+
)
483+
start_pos = (
484+
int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False)
485+
+ flag_size
486+
)
487+
header_tensor = t[flag_size:start_pos]
488+
header = json.loads(header_tensor.numpy().tobytes())
489+
if "__metadata__" in header:
490+
header.pop("__metadata__")
504491

505-
metas: list[ParameterMeta] = []
506-
offset = 0
507-
try:
508-
for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]):
509-
start, end = meta["data_offsets"]
510-
# safetensors format ensures offsets are aligned
511-
assert offset == start, f"offset {offset} should be equal to start {start}"
512-
metas.append(
513-
ParameterMeta(
514-
name=name,
515-
dtype=_getdtype(meta["dtype"]),
516-
shape=torch.Size(meta["shape"]),
517-
manually_aligned=False,
518-
)
519-
)
520-
offset = end
521-
except Exception as e:
522-
logger.error(f"fail to parse safetensors header from {file_path}: {e}")
523-
raise
524-
525-
buffer = t[start_pos:]
526-
assert offset == buffer.nbytes, (
527-
f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}"
528-
)
529-
_pin(buffer)
530-
return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas)
531-
532-
local_memory_buffers: list[MemoryBuffer] = []
533-
lock = threading.Lock()
534-
idx = 0
535-
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
536-
futures = [executor.submit(_inplace_pin_memory, file) for file in files]
537-
for future in concurrent.futures.as_completed(futures):
538-
memory_buffer = future.result()
539-
with lock:
540-
local_memory_buffers.append(memory_buffer)
541-
logger.info(
542-
f"[rank{rank}] register pin_memory for file in /dev/shm {idx + 1}/{len(files)} finished"
492+
metas: list[ParameterMeta] = []
493+
offset = 0
494+
try:
495+
for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]):
496+
start, end = meta["data_offsets"]
497+
# safetensors format ensures offsets are aligned
498+
assert offset == start, f"offset {offset} should be equal to start {start}"
499+
metas.append(
500+
ParameterMeta(
501+
name=name,
502+
dtype=_getdtype(meta["dtype"]),
503+
shape=torch.Size(meta["shape"]),
504+
aligned_size=end - start,
543505
)
544-
idx += 1
545-
return local_memory_buffers
546-
547-
def normal_pin_memory(
548-
files: list[str], named_tensors: dict[str, torch.Tensor]
549-
) -> list[MemoryBuffer]:
550-
parameters = _load_checkpoint(files)
551-
if named_tensors:
552-
parameters.update(named_tensors)
553-
bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values()))
554-
555-
class MemoryBucket(BaseModel):
556-
size: int
557-
metas: list[ParameterMeta]
558-
559-
buckets: list[MemoryBucket] = []
560-
buckets.append(MemoryBucket(size=0, metas=[]))
561-
for name, tensor in sorted(parameters.items()):
562-
size = _align_size(tensor.dtype, tensor.shape)
563-
if buckets[-1].size + size > bucket_size:
564-
assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty"
565-
buckets.append(MemoryBucket(size=0, metas=[]))
566-
buckets[-1].metas.append(
567-
ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype)
568-
)
569-
buckets[-1].size += size
506+
)
507+
offset = end
508+
except Exception as e:
509+
logger.error(f"fail to parse safetensors header from {file_path}: {e}")
510+
raise
570511

571-
local_memory_buffers = [
572-
MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas)
573-
for bucket in buckets
574-
]
512+
buffer = t[start_pos:]
513+
assert offset == buffer.nbytes, (
514+
f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}"
515+
)
516+
# Remove the file after successfully loading. This will avoid doubling the memory usage.
517+
# We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
518+
os.remove(file_path)
519+
_pin(buffer)
520+
logger.info(
521+
f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
522+
)
523+
return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas)
575524

576-
def register_pin_memory(idx: int, size: int) -> tuple[int, torch.Tensor]:
577-
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
578-
return idx, buffer
579-
580-
def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
581-
buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
582-
583-
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
584-
futures = [
585-
executor.submit(register_pin_memory, idx, bucket.size)
586-
for idx, bucket in enumerate(buckets)
587-
]
588-
new_futures = []
589-
for future in concurrent.futures.as_completed(futures):
590-
idx, buffer = future.result()
591-
assert buffer.numel() == buckets[idx].size, (
592-
f"buffer numel {buffer.numel()} should be equal to bucket size {buckets[idx].size}"
593-
)
594-
local_memory_buffers[idx].buffer = buffer
595-
logger.info(
596-
f"[rank{rank}] register pin_memory for bucket {idx + 1}/{len(buckets)} finished, "
597-
f"size {buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer"
525+
local_memory_buffers: list[MemoryBuffer] = []
526+
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
527+
local_memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files))
528+
return local_memory_buffers
529+
530+
531+
def _normal_pin_memory(
532+
files: list[str],
533+
named_tensors: dict[str, torch.Tensor],
534+
rank: int | None = None,
535+
) -> list[MemoryBuffer]:
536+
parameters = _load_checkpoint(files)
537+
if named_tensors:
538+
parameters.update(named_tensors)
539+
bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values()))
540+
541+
class MemoryBucket(BaseModel):
542+
size: int
543+
metas: list[ParameterMeta]
544+
545+
buckets: list[MemoryBucket] = []
546+
buckets.append(MemoryBucket(size=0, metas=[]))
547+
for name, tensor in sorted(parameters.items()):
548+
size = _align_size(tensor.dtype, tensor.shape)
549+
if buckets[-1].size + size > bucket_size:
550+
assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty"
551+
buckets.append(MemoryBucket(size=0, metas=[]))
552+
buckets[-1].metas.append(
553+
ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype, aligned_size=size)
554+
)
555+
buckets[-1].size += size
556+
557+
local_memory_buffers = [
558+
MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas)
559+
for bucket in buckets
560+
]
561+
562+
def register_pin_memory(idx: int, size: int) -> tuple[int, torch.Tensor]:
563+
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
564+
return idx, buffer
565+
566+
def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
567+
buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
568+
569+
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
570+
futures = [
571+
executor.submit(register_pin_memory, idx, bucket.size)
572+
for idx, bucket in enumerate(buckets)
573+
]
574+
new_futures = []
575+
for future in concurrent.futures.as_completed(futures):
576+
idx, buffer = future.result()
577+
assert buffer.numel() == buckets[idx].size, (
578+
f"buffer numel {buffer.numel()} should be equal to bucket size {buckets[idx].size}"
579+
)
580+
local_memory_buffers[idx].buffer = buffer
581+
logger.info(
582+
f"[rank{rank}] register pin_memory for bucket {idx + 1}/{len(buckets)} finished, "
583+
f"size {buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer"
584+
)
585+
offset = 0
586+
for meta in buckets[idx].metas:
587+
name = meta.name
588+
tensor = parameters[name]
589+
size = _align_size(tensor.dtype, tensor.shape)
590+
assert size == _align_size(meta.dtype, meta.shape), (
591+
f"tensor {name} size {size} should be equal to meta size {_align_size(meta.dtype, meta.shape)}"
598592
)
599-
offset = 0
600-
for meta in buckets[idx].metas:
601-
name = meta.name
602-
tensor = parameters[name]
603-
size = _align_size(tensor.dtype, tensor.shape)
604-
assert size == _align_size(meta.dtype, meta.shape), (
605-
f"tensor {name} size {size} should be equal to meta size {_align_size(meta.dtype, meta.shape)}"
606-
)
607-
new_futures.append(executor.submit(register_tensor, buffer, offset, tensor))
608-
offset += size
609-
for future in concurrent.futures.as_completed(new_futures):
610-
future.result()
611-
return local_memory_buffers
593+
new_futures.append(executor.submit(register_tensor, buffer, offset, tensor))
594+
offset += size
595+
for future in concurrent.futures.as_completed(new_futures):
596+
future.result()
597+
return local_memory_buffers
598+
612599

600+
def _register_checkpoint(
601+
*,
602+
files: list[str],
603+
named_tensors: dict[str, torch.Tensor],
604+
rank: int | None = None,
605+
) -> list[MemoryBuffer]:
606+
logger.info(
607+
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
608+
)
609+
if not files and not named_tensors:
610+
return []
611+
memory_buffers: list[MemoryBuffer] = []
613612
files_to_inplace_pin = [
614613
file
615614
for file in files
@@ -618,11 +617,10 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
618617
files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
619618
if files_to_normal_pin or named_tensors:
620619
memory_buffers.extend(
621-
normal_pin_memory(files=files_to_normal_pin, named_tensors=named_tensors)
620+
_normal_pin_memory(files=files_to_normal_pin, named_tensors=named_tensors, rank=rank)
622621
)
623622
if files_to_inplace_pin:
624-
memory_buffers.extend(inplace_pin_memory(files_to_inplace_pin))
625-
623+
memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank))
626624
return memory_buffers
627625

628626

@@ -671,11 +669,7 @@ def _gen_h2d_buckets(
671669
for idx, metas in enumerate(items.memory_buffer_metas_list):
672670
start_offset, offset = 0, 0
673671
for meta in metas.metas:
674-
s = (
675-
_align_size(meta.dtype, meta.shape)
676-
if meta.manually_aligned
677-
else meta.dtype.itemsize * meta.shape.numel()
678-
)
672+
s = meta.aligned_size
679673
if buckets[-1][1].size + s > bucket_size:
680674
if offset - start_offset > 0:
681675
buckets[-1][1].ranges.append(
@@ -1159,12 +1153,7 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
11591153
for items in self._current_global_parameter_metas.values():
11601154
for metas_list in items.memory_buffer_metas_list:
11611155
for meta in metas_list.metas:
1162-
max_tensor_bytes = max(
1163-
max_tensor_bytes,
1164-
_align_size(meta.dtype, meta.shape)
1165-
if meta.manually_aligned
1166-
else meta.dtype.itemsize * meta.shape.numel(),
1167-
)
1156+
max_tensor_bytes = max(max_tensor_bytes, meta.aligned_size)
11681157
free_bytes_divided_3 = free_bytes // (3 * _ALIGN_SIZE) * _ALIGN_SIZE
11691158
if max_tensor_bytes <= free_bytes_divided_3 and not disable_h2d_buffer:
11701159
self._logger_rank0(f"[rank{self._rank}] use h2d buffer")

0 commit comments

Comments
 (0)