Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
664 changes: 573 additions & 91 deletions scripts/generate_torch_ops.py

Large diffs are not rendered by default.

227 changes: 159 additions & 68 deletions scripts/generate_wrappers.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions scripts/torch_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@
- _upsample_nearest_exact2d_backward
- _upsample_nearest_exact3d
- _upsample_nearest_exact3d_backward
- add
- add_
- argsort
- bernoulli_
Expand Down Expand Up @@ -527,6 +528,7 @@
- less_equal_
- lt_
- masked_fill_
- mul
- mul_
- multiply_
- ne_
Expand Down
9 changes: 9 additions & 0 deletions src/hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define INFINI_OPS_HASH_H_

#include <functional>
#include <optional>
#include <vector>

template <typename T>
Expand All @@ -18,4 +19,12 @@ inline void HashCombine(std::size_t& seed, const std::vector<T>& v) {
}
}

template <typename T>
inline void HashCombine(std::size_t& seed, const std::optional<T>& v) {
HashCombine(seed, v.has_value());
if (v.has_value()) {
HashCombine(seed, *v);
}
}

#endif
10 changes: 8 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,13 @@ def skip_op_without_platform_impl(request):

op_cls = _op_class_from_module(request.node.module)

if op_cls is None or not hasattr(op_cls, "active_implementation_indices"):
if op_cls is None:
if "op_meta" in params:
return

pytest.skip("operator wrapper is not available in this build")

if not hasattr(op_cls, "active_implementation_indices"):
return

if not any(op_cls.active_implementation_indices(d) for d in device_selectors):
Expand Down Expand Up @@ -308,7 +314,7 @@ def pytest_pyfunc_call(pyfuncitem):
atol = payload.atol

# `torch.allclose` rejects `bool` dtypes — use `torch.equal` for
# non-floating outputs (bool, int) so comparison ops work. Pass
# non-floating outputs (bool, int) so comparison ops work. Pass
# `equal_nan=True` so NaN-in-both-positions (common for special
# functions fed out-of-domain inputs) does not fail the test.
if output.dtype.is_floating_point:
Expand Down
189 changes: 186 additions & 3 deletions tests/test_generate_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,189 @@ def test_load_aten_yaml_uses_packaged_torchgen(monkeypatch):
def test_public_op_name_normalizes_aten_internal_and_inplace_names():
module = _load_generator_module()

assert module._public_op_name("_softmax") == "aten_softmax"
assert module._public_op_name("add_") == "add_inplace"
assert module._public_op_name("_add_relu_") == "aten_add_relu_inplace"
assert module._public_op_name("_softmax") == "internal_softmax"
assert module._public_op_name("add_") == "add"
assert module._public_op_name("_add_relu_") == "internal_add_relu"


def test_schema_self_param_renders_as_input_in_public_cpp_api():
module = _load_generator_module()
op = module._parse_func(
"_softmax(Tensor self, int dim, bool half_to_float, *, "
"Tensor(a!) out) -> Tensor(a!)"
)

assert op.params[0].name == "self"
assert op.params[0].api_name == "input"

base = module._generate_base_header("internal_softmax", [op])
source = module._generate_torch_method_source("internal_softmax", op)

assert "Softmax(const Tensor input, const int64_t dim" in base
assert "self_shape_" not in base
assert "input_shape_" in base
assert "auto at_self = ToAtenTensor<kDev>" in source
assert "input_shape_" in source
assert "at::_softmax_out(at_out, at_self" in source


def test_optional_tensor_params_are_exposed_and_forwarded_to_aten():
module = _load_generator_module()
op = module._parse_func(
"batch_norm_elemt(Tensor input, Tensor? weight=None, "
"Tensor? bias=None, Tensor mean, Tensor invstd, float eps, "
"*, Tensor(a!) out) -> Tensor(a!)"
)

assert [param.cpp_type for param in op.visible_params] == [
"Tensor",
"std::optional<Tensor>",
"std::optional<Tensor>",
"Tensor",
"Tensor",
"double",
"Tensor",
]

base = module._generate_base_header("batch_norm_elemt", [op])
source = module._generate_torch_method_source("batch_norm_elemt", op)

assert "#include <optional>" in base
assert "std::optional<Tensor> weight" in base
assert "std::optional<Tensor> bias" in base
assert "bool has_weight_" in base
assert "bool has_bias_" in base
assert "c10::optional<at::Tensor> at_weight" in source
assert "c10::optional<at::Tensor> at_bias" in source
assert "weight->shape()" in source
assert "weight_shape_" not in source
assert "at::batch_norm_elemt_out" in source
assert "at_weight" in source
assert "at_bias" in source


def test_optional_scalar_and_array_params_are_exposed_and_forwarded_to_aten():
module = _load_generator_module()
quantile = module._parse_func(
"quantile(Tensor input, Tensor q, int? dim=None, bool keepdim=False, "
"str interpolation='linear', *, Tensor(a!) out) -> Tensor(a!)"
)
upsample = module._parse_func(
"upsample_bicubic2d(Tensor input, SymInt[2] output_size, "
"bool align_corners, float[]? scale_factors=None, "
"*, Tensor(a!) out) -> Tensor(a!)"
)

