@@ -157,22 +157,43 @@ def format_sample_detailed(sample: Any, indent: str = "") -> str:
157157
158158 Example:
159159 >>> print(format_sample_detailed({"image": torch.zeros(3, 224, 224), "label": 5}))
160- - image: Tensor(shape=(3, 224, 224), dtype=torch.float32, ...)
161- - label: 5
160+ image: Tensor(shape=(3, 224, 224), dtype=torch.float32, ...)
161+ label: 5
162162 """
163+
164+ def _child_indent (cur : str , value : Any ) -> str :
165+ if cur :
166+ return cur + " "
167+ if isinstance (value , (dict , list , tuple )):
168+ return " "
169+ if dataclasses .is_dataclass (value ):
170+ return " "
171+ return " "
172+
163173 if isinstance (sample , dict ):
164174 result = []
165175 for _ , (key , value ) in zip (range (25 ), sample .items ()):
166- result .append (f"{ indent } - { key } : { format_sample_detailed (value , indent + ' ' )} " )
176+ nested = format_sample_detailed (value , _child_indent (indent , value ))
177+ head = f"{ indent } { key } :"
178+ if "\n " not in nested :
179+ result .append (f"{ head } { nested } " )
180+ elif isinstance (value , str ) or dataclasses .is_dataclass (value ):
181+ result .append (f"{ head } { nested } " )
182+ else :
183+ result .append (f"{ head } \n { nested } " )
167184 if len (sample ) > 25 :
168- result .append (f"{ indent } - ... (and { len (sample ) - 25 } more items)" )
185+ result .append (f"{ indent } ... (and { len (sample ) - 25 } more items)" )
169186 return "\n " .join (result )
170187 elif isinstance (sample , str ):
171188 if len (sample ) > 1000 :
172189 sample = f"{ sample [:1000 ]} ... (and { len (sample ) - 1000 } more characters)"
173190 if "\n " in sample :
174- # represent as """ string if it contains newlines:
175- return '"""' + sample .replace ("\n " , "\n " + indent ) + '"""'
191+ lines = sample .split ("\n " )
192+ out = '"""' + indent + lines [0 ]
193+ for line in lines [1 :]:
194+ out += "\n " + indent + line
195+ out += '"""'
196+ return out
176197 return repr (sample )
177198 elif isinstance (sample , (int , float , bool , type (None ))):
178199 return repr (sample )
@@ -181,9 +202,22 @@ def format_sample_detailed(sample: Any, indent: str = "") -> str:
181202 return f"[{ ', ' .join (repr (value ) for value in sample )} ]"
182203 result = []
183204 for _ , value in zip (range (10 ), sample ):
184- result .append (f"{ indent } - { format_sample_detailed (value , indent + ' ' )} " )
205+ if isinstance (value , dict ) and len (value ) == 1 :
206+ (k , v ), = value .items ()
207+ nested_v = format_sample_detailed (v , indent + " " )
208+ item_head = f"{ indent } - { k } :"
209+ if "\n " not in nested_v :
210+ result .append (f"{ item_head } { nested_v } " )
211+ else :
212+ result .append (f"{ item_head } \n { nested_v } " )
213+ else :
214+ nested = format_sample_detailed (value , indent + " " )
215+ if "\n " not in nested :
216+ result .append (f"{ indent } - { nested } " )
217+ else :
218+ result .append (f"{ indent } -\n { nested } " )
185219 if len (sample ) > 10 :
186- result .append (f"{ indent } - ... (and { len (sample ) - 10 } more items)" )
220+ result .append (f"{ indent } - ... (and { len (sample ) - 10 } more items)" )
187221 return "\n " .join (result )
188222 elif isinstance (sample , torch .Tensor ):
189223 try :
@@ -235,12 +269,16 @@ def format_sample_detailed(sample: Any, indent: str = "") -> str:
235269 # Handle empty arrays or non-numeric dtypes
236270 return f"np.ndarray(shape={ sample .shape } , dtype={ sample .dtype } )"
237271 elif dataclasses .is_dataclass (sample ):
238- result = [f"{ indent } { type (sample ).__name__ } (" ]
272+ result = [f"{ type (sample ).__name__ } (" ]
239273 for field in dataclasses .fields (sample ):
240- result .append (
241- f"{ indent } { field .name } ={ format_sample_detailed (getattr (sample , field .name ), indent + ' ' )} "
242- )
243- result .append (f"{ indent } )" )
274+ field_val = getattr (sample , field .name )
275+ nested = format_sample_detailed (field_val , indent + " " )
276+ head = f"{ indent } { field .name } :"
277+ if "\n " not in nested :
278+ result .append (f"{ head } { nested } " )
279+ else :
280+ result .append (f"{ head } \n { nested } " )
281+ result .append (")" )
244282 return "\n " .join (result )
245283 else :
246284 repr_str = repr (sample )
0 commit comments