@@ -125,14 +125,43 @@ def _patch_build_mode_flags(kwargs):
125125
126126 return kwargs
127127
128+ def _has_pytorch_dep (dep_list ):
129+ """Check if a dependency list contains PyTorch/ATen dependencies."""
130+ if not dep_list :
131+ return False
132+ for dep in dep_list :
133+ if type (dep ) == "string" :
134+ if "torch" in dep or "libtorch" in dep or "caffe2" in dep :
135+ return True
136+ return False
137+
128138def _patch_test_compiler_flags (kwargs ):
129139 if "compiler_flags" not in kwargs :
130140 kwargs ["compiler_flags" ] = []
131141
132- # Required globally by all c++ tests.
133- kwargs ["compiler_flags" ] += [
134- "-std=c++17" ,
135- ]
142+ # Determine C++ standard based on whether this is an aten test.
143+ # Aten tests require at least C++20 to compile against PyTorch, while
144+ # non-aten tests are pinned to C++17 for embedded.
145+ name = kwargs .get ("name" , "" )
146+ external_deps = kwargs .get ("external_deps" , [])
147+ deps = kwargs .get ("deps" , [])
148+ xplat_deps = kwargs .get ("xplat_deps" , [])
149+ fbcode_deps = kwargs .get ("fbcode_deps" , [])
150+ is_aten_test = (
151+ "_aten" in name or
152+ "aten_" in name or
153+ "libtorch" in external_deps or
154+ "gtest_aten" in external_deps or
155+ "gmock_aten" in external_deps or
156+ _has_pytorch_dep (deps ) or
157+ _has_pytorch_dep (xplat_deps ) or
158+ _has_pytorch_dep (fbcode_deps )
159+ )
160+
161+ if not is_aten_test :
162+ kwargs ["compiler_flags" ] += [
163+ "-std=c++17" ,
164+ ]
136165
137166 # Relaxing some constraints for tests
138167 kwargs ["compiler_flags" ] += [
0 commit comments