@@ -72,6 +72,14 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
7272 return shard_size * marlin_tile_size , shard_offset * marlin_tile_size
7373
7474
75+ def adjust_block_scale_shard (weight_block_size , shard_size , shard_offset ):
76+ assert weight_block_size is not None
77+ block_n = weight_block_size [0 ]
78+ shard_offset = (shard_offset + block_n - 1 ) // block_n
79+ shard_size = (shard_size + block_n - 1 ) // block_n
80+ return shard_size , shard_offset
81+
82+
7583def adjust_bitsandbytes_4bit_shard (
7684 param : Parameter , shard_offsets : dict [str , tuple [int , int ]], loaded_shard_id : str
7785) -> tuple [int , int ]:
@@ -744,7 +752,12 @@ def weight_loader(
744752 assert param_data .shape == loaded_weight .shape
745753 param_data .copy_ (loaded_weight )
746754
747- def _load_fused_module_from_checkpoint (self , param : BaseAphroditeParameter , loaded_weight : torch .Tensor ):
755+ def _load_fused_module_from_checkpoint (
756+ self ,
757+ param : BaseAphroditeParameter ,
758+ loaded_weight : torch .Tensor ,
759+ output_sizes : list [int ] | None = None ,
760+ ):
748761 """
749762 Handle special case for models where MLP layers are already
750763 fused on disk. In this case, we have no shard id. This function
@@ -757,7 +770,8 @@ def _load_fused_module_from_checkpoint(self, param: BaseAphroditeParameter, load
757770
758771 current_shard_offset = 0
759772 shard_offsets : list [tuple [int , int , int ]] = []
760- for i , output_size in enumerate (self .output_sizes ):
773+ output_sizes = output_sizes or self .output_sizes
774+ for i , output_size in enumerate (output_sizes ):
761775 shard_offsets .append ((i , current_shard_offset , output_size ))
762776 current_shard_offset += output_size
763777
@@ -776,37 +790,76 @@ def _load_fused_module_from_checkpoint(self, param: BaseAphroditeParameter, load
776790 loaded_weight_shard = loaded_weight .narrow (param .output_dim , shard_offset , shard_size )
777791 self .weight_loader_v2 (param , loaded_weight_shard , shard_id )
778792
793+ def validate_shard_id (self , loaded_shard_id : int | tuple [int , ...] | None ):
794+ if loaded_shard_id is None :
795+ return
796+ if isinstance (loaded_shard_id , tuple ):
797+ for idx in loaded_shard_id :
798+ if not (0 <= idx < len (self .output_sizes )):
799+ raise ValueError (
800+ f"Shard id index { idx } should be between 0 and "
801+ f"{ len (self .output_sizes ) - 1 } . Got shard id { loaded_shard_id } ."
802+ )
803+ if len (loaded_shard_id ) > 1 and any (
804+ b - a != 1 for a , b in zip (loaded_shard_id [:- 1 ], loaded_shard_id [1 :])
805+ ):
806+ raise ValueError (
807+ "Shard id with multiple indices should be consecutive. "
808+ f"Got shard id { loaded_shard_id } ."
809+ )
810+ return
811+ if isinstance (loaded_shard_id , int ):
812+ if loaded_shard_id < 0 or loaded_shard_id >= len (self .output_sizes ):
813+ raise ValueError (
814+ f"Shard id should be between 0 and { len (self .output_sizes ) - 1 } . "
815+ f"Got shard id { loaded_shard_id } ."
816+ )
817+ return
818+ raise ValueError ("This line should not be reached" )
819+
779820 def weight_loader_v2 (
780821 self ,
781822 param : BaseAphroditeParameter ,
782823 loaded_weight : torch .Tensor ,
783- loaded_shard_id : int | None = None ,
824+ loaded_shard_id : tuple [ int , ...] | int | None = None ,
784825 ):
785- if loaded_shard_id is None :
826+ self .validate_shard_id (loaded_shard_id )
827+ if loaded_shard_id is None or isinstance (loaded_shard_id , tuple ):
786828 if isinstance (param , PerTensorScaleParameter ):
787829 param .load_merged_column_weight (loaded_weight = loaded_weight , shard_id = 0 )
788830 return
789831 elif type (param ) in (RowAphroditeParameter , BaseAphroditeParameter ):
790832 param .load_merged_column_weight (loaded_weight = loaded_weight )
791833 return
792- # TODO: @dsikka - move to parameter.py
793- self ._load_fused_module_from_checkpoint (param , loaded_weight )
834+ output_sizes = (
835+ [self .output_sizes [idx ] for idx in loaded_shard_id ]
836+ if loaded_shard_id
837+ else None
838+ )
839+ if isinstance (param , BlockQuantScaleParameter ):
840+ weight_block_size = getattr (self , "weight_block_size" , None )
841+ output_sizes = [
842+ adjust_block_scale_shard (weight_block_size , size , 0 )[0 ]
843+ for size in (output_sizes or self .output_sizes )
844+ ]
845+ self ._load_fused_module_from_checkpoint (
846+ param , loaded_weight , output_sizes = output_sizes
847+ )
794848 return
795849
796850 assert loaded_shard_id < len (self .output_sizes )
797851
852+ shard_offset = sum (self .output_sizes [:loaded_shard_id ]) // self .tp_size
853+ shard_size = self .output_sizes [loaded_shard_id ] // self .tp_size
854+
798855 if isinstance (param , BlockQuantScaleParameter ):
799856 assert self .quant_method is not None
800- # Assume the weight block size has been set by quant method
801857 assert hasattr (self , "weight_block_size" )
802858 weight_block_size = self .weight_block_size
803859 assert weight_block_size is not None
804- block_n , _ = weight_block_size [0 ], weight_block_size [1 ]
805- shard_offset = ((sum (self .output_sizes [:loaded_shard_id ]) + block_n - 1 ) // block_n ) // self .tp_size
806- shard_size = (self .output_sizes [loaded_shard_id ] + block_n - 1 ) // block_n // self .tp_size
807- else :
808- shard_offset = sum (self .output_sizes [:loaded_shard_id ]) // self .tp_size
809- shard_size = self .output_sizes [loaded_shard_id ] // self .tp_size
860+ shard_size , shard_offset = adjust_block_scale_shard (
861+ weight_block_size , shard_size , shard_offset
862+ )
810863
811864 param .load_merged_column_weight (
812865 loaded_weight = loaded_weight ,
0 commit comments