Skip to content

Commit 249fd55

Browse files
committed
Phase 13.26.ADF: read_tree dtype_overrides for on-the-fly type conversion
New parameter dtype_overrides={regex: np.dtype} on read_tree(). Patterns matched via re.fullmatch, first match wins, compiled once. Injected into dtype_hints (priority: overrides > compression > schema). Overflow safety: warns with correct original→target dtype when finite values become inf during downcast. NaN preserved (IEEE 754). Applies to current tree only, not subframes. Tests D1-D10: regex matching, precedence, overflow warning (correct dtypes), NaN preservation, baseline equivalence, round-trip tolerance, schema round-trip, entry_range consistency. Reviewed-by: Sonnet1
1 parent b9c2866 commit 249fd55

2 files changed

Lines changed: 309 additions & 3 deletions

File tree

UTILS/dfextensions/AliasDataFrame/AliasDataFrame.py

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5556,7 +5556,7 @@ def _write_metadata_to_tree(self, open_tfile, treename):
55565556

55575557
@staticmethod
55585558
def read_tree(filename, treename="tree", entry_start=None, entry_stop=None,
5559-
num_workers=8, load_subframes=True):
5559+
num_workers=8, load_subframes=True, dtype_overrides=None):
55605560
"""
55615561
Read AliasDataFrame from ROOT TTree with optimized memory and speed.
55625562

@@ -5582,6 +5582,22 @@ def read_tree(filename, treename="tree", entry_start=None, entry_stop=None,
55825582
If True (default), automatically load and register subframes defined
55835583
in schema. Tries both Python naming ({treename}__subframe__{name})
55845584
and C++ naming ({name}) conventions.
5585+
dtype_overrides : dict, optional
5586+
Regex pattern → numpy dtype mapping for on-the-fly type conversion
5587+
during read. Patterns are matched against branch names using
5588+
``re.fullmatch``. First matching pattern wins. Applied AFTER
5589+
schema/compression dtype hints (higher priority).
5590+
5591+
Example::
5592+
5593+
dtype_overrides={
5594+
r'.*_PIter\\d+$': np.float16, # iteration coefficients
5595+
r'.*_err_.*': np.float32, # errors stay float32
5596+
r'firstTForbit': np.uint32, # orbit counter
5597+
}
5598+
5599+
Safety: warns on overflow (finite value → inf after downcast).
5600+
NaN values are preserved across all float conversions.
55855601

55865602
Returns
55875603
-------
@@ -5595,6 +5611,7 @@ def read_tree(filename, treename="tree", entry_start=None, entry_stop=None,
55955611
- entry_start/entry_stop apply only to main tree, not subframes
55965612
- Subframes are always fully loaded (they contain small calibration data)
55975613
- Backward compatible with files created by older versions
5614+
- dtype_overrides applies to the current tree only, not subframes
55985615

55995616
Examples
56005617
--------
@@ -5609,6 +5626,12 @@ def read_tree(filename, treename="tree", entry_start=None, entry_stop=None,
56095626

56105627
>>> # Skip subframe loading (faster, for main tree only)
56115628
>>> adf = AliasDataFrame.read_tree("data.root", "tree", load_subframes=False)
5629+
5630+
>>> # Read with dtype conversion (3GB → 800MB)
5631+
>>> adf = AliasDataFrame.read_tree("data.root", "tree", dtype_overrides={
5632+
... r'.*_PIter\\d+$': np.float16,
5633+
... r'.*_err_.*': np.float32,
5634+
... })
56125635
"""
56135636
import warnings
56145637
import concurrent.futures
@@ -5742,13 +5765,36 @@ def read_tree(filename, treename="tree", entry_start=None, entry_stop=None,
57425765
f"Using default."
57435766
)
57445767

5768+
# =========================================================================
5769+
# Step 2c: Apply user-specified dtype_overrides (Phase 13.26.ADF)
5770+
# Priority: dtype_overrides > compression_info > column_dtypes
5771+
# =========================================================================
5772+
if dtype_overrides:
5773+
# Pre-compile patterns for efficiency
5774+
compiled_overrides = []
5775+
for pattern, dtype in dtype_overrides.items():
5776+
try:
5777+
compiled_overrides.append((re.compile(pattern), np.dtype(dtype)))
5778+
except (re.error, TypeError) as e:
5779+
warnings.warn(
5780+
f"Invalid dtype_override: pattern={pattern!r}, dtype={dtype}: {e}"
5781+
)
5782+
57455783
# =========================================================================
57465784
# Step 3: Read branches with uproot (branch-by-branch for memory efficiency)
57475785
# =========================================================================
57485786
with uproot.open(filename) as f:
57495787
tree = f[treename]
57505788
branch_names = list(tree.keys())
57515789

