@@ -1474,13 +1474,17 @@ def sample_inputs_roi_align(op_info, device, dtype, requires_grad, **kwargs):
14741474 del op_info , kwargs
14751475
14761476 def make_x ():
1477- return torch .rand (1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad )
1477+ return torch .rand (
1478+ 1 , 1 , 10 , 10 , dtype = dtype , device = device , requires_grad = requires_grad
1479+ )
14781480
14791481 # rois is [K, 5] = [batch_idx, x1, y1, x2, y2]
14801482 roi_a = torch .tensor ([[0 , 1.5 , 1.5 , 3.0 , 3.0 ]], dtype = dtype , device = device )
14811483 roi_b = torch .tensor ([[0 , 0.2 , 0.3 , 4.5 , 3.5 ]], dtype = dtype , device = device )
14821484 roi_int = torch .tensor ([[0 , 0.0 , 0.0 , 4.0 , 4.0 ]], dtype = dtype , device = device )
1483- roi_malformed = torch .tensor ([[0 , 2.0 , 0.3 , 1.5 , 1.5 ]], dtype = dtype , device = device ) # x1 > x2-ish
1485+ roi_malformed = torch .tensor (
1486+ [[0 , 2.0 , 0.3 , 1.5 , 1.5 ]], dtype = dtype , device = device
1487+ ) # x1 > x2-ish
14841488
14851489 # (rois, spatial_scale, pooled_h, pooled_w, sampling_ratio, aligned)
14861490 cases = [
0 commit comments