Skip to content

Commit d6bd97e

Browse files
committed
Use @pytest.mark.parametrize to reduce test duplication
Signed-off-by: Sreekanth <prsreekanth920@gmail.com>
1 parent 27424ba commit d6bd97e

15 files changed

Lines changed: 254 additions & 266 deletions

packages/pynumaflow/tests/accumulator/test_async_accumulator.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -409,26 +409,22 @@ class ExampleBadClass:
409409
AccumulatorAsyncServer(accumulator_instance=ExampleBadClass)
410410

411411

412-
def test_max_threads():
413-
# max cap at 16
414-
server = AccumulatorAsyncServer(accumulator_instance=ExampleClass, max_threads=32)
415-
assert server.max_threads == 16
416-
417-
# use argument provided
418-
server = AccumulatorAsyncServer(accumulator_instance=ExampleClass, max_threads=5)
419-
assert server.max_threads == 5
420-
421-
# defaults to 4
422-
server = AccumulatorAsyncServer(accumulator_instance=ExampleClass)
423-
assert server.max_threads == 4
424-
425-
# zero threads
426-
server = AccumulatorAsyncServer(ExampleClass, max_threads=0)
427-
assert server.max_threads == 0
428-
429-
# negative threads
430-
server = AccumulatorAsyncServer(ExampleClass, max_threads=-5)
431-
assert server.max_threads == -5
412+
@pytest.mark.parametrize(
413+
"max_threads_arg,expected",
414+
[
415+
(32, 16), # max cap at 16
416+
(5, 5), # use argument provided
417+
(None, 4), # defaults to 4
418+
(0, 0), # zero threads
419+
(-5, -5), # negative threads
420+
],
421+
)
422+
def test_max_threads(max_threads_arg, expected):
423+
kwargs = {"accumulator_instance": ExampleClass}
424+
if max_threads_arg is not None:
425+
kwargs["max_threads"] = max_threads_arg
426+
server = AccumulatorAsyncServer(**kwargs)
427+
assert server.max_threads == expected
432428

433429

434430
def test_server_info_file_path_handling():

packages/pynumaflow/tests/batchmap/test_async_batch_map.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -169,15 +169,17 @@ def test_is_ready(async_batch_map_server) -> None:
169169
assert response.ready
170170

171171

172-
def test_max_threads():
173-
# max cap at 16
174-
server = BatchMapAsyncServer(batch_mapper_instance=handler, max_threads=32)
175-
assert server.max_threads == 16
176-
177-
# use argument provided
178-
server = BatchMapAsyncServer(batch_mapper_instance=handler, max_threads=5)
179-
assert server.max_threads == 5
180-
181-
# defaults to 4
182-
server = BatchMapAsyncServer(batch_mapper_instance=handler)
183-
assert server.max_threads == 4
172+
@pytest.mark.parametrize(
173+
"max_threads_arg,expected",
174+
[
175+
(32, 16), # max cap at 16
176+
(5, 5), # use argument provided
177+
(None, 4), # defaults to 4
178+
],
179+
)
180+
def test_max_threads(max_threads_arg, expected):
181+
kwargs = {"batch_mapper_instance": handler}
182+
if max_threads_arg is not None:
183+
kwargs["max_threads"] = max_threads_arg
184+
server = BatchMapAsyncServer(**kwargs)
185+
assert server.max_threads == expected

packages/pynumaflow/tests/map/test_async_mapper.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -206,15 +206,17 @@ def test_invalid_input():
206206
MapAsyncServer()
207207

208208

