Skip to content

Commit 016d3b0

Browse files
committed
Fix test mistakes
Signed-off-by: R. Garcia-Dias <rafaelagd@gmail.com>
1 parent ef2388e commit 016d3b0

2 files changed

Lines changed: 7 additions & 6 deletions

File tree

tests/networks/blocks/test_patchembedding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,16 @@
4545
)
4646
]
4747

48+
img_size = 96
4849
TEST_CASE_PATCHEMBED = [
4950
[
5051
params,
51-
(2, params["in_chans"], *([params["img_size"]] * params["spatial_dims"])),
52-
(2, (params["img_size"] // params["patch_size"]) ** params["spatial_dims"], params["embed_dim"]),
52+
(2, params["in_chans"], *([img_size] * params["spatial_dims"])),
53+
(2, params["embed_dim"], *([img_size // params["patch_size"]]) * params["spatial_dims"]),
5354
]
5455
for params in dict_product(
5556
patch_size=[2],
5657
in_chans=[1, 4],
57-
img_size=[96],
5858
embed_dim=[6, 12],
5959
norm_layer=[nn.LayerNorm],
6060
spatial_dims=[2, 3],

tests/transforms/test_gibbs_noise.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@
2525

2626
_, has_torch_fft = optional_import("torch.fft", name="fftshift")
2727

28-
params = {"shape": ((128, 64), (64, 48, 80)), "input_type": TEST_NDARRAYS if has_torch_fft else [np.array]}
29-
TEST_CASES = list(dict_product(format="list", **params))
30-
28+
TEST_CASES = dict_product(shape=((128, 64), (64, 48, 80)), input_type=TEST_NDARRAYS if has_torch_fft else [np.array])
29+
TEST_CASES = [
30+
[p_dict["shape"], p_dict["input_type"]] for p_dict in TEST_CASES
31+
]
3132

3233
class TestGibbsNoise(unittest.TestCase):
3334
def setUp(self):

0 commit comments

Comments
 (0)