|
9 | 9 | import onnxscript.ir as ir |
10 | 10 | from onnxscript.rewriter import _fusion_utils, pattern |
11 | 11 |
|
12 | | -""" |
13 | | -The MultiHeadAttention pattern: generate an instance |
14 | | - MHA (query, key, value, None, None, mask, past_key, past_value) |
15 | | -where query has shape (B, S, D), key has shape (B, Skv, D), and value has shape (B, Skv, Dv). |
16 | | -The next two inputs bias and key_padding_mask are None in this pattern. The mask (attention_bias) |
17 | | -must be of shape (1 or B, 1 or H, S, St). past_key and past_value are of shape (B, H, Spast, Dh). |
18 | | -
|
19 | | -We use the following abbreviations for the dimensions: |
20 | | -B: Batch size |
21 | | -S: Sequence length |
22 | | -D: input embedding dimension |
23 | | -Dv: value hidden size (usually, Dv = D) |
24 | | -H: number of heads |
25 | | -Dh: head size or embedding dimension per head (usually, D = H * Dh) |
26 | | -Skv: key/value sequence length |
27 | | -St: total sequence length |
28 | | -
|
29 | | -In the sequel, the suffix "_BHSDh" indicates that the tensor has the shape (B, H, S, Dh). |
30 | | -The suffix "BH_Skv_Dh" indicates that the tensor has the shape (B*H, Skv, Dh). |
31 | | -""" |
| 12 | +valid_float_types = [ir.DataType.FLOAT, ir.DataType.FLOAT16] |
32 | 13 |
|
33 | 14 | Dim = Union[int, ir.SymbolicDim] |
34 | 15 |
|
@@ -102,6 +83,13 @@ def check( |
102 | 83 | def no_match(val: ir.Value, dims: Sequence[str]) -> bool: |
103 | 84 | return not _fusion_utils._check_shape(self.bindings, val, dims) |
104 | 85 |
|
| 86 | + if query_matmul.dtype not in valid_float_types: |
| 87 | + return check_result.fail("Query is not a float or float16 type.", query_matmul) |
| 88 | + if key_matmul.dtype not in valid_float_types: |
| 89 | + return check_result.fail("Key is not a float or float16 type.", key_matmul) |
| 90 | + if value_matmul.dtype not in valid_float_types: |
| 91 | + return check_result.fail("Value is not a float or float16 type.", value_matmul) |
| 92 | + |
105 | 93 | if no_match(query_matmul, ["B", "S", "D"]): |
106 | 94 | return check_result.fail( |
107 | 95 | f"Shape mismatch: {query_matmul} does not match expected dimensions ['B', 'S', 'D']", |
@@ -148,19 +136,26 @@ def rewrite( |
148 | 136 | num_heads, |
149 | 137 | **_, |
150 | 138 | ): |
151 | | - if self._q_no_bias: |
152 | | - q_bias = op.Constant( |
153 | | - value=ir.tensor(numpy.zeros((self.Dh_q,), dtype=numpy.float32)) |
154 | | - ) |
155 | | - if self._k_no_bias: |
156 | | - k_bias = op.Constant( |
157 | | - value=ir.tensor(numpy.zeros((self.Dh_k,), dtype=numpy.float32)) |
158 | | - ) |
159 | | - if self._v_no_bias: |
160 | | - v_bias = op.Constant( |
161 | | - value=ir.tensor(numpy.zeros((self.Dh_v,), dtype=numpy.float32)) |
162 | | - ) |
163 | | - bias = op.Concat(q_bias, k_bias, v_bias, axis=0) |
| 139 | + if self._q_no_bias and self._k_no_bias and self._v_no_bias: |
| 140 | + bias = None |
| 141 | + else: |
| 142 | + if self._q_no_bias: |
| 143 | + q_bias = op.Constant( |
| 144 | + value=ir.tensor( |
| 145 | + numpy.zeros((self.Dh_q,), dtype=query_matmul.dtype.numpy()) |
| 146 | + ) |
| 147 | + ) |
| 148 | + if self._k_no_bias: |
| 149 | + k_bias = op.Constant( |
| 150 | + value=ir.tensor(numpy.zeros((self.Dh_k,), dtype=key_matmul.dtype.numpy())) |
| 151 | + ) |
| 152 | + if self._v_no_bias: |
| 153 | + v_bias = op.Constant( |
| 154 | + value=ir.tensor( |
| 155 | + numpy.zeros((self.Dh_v,), dtype=value_matmul.dtype.numpy()) |
| 156 | + ) |
| 157 | + ) |
| 158 | + bias = op.Concat(q_bias, k_bias, v_bias, axis=0) |
164 | 159 | return op.MultiHeadAttention( |
165 | 160 | query_matmul, |
166 | 161 | key_matmul, |
|
0 commit comments