Skip to content

Commit 4bd0e55

Browse files
author
miranov25
committed
Phase 11.1a: Namespace function support (TMath, ROOT.Math, etc.)
Implement namespace function calls in DSL expressions: - TMath.Pi(), TMath.Sin(x), TMath.Sqrt(x) - ROOT.Math.VectorUtil.DeltaPhi(v1, v2) - Support both Python (.) and C++ (::) syntax Components: - dsl_compiler.py: Preprocess :: to . for AST parsing - Preserves :: inside strings and brackets (slice syntax) - ir_builder.py: Namespace detection and CallNode building - _extract_namespace_chain(): Identify namespace vs method calls - _build_namespace_call(): Create CallNode with namespace info - Correctly disambiguates track.Pt() (method) vs TMath.Pi() (namespace) - backend_cpp.py: Generate qualified C++ names and headers - TMath::Sqrt, ROOT::Math::VectorUtil::DeltaPhi - Adds appropriate headers (<TMath.h>, etc.) - constants.py: KNOWN_NAMESPACES, NAMESPACE_HEADERS, NAMESPACE_FUNCTION_TYPES LIMITATION: Scalar functions on vectors not yet supported. TMath.Sqrt(RVec<double>) fails - requires Phase 11.1b. Use sqrt(pt) with ADL for vectorized operations until then. Tests: 722 passed, 1 skipped
1 parent 6e64a0b commit 4bd0e55

5 files changed

Lines changed: 741 additions & 5 deletions

File tree

UTILS/dfextensions/RDataFrameDSL/RDataFrameDSL/backend_cpp.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
SubscriptNode, SliceNode, UnaryOp, BinaryOp, RVecSliceNode, SliceKind
6565
)
6666
from .ir_errors import IRError, IRErrorKind
67+
from .constants import NAMESPACE_HEADERS # Phase 11.1
6768

6869
__all__ = [
6970
'CppCodeGenerator',
@@ -1201,8 +1202,21 @@ def _cpp_function_name(self, node: CallNode) -> str:
12011202
# === PHASE 9: For RVec operations, use unqualified names to enable ADL ===
12021203
# ROOT provides vectorized functions like sqrt, sin, cos via ROOT::VecOps
12031204
# ADL (Argument Dependent Lookup) finds them when arguments are ROOT::RVec
1205+
# Note: This applies to std:: functions but NOT TMath:: (which need explicit namespace)
12041206
if node.rank > 0:
1205-
# Use unqualified name for ADL with RVec
1207+
# Phase 11.1: For non-std namespace functions (like TMath::Sin), preserve the namespace
1208+
# cpp_name contains full qualified name for namespace functions
1209+
is_std_function = (node.namespace == "std" or
1210+
(node.cpp_name and node.cpp_name.startswith("std::")))
1211+
1212+
if not is_std_function and (node.namespace or node.cpp_name):
1213+
# Non-std namespace function - use cpp_name which has proper qualification
1214+
if node.cpp_name:
1215+
return node.cpp_name
1216+
if node.namespace:
1217+
return f"{node.namespace}::{node.func}"
1218+
1219+
# Standard math functions - use unqualified name for ADL
12061220
func_name = node.func
12071221
if "." in func_name:
12081222
func_name = func_name.replace(".", "::")
@@ -1211,8 +1225,7 @@ def _cpp_function_name(self, node: CallNode) -> str:
12111225

12121226
# Check custom cpp_name first - it takes priority
12131227
if node.cpp_name:
1214-
# cpp_name already contains full qualified name (e.g., "std::sqrt")
1215-
# Don't add namespace again even if node.namespace is set
1228+
# cpp_name already contains full qualified name (e.g., "std::sqrt", "TMath::Sin")
12161229
return node.cpp_name
12171230

12181231
# Check if it's a namespaced function (e.g., TMath.Gaus)
@@ -1256,6 +1269,20 @@ def _collect_headers(self, ir: IRNode) -> List[str]:
12561269
# Add headers from node itself
12571270
if node.headers:
12581271
headers.update(node.headers)
1272+
1273+
# Phase 11.1: Add namespace headers
1274+
if node.namespace:
1275+
# Convert C++ namespace back to dot notation for lookup
1276+
ns_dot = node.namespace.replace("::", ".")
1277+
if ns_dot in NAMESPACE_HEADERS:
1278+
headers.add(NAMESPACE_HEADERS[ns_dot])
1279+
# Check parent namespaces
1280+
parts = ns_dot.split(".")
1281+
for i in range(len(parts), 0, -1):
1282+
parent = ".".join(parts[:i])
1283+
if parent in NAMESPACE_HEADERS:
1284+
headers.add(NAMESPACE_HEADERS[parent])
1285+
break
12591286

12601287
elif isinstance(node, BinaryOpNode):
12611288
if node.op == BinaryOp.POW:

UTILS/dfextensions/RDataFrameDSL/RDataFrameDSL/constants.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
'FUNCTION_CPP_NAMES',
2121
'CLASS_HEADERS',
2222
'SCALAR_TYPES',
23+
# Phase 11.1: Namespace support
24+
'KNOWN_NAMESPACES',
25+
'NAMESPACE_HEADERS',
26+
'NAMESPACE_FUNCTION_TYPES',
2327
]
2428

2529

