Skip to content

Commit 7211569

Browse files
committed
Refactor sample_inputs_roi_align for improved clarity and efficiency; remove redundant tests and simplify input handling
1 parent db3701c commit 7211569

2 files changed

Lines changed: 26 additions & 107 deletions

File tree

tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 25 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,81 +1471,34 @@ def sample_inputs_replication_pad1d(op_info, device, dtype, requires_grad, **kwa
14711471

14721472

14731473
def sample_inputs_roi_align(op_info, device, dtype, requires_grad, **kwargs):
1474-
del op_info
1475-
del kwargs
1476-
# roi_align signature: (input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False)
1477-
1478-
# Test 1: spatial_scale=1, sampling_ratio=2
1479-
x1 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1480-
roi1 = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=dtype, device=device)
1481-
yield opinfo_core.SampleInput(
1482-
x1,
1483-
args=(roi1, (5, 5)),
1484-
kwargs={"spatial_scale": 1.0, "sampling_ratio": 2, "aligned": True},
1485-
)
1486-
1487-
# Test 2: spatial_scale=0.5, sampling_ratio=3
1488-
x2 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1489-
roi2 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
1490-
yield opinfo_core.SampleInput(
1491-
x2,
1492-
args=(roi2, (5, 5)),
1493-
kwargs={"spatial_scale": 0.5, "sampling_ratio": 3, "aligned": True},
1494-
)
1495-
1496-
# Test 3: spatial_scale=1.8, sampling_ratio=2
1497-
x3 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1498-
roi3 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
1499-
yield opinfo_core.SampleInput(
1500-
x3,
1501-
args=(roi3, (5, 5)),
1502-
kwargs={"spatial_scale": 1.8, "sampling_ratio": 2, "aligned": True},
1503-
)
1504-
1505-
# Test 4: spatial_scale=2.5, sampling_ratio=0, output_size=(2,2)
1506-
x4 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1507-
roi4 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
1508-
yield opinfo_core.SampleInput(
1509-
x4,
1510-
args=(roi4, (2, 2)),
1511-
kwargs={"spatial_scale": 2.5, "sampling_ratio": 0, "aligned": True},
1512-
)
1474+
del op_info, kwargs
15131475

1514-
# Test 5: spatial_scale=2.5, sampling_ratio=-1, output_size=(2,2)
1515-
x5 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1516-
roi5 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
1517-
yield opinfo_core.SampleInput(
1518-
x5,
1519-
args=(roi5, (2, 2)),
1520-
kwargs={"spatial_scale": 2.5, "sampling_ratio": -1, "aligned": True},
1521-
)
1476+
def make_x():
1477+
return torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
15221478

1523-
# Test 6: malformed boxes (test_roi_align_malformed_boxes)
1524-
x6 = torch.randn(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1525-
roi6 = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=dtype, device=device)
1526-
yield opinfo_core.SampleInput(
1527-
x6,
1528-
args=(roi6, (5, 5)),
1529-
kwargs={"spatial_scale": 1.0, "sampling_ratio": 1, "aligned": True},
1530-
)
1479+
# rois is [K, 5] = [batch_idx, x1, y1, x2, y2]
1480+
roi_a = torch.tensor([[0, 1.5, 1.5, 3.0, 3.0]], dtype=dtype, device=device)
1481+
roi_b = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
1482+
roi_int = torch.tensor([[0, 0.0, 0.0, 4.0, 4.0]], dtype=dtype, device=device)
1483+
roi_malformed = torch.tensor([[0, 2.0, 0.3, 1.5, 1.5]], dtype=dtype, device=device) # x1 > x2-ish
15311484

