File tree Expand file tree Collapse file tree 2 files changed +6
-4
lines changed
tensorrt_llm/_torch/modules/fused_moe Expand file tree Collapse file tree 2 files changed +6
-4
lines changed Original file line number Diff line number Diff line change @@ -1237,15 +1237,15 @@ def create_weights(self):
12371237 )
12381238 return self .backend .create_weights ()
12391239
1240- def load_weights (self , weights : List [Dict ]):
1240+ def load_weights (self , weights : List [Dict ], allow_partial_loading : bool = False ):
12411241 """
12421242 Load weights - delegated to backend
12431243
12441244 """
12451245 assert hasattr (self .backend , "load_weights" ), (
12461246 f"Backend { self .backend .__class__ .__name__ } must implement load_weights()"
12471247 )
1248- return self .backend .load_weights (weights )
1248+ return self .backend .load_weights (weights , allow_partial_loading )
12491249
12501250 def post_load_weights (self ):
12511251 """
Original file line number Diff line number Diff line change @@ -979,8 +979,10 @@ def forward_chunk(
979979 enable_alltoall = False )
980980 return x
981981
982- def load_weights (self , weights : Dict [str , torch .Tensor ]):
983- super ().load_weights (weights )
982+ def load_weights (self ,
983+ weights : List [Dict ],
984+ allow_partial_loading : bool = False ):
985+ super ().load_weights (weights , allow_partial_loading )
984986 dwdp_handle_collector = getattr (self , "dwdp_handle_collector" , None )
985987 if dwdp_handle_collector is not None :
986988 dwdp_handle_collector .register_weights (self )
You can’t perform that action at this time.
0 commit comments