Skip to content

Commit 67fd389

Browse files
committed
multiproc async
Signed-off-by: Sidhant Kohli <sidhant.kohli@gmail.com>
1 parent 157a90d commit 67fd389

11 files changed

Lines changed: 141 additions & 25 deletions

File tree

pynumaflow/batchmapper/servicer/async_servicer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ async def MapFn(
9898

9999
except BaseException as err:
100100
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
101-
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING)
101+
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING, False)
102102
return
103103

104104
async def IsReady(

pynumaflow/info/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
# MULTIPROC_KEY is the field used to indicate that Multiproc map mode is enabled
2727
# The value contains the number of servers spawned.
2828
MULTIPROC_KEY = "MULTIPROC"
29+
MULTIPROC_ENDPOINTS = "MULTIPROC_ENDPOINTS"
2930

3031
SI = TypeVar("SI", bound="ServerInfo")
3132

pynumaflow/mapper/_servicer/_async_servicer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,10 @@ class AsyncMapServicer(map_pb2_grpc.MapServicer):
1818
Provides the functionality for the required rpc methods.
1919
"""
2020

21-
def __init__(
22-
self,
23-
handler: MapAsyncCallable,
24-
):
21+
def __init__(self, handler: MapAsyncCallable, multiproc: bool = False):
2522
self.background_tasks = set()
23+
# This indicates whether the grpc server attached is multiproc or not
24+
self.multiproc = multiproc
2625
self.__map_handler: MapAsyncCallable = handler
2726

2827
async def MapFn(
@@ -56,7 +55,7 @@ async def MapFn(
5655
async for msg in consumer:
5756
# If the message is an exception, we raise the exception
5857
if isinstance(msg, BaseException):
59-
await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING)
58+
await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING, self.multiproc)
6059
return
6160
# Send window response back to the client
6261
else:
@@ -65,7 +64,7 @@ async def MapFn(
6564
await producer
6665
except BaseException as e:
6766
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
68-
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
67+
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING, self.multiproc)
6968
return
7069

7170
async def _process_inputs(
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import logging
2+
import multiprocessing
3+
4+
import aiorun
5+
import grpc
6+
7+
from pynumaflow._constants import MAX_NUM_THREADS, MAX_MESSAGE_SIZE, MAP_SOCK_PATH, MAP_SERVER_INFO_FILE_PATH, \
8+
_PROCESS_COUNT, NUM_THREADS_DEFAULT, MULTIPROC_MAP_SOCK_ADDR
9+
from pynumaflow.info.server import get_metadata_env
10+
from pynumaflow.info.types import ServerInfo, MINIMUM_NUMAFLOW_VERSION, ContainerType, MAP_MODE_KEY, MapMode, \
11+
METADATA_ENVS, MULTIPROC_KEY, MULTIPROC_ENDPOINTS, Protocol
12+
from pynumaflow.mapper._dtypes import MapAsyncCallable
13+
from pynumaflow.mapper._servicer._async_servicer import AsyncMapServicer
14+
from pynumaflow.proto.mapper import map_pb2_grpc
15+
from pynumaflow.shared.server import start_async_server, NumaflowServer, reserve_port
16+
from pynumaflow.info.server import write as info_server_write
17+
18+
_LOGGER = logging.getLogger(__name__)
19+
20+
21+
class AsyncMultiprocMapServer(NumaflowServer):
22+
"""
23+
A multiprocess asynchronous gRPC server for Numaflow Map UDFs.
24+
Spawns N worker processes, each running an asyncio-based gRPC server.
25+
"""
26+
27+
def __init__(
28+
self,
29+
mapper_instance: MapAsyncCallable,
30+
server_count: int = _PROCESS_COUNT,
31+
sock_path: str = MULTIPROC_MAP_SOCK_ADDR,
32+
max_message_size: int = MAX_MESSAGE_SIZE,
33+
max_threads: int = NUM_THREADS_DEFAULT,
34+
server_info_file: str = MAP_SERVER_INFO_FILE_PATH,
35+
use_tcp: bool = False,
36+
):
37+
self.sock_path = f"unix://{sock_path}"
38+
self.max_threads = min(max_threads, MAX_NUM_THREADS)
39+
self.max_message_size = max_message_size
40+
self.server_info_file = server_info_file
41+
self.use_tcp = use_tcp
42+
43+
self.mapper_instance = mapper_instance
44+
45+
self._server_options = [
46+
("grpc.max_send_message_length", self.max_message_size),
47+
("grpc.max_receive_message_length", self.max_message_size),
48+
("grpc.so_reuseport", 1),
49+
("grpc.so_reuseaddr", 1),
50+
]
51+
52+
self._process_count = min(server_count, 2 * _PROCESS_COUNT)
53+
self.servicer = AsyncMapServicer(handler=self.mapper_instance, multiproc=True)
54+
55+
def start(self):
56+
"""
57+
Starts the multiprocess async gRPC servers.
58+
"""
59+
_LOGGER.info(
60+
"Starting async multiprocess gRPC server with %d workers", self._process_count
61+
)
62+
63+
workers = []
64+
ports = []
65+
66+
for idx in range(self._process_count):
67+
if self.use_tcp:
68+
with reserve_port(0) as reserved_port:
69+
bind_address = f"0.0.0.0:{reserved_port}"
70+
ports.append(f"http://{bind_address}")
71+
else:
72+
bind_address = f"unix://{self.sock_path}{idx}.sock"
73+
_LOGGER.info("Binding server to: %s", bind_address)
74+
75+
worker = multiprocessing.Process(
76+
target=self._run_server_process,
77+
args=(bind_address,),
78+
)
79+
worker.start()
80+
workers.append(worker)
81+
82+
# Write server info file
83+
server_info = ServerInfo.get_default_server_info()
84+
server_info.metadata[MULTIPROC_KEY] = str(self._process_count)
85+
server_info.metadata[MAP_MODE_KEY] = MapMode.UnaryMap
86+
if self.use_tcp:
87+
server_info.protocol = Protocol.TCP
88+
server_info.metadata[MULTIPROC_ENDPOINTS] = ",".join(map(str, ports))
89+
info_server_write(server_info=server_info, info_file=self.server_info_file)
90+
91+
for worker in workers:
92+
worker.join()
93+
94+
def _run_server_process(self, bind_address):
95+
async def run_server():
96+
server = grpc.aio.server(options=self._server_options)
97+
server.add_insecure_port(bind_address)
98+
map_pb2_grpc.add_MapServicer_to_server(self.servicer, server)
99+
100+
server_info = ServerInfo.get_default_server_info()
101+
server_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Mapper]
102+
server_info.metadata = get_metadata_env(envs=METADATA_ENVS)
103+
# Add the MULTIPROC metadata using the number of servers to use
104+
server_info.metadata[MULTIPROC_KEY] = str(self._process_count)
105+
# Add the MAP_MODE metadata to the server info for the correct map mode
106+
server_info.metadata[MAP_MODE_KEY] = MapMode.UnaryMap
107+
108+
await start_async_server(
109+
server_async=server,
110+
sock_path=bind_address,
111+
max_threads=self.max_threads,
112+
cleanup_coroutines=list(),
113+
server_info_file=self.server_info_file,
114+
server_info=server_info,
115+
)
116+
117+
aiorun.run(run_server(), use_uvloop=True)

pynumaflow/mapper/async_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282
("grpc.max_receive_message_length", self.max_message_size),
8383
]
8484
# Get the servicer instance for the async server
85-
self.servicer = AsyncMapServicer(handler=mapper_instance)
85+
self.servicer = AsyncMapServicer(handler=mapper_instance, multiproc=False)
8686

8787
def start(self) -> None:
8888
"""

pynumaflow/mapstreamer/servicer/async_servicer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ async def MapFn(
5959
yield map_pb2.MapResponse(status=map_pb2.TransmissionStatus(eot=True), id=req.id)
6060
except BaseException as err:
6161
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
62-
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING)
62+
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING, False)
6363
return
6464

