|
| 1 | +"""Decoding pipeline helpers.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +from typing import Any, Iterable, List, Sequence, Tuple |
| 6 | + |
| 7 | +import numpy as np |
| 8 | + |
| 9 | +from .base import DecodeStep |
| 10 | +from .registry import DecoderRegistry, DEFAULT_DECODER_REGISTRY, register_builtin_decoders |
| 11 | + |
| 12 | + |
| 13 | +def _coerce_kwargs(kwargs: Any) -> dict: |
| 14 | + if kwargs is None: |
| 15 | + return {} |
| 16 | + if hasattr(kwargs, "items"): |
| 17 | + return dict(kwargs) |
| 18 | + raise TypeError( |
| 19 | + f"Decode kwargs must be a mapping, got {type(kwargs).__name__}" |
| 20 | + ) |
| 21 | + |
| 22 | + |
| 23 | +def normalize_decode_modes(decode_modes: Iterable[Any]) -> List[DecodeStep]: |
| 24 | + """Normalize decode configuration entries into DecodeStep objects.""" |
| 25 | + steps: List[DecodeStep] = [] |
| 26 | + for mode in decode_modes: |
| 27 | + if isinstance(mode, DecodeStep): |
| 28 | + steps.append(DecodeStep(name=mode.name, kwargs=_coerce_kwargs(mode.kwargs))) |
| 29 | + continue |
| 30 | + |
| 31 | + if hasattr(mode, "name"): |
| 32 | + name = mode.name |
| 33 | + kwargs = _coerce_kwargs(getattr(mode, "kwargs", {})) |
| 34 | + steps.append(DecodeStep(name=name, kwargs=kwargs)) |
| 35 | + continue |
| 36 | + |
| 37 | + if isinstance(mode, dict): |
| 38 | + name = mode.get("name") |
| 39 | + kwargs = _coerce_kwargs(mode.get("kwargs", {})) |
| 40 | + steps.append(DecodeStep(name=name, kwargs=kwargs)) |
| 41 | + continue |
| 42 | + |
| 43 | + raise TypeError(f"Unsupported decode mode type: {type(mode).__name__}") |
| 44 | + |
| 45 | + for step in steps: |
| 46 | + if not step.name: |
| 47 | + raise ValueError("Decode step is missing required field 'name'.") |
| 48 | + |
| 49 | + return steps |
| 50 | + |
| 51 | + |
| 52 | +def _prepare_batched_input(data: np.ndarray) -> Tuple[np.ndarray, int]: |
| 53 | + arr = np.asarray(data) |
| 54 | + if arr.ndim == 5: |
| 55 | + return arr, arr.shape[0] |
| 56 | + if arr.ndim == 4: |
| 57 | + return arr[np.newaxis, ...], 1 |
| 58 | + if arr.ndim == 3: |
| 59 | + return arr[np.newaxis, np.newaxis, ...], 1 |
| 60 | + if arr.ndim == 2: |
| 61 | + return arr[np.newaxis, np.newaxis, np.newaxis, ...], 1 |
| 62 | + raise ValueError( |
| 63 | + f"Expected input with 2-5 dimensions, got shape {arr.shape}." |
| 64 | + ) |
| 65 | + |
| 66 | + |
| 67 | +def apply_decode_pipeline( |
| 68 | + data: np.ndarray, |
| 69 | + decode_modes: Sequence[Any] | None, |
| 70 | + registry: DecoderRegistry | None = None, |
| 71 | +) -> np.ndarray: |
| 72 | + """Apply configured decode steps to prediction data.""" |
| 73 | + if not decode_modes: |
| 74 | + return data |
| 75 | + |
| 76 | + register_builtin_decoders() |
| 77 | + registry = registry or DEFAULT_DECODER_REGISTRY |
| 78 | + steps = normalize_decode_modes(decode_modes) |
| 79 | + batched, batch_size = _prepare_batched_input(data) |
| 80 | + |
| 81 | + results: List[np.ndarray] = [] |
| 82 | + for batch_idx in range(batch_size): |
| 83 | + sample = batched[batch_idx] |
| 84 | + for step in steps: |
| 85 | + try: |
| 86 | + decoder = registry.get(step.name) |
| 87 | + except KeyError as exc: |
| 88 | + available = ", ".join(registry.available()) |
| 89 | + raise ValueError( |
| 90 | + f"Unknown decode function '{step.name}'. " |
| 91 | + f"Available functions: [{available}]." |
| 92 | + ) from exc |
| 93 | + |
| 94 | + try: |
| 95 | + sample = decoder(sample, **step.kwargs) |
| 96 | + except Exception as exc: |
| 97 | + raise RuntimeError( |
| 98 | + f"Error applying decode function '{step.name}': {exc}" |
| 99 | + ) from exc |
| 100 | + |
| 101 | + results.append(sample) |
| 102 | + |
| 103 | + if len(results) == 1: |
| 104 | + return results[0] |
| 105 | + return np.stack(results, axis=0) |
| 106 | + |
| 107 | + |
| 108 | +def resolve_decode_modes_from_cfg(cfg: Any) -> Sequence[Any] | None: |
| 109 | + """Resolve decode mode list from config. |
| 110 | +
|
| 111 | + Priority: |
| 112 | + 1. ``cfg.test.decoding`` |
| 113 | + 2. ``cfg.inference.decoding`` |
| 114 | + """ |
| 115 | + if hasattr(cfg, "test") and cfg.test and hasattr(cfg.test, "decoding"): |
| 116 | + return cfg.test.decoding |
| 117 | + if hasattr(cfg, "inference") and hasattr(cfg.inference, "decoding"): |
| 118 | + return cfg.inference.decoding |
| 119 | + return None |
| 120 | + |
| 121 | + |
| 122 | +def apply_decode_mode(cfg: Any, data: np.ndarray, *, verbose: bool = True) -> np.ndarray: |
| 123 | + """Apply decode pipeline resolved from ``test.decoding`` or ``inference.decoding``.""" |
| 124 | + decode_modes = resolve_decode_modes_from_cfg(cfg) |
| 125 | + if not decode_modes: |
| 126 | + if verbose: |
| 127 | + print(" No decoding configuration found (test.decoding or inference.decoding)") |
| 128 | + return data |
| 129 | + |
| 130 | + if verbose: |
| 131 | + source = "test.decoding" |
| 132 | + if not (hasattr(cfg, "test") and cfg.test and hasattr(cfg.test, "decoding")): |
| 133 | + source = "inference.decoding" |
| 134 | + print(f" Using {source}: {decode_modes}") |
| 135 | + |
| 136 | + return apply_decode_pipeline(data, decode_modes) |
0 commit comments