@@ -51,3 +51,61 @@ def forward(self, x):
5151 attention_value = attention_value + x
5252 attention_value = self .ff_self (attention_value ) + attention_value
5353 return attention_value .swapaxes (2 , 1 ).view (- 1 , self .channels , self .size [0 ], self .size [1 ])
54+
55+
56+ class SelfAttentionAD (nn .Module ):
57+ """
58+ Adaptive head count SelfAttention block
59+ """
60+
61+ def __init__ (self , channels , size , act = "silu" , dropout = 0.1 ):
62+ """
63+ Initialize the adaptive head count self-attention block
64+ :param channels: Channels
65+ :param size: Size
66+ :param act: Activation function
67+ """
68+ super (SelfAttentionAD , self ).__init__ ()
69+ self .channels = channels
70+ self .size = size
71+ self .dropout = dropout
72+
73+ # Adaptive head count
74+ head_count = max (1 , channels // 64 )
75+
76+ # batch_first is not supported in pytorch 1.8.
77+ # If you want to support upgrading to 1.9 and above, or use the following code to transpose
78+ self .mha = nn .MultiheadAttention (embed_dim = channels , num_heads = head_count , batch_first = True )
79+ self .ln = nn .LayerNorm (normalized_shape = [channels ])
80+ self .ff_self = nn .Sequential (
81+ nn .LayerNorm (normalized_shape = [channels ]),
82+ nn .Linear (in_features = channels , out_features = channels ),
83+ get_activation_function (name = act ),
84+ nn .Dropout (dropout ),
85+ nn .Linear (in_features = channels , out_features = channels ),
86+ nn .Dropout (dropout ),
87+ )
88+
89+ def forward (self , x ):
90+ """
91+ SelfAttention forward
92+ :param x: Input
93+ :return: attention_value
94+ """
95+ batch , channels , height , width = x .shape
96+ assert height == self .size [0 ] and width == self .size [1 ], \
97+ f"Input size { height } x{ width } does not match the expected size { self .size [0 ]} x{ self .size [1 ]} "
98+ # Flatten the spatial dimension into sequence dimensions
99+ # (batch, channels, height*width) -> (batch, seq_len, channels)
100+ x_flat = x .flatten (2 ).swapaxes (1 , 2 )
101+
102+ # First residual calculation
103+ x_ln = self .ln (x_flat )
104+ # batch_first is not supported in pytorch 1.8.
105+ # If you want to support upgrading to 1.9 and above, or use the following code to transpose
106+ attention_value , _ = self .mha (x_ln , x_ln , x_ln )
107+ attention_value = attention_value + x_flat
108+
109+ # Second residual calculation
110+ attention_value = self .ff_self (attention_value ) + attention_value
111+ return attention_value .swapaxes (1 , 2 ).view (batch , channels , height , width )
0 commit comments