@@ -27,51 +27,41 @@ def _invoke_map_fn(test_server, timeout=1):
2727 )
2828
2929
30- def test_multiproc_init ():
31- my_server = MapMultiprocServer (mapper_instance = map_handler , server_count = 3 )
32- assert my_server ._process_count == 3
33-
34-
35- def test_multiproc_process_count ():
36- default_val = os .cpu_count ()
37- my_server = MapMultiprocServer (mapper_instance = map_handler )
38- assert my_server ._process_count == default_val
39-
40-
41- def test_max_process_count ():
42- """Max process count is capped at 2 * os.cpu_count, irrespective of what the user
43- provides as input"""
44- default_val = os .cpu_count ()
45- server = MapMultiprocServer (mapper_instance = map_handler , server_count = 100 )
46- assert server ._process_count == default_val * 2
47-
48-
49- def test_udf_map_err_handshake ():
30+ @pytest .mark .parametrize (
31+ "server_count,expected" ,
32+ [
33+ (3 , 3 ), # explicit count
34+ (None , os .cpu_count ()), # default to cpu count
35+ (100 , os .cpu_count () * 2 ), # max cap at 2 * cpu count
36+ ],
37+ )
38+ def test_process_count (server_count , expected ):
39+ kwargs = {"mapper_instance" : map_handler }
40+ if server_count is not None :
41+ kwargs ["server_count" ] = server_count
42+ server = MapMultiprocServer (** kwargs )
43+ assert server ._process_count == expected
44+
45+
46+ @pytest .mark .parametrize (
47+ "handshake,expected_msg" ,
48+ [
49+ (False , "MapFn: expected handshake as the first message" ),
50+ (True , "Something is fishy!" ),
51+ ],
52+ )
53+ def test_udf_map_error (handshake , expected_msg ):
5054 my_server = MapMultiprocServer (mapper_instance = err_map_handler )
5155 services = {map_pb2 .DESCRIPTOR .services_by_name ["Map" ]: my_server .servicer }
5256 test_server = server_from_dictionary (services , strict_real_time ())
5357
54- test_datums = get_test_datums (handshake = False )
55- method = _invoke_map_fn (test_server )
56- send_test_requests (method , test_datums )
57- drain_responses (method )
58-
59- metadata , code , details = method .termination ()
60- assert "MapFn: expected handshake as the first message" in details
61- assert code == StatusCode .INTERNAL
62-
63-
64- def test_udf_map_err ():
65- my_server = MapMultiprocServer (mapper_instance = err_map_handler )
66- services = {map_pb2 .DESCRIPTOR .services_by_name ["Map" ]: my_server .servicer }
67- test_server = server_from_dictionary (services , strict_real_time ())
68- test_datums = get_test_datums (handshake = True )
58+ test_datums = get_test_datums (handshake = handshake )
6959 method = _invoke_map_fn (test_server )
7060 send_test_requests (method , test_datums )
7161 drain_responses (method )
7262
7363 metadata , code , details = method .termination ()
74- assert "Something is fishy!" in details
64+ assert expected_msg in details
7565 assert code == StatusCode .INTERNAL
7666
7767
0 commit comments