@@ -93,7 +93,7 @@ class ParameterMeta(BaseModel):
9393 name : str
9494 dtype : _TorchDtype
9595 shape : _TorchSize
96- manually_aligned : bool = True
96+ aligned_size : int
9797
9898
9999class BucketRange (NamedTuple ):
@@ -142,11 +142,7 @@ def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
142142def _to_named_tensor (metas : list [ParameterMeta ], offset : int = 0 ) -> list [dict ]:
143143 ret = []
144144 for meta in metas :
145- size = (
146- _align_size (meta .dtype , meta .shape )
147- if meta .manually_aligned
148- else meta .dtype .itemsize * meta .shape .numel ()
149- )
145+ size = meta .aligned_size
150146 ret .append (
151147 {
152148 "name" : meta .name ,
@@ -428,6 +424,7 @@ class TPMeta(BaseModel):
428424 name = parameter_name ,
429425 shape = meta ["shape" ],
430426 dtype = meta ["dtype" ],
427+ aligned_size = _align_size (meta ["dtype" ], meta ["shape" ]),
431428 )
432429 tp_meta = tp_metas [parameter_name ]
433430 if tp_meta .concat_dim != - 1 :
@@ -437,7 +434,10 @@ class TPMeta(BaseModel):
437434 shape = list (parameter_metas [name ].shape )
438435 shape [tp_meta .concat_dim ] = shape [tp_meta .concat_dim ] * tp_meta .size
439436 parameter_metas [name ] = ParameterMeta (
440- name = name , shape = torch .Size (shape ), dtype = parameter_metas [name ].dtype
437+ name = name ,
438+ shape = torch .Size (shape ),
439+ dtype = parameter_metas [name ].dtype ,
440+ aligned_size = _align_size (parameter_metas [name ].dtype , torch .Size (shape )),
441441 )
442442 weights_in_cpu = [parameters_with_tp [name ][key ] for key in sorted (parameters_with_tp [name ])]
443443 # TODO: here concat is serial, which may be slow
@@ -455,20 +455,15 @@ class TPMeta(BaseModel):
455455 return parameters
456456
457457
458- def _register_checkpoint (
459- * ,
460- files : list [str ],
461- named_tensors : dict [str , torch .Tensor ],
462- rank : int | None = None ,
463- ) -> list [MemoryBuffer ]:
464- logger .info (
465- f"[rank{ rank } ] start to register checkpoint with { len (files )} files and { len (named_tensors )} named_tensors"
466- )
467- if not files and not named_tensors :
468- return []
469- memory_buffers : list [MemoryBuffer ] = []
458+ def _inplace_pin_memory (files : list [str ], rank : int | None = None ) -> list [MemoryBuffer ]:
459+ def _parse_and_pin_from_safetensors (file_path : str ) -> MemoryBuffer :
460+ """
461+ safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
462+ We load the safetensors file as bytes, then parse the header manually to get parameter metas.
463+ The actual tensor data is in the remaining bytes and is naturally aligned.
464+ We pin the remaining bytes as the buffer, making pinning faster.
465+ """
470466
471- def inplace_pin_memory (files : list [str ]) -> list [MemoryBuffer ]:
472467 def _pin (t : torch .Tensor ):
473468 """
474469 Pin the memory of tensor in-place.
@@ -478,138 +473,142 @@ def _pin(t: torch.Tensor):
478473 r = cudart .cudaHostRegister (t .data_ptr (), t .numel () * t .element_size (), 0 )
479474 assert r == 0 , f"pin memory error, error code: { r } "
480475
481- def _inplace_pin_memory (file_path : str ) -> MemoryBuffer :
482- """
483- safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
484- We load the safetensors file as bytes, then parse the header manually to get parameter metas.
485- The actual tensor data is in the remaining bytes and is naturally aligned.
486- We pin the remaining bytes as the buffer, making pinning faster.
487- """
488- # TODO: should only support /dev/shm? but we found files in disk also work?
489- size = os .stat (file_path ).st_size
490- flag_size = 8
491- t = torch .from_file (file_path , True , size , dtype = torch .uint8 )
492- assert t .nbytes > flag_size , (
493- f"tensor nbytes { t .nbytes } should be greater than flag_size { flag_size } "
494- )
495- os .remove (file_path )
496- start_pos = (
497- int .from_bytes (t [0 :flag_size ].numpy ().tobytes (), byteorder = "little" , signed = False )
498- + flag_size
499- )
500- header_tensor = t [flag_size :start_pos ]
501- header = json .loads (header_tensor .numpy ().tobytes ())
502- if "__metadata__" in header :
503- header .pop ("__metadata__" )
476+ # TODO: should only support /dev/shm? but we found files in disk also work?
477+ size = os .stat (file_path ).st_size
478+ flag_size = 8
479+ t = torch .from_file (file_path , True , size , dtype = torch .uint8 )
480+ assert t .nbytes > flag_size , (
481+ f"tensor nbytes { t .nbytes } should be greater than flag_size { flag_size } "
482+ )
483+ start_pos = (
484+ int .from_bytes (t [0 :flag_size ].numpy ().tobytes (), byteorder = "little" , signed = False )
485+ + flag_size
486+ )
487+ header_tensor = t [flag_size :start_pos ]
488+ header = json .loads (header_tensor .numpy ().tobytes ())
489+ if "__metadata__" in header :
490+ header .pop ("__metadata__" )
504491
505- metas : list [ParameterMeta ] = []
506- offset = 0
507- try :
508- for name , meta in sorted (header .items (), key = lambda x : x [1 ]["data_offsets" ]):
509- start , end = meta ["data_offsets" ]
510- # safetensors format ensures offsets are aligned
511- assert offset == start , f"offset { offset } should be equal to start { start } "
512- metas .append (
513- ParameterMeta (
514- name = name ,
515- dtype = _getdtype (meta ["dtype" ]),
516- shape = torch .Size (meta ["shape" ]),
517- manually_aligned = False ,
518- )
519- )
520- offset = end
521- except Exception as e :
522- logger .error (f"fail to parse safetensors header from { file_path } : { e } " )
523- raise
524-
525- buffer = t [start_pos :]
526- assert offset == buffer .nbytes , (
527- f"offset { offset } should be equal to buffer.nbytes { buffer .nbytes } "
528- )
529- _pin (buffer )
530- return MemoryBuffer (buffer = buffer , size = buffer .nbytes , metas = metas )
531-
532- local_memory_buffers : list [MemoryBuffer ] = []
533- lock = threading .Lock ()
534- idx = 0
535- with concurrent .futures .ThreadPoolExecutor (max_workers = 32 ) as executor :
536- futures = [executor .submit (_inplace_pin_memory , file ) for file in files ]
537- for future in concurrent .futures .as_completed (futures ):
538- memory_buffer = future .result ()
539- with lock :
540- local_memory_buffers .append (memory_buffer )
541- logger .info (
542- f"[rank{ rank } ] register pin_memory for file in /dev/shm { idx + 1 } /{ len (files )} finished"
492+ metas : list [ParameterMeta ] = []
493+ offset = 0
494+ try :
495+ for name , meta in sorted (header .items (), key = lambda x : x [1 ]["data_offsets" ]):
496+ start , end = meta ["data_offsets" ]
497+ # safetensors format ensures offsets are aligned
498+ assert offset == start , f"offset { offset } should be equal to start { start } "
499+ metas .append (
500+ ParameterMeta (
501+ name = name ,
502+ dtype = _getdtype (meta ["dtype" ]),
503+ shape = torch .Size (meta ["shape" ]),
504+ aligned_size = end - start ,
543505 )
544- idx += 1
545- return local_memory_buffers
546-
547- def normal_pin_memory (
548- files : list [str ], named_tensors : dict [str , torch .Tensor ]
549- ) -> list [MemoryBuffer ]:
550- parameters = _load_checkpoint (files )
551- if named_tensors :
552- parameters .update (named_tensors )
553- bucket_size = max (4 << 30 , max (_align_size (x .dtype , x .shape ) for x in parameters .values ()))
554-
555- class MemoryBucket (BaseModel ):
556- size : int
557- metas : list [ParameterMeta ]
558-
559- buckets : list [MemoryBucket ] = []
560- buckets .append (MemoryBucket (size = 0 , metas = []))
561- for name , tensor in sorted (parameters .items ()):
562- size = _align_size (tensor .dtype , tensor .shape )
563- if buckets [- 1 ].size + size > bucket_size :
564- assert buckets [- 1 ], f"buckets[{ len (buckets ) - 1 } ] should not be empty"
565- buckets .append (MemoryBucket (size = 0 , metas = []))
566- buckets [- 1 ].metas .append (
567- ParameterMeta (name = name , shape = tensor .shape , dtype = tensor .dtype )
568- )
569- buckets [- 1 ].size += size
506+ )
507+ offset = end
508+ except Exception as e :
509+ logger .error (f"fail to parse safetensors header from { file_path } : { e } " )
510+ raise
570511
571- local_memory_buffers = [
572- MemoryBuffer (buffer = torch .empty (0 ), size = bucket .size , metas = bucket .metas )
573- for bucket in buckets
574- ]
512+ buffer = t [start_pos :]
513+ assert offset == buffer .nbytes , (
514+ f"offset { offset } should be equal to buffer.nbytes { buffer .nbytes } "
515+ )
516+ # Remove the file after successfully loading. This will avoid doubling the memory usage.
517+ # We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
518+ os .remove (file_path )
519+ _pin (buffer )
520+ logger .info (
521+ f"[rank{ rank } ] inplace pin memory for file { file_path } finished, size { buffer .nbytes / 1024 / 1024 :.2f} MiB"
522+ )
523+ return MemoryBuffer (buffer = buffer , size = buffer .nbytes , metas = metas )
575524
576- def register_pin_memory (idx : int , size : int ) -> tuple [int , torch .Tensor ]:
577- buffer = torch .empty (size , dtype = torch .uint8 , pin_memory = True )
578- return idx , buffer
579-
580- def register_tensor (buffer : torch .Tensor , offset : int , tensor : torch .Tensor ):
581- buffer [offset : offset + tensor .nbytes ] = tensor .view (- 1 ).view (dtype = torch .uint8 )
582-
583- with concurrent .futures .ThreadPoolExecutor (max_workers = 32 ) as executor :
584- futures = [
585- executor .submit (register_pin_memory , idx , bucket .size )
586- for idx , bucket in enumerate (buckets )
587- ]
588- new_futures = []
589- for future in concurrent .futures .as_completed (futures ):
590- idx , buffer = future .result ()
591- assert buffer .numel () == buckets [idx ].size , (
592- f"buffer numel { buffer .numel ()} should be equal to bucket size { buckets [idx ].size } "
593- )
594- local_memory_buffers [idx ].buffer = buffer
595- logger .info (
596- f"[rank{ rank } ] register pin_memory for bucket { idx + 1 } /{ len (buckets )} finished, "
597- f"size { buffer .numel () / 1024 / 1024 :.2f} MiB, start to copy tensors to buffer"
525+ local_memory_buffers : list [MemoryBuffer ] = []
526+ with concurrent .futures .ThreadPoolExecutor (max_workers = 32 ) as executor :
527+ local_memory_buffers = list (executor .map (_parse_and_pin_from_safetensors , files ))
528+ return local_memory_buffers
529+
530+
531+ def _normal_pin_memory (
532+ files : list [str ],
533+ named_tensors : dict [str , torch .Tensor ],
534+ rank : int | None = None ,
535+ ) -> list [MemoryBuffer ]:
536+ parameters = _load_checkpoint (files )
537+ if named_tensors :
538+ parameters .update (named_tensors )
539+ bucket_size = max (4 << 30 , max (_align_size (x .dtype , x .shape ) for x in parameters .values ()))
540+
541+ class MemoryBucket (BaseModel ):
542+ size : int
543+ metas : list [ParameterMeta ]
544+
545+ buckets : list [MemoryBucket ] = []
546+ buckets .append (MemoryBucket (size = 0 , metas = []))
547+ for name , tensor in sorted (parameters .items ()):
548+ size = _align_size (tensor .dtype , tensor .shape )
549+ if buckets [- 1 ].size + size > bucket_size :
550+ assert buckets [- 1 ], f"buckets[{ len (buckets ) - 1 } ] should not be empty"
551+ buckets .append (MemoryBucket (size = 0 , metas = []))
552+ buckets [- 1 ].metas .append (
553+ ParameterMeta (name = name , shape = tensor .shape , dtype = tensor .dtype , aligned_size = size )
554+ )
555+ buckets [- 1 ].size += size
556+
557+ local_memory_buffers = [
558+ MemoryBuffer (buffer = torch .empty (0 ), size = bucket .size , metas = bucket .metas )
559+ for bucket in buckets
560+ ]
561+
562+ def register_pin_memory (idx : int , size : int ) -> tuple [int , torch .Tensor ]:
563+ buffer = torch .empty (size , dtype = torch .uint8 , pin_memory = True )
564+ return idx , buffer
565+
566+ def register_tensor (buffer : torch .Tensor , offset : int , tensor : torch .Tensor ):
567+ buffer [offset : offset + tensor .nbytes ] = tensor .view (- 1 ).view (dtype = torch .uint8 )
568+
569+ with concurrent .futures .ThreadPoolExecutor (max_workers = 32 ) as executor :
570+ futures = [
571+ executor .submit (register_pin_memory , idx , bucket .size )
572+ for idx , bucket in enumerate (buckets )
573+ ]
574+ new_futures = []
575+ for future in concurrent .futures .as_completed (futures ):
576+ idx , buffer = future .result ()
577+ assert buffer .numel () == buckets [idx ].size , (
578+ f"buffer numel { buffer .numel ()} should be equal to bucket size { buckets [idx ].size } "
579+ )
580+ local_memory_buffers [idx ].buffer = buffer
581+ logger .info (
582+ f"[rank{ rank } ] register pin_memory for bucket { idx + 1 } /{ len (buckets )} finished, "
583+ f"size { buffer .numel () / 1024 / 1024 :.2f} MiB, start to copy tensors to buffer"
584+ )
585+ offset = 0
586+ for meta in buckets [idx ].metas :
587+ name = meta .name
588+ tensor = parameters [name ]
589+ size = _align_size (tensor .dtype , tensor .shape )
590+ assert size == _align_size (meta .dtype , meta .shape ), (
591+ f"tensor { name } size { size } should be equal to meta size { _align_size (meta .dtype , meta .shape )} "
598592 )
599- offset = 0
600- for meta in buckets [idx ].metas :
601- name = meta .name
602- tensor = parameters [name ]
603- size = _align_size (tensor .dtype , tensor .shape )
604- assert size == _align_size (meta .dtype , meta .shape ), (
605- f"tensor { name } size { size } should be equal to meta size { _align_size (meta .dtype , meta .shape )} "
606- )
607- new_futures .append (executor .submit (register_tensor , buffer , offset , tensor ))
608- offset += size
609- for future in concurrent .futures .as_completed (new_futures ):
610- future .result ()
611- return local_memory_buffers
593+ new_futures .append (executor .submit (register_tensor , buffer , offset , tensor ))
594+ offset += size
595+ for future in concurrent .futures .as_completed (new_futures ):
596+ future .result ()
597+ return local_memory_buffers
598+
612599
600+ def _register_checkpoint (
601+ * ,
602+ files : list [str ],
603+ named_tensors : dict [str , torch .Tensor ],
604+ rank : int | None = None ,
605+ ) -> list [MemoryBuffer ]:
606+ logger .info (
607+ f"[rank{ rank } ] start to register checkpoint with { len (files )} files and { len (named_tensors )} named_tensors"
608+ )
609+ if not files and not named_tensors :
610+ return []
611+ memory_buffers : list [MemoryBuffer ] = []
613612 files_to_inplace_pin = [
614613 file
615614 for file in files
@@ -618,11 +617,10 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
618617 files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin ]
619618 if files_to_normal_pin or named_tensors :
620619 memory_buffers .extend (
621- normal_pin_memory (files = files_to_normal_pin , named_tensors = named_tensors )
620+ _normal_pin_memory (files = files_to_normal_pin , named_tensors = named_tensors , rank = rank )
622621 )
623622 if files_to_inplace_pin :
624- memory_buffers .extend (inplace_pin_memory (files_to_inplace_pin ))
625-
623+ memory_buffers .extend (_inplace_pin_memory (files_to_inplace_pin , rank = rank ))
626624 return memory_buffers
627625
628626
@@ -671,11 +669,7 @@ def _gen_h2d_buckets(
671669 for idx , metas in enumerate (items .memory_buffer_metas_list ):
672670 start_offset , offset = 0 , 0
673671 for meta in metas .metas :
674- s = (
675- _align_size (meta .dtype , meta .shape )
676- if meta .manually_aligned
677- else meta .dtype .itemsize * meta .shape .numel ()
678- )
672+ s = meta .aligned_size
679673 if buckets [- 1 ][1 ].size + s > bucket_size :
680674 if offset - start_offset > 0 :
681675 buckets [- 1 ][1 ].ranges .append (
@@ -1159,12 +1153,7 @@ def _detect_bucket_size(self, *, disable_h2d_buffer: bool = False) -> tuple[int,
11591153 for items in self ._current_global_parameter_metas .values ():
11601154 for metas_list in items .memory_buffer_metas_list :
11611155 for meta in metas_list .metas :
1162- max_tensor_bytes = max (
1163- max_tensor_bytes ,
1164- _align_size (meta .dtype , meta .shape )
1165- if meta .manually_aligned
1166- else meta .dtype .itemsize * meta .shape .numel (),
1167- )
1156+ max_tensor_bytes = max (max_tensor_bytes , meta .aligned_size )
11681157 free_bytes_divided_3 = free_bytes // (3 * _ALIGN_SIZE ) * _ALIGN_SIZE
11691158 if max_tensor_bytes <= free_bytes_divided_3 and not disable_h2d_buffer :
11701159 self ._logger_rank0 (f"[rank{ self ._rank } ] use h2d buffer" )
0 commit comments