Skip to content

Commit f1dc97c

Browse files
committed
convert: simplify merge tensor condition
1 parent f327a49 commit f1dc97c

1 file changed

Lines changed: 15 additions & 20 deletions

File tree

convert_hf_to_gguf.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -522,28 +522,23 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
522522
if self.fuse_gate_up_exps and bid is not None:
523523
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_GATE_EXP, bid):
524524
self._gate_exp_buffer[bid] = data_torch
525-
# Check if up_exps is already buffered for this layer
526-
if bid in self._up_exp_buffer:
527-
gate_data = self._gate_exp_buffer.pop(bid)
528-
up_data = self._up_exp_buffer.pop(bid)
529-
# gate/up shape: (n_expert, n_ff, n_embd), concatenate to (n_expert, n_ff*2, n_embd)
530-
fused_data = torch.cat([gate_data, up_data], dim=1)
531-
fused_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_UP_EXP, bid)
532-
logger.info(f"Fused gate_exps and up_exps for layer {bid}")
533-
return [(fused_name, fused_data)]
534-
return [] # Wait for up_exps
535525
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_UP_EXP, bid):
536526
self._up_exp_buffer[bid] = data_torch
537-
# Check if gate_exps is already buffered for this layer
538-
if bid in self._gate_exp_buffer:
539-
gate_data = self._gate_exp_buffer.pop(bid)
540-
up_data = self._up_exp_buffer.pop(bid)
541-
# gate/up shape: (n_expert, n_ff, n_embd), concatenate to (n_expert, n_ff*2, n_embd)
542-
fused_data = torch.cat([gate_data, up_data], dim=1)
543-
fused_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_UP_EXP, bid)
544-
logger.info(f"Fused gate_exps and up_exps for layer {bid}")
545-
return [(fused_name, fused_data)]
546-
return [] # Wait for gate_exps
527+
528+
# Check if both gate and up are buffered for this layer
529+
if bid in self._gate_exp_buffer and bid in self._up_exp_buffer:
530+
gate_data = self._gate_exp_buffer.pop(bid)
531+
up_data = self._up_exp_buffer.pop(bid)
532+
# gate/up shape: (n_expert, n_ff, n_embd), concatenate to (n_expert, n_ff*2, n_embd)
533+
fused_data = torch.cat([gate_data, up_data], dim=1)
534+
fused_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_UP_EXP, bid)
535+
logger.info(f"Fused gate_exps and up_exps for layer {bid}")
536+
return [(fused_name, fused_data)]
537+
538+
# If we buffered a gate/up tensor, wait for the other
539+
if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_GATE_EXP, bid) or \
540+
self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_UP_EXP, bid):
541+
return []
547542

548543
return [(new_name, data_torch)]
549544

0 commit comments

Comments
 (0)