@@ -273,3 +277,71 @@
273277
"<TClass.h>",
274278
"<TDataMember.h>",
275279
]
280+
281+
282+
# =============================================================================
283+
# Phase 11.1: Namespace Support
284+
# =============================================================================
285+
286+
# Known ROOT/C++ namespaces (builtins)
287+
KNOWN_NAMESPACES: Set[str] = {
288+
"TMath",
289+
"ROOT",
290+
"ROOT.Math",
291+
"ROOT.Math.VectorUtil",
292+
"ROOT.VecOps",
293+
"std",
294+
}
295+
296+
# Headers required for namespaces
297+
NAMESPACE_HEADERS: Dict[str, str] = {
298+
"TMath": "<TMath.h>",
299+
"ROOT.Math": "<Math/Vector4D.h>",
300+
"ROOT.Math.VectorUtil": "<Math/VectorUtil.h>",
301+
"ROOT.VecOps": "<ROOT/RVec.hxx>",
302+
"std": "<cmath>",
303+
}
304+
305+
# Return types for common namespace functions (explicit, no guessing)
306+
NAMESPACE_FUNCTION_TYPES: Dict[str, Dict[str, str]] = {
307+
"TMath": {
308+
"Pi": "double",
309+
"E": "double",
310+
"Sin": "double",
311+
"Cos": "double",
312+
"Tan": "double",
313+
"ASin": "double",
314+
"ACos": "double",
315+
"ATan": "double",
316+
"ATan2": "double",
317+
"Sqrt": "double",
318+
"Exp": "double",
319+
"Log": "double",
320+
"Log10": "double",
321+
"Abs": "double",
322+
"Power": "double",
323+
"Min": "double",
324+
"Max": "double",
325+
"Sign": "double",
326+
"Gaus": "double",
327+
"BreitWigner": "double",
328+
"Landau": "double",
329+
"TwoPi": "double",
330+
"PiOver2": "double",
331+
"PiOver4": "double",
332+
"DegToRad": "double",
333+
"RadToDeg": "double",
334+
"Hypot": "double",
335+
"Range": "double",
336+
"Floor": "double",
337+
"Ceil": "double",
338+
"Nint": "int",
339+
},
340+
"ROOT.Math.VectorUtil": {
341+
"DeltaPhi": "double",
342+
"DeltaR": "double",
343+
"CosTheta": "double",
344+
"Angle": "double",
345+
"InvariantMass": "double",
346+
},
347+
}

UTILS/dfextensions/RDataFrameDSL/RDataFrameDSL/dsl_compiler.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,62 @@ def __init__(self, schema: Dict[str, str], safe_indexing: bool = True):
158158
self._functions: Dict[str, GeneratedFunction] = {}
159159
self.library = FunctionLibrary()
160160

161+
def _preprocess_expression(self, expr: str) -> str:
162+
"""
163+
Convert C++ :: notation to Python dot notation.
164+
165+
Phase 11.1: Handles namespace syntax preprocessing before AST parsing.
166+
167+
Handles:
168+
- TMath::Pi() -> TMath.Pi()
169+
- ROOT::Math::VectorUtil::DeltaPhi() -> ROOT.Math.VectorUtil.DeltaPhi()
170+
171+
Skips:
172+
- String literals to avoid corrupting "Error::Message"
173+
- Square brackets to preserve slice syntax like pt[::-1]
174+
175+
Args:
176+
expr: DSL expression potentially with C++ :: notation
177+
178+
Returns:
179+
Expression with :: replaced by . (except in strings and brackets)
180+
"""
181+
result = []
182+
i = 0
183+
in_string = False
184+
string_char = None
185+
bracket_depth = 0 # Track [] nesting for slice syntax
186+
187+
while i < len(expr):
188+
char = expr[i]
189+
190+
# Track string literals
191+
if char in ('"', "'") and (i == 0 or expr[i-1] != '\\'):
192+
if not in_string:
193+
in_string = True
194+
string_char = char
195+
elif char == string_char:
196+
in_string = False
197+
string_char = None
198+
199+
# Track square brackets (for slice syntax like [::-1])
200+
if not in_string:
201+
if char == '[':
202+
bracket_depth += 1
203+
elif char == ']':
204+
bracket_depth -= 1
205+
206+
# Replace :: with . only outside strings AND outside brackets
207+
if not in_string and bracket_depth == 0 and expr[i:i+2] == '::':
208+
result.append('.')
209+
i += 2
210+
continue
211+
212+
result.append(char)
213+
i += 1
214+
215+
return ''.join(result)
216+
161217
def define(self, name: str, expression: str) -> 'DSLCompiler':
162218
"""
163219
Define a new column from a DSL expression.
@@ -198,9 +254,12 @@ def define(self, name: str, expression: str) -> 'DSLCompiler':
198254
suggestions=["Each column name must be unique"]
199255
)
200256

257+
# Phase 11.1: Preprocess C++ :: syntax to Python dot syntax
258+
preprocessed = self._preprocess_expression(expression)
259+
201260
# Parse and generate
202261
builder = IRBuilder(self._inferrer)
203-
ir = builder.build(expression)
262+
ir = builder.build(preprocessed)
204263

205264
# Use unique suffix to avoid collisions in parallel execution
206265
unique_name = f"{name}_{self._unique_id}"

0 commit comments

Comments
 (0)