Skip to content

Commit 87e86ab

Browse files
committed
fix(tests): keep aggregate torch op tests active
1 parent 2a5d6af commit 87e86ab

4 files changed

Lines changed: 36 additions & 2 deletions

File tree

scripts/generate_torch_ops.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -781,6 +781,9 @@ def _default_call_value(param: Param) -> str:
781781
if param.aten_type in _NULLOPT_BY_TYPE:
782782
return _NULLOPT_BY_TYPE[param.aten_type]
783783

784+
if param.default is not None:
785+
return _translate_default(param)
786+
784787
return param.hidden_value()
785788

786789

@@ -1144,6 +1147,8 @@ def _generate_torch_header(name: str, ops: list[Op]) -> str:
11441147
def _generate_torch_method_source(name: str, op: Op) -> str:
11451148
op_type = _op_cpp_type(name)
11461149
conversion_lines = []
1150+
out_device_index = f"{op.out_params[0].api_name}.device().index()"
1151+
conversion_lines.append(f" const auto device_index = {out_device_index};")
11471152

11481153
def _optional_aten_type(param: Param) -> str:
11491154
return _NULLOPT_BY_TYPE[param.aten_type].removesuffix("{}")
@@ -1157,7 +1162,7 @@ def _optional_aten_value(schema_param: Param, api_param: Param) -> str:
11571162
return (
11581163
f"ToAtenTensor<kDev>({data_expr}, {api_name}->shape(), "
11591164
f"{api_name}->strides(), {api_name}->dtype(), "
1160-
f"device_index_)"
1165+
f"device_index)"
11611166
)
11621167

11631168
if schema_param.aten_type == "Scalar?":
@@ -1211,7 +1216,7 @@ def _append_optional_conversion(schema_param: Param, api_param: Param) -> None:
12111216
conversion_lines.append(
12121217
f" auto at_{param.name} = ToAtenTensor<kDev>(\n"
12131218
f" {data_expr}, {api_name}_shape_, {api_name}_strides_,\n"
1214-
f" {api_name}_type_, device_index_);"
1219+
f" {api_name}_type_, device_index);"
12151220
)
12161221

12171222
for schema_index, param in enumerate(op.params):

scripts/torch_ops.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,7 @@
499499
- _upsample_nearest_exact2d_backward
500500
- _upsample_nearest_exact3d
501501
- _upsample_nearest_exact3d_backward
502+
- add
502503
- add_
503504
- argsort
504505
- bernoulli_
@@ -527,6 +528,7 @@
527528
- less_equal_
528529
- lt_
529530
- masked_fill_
531+
- mul
530532
- mul_
531533
- multiply_
532534
- ne_

tests/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ def skip_op_without_platform_impl(request):
120120
op_cls = _op_class_from_module(request.node.module)
121121

122122
if op_cls is None:
123+
if "op_meta" in params:
124+
return
125+
123126
pytest.skip("operator wrapper is not available in this build")
124127

125128
if not hasattr(op_cls, "active_implementation_indices"):

tests/test_generate_torch_ops.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,30 @@ def test_existing_base_overload_can_omit_optional_schema_params():
165165
assert "at::slow_conv3d_out" in source
166166

167167

168+
def test_existing_base_overload_can_omit_defaulted_schema_params():
169+
module = _load_generator_module()
170+
op = module._parse_func(
171+
"add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1, "
172+
"Tensor(a!) out) -> Tensor(a!)"
173+
)
174+
signature = [
175+
("const Tensor", "input"),
176+
("const Tensor", "other"),
177+
("Tensor", "out"),
178+
]
179+
180+
bound = module._bind_base_signature(op, signature)
181+
182+
assert bound is not None
183+
184+
source = module._generate_torch_method_source("add", bound)
185+
186+
assert "double alpha" not in source
187+
assert "const auto device_index = out.device().index();" in source
188+
assert "device_index_)" not in source
189+
assert "at::add_out(at_out, at_self, at_other, 1)" in source
190+
191+
168192
def test_existing_base_overload_matches_by_name_when_types_repeat():
169193
module = _load_generator_module()
170194
op = module._parse_func(

0 commit comments

Comments
 (0)