|
22 | 22 | from monai.networks.blocks.selfattention import SABlock |
23 | 23 | from monai.networks.layers.factories import RelPosEmbedding |
24 | 24 | from monai.utils import optional_import |
25 | | -from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose, dict_product, test_script_save |
| 25 | +from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose, test_script_save |
26 | 26 |
|
27 | 27 | einops, has_einops = optional_import("einops") |
28 | 28 |
|
29 | | -TEST_CASE_SABLOCK = [ |
30 | | - [params, (2, 512, params["hidden_size"]), (2, 512, params["hidden_size"])] |
31 | | - for params in dict_product( |
32 | | - dropout_rate=np.linspace(0, 1, 4), |
33 | | - hidden_size=[360, 480, 600, 768], |
34 | | - num_heads=[4, 6, 8, 12], |
35 | | - rel_pos_embedding=[None, RelPosEmbedding.DECOMPOSED], |
36 | | - input_size=[(16, 32), (8, 8, 8)], |
37 | | - include_fc=[True, False], |
38 | | - use_combined_linear=[True, False], |
39 | | - ) |
40 | | -] |
| 29 | +TEST_CASE_SABLOCK = [] |
| 30 | +for dropout_rate in np.linspace(0, 1, 4): |
| 31 | + for hidden_size in [360, 480, 600, 768]: |
| 32 | + for num_heads in [4, 6, 8, 12]: |
| 33 | + for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: |
| 34 | + for input_size in [(16, 32), (8, 8, 8)]: |
| 35 | + for include_fc in [True, False]: |
| 36 | + for use_combined_linear in [True, False]: |
| 37 | + test_case = [ |
| 38 | + { |
| 39 | + "hidden_size": hidden_size, |
| 40 | + "num_heads": num_heads, |
| 41 | + "dropout_rate": dropout_rate, |
| 42 | + "rel_pos_embedding": rel_pos_embedding, |
| 43 | + "input_size": input_size, |
| 44 | + "include_fc": include_fc, |
| 45 | + "use_combined_linear": use_combined_linear, |
| 46 | + "use_flash_attention": True if rel_pos_embedding is None else False, |
| 47 | + }, |
| 48 | + (2, 512, hidden_size), |
| 49 | + (2, 512, hidden_size), |
| 50 | + ] |
| 51 | + TEST_CASE_SABLOCK.append(test_case) |
41 | 52 |
|
42 | 53 |
|
43 | 54 | class TestResBlock(unittest.TestCase): |
@@ -216,6 +227,27 @@ def test_flash_attention(self): |
216 | 227 | out_2 = block_wo_flash_attention(test_data) |
217 | 228 | assert_allclose(out_1, out_2, atol=1e-4) |
218 | 229 |
|
| 230 | + @parameterized.expand([[True], [False]]) |
| 231 | + def test_no_extra_weights_if_no_fc(self, include_fc): |
| 232 | + input_param = { |
| 233 | + "hidden_size": 360, |
| 234 | + "num_heads": 4, |
| 235 | + "dropout_rate": 0.0, |
| 236 | + "rel_pos_embedding": None, |
| 237 | + "input_size": (16, 32), |
| 238 | + "include_fc": include_fc, |
| 239 | + "use_combined_linear": use_combined_linear, |
| 240 | + } |
| 241 | + net = SABlock(**input_param) |
| 242 | + if not include_fc: |
| 243 | + self.assertNotIn("out_proj.weight", net.state_dict()) |
| 244 | + self.assertNotIn("out_proj.bias", net.state_dict()) |
| 245 | + self.assertIsInstance(net.out_proj, torch.nn.Identity) |
| 246 | + else: |
| 247 | + self.assertIn("out_proj.weight", net.state_dict()) |
| 248 | + self.assertIn("out_proj.bias", net.state_dict()) |
| 249 | + self.assertIsInstance(net.out_proj, torch.nn.Linear) |
| 250 | + |
219 | 251 |
|
220 | 252 | if __name__ == "__main__": |
221 | 253 | unittest.main() |
0 commit comments