Skip to content

Commit 2806376

Browse files
committed
add unit test
Signed-off-by: Sidhant Kohli <sidhant.kohli@gmail.com>
1 parent 1ca6249 commit 2806376

2 files changed

Lines changed: 84 additions & 1 deletion

File tree

pynumaflow/mapper/async_multiproc_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def start(self):
8282
bind_address = f"0.0.0.0:{reserved_port}"
8383
ports.append(f"http://{bind_address}")
8484
else:
85-
bind_address = f"unix://{self.sock_path}{idx}.sock"
85+
bind_address = f"{self.sock_path}{idx}.sock"
8686
_LOGGER.info("Binding server to: %s", bind_address)
8787

8888
worker = multiprocessing.Process(

tests/map/test_async_multiproc.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import os
2+
import time
3+
import uuid
4+
import unittest
5+
import grpc
6+
from multiprocessing import Process
7+
from pynumaflow.mapper import Datum, Messages, Message
8+
from pynumaflow.proto.mapper import map_pb2_grpc
9+
from tests.map.test_async_mapper import request_generator
10+
from tests.map.utils import get_test_datums
11+
from pynumaflow.mapper import AsyncMapMultiprocServer # wherever your class is
12+
13+
sock_prefix = f"/tmp/test_async_multiproc_map_{uuid.uuid4().hex}_"
14+
15+
16+
async def async_handler(keys, datum: Datum) -> Messages:
17+
msg = f"payload:{datum.value.decode()} event_time:{datum.event_time} watermark:{datum.watermark}"
18+
return Messages(Message(value=msg.encode(), keys=keys))
19+
20+
class TestAsyncMapMultiprocServer(unittest.TestCase):
21+
def setUp(self):
22+
self.base_sock_path = sock_prefix
23+
self.server = AsyncMapMultiprocServer(
24+
mapper_instance=async_handler,
25+
server_count=2,
26+
sock_path=self.base_sock_path,
27+
use_tcp=False,
28+
server_info_file="/tmp/server_info"
29+
)
30+
self.process = Process(target=self.server.start)
31+
self.process.start()
32+
33+
# Wait for both servers to bind
34+
self.socket_paths = [f"{self.base_sock_path}{i}.sock" for i in range(2)]
35+
for path in self.socket_paths:
36+
for _ in range(10):
37+
if os.path.exists(path):
38+
break
39+
time.sleep(0.5)
40+
41+
def tearDown(self):
42+
self.process.terminate()
43+
self.process.join()
44+
for path in self.socket_paths:
45+
try:
46+
os.remove(path)
47+
except FileNotFoundError:
48+
pass
49+
50+
def test_map_fn(self):
51+
bind_address = f"unix://{self.socket_paths[0]}"
52+
request = get_test_datums()
53+
with grpc.insecure_channel(bind_address) as channel:
54+
stub = map_pb2_grpc.MapStub(channel)
55+
responses_iter = stub.MapFn(request_iterator=request_generator(request))
56+
responses = []
57+
# capture the output from the ReadFn generator and assert.
58+
for r in responses_iter:
59+
responses.append(r)
60+
61+
# 1 handshake + 3 data responses
62+
self.assertEqual(4, len(responses))
63+
64+
self.assertTrue(responses[0].handshake.sot)
65+
66+
idx = 1
67+
while idx < len(responses):
68+
_id = "test-id-" + str(idx)
69+
self.assertEqual(_id, responses[idx].id)
70+
self.assertEqual(
71+
bytes(
72+
"payload:test_mock_message "
73+
"event_time:2022-09-12 16:00:00 watermark:2022-09-12 16:01:00",
74+
encoding="utf-8",
75+
),
76+
responses[idx].results[0].value,
77+
)
78+
self.assertEqual(1, len(responses[idx].results))
79+
idx += 1
80+
81+
def test_server_start(self):
82+
for path in self.socket_paths:
83+
self.assertTrue(os.path.exists(path), f"Server socket {path} was not created successfully")

0 commit comments

Comments
 (0)