Skip to content

Commit c37c74f

Browse files
authored
Fused QKV add node issue for GQA graph surgery (#1057)
### What does this PR do? Type of change: Bug fix There was a small issue where for models like qwen which have bias add nodes, while fusing the q,k,v matmul and q,k,v add nodes , the fused qkv bias add node was added to the graph before the fused qkv matmul node, causing the removal script to assume that the fused matmul and the add node were part of dead subgraph hence removing them. I just changed the order in which they are added. Now there are no issues. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Optimized graph surgery operations for ONNX model processing by adjusting node insertion timing during the multi-head to grouped-query attention transformation, maintaining functional equivalence while improving internal processing flow. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Hrishith Thadicherla <hthadicherla@nvidia.com>
1 parent 4c399af commit c37c74f

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

modelopt/onnx/graph_surgery/gqa_replacement.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ def _find_node_by_pattern(pattern: str, op_type: str | None = None) -> onnx.Node
707707
outputs=[qkv_add_output],
708708
name=qkv_add_name,
709709
)
710-
graph.node.append(qkv_add_node)
710+
qkv_matmul_nodes.append(qkv_add_node)
711711

712712
# Add value_info
713713
qkv_add_info = helper.make_tensor_value_info(

0 commit comments

Comments
 (0)