1532-
# Test 7: aligned=False, spatial_scale=1, sampling_ratio=2
1533-
x7 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1534-
roi7 = torch.tensor([[0, 0, 0, 4, 4]], dtype=dtype, device=device)
1535-
yield opinfo_core.SampleInput(
1536-
x7,
1537-
args=(roi7, (5, 5)),
1538-
kwargs={"spatial_scale": 1.0, "sampling_ratio": 2, "aligned": False},
1539-
)
1485+
# (rois, spatial_scale, pooled_h, pooled_w, sampling_ratio, aligned)
1486+
cases = [
1487+
(roi_a, 1.0, 5, 5, 2, True),
1488+
(roi_b, 0.5, 5, 5, 3, True),
1489+
(roi_b, 1.8, 5, 5, 2, True),
1490+
(roi_b, 2.5, 2, 2, 0, True),
1491+
(roi_b, 2.5, 2, 2, -1, True),
1492+
(roi_malformed, 1.0, 5, 5, 1, True),
1493+
(roi_int, 1.0, 5, 5, 2, False),
1494+
(roi_int, 1.0, 5, 5, -1, False),
1495+
]
15401496

1541-
# Test 8: aligned=False, spatial_scale=1, sampling_ratio=-1
1542-
x8 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1543-
roi8 = torch.tensor([[0, 0, 0, 4, 4]], dtype=dtype, device=device)
1544-
yield opinfo_core.SampleInput(
1545-
x8,
1546-
args=(roi8, (5, 5)),
1547-
kwargs={"spatial_scale": 1.0, "sampling_ratio": -1, "aligned": False},
1548-
)
1497+
for rois, spatial_scale, ph, pw, sr, aligned in cases:
1498+
yield opinfo_core.SampleInput(
1499+
make_x(),
1500+
args=(rois, float(spatial_scale), int(ph), int(pw), int(sr), bool(aligned)),
1501+
)
15491502

15501503

15511504
def sample_inputs_roi_pool(op_info, device, dtype, requires_grad, **kwargs):
@@ -3132,7 +3085,7 @@ def __init__(self):
31323085
),
31333086
opinfo_core.OpInfo(
31343087
"torchvision.ops.roi_align",
3135-
op=torchvision.ops.roi_align,
3088+
op=torch.ops.torchvision.roi_align.default,
31363089
dtypes=common_dtype.floating_types(),
31373090
sample_inputs_func=sample_inputs_roi_align,
31383091
supports_out=False,

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -448,36 +448,6 @@ def _where_input_wrangler(
448448
args[0], args[1] = args[1], args[0]
449449
return args, kwargs
450450

451-
452-
def _torchvision_roi_align_default_input_wrangler(
453-
args: list[Any], kwargs: dict[str, Any]
454-
) -> tuple[list[Any], dict[str, Any]]:
455-
# Convert:
456-
# roi_align(input, boxes, output_size, spatial_scale=..., sampling_ratio=..., aligned=...)
457-
# into:
458-
# roi_align(input, boxes, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned)
459-
output_size = args.pop(2)
460-
if isinstance(output_size, np.ndarray):
461-
if output_size.ndim == 0:
462-
pooled_height = int(output_size)
463-
pooled_width = int(output_size)
464-
else:
465-
pooled_height, pooled_width = output_size.tolist()
466-
elif isinstance(output_size, (tuple, list)):
467-
pooled_height, pooled_width = output_size
468-
else:
469-
pooled_height = output_size
470-
pooled_width = output_size
471-
472-
pooled_height = int(pooled_height)
473-
pooled_width = int(pooled_width)
474-
spatial_scale = float(kwargs.pop("spatial_scale", 1.0))
475-
sampling_ratio = int(kwargs.pop("sampling_ratio", -1))
476-
aligned = bool(kwargs.pop("aligned", False))
477-
args.extend([spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned])
478-
return args, {}
479-
480-
481451
# Ops to be tested for numerical consistency between onnx and pytorch
482452
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
483453
TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = (
@@ -1948,11 +1918,7 @@ def _torchvision_roi_align_default_input_wrangler(
19481918
),
19491919
TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like),
19501920
TorchLibOpInfo("torchvision.ops.nms", vision_ops.torchvision_nms),
1951-
TorchLibOpInfo(
1952-
"torchvision.ops.roi_align",
1953-
vision_ops.torchvision_roi_align,
1954-
input_wrangler=_torchvision_roi_align_default_input_wrangler,
1955-
),
1921+
TorchLibOpInfo("torchvision.ops.roi_align", vision_ops.torchvision_roi_align),
19561922
TorchLibOpInfo("torchvision.ops.roi_pool", vision_ops.torchvision_roi_pool),
19571923
)
19581924

0 commit comments

Comments
 (0)