Skip to content

Commit eb9ffd5

Browse files
authored
[None][fix] Add missing allow_partial_loading param to CuteDSL and ConfigurableMoE load_weights (#12761)
Signed-off-by: Xianjie <5410381+qiaoxj07@users.noreply.github.com>
1 parent 1431153 commit eb9ffd5

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1237,15 +1237,15 @@ def create_weights(self):
12371237
)
12381238
return self.backend.create_weights()
12391239

1240-
def load_weights(self, weights: List[Dict]):
1240+
def load_weights(self, weights: List[Dict], allow_partial_loading: bool = False):
12411241
"""
12421242
Load weights - delegated to backend
12431243
12441244
"""
12451245
assert hasattr(self.backend, "load_weights"), (
12461246
f"Backend {self.backend.__class__.__name__} must implement load_weights()"
12471247
)
1248-
return self.backend.load_weights(weights)
1248+
return self.backend.load_weights(weights, allow_partial_loading)
12491249

12501250
def post_load_weights(self):
12511251
"""

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -979,8 +979,10 @@ def forward_chunk(
979979
enable_alltoall=False)
980980
return x
981981

982-
def load_weights(self, weights: Dict[str, torch.Tensor]):
983-
super().load_weights(weights)
982+
def load_weights(self,
983+
weights: List[Dict],
984+
allow_partial_loading: bool = False):
985+
super().load_weights(weights, allow_partial_loading)
984986
dwdp_handle_collector = getattr(self, "dwdp_handle_collector", None)
985987
if dwdp_handle_collector is not None:
986988
dwdp_handle_collector.register_weights(self)

0 commit comments

Comments
 (0)