|
6 | 6 | from __future__ import annotations |
7 | 7 |
|
8 | 8 | from abc import ABC, abstractmethod |
9 | | -from typing import Any, Union, overload |
| 9 | +from typing import Any, FrozenSet, Union, overload |
10 | 10 |
|
11 | 11 | from typeguard import check_type, typechecked |
12 | 12 |
|
13 | 13 | from tmlt.core.domains.base import Domain |
14 | 14 | from tmlt.core.measurements.base import Measurement |
15 | 15 | from tmlt.core.metrics import Metric, UnsupportedCombinationError |
| 16 | +from tmlt.core.utils.format import default_format_attrs, default_format_children |
16 | 17 |
|
17 | 18 |
|
18 | 19 | class Transformation(ABC): |
19 | 20 | """Abstract base class for transformations.""" |
20 | 21 |
|
| 22 | + _FORMAT_EXCLUDED_ATTRS: FrozenSet[str] = frozenset( |
| 23 | + {"input_domain", "input_metric", "output_domain", "output_metric"} |
| 24 | + ) |
| 25 | + """Fields hidden from output when formatting this transformation.""" |
| 26 | + |
21 | 27 | @typechecked |
22 | 28 | def __init__( |
23 | 29 | self, |
@@ -125,3 +131,29 @@ def __or__(self, other: Any) -> Union[Measurement, Transformation]: |
125 | 131 | @abstractmethod |
126 | 132 | def __call__(self, data: Any) -> Any: |
127 | 133 | """Perform transformation.""" |
| 134 | + |
| 135 | + def format(self) -> str: |
| 136 | + """Return a human-readable multi-line description of this transformation. |
| 137 | +
|
| 138 | + The default implementation assembles :meth:`_format_head` and |
| 139 | + :meth:`_format_children`; subclasses can override either of these |
| 140 | + hooks (or :meth:`format` itself) to customize the rendering. |
| 141 | + """ |
| 142 | + head = self._format_head() |
| 143 | + children = self._format_children() |
| 144 | + if not children: |
| 145 | + return head |
| 146 | + return f"{head}\n{children}" |
| 147 | + |
| 148 | + def _format_head(self) -> str: |
| 149 | + """Render this component's head line: class name followed by its attrs.""" |
| 150 | + parts = [type(self).__name__] |
| 151 | + parts.extend( |
| 152 | + f"{name}={value}" |
| 153 | + for name, value in default_format_attrs(self, self._FORMAT_EXCLUDED_ATTRS) |
| 154 | + ) |
| 155 | + return " ".join(parts) |
| 156 | + |
| 157 | + def _format_children(self) -> str: |
| 158 | + """Return the rendered block for nested transformations.""" |
| 159 | + return default_format_children(self) |
0 commit comments