@@ -36,7 +36,7 @@ Tensor.reshape op_5 1 1 v 14 shape=(%batch,%size,%num_heads,%fea
3636Tensor.permute op_6 1 1 10 16 dims=(0,2,1,3)
3737Tensor.permute op_7 1 1 12 17 dims=(0,2,1,3)
3838Tensor.permute op_8 1 1 14 18 dims=(0,2,1,3)
39- F.scaled_dot_product_attention op_9 4 1 16 17 18 attn_mask 19 dropout_p=0.0 is_causal=False scale=%scale
39+ F.scaled_dot_product_attention sdpa 4 1 16 17 18 attn_mask 19 %*=%*
4040Tensor.permute op_10 1 1 19 20 dims=(0,2,1,3)
4141Tensor.reshape op_11 1 1 20 21 shape=(%batch,%size,%embed_dim)
4242nn.Linear out_proj 1 1 21 out bias=%outbias in_features=%embed_dim out_features=%qdim @bias @weight
@@ -54,6 +54,23 @@ pnnx.Output output 1 0 out
5454 return " sdpa_attention" ;
5555 }
5656
57+ bool match (const std::map<std::string, Parameter>& captured_params) const
58+ {
59+ if (captured_params.find (" sdpa.dropout_p" ) != captured_params.end ())
60+ {
61+ if (captured_params.at (" sdpa.dropout_p" ).type != 3 || captured_params.at (" sdpa.dropout_p" ).f != 0 .f )
62+ return false ;
63+ }
64+
65+ if (captured_params.find (" sdpa.is_causal" ) != captured_params.end ())
66+ {
67+ if (captured_params.at (" sdpa.is_causal" ).type != 1 || captured_params.at (" sdpa.is_causal" ).b != false )
68+ return false ;
69+ }
70+
71+ return true ;
72+ }
73+
5774 void write (Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
5875 {
5976 op->params [" 0" ] = captured_params.at (" embed_dim" );
@@ -68,7 +85,8 @@ pnnx.Output output 1 0 out
6885 op->params [" 3" ] = kdim;
6986 op->params [" 4" ] = vdim;
7087 op->params [" 5" ] = 1 ;
71- op->params [" 6" ] = captured_params.at (" scale" );
88+ if (captured_params.find (" sdpa.scale" ) != captured_params.end ())
89+ op->params [" 6" ] = captured_params.at (" sdpa.scale" );
7290
7391 op->attrs [" 0" ] = Attribute ();
7492 op->attrs [" 0" ].data = {0 , 0 , 0 , 0 };
@@ -138,7 +156,7 @@ Tensor.reshape op_5 1 1 v 14 shape=(%batch,%size,%num_heads,%fea
138156Tensor.permute op_6 1 1 10 16 dims=(0,2,1,3)
139157Tensor.permute op_7 1 1 12 17 dims=(0,2,1,3)
140158Tensor.permute op_8 1 1 14 18 dims=(0,2,1,3)
141- F.scaled_dot_product_attention op_9 4 1 16 17 18 attn_mask 19 dropout_p=0.0 is_causal=False scale=%scale
159+ F.scaled_dot_product_attention sdpa 4 1 16 17 18 attn_mask 19 %*=%*
142160Tensor.permute op_10 1 1 19 20 dims=(0,2,1,3)
143161Tensor.reshape op_11 1 1 20 21 shape=(%batch,%qsize,%embed_dim)
144162nn.Linear out_proj 1 1 21 out bias=%outbias in_features=%embed_dim out_features=%qdim @bias @weight
@@ -166,7 +184,7 @@ Tensor.reshape op_5 1 1 v 14 shape=(%batch,%size,%num_heads,%fea
166184Tensor.permute op_6 1 1 10 16 dims=(0,2,1,3)
167185Tensor.permute op_7 1 1 12 17 dims=(0,2,1,3)
168186Tensor.permute op_8 1 1 14 18 dims=(0,2,1,3)
169- F.scaled_dot_product_attention op_9 3 1 16 17 18 19 dropout_p=0.0 is_causal=False attn_mask=None scale=%scale
187+ F.scaled_dot_product_attention sdpa 3 1 16 17 18 19 %*=%*
170188Tensor.permute op_10 1 1 19 20 dims=(0,2,1,3)
171189Tensor.reshape op_11 1 1 20 21 shape=(%batch,%size,%embed_dim)
172190nn.Linear out_proj 1 1 21 out bias=%outbias in_features=%embed_dim out_features=%qdim @bias @weight
@@ -201,7 +219,7 @@ Tensor.reshape op_5 1 1 v 14 shape=(%batch,%size,%num_heads,%fea
201219Tensor.permute op_6 1 1 10 16 dims=(0,2,1,3)
202220Tensor.permute op_7 1 1 12 17 dims=(0,2,1,3)
203221Tensor.permute op_8 1 1 14 18 dims=(0,2,1,3)
204- F.scaled_dot_product_attention op_9 3 1 16 17 18 19 dropout_p=0.0 is_causal=False attn_mask=None scale=%scale
222+ F.scaled_dot_product_attention sdpa 3 1 16 17 18 19 %*=%*
205223Tensor.permute op_10 1 1 19 20 dims=(0,2,1,3)
206224Tensor.reshape op_11 1 1 20 21 shape=(%batch,%qsize,%embed_dim)
207225nn.Linear out_proj 1 1 21 out bias=%outbias in_features=%embed_dim out_features=%qdim @bias @weight
@@ -218,6 +236,40 @@ pnnx.Output output 1 0 out
218236
219237REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS (F_scaled_dot_product_attention_3, 10 )
220238
239+ class F_scaled_dot_product_attention_4 : public F_scaled_dot_product_attention
240+ {
241+ public:
242+ const char * match_pattern_graph () const
243+ {
244+ return R"PNNXIR( 7767517
245+ 15 14
246+ pnnx.Input input 0 1 input
247+ nn.Linear op_0 1 1 input q bias=%qbias in_features=%qdim out_features=%embed_dim @bias @weight
248+ nn.Linear op_1 1 1 input k bias=%kbias in_features=%kdim out_features=%embed_dim @bias @weight
249+ nn.Linear op_2 1 1 input v bias=%vbias in_features=%vdim out_features=%embed_dim @bias @weight
250+ Tensor.view op_3 1 1 q 10 shape=(%batch,%size,%num_heads,%feat_per_head)
251+ Tensor.view op_4 1 1 k 12 shape=(%batch,%size,%num_heads,%feat_per_head)
252+ Tensor.view op_5 1 1 v 14 shape=(%batch,%size,%num_heads,%feat_per_head)
253+ torch.transpose op_6 1 1 10 16 dim0=1 dim1=2
254+ torch.transpose op_7 1 1 12 17 dim0=1 dim1=2
255+ torch.transpose op_8 1 1 14 18 dim0=1 dim1=2
256+ F.scaled_dot_product_attention sdpa 3 1 16 17 18 19 %*=%*
257+ torch.transpose op_10 1 1 19 20 dim0=1 dim1=2
258+ Tensor.reshape op_11 1 1 20 21 shape=(%batch,%size,%embed_dim)
259+ nn.Linear out_proj 1 1 21 out bias=%outbias in_features=%embed_dim out_features=%qdim @bias @weight
260+ pnnx.Output output 1 0 out
261+ )PNNXIR" ;
262+ }
263+
264+ void write (Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
265+ {
266+ F_scaled_dot_product_attention::write (op, captured_params, captured_attrs);
267+ op->params [" 5" ] = 0 ;
268+ }
269+ };
270+
271+ REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS (F_scaled_dot_product_attention_4, 10 )
272+
221273} // namespace ncnn
222274
223275} // namespace pnnx
0 commit comments