Skip to content

Commit bd2e2e1

Browse files
committed
Add format() method for transformations and measurements
Adds a format() method for converting transformations and measurements into human-readable strings. Most components use a default implementation based on inspecting each class to get its public properties, but some (most notably any component with multiple children) use custom formatting logic. Also adds a `tmlt.core.utils.format` module containing shared helpers used during formatting. #31
1 parent ef98efe commit bd2e2e1

43 files changed

Lines changed: 1710 additions & 5 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CHANGELOG.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@
33
Changelog
44
=========
55

6-
76
Unreleased
87
----------
98

9+
Added
10+
~~~~~
11+
12+
- :class:`.Transformation`\s and :class:`.Measurement`\ s have a new ``format`` method, which renders a human-readable string showing the structure of the transformation/measurement to aid in visualization and debugging.
13+
1014
.. _v0.19.0:
1115

1216
0.19.0 - 2026-05-22

src/tmlt/core/measurements/base.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,24 @@
33
# SPDX-License-Identifier: Apache-2.0
44
# Copyright Tumult Labs 2026
55
from abc import ABC, abstractmethod
6-
from typing import Any
6+
from typing import Any, FrozenSet
77

88
from typeguard import typechecked
99

1010
from tmlt.core.domains.base import Domain
1111
from tmlt.core.measures import Measure
1212
from tmlt.core.metrics import Metric, UnsupportedCombinationError
13+
from tmlt.core.utils.format import default_format_attrs, default_format_children
1314

1415

1516
class Measurement(ABC):
1617
"""Abstract base class for measurements."""
1718