209-
def test_max_threads():
210-
# max cap at 16
211-
server = MapAsyncServer(mapper_instance=async_map_handler, max_threads=32)
212-
assert server.max_threads == 16
213-
214-
# use argument provided
215-
server = MapAsyncServer(mapper_instance=async_map_handler, max_threads=5)
216-
assert server.max_threads == 5
217-
218-
# defaults to 4
219-
server = MapAsyncServer(mapper_instance=async_map_handler)
220-
assert server.max_threads == 4
209+
@pytest.mark.parametrize(
210+
"max_threads_arg,expected",
211+
[
212+
(32, 16), # max cap at 16
213+
(5, 5), # use argument provided
214+
(None, 4), # defaults to 4
215+
],
216+
)
217+
def test_max_threads(max_threads_arg, expected):
218+
kwargs = {"mapper_instance": async_map_handler}
219+
if max_threads_arg is not None:
220+
kwargs["max_threads"] = max_threads_arg
221+
server = MapAsyncServer(**kwargs)
222+
assert server.max_threads == expected

packages/pynumaflow/tests/map/test_multiproc_mapper.py

Lines changed: 26 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -27,51 +27,41 @@ def _invoke_map_fn(test_server, timeout=1):
2727
)
2828

2929

30-
def test_multiproc_init():
31-
my_server = MapMultiprocServer(mapper_instance=map_handler, server_count=3)
32-
assert my_server._process_count == 3
33-
34-
35-
def test_multiproc_process_count():
36-
default_val = os.cpu_count()
37-
my_server = MapMultiprocServer(mapper_instance=map_handler)
38-
assert my_server._process_count == default_val
39-
40-
41-
def test_max_process_count():
42-
"""Max process count is capped at 2 * os.cpu_count, irrespective of what the user
43-
provides as input"""
44-
default_val = os.cpu_count()
45-
server = MapMultiprocServer(mapper_instance=map_handler, server_count=100)
46-
assert server._process_count == default_val * 2
47-
48-
49-
def test_udf_map_err_handshake():
30+
@pytest.mark.parametrize(
31+
"server_count,expected",
32+
[
33+
(3, 3), # explicit count
34+
(None, os.cpu_count()), # default to cpu count
35+
(100, os.cpu_count() * 2), # max cap at 2 * cpu count
36+
],
37+
)
38+
def test_process_count(server_count, expected):
39+
kwargs = {"mapper_instance": map_handler}
40+
if server_count is not None:
41+
kwargs["server_count"] = server_count
42+
server = MapMultiprocServer(**kwargs)
43+
assert server._process_count == expected
44+
45+
46+
@pytest.mark.parametrize(
47+
"handshake,expected_msg",
48+
[
49+
(False, "MapFn: expected handshake as the first message"),
50+
(True, "Something is fishy!"),
51+
],
52+
)
53+
def test_udf_map_error(handshake, expected_msg):
5054
my_server = MapMultiprocServer(mapper_instance=err_map_handler)
5155
services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer}
5256
test_server = server_from_dictionary(services, strict_real_time())
5357

54-
test_datums = get_test_datums(handshake=False)
55-
method = _invoke_map_fn(test_server)
56-
send_test_requests(method, test_datums)
57-
drain_responses(method)
58-
59-
metadata, code, details = method.termination()
60-
assert "MapFn: expected handshake as the first message" in details
61-
assert code == StatusCode.INTERNAL
62-
63-
64-
def test_udf_map_err():
65-
my_server = MapMultiprocServer(mapper_instance=err_map_handler)
66-
services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer}
67-
test_server = server_from_dictionary(services, strict_real_time())
68-
test_datums = get_test_datums(handshake=True)
58+
test_datums = get_test_datums(handshake=handshake)
6959
method = _invoke_map_fn(test_server)
7060
send_test_requests(method, test_datums)
7161
drain_responses(method)
7262

7363
metadata, code, details = method.termination()
74-
assert "Something is fishy!" in details
64+
assert expected_msg in details
7565
assert code == StatusCode.INTERNAL
7666

7767

packages/pynumaflow/tests/map/test_sync_mapper.py

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,33 +36,25 @@ def test_init_with_args():
3636
assert my_servicer.max_message_size == 1024 * 1024 * 5
3737

3838

