Skip to content

Commit b1c7bbc

Browse files
committed
Fix sample printer
1 parent c560494 commit b1c7bbc

2 files changed

Lines changed: 120 additions & 13 deletions

File tree

src/megatron/energon/sample_utils.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -157,22 +157,43 @@ def format_sample_detailed(sample: Any, indent: str = "") -> str:
157157
158158
Example:
159159
>>> print(format_sample_detailed({"image": torch.zeros(3, 224, 224), "label": 5}))
160-
- image: Tensor(shape=(3, 224, 224), dtype=torch.float32, ...)
161-
- label: 5
160+
image: Tensor(shape=(3, 224, 224), dtype=torch.float32, ...)
161+
label: 5
162162
"""
163+
164+
def _child_indent(cur: str, value: Any) -> str:
165+
if cur:
166+
return cur + " "
167+
if isinstance(value, (dict, list, tuple)):
168+
return " "
169+
if dataclasses.is_dataclass(value):
170+
return " "
171+
return " "
172+
163173
if isinstance(sample, dict):
164174
result = []
165175
for _, (key, value) in zip(range(25), sample.items()):
166-
result.append(f"{indent} - {key}: {format_sample_detailed(value, indent + ' ')}")
176+
nested = format_sample_detailed(value, _child_indent(indent, value))
177+
head = f"{indent}{key}:"
178+
if "\n" not in nested:
179+
result.append(f"{head} {nested}")
180+
elif isinstance(value, str) or dataclasses.is_dataclass(value):
181+
result.append(f"{head} {nested}")
182+
else:
183+
result.append(f"{head}\n{nested}")
167184
if len(sample) > 25:
168-
result.append(f"{indent} - ... (and {len(sample) - 25} more items)")
185+
result.append(f"{indent}... (and {len(sample) - 25} more items)")
169186
return "\n".join(result)
170187
elif isinstance(sample, str):
171188
if len(sample) > 1000:
172189
sample = f"{sample[:1000]}... (and {len(sample) - 1000} more characters)"
173190
if "\n" in sample:
174-
# represent as """ string if it contains newlines:
175-
return '"""' + sample.replace("\n", "\n " + indent) + '"""'
191+
lines = sample.split("\n")
192+
out = '"""' + indent + lines[0]
193+
for line in lines[1:]:
194+
out += "\n" + indent + line
195+
out += '"""'
196+
return out
176197
return repr(sample)
177198
elif isinstance(sample, (int, float, bool, type(None))):
178199
return repr(sample)
@@ -181,9 +202,22 @@ def format_sample_detailed(sample: Any, indent: str = "") -> str:
181202
return f"[{', '.join(repr(value) for value in sample)}]"
182203
result = []
183204
for _, value in zip(range(10), sample):
184-
result.append(f"{indent} - {format_sample_detailed(value, indent + ' ')}")
205+
if isinstance(value, dict) and len(value) == 1:
206+
(k, v), = value.items()
207+
nested_v = format_sample_detailed(v, indent + " ")
208+
item_head = f"{indent}- {k}:"
209+
if "\n" not in nested_v:
210+
result.append(f"{item_head} {nested_v}")
211+
else:
212+
result.append(f"{item_head}\n{nested_v}")
213+
else:
214+
nested = format_sample_detailed(value, indent + " ")
215+
if "\n" not in nested:
216+
result.append(f"{indent}- {nested}")
217+
else:
218+
result.append(f"{indent}-\n{nested}")
185219
if len(sample) > 10:
186-
result.append(f"{indent} - ... (and {len(sample) - 10} more items)")
220+
result.append(f"{indent}- ... (and {len(sample) - 10} more items)")
187221
return "\n".join(result)
188222
elif isinstance(sample, torch.Tensor):
189223
try:
@@ -235,12 +269,16 @@ def format_sample_detailed(sample: Any, indent: str = "") -> str:
235269
# Handle empty arrays or non-numeric dtypes
236270
return f"np.ndarray(shape={sample.shape}, dtype={sample.dtype})"
237271
elif dataclasses.is_dataclass(sample):
238-
result = [f"{indent}{type(sample).__name__}("]
272+
result = [f"{type(sample).__name__}("]
239273
for field in dataclasses.fields(sample):
240-
result.append(
241-
f"{indent} {field.name}={format_sample_detailed(getattr(sample, field.name), indent + ' ')}"
242-
)
243-
result.append(f"{indent})")
274+
field_val = getattr(sample, field.name)
275+
nested = format_sample_detailed(field_val, indent + " ")
276+
head = f"{indent}{field.name}:"
277+
if "\n" not in nested:
278+
result.append(f"{head} {nested}")
279+
else:
280+
result.append(f"{head}\n{nested}")
281+
result.append(")")
244282
return "\n".join(result)
245283
else:
246284
repr_str = repr(sample)

tests/test_sample_utils.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION.
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
"""Tests for sample formatting helpers."""
5+
6+
import unittest
7+
from dataclasses import dataclass
8+
9+
import numpy as np
10+
import torch
11+
12+
from megatron.energon.sample_utils import format_sample_detailed
13+
14+
15+
@dataclass
16+
class _FormatProbeDc:
17+
n: int
18+
nested: dict[str, int]
19+
20+
21+
class TestFormatSampleDetailed(unittest.TestCase):
22+
def test_format_sample_detailed_complex_types(self) -> None:
23+
"""Exercise dict, scalars, strings, lists/tuples, tensors, ndarray, dataclass, and fallback repr."""
24+
25+
class _Unknown:
26+
def __repr__(self) -> str:
27+
return "<unknown-probe>"
28+
29+
sample = {
30+
"scalars": {"i": -3, "f": 2.5, "b": False, "n": None},
31+
"plain_str": "hi",
32+
"multiline_str": "line1\nline2",
33+
"primitive_seq": (1, 2, "x"),
34+
"hetero_list": [{"k": 1}, {"k": 2}],
35+
"tensor": torch.tensor([0.0, 2.0], dtype=torch.float32),
36+
"array": np.array([1, 2, 3], dtype=np.int64),
37+
"dataclass": _FormatProbeDc(n=9, nested={"a": 1, "b": 2}),
38+
"unknown": _Unknown(),
39+
}
40+
out = format_sample_detailed(sample)
41+
42+
print(out)
43+
44+
assert out == '''\
45+
scalars:
46+
i: -3
47+
f: 2.5
48+
b: False
49+
n: None
50+
plain_str: 'hi'
51+
multiline_str: """\
52+
line1
53+
line2"""
54+
primitive_seq: [1, 2, 'x']
55+
hetero_list:
56+
- k: 1
57+
- k: 2
58+
tensor: Tensor(shape=torch.Size([2]), dtype=torch.float32, device=cpu, min=0.0, max=2.0, values=[0.0, 2.0])
59+
array: np.ndarray(shape=(3,), dtype=int64, min=1, max=3, values=[np.int64(1), np.int64(2), np.int64(3)])
60+
dataclass: _FormatProbeDc(
61+
n: 9
62+
nested:
63+
a: 1
64+
b: 2
65+
)
66+
unknown: <unknown-probe>'''
67+
68+
if __name__ == "__main__":
69+
unittest.main()

0 commit comments

Comments
 (0)