Skip to content

Commit 975c52b

Browse files
modify qkv attention axis
1 parent 6ccd433 commit 975c52b

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

onnxscript/rewriter/ort_fusions/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def rewrite(
210210
# Dh_v = self.bindings.get("Dh_v")
211211
# qkv_hidden_sizes = [Dh_q, Dh_k, Dh_v]
212212
if self._no_slice:
213-
qkv_weight = op.Concat(q_mul, k_mul, v_mul, axis=0)
213+
qkv_weight = op.Concat(q_mul, k_mul, v_mul, axis=1)
214214

215215
if self._has_past:
216216
attention, present = op.Attention(

0 commit comments

Comments
 (0)