Skip to content

Commit 902eddd

Browse files
committed
feat: Test new network and self-attention.
1 parent b4268c0 commit 902eddd

5 files changed

Lines changed: 124 additions & 1 deletion

File tree

iddm/config/choices.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# Support option
1818
bool_choices = [True, False]
1919
sample_choices = ["ddpm", "ddim", "plms"]
20-
network_choices = ["unet", "cspdarkunet", "unetv2"]
20+
network_choices = ["unet", "cspdarkunet", "unetv2", "unet-slim"]
2121
optim_choices = ["adam", "adamw", "sgd"]
2222
act_choices = ["gelu", "silu", "relu", "relu6", "lrelu"]
2323
lr_func_choices = ["linear", "cosine", "warmup_cosine"]

iddm/model/modules/attention.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,61 @@ def forward(self, x):
5151
attention_value = attention_value + x
5252
attention_value = self.ff_self(attention_value) + attention_value
5353
return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size[0], self.size[1])
54+
55+
56+
class SelfAttentionAD(nn.Module):
57+
"""
58+
Adaptive head count SelfAttention block
59+
"""
60+
61+
def __init__(self, channels, size, act="silu", dropout=0.1):
62+
"""
63+
Initialize the adaptive head count self-attention block
64+
:param channels: Channels
65+
:param size: Size
66+
:param act: Activation function
67+
"""
68+
super(SelfAttentionAD, self).__init__()
69+
self.channels = channels
70+
self.size = size
71+
self.dropout = dropout
72+
73+
# Adaptive head count
74+
head_count = max(1, channels // 64)
75+
76+
# batch_first is not supported in pytorch 1.8.
77+
# If you want to support upgrading to 1.9 and above, or use the following code to transpose
78+
self.mha = nn.MultiheadAttention(embed_dim=channels, num_heads=head_count, batch_first=True)
79+
self.ln = nn.LayerNorm(normalized_shape=[channels])
80+
self.ff_self = nn.Sequential(
81+
nn.LayerNorm(normalized_shape=[channels]),
82+
nn.Linear(in_features=channels, out_features=channels),
83+
get_activation_function(name=act),
84+
nn.Dropout(dropout),
85+
nn.Linear(in_features=channels, out_features=channels),
86+
nn.Dropout(dropout),
87+
)
88+
89+
def forward(self, x):
90+
"""
91+
SelfAttention forward
92+
:param x: Input
93+
:return: attention_value
94+
"""
95+
batch, channels, height, width = x.shape
96+
assert height == self.size[0] and width == self.size[1], \
97+
f"Input size {height}x{width} does not match the expected size {self.size[0]}x{self.size[1]}"
98+
# Flatten the spatial dimension into sequence dimensions
99+
# (batch, channels, height*width) -> (batch, seq_len, channels)
100+
x_flat = x.flatten(2).swapaxes(1, 2)
101+
102+
# First residual calculation
103+
x_ln = self.ln(x_flat)
104+
# batch_first is not supported in pytorch 1.8.
105+
# If you want to support upgrading to 1.9 and above, or use the following code to transpose
106+
attention_value, _ = self.mha(x_ln, x_ln, x_ln)
107+
attention_value = attention_value + x_flat
108+
109+
# Second residual calculation
110+
attention_value = self.ff_self(attention_value) + attention_value
111+
return attention_value.swapaxes(1, 2).view(batch, channels, height, width)

iddm/model/networks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .cspdarkunet import CSPDarkUnet
1111
from .unet import UNet
1212
from .unetv2 import UNetV2
13+
from .unet_slim import UNetSlim
1314

1415
# Super resolution network
1516
from .sr.srv1 import SRv1

iddm/model/networks/unet_slim.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#!/usr/bin/env python
2+
# -*- coding:utf-8 -*-
3+
"""
4+
@Date : 2023/6/23 22:26
5+
@Author : chairc
6+
@Site : https://github.com/chairc
7+
"""
8+
9+
from iddm.model.networks.unet import UNet
10+
11+
12+
class UNetSlim(UNet):
13+
"""
14+
UNet-Slim
15+
This is a slim network demo, reduce 45% GPU used
16+
"""
17+
18+
def __init__(self, **kwargs):
19+
"""
20+
Initialize the UNet-Slim network
21+
:param in_channel: Input channel
22+
:param out_channel: Output channel
23+
:param channel: The list of channel
24+
:param time_channel: Time channel
25+
:param num_classes: Number of classes
26+
:param image_size: Adaptive image size
27+
:param device: Device type
28+
:param act: Activation function
29+
"""
30+
super(UNetSlim, self).__init__(**kwargs)
31+
32+
def forward(self, x, time, y=None):
33+
"""
34+
Forward
35+
:param x: Input
36+
:param time: Time
37+
:param y: Input label
38+
:return: output
39+
"""
40+
time = self.encode_time_with_label(time=time, y=y)
41+
42+
x = self.inc(x)
43+
x1 = x
44+
x = self.down1(x, time)
45+
x = self.sa1(x)
46+
x2_sa = x
47+
x = self.down2(x, time)
48+
x3_sa = x
49+
x = self.down3(x, time)
50+
x = self.sa3(x)
51+
52+
x = self.bot1(x)
53+
x = self.bot2(x)
54+
x = self.bot3(x)
55+
56+
x = self.up1(x, x3_sa, time)
57+
x = self.up2(x, x2_sa, time)
58+
x = self.sa5(x)
59+
x = self.up3(x, x1, time)
60+
output = self.outc(x)
61+
return output

iddm/utils/initializer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from iddm.model.networks.unet import UNet
1919
from iddm.model.networks.unetv2 import UNetV2
20+
from iddm.model.networks.unet_slim import UNetSlim
2021
from iddm.model.networks.cspdarkunet import CSPDarkUnet
2122
from iddm.model.networks.sr.srv1 import SRv1
2223
from iddm.model.samples.ddim import DDIMDiffusion
@@ -94,6 +95,8 @@ def network_initializer(network, device):
9495
Network = UNetV2
9596
elif network == "cspdarkunet":
9697
Network = CSPDarkUnet
98+
elif network == "unet-slim":
99+
Network = UNetSlim
97100
else:
98101
Network = UNet
99102
logger.warning(msg=f"[{device}]: Setting network error, we has been automatically set to unet.")

0 commit comments

Comments
 (0)