|
| 1 | +"""Visualize a 2D chunk plan from ``waterdata.chunking._plan_joint``. |
| 2 | +
|
| 3 | +Builds a representative over-budget query (long site list + long |
| 4 | +top-level-``OR`` filter), runs the joint planner against the real |
| 5 | +``_construct_api_requests`` builder, and renders the resulting |
| 6 | +list × filter cartesian product as a heatmap. Each cell is one |
| 7 | +sub-request the wrapper would issue; the colour is its URL byte |
| 8 | +size relative to the limit. Use this to eyeball plans for the kind |
| 9 | +of correctness properties that are easy to miss in unit tests: |
| 10 | +
|
| 11 | +- every cell ≤ the limit (no plan is allowed to overflow), |
| 12 | +- the headroom is reasonably balanced (no chunk is wasted), |
| 13 | +- the filter partition matches the OR-axis layout you expect, |
| 14 | +- the cartesian product is rectangular (every list chunk pairs |
| 15 | + with every filter chunk exactly once). |
| 16 | +
|
| 17 | +Run: ``python demos/visualize_chunk_plan.py``. Saves the figure to |
| 18 | +``demos/chunk_plan.png`` and also prints the plan as a text table. |
| 19 | +""" |
| 20 | + |
| 21 | +from __future__ import annotations |
| 22 | + |
| 23 | +from pathlib import Path |
| 24 | + |
| 25 | +import matplotlib.pyplot as plt |
| 26 | +import numpy as np |
| 27 | + |
| 28 | +from dataretrieval.waterdata.chunking import _plan_joint, _request_bytes |
| 29 | +from dataretrieval.waterdata.utils import _construct_api_requests |
| 30 | + |
| 31 | + |
| 32 | +def build_demo_args() -> tuple[dict, int]: |
| 33 | + """A query that needs both list and filter chunking to fit. |
| 34 | +
|
| 35 | + 100 USGS sites + 12 ``time`` OR-clauses against a 1500-byte URL |
| 36 | + limit. Real ``_construct_api_requests`` URL-encoding applies so |
| 37 | + the bytes match what production would build. |
| 38 | + """ |
| 39 | + sites = [f"USGS-{i:08d}" for i in range(60)] |
| 40 | + # 16 datetime equality clauses — long enough that ``k=1`` forces |
| 41 | + # many list chunks, but halving the filter frees enough budget that |
| 42 | + # a moderate ``k>1`` plus a coarser list split is the joint |
| 43 | + # planner's optimum. The demo aims to exercise BOTH chunking dims. |
| 44 | + clauses = [ |
| 45 | + f"time='2024-{m:02d}-{d:02d}T00:00:00Z'" |
| 46 | + for m in range(1, 5) |
| 47 | + for d in (1, 8, 15, 22) |
| 48 | + ] |
| 49 | + args = { |
| 50 | + "service": "daily", |
| 51 | + "monitoring_location_id": sites, |
| 52 | + "filter": " OR ".join(clauses), |
| 53 | + } |
| 54 | + return args, 800 |
| 55 | + |
| 56 | + |
| 57 | +def gather_subrequest_bytes( |
| 58 | + args: dict, |
| 59 | + list_plan: dict[str, list[list]], |
| 60 | + filter_chunks: list[str | None], |
| 61 | +) -> tuple[np.ndarray, list[str], list[str], str | None]: |
| 62 | + """Build every sub-request URL the planner would emit and return a |
| 63 | + ``(rows, cols)`` matrix of byte counts. The list dim becomes rows; |
| 64 | + filter chunks become columns. Row/column labels summarise the |
| 65 | + contents of each chunk for the figure axes. ``list_dim`` (4th |
| 66 | + return value) is the name of the chunked list parameter, used by |
| 67 | + the partition spot-check.""" |
| 68 | + if not list_plan: |
| 69 | + # No list chunking; render as a single row. |
| 70 | + list_dim, list_chunks = None, [None] |
| 71 | + else: |
| 72 | + # The demo uses a single list dim; if multiple, take the first. |
| 73 | + list_dim = next(iter(list_plan)) |
| 74 | + list_chunks = list_plan[list_dim] |
| 75 | + |
| 76 | + n_rows = len(list_chunks) |
| 77 | + n_cols = len(filter_chunks) |
| 78 | + bytes_ = np.zeros((n_rows, n_cols), dtype=int) |
| 79 | + |
| 80 | + for r, list_chunk in enumerate(list_chunks): |
| 81 | + for c, filter_chunk in enumerate(filter_chunks): |
| 82 | + sub_args = dict(args) |
| 83 | + if list_chunk is not None: |
| 84 | + sub_args[list_dim] = list_chunk |
| 85 | + if filter_chunk is not None: |
| 86 | + sub_args["filter"] = filter_chunk |
| 87 | + bytes_[r, c] = _request_bytes(_construct_api_requests(**sub_args)) |
| 88 | + |
| 89 | + row_labels = [] |
| 90 | + cursor = 0 |
| 91 | + for list_chunk in list_chunks: |
| 92 | + if list_chunk is None: |
| 93 | + row_labels.append("all (no chunking)") |
| 94 | + else: |
| 95 | + end = cursor + len(list_chunk) |
| 96 | + row_labels.append(f"[{cursor}:{end}]\n({len(list_chunk)} items)") |
| 97 | + cursor = end |
| 98 | + |
| 99 | + col_labels = [] |
| 100 | + cursor = 0 |
| 101 | + for fc in filter_chunks: |
| 102 | + if fc is None: |
| 103 | + col_labels.append("no filter") |
| 104 | + else: |
| 105 | + n_clauses = fc.count(" OR ") + 1 |
| 106 | + end = cursor + n_clauses |
| 107 | + col_labels.append(f"[{cursor}:{end}]\n({n_clauses} clauses)") |
| 108 | + cursor = end |
| 109 | + |
| 110 | + return bytes_, row_labels, col_labels, list_dim |
| 111 | + |
| 112 | + |
| 113 | +def draw_heatmap( |
| 114 | + bytes_: np.ndarray, |
| 115 | + row_labels: list[str], |
| 116 | + col_labels: list[str], |
| 117 | + url_limit: int, |
| 118 | + out_path: Path, |
| 119 | + list_dim: str | None, |
| 120 | +) -> None: |
| 121 | + """Render the byte matrix as a heatmap. Cells are coloured by |
| 122 | + ``bytes / url_limit``; the limit itself is the colour-scale's red |
| 123 | + boundary so anything over budget stands out. Each cell is |
| 124 | + annotated with its byte count; a red cell would mean the planner |
| 125 | + produced an over-budget sub-request (visible bug).""" |
| 126 | + n_rows, n_cols = bytes_.shape |
| 127 | + fig, ax = plt.subplots(figsize=(max(6, 1.2 * n_cols), max(4, 0.35 * n_rows + 1.5))) |
| 128 | + |
| 129 | + # vmax = url_limit pins the red end of the colour scale to the |
| 130 | + # budget. Anything over the limit saturates and becomes obvious. |
| 131 | + im = ax.imshow( |
| 132 | + bytes_, |
| 133 | + cmap="RdYlGn_r", |
| 134 | + vmin=0, |
| 135 | + vmax=url_limit, |
| 136 | + aspect="auto", |
| 137 | + ) |
| 138 | + ax.set_xticks(range(n_cols)) |
| 139 | + ax.set_xticklabels(col_labels, rotation=30, ha="right") |
| 140 | + ax.set_yticks(range(n_rows)) |
| 141 | + ax.set_yticklabels(row_labels) |
| 142 | + ax.set_xlabel("Filter sub-chunk (OR-clause range)") |
| 143 | + ax.set_ylabel( |
| 144 | + f"List sub-chunk ({list_dim} range)" |
| 145 | + if list_dim is not None |
| 146 | + else "List sub-chunk" |
| 147 | + ) |
| 148 | + ax.set_title( |
| 149 | + f"Joint chunk plan: {n_rows} × {n_cols} = {n_rows * n_cols} " |
| 150 | + f"sub-requests · url_limit={url_limit} bytes" |
| 151 | + ) |
| 152 | + |
| 153 | + # Per-cell annotations. |
| 154 | + for r in range(n_rows): |
| 155 | + for c in range(n_cols): |
| 156 | + ax.text( |
| 157 | + c, |
| 158 | + r, |
| 159 | + f"{bytes_[r, c]}", |
| 160 | + ha="center", |
| 161 | + va="center", |
| 162 | + color="black" if bytes_[r, c] < 0.6 * url_limit else "white", |
| 163 | + fontsize=9, |
| 164 | + ) |
| 165 | + |
| 166 | + fig.colorbar(im, ax=ax, label="URL bytes") |
| 167 | + fig.tight_layout() |
| 168 | + fig.savefig(out_path, dpi=120) |
| 169 | + plt.close(fig) |
| 170 | + |
| 171 | + |
| 172 | +def print_text_table( |
| 173 | + bytes_: np.ndarray, |
| 174 | + row_labels: list[str], |
| 175 | + col_labels: list[str], |
| 176 | + url_limit: int, |
| 177 | +) -> None: |
| 178 | + """ASCII fallback so the plan is also legible without opening the |
| 179 | + PNG (CI logs, terminals without graphics, etc.).""" |
| 180 | + print(f"\nurl_limit = {url_limit} bytes") |
| 181 | + print( |
| 182 | + f"plan shape: {bytes_.shape[0]} list × {bytes_.shape[1]} filter " |
| 183 | + f"= {bytes_.size} sub-requests" |
| 184 | + ) |
| 185 | + print( |
| 186 | + f"min cell: {bytes_.min()} bytes · max cell: {bytes_.max()} bytes " |
| 187 | + f"(headroom: {url_limit - bytes_.max()} bytes)" |
| 188 | + ) |
| 189 | + print() |
| 190 | + col_w = max(8, max(len(c.replace("\n", " ")) for c in col_labels) + 1) |
| 191 | + row_w = max(len(r.replace("\n", " ")) for r in row_labels) + 2 |
| 192 | + print(" " * row_w + "".join(c.replace("\n", " ").rjust(col_w) for c in col_labels)) |
| 193 | + for r, row_label in enumerate(row_labels): |
| 194 | + cells = "".join( |
| 195 | + f"{int(bytes_[r, c])}".rjust(col_w) for c in range(bytes_.shape[1]) |
| 196 | + ) |
| 197 | + print(row_label.replace("\n", " ").ljust(row_w) + cells) |
| 198 | + |
| 199 | + |
| 200 | +def spot_check_partition( |
| 201 | + args: dict, |
| 202 | + list_plan: dict[str, list[list]], |
| 203 | + filter_chunks: list[str | None], |
| 204 | + list_dim: str | None, |
| 205 | +) -> None: |
| 206 | + """Sanity-check that the cartesian-product plan covers every |
| 207 | + original list value and OR-clause exactly once. Catches partition |
| 208 | + bugs that the heatmap alone wouldn't surface (e.g. a chunk that |
| 209 | + drops or duplicates members).""" |
| 210 | + if list_dim is not None: |
| 211 | + original = list(args[list_dim]) |
| 212 | + seen = [v for chunk in list_plan[list_dim] for v in chunk] |
| 213 | + assert sorted(seen) == sorted(original), ( |
| 214 | + f"list partition lost or duplicated values: " |
| 215 | + f"{len(seen)} seen vs {len(original)} expected" |
| 216 | + ) |
| 217 | + print(f"list partition covers all {len(original)} {list_dim}s exactly once") |
| 218 | + |
| 219 | + original_filter = args.get("filter") |
| 220 | + if original_filter and len(filter_chunks) > 1: |
| 221 | + original_clauses = [c.strip() for c in original_filter.split(" OR ")] |
| 222 | + seen_clauses: list[str] = [] |
| 223 | + for fc in filter_chunks: |
| 224 | + if fc is None: |
| 225 | + continue |
| 226 | + seen_clauses.extend(c.strip() for c in fc.split(" OR ")) |
| 227 | + assert seen_clauses == original_clauses, ( |
| 228 | + "filter partition must cover original clauses in order, exactly once" |
| 229 | + ) |
| 230 | + print( |
| 231 | + f"filter partition covers all {len(original_clauses)} " |
| 232 | + f"OR-clauses exactly once, in order" |
| 233 | + ) |
| 234 | + |
| 235 | + |
| 236 | +def main() -> None: |
| 237 | + args, url_limit = build_demo_args() |
| 238 | + |
| 239 | + plan = _plan_joint(args, _construct_api_requests, url_limit) |
| 240 | + if plan is None: |
| 241 | + raise SystemExit( |
| 242 | + "Demo args fit under url_limit — pick a tighter limit or a " |
| 243 | + "longer query so the planner actually fans out." |
| 244 | + ) |
| 245 | + list_plan, filter_chunks = plan |
| 246 | + |
| 247 | + bytes_, row_labels, col_labels, list_dim = gather_subrequest_bytes( |
| 248 | + args, list_plan, filter_chunks |
| 249 | + ) |
| 250 | + |
| 251 | + out = Path(__file__).parent / "chunk_plan.png" |
| 252 | + draw_heatmap(bytes_, row_labels, col_labels, url_limit, out, list_dim) |
| 253 | + print(f"wrote {out}") |
| 254 | + |
| 255 | + print_text_table(bytes_, row_labels, col_labels, url_limit) |
| 256 | + spot_check_partition(args, list_plan, filter_chunks, list_dim) |
| 257 | + |
| 258 | + over = np.argwhere(bytes_ > url_limit) |
| 259 | + if len(over): |
| 260 | + print( |
| 261 | + f"\nBUG: {len(over)} cell(s) over the {url_limit}-byte limit " |
| 262 | + f"(first: row {over[0, 0]}, col {over[0, 1]} = " |
| 263 | + f"{bytes_[over[0, 0], over[0, 1]]} bytes)" |
| 264 | + ) |
| 265 | + raise SystemExit(1) |
| 266 | + print("all cells within url_limit — plan is valid") |
| 267 | + |
| 268 | + |
| 269 | +if __name__ == "__main__": |
| 270 | + main() |
0 commit comments