5790+
# Apply dtype_overrides: regex match branch names → inject into dtype_hints
5791+
if dtype_overrides and compiled_overrides:
5792+
for branch_name in branch_names:
5793+
for regex, target_dtype in compiled_overrides:
5794+
if regex.fullmatch(branch_name):
5795+
dtype_hints[branch_name] = target_dtype
5796+
break # first match wins
5797+
57525798
if not branch_names:
57535799
df = pd.DataFrame()
57545800

@@ -5765,7 +5811,22 @@ def read_branch(branch_name):
57655811
if branch_name in dtype_hints:
57665812
target_dtype = dtype_hints[branch_name]
57675813
if arr.dtype != target_dtype:
5768-
arr = arr.astype(target_dtype)
5814+
# Safety: detect overflow on downcast (finite→inf)
5815+
if np.issubdtype(arr.dtype, np.floating) and np.issubdtype(target_dtype, np.floating):
5816+
original_dtype = arr.dtype
5817+
finite_before = np.isfinite(arr).sum()
5818+
arr = arr.astype(target_dtype)
5819+
finite_after = np.isfinite(arr).sum()
5820+
if finite_after < finite_before:
5821+
n_overflow = finite_before - finite_after
5822+
warnings.warn(
5823+
f"[read_tree] dtype_overrides: {n_overflow} values overflowed "
5824+
f"to inf in column '{branch_name}' during "
5825+
f"{original_dtype} → {target_dtype} conversion",
5826+
UserWarning,
5827+
)
5828+
else:
5829+
arr = arr.astype(target_dtype)
57695830

57705831
return branch_name, arr
57715832

@@ -5807,7 +5868,22 @@ def read_branch(branch_name):
58075868
if branch_name in dtype_hints:
58085869
target_dtype = dtype_hints[branch_name]
58095870
if arr.dtype != target_dtype:
5810-
arr = arr.astype(target_dtype)
5871+
# Safety: detect overflow on downcast (finite→inf)
5872+
if np.issubdtype(arr.dtype, np.floating) and np.issubdtype(target_dtype, np.floating):
5873+
original_dtype = arr.dtype
5874+
finite_before = np.isfinite(arr).sum()
5875+
arr = arr.astype(target_dtype)
5876+
finite_after = np.isfinite(arr).sum()
5877+
if finite_after < finite_before:
5878+
n_overflow = finite_before - finite_after
5879+
warnings.warn(
5880+
f"[read_tree] dtype_overrides: {n_overflow} values overflowed "
5881+
f"to inf in column '{branch_name}' during "
5882+
f"{original_dtype} → {target_dtype} conversion",
5883+
UserWarning,
5884+
)
5885+
else:
5886+
arr = arr.astype(target_dtype)
58115887