6565
async def __invoke_map_stream(self, keys: list[str], req: Datum):

pynumaflow/reducer/servicer/async_servicer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ async def ReduceFn(
105105
_LOGGER.critical("Reduce Error", exc_info=True)
106106
# Send a context abort signal for the rpc, this is required for numa container to get
107107
# the correct grpc error
108-
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
108+
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING, False)
109109

110110
# send EOF to all the tasks once the request iterator is exhausted
111111
# This will signal the tasks to stop reading the data on their
@@ -136,7 +136,7 @@ async def ReduceFn(
136136
_LOGGER.critical("Reduce Error", exc_info=True)
137137
# Send a context abort signal for the rpc, this is required for numa container to get
138138
# the correct grpc error
139-
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
139+
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING, False)
140140

141141
async def IsReady(
142142
self, request: _empty_pb2.Empty, context: NumaflowServicerContext

pynumaflow/reducestreamer/servicer/async_servicer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,20 +95,20 @@ async def ReduceFn(
9595
async for msg in consumer:
9696
# If the message is an exception, we raise the exception
9797
if isinstance(msg, BaseException):
98-
await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING)
98+
await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING, False)
9999
return
100100
# Send window EOF response or Window result response
101101
# back to the client
102102
else:
103103
yield msg
104104
except BaseException as e:
105-
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
105+
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING, False)
106106
return
107107
# Wait for the process_input_stream task to finish for a clean exit
108108
try:
109109
await producer
110110
except BaseException as e:
111-
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
111+
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING, False)
112112
return
113113

114114
async def IsReady(

pynumaflow/shared/server.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ async def server_graceful_shutdown():
217217

218218

219219
@contextlib.contextmanager
220-
def _reserve_port(port_num: int) -> Iterator[int]:
220+
def reserve_port(port_num: int) -> Iterator[int]:
221221
"""Find and reserve a port for all subprocesses to use."""
222222
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
223223
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@@ -311,9 +311,8 @@ def get_exception_traceback_str(exc) -> str:
311311
return file.getvalue().rstrip()
312312

313313

314-
async def handle_async_error(
315-
context: NumaflowServicerContext, exception: BaseException, exception_type: str
316-
):
314+
async def handle_async_error(context: NumaflowServicerContext, exception: BaseException,
315+
exception_type: str, parent: bool = False):
317316
"""
318317
Handle exceptions for async servers by updating the context and exiting.
319318
"""
@@ -322,4 +321,4 @@ async def handle_async_error(
322321
await asyncio.gather(
323322
context.abort(grpc.StatusCode.INTERNAL, details=err_msg), return_exceptions=True
324323
)
325-
exit_on_error(err=err_msg, parent=False, context=context, update_context=False)
324+
exit_on_error(err=err_msg, parent=parent, context=context, update_context=False)

pynumaflow/sinker/servicer/async_servicer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ async def SinkFn(
8585
# if there is an exception, we will mark all the responses as a failure
8686
err_msg = f"UDSinkError: {repr(err)}"
8787
_LOGGER.critical(err_msg, exc_info=True)
88-
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING)
88+
await handle_async_error(context, err, ERR_UDF_EXCEPTION_STRING, False)
8989
return
9090

9191
async def __invoke_sink(

0 commit comments

Comments
 (0)