39-
def test_udf_map_err_handshake():
39+
@pytest.mark.parametrize(
40+
"handshake,expected_msg",
41+
[
42+
(False, "MapFn: expected handshake as the first message"),
43+
(True, "Something is fishy!"),
44+
],
45+
)
46+
def test_udf_map_error(handshake, expected_msg):
4047
my_server = MapServer(mapper_instance=err_map_handler)
4148
services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer}
4249
test_server = server_from_dictionary(services, strict_real_time())
4350

44-
test_datums = get_test_datums(handshake=False)
51+
test_datums = get_test_datums(handshake=handshake)
4552
method = _invoke_map_fn(test_server)
4653
send_test_requests(method, test_datums)
4754
drain_responses(method)
4855

4956
metadata, code, details = method.termination()
50-
assert "MapFn: expected handshake as the first message" in details
51-
assert code == StatusCode.INTERNAL
52-
53-
54-
def test_udf_map_error_response():
55-
my_server = MapServer(mapper_instance=err_map_handler)
56-
services = {map_pb2.DESCRIPTOR.services_by_name["Map"]: my_server.servicer}
57-
test_server = server_from_dictionary(services, strict_real_time())
58-
59-
test_datums = get_test_datums(handshake=True)
60-
method = _invoke_map_fn(test_server)
61-
send_test_requests(method, test_datums)
62-
drain_responses(method)
63-
64-
metadata, code, details = method.termination()
65-
assert "Something is fishy!" in details
57+
assert expected_msg in details
6658
assert code == StatusCode.INTERNAL
6759

6860

@@ -110,15 +102,17 @@ def test_invalid_input():
110102
MapServer()
111103

112104

113-
def test_max_threads():
114-
# max cap at 16
115-
server = MapServer(mapper_instance=map_handler, max_threads=32)
116-
assert server.max_threads == 16
117-
118-
# use argument provided
119-
server = MapServer(mapper_instance=map_handler, max_threads=5)
120-
assert server.max_threads == 5
121-
122-
# defaults to 4
123-
server = MapServer(mapper_instance=map_handler)
124-
assert server.max_threads == 4
105+
@pytest.mark.parametrize(
106+
"max_threads_arg,expected",
107+
[
108+
(32, 16), # max cap at 16
109+
(5, 5), # use argument provided
110+
(None, 4), # defaults to 4
111+
],
112+
)
113+
def test_max_threads(max_threads_arg, expected):
114+
kwargs = {"mapper_instance": map_handler}
115+
if max_threads_arg is not None:
116+
kwargs["max_threads"] = max_threads_arg
117+
server = MapServer(**kwargs)
118+
assert server.max_threads == expected

packages/pynumaflow/tests/mapstream/test_async_map_stream.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,17 @@ def test_is_ready(async_map_stream_server):
153153
assert response.ready
154154

155155

156-
def test_max_threads():
157-
# max cap at 16
158-
server = MapStreamAsyncServer(map_stream_instance=async_map_stream_handler, max_threads=32)
159-
assert server.max_threads == 16
160-
161-
# use argument provided
162-
server = MapStreamAsyncServer(map_stream_instance=async_map_stream_handler, max_threads=5)
163-
assert server.max_threads == 5
164-
165-
# defaults to 4
166-
server = MapStreamAsyncServer(map_stream_instance=async_map_stream_handler)
167-
assert server.max_threads == 4
156+
@pytest.mark.parametrize(
157+
"max_threads_arg,expected",
158+
[
159+
(32, 16), # max cap at 16
160+
(5, 5), # use argument provided
161+
(None, 4), # defaults to 4
162+
],
163+
)
164+
def test_max_threads(max_threads_arg, expected):
165+
kwargs = {"map_stream_instance": async_map_stream_handler}
166+
if max_threads_arg is not None:
167+
kwargs["max_threads"] = max_threads_arg
168+
server = MapStreamAsyncServer(**kwargs)
169+
assert server.max_threads == expected

