1+ import os
2+ import time
3+ import uuid
4+ import unittest
5+ import grpc
6+ from multiprocessing import Process
7+ from pynumaflow .mapper import Datum , Messages , Message
8+ from pynumaflow .proto .mapper import map_pb2_grpc
9+ from tests .map .test_async_mapper import request_generator
10+ from tests .map .utils import get_test_datums
11+ from pynumaflow .mapper import AsyncMapMultiprocServer # wherever your class is
12+
13+ sock_prefix = f"/tmp/test_async_multiproc_map_{ uuid .uuid4 ().hex } _"
14+
15+
16+ async def async_handler (keys , datum : Datum ) -> Messages :
17+ msg = f"payload:{ datum .value .decode ()} event_time:{ datum .event_time } watermark:{ datum .watermark } "
18+ return Messages (Message (value = msg .encode (), keys = keys ))
19+
20+ class TestAsyncMapMultiprocServer (unittest .TestCase ):
21+ def setUp (self ):
22+ self .base_sock_path = sock_prefix
23+ self .server = AsyncMapMultiprocServer (
24+ mapper_instance = async_handler ,
25+ server_count = 2 ,
26+ sock_path = self .base_sock_path ,
27+ use_tcp = False ,
28+ server_info_file = "/tmp/server_info"
29+ )
30+ self .process = Process (target = self .server .start )
31+ self .process .start ()
32+
33+ # Wait for both servers to bind
34+ self .socket_paths = [f"{ self .base_sock_path } { i } .sock" for i in range (2 )]
35+ for path in self .socket_paths :
36+ for _ in range (10 ):
37+ if os .path .exists (path ):
38+ break
39+ time .sleep (0.5 )
40+
41+ def tearDown (self ):
42+ self .process .terminate ()
43+ self .process .join ()
44+ for path in self .socket_paths :
45+ try :
46+ os .remove (path )
47+ except FileNotFoundError :
48+ pass
49+
50+ def test_map_fn (self ):
51+ bind_address = f"unix://{ self .socket_paths [0 ]} "
52+ request = get_test_datums ()
53+ with grpc .insecure_channel (bind_address ) as channel :
54+ stub = map_pb2_grpc .MapStub (channel )
55+ responses_iter = stub .MapFn (request_iterator = request_generator (request ))
56+ responses = []
57+ # capture the output from the ReadFn generator and assert.
58+ for r in responses_iter :
59+ responses .append (r )
60+
61+ # 1 handshake + 3 data responses
62+ self .assertEqual (4 , len (responses ))
63+
64+ self .assertTrue (responses [0 ].handshake .sot )
65+
66+ idx = 1
67+ while idx < len (responses ):
68+ _id = "test-id-" + str (idx )
69+ self .assertEqual (_id , responses [idx ].id )
70+ self .assertEqual (
71+ bytes (
72+ "payload:test_mock_message "
73+ "event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00" ,
74+ encoding = "utf-8" ,
75+ ),
76+ responses [idx ].results [0 ].value ,
77+ )
78+ self .assertEqual (1 , len (responses [idx ].results ))
79+ idx += 1
80+
81+ def test_server_start (self ):
82+ for path in self .socket_paths :
83+ self .assertTrue (os .path .exists (path ), f"Server socket { path } was not created successfully" )
0 commit comments