|
| 1 | +""" |
| 2 | +Phase 13.26.ADF — dtype_overrides on read_tree |
| 3 | +
|
| 4 | +D1: regex pattern matching converts float64→float16 |
| 5 | +D2: first-match-wins precedence |
| 6 | +D3: existing columns (no override) unchanged |
| 7 | +D4: overflow detection warns on finite→inf |
| 8 | +D5: NaN preserved across float downcast |
| 9 | +D6: schema dtype_hints not overridden when no override matches |
| 10 | +D7: round-trip: write float64, read with override, values within tolerance |
| 11 | +""" |
| 12 | + |
| 13 | +import os |
| 14 | +import sys |
| 15 | +import tempfile |
| 16 | +import warnings |
| 17 | +import pytest |
| 18 | +import numpy as np |
| 19 | +import pandas as pd |
| 20 | + |
| 21 | +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| 22 | +from AliasDataFrame import AliasDataFrame |
| 23 | + |
| 24 | +try: |
| 25 | + import ROOT |
| 26 | + import uproot |
| 27 | + _HAS_ROOT = ROOT is not None |
| 28 | +except ImportError: |
| 29 | + _HAS_ROOT = False |
| 30 | + |
| 31 | + |
| 32 | +@pytest.fixture |
| 33 | +def tmp_root_file(): |
| 34 | + """Create a temp ROOT file with float64 branches for testing dtype_overrides.""" |
| 35 | + if not _HAS_ROOT: |
| 36 | + pytest.skip("Requires ROOT + uproot") |
| 37 | + |
| 38 | + tmpdir = tempfile.mkdtemp() |
| 39 | + filepath = os.path.join(tmpdir, "test_dtype_overrides.root") |
| 40 | + |
| 41 | + rng = np.random.default_rng(42) |
| 42 | + n = 1000 |
| 43 | + df = pd.DataFrame({ |
| 44 | + 'x': rng.uniform(0, 300, n).astype(np.float64), |
| 45 | + 'dy_intercept_PIter1': rng.normal(0, 0.1, n).astype(np.float64), |
| 46 | + 'dy_slope_PIter1': rng.normal(0, 0.01, n).astype(np.float64), |
| 47 | + 'dy_err_PIter1': rng.uniform(0.001, 0.1, n).astype(np.float64), |
| 48 | + 'dz_intercept_PIter2': rng.normal(0, 0.2, n).astype(np.float64), |
| 49 | + 'row': np.arange(n, dtype=np.int64), |
| 50 | + 'firstTForbit': np.full(n, 29893280, dtype=np.int64), |
| 51 | + }) |
| 52 | + |
| 53 | + adf = AliasDataFrame(df) |
| 54 | + adf.export_tree(filepath, treename="tree") |
| 55 | + return filepath |
| 56 | + |
| 57 | + |
| 58 | +@pytest.mark.skipif(not _HAS_ROOT, reason="Requires ROOT + uproot") |
| 59 | +class TestDtypeOverrides: |
| 60 | + |
| 61 | + @pytest.mark.invariance |
| 62 | + def test_D1_regex_converts_float64_to_float16(self, tmp_root_file): |
| 63 | + """Branches matching regex pattern are converted to target dtype.""" |
| 64 | + adf = AliasDataFrame.read_tree(tmp_root_file, "tree", dtype_overrides={ |
| 65 | + r'dy_.*_PIter\d+': np.float16, |
| 66 | + }) |
| 67 | + assert adf.df['dy_intercept_PIter1'].dtype == np.float16 |
| 68 | + assert adf.df['dy_slope_PIter1'].dtype == np.float16 |
| 69 | + assert adf.df['dy_err_PIter1'].dtype == np.float16 |
| 70 | + # Non-matching columns unchanged |
| 71 | + assert adf.df['dz_intercept_PIter2'].dtype != np.float16 |
| 72 | + |
| 73 | + @pytest.mark.invariance |
| 74 | + def test_D2_first_match_wins(self, tmp_root_file): |
| 75 | + """First matching pattern takes priority.""" |
| 76 | + adf = AliasDataFrame.read_tree(tmp_root_file, "tree", dtype_overrides={ |
| 77 | + r'dy_err_.*': np.float32, # errors → float32 (matches first) |
| 78 | + r'dy_.*_PIter\d+': np.float16, # everything else → float16 |
| 79 | + }) |
| 80 | + assert adf.df['dy_err_PIter1'].dtype == np.float32 # first match |
| 81 | + assert adf.df['dy_intercept_PIter1'].dtype == np.float16 # second match |
| 82 | + assert adf.df['dy_slope_PIter1'].dtype == np.float16 |
| 83 | + |
| 84 | + @pytest.mark.invariance |
| 85 | + def test_D3_no_override_columns_unchanged(self, tmp_root_file): |
| 86 | + """Columns not matching any pattern keep their original dtype.""" |
| 87 | + adf = AliasDataFrame.read_tree(tmp_root_file, "tree", dtype_overrides={ |
| 88 | + r'dy_.*': np.float16, |
| 89 | + }) |
| 90 | + # 'x' and 'dz_*' should not be affected |
| 91 | + assert adf.df['x'].dtype != np.float16 |
| 92 | + assert adf.df['dz_intercept_PIter2'].dtype != np.float16 |
| 93 | + |
| 94 | + @pytest.mark.invariance |
| 95 | + def test_D4_overflow_warns(self, tmp_root_file): |
| 96 | + """Overflow on downcast produces a UserWarning.""" |
| 97 | + # Write a file with large values that overflow float16 |
| 98 | + tmpdir = tempfile.mkdtemp() |
| 99 | + filepath = os.path.join(tmpdir, "overflow_test.root") |
| 100 | + df = pd.DataFrame({ |
| 101 | + 'big_values': np.array([1e10, 1e20, 0.5, -0.5], dtype=np.float64), |
| 102 | + 'row': np.arange(4, dtype=np.int64), |
| 103 | + }) |
| 104 | + AliasDataFrame(df).export_tree(filepath, "tree") |
| 105 | + |
| 106 | + with warnings.catch_warnings(record=True) as w: |
| 107 | + warnings.simplefilter("always") |
| 108 | + adf = AliasDataFrame.read_tree(filepath, "tree", dtype_overrides={ |
| 109 | + r'big_values': np.float16, |
| 110 | + }) |
| 111 | + overflow_warnings = [x for x in w if "overflowed" in str(x.message)] |
| 112 | + assert len(overflow_warnings) >= 1 |
| 113 | + |
| 114 | + @pytest.mark.invariance |
| 115 | + def test_D5_nan_preserved(self, tmp_root_file): |
| 116 | + """NaN values survive float64→float16 downcast.""" |
| 117 | + tmpdir = tempfile.mkdtemp() |
| 118 | + filepath = os.path.join(tmpdir, "nan_test.root") |
| 119 | + arr = np.array([1.0, np.nan, 3.0, np.nan, 5.0], dtype=np.float64) |
| 120 | + df = pd.DataFrame({'val': arr, 'row': np.arange(5, dtype=np.int64)}) |
| 121 | + AliasDataFrame(df).export_tree(filepath, "tree") |
| 122 | + |
| 123 | + adf = AliasDataFrame.read_tree(filepath, "tree", dtype_overrides={ |
| 124 | + r'val': np.float16, |
| 125 | + }) |
| 126 | + assert adf.df['val'].dtype == np.float16 |
| 127 | + assert np.isnan(adf.df['val'].values[1]) |
| 128 | + assert np.isnan(adf.df['val'].values[3]) |
| 129 | + assert not np.isnan(adf.df['val'].values[0]) |
| 130 | + |
| 131 | + @pytest.mark.invariance |
| 132 | + def test_D6_no_overrides_matches_baseline(self, tmp_root_file): |
| 133 | + """read_tree with dtype_overrides=None produces same result as without.""" |
| 134 | + adf_base = AliasDataFrame.read_tree(tmp_root_file, "tree") |
| 135 | + adf_none = AliasDataFrame.read_tree(tmp_root_file, "tree", dtype_overrides=None) |
| 136 | + adf_empty = AliasDataFrame.read_tree(tmp_root_file, "tree", dtype_overrides={}) |
| 137 | + |
| 138 | + for col in adf_base.df.columns: |
| 139 | + assert adf_base.df[col].dtype == adf_none.df[col].dtype |
| 140 | + assert adf_base.df[col].dtype == adf_empty.df[col].dtype |
| 141 | + |
| 142 | + @pytest.mark.invariance |
| 143 | + def test_D7_roundtrip_values_within_tolerance(self, tmp_root_file): |
| 144 | + """Values survive write→read-with-override within float16 tolerance.""" |
| 145 | + adf_orig = AliasDataFrame.read_tree(tmp_root_file, "tree") |
| 146 | + adf_f16 = AliasDataFrame.read_tree(tmp_root_file, "tree", dtype_overrides={ |
| 147 | + r'dy_intercept_PIter1': np.float16, |
| 148 | + }) |
| 149 | + |
| 150 | + orig_vals = adf_orig.df['dy_intercept_PIter1'].values |
| 151 | + f16_vals = adf_f16.df['dy_intercept_PIter1'].values.astype(np.float64) |
| 152 | + |
| 153 | + # float16 has ~3 decimal digits of precision |
| 154 | + np.testing.assert_allclose(orig_vals, f16_vals, rtol=1e-2, atol=1e-3, |
| 155 | + err_msg="D7: round-trip values diverged beyond float16 tolerance") |
| 156 | + |
| 157 | + @pytest.mark.invariance |
| 158 | + def test_D8_schema_roundtrip_preserves_overridden_dtype(self, tmp_root_file): |
| 159 | + """Export after read-with-override preserves the overridden dtype in schema. |
| 160 | + |
| 161 | + Sequence: read(override float16) → export → re-read(no override) |
| 162 | + Expected: re-read gets float16 via schema column_dtypes, not float64. |
| 163 | + """ |
| 164 | + tmpdir = tempfile.mkdtemp() |
| 165 | + reexport_path = os.path.join(tmpdir, "reexported.root") |
| 166 | + |
| 167 | + # Read with override |
| 168 | + adf = AliasDataFrame.read_tree(tmp_root_file, "tree", dtype_overrides={ |
| 169 | + r'dy_intercept_PIter1': np.float16, |
| 170 | + }) |
| 171 | + assert adf.df['dy_intercept_PIter1'].dtype == np.float16 |
| 172 | + |
| 173 | + # Export (schema records actual dtype) |
| 174 | + adf.export_tree(reexport_path, "tree") |
| 175 | + |
| 176 | + # Re-read WITHOUT override — schema should preserve float16 |
| 177 | + adf2 = AliasDataFrame.read_tree(reexport_path, "tree") |
| 178 | + assert adf2.df['dy_intercept_PIter1'].dtype == np.float16, \ |
| 179 | + "D8: schema round-trip lost overridden dtype" |
| 180 | + |
| 181 | + @pytest.mark.invariance |
| 182 | + def test_D9_entry_range_with_overrides(self, tmp_root_file): |
| 183 | + """dtype_overrides applied consistently with entry_start/entry_stop.""" |
| 184 | + # Read full file with override |
| 185 | + adf_full = AliasDataFrame.read_tree(tmp_root_file, "tree", dtype_overrides={ |
| 186 | + r'dy_.*_PIter\d+': np.float16, |
| 187 | + }) |
| 188 | + |
| 189 | + # Read subset with same override |
| 190 | + adf_sub = AliasDataFrame.read_tree(tmp_root_file, "tree", |
| 191 | + entry_start=100, entry_stop=500, |
| 192 | + dtype_overrides={r'dy_.*_PIter\d+': np.float16}) |
| 193 | + |
| 194 | + # Dtypes must match |
| 195 | + assert adf_sub.df['dy_intercept_PIter1'].dtype == np.float16 |
| 196 | + assert adf_sub.df['dy_slope_PIter1'].dtype == np.float16 |
| 197 | + |
| 198 | + # Values must match the corresponding slice of full read |
| 199 | + np.testing.assert_array_equal( |
| 200 | + adf_sub.df['dy_intercept_PIter1'].values, |
| 201 | + adf_full.df['dy_intercept_PIter1'].values[100:500], |
| 202 | + err_msg="D9: entry_range + override produced different values" |
| 203 | + ) |
| 204 | + |
| 205 | + @pytest.mark.invariance |
| 206 | + def test_D10_override_warning_shows_correct_dtypes(self, tmp_root_file): |
| 207 | + """Overflow warning message shows original→target dtype, not target→target.""" |
| 208 | + tmpdir = tempfile.mkdtemp() |
| 209 | + filepath = os.path.join(tmpdir, "dtype_msg_test.root") |
| 210 | + df = pd.DataFrame({ |
| 211 | + 'big': np.array([1e10, 1e20], dtype=np.float64), |
| 212 | + 'row': np.arange(2, dtype=np.int64), |
| 213 | + }) |
| 214 | + AliasDataFrame(df).export_tree(filepath, "tree") |
| 215 | + |
| 216 | + with warnings.catch_warnings(record=True) as w: |
| 217 | + warnings.simplefilter("always") |
| 218 | + AliasDataFrame.read_tree(filepath, "tree", dtype_overrides={ |
| 219 | + r'big': np.float16, |
| 220 | + }) |
| 221 | + overflow_msgs = [str(x.message) for x in w if "overflowed" in str(x.message)] |
| 222 | + assert len(overflow_msgs) >= 1 |
| 223 | + # Must show float64 → float16, NOT float16 → float16 |
| 224 | + assert "float64" in overflow_msgs[0], \ |
| 225 | + f"D10: warning should show original dtype float64, got: {overflow_msgs[0]}" |
| 226 | + assert "float16" in overflow_msgs[0] |
| 227 | + |
| 228 | + |
| 229 | +if __name__ == '__main__': |
| 230 | + pytest.main([__file__, '-v', '-s']) |
0 commit comments