@@ -522,28 +522,23 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
522522 if self.fuse_gate_up_exps and bid is not None:
523523 if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_GATE_EXP, bid):
524524 self._gate_exp_buffer[bid] = data_torch
525- # Check if up_exps is already buffered for this layer
526- if bid in self._up_exp_buffer:
527- gate_data = self._gate_exp_buffer.pop(bid)
528- up_data = self._up_exp_buffer.pop(bid)
529- # gate/up shape: (n_expert, n_ff, n_embd), concatenate to (n_expert, n_ff*2, n_embd)
530- fused_data = torch.cat([gate_data, up_data], dim=1)
531- fused_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_UP_EXP, bid)
532- logger.info(f"Fused gate_exps and up_exps for layer {bid}")
533- return [(fused_name, fused_data)]
534- return [] # Wait for up_exps
535525 elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_UP_EXP, bid):
536526 self._up_exp_buffer[bid] = data_torch
537- # Check if gate_exps is already buffered for this layer
538- if bid in self._gate_exp_buffer:
539- gate_data = self._gate_exp_buffer.pop(bid)
540- up_data = self._up_exp_buffer.pop(bid)
541- # gate/up shape: (n_expert, n_ff, n_embd), concatenate to (n_expert, n_ff*2, n_embd)
542- fused_data = torch.cat([gate_data, up_data], dim=1)
543- fused_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_UP_EXP, bid)
544- logger.info(f"Fused gate_exps and up_exps for layer {bid}")
545- return [(fused_name, fused_data)]
546- return [] # Wait for gate_exps
527+
528+ # Check if both gate and up are buffered for this layer
529+ if bid in self._gate_exp_buffer and bid in self._up_exp_buffer:
530+ gate_data = self._gate_exp_buffer.pop(bid)
531+ up_data = self._up_exp_buffer.pop(bid)
532+ # gate/up shape: (n_expert, n_ff, n_embd), concatenate to (n_expert, n_ff*2, n_embd)
533+ fused_data = torch.cat([gate_data, up_data], dim=1)
534+ fused_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_UP_EXP, bid)
535+ logger.info(f"Fused gate_exps and up_exps for layer {bid}")
536+ return [(fused_name, fused_data)]
537+
538+ # If we buffered a gate/up tensor, wait for the other
539+ if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_GATE_EXP, bid) or \
540+ self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_UP_EXP, bid):
541+ return []
547542
548543 return [(new_name, data_torch)]
549544
0 commit comments