Skip to content

Commit 30450e9

Browse files
remove c++ std version for aten tests
Differential Revision: D94554543 Pull Request resolved: pytorch#17747
1 parent 25f2a3f commit 30450e9

2 files changed

Lines changed: 34 additions & 6 deletions

File tree

shim_et/xplat/executorch/build/runtime_wrapper.bzl

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,43 @@ def _patch_build_mode_flags(kwargs):
125125

126126
return kwargs
127127

128+
def _has_pytorch_dep(dep_list):
129+
"""Check if a dependency list contains PyTorch/ATen dependencies."""
130+
if not dep_list:
131+
return False
132+
for dep in dep_list:
133+
if type(dep) == "string":
134+
if "torch" in dep or "libtorch" in dep or "caffe2" in dep:
135+
return True
136+
return False
137+
128138
def _patch_test_compiler_flags(kwargs):
129139
if "compiler_flags" not in kwargs:
130140
kwargs["compiler_flags"] = []
131141

132-
# Required globally by all c++ tests.
133-
kwargs["compiler_flags"] += [
134-
"-std=c++17",
135-
]
142+
# Determine C++ standard based on whether this is an aten test.
143+
# Aten tests require at least C++20 to compile against PyTorch, while
144+
# non-aten tests are pinned to C++17 for embedded.
145+
name = kwargs.get("name", "")
146+
external_deps = kwargs.get("external_deps", [])
147+
deps = kwargs.get("deps", [])
148+
xplat_deps = kwargs.get("xplat_deps", [])
149+
fbcode_deps = kwargs.get("fbcode_deps", [])
150+
is_aten_test = (
151+
"_aten" in name or
152+
"aten_" in name or
153+
"libtorch" in external_deps or
154+
"gtest_aten" in external_deps or
155+
"gmock_aten" in external_deps or
156+
_has_pytorch_dep(deps) or
157+
_has_pytorch_dep(xplat_deps) or
158+
_has_pytorch_dep(fbcode_deps)
159+
)
160+
161+
if not is_aten_test:
162+
kwargs["compiler_flags"] += [
163+
"-std=c++17",
164+
]
136165

137166
# Relaxing some constraints for tests
138167
kwargs["compiler_flags"] += [

third-party/gtest_defs.bzl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "get_aten_mode_opt
44
COMPILER_FLAGS = [
55
"-std=c++17",
66
]
7-
COMPILER_FLAGS_ATEN = [
8-
"-std=c++17",]
7+
COMPILER_FLAGS_ATEN = []
98

109
# define_gtest_targets
1110
def define_gtest_targets():

0 commit comments

Comments
 (0)