Skip to content

Commit f7e1e61

Browse files
shs037tensorflower-gardener
authored andcommitted
Adds a utility function for formating list into string.
PiperOrigin-RevId: 484026229
1 parent 7d7b670 commit f7e1e61

3 files changed

Lines changed: 50 additions & 2 deletions

File tree

tensorflow_privacy/privacy/privacy_tests/BUILD

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ py_test(
1515
srcs = ["utils_test.py"],
1616
python_version = "PY3",
1717
srcs_version = "PY3",
18-
deps = [":utils"],
18+
deps = [
19+
":utils",
20+
"//third_party/py/parameterized",
21+
],
1922
)
2023

2124
py_test(

tensorflow_privacy/privacy/privacy_tests/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515

1616
import enum
1717
import logging
18-
from typing import Callable, Optional, Union
18+
import numbers
19+
from typing import Callable, Iterable, Optional, Union
1920

2021
import numpy as np
2122
from scipy import special
@@ -254,3 +255,9 @@ def get_loss(
254255
else:
255256
loss = loss_function(labels, predictions, sample_weight)
256257
return loss
258+
259+
260+
def format_number_list(input_list: Iterable[numbers.Number],
261+
precision: int = 4) -> str:
262+
"""Formats list of numbers as a string."""
263+
return ', '.join([f'{x:.{precision}f}' for x in input_list])

tensorflow_privacy/privacy/privacy_tests/utils_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,5 +280,43 @@ def test_get_loss_call_loss_function(self, loss_function, multilabel_data,
280280
mock_squared_loss.assert_called_once()
281281

282282

283+
@parameterized.parameters(
284+
# integer list, one element
285+
([1], 0, '1'),
286+
([-1], 1, '-1.0'),
287+
# integer list, multiple elements
288+
([1, 2, 3], 0, '1, 2, 3'),
289+
([-1, 2, -3], 1, '-1.0, 2.0, -3.0'),
290+
# float list, one element
291+
([1.1], 0, '1'),
292+
([1.1], 1, '1.1'),
293+
([-1.1], 3, '-1.100'),
294+
# float and integer combined, multiple elements
295+
([0, 1.1, 2.22, 3.333], 0, '0, 1, 2, 3'),
296+
([0, 1.1, -2.22, 3.333], 1, '0.0, 1.1, -2.2, 3.3'),
297+
([0, 1.1, 2.22, 3.333], 3, '0.000, 1.100, 2.220, 3.333'),
298+
# inf and nan
299+
([np.inf, 1, -2.22, -np.inf, np.nan], 1, 'inf, 1.0, -2.2, -inf, nan'),
300+
# empty list
301+
([], 1, ''),
302+
# iterables other than list
303+
((np.inf, 1, 2.2), 0, 'inf, 1, 2'),
304+
(range(-1, 3), 1, '-1.0, 0.0, 1.0, 2.0'))
305+
class TestPrintNumberList(parameterized.TestCase):
306+
307+
def test_format_list(self, input_list, precision, expected_output):
308+
self.assertEqual(
309+
utils.format_number_list(input_list, precision), expected_output)
310+
311+
def test_format_iterator(self, input_list, precision, expected_output):
312+
self.assertEqual(
313+
utils.format_number_list(iter(input_list), precision), expected_output)
314+
315+
def test_format_numpy_array(self, input_list, precision, expected_output):
316+
self.assertEqual(
317+
utils.format_number_list(np.array(input_list), precision),
318+
expected_output)
319+
320+
283321
if __name__ == '__main__':
284322
absltest.main()

0 commit comments

Comments
 (0)