Skip to content

Commit 827460e

Browse files
committed
revert self attention changes
Signed-off-by: R. Garcia-Dias <rafaelagd@gmail.com>
1 parent 89c43b8 commit 827460e

1 file changed

Lines changed: 45 additions & 13 deletions

File tree

tests/networks/blocks/test_selfattention.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,33 @@
2222
from monai.networks.blocks.selfattention import SABlock
2323
from monai.networks.layers.factories import RelPosEmbedding
2424
from monai.utils import optional_import
25-
from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose, dict_product, test_script_save
25+
from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose, test_script_save
2626

2727
einops, has_einops = optional_import("einops")
2828

29-
TEST_CASE_SABLOCK = [
30-
[params, (2, 512, params["hidden_size"]), (2, 512, params["hidden_size"])]
31-
for params in dict_product(
32-
dropout_rate=np.linspace(0, 1, 4),
33-
hidden_size=[360, 480, 600, 768],
34-
num_heads=[4, 6, 8, 12],
35-
rel_pos_embedding=[None, RelPosEmbedding.DECOMPOSED],
36-
input_size=[(16, 32), (8, 8, 8)],
37-
include_fc=[True, False],
38-
use_combined_linear=[True, False],
39-
)
40-
]
29+
TEST_CASE_SABLOCK = []
30+
for dropout_rate in np.linspace(0, 1, 4):
31+
for hidden_size in [360, 480, 600, 768]:
32+
for num_heads in [4, 6, 8, 12]:
33+
for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]:
34+
for input_size in [(16, 32), (8, 8, 8)]:
35+
for include_fc in [True, False]:
36+
for use_combined_linear in [True, False]:
37+
test_case = [
38+
{
39+
"hidden_size": hidden_size,
40+
"num_heads": num_heads,
41+
"dropout_rate": dropout_rate,
42+
"rel_pos_embedding": rel_pos_embedding,
43+
"input_size": input_size,
44+
"include_fc": include_fc,
45+
"use_combined_linear": use_combined_linear,
46+
"use_flash_attention": True if rel_pos_embedding is None else False,
47+
},
48+
(2, 512, hidden_size),
49+
(2, 512, hidden_size),
50+
]
51+
TEST_CASE_SABLOCK.append(test_case)
4152

4253

4354
class TestResBlock(unittest.TestCase):
@@ -216,6 +227,27 @@ def test_flash_attention(self):
216227
out_2 = block_wo_flash_attention(test_data)
217228
assert_allclose(out_1, out_2, atol=1e-4)
218229

230+
@parameterized.expand([[True], [False]])
231+
def test_no_extra_weights_if_no_fc(self, include_fc):
232+
input_param = {
233+
"hidden_size": 360,
234+
"num_heads": 4,
235+
"dropout_rate": 0.0,
236+
"rel_pos_embedding": None,
237+
"input_size": (16, 32),
238+
"include_fc": include_fc,
239+
"use_combined_linear": use_combined_linear,
240+
}
241+
net = SABlock(**input_param)
242+
if not include_fc:
243+
self.assertNotIn("out_proj.weight", net.state_dict())
244+
self.assertNotIn("out_proj.bias", net.state_dict())
245+
self.assertIsInstance(net.out_proj, torch.nn.Identity)
246+
else:
247+
self.assertIn("out_proj.weight", net.state_dict())
248+
self.assertIn("out_proj.bias", net.state_dict())
249+
self.assertIsInstance(net.out_proj, torch.nn.Linear)
250+
219251

220252
if __name__ == "__main__":
221253
unittest.main()

0 commit comments

Comments
 (0)