58125888
arrays[branch_name] = arr
58135889

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
"""
2+
Phase 13.26.ADF — dtype_overrides on read_tree
3+
4+
D1: regex pattern matching converts float64→float16
5+
D2: first-match-wins precedence
6+
D3: existing columns (no override) unchanged
7+
D4: overflow detection warns on finite→inf
8+
D5: NaN preserved across float downcast
9+
D6: schema dtype_hints not overridden when no override matches
10+
D7: round-trip: write float64, read with override, values within tolerance
11+
"""
12+
13+
import os
14+
import sys
15+
import tempfile
16+
import warnings
17+
import pytest
18+
import numpy as np
19+
import pandas as pd
20+
21+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
22+
from AliasDataFrame import AliasDataFrame
23+
24+
try:
25+
import ROOT
26+
import uproot
27+
_HAS_ROOT = ROOT is not None
28+
except ImportError:
29+
_HAS_ROOT = False
30+
31+
32+
@pytest.fixture
33+
def tmp_root_file():
34+
"""Create a temp ROOT file with float64 branches for testing dtype_overrides."""
35+
if not _HAS_ROOT:
36+
pytest.skip("Requires ROOT + uproot")
37+
38+
tmpdir = tempfile.mkdtemp()
39+
filepath = os.path.join(tmpdir, "test_dtype_overrides.root")
40+
41+
rng = np.random.default_rng(42)
42+
n = 1000
43+
df = pd.DataFrame({
44+
'x': rng.uniform(0, 300, n).astype(np.float64),
45+
'dy_intercept_PIter1': rng.normal(0, 0.1, n).astype(np.float64),
46+
'dy_slope_PIter1': rng.normal(0, 0.01, n).astype(np.float64),
47+
'dy_err_PIter1': rng.uniform(0.001, 0.1, n).astype(np.float64),
48+
'dz_intercept_PIter2': rng.normal(0, 0.2, n).astype(np.float64),
49+
'row': np.arange(n, dtype=np.int64),
50+
'firstTForbit': np.full(n, 29893280, dtype=np.int64),
51+
})
52+
53+
adf = AliasDataFrame(df)
54+
adf.export_tree(filepath, treename="tree")
55+
return filepath
56+
57+
58+
@pytest.mark.skipif(not _HAS_ROOT, reason="Requires ROOT + uproot")
59+
class TestDtypeOverrides:
60+
61+
@pytest.mark.invariance
62+
def test_D1_regex_converts_float64_to_float16(self, tmp_root_file):
63+
"""Branches matching regex pattern are converted to target dtype."""
64+
adf = AliasDataFrame.read_tree(tmp_root_file, "tree", dtype_overrides={
65+
r'dy_.*_PIter\d+': np.float16,
66+
})
67+
assert adf.df['dy_intercept_PIter1'].dtype == np.float16
68+
assert adf.df['dy_slope_PIter1'].dtype == np.float16
69+
assert adf.df['dy_err_PIter1'].dtype == np.float16
70+
# Non-matching columns unchanged
71+
assert adf.df['dz_intercept_PIter2'].dtype != np.float16
72+
73+
@pytest.mark.invariance
74+
def test_D2_first_match_wins(self, tmp_root_file):
75+
"""First matching pattern takes priority."""
76+
adf = AliasDataFrame.read_tree(tmp_root_file, "tree", dtype_overrides={
77+
r'dy_err_.*': np.float32, # errors → float32 (matches first)
78+
r'dy_.*_PIter\d+': np.float16, # everything else → float16
79+
})
80+
assert adf.df['dy_err_PIter1'].dtype == np.float32 # first match
81+
assert adf.df['dy_intercept_PIter1'].dtype == np.float16 # second match
82+
assert adf.df['dy_slope_PIter1'].dtype == np.float16
83+
84+
@pytest.mark.invariance
85+
def test_D3_no_override_columns_unchanged(self, tmp_root_file):
86+
"""Columns not matching any pattern keep their original dtype."""
87+
adf = AliasDataFrame.read_tree(tmp_root_file, "tree", dtype_overrides={
88+
r'dy_.*': np.float16,
89+
})
90+
# 'x' and 'dz_*' should not be affected
91+
assert adf.df['x'].dtype != np.float16
92+
assert adf.df['dz_intercept_PIter2'].dtype != np.float16
93+
94+
@pytest.mark.invariance
95+
def test_D4_overflow_warns(self, tmp_root_file):
96+
"""Overflow on downcast produces a UserWarning."""
97+
# Write a file with large values that overflow float16
98+
tmpdir = tempfile.mkdtemp()
99+
filepath = os.path.join(tmpdir, "overflow_test.root")
100+
df = pd.DataFrame({
101+
'big_values': np.array([1e10, 1e20, 0.5, -0.5], dtype=np.float64),
102+
'row': np.arange(4, dtype=np.int64),
103+
})
104+
AliasDataFrame(df).export_tree(filepath, "tree")
105+
106+
with warnings.catch_warnings(record=True) as w:
107+
warnings.simplefilter("always")
108+
adf = AliasDataFrame.read_tree(filepath, "tree", dtype_overrides={
109+
r'big_values': np.float16,
110+
})
111+
overflow_warnings = [x for x in w if "overflowed" in str(x.message)]
112+
assert len(overflow_warnings) >= 1
113+
114+
@pytest.mark.invariance
115+
def test_D5_nan_preserved(self, tmp_root_file):
116+
"""NaN values survive float64→float16 downcast."""
117+
tmpdir = tempfile.mkdtemp()
118+
filepath = os.path.join(tmpdir, "nan_test.root")
119+
arr = np.array([1.0, np.nan, 3.0, np.nan, 5.0], dtype=np.float64)
120+
df = pd.DataFrame({'val': arr, 'row': np.arange(5, dtype=np.int64)})
121+
AliasDataFrame(df).export_tree(filepath, "tree")
122+
123+
adf = AliasDataFrame.read_tree(filepath, "tree", dtype_overrides={
124+
r'val': np.float16,
125+
})
126+
assert adf.df['val'].dtype == np.float16
127+
assert np.isnan(adf.df['val'].values[1])
128+
assert np.isnan(adf.df['val'].values[3])
129+
assert not np.isnan(adf.df['val'].values[0])
130+
131+
@pytest.mark.invariance
132+
def test_D6_no_overrides_matches_baseline(self, tmp_root_file):
133+
"""read_tree with dtype_overrides=None produces same result as without."""
134+
adf_base = AliasDataFrame.read_tree(tmp_root_file, "tree")
135+
adf_none = AliasDataFrame.read_tree(tmp_root_file, "tree", dtype_overrides=None)
136+
adf_empty = AliasDataFrame.read_tree(tmp_root_file, "tree", dtype_overrides={})
137+
138+
for col in adf_base.df.columns:
139+
assert adf_base.df[col].dtype == adf_none.df[col].dtype
140+
assert adf_base.df[col].dtype == adf_empty.df[col].dtype
141+
142+
@pytest.mark.invariance
143+
def test_D7_roundtrip_values_within_tolerance(self, tmp_root_file):
144+
"""Values survive write→read-with-override within float16 tolerance."""
145+
adf_orig = AliasDataFrame.read_tree(tmp_root_file, "tree")
146+
adf_f16 = AliasDataFrame.read_tree(tmp_root_file, "tree", dtype_overrides={
147+
r'dy_intercept_PIter1': np.float16,
148+
})
149+
150+
orig_vals = adf_orig.df['dy_intercept_PIter1'].values
151+
f16_vals = adf_f16.df['dy_intercept_PIter1'].values.astype(np.float64)
152+
153+
# float16 has ~3 decimal digits of precision
154+
np.testing.assert_allclose(orig_vals, f16_vals, rtol=1e-2, atol=1e-3,
155+
err_msg="D7: round-trip values diverged beyond float16 tolerance")
156+
157+
@pytest.mark.invariance
158+
def test_D8_schema_roundtrip_preserves_overridden_dtype(self, tmp_root_file):
159+
"""Export after read-with-override preserves the overridden dtype in schema.
160+
161+
Sequence: read(override float16) → export → re-read(no override)
162+
Expected: re-read gets float16 via schema column_dtypes, not float64.
163+
"""
164+
tmpdir = tempfile.mkdtemp()
165+
reexport_path = os.path.join(tmpdir, "reexported.root")
166+
167+
# Read with override
168+
adf = AliasDataFrame.read_tree(tmp_root_file, "tree", dtype_overrides={
169+
r'dy_intercept_PIter1': np.float16,
170+
})
171+
assert adf.df['dy_intercept_PIter1'].dtype == np.float16
172+
173+
# Export (schema records actual dtype)
174+
adf.export_tree(reexport_path, "tree")
175+
176+
# Re-read WITHOUT override — schema should preserve float16
177+
adf2 = AliasDataFrame.read_tree(reexport_path, "tree")
178+
assert adf2.df['dy_intercept_PIter1'].dtype == np.float16, \
179+
"D8: schema round-trip lost overridden dtype"
180+
181+
@pytest.mark.invariance
182+
def test_D9_entry_range_with_overrides(self, tmp_root_file):
183+
"""dtype_overrides applied consistently with entry_start/entry_stop."""
184+
# Read full file with override
185+
adf_full = AliasDataFrame.read_tree(tmp_root_file, "tree", dtype_overrides={
186+
r'dy_.*_PIter\d+': np.float16,
187+
})
188+
189+
# Read subset with same override
190+
adf_sub = AliasDataFrame.read_tree(tmp_root_file, "tree",
191+
entry_start=100, entry_stop=500,
192+
dtype_overrides={r'dy_.*_PIter\d+': np.float16})
193+
194+
# Dtypes must match
195+
assert adf_sub.df['dy_intercept_PIter1'].dtype == np.float16
196+
assert adf_sub.df['dy_slope_PIter1'].dtype == np.float16
197+
198+
# Values must match the corresponding slice of full read
199+
np.testing.assert_array_equal(
200+
adf_sub.df['dy_intercept_PIter1'].values,
201+
adf_full.df['dy_intercept_PIter1'].values[100:500],
202+
err_msg="D9: entry_range + override produced different values"
203+
)
204+
205+
@pytest.mark.invariance
206+
def test_D10_override_warning_shows_correct_dtypes(self, tmp_root_file):
207+
"""Overflow warning message shows original→target dtype, not target→target."""
208+
tmpdir = tempfile.mkdtemp()
209+
filepath = os.path.join(tmpdir, "dtype_msg_test.root")
210+
df = pd.DataFrame({
211+
'big': np.array([1e10, 1e20], dtype=np.float64),
212+
'row': np.arange(2, dtype=np.int64),
213+
})
214+
AliasDataFrame(df).export_tree(filepath, "tree")
215+
216+
with warnings.catch_warnings(record=True) as w:
217+
warnings.simplefilter("always")
218+
AliasDataFrame.read_tree(filepath, "tree", dtype_overrides={
219+
r'big': np.float16,
220+
})
221+
overflow_msgs = [str(x.message) for x in w if "overflowed" in str(x.message)]
222+
assert len(overflow_msgs) >= 1
223+
# Must show float64 → float16, NOT float16 → float16
224+
assert "float64" in overflow_msgs[0], \
225+
f"D10: warning should show original dtype float64, got: {overflow_msgs[0]}"
226+
assert "float16" in overflow_msgs[0]
227+
228+
229+
if __name__ == '__main__':
230+
pytest.main([__file__, '-v', '-s'])

0 commit comments

Comments
 (0)