@@ -1921,24 +1921,11 @@ def _transform_checkpoint_tensor(
19211921
19221922 expected_shape = tuple (expected_shape )
19231923
1924- # Some checkpoints store expert projections as (out, in) while others store
1925- # them as (in, out). Keep both candidates and let the defused leaf shape be
1926- # the final arbiter instead of hard-coding one model family's layout.
1927- candidates : list [tuple [torch .Tensor , bool ]] = [(tensor , False )]
1928- if tensor .ndim == 2 :
1929- transposed = tensor .transpose (0 , 1 ).contiguous ()
1930- if prefer_transposed is True :
1931- candidates = [(transposed , True ), (tensor , False )]
1932- elif prefer_transposed is None and transposed .shape != tensor .shape :
1933- candidates .append ((transposed , True ))
1934- elif prefer_transposed is False and transposed .shape != tensor .shape :
1935- candidates .append ((transposed , True ))
1936-
1937- for candidate , used_transpose in candidates :
1924+ def try_candidate (candidate : torch .Tensor , * , used_transpose : bool ) -> Optional [torch .Tensor ]:
19381925 if split_index is None :
19391926 if tuple (candidate .shape ) == expected_shape :
19401927 return candidate .contiguous ()
1941- continue
1928+ return None
19421929
19431930 preferred_dims : list [int ] = []
19441931 mapped_split_dim = split_dim
@@ -1976,7 +1963,27 @@ def _transform_checkpoint_tensor(
19761963 if tuple (split_tensor .shape ) == expected_shape :
19771964 return split_tensor
19781965
1979- return None
1966+ return None
1967+
1968+ # Some checkpoints store expert projections as (out, in) while others store
1969+ # them as (in, out). Try one layout at a time so the transpose copy is only
1970+ # materialized when the first candidate cannot satisfy the target shape.
1971+ if tensor .ndim == 2 and prefer_transposed is True :
1972+ # Honor explicit layout hints first; fall back only if shape/split checks fail.
1973+ transposed = tensor .transpose (0 , 1 ).contiguous ()
1974+ matched = try_candidate (transposed , used_transpose = True )
1975+ if matched is not None :
1976+ return matched
1977+ return try_candidate (tensor , used_transpose = False )
1978+
1979+ # Most checkpoint tensors already match the shell shape; avoid an eager transpose copy.
1980+ matched = try_candidate (tensor , used_transpose = False )
1981+ if matched is not None or tensor .ndim != 2 or tensor .shape [0 ] == tensor .shape [1 ]:
1982+ return matched
1983+
1984+ # Last resort for rectangular weights whose checkpoint layout is opposite of the shell.
1985+ transposed = tensor .transpose (0 , 1 ).contiguous ()
1986+ return try_candidate (transposed , used_transpose = True )
19801987
19811988 @staticmethod
19821989 def _resolve_prefer_transposed_hint (
@@ -2170,124 +2177,149 @@ def _copy_checkpoint_tensors_into_submodule(
21702177 continue
21712178 grouped_names .setdefault (shard , []).append (("buffer" , rel_name , full_name , expert_index , split_index , split_dim ))
21722179
2173- with torch .inference_mode ():
2174- for shard , entries in grouped_names .items ():
2175- shard_path = os .path .join (self .model_local_path , shard )
2176- with safe_open (shard_path , framework = "pt" , device = "cpu" ) as handler :
2177- for kind , rel_name , full_name , expert_index , split_index , split_dim in entries :
2178- target_tensor = t_params .get (rel_name ) if kind == "param" else t_bufs .get (rel_name )
2179- expected_shape = tuple (target_tensor .shape ) if target_tensor is not None else None
2180- prefer_transposed = self ._resolve_prefer_transposed_hint (
2181- target_model = target_model ,
2182- module_path = module_path ,
2183- rel_name = rel_name ,
2184- modules_by_name = modules_by_name ,
2185- )
2186- checkpoint_tensor = handler .get_tensor (full_name )
2187- tensor = self ._transform_checkpoint_tensor (
2188- checkpoint_tensor ,
2189- expert_index = expert_index ,
2190- split_index = split_index ,
2191- split_dim = split_dim ,
2192- expected_shape = expected_shape ,
2193- prefer_transposed = prefer_transposed ,
2194- )
2195- if tensor is None :
2196- raise RuntimeError (self ._materialization_issue_message (
2197- phase = "submodule materialization" ,
2198- kind = kind ,
2180+ total_entries = sum (len (entries ) for entries in grouped_names .values ())
2181+ progress = None
2182+ loaded_entries = 0
2183+ if total_entries :
2184+ progress = log .pb (range (total_entries )).manual ().set (show_left_steps = False )
2185+ module_label = module_path or "<root>"
2186+ progress .title (f"Loading checkpoint tensors ({ total_entries } )" )
2187+ progress .subtitle (f"{ module_label } : 0/{ total_entries } " )
2188+ progress .draw (force = True )
2189+
2190+ try :
2191+ with torch .inference_mode ():
2192+ for shard , entries in grouped_names .items ():
2193+ shard_path = os .path .join (self .model_local_path , shard )
2194+ with safe_open (shard_path , framework = "pt" , device = "cpu" ) as handler :
2195+ for kind , rel_name , full_name , expert_index , split_index , split_dim in entries :
2196+ if progress is not None :
2197+ progress .current_iter_step = loaded_entries
2198+ progress .subtitle (f"{ rel_name } : { loaded_entries + 1 } /{ total_entries } " )
2199+ progress .draw ()
2200+ target_tensor = t_params .get (rel_name ) if kind == "param" else t_bufs .get (rel_name )
2201+ expected_shape = tuple (target_tensor .shape ) if target_tensor is not None else None
2202+ prefer_transposed = self ._resolve_prefer_transposed_hint (
2203+ target_model = target_model ,
21992204 module_path = module_path ,
22002205 rel_name = rel_name ,
2201- reason = "checkpoint tensor could not be reshaped into the target layout" ,
2202- full_name = full_name ,
2203- source_shape = tuple (checkpoint_tensor .shape ),
2204- target_shape = expected_shape ,
2206+ modules_by_name = modules_by_name ,
2207+ )
2208+ checkpoint_tensor = handler .get_tensor (full_name )
2209+ tensor = self ._transform_checkpoint_tensor (
2210+ checkpoint_tensor ,
22052211 expert_index = expert_index ,
22062212 split_index = split_index ,
22072213 split_dim = split_dim ,
2208- ))
2209- if kind == "param" :
2210- target_param = t_params .get (rel_name )
2211- if target_param is None :
2212- raise RuntimeError (self ._materialization_issue_message (
2213- phase = "submodule materialization" ,
2214- kind = kind ,
2215- module_path = module_path ,
2216- rel_name = rel_name ,
2217- reason = "target tensor disappeared before materialization" ,
2218- full_name = full_name ,
2219- source_shape = tuple (tensor .shape ),
2220- expert_index = expert_index ,
2221- split_index = split_index ,
2222- split_dim = split_dim ,
2223- ))
2224- if target_param .shape != tensor .shape :
2214+ expected_shape = expected_shape ,
2215+ prefer_transposed = prefer_transposed ,
2216+ )
2217+ if tensor is None :
22252218 raise RuntimeError (self ._materialization_issue_message (
22262219 phase = "submodule materialization" ,
22272220 kind = kind ,
22282221 module_path = module_path ,
22292222 rel_name = rel_name ,
2230- reason = "target tensor shape does not match the transformed checkpoint tensor " ,
2223+ reason = "checkpoint tensor could not be reshaped into the target layout " ,
22312224 full_name = full_name ,
2232- source_shape = tuple (tensor .shape ),
2233- target_shape = tuple ( target_param . shape ) ,
2225+ source_shape = tuple (checkpoint_tensor .shape ),
2226+ target_shape = expected_shape ,
22342227 expert_index = expert_index ,
22352228 split_index = split_index ,
22362229 split_dim = split_dim ,
22372230 ))
2238- target_param_new = _ensure_target_storage_on_device_ (target_param , device )
2239- if target_param_new is not target_param :
2231+ if kind == "param" :
2232+ target_param = t_params .get (rel_name )
2233+ if target_param is None :
2234+ raise RuntimeError (self ._materialization_issue_message (
2235+ phase = "submodule materialization" ,
2236+ kind = kind ,
2237+ module_path = module_path ,
2238+ rel_name = rel_name ,
2239+ reason = "target tensor disappeared before materialization" ,
2240+ full_name = full_name ,
2241+ source_shape = tuple (tensor .shape ),
2242+ expert_index = expert_index ,
2243+ split_index = split_index ,
2244+ split_dim = split_dim ,
2245+ ))
2246+ if target_param .shape != tensor .shape :
2247+ raise RuntimeError (self ._materialization_issue_message (
2248+ phase = "submodule materialization" ,
2249+ kind = kind ,
2250+ module_path = module_path ,
2251+ rel_name = rel_name ,
2252+ reason = "target tensor shape does not match the transformed checkpoint tensor" ,
2253+ full_name = full_name ,
2254+ source_shape = tuple (tensor .shape ),
2255+ target_shape = tuple (target_param .shape ),
2256+ expert_index = expert_index ,
2257+ split_index = split_index ,
2258+ split_dim = split_dim ,
2259+ ))
2260+ target_param_new = _ensure_target_storage_on_device_ (target_param , device )
2261+ if target_param_new is not target_param :
2262+ t_parent , leaf = _get_parent_and_leaf_by_path (target_submodule , rel_name )
2263+ setattr (t_parent , leaf , target_param_new )
2264+ target_param = target_param_new
2265+ source = tensor .detach ()
2266+ if source .dtype != target_param .dtype :
2267+ source = source .to (dtype = target_param .dtype )
2268+ target_param .detach ().copy_ (source , non_blocking = (non_blocking and source .is_pinned ()))
2269+ else :
2270+ target_buffer = t_bufs .get (rel_name )
22402271 t_parent , leaf = _get_parent_and_leaf_by_path (target_submodule , rel_name )
2241- setattr (t_parent , leaf , target_param_new )
2242- target_param = target_param_new
2243- source = tensor .detach ()
2244- if source .dtype != target_param .dtype :
2245- source = source .to (dtype = target_param .dtype )
2246- target_param .detach ().copy_ (source , non_blocking = (non_blocking and source .is_pinned ()))
2247- continue
2248-
2249- target_buffer = t_bufs .get (rel_name )
2250- t_parent , leaf = _get_parent_and_leaf_by_path (target_submodule , rel_name )
2251- persistent = leaf not in getattr (t_parent , "_non_persistent_buffers_set" , set ())
2252-
2253- source = tensor .detach ()
2254- if target_buffer is None :
2255- new_buffer = source .to (device = device )
2256- t_parent .register_buffer (leaf , new_buffer , persistent = persistent )
2257- t_bufs [rel_name ] = new_buffer
2258- continue
2259-
2260- if tuple (target_buffer .shape ) != tuple (source .shape ):
2261- raise RuntimeError (self ._materialization_issue_message (
2262- phase = "submodule materialization" ,
2263- kind = kind ,
2264- module_path = module_path ,
2265- rel_name = rel_name ,
2266- reason = "target tensor shape does not match the transformed checkpoint tensor" ,
2267- full_name = full_name ,
2268- source_shape = tuple (source .shape ),
2269- target_shape = tuple (target_buffer .shape ),
2270- expert_index = expert_index ,
2271- split_index = split_index ,
2272- split_dim = split_dim ,
2273- ))
2274-
2275- if getattr (target_buffer , "is_meta" , False ) or target_buffer .device .type == "meta" :
2276- new_buffer = torch .empty_like (target_buffer , device = device )
2277- new_buffer .copy_ (source .to (dtype = new_buffer .dtype ), non_blocking = (non_blocking and source .is_pinned ()))
2278- t_parent .register_buffer (leaf , new_buffer , persistent = persistent )
2279- t_bufs [rel_name ] = new_buffer
2280- continue
2281-
2282- if target_buffer .device != device :
2283- new_buffer = torch .empty_like (target_buffer , device = device )
2284- new_buffer .copy_ (source .to (dtype = new_buffer .dtype ), non_blocking = (non_blocking and source .is_pinned ()))
2285- t_parent .register_buffer (leaf , new_buffer , persistent = persistent )
2286- t_bufs [rel_name ] = new_buffer
2287- else :
2288- if source .dtype != target_buffer .dtype :
2289- source = source .to (dtype = target_buffer .dtype )
2290- target_buffer .copy_ (source , non_blocking = (non_blocking and source .is_pinned ()))
2272+ persistent = leaf not in getattr (t_parent , "_non_persistent_buffers_set" , set ())
2273+
2274+ source = tensor .detach ()
2275+ if target_buffer is None :
2276+ new_buffer = source .to (device = device )
2277+ t_parent .register_buffer (leaf , new_buffer , persistent = persistent )
2278+ t_bufs [rel_name ] = new_buffer
2279+ else :
2280+ if tuple (target_buffer .shape ) != tuple (source .shape ):
2281+ raise RuntimeError (self ._materialization_issue_message (
2282+ phase = "submodule materialization" ,
2283+ kind = kind ,
2284+ module_path = module_path ,
2285+ rel_name = rel_name ,
2286+ reason = "target tensor shape does not match the transformed checkpoint tensor" ,
2287+ full_name = full_name ,
2288+ source_shape = tuple (source .shape ),
2289+ target_shape = tuple (target_buffer .shape ),
2290+ expert_index = expert_index ,
2291+ split_index = split_index ,
2292+ split_dim = split_dim ,
2293+ ))
2294+
2295+ if getattr (target_buffer , "is_meta" , False ) or target_buffer .device .type == "meta" :
2296+ new_buffer = torch .empty_like (target_buffer , device = device )
2297+ new_buffer .copy_ (
2298+ source .to (dtype = new_buffer .dtype ),
2299+ non_blocking = (non_blocking and source .is_pinned ()),
2300+ )
2301+ t_parent .register_buffer (leaf , new_buffer , persistent = persistent )
2302+ t_bufs [rel_name ] = new_buffer
2303+ elif target_buffer .device != device :
2304+ new_buffer = torch .empty_like (target_buffer , device = device )
2305+ new_buffer .copy_ (
2306+ source .to (dtype = new_buffer .dtype ),
2307+ non_blocking = (non_blocking and source .is_pinned ()),
2308+ )
2309+ t_parent .register_buffer (leaf , new_buffer , persistent = persistent )
2310+ t_bufs [rel_name ] = new_buffer
2311+ else :
2312+ if source .dtype != target_buffer .dtype :
2313+ source = source .to (dtype = target_buffer .dtype )
2314+ target_buffer .copy_ (source , non_blocking = (non_blocking and source .is_pinned ()))
2315+
2316+ loaded_entries += 1
2317+ if progress is not None :
2318+ progress .current_iter_step = loaded_entries
2319+ progress .draw ()
2320+ finally :
2321+ if progress is not None :
2322+ progress .close ()
22912323
22922324 self ._restore_missing_nonpersistent_buffers (
22932325 target_model = target_model ,
0 commit comments