packages/pynumaflow/tests/reduce/test_async_reduce.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -248,15 +248,17 @@ class ExampleBadClass:
248248
ReduceAsyncServer(reducer_instance=ExampleBadClass)
249249

250250

251-
def test_max_threads():
252-
# max cap at 16
253-
server = ReduceAsyncServer(reducer_instance=ExampleClass, max_threads=32)
254-
assert server.max_threads == 16
255-
256-
# use argument provided
257-
server = ReduceAsyncServer(reducer_instance=ExampleClass, max_threads=5)
258-
assert server.max_threads == 5
259-
260-
# defaults to 4
261-
server = ReduceAsyncServer(reducer_instance=ExampleClass)
262-
assert server.max_threads == 4
251+
@pytest.mark.parametrize(
252+
"max_threads_arg,expected",
253+
[
254+
(32, 16), # max cap at 16
255+
(5, 5), # use argument provided
256+
(None, 4), # defaults to 4
257+
],
258+
)
259+
def test_max_threads(max_threads_arg, expected):
260+
kwargs = {"reducer_instance": ExampleClass}
261+
if max_threads_arg is not None:
262+
kwargs["max_threads"] = max_threads_arg
263+
server = ReduceAsyncServer(**kwargs)
264+
assert server.max_threads == expected

packages/pynumaflow/tests/reducestreamer/test_async_reduce.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -279,18 +279,20 @@ class ExampleBadClass:
279279
ReduceStreamAsyncServer(reduce_stream_instance=ExampleBadClass)
280280

281281

282-
def test_max_threads():
283-
# max cap at 16
284-
server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass, max_threads=32)
285-
assert server.max_threads == 16
286-
287-
# use argument provided
288-
server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass, max_threads=5)
289-
assert server.max_threads == 5
290-
291-
# defaults to 4
292-
server = ReduceStreamAsyncServer(reduce_stream_instance=ExampleClass)
293-
assert server.max_threads == 4
282+
@pytest.mark.parametrize(
283+
"max_threads_arg,expected",
284+
[
285+
(32, 16), # max cap at 16
286+
(5, 5), # use argument provided
287+
(None, 4), # defaults to 4
288+
],
289+
)
290+
def test_max_threads(max_threads_arg, expected):
291+
kwargs = {"reduce_stream_instance": ExampleClass}
292+
if max_threads_arg is not None:
293+
kwargs["max_threads"] = max_threads_arg
294+
server = ReduceStreamAsyncServer(**kwargs)
295+
assert server.max_threads == expected
294296

295297

296298
def test_start_shutdown_handler_without_callback():

packages/pynumaflow/tests/sideinput/test_side_input_server.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -108,15 +108,17 @@ def test_invalid_input():
108108
SideInputServer()
109109

110110

111-
def test_max_threads():
112-
# max cap at 16
113-
server = SideInputServer(retrieve_side_input_handler, max_threads=32)
114-
assert server.max_threads == 16
115-
116-
# use argument provided
117-
server = SideInputServer(retrieve_side_input_handler, max_threads=5)
118-
assert server.max_threads == 5
119-
120-
# defaults to 4
121-
server = SideInputServer(retrieve_side_input_handler)
122-
assert server.max_threads == 4
111+
@pytest.mark.parametrize(
112+
"max_threads_arg,expected",
113+
[
114+
(32, 16), # max cap at 16
115+
(5, 5), # use argument provided
116+
(None, 4), # defaults to 4
117+
],
118+
)
119+
def test_max_threads(max_threads_arg, expected):
120+
kwargs = {"side_input_instance": retrieve_side_input_handler}
121+
if max_threads_arg is not None:
122+
kwargs["max_threads"] = max_threads_arg
123+
server = SideInputServer(**kwargs)
124+
assert server.max_threads == expected

0 commit comments

Comments
 (0)