assert [param.cpp_type for param in quantile.visible_params] == [
"Tensor",
"Tensor",
"std::optional<int64_t>",
"bool",
"std::string",
"Tensor",
]
assert [param.cpp_type for param in upsample.visible_params] == [
"Tensor",
"std::vector<int64_t>",
"bool",
"std::optional<std::vector<double>>",
"Tensor",
]

quantile_source = module._generate_torch_method_source("quantile", quantile)
upsample_source = module._generate_torch_method_source(
"upsample_bicubic2d", upsample
)

assert "c10::optional<int64_t> at_dim" in quantile_source
assert "at::quantile_out" in quantile_source
assert "at_dim" in quantile_source
assert "c10::optional<at::ArrayRef<double>> at_scale_factors" in upsample_source
assert "at::upsample_bicubic2d_out" in upsample_source
assert "at_scale_factors" in upsample_source


def test_existing_base_overload_can_omit_optional_schema_params():
module = _load_generator_module()
op = module._parse_func(
"slow_conv3d(Tensor input, Tensor weight, int[3] kernel_size, "
"Tensor? bias=None, int[3] stride=1, int[3] padding=0, "
"*, Tensor(a!) out) -> Tensor(a!)"
)
signature = [
("Tensor", "input"),
("Tensor", "weight"),
("std::vector<int64_t>", "kernel_size"),
("std::vector<int64_t>", "stride"),
("std::vector<int64_t>", "padding"),
("Tensor", "out"),
]

bound = module._bind_base_signature(op, signature)

assert bound is not None
assert [param.name for param in bound.visible_params] == [
"input",
"weight",
"kernel_size",
"stride",
"padding",
"out",
]

source = module._generate_torch_method_source("slow_conv3d", bound)

assert "std::optional<Tensor> bias" not in source
assert "c10::optional<at::Tensor>{}" in source
assert "at::slow_conv3d_out" in source


def test_existing_base_overload_can_omit_defaulted_schema_params():
module = _load_generator_module()
op = module._parse_func(
"add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1, "
"Tensor(a!) out) -> Tensor(a!)"
)
signature = [
("const Tensor", "input"),
("const Tensor", "other"),
("Tensor", "out"),
]

bound = module._bind_base_signature(op, signature)

assert bound is not None

source = module._generate_torch_method_source("add", bound)

assert "double alpha" not in source
assert "const auto device_index = out.device().index();" in source
assert "device_index_)" not in source
assert "at::add_out(at_out, at_self, at_other, 1)" in source


def test_existing_base_overload_matches_by_name_when_types_repeat():
module = _load_generator_module()
op = module._parse_func(
"std(Tensor input, int[1]? dim=None, bool unbiased=True, "
"bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)"
)
signature = [
("Tensor", "input"),
("bool", "keepdim"),
("Tensor", "out"),
]

bound = module._bind_base_signature(op, signature)

assert bound is not None
assert [param.name for param in bound.visible_params] == [
"input",
"keepdim",
"out",
]

source = module._generate_torch_method_source("std", bound)

assert "c10::optional<at::IntArrayRef>{}, true, keepdim" in source
assert "unbiased" not in source
72 changes: 72 additions & 0 deletions tests/test_generate_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import importlib.util
import pathlib
import sys


def _load_generator_module():
path = (
pathlib.Path(__file__).resolve().parents[1] / "scripts" / "generate_wrappers.py"
)
spec = importlib.util.spec_from_file_location("generate_wrappers_under_test", path)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
sys.modules[spec.name] = module
spec.loader.exec_module(module)

return module


def test_generated_dispatch_keeps_optional_scalar_and_tensor_overloads_distinct(
monkeypatch, tmp_path
):
module = _load_generator_module()
base_header = tmp_path / "clamp.h"
base_header.write_text(
"""
class Clamp {
public:
virtual void operator()(const Tensor input, const std::optional<double> min,
const std::optional<double> max, Tensor out) const = 0;
virtual void operator()(const Tensor input, const std::optional<Tensor> min,
const std::optional<Tensor> max, Tensor out) const = 0;
};
"""
)
monkeypatch.setattr(module, "_find_base_header", lambda op_name: base_header)

operator = module._Operator(
"clamp",
constructors=[
module._ParsedFunction(
[
module._ParsedArgument("const Tensor", "input"),
module._ParsedArgument("const std::optional<double>", "min"),
module._ParsedArgument("const std::optional<double>", "max"),
module._ParsedArgument("Tensor", "out"),
]
),
module._ParsedFunction(
[
module._ParsedArgument("const Tensor", "input"),
module._ParsedArgument("const std::optional<Tensor>", "min"),
module._ParsedArgument("const std::optional<Tensor>", "max"),
module._ParsedArgument("Tensor", "out"),
]
),
],
calls=[],
)

declarations, _ = module._generate_generated_dispatch_entries(operator)

text = "\n".join(declarations)

assert (
"MakeClamp(const Config& config, const Tensor input, "
"const std::optional<double> min, const std::optional<double> max, "
"Tensor out)"
) in text
assert (
"MakeClamp(const Config& config, const Tensor input, "
"std::optional<Tensor> min, std::optional<Tensor> max, Tensor out)"
) in text
Loading
Loading