1+ import torch
2+ import torch .nn as nn
3+ import torch .nn .functional as F
4+
5+ # Define the Self-Attention layer
6+ class SelfAttention (nn .Module ):
7+ def __init__ (self , in_channels ):
8+ super (SelfAttention , self ).__init__ ()
9+ # Convolutional layers for generating query, key, and value matrices from the input feature maps
10+ self .query = nn .Conv2d (in_channels , in_channels , kernel_size = 1 ) # For the query matrix (Q)
11+ self .key = nn .Conv2d (in_channels , in_channels , kernel_size = 1 ) # For the key matrix (K)
12+ self .value = nn .Conv2d (in_channels , in_channels , kernel_size = 1 ) # For the value matrix (V)
13+ # Softmax layer to normalize the attention scores
14+ self .softmax = nn .Softmax (dim = - 1 )
15+
16+ def forward (self , x ):
17+ batch_size , C , H , W = x .size () # Get the dimensions of the input tensor
18+
19+ # Generate query, key, and value matrices by applying respective convolutions
20+ queries = self .query (x ).view (batch_size , C , - 1 ) # Reshape to (B, C, H*W)
21+ keys = self .key (x ).view (batch_size , C , - 1 ) # Reshape to (B, C, H*W)
22+ values = self .value (x ).view (batch_size , C , - 1 ) # Reshape to (B, C, H*W)
23+
24+ # Compute the attention scores using matrix multiplication between queries and keys
25+ attention_scores = torch .bmm (queries .permute (0 , 2 , 1 ), keys ) # Output: (B, H*W, H*W)
26+ attention_scores = self .softmax (attention_scores ) # Apply softmax to get the attention weights
27+
28+ # Multiply values with attention scores and reshape back to the original size
29+ out = torch .bmm (values , attention_scores .permute (0 , 2 , 1 )) # Output: (B, C, H*W)
30+ return out .view (batch_size , C , H , W ) # Reshape to (B, C, H, W) without changing the original shape
31+
32+ # Define the CNN with Self-Attention model
33+ class CNNattention (nn .Module ):
34+ def __init__ (self , in_channels , out_channels = 1 ): # Default output channels set to 1
35+ super (CNNattention , self ).__init__ ()
36+ # Encoder part (downsampling with convolutional layers)
37+ self .enc1 = nn .Conv2d (in_channels , 64 , kernel_size = 3 , padding = 1 ) # First encoding layer, input has 4 channels
38+ self .enc2 = nn .Conv2d (64 , 128 , kernel_size = 3 , padding = 1 ) # Second encoding layer
39+ self .enc3 = nn .Conv2d (128 , 256 , kernel_size = 3 , padding = 1 ) # Third encoding layer
40+
41+ # Self-Attention layer applied after the encoder
42+ self .attention = SelfAttention (256 ) # Self-Attention applied to 256 channels
43+
44+ # Decoder part (upsampling with convolutional layers)
45+ self .dec1 = nn .Conv2d (256 , 128 , kernel_size = 3 , padding = 1 ) # First decoding layer
46+ self .dec2 = nn .Conv2d (128 , 64 , kernel_size = 3 , padding = 1 ) # Second decoding layer
47+ self .dec3 = nn .Conv2d (64 , out_channels , kernel_size = 3 , padding = 1 ) # Final layer, output has 1 channel
48+ nn .Sigmoid () # Sigmoid activation (not used in forward)
49+
50+ def forward (self , x ):
51+ # Encoder: apply convolutional layers and relu activation
52+ enc1 = F .relu (self .enc1 (x )) # (4, 56, 56) -> (64, 56, 56)
53+ enc2 = F .relu (self .enc2 (F .max_pool2d (enc1 , 2 ))) # Downsampling: (64, 28, 28) -> (128, 28, 28)
54+ enc3 = F .relu (self .enc3 (F .max_pool2d (enc2 , 2 ))) # Downsampling: (128, 14, 14) -> (256, 14, 14)
55+
56+ # Self-Attention layer to refine feature maps
57+ attn_out = self .attention (enc3 ) # Attention output: (256, 14, 14)
58+
59+ # Decoder: upsample and apply convolutional layers
60+ dec1 = F .relu (self .dec1 (F .upsample (attn_out , scale_factor = 2 , mode = 'bilinear' , align_corners = False ))) # Upsample: (128, 28, 28)
61+ dec2 = F .relu (self .dec2 (F .upsample (dec1 , scale_factor = 2 , mode = 'bilinear' , align_corners = False ))) # Upsample: (64, 56, 56)
62+ out = self .dec3 (dec2 ) # Final output: (64, 56, 56) -> (1, 56, 56)
63+
64+ return out # Return the final output
0 commit comments