Skip to content

Commit 0be3a34

Browse files
authored
pnnx ncnn handle optional sdpa scale param (#6000)
1 parent 8363040 commit 0be3a34

1 file changed

Lines changed: 57 additions & 5 deletions

File tree

tools/pnnx/src/pass_ncnn/F_scaled_dot_product_attention.cpp

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Tensor.reshape op_5 1 1 v 14 shape=(%batch,%size,%num_heads,%fea
3636
Tensor.permute op_6 1 1 10 16 dims=(0,2,1,3)
3737
Tensor.permute op_7 1 1 12 17 dims=(0,2,1,3)
3838
Tensor.permute op_8 1 1 14 18 dims=(0,2,1,3)
39-
F.scaled_dot_product_attention op_9 4 1 16 17 18 attn_mask 19 dropout_p=0.0 is_causal=False scale=%scale
39+
F.scaled_dot_product_attention sdpa 4 1 16 17 18 attn_mask 19 %*=%*
4040
Tensor.permute op_10 1 1 19 20 dims=(0,2,1,3)
4141
Tensor.reshape op_11 1 1 20 21 shape=(%batch,%size,%embed_dim)
4242
nn.Linear out_proj 1 1 21 out bias=%outbias in_features=%embed_dim out_features=%qdim @bias @weight
@@ -54,6 +54,23 @@ pnnx.Output output 1 0 out
5454
return "sdpa_attention";
5555
}
5656

57+
bool match(const std::map<std::string, Parameter>& captured_params) const
58+
{
59+
if (captured_params.find("sdpa.dropout_p") != captured_params.end())
60+
{
61+
if (captured_params.at("sdpa.dropout_p").type != 3 || captured_params.at("sdpa.dropout_p").f != 0.f)
62+
return false;
63+
}
64+
65+
if (captured_params.find("sdpa.is_causal") != captured_params.end())
66+
{
67+
if (captured_params.at("sdpa.is_causal").type != 1 || captured_params.at("sdpa.is_causal").b != false)
68+
return false;
69+
}
70+
71+
return true;
72+
}
73+
5774
void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
5875
{
5976
op->params["0"] = captured_params.at("embed_dim");
@@ -68,7 +85,8 @@ pnnx.Output output 1 0 out
6885
op->params["3"] = kdim;
6986
op->params["4"] = vdim;
7087
op->params["5"] = 1;
71-
op->params["6"] = captured_params.at("scale");
88+
if (captured_params.find("sdpa.scale") != captured_params.end())
89+
op->params["6"] = captured_params.at("sdpa.scale");
7290

7391
op->attrs["0"] = Attribute();
7492
op->attrs["0"].data = {0, 0, 0, 0};
@@ -138,7 +156,7 @@ Tensor.reshape op_5 1 1 v 14 shape=(%batch,%size,%num_heads,%fea
138156
Tensor.permute op_6 1 1 10 16 dims=(0,2,1,3)
139157
Tensor.permute op_7 1 1 12 17 dims=(0,2,1,3)
140158
Tensor.permute op_8 1 1 14 18 dims=(0,2,1,3)
141-
F.scaled_dot_product_attention op_9 4 1 16 17 18 attn_mask 19 dropout_p=0.0 is_causal=False scale=%scale
159+
F.scaled_dot_product_attention sdpa 4 1 16 17 18 attn_mask 19 %*=%*
142160
Tensor.permute op_10 1 1 19 20 dims=(0,2,1,3)
143161
Tensor.reshape op_11 1 1 20 21 shape=(%batch,%qsize,%embed_dim)
144162
nn.Linear out_proj 1 1 21 out bias=%outbias in_features=%embed_dim out_features=%qdim @bias @weight
@@ -166,7 +184,7 @@ Tensor.reshape op_5 1 1 v 14 shape=(%batch,%size,%num_heads,%fea
166184
Tensor.permute op_6 1 1 10 16 dims=(0,2,1,3)
167185
Tensor.permute op_7 1 1 12 17 dims=(0,2,1,3)
168186
Tensor.permute op_8 1 1 14 18 dims=(0,2,1,3)
169-
F.scaled_dot_product_attention op_9 3 1 16 17 18 19 dropout_p=0.0 is_causal=False attn_mask=None scale=%scale
187+
F.scaled_dot_product_attention sdpa 3 1 16 17 18 19 %*=%*
170188
Tensor.permute op_10 1 1 19 20 dims=(0,2,1,3)
171189
Tensor.reshape op_11 1 1 20 21 shape=(%batch,%size,%embed_dim)
172190
nn.Linear out_proj 1 1 21 out bias=%outbias in_features=%embed_dim out_features=%qdim @bias @weight
@@ -201,7 +219,7 @@ Tensor.reshape op_5 1 1 v 14 shape=(%batch,%size,%num_heads,%fea
201219
Tensor.permute op_6 1 1 10 16 dims=(0,2,1,3)
202220
Tensor.permute op_7 1 1 12 17 dims=(0,2,1,3)
203221
Tensor.permute op_8 1 1 14 18 dims=(0,2,1,3)
204-
F.scaled_dot_product_attention op_9 3 1 16 17 18 19 dropout_p=0.0 is_causal=False attn_mask=None scale=%scale
222+
F.scaled_dot_product_attention sdpa 3 1 16 17 18 19 %*=%*
205223
Tensor.permute op_10 1 1 19 20 dims=(0,2,1,3)
206224
Tensor.reshape op_11 1 1 20 21 shape=(%batch,%qsize,%embed_dim)
207225
nn.Linear out_proj 1 1 21 out bias=%outbias in_features=%embed_dim out_features=%qdim @bias @weight
@@ -218,6 +236,40 @@ pnnx.Output output 1 0 out
218236

219237
REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_3, 10)
220238

239+
class F_scaled_dot_product_attention_4 : public F_scaled_dot_product_attention
240+
{
241+
public:
242+
const char* match_pattern_graph() const
243+
{
244+
return R"PNNXIR(7767517
245+
15 14
246+
pnnx.Input input 0 1 input
247+
nn.Linear op_0 1 1 input q bias=%qbias in_features=%qdim out_features=%embed_dim @bias @weight
248+
nn.Linear op_1 1 1 input k bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight
249+
nn.Linear op_2 1 1 input v bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight
250+
Tensor.view op_3 1 1 q 10 shape=(%batch,%size,%num_heads,%feat_per_head)
251+
Tensor.view op_4 1 1 k 12 shape=(%batch,%size,%num_heads,%feat_per_head)
252+
Tensor.view op_5 1 1 v 14 shape=(%batch,%size,%num_heads,%feat_per_head)
253+
torch.transpose op_6 1 1 10 16 dim0=1 dim1=2
254+
torch.transpose op_7 1 1 12 17 dim0=1 dim1=2
255+
torch.transpose op_8 1 1 14 18 dim0=1 dim1=2
256+
F.scaled_dot_product_attention sdpa 3 1 16 17 18 19 %*=%*
257+
torch.transpose op_10 1 1 19 20 dim0=1 dim1=2
258+
Tensor.reshape op_11 1 1 20 21 shape=(%batch,%size,%embed_dim)
259+
nn.Linear out_proj 1 1 21 out bias=%outbias in_features=%embed_dim out_features=%qdim @bias @weight
260+
pnnx.Output output 1 0 out
261+
)PNNXIR";
262+
}
263+
264+
void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
265+
{
266+
F_scaled_dot_product_attention::write(op, captured_params, captured_attrs);
267+
op->params["5"] = 0;
268+
}
269+
};
270+
271+
REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_4, 10)
272+
221273
} // namespace ncnn
222274

223275
} // namespace pnnx

0 commit comments

Comments
 (0)