Skip to content

Commit 18cd621

Browse files
committed
Fixed lintrunner
1 parent 7211569 commit 18cd621

2 files changed

Lines changed: 7 additions & 2 deletions

File tree

tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,13 +1474,17 @@ def sample_inputs_roi_align(op_info, device, dtype, requires_grad, **kwargs):
14741474
del op_info, kwargs
14751475

14761476
def make_x():
1477-
return torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
1477+
return torch.rand(
1478+
1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad
1479+
)
14781480

14791481
# rois is [K, 5] = [batch_idx, x1, y1, x2, y2]
14801482
roi_a = torch.tensor([[0, 1.5, 1.5, 3.0, 3.0]], dtype=dtype, device=device)
14811483
roi_b = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
14821484
roi_int = torch.tensor([[0, 0.0, 0.0, 4.0, 4.0]], dtype=dtype, device=device)
1483-
roi_malformed = torch.tensor([[0, 2.0, 0.3, 1.5, 1.5]], dtype=dtype, device=device) # x1 > x2-ish
1485+
roi_malformed = torch.tensor(
1486+
[[0, 2.0, 0.3, 1.5, 1.5]], dtype=dtype, device=device
1487+
) # x1 > x2-ish
14841488

14851489
# (rois, spatial_scale, pooled_h, pooled_w, sampling_ratio, aligned)
14861490
cases = [

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,7 @@ def _where_input_wrangler(
448448
args[0], args[1] = args[1], args[0]
449449
return args, kwargs
450450

451+
451452
# Ops to be tested for numerical consistency between onnx and pytorch
452453
# Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py
453454
TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = (

0 commit comments

Comments
 (0)