19+
_FORMAT_EXCLUDED_ATTRS: FrozenSet[str] = frozenset(
20+
{"input_domain", "input_metric", "output_measure", "is_interactive"}
21+
)
22+
"""Fields hidden from output when formatting this measurement."""
23+
1824
@typechecked
1925
def __init__(
2026
self,
@@ -96,3 +102,29 @@ def privacy_relation(self, d_in: Any, d_out: Any) -> bool:
96102
@abstractmethod
97103
def __call__(self, data: Any) -> Any:
98104
"""Performs measurement."""
105+
106+
def format(self) -> str:
107+
"""Return a human-readable multi-line description of this measurement.
108+
109+
The default implementation assembles :meth:`_format_head` and
110+
:meth:`_format_children`; subclasses can override either of these
111+
hooks (or :meth:`format` itself) to customize the rendering.
112+
"""
113+
head = self._format_head()
114+
children = self._format_children()
115+
if not children:
116+
return head
117+
return f"{head}\n{children}"
118+
119+
def _format_head(self) -> str:
120+
"""Render this measurement's head line: class name followed by its attrs."""
121+
parts = [type(self).__name__]
122+
parts.extend(
123+
f"{name}={value}"
124+
for name, value in default_format_attrs(self, self._FORMAT_EXCLUDED_ATTRS)
125+
)
126+
return " ".join(parts)
127+
128+
def _format_children(self) -> str:
129+
"""Return the rendered block for nested transformations/measurements."""
130+
return default_format_children(self)

src/tmlt/core/measurements/chaining.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from tmlt.core.exceptions import DomainMismatchError, MetricMismatchError
1111
from tmlt.core.measurements.base import Measurement
1212
from tmlt.core.transformations.base import Transformation
13+
from tmlt.core.utils.format import format_chain, get_chain_children
1314

1415

1516
class ChainTM(Measurement):
@@ -140,3 +141,7 @@ def privacy_relation(self, d_in: Any, d_out: Any) -> bool:
140141
def __call__(self, data: Any) -> Any:
141142
"""Computes measurement after applying transformation on input data."""
142143
return self._measurement(self._transformation(data))
144+
145+
def format(self) -> str:
146+
"""Return a human-readable multi-line description of this measurement."""
147+
return format_chain(get_chain_children(self))

src/tmlt/core/measurements/composition.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616
from tmlt.core.measurements.base import Measurement
1717
from tmlt.core.measures import ApproxDP, PureDP, RhoZCDP
18+
from tmlt.core.utils.format import format_siblings
1819

1920

2021
class Composition(Measurement):
@@ -178,3 +179,6 @@ def privacy_relation(self, d_in: Any, d_out: Any) -> bool:
178179
def __call__(self, data: Any) -> List:
179180
"""Return answers to composed measurements."""
180181
return [measurement(data) for measurement in self._measurements]
182+
183+
def _format_children(self) -> str:
184+
return format_siblings(self._measurements)

src/tmlt/core/measurements/interactive_measurements.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from tmlt.core.transformations.base import Transformation
3636
from tmlt.core.transformations.chaining import ChainTT
3737
from tmlt.core.transformations.identity import Identity
38+
from tmlt.core.utils.format import format_siblings
3839
from tmlt.core.utils.misc import copy_if_mutable
3940

4041

@@ -720,6 +721,9 @@ def __call__(self, data: Any) -> ParallelQueryable:
720721
"""Returns a :class:`~.ParallelQueryable`."""
721722
return ParallelQueryable(data, self._measurements)
722723

724+
def _format_children(self) -> str:
725+
return format_siblings(self._measurements)
726+
723727

724728
class MakeInteractive(Measurement):
725729
"""Creates a :class:`~.GetAnswerQueryable`.

src/tmlt/core/measurements/pandas_measurements/dataframe.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from tmlt.core.measures import Measure, PureDP, RhoZCDP
2727
from tmlt.core.metrics import HammingDistance, SymmetricDifference
2828
from tmlt.core.utils.exact_number import ExactNumber, ExactNumberInput
29+
from tmlt.core.utils.format import format_labeled_siblings
2930

3031

3132
class Aggregate(Measurement):
@@ -273,3 +274,14 @@ def __call__(self, df: pd.DataFrame) -> pd.DataFrame:
273274
for column_name, aggregation in self.column_to_aggregation.items()
274275
}
275276
)
277+
278+
def format(self) -> str:
279+
"""Return a human-readable multi-line description of this measurement.
280+
281+
The per-column aggregations are rendered as labeled sibling children
282+
(the ``output_schema`` is derivable from them, so it is not shown).
283+
"""
284+
return (
285+
f"{type(self).__name__}\n"
286+
f"{format_labeled_siblings(self.column_to_aggregation.items())}"
287+
)

src/tmlt/core/transformations/base.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,24 @@
66
from __future__ import annotations
77

88
from abc import ABC, abstractmethod
9-
from typing import Any, Union, overload
9+
from typing import Any, FrozenSet, Union, overload
1010

1111
from typeguard import check_type, typechecked
1212

1313
from tmlt.core.domains.base import Domain
1414
from tmlt.core.measurements.base import Measurement
1515
from tmlt.core.metrics import Metric, UnsupportedCombinationError
16+
from tmlt.core.utils.format import default_format_attrs, default_format_children
1617

1718

1819
class Transformation(ABC):
1920
"""Abstract base class for transformations."""
2021

22+
_FORMAT_EXCLUDED_ATTRS: FrozenSet[str] = frozenset(
23+
{"input_domain", "input_metric", "output_domain", "output_metric"}
24+
)
25+
"""Fields hidden from output when formatting this transformation."""
26+
2127
@typechecked
2228
def __init__(
2329
self,
@@ -125,3 +131,29 @@ def __or__(self, other: Any) -> Union[Measurement, Transformation]:
125131
@abstractmethod
126132
def __call__(self, data: Any) -> Any:
127133
"""Perform transformation."""
134+
135+
def format(self) -> str:
136+
"""Return a human-readable multi-line description of this transformation.
137+
138+
The default implementation assembles :meth:`_format_head` and
139+
:meth:`_format_children`; subclasses can override either of these
140+
hooks (or :meth:`format` itself) to customize the rendering.
141+
"""
142+
head = self._format_head()
143+
children = self._format_children()
144+
if not children:
145+
return head
146+
return f"{head}\n{children}"
147+
148+
def _format_head(self) -> str:
149+
"""Render this component's head line: class name followed by its attrs."""
150+
parts = [type(self).__name__]
151+
parts.extend(
152+
f"{name}={value}"
153+
for name, value in default_format_attrs(self, self._FORMAT_EXCLUDED_ATTRS)
154+
)
155+
return " ".join(parts)
156+
157+
def _format_children(self) -> str:
158+
"""Return the rendered block for nested transformations."""
159+
return default_format_children(self)

src/tmlt/core/transformations/chaining.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from tmlt.core.exceptions import DomainMismatchError, MetricMismatchError
1111
from tmlt.core.transformations.base import Transformation
12+
from tmlt.core.utils.format import format_chain, get_chain_children
1213

1314

1415
class ChainTT(Transformation):
@@ -126,3 +127,7 @@ def transformation2(self) -> Transformation:
126127
def __call__(self, data: Any) -> Any:
127128
"""Performs transformation1 followed by transformation2."""
128129
return self._transformation2(self._transformation1(data))
130+
131+
def format(self) -> str:
132+
"""Return a human-readable multi-line description of this measurement."""
133+
return format_chain(get_chain_children(self))

src/tmlt/core/transformations/spark_transformations/groupby.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,11 @@ class GroupBy(Transformation):
123123
1
124124
""" # noqa: E501
125125

126+
# When formatted, group_keys provides no information that isn't in groupby_columns
127+
_FORMAT_EXCLUDED_ATTRS = Transformation._FORMAT_EXCLUDED_ATTRS | { # noqa: SLF001
128+
"group_keys"
129+
}
130+
126131
@typechecked
127132
def __init__(
128133
self,

0 commit comments

Comments
 (0)