diff --git a/.github/workflows/build-push.yaml b/.github/workflows/build-push.yaml index ab6a122f..a2c1ed8c 100644 --- a/.github/workflows/build-push.yaml +++ b/.github/workflows/build-push.yaml @@ -23,7 +23,7 @@ jobs: "examples/reducestream/counter", "examples/reducestream/sum", "examples/sideinput/simple_sideinput", "examples/sideinput/simple_sideinput/udf", "examples/sink/async_log", "examples/sink/log", "examples/source/simple_source", "examples/sourcetransform/event_time_filter", - "examples/batchmap/flatmap" + "examples/batchmap/flatmap", "examples/accumulator/streamsorter" ] steps: diff --git a/Makefile b/Makefile index 1ababadc..0e5334f2 100644 --- a/Makefile +++ b/Makefile @@ -32,6 +32,7 @@ proto: poetry run python3 -m grpc_tools.protoc --pyi_out=pynumaflow/proto/sourcetransformer -I=pynumaflow/proto/sourcetransformer --python_out=pynumaflow/proto/sourcetransformer --grpc_python_out=pynumaflow/proto/sourcetransformer pynumaflow/proto/sourcetransformer/*.proto poetry run python3 -m grpc_tools.protoc --pyi_out=pynumaflow/proto/sideinput -I=pynumaflow/proto/sideinput --python_out=pynumaflow/proto/sideinput --grpc_python_out=pynumaflow/proto/sideinput pynumaflow/proto/sideinput/*.proto poetry run python3 -m grpc_tools.protoc --pyi_out=pynumaflow/proto/sourcer -I=pynumaflow/proto/sourcer --python_out=pynumaflow/proto/sourcer --grpc_python_out=pynumaflow/proto/sourcer pynumaflow/proto/sourcer/*.proto + poetry run python3 -m grpc_tools.protoc --pyi_out=pynumaflow/proto/accumulator -I=pynumaflow/proto/accumulator --python_out=pynumaflow/proto/accumulator --grpc_python_out=pynumaflow/proto/accumulator pynumaflow/proto/accumulator/*.proto sed -i.bak -e 's/^\(import.*_pb2\)/from . \1/' pynumaflow/proto/*/*.py diff --git a/examples/accumulator/streamsorter/Dockerfile b/examples/accumulator/streamsorter/Dockerfile new file mode 100644 index 00000000..dd2d605b --- /dev/null +++ b/examples/accumulator/streamsorter/Dockerfile @@ -0,0 +1,55 @@ +#################################################################################################### +# Stage 1: Base Builder - installs core dependencies using poetry +#################################################################################################### +FROM python:3.10-slim-bullseye AS base-builder + +ENV PYSETUP_PATH="/opt/pysetup" +WORKDIR $PYSETUP_PATH + +# Copy only core dependency files first for better caching +COPY pyproject.toml poetry.lock README.md ./ +COPY pynumaflow/ ./pynumaflow/ +RUN apt-get update && apt-get install --no-install-recommends -y \ + curl wget build-essential git \ + && apt-get clean && rm -rf /var/lib/apt/lists/* \ + && pip install poetry \ + && poetry install --no-root --no-interaction + +#################################################################################################### +# Stage 2: UDF Builder - adds UDF code and installs UDF-specific deps +#################################################################################################### +FROM base-builder AS udf-builder + +ENV EXAMPLE_PATH="/opt/pysetup/examples/accumulator/streamsorter" +ENV POETRY_VIRTUALENVS_IN_PROJECT=true + +WORKDIR $EXAMPLE_PATH +COPY examples/accumulator/streamsorter/ ./ +RUN poetry install --no-root --no-interaction + +#################################################################################################### +# Stage 3: UDF Runtime - clean container with only needed stuff +#################################################################################################### +FROM python:3.10-slim-bullseye AS udf + +ENV PYSETUP_PATH="/opt/pysetup" +ENV EXAMPLE_PATH="$PYSETUP_PATH/examples/accumulator/streamsorter" +ENV VENV_PATH="$EXAMPLE_PATH/.venv" +ENV PATH="$VENV_PATH/bin:$PATH" + +RUN apt-get update && apt-get install --no-install-recommends -y wget \ + && apt-get clean && rm -rf /var/lib/apt/lists/* \ + && wget -O /dumb-init https://github.com/Yelp/dumb-init/releases/download/v1.2.5/dumb-init_1.2.5_x86_64 \ + && chmod +x /dumb-init + +WORKDIR $PYSETUP_PATH +COPY --from=udf-builder $VENV_PATH $VENV_PATH +COPY --from=udf-builder $EXAMPLE_PATH $EXAMPLE_PATH + +WORKDIR $EXAMPLE_PATH +RUN chmod +x entry.sh + +ENTRYPOINT ["/dumb-init", "--"] +CMD ["sh", "-c", "$EXAMPLE_PATH/entry.sh"] + +EXPOSE 5000 \ No newline at end of file diff --git a/examples/accumulator/streamsorter/Makefile b/examples/accumulator/streamsorter/Makefile new file mode 100644 index 00000000..5eb6a3e8 --- /dev/null +++ b/examples/accumulator/streamsorter/Makefile @@ -0,0 +1,22 @@ +TAG ?= stable +PUSH ?= false +IMAGE_REGISTRY = quay.io/numaio/numaflow-python/streamsorter:${TAG} +DOCKER_FILE_PATH = examples/accumulator/streamsorter/Dockerfile + +.PHONY: update +update: + poetry update -vv + +.PHONY: image-push +image-push: update + cd ../../../ && docker buildx build \ + -f ${DOCKER_FILE_PATH} \ + -t ${IMAGE_REGISTRY} \ + --platform linux/amd64,linux/arm64 . --push + +.PHONY: image +image: update + cd ../../../ && docker build \ + -f ${DOCKER_FILE_PATH} \ + -t ${IMAGE_REGISTRY} . + @if [ "$(PUSH)" = "true" ]; then docker push ${IMAGE_REGISTRY}; fi diff --git a/examples/accumulator/streamsorter/Makefile.optimized b/examples/accumulator/streamsorter/Makefile.optimized new file mode 100644 index 00000000..136be046 --- /dev/null +++ b/examples/accumulator/streamsorter/Makefile.optimized @@ -0,0 +1,52 @@ +TAG ?= stable +PUSH ?= false +IMAGE_REGISTRY = quay.io/numaio/numaflow-python/streamsorter:${TAG} +DOCKER_FILE_PATH = examples/accumulator/streamsorter/Dockerfile +BASE_IMAGE_NAME = numaflow-python-base + +.PHONY: base-image +base-image: + @echo "Building shared base image..." + docker build -f Dockerfile.base -t ${BASE_IMAGE_NAME} . + +.PHONY: update +update: + poetry update -vv + +.PHONY: image-push +image-push: base-image update + cd ../../../ && docker buildx build \ + -f ${DOCKER_FILE_PATH} \ + -t ${IMAGE_REGISTRY} \ + --platform linux/amd64,linux/arm64 . --push + +.PHONY: image +image: base-image update + cd ../../../ && docker build \ + -f ${DOCKER_FILE_PATH} \ + -t ${IMAGE_REGISTRY} . + @if [ "$(PUSH)" = "true" ]; then docker push ${IMAGE_REGISTRY}; fi + +.PHONY: image-fast +image-fast: update + @echo "Building with shared base image (fastest option)..." + cd ../../../ && docker build \ + -f examples/map/even_odd/Dockerfile.shared-base \ + -t ${IMAGE_REGISTRY} . + @if [ "$(PUSH)" = "true" ]; then docker push ${IMAGE_REGISTRY}; fi + +.PHONY: clean +clean: + docker rmi ${BASE_IMAGE_NAME} 2>/dev/null || true + docker rmi ${IMAGE_REGISTRY} 2>/dev/null || true + +.PHONY: help +help: + @echo "Available targets:" + @echo " base-image - Build the shared base image with pynumaflow" + @echo " image - Build UDF image with optimized multi-stage build" + @echo " image-fast - Build UDF image using shared base (fastest)" + @echo " image-push - Build and push multi-platform image" + @echo " update - Update poetry dependencies" + @echo " clean - Remove built images" + @echo " help - Show this help message" \ No newline at end of file diff --git a/examples/accumulator/streamsorter/README.md b/examples/accumulator/streamsorter/README.md new file mode 100644 index 00000000..19b8da6e --- /dev/null +++ b/examples/accumulator/streamsorter/README.md @@ -0,0 +1,43 @@ +# Stream Sorter + +An example User Defined Function that sorts the incoming stream by event time. + +### Applying the Pipeline + +To apply the pipeline, use the following command: + +```shell + kubectl apply -f pipeline.yaml +``` + +### Publish messages + +Port-forward the HTTP endpoint, and make POST requests using curl. Remember to replace xxxx with the appropriate pod names. + +```shell + kubectl port-forward stream-sorter-http-one-0-xxxx 8444:8443 + + # Post data to the HTTP endpoint + curl -kq -X POST -d "101" https://localhost:8444/vertices/http-one -H "X-Numaflow-Event-Time: 60000" + curl -kq -X POST -d "102" https://localhost:8444/vertices/http-one -H "X-Numaflow-Event-Time: 61000" + curl -kq -X POST -d "103" https://localhost:8444/vertices/http-one -H "X-Numaflow-Event-Time: 62000" + curl -kq -X POST -d "104" https://localhost:8444/vertices/http-one -H "X-Numaflow-Event-Time: 63000" +``` + +```shell + kubectl port-forward stream-sorter-http-two-0-xxxx 8445:8443 + + # Post data to the HTTP endpoint + curl -kq -X POST -d "105" https://localhost:8445/vertices/http-two -H "X-Numaflow-Event-Time: 70000" + curl -kq -X POST -d "106" https://localhost:8445/vertices/http-two -H "X-Numaflow-Event-Time: 71000" + curl -kq -X POST -d "107" https://localhost:8445/vertices/http-two -H "X-Numaflow-Event-Time: 72000" + curl -kq -X POST -d "108" https://localhost:8445/vertices/http-two -H "X-Numaflow-Event-Time: 73000" +``` + +### Verify the output + +```shell + kubectl logs -f stream-sorter-log-sink-0-xxxx +``` + +The output should be sorted by event time. \ No newline at end of file diff --git a/examples/accumulator/streamsorter/entry.sh b/examples/accumulator/streamsorter/entry.sh new file mode 100644 index 00000000..073b05e3 --- /dev/null +++ b/examples/accumulator/streamsorter/entry.sh @@ -0,0 +1,4 @@ +#!/bin/sh +set -eux + +python example.py diff --git a/examples/accumulator/streamsorter/example.py b/examples/accumulator/streamsorter/example.py new file mode 100644 index 00000000..8e0615ed --- /dev/null +++ b/examples/accumulator/streamsorter/example.py @@ -0,0 +1,72 @@ +import logging +import os +from collections.abc import AsyncIterable +from datetime import datetime + +from pynumaflow import setup_logging +from pynumaflow.accumulator import Accumulator, AccumulatorAsyncServer +from pynumaflow.accumulator import ( + Message, + Datum, +) +from pynumaflow.shared.asynciter import NonBlockingIterator + +_LOGGER = setup_logging(__name__) +if os.getenv("PYTHONDEBUG"): + _LOGGER.setLevel(logging.DEBUG) + + +class StreamSorter(Accumulator): + def __init__(self): + _LOGGER.info("StreamSorter initialized") + self.latest_wm = datetime.fromtimestamp(-1) + self.sorted_buffer: list[Datum] = [] + + async def handler( + self, + datums: AsyncIterable[Datum], + output: NonBlockingIterator, + ): + _LOGGER.info("StreamSorter handler started") + async for datum in datums: + _LOGGER.info( + f"Received datum with event time: {datum.event_time}, " + f"Current latest watermark: {self.latest_wm}, " + f"Datum watermark: {datum.watermark}" + ) + + # If watermark has moved forward + if datum.watermark and datum.watermark > self.latest_wm: + self.latest_wm = datum.watermark + await self.flush_buffer(output) + + self.insert_sorted(datum) + + def insert_sorted(self, datum: Datum): + # Binary insert to keep sorted buffer in order + left, right = 0, len(self.sorted_buffer) + while left < right: + mid = (left + right) // 2 + if self.sorted_buffer[mid].event_time > datum.event_time: + right = mid + else: + left = mid + 1 + self.sorted_buffer.insert(left, datum) + + async def flush_buffer(self, output: NonBlockingIterator): + _LOGGER.info(f"Watermark updated, flushing sortedBuffer: {self.latest_wm}") + i = 0 + for datum in self.sorted_buffer: + if datum.event_time > self.latest_wm: + break + await output.put(Message.from_datum(datum)) + _LOGGER.info(f"Sent datum with event time: {datum.event_time}") + i += 1 + # Remove flushed items + self.sorted_buffer = self.sorted_buffer[i:] + + +if __name__ == "__main__": + grpc_server = None + grpc_server = AccumulatorAsyncServer(StreamSorter) + grpc_server.start() diff --git a/examples/accumulator/streamsorter/pipeline.yaml b/examples/accumulator/streamsorter/pipeline.yaml new file mode 100644 index 00000000..d4ccab96 --- /dev/null +++ b/examples/accumulator/streamsorter/pipeline.yaml @@ -0,0 +1,49 @@ +apiVersion: numaflow.numaproj.io/v1alpha1 +kind: Pipeline +metadata: + name: stream-sorter +spec: + limits: + readBatchSize: 1 + vertices: + - name: http-one + scale: + min: 1 + max: 1 + source: + http: {} + - name: http-two + scale: + min: 1 + max: 1 + source: + http: {} + - name: py-accum + udf: + container: + image: quay.io/numaio/numaflow-python/streamsorter:stable + imagePullPolicy: Always + env: + - name: PYTHONDEBUG + value: "true" + groupBy: + window: + accumulator: + timeout: 10s + keyed: true + storage: + persistentVolumeClaim: + volumeSize: 1Gi + - name: py-sink + scale: + min: 1 + max: 1 + sink: + log: {} + edges: + - from: http-one + to: py-accum + - from: http-two + to: py-accum + - from: py-accum + to: py-sink diff --git a/examples/accumulator/streamsorter/pyproject.toml b/examples/accumulator/streamsorter/pyproject.toml new file mode 100644 index 00000000..9397268d --- /dev/null +++ b/examples/accumulator/streamsorter/pyproject.toml @@ -0,0 +1,13 @@ +[tool.poetry] +name = "stream-sorter" +version = "0.2.4" +description = "" +authors = ["Numaflow developers"] + +[tool.poetry.dependencies] +python = "~3.10" +pynumaflow = { path = "../../../"} + +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" diff --git a/examples/map/even_odd/Dockerfile b/examples/map/even_odd/Dockerfile index 1bf155ca..0e9be000 100644 --- a/examples/map/even_odd/Dockerfile +++ b/examples/map/even_odd/Dockerfile @@ -9,7 +9,6 @@ WORKDIR $PYSETUP_PATH # Copy only core dependency files first for better caching COPY pyproject.toml poetry.lock README.md ./ COPY pynumaflow/ ./pynumaflow/ -RUN echo "Simulating long build step..." && sleep 20 RUN apt-get update && apt-get install --no-install-recommends -y \ curl wget build-essential git \ && apt-get clean && rm -rf /var/lib/apt/lists/* \ diff --git a/pynumaflow/_constants.py b/pynumaflow/_constants.py index ea5e2b9d..01ae44d5 100644 --- a/pynumaflow/_constants.py +++ b/pynumaflow/_constants.py @@ -26,6 +26,7 @@ MULTIPROC_MAP_SOCK_ADDR = "/var/run/numaflow/multiproc" FALLBACK_SINK_SOCK_PATH = "/var/run/numaflow/fb-sink.sock" BATCH_MAP_SOCK_PATH = "/var/run/numaflow/batchmap.sock" +ACCUMULATOR_SOCK_PATH = "/var/run/numaflow/accumulator.sock" # Server information file configs MAP_SERVER_INFO_FILE_PATH = "/var/run/numaflow/mapper-server-info" @@ -36,6 +37,7 @@ SIDE_INPUT_SERVER_INFO_FILE_PATH = "/var/run/numaflow/sideinput-server-info" SOURCE_SERVER_INFO_FILE_PATH = "/var/run/numaflow/sourcer-server-info" FALLBACK_SINK_SERVER_INFO_FILE_PATH = "/var/run/numaflow/fb-sinker-server-info" +ACCUMULATOR_SERVER_INFO_FILE_PATH = "/var/run/numaflow/accumulator-server-info" ENV_UD_CONTAINER_TYPE = "NUMAFLOW_UD_CONTAINER_TYPE" UD_CONTAINER_FALLBACK_SINK = "fb-udsink" diff --git a/pynumaflow/accumulator/__init__.py b/pynumaflow/accumulator/__init__.py new file mode 100644 index 00000000..0d1368d8 --- /dev/null +++ b/pynumaflow/accumulator/__init__.py @@ -0,0 +1,19 @@ +from pynumaflow.accumulator._dtypes import ( + Message, + Datum, + IntervalWindow, + DROP, + KeyedWindow, + Accumulator, +) +from pynumaflow.accumulator.async_server import AccumulatorAsyncServer + +__all__ = [ + "Message", + "Datum", + "IntervalWindow", + "DROP", + "AccumulatorAsyncServer", + "KeyedWindow", + "Accumulator", +] diff --git a/pynumaflow/accumulator/_dtypes.py b/pynumaflow/accumulator/_dtypes.py new file mode 100644 index 00000000..31a0d5fe --- /dev/null +++ b/pynumaflow/accumulator/_dtypes.py @@ -0,0 +1,554 @@ +from abc import ABCMeta, abstractmethod +from asyncio import Task +from dataclasses import dataclass +from datetime import datetime +from enum import IntEnum +from typing import TypeVar, Callable, Union, Optional +from collections.abc import AsyncIterable + +from pynumaflow.shared.asynciter import NonBlockingIterator +from pynumaflow._constants import DROP + +M = TypeVar("M", bound="Message") + + +class WindowOperation(IntEnum): + """ + Enumerate the type of Window operation received. + """ + + OPEN = (0,) + CLOSE = (1,) + APPEND = (2,) + + +@dataclass(init=False) +class Datum: + """ + Class to define the important information for the event. + Args: + keys: the keys of the event. + value: the payload of the event. + event_time: the event time of the event. + watermark: the watermark of the event. + >>> # Example usage + >>> from pynumaflow.accumulator import Datum + >>> from datetime import datetime, timezone + >>> payload = bytes("test_mock_message", encoding="utf-8") + >>> t1 = datetime.fromtimestamp(1662998400, timezone.utc) + >>> t2 = datetime.fromtimestamp(1662998460, timezone.utc) + >>> msg_headers = {"key1": "value1", "key2": "value2"} + >>> d = Datum( + ... keys=["test_key"], + ... value=payload, + ... event_time=t1, + ... watermark=t2, + ... headers=msg_headers + ... ) + """ + + __slots__ = ("_keys", "_value", "_event_time", "_watermark", "_headers", "_id") + + _keys: list[str] + _value: bytes + _event_time: datetime + _watermark: datetime + _headers: dict[str, str] + _id: str + + def __init__( + self, + keys: list[str], + value: bytes, + event_time: datetime, + watermark: datetime, + id_: str, + headers: Optional[dict[str, str]] = None, + ): + self._keys = keys or list() + self._value = value or b"" + if not isinstance(event_time, datetime): + raise TypeError(f"Wrong data type: {type(event_time)} for Datum.event_time") + self._event_time = event_time + if not isinstance(watermark, datetime): + raise TypeError(f"Wrong data type: {type(watermark)} for Datum.watermark") + self._watermark = watermark + self._headers = headers or {} + self._id = id_ + + def keys(self) -> list[str]: + """Returns the keys of the event. + + Returns: + list[str]: A list of string keys associated with this event. + """ + return self._keys + + @property + def value(self) -> bytes: + """Returns the value of the event. + + Returns: + bytes: The payload data of the event as bytes. + """ + return self._value + + @property + def event_time(self) -> datetime: + """Returns the event time of the event. + + Returns: + datetime: The timestamp when the event occurred. + """ + return self._event_time + + @property + def watermark(self) -> datetime: + """Returns the watermark of the event. + + Returns: + datetime: The watermark timestamp indicating the progress of event time. + """ + return self._watermark + + @property + def headers(self) -> dict[str, str]: + """Returns the headers of the event. + + Returns: + dict[str, str]: A dictionary containing header key-value pairs for this event. + """ + return self._headers.copy() + + @property + def id(self) -> str: + """Returns the id of the event. + + Returns: + str: The unique identifier for this event. + """ + return self._id + + +@dataclass(init=False) +class IntervalWindow: + """Defines the start and end of the interval window for the event.""" + + __slots__ = ("_start", "_end") + + _start: datetime + _end: datetime + + def __init__(self, start: datetime, end: datetime): + self._start = start + self._end = end + + @property + def start(self) -> datetime: + """Returns the start point of the interval window. + + Returns: + datetime: The start timestamp of the interval window. + """ + return self._start + + @property + def end(self) -> datetime: + """Returns the end point of the interval window. + + Returns: + datetime: The end timestamp of the interval window. + """ + return self._end + + +@dataclass(init=False) +class KeyedWindow: + """ + Defines the window for a accumulator operation which includes the + interval window along with the slot. + """ + + __slots__ = ("_window", "_slot", "_keys") + + _window: IntervalWindow + _slot: str + _keys: list[str] + + def __init__(self, start: datetime, end: datetime, slot: str = "", keys: list[str] = []): + self._window = IntervalWindow(start=start, end=end) + self._slot = slot + self._keys = keys + + @property + def start(self) -> datetime: + """Returns the start point of the interval window. + + Returns: + datetime: The start timestamp of the interval window. + """ + return self._window.start + + @property + def end(self) -> datetime: + """Returns the end point of the interval window. + + Returns: + datetime: The end timestamp of the interval window. + """ + return self._window.end + + @property + def slot(self) -> str: + """Returns the slot from the window. + + Returns: + str: The slot identifier for this window. + """ + return self._slot + + @property + def window(self) -> IntervalWindow: + """Returns the interval window. + + Returns: + IntervalWindow: The underlying interval window object. + """ + return self._window + + @property + def keys(self) -> list[str]: + """Returns the keys for window. + + Returns: + list[str]: A list of keys associated with this window. + """ + return self._keys + + +@dataclass +class AccumulatorResult: + """Defines the object to hold the result of accumulator computation.""" + + __slots__ = ( + "_future", + "_iterator", + "_key", + "_result_queue", + "_consumer_future", + "_latest_watermark", + ) + + _future: Task + _iterator: NonBlockingIterator + _key: list[str] + _result_queue: NonBlockingIterator + _consumer_future: Task + _latest_watermark: datetime + + @property + def future(self) -> Task: + """Returns the future result of computation. + + Returns: + Task: The asyncio Task representing the computation future. + """ + return self._future + + @property + def iterator(self) -> NonBlockingIterator: + """Returns the handle to the producer queue. + + Returns: + NonBlockingIterator: The iterator for producing data to the queue. + """ + return self._iterator + + @property + def keys(self) -> list[str]: + """Returns the keys of the partition. + + Returns: + list[str]: The keys associated with this partition. + """ + return self._key + + @property + def result_queue(self) -> NonBlockingIterator: + """Returns the async queue used to write the output for the tasks. + + Returns: + NonBlockingIterator: The queue for writing task output. + """ + return self._result_queue + + @property + def consumer_future(self) -> Task: + """Returns the async consumer task for the result queue. + + Returns: + Task: The asyncio Task for consuming from the result queue. + """ + return self._consumer_future + + @property + def latest_watermark(self) -> datetime: + """Returns the latest watermark for task. + + Returns: + datetime: The latest watermark timestamp for this task. + """ + return self._latest_watermark + + def update_watermark(self, new_watermark: datetime): + """Updates the latest watermark value. + + Args: + new_watermark (datetime): The new watermark timestamp to set. + + Raises: + TypeError: If new_watermark is not a datetime object. + """ + if not isinstance(new_watermark, datetime): + raise TypeError("new_watermark must be a datetime object") + self._latest_watermark = new_watermark + + +@dataclass +class AccumulatorRequest: + """Defines the object to hold a request for the accumulator operation.""" + + __slots__ = ("_operation", "_keyed_window", "_payload") + + _operation: WindowOperation + _keyed_window: KeyedWindow + _payload: Datum + + def __init__(self, operation: WindowOperation, keyed_window: KeyedWindow, payload: Datum): + self._operation = operation + self._keyed_window = keyed_window + self._payload = payload + + @property + def operation(self) -> WindowOperation: + """Returns the operation type. + + Returns: + WindowOperation: The type of window operation (OPEN, CLOSE, or APPEND). + """ + return self._operation + + @property + def keyed_window(self) -> KeyedWindow: + """Returns the keyed window. + + Returns: + KeyedWindow: The keyed window associated with this request. + """ + return self._keyed_window + + @property + def payload(self) -> Datum: + """Returns the payload of the window. + + Returns: + Datum: The data payload for this accumulator request. + """ + return self._payload + + +@dataclass(init=False) +class Message: + """ + Basic datatype for data passing to the next vertex/vertices. + + Args: + value: data in bytes + keys: []string keys for vertex (optional) + tags: []string tags for conditional forwarding (optional) + watermark: watermark for this message (optional) + event_time: event time for this message (optional) + headers: headers for this message (optional) + id: message id (optional) + """ + + __slots__ = ("_value", "_keys", "_tags", "_watermark", "_event_time", "_headers", "_id") + + _value: bytes + _keys: list[str] + _tags: list[str] + _watermark: datetime + _event_time: datetime + _headers: dict[str, str] + _id: str + + def __init__( + self, + value: bytes, + keys: list[str] = None, + tags: list[str] = None, + watermark: datetime = None, + event_time: datetime = None, + headers: dict[str, str] = None, + id: str = None, + ): + """ + Creates a Message object to send value to a vertex. + """ + self._keys = keys or [] + self._tags = tags or [] + self._value = value or b"" + self._watermark = watermark + self._event_time = event_time + self._headers = headers or {} + self._id = id or "" + + @classmethod + def to_drop(cls: type[M]) -> M: + """Creates a Message instance that indicates the message should be dropped. + + Returns: + M: A Message instance with empty value and DROP tag indicating + the message should be dropped. + """ + return cls(b"", None, [DROP]) + + @property + def value(self) -> bytes: + """Returns the message payload value. + + Returns: + bytes: The message payload data as bytes. + """ + return self._value + + @property + def keys(self) -> list[str]: + """Returns the message keys. + + Returns: + list[str]: A list of string keys associated with this message. + """ + return self._keys + + @property + def tags(self) -> list[str]: + """Returns the message tags for conditional forwarding. + + Returns: + list[str]: A list of string tags used for conditional forwarding. + """ + return self._tags + + @property + def watermark(self) -> datetime: + """Returns the watermark timestamp for this message. + + Returns: + datetime: The watermark timestamp, or None if not set. + """ + return self._watermark + + @property + def event_time(self) -> datetime: + """Returns the event time for this message. + + Returns: + datetime: The event time timestamp, or None if not set. + """ + return self._event_time + + @property + def headers(self) -> dict[str, str]: + """Returns the message headers. + + Returns: + dict[str, str]: A dictionary containing header key-value pairs for this message. + """ + return self._headers.copy() + + @property + def id(self) -> str: + """Returns the message ID. + + Returns: + str: The unique identifier for this message. + """ + return self._id + + @classmethod + def from_datum(cls, datum: Datum): + """Create a Message instance from a Datum object. + + Args: + datum: The Datum object to convert + + Returns: + Message: A new Message instance with data from the datum + """ + return cls( + value=datum.value, + keys=datum.keys(), + watermark=datum.watermark, + event_time=datum.event_time, + headers=datum.headers, + id=datum.id, + ) + + +AccumulatorAsyncCallable = Callable[[list[str], AsyncIterable[Datum], NonBlockingIterator], None] + + +class Accumulator(metaclass=ABCMeta): + """ + Accumulate can read unordered from the input stream and emit the ordered + data to the output stream. Once the watermark (WM) of the output stream progresses, + the data in WAL until that WM will be garbage collected. + NOTE: A message can be silently dropped if need be, + and it will be cleared from the WAL when the WM progresses. + """ + + def __call__(self, *args, **kwargs): + """ + Allow to call handler function directly if class instance is sent + as the accumulator_instance. + """ + return self.handler(*args, **kwargs) + + @abstractmethod + async def handler( + self, + datums: AsyncIterable[Datum], + output: NonBlockingIterator, + ): + """ + Implement this handler function which implements the AccumulatorStreamCallable interface. + """ + pass + + +class _AccumulatorBuilderClass: + """ + Class to build an Accumulator class instance. + Used Internally + + Args: + accumulator_class: the Accumulator class to be used for Accumulator UDF + args: the arguments to be passed to the accumulator class + kwargs: the keyword arguments to be passed to the accumulator class + """ + + def __init__(self, accumulator_class: type[Accumulator], args: tuple, kwargs: dict): + self._accumulator_class: type[Accumulator] = accumulator_class + self._args = args + self._kwargs = kwargs + + def create(self) -> Accumulator: + """ + Create a new Accumulator instance. + """ + return self._accumulator_class(*self._args, **self._kwargs) + + +# AccumulatorStreamCallable is a callable which can be used as a handler for the Accumulator UDF. +AccumulatorStreamCallable = Union[AccumulatorAsyncCallable, type[Accumulator]] diff --git a/pynumaflow/accumulator/async_server.py b/pynumaflow/accumulator/async_server.py new file mode 100644 index 00000000..042359ca --- /dev/null +++ b/pynumaflow/accumulator/async_server.py @@ -0,0 +1,206 @@ +import inspect +from typing import Optional + +import aiorun +import grpc + +from pynumaflow.accumulator.servicer.async_servicer import AsyncAccumulatorServicer +from pynumaflow.info.types import ServerInfo, ContainerType, MINIMUM_NUMAFLOW_VERSION +from pynumaflow.proto.accumulator import accumulator_pb2_grpc + + +from pynumaflow._constants import ( + MAX_MESSAGE_SIZE, + NUM_THREADS_DEFAULT, + _LOGGER, + MAX_NUM_THREADS, + ACCUMULATOR_SOCK_PATH, + ACCUMULATOR_SERVER_INFO_FILE_PATH, +) + +from pynumaflow.accumulator._dtypes import ( + AccumulatorStreamCallable, + _AccumulatorBuilderClass, + Accumulator, +) + +from pynumaflow.shared.server import NumaflowServer, check_instance, start_async_server + + +def get_handler( + accumulator_handler: AccumulatorStreamCallable, + init_args: tuple = (), + init_kwargs: Optional[dict] = None, +): + """ + Get the correct handler type based on the arguments passed + """ + if inspect.isfunction(accumulator_handler): + if init_args or init_kwargs: + # if the init_args or init_kwargs are passed, then the accumulator_instance + # can only be of class Accumulator type + raise TypeError("Cannot pass function handler with init args or kwargs") + # return the function handler + return accumulator_handler + elif not check_instance(accumulator_handler, Accumulator) and issubclass( + accumulator_handler, Accumulator + ): + # if handler is type of Class Accumulator, create a new instance of + # a AccumulatorBuilderClass + return _AccumulatorBuilderClass(accumulator_handler, init_args, init_kwargs) + else: + _LOGGER.error( + _error_msg := f"Invalid Class Type {accumulator_handler}: " + f"Please make sure the class type is passed, and it is a subclass of Accumulator" + ) + raise TypeError(_error_msg) + + +class AccumulatorAsyncServer(NumaflowServer): + """ + Class for a new Accumulator Server instance. + A new servicer instance is created and attached to the server. + The server instance is returned. + Args: + accumulator_instance: The accumulator instance to be used for + Accumulator UDF + init_args: The arguments to be passed to the accumulator_handler + init_kwargs: The keyword arguments to be passed to the + accumulator_handler + sock_path: The UNIX socket path to be used for the server + max_message_size: The max message size in bytes the server can receive and send + max_threads: The max number of threads to be spawned; + defaults to 4 and max capped at 16 + server_info_file: The path to the server info file + Example invocation: + import os + from collections.abc import AsyncIterable + from datetime import datetime + + from pynumaflow.accumulator import Accumulator, AccumulatorAsyncServer + from pynumaflow.accumulator import ( + Message, + Datum, + ) + from pynumaflow.shared.asynciter import NonBlockingIterator + + class StreamSorter(Accumulator): + def __init__(self, counter): + self.latest_wm = datetime.fromtimestamp(-1) + self.sorted_buffer: list[Datum] = [] + + async def handler( + self, + datums: AsyncIterable[Datum], + output: NonBlockingIterator, + ): + async for _ in datums: + # Process the datums and send output + if datum.watermark and datum.watermark > self.latest_wm: + self.latest_wm = datum.watermark + await self.flush_buffer(output) + + self.insert_sorted(datum) + + def insert_sorted(self, datum: Datum): + # Binary insert to keep sorted buffer in order + left, right = 0, len(self.sorted_buffer) + while left < right: + mid = (left + right) // 2 + if self.sorted_buffer[mid].event_time > datum.event_time: + right = mid + else: + left = mid + 1 + self.sorted_buffer.insert(left, datum) + + async def flush_buffer(self, output: NonBlockingIterator): + i = 0 + for datum in self.sorted_buffer: + if datum.event_time > self.latest_wm: + break + await output.put(Message.from_datum(datum)) + i += 1 + # Remove flushed items + self.sorted_buffer = self.sorted_buffer[i:] + + + if __name__ == "__main__": + grpc_server = AccumulatorAsyncServer(StreamSorter) + grpc_server.start() + + """ + + def __init__( + self, + accumulator_instance: AccumulatorStreamCallable, + init_args: tuple = (), + init_kwargs: dict = None, + sock_path=ACCUMULATOR_SOCK_PATH, + max_message_size=MAX_MESSAGE_SIZE, + max_threads=NUM_THREADS_DEFAULT, + server_info_file=ACCUMULATOR_SERVER_INFO_FILE_PATH, + ): + """ + Create a new grpc Accumulator Server instance. + A new servicer instance is created and attached to the server. + The server instance is returned. + Args: + accumulator_instance: The Accumulator instance to be used for + Accumulator UDF + init_args: The arguments to be passed to the accumulator_handler + init_kwargs: The keyword arguments to be passed to the + accumulator_handler + sock_path: The UNIX socket path to be used for the server + max_message_size: The max message size in bytes the server can receive and send + max_threads: The max number of threads to be spawned; + defaults to 4 and max capped at 16 + server_info_file: The path to the server info file + """ + if init_kwargs is None: + init_kwargs = {} + self.accumulator_handler = get_handler(accumulator_instance, init_args, init_kwargs) + self.sock_path = f"unix://{sock_path}" + self.max_message_size = max_message_size + self.max_threads = min(max_threads, MAX_NUM_THREADS) + self.server_info_file = server_info_file + + self._server_options = [ + ("grpc.max_send_message_length", self.max_message_size), + ("grpc.max_receive_message_length", self.max_message_size), + ] + # Get the servicer instance for the async server + self.servicer = AsyncAccumulatorServicer(self.accumulator_handler) + + def start(self): + """ + Starter function for the Async server class, need a separate caller + so that all the async coroutines can be started from a single context + """ + _LOGGER.info( + "Starting Async Accumulator Server", + ) + aiorun.run(self.aexec(), use_uvloop=True) + + async def aexec(self): + """ + Starts the Async gRPC server on the given UNIX socket with + given max threads. + """ + # As the server is async, we need to create a new server instance in the + # same thread as the event loop so that all the async calls are made in the + # same context + # Create a new async server instance and add the servicer to it + server = grpc.aio.server(options=self._server_options) + server.add_insecure_port(self.sock_path) + accumulator_pb2_grpc.add_AccumulatorServicer_to_server(self.servicer, server) + + serv_info = ServerInfo.get_default_server_info() + serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Accumulator] + await start_async_server( + server_async=server, + sock_path=self.sock_path, + max_threads=self.max_threads, + cleanup_coroutines=list(), + server_info_file=self.server_info_file, + server_info=serv_info, + ) diff --git a/pynumaflow/accumulator/servicer/__init__.py b/pynumaflow/accumulator/servicer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pynumaflow/accumulator/servicer/async_servicer.py b/pynumaflow/accumulator/servicer/async_servicer.py new file mode 100644 index 00000000..6eebdbeb --- /dev/null +++ b/pynumaflow/accumulator/servicer/async_servicer.py @@ -0,0 +1,133 @@ +import asyncio +from collections.abc import AsyncIterable +from typing import Union + +from google.protobuf import empty_pb2 as _empty_pb2 + +from pynumaflow._constants import ERR_UDF_EXCEPTION_STRING +from pynumaflow.proto.accumulator import accumulator_pb2, accumulator_pb2_grpc +from pynumaflow.accumulator._dtypes import ( + Datum, + AccumulatorAsyncCallable, + _AccumulatorBuilderClass, + AccumulatorRequest, + KeyedWindow, +) +from pynumaflow.accumulator.servicer.task_manager import TaskManager +from pynumaflow.shared.server import handle_async_error +from pynumaflow.types import NumaflowServicerContext + + +async def datum_generator( + request_iterator: AsyncIterable[accumulator_pb2.AccumulatorRequest], +) -> AsyncIterable[AccumulatorRequest]: + """Generate a AccumulatorRequest from a AccumulatorRequest proto message.""" + async for d in request_iterator: + # Convert protobuf KeyedWindow to our KeyedWindow dataclass + keyed_window = KeyedWindow( + start=d.operation.keyedWindow.start.ToDatetime(), + end=d.operation.keyedWindow.end.ToDatetime(), + slot=d.operation.keyedWindow.slot, + keys=list(d.operation.keyedWindow.keys), + ) + + accumulator_request = AccumulatorRequest( + operation=d.operation.event, + keyed_window=keyed_window, # Use the new parameter name + payload=Datum( + keys=list(d.payload.keys), + value=d.payload.value, + event_time=d.payload.event_time.ToDatetime(), + watermark=d.payload.watermark.ToDatetime(), + id_=d.payload.id, + headers=dict(d.payload.headers), + ), + ) + yield accumulator_request + + +class AsyncAccumulatorServicer(accumulator_pb2_grpc.AccumulatorServicer): + """ + This class is used to create a new grpc Accumulator servicer instance. + Provides the functionality for the required rpc methods. + """ + + def __init__( + self, + handler: Union[AccumulatorAsyncCallable, _AccumulatorBuilderClass], + ): + # The accumulator handler can be a function or a builder class instance. + self.__accumulator_handler: Union[ + AccumulatorAsyncCallable, _AccumulatorBuilderClass + ] = handler + + async def AccumulateFn( + self, + request_iterator: AsyncIterable[accumulator_pb2.AccumulatorRequest], + context: NumaflowServicerContext, + ) -> accumulator_pb2.AccumulatorResponse: + """ + Applies a accumulator function to a datum stream. + The pascal case function name comes from the proto accumulator_pb2_grpc.py file. + """ + # Create a task manager instance + task_manager = TaskManager(handler=self.__accumulator_handler) + + # Create a consumer task to read from the result queue + # All the results from the accumulator function will be sent to the result queue + # We will read from the result queue and send the results to the client + consumer = task_manager.global_result_queue.read_iterator() + + # Create an async iterator from the request iterator + datum_iterator = datum_generator(request_iterator=request_iterator) + + # Create a process_input_stream task in the task manager, + # this would read from the datum iterator + # and then create the required tasks to process the data requests + # The results from these tasks are then sent to the result queue + producer = asyncio.create_task(task_manager.process_input_stream(datum_iterator)) + + # Start the consumer task where we read from the result queue + # and send the results to the client + # The task manager can write the following to the result queue: + # 1. A accumulator_pb2.AccumulatorResponse message + # This is the result of the accumulator function, it contains the window and the + # result of the accumulator function + # The result of the accumulator function is a accumulator_pb2.AccumulatorResponse message + # and can be directly sent to the client + # + # 2. An Exception + # Any exceptions that occur during the processing accumulator function tasks are + # sent to the result queue. We then forward these exception to the client + # + # 3. A accumulator_pb2.AccumulatorResponse message with EOF=True + # This is a special message that indicates the end of the processing for a window + # When we get this message, we send an EOF message to the client + try: + async for msg in consumer: + # If the message is an exception, we raise the exception + if isinstance(msg, BaseException): + await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING) + return + # Send window EOF response or Window result response + # back to the client + else: + yield msg + except BaseException as e: + await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + return + # Wait for the process_input_stream task to finish for a clean exit + try: + await producer + except BaseException as e: + await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING) + return + + async def IsReady( + self, request: _empty_pb2.Empty, context: NumaflowServicerContext + ) -> accumulator_pb2.ReadyResponse: + """ + IsReady is the heartbeat endpoint for gRPC. + The pascal case function name comes from the proto accumulator_pb2_grpc.py file. + """ + return accumulator_pb2.ReadyResponse(ready=True) diff --git a/pynumaflow/accumulator/servicer/task_manager.py b/pynumaflow/accumulator/servicer/task_manager.py new file mode 100644 index 00000000..a7c80968 --- /dev/null +++ b/pynumaflow/accumulator/servicer/task_manager.py @@ -0,0 +1,349 @@ +import asyncio +from collections.abc import AsyncIterable +from datetime import datetime +from typing import Union + +from google.protobuf import timestamp_pb2 +from pynumaflow._constants import ( + STREAM_EOF, + DELIMITER, + _LOGGER, +) +from pynumaflow.accumulator._dtypes import ( + AccumulatorResult, + Datum, + _AccumulatorBuilderClass, + AccumulatorAsyncCallable, + WindowOperation, +) +from pynumaflow.proto.accumulator import accumulator_pb2 +from pynumaflow.shared.asynciter import NonBlockingIterator + + +def build_unique_key_name(keys): + """ + Builds a unique key name for the given keys and window. + The key name is used to identify the Accumulator task. + The format is: start_time:end_time:key1:key2:... + """ + return f"{DELIMITER.join(keys)}" + + +def build_window_hash(window): + """ + Builds a hash for the given window. + The hash is used to identify the Accumulator Window + The format is: start_time:end_time + """ + return f"{window.start.ToMilliseconds()}:{window.end.ToMilliseconds()}" + + +class TaskManager: + """ + TaskManager is responsible for managing the Accumulator tasks. + It is created whenever a new accumulator operation is requested. + """ + + def __init__(self, handler: Union[AccumulatorAsyncCallable, _AccumulatorBuilderClass]): + # A dictionary to store the task information + self.tasks: dict[str, AccumulatorResult] = {} + # Collection for storing strong references to all running tasks. + # Event loop only keeps a weak reference, which can cause it to + # get lost during execution. + self.background_tasks = set() + # Handler for the accumulator operation + self.__accumulator_handler = handler + # Queue to store the results of the accumulator operation + # This queue is used to send the results to the client + # once the accumulator operation is completed. + # This queue is also used to send the error/exceptions to the client + # if the accumulator operation fails. + self.global_result_queue = NonBlockingIterator() + + def get_unique_windows(self): + """ + Returns the unique windows that are currently being processed + """ + # Dict to store unique windows + windows = dict() + # Iterate over all the tasks and add the windows + for task in self.tasks.values(): + window_hash = build_window_hash(task.window) + window_found = windows.get(window_hash, None) + # if window not seen yet, add to the dict + if not window_found: + windows[window_hash] = task.window + return windows + + def get_tasks(self): + """ + Returns the list of accumulator tasks that are + currently being processed + """ + return list(self.tasks.values()) + + async def stream_send_eof(self): + """ + Function used to indicate to all processing tasks that no + more requests are expected by sending EOF message to + local input streams of individual tasks. + This is called when the input grpc stream is closed. + """ + # Create a copy of the keys to avoid dictionary size change during iteration + task_keys = list(self.tasks.keys()) + for unified_key in task_keys: + await self.tasks[unified_key].iterator.put(STREAM_EOF) + + async def close_task(self, req): + """ + Closes a running accumulator task for a given key. + Based on the request we compute the unique key, and then + signal the corresponding task for it to closure. + The steps involve + 1. Send a signal to the local request queue of the task to stop reading + 2. Wait for the user function to complete + 3. Wait for all the results from the task to be written to the global result queue + 4. Remove the task from the tracker + """ + d = req.payload + keys = d.keys() + unified_key = build_unique_key_name(keys) + curr_task = self.tasks.get(unified_key, None) + + if curr_task: + await self.tasks[unified_key].iterator.put(STREAM_EOF) + await curr_task.future + await curr_task.consumer_future + self.tasks.pop(unified_key) + else: + _LOGGER.critical("accumulator task not found", exc_info=True) + err = BaseException("accumulator task not found") + # Put the exception in the result queue + await self.global_result_queue.put(err) + + async def create_task(self, req): + """ + Creates a new accumulator task for the given request. + Based on the request we compute a unique key, and then + it creates a new task or appends the request to the existing task. + """ + d = req.payload + keys = d.keys() + unified_key = build_unique_key_name(keys) + curr_task = self.tasks.get(unified_key, None) + + # If the task does not exist, create a new task + if not curr_task: + niter = NonBlockingIterator() + riter = niter.read_iterator() + # Create a new result queue for the current task + # We create a new result queue for each task, so that + # the results of the accumulator operation can be sent to the + # the global result queue, which in turn sends the results + # to the client. + res_queue = NonBlockingIterator() + + # Create a new write_to_global_queue task for the current, this will read from the + # result queue specifically for the current task and update the + # global result queue + consumer = asyncio.create_task( + self.write_to_global_queue(res_queue, self.global_result_queue, unified_key) + ) + # Save a reference to the result of this function, to avoid a + # task disappearing mid-execution. + self.background_tasks.add(consumer) + consumer.add_done_callback(self.clean_background) + + # Create a new task for the accumulator operation, this will invoke the + # Accumulator handler with the given keys, request iterator, and window. + task = asyncio.create_task(self.__invoke_accumulator(riter, res_queue)) + # Save a reference to the result of this function, to avoid a + # task disappearing mid-execution. + self.background_tasks.add(task) + task.add_done_callback(self.clean_background) + + # Create a new AccumulatorResult object to store the task information + curr_task = AccumulatorResult( + task, niter, keys, res_queue, consumer, datetime.fromtimestamp(-1) + ) + + # Save the result of the accumulator operation to the task list + self.tasks[unified_key] = curr_task + + # Put the request in the iterator + await curr_task.iterator.put(d) + + async def send_datum_to_task(self, req): + """ + Appends the request to the existing window reduce task. + If the task does not exist, create it. + """ + d = req.payload + keys = d.keys() + unified_key = build_unique_key_name(keys) + result = self.tasks.get(unified_key, None) + if not result: + await self.create_task(req) + else: + await result.iterator.put(d) + + async def __invoke_accumulator( + self, + request_iterator: AsyncIterable[Datum], + output: NonBlockingIterator, + ): + """ + Invokes the UDF accumulator handler with the given keys, + request iterator, and window. Returns the result of the + accumulator operation. + """ + new_instance = self.__accumulator_handler + + # If the accumulator handler is a class instance, create a new instance of it. + # It is required for a new key to be processed by a + # new instance of the accumulator for a given window + # Otherwise the function handler can be called directly + if isinstance(self.__accumulator_handler, _AccumulatorBuilderClass): + new_instance = self.__accumulator_handler.create() + try: + _ = await new_instance(request_iterator, output) + # send EOF to the output stream + await output.put(STREAM_EOF) + # If there is an error in the accumulator operation, log and + # then send the error to the result queue + except BaseException as err: + _LOGGER.critical("panic inside accumulator handle", exc_info=True) + # Put the exception in the result queue + await self.global_result_queue.put(err) + + async def process_input_stream( + self, request_iterator: AsyncIterable[accumulator_pb2.AccumulatorRequest] + ): + # Start iterating through the request iterator and create tasks + # based on the operation type received. + try: + request_count = 0 + async for request in request_iterator: + request_count += 1 + # check whether the request is an open or append operation + if request.operation is int(WindowOperation.OPEN): + # create a new task for the open operation and + # put the request in the task iterator + await self.create_task(request) + elif request.operation is int(WindowOperation.APPEND): + # append the task data to the existing task + # if the task does not exist, create a new task + await self.send_datum_to_task(request) + elif request.operation is int(WindowOperation.CLOSE): + # close the current task for req + await self.close_task(request) + else: + _LOGGER.debug(f"No operation matched for request: {request}", exc_info=True) + + # If there is an error in the accumulator operation, log and + # then send the error to the result queue + except BaseException as e: + err_msg = f"Accumulator Error: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + # Put the exception in the global result queue + await self.global_result_queue.put(e) + return + + try: + # send EOF to all the tasks once the request iterator is exhausted + # This will signal the tasks to stop reading the data on their + # respective iterators. + await self.stream_send_eof() + + # get the list of accumulator tasks that are currently being processed + # iterate through the tasks and wait for them to complete + for task in self.get_tasks(): + # Once this is done, we know that the task has written all the results + # to the local result queue + fut = task.future + await fut + + # Wait for the local queue to write + # all the results of this task to the global result queue + con_future = task.consumer_future + await con_future + self.tasks.clear() + + # Now send STREAM_EOF to terminate the global result queue iterator + await self.global_result_queue.put(STREAM_EOF) + except BaseException as e: + err_msg = f"Accumulator Streaming Error: {repr(e)}" + _LOGGER.critical(err_msg, exc_info=True) + await self.global_result_queue.put(e) + + async def write_to_global_queue( + self, input_queue: NonBlockingIterator, output_queue: NonBlockingIterator, unified_key: str + ): + """ + This function is used to route the messages from the + local result queue for a given task to the global result queue. + Once all messages are routed, it sends the window EOF messages for the same. + """ + reader = input_queue.read_iterator() + task = self.tasks[unified_key] + + wm: datetime = task.latest_watermark + async for msg in reader: + # Convert the window to a datetime object + # Only update watermark if msg.watermark is not None + if msg.watermark is not None and wm < msg.watermark: + task.update_watermark(msg.watermark) + self.tasks[unified_key] = task + wm = msg.watermark + + # Convert datetime to protobuf timestamp + event_time_pb = timestamp_pb2.Timestamp() + if msg.event_time is not None: + event_time_pb.FromDatetime(msg.event_time) + + watermark_pb = timestamp_pb2.Timestamp() + if msg.watermark is not None: + watermark_pb.FromDatetime(msg.watermark) + + start_dt_pb = timestamp_pb2.Timestamp() + start_dt_pb.FromDatetime(datetime.fromtimestamp(0)) + + end_dt_pb = timestamp_pb2.Timestamp() + end_dt_pb.FromDatetime(wm) + + res = accumulator_pb2.AccumulatorResponse( + payload=accumulator_pb2.Payload( + keys=msg.keys, + value=msg.value, + event_time=event_time_pb, + watermark=watermark_pb, + headers=msg.headers, + id=msg.id, + ), + window=accumulator_pb2.KeyedWindow( + start=start_dt_pb, end=end_dt_pb, slot="slot-0", keys=task.keys + ), + EOF=False, + tags=msg.tags, + ) + await output_queue.put(res) + # send EOF + start_eof_pb = timestamp_pb2.Timestamp() + start_eof_pb.FromDatetime(datetime.fromtimestamp(0)) + + end_eof_pb = timestamp_pb2.Timestamp() + end_eof_pb.FromDatetime(wm) + + res = accumulator_pb2.AccumulatorResponse( + window=accumulator_pb2.KeyedWindow( + start=start_eof_pb, end=end_eof_pb, slot="slot-0", keys=task.keys + ), + EOF=True, + ) + await output_queue.put(res) + + def clean_background(self, task): + """ + Remove the task from the background tasks collection + """ + self.background_tasks.remove(task) diff --git a/pynumaflow/info/types.py b/pynumaflow/info/types.py index 2845c264..12375a70 100644 --- a/pynumaflow/info/types.py +++ b/pynumaflow/info/types.py @@ -71,6 +71,7 @@ class ContainerType(str, Enum): Sessionreducer = "sessionreducer" Sideinput = "sideinput" Fbsinker = "fb-sinker" + Accumulator = "accumulator" # Minimum version of Numaflow required by the current SDK version @@ -86,6 +87,7 @@ class ContainerType(str, Enum): ContainerType.Sessionreducer: "1.4.0-z", ContainerType.Sideinput: "1.4.0-z", ContainerType.Fbsinker: "1.4.0-z", + ContainerType.Accumulator: "1.5.0-z", } diff --git a/pynumaflow/proto/accumulator/__init__.py b/pynumaflow/proto/accumulator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pynumaflow/proto/accumulator/accumulator.proto b/pynumaflow/proto/accumulator/accumulator.proto new file mode 100644 index 00000000..acde986b --- /dev/null +++ b/pynumaflow/proto/accumulator/accumulator.proto @@ -0,0 +1,90 @@ +/* +Copyright 2022 The Numaproj Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +syntax = "proto3"; + +option go_package = "github.com/numaproj/numaflow-go/pkg/apis/proto/accumulator/v1"; +option java_package = "io.numaproj.numaflow.accumulator.v1"; + +import "google/protobuf/empty.proto"; +import "google/protobuf/timestamp.proto"; + + +package accumulator.v1; + +// AccumulatorWindow describes a special kind of SessionWindow (similar to Global Window) where output should +// always have monotonically increasing WM but it can be manipulated through event-time by reordering the messages. +// NOTE: Quite powerful, should not be abused; it can cause stalling of pipelines and leaks +service Accumulator { + // AccumulateFn applies a accumulate function to a request stream. + rpc AccumulateFn(stream AccumulatorRequest) returns (stream AccumulatorResponse); + + // IsReady is the heartbeat endpoint for gRPC. + rpc IsReady(google.protobuf.Empty) returns (ReadyResponse); +} + +// Payload represents a payload element. +message Payload { + repeated string keys = 1; + bytes value = 2; + google.protobuf.Timestamp event_time = 3; + google.protobuf.Timestamp watermark = 4; + string id = 5; + map headers = 6; +} + +// AccumulatorRequest represents a request element. +message AccumulatorRequest { + // WindowOperation represents a window operation. + // For Unaligned windows, OPEN, APPEND and CLOSE events are sent. + message WindowOperation { + enum Event { + OPEN = 0; + CLOSE = 1; + APPEND = 2; + } + Event event = 1; + KeyedWindow keyedWindow = 2; + } + + Payload payload = 1; + WindowOperation operation = 2; +} + + +// Window represents a window. +message KeyedWindow { + google.protobuf.Timestamp start = 1; + google.protobuf.Timestamp end = 2; + string slot = 3; + repeated string keys = 4; +} + +// AccumulatorResponse represents a response element. +message AccumulatorResponse { + Payload payload = 1; + // window represents a window to which the result belongs. + KeyedWindow window = 2; + repeated string tags = 3; + // EOF represents the end of the response for a window. + bool EOF = 4; +} + + +// ReadyResponse is the health check result. +message ReadyResponse { + bool ready = 1; +} \ No newline at end of file diff --git a/pynumaflow/proto/accumulator/accumulator_pb2.py b/pynumaflow/proto/accumulator/accumulator_pb2.py new file mode 100644 index 00000000..f1e8ec8d --- /dev/null +++ b/pynumaflow/proto/accumulator/accumulator_pb2.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: accumulator.proto +# Protobuf Python Version: 4.25.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 +from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x11\x61\x63\x63umulator.proto\x12\x0e\x61\x63\x63umulator.v1\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1fgoogle/protobuf/timestamp.proto"\xf8\x01\n\x07Payload\x12\x0c\n\x04keys\x18\x01 \x03(\t\x12\r\n\x05value\x18\x02 \x01(\x0c\x12.\n\nevent_time\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12-\n\twatermark\x18\x04 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\n\n\x02id\x18\x05 \x01(\t\x12\x35\n\x07headers\x18\x06 \x03(\x0b\x32$.accumulator.v1.Payload.HeadersEntry\x1a.\n\x0cHeadersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01"\xbe\x02\n\x12\x41\x63\x63umulatorRequest\x12(\n\x07payload\x18\x01 \x01(\x0b\x32\x17.accumulator.v1.Payload\x12\x45\n\toperation\x18\x02 \x01(\x0b\x32\x32.accumulator.v1.AccumulatorRequest.WindowOperation\x1a\xb6\x01\n\x0fWindowOperation\x12G\n\x05\x65vent\x18\x01 \x01(\x0e\x32\x38.accumulator.v1.AccumulatorRequest.WindowOperation.Event\x12\x30\n\x0bkeyedWindow\x18\x02 \x01(\x0b\x32\x1b.accumulator.v1.KeyedWindow"(\n\x05\x45vent\x12\x08\n\x04OPEN\x10\x00\x12\t\n\x05\x43LOSE\x10\x01\x12\n\n\x06\x41PPEND\x10\x02"}\n\x0bKeyedWindow\x12)\n\x05start\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\'\n\x03\x65nd\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0c\n\x04slot\x18\x03 \x01(\t\x12\x0c\n\x04keys\x18\x04 \x03(\t"\x87\x01\n\x13\x41\x63\x63umulatorResponse\x12(\n\x07payload\x18\x01 \x01(\x0b\x32\x17.accumulator.v1.Payload\x12+\n\x06window\x18\x02 \x01(\x0b\x32\x1b.accumulator.v1.KeyedWindow\x12\x0c\n\x04tags\x18\x03 \x03(\t\x12\x0b\n\x03\x45OF\x18\x04 \x01(\x08"\x1e\n\rReadyResponse\x12\r\n\x05ready\x18\x01 \x01(\x08\x32\xac\x01\n\x0b\x41\x63\x63umulator\x12[\n\x0c\x41\x63\x63umulateFn\x12".accumulator.v1.AccumulatorRequest\x1a#.accumulator.v1.AccumulatorResponse(\x01\x30\x01\x12@\n\x07IsReady\x12\x16.google.protobuf.Empty\x1a\x1d.accumulator.v1.ReadyResponseBd\n#io.numaproj.numaflow.accumulator.v1Z=github.com/numaproj/numaflow-go/pkg/apis/proto/accumulator/v1b\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "accumulator_pb2", _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + _globals["DESCRIPTOR"]._options = None + _globals[ + "DESCRIPTOR" + ]._serialized_options = b"\n#io.numaproj.numaflow.accumulator.v1Z=github.com/numaproj/numaflow-go/pkg/apis/proto/accumulator/v1" + _globals["_PAYLOAD_HEADERSENTRY"]._options = None + _globals["_PAYLOAD_HEADERSENTRY"]._serialized_options = b"8\001" + _globals["_PAYLOAD"]._serialized_start = 100 + _globals["_PAYLOAD"]._serialized_end = 348 + _globals["_PAYLOAD_HEADERSENTRY"]._serialized_start = 302 + _globals["_PAYLOAD_HEADERSENTRY"]._serialized_end = 348 + _globals["_ACCUMULATORREQUEST"]._serialized_start = 351 + _globals["_ACCUMULATORREQUEST"]._serialized_end = 669 + _globals["_ACCUMULATORREQUEST_WINDOWOPERATION"]._serialized_start = 487 + _globals["_ACCUMULATORREQUEST_WINDOWOPERATION"]._serialized_end = 669 + _globals["_ACCUMULATORREQUEST_WINDOWOPERATION_EVENT"]._serialized_start = 629 + _globals["_ACCUMULATORREQUEST_WINDOWOPERATION_EVENT"]._serialized_end = 669 + _globals["_KEYEDWINDOW"]._serialized_start = 671 + _globals["_KEYEDWINDOW"]._serialized_end = 796 + _globals["_ACCUMULATORRESPONSE"]._serialized_start = 799 + _globals["_ACCUMULATORRESPONSE"]._serialized_end = 934 + _globals["_READYRESPONSE"]._serialized_start = 936 + _globals["_READYRESPONSE"]._serialized_end = 966 + _globals["_ACCUMULATOR"]._serialized_start = 969 + _globals["_ACCUMULATOR"]._serialized_end = 1141 +# @@protoc_insertion_point(module_scope) diff --git a/pynumaflow/proto/accumulator/accumulator_pb2.pyi b/pynumaflow/proto/accumulator/accumulator_pb2.pyi new file mode 100644 index 00000000..d9f0f7a5 --- /dev/null +++ b/pynumaflow/proto/accumulator/accumulator_pb2.pyi @@ -0,0 +1,122 @@ +from google.protobuf import empty_pb2 as _empty_pb2 +from google.protobuf import timestamp_pb2 as _timestamp_pb2 +from google.protobuf.internal import containers as _containers +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ( + ClassVar as _ClassVar, + Iterable as _Iterable, + Mapping as _Mapping, + Optional as _Optional, + Union as _Union, +) + +DESCRIPTOR: _descriptor.FileDescriptor + +class Payload(_message.Message): + __slots__ = ("keys", "value", "event_time", "watermark", "id", "headers") + + class HeadersEntry(_message.Message): + __slots__ = ("key", "value") + KEY_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + key: str + value: str + def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... + KEYS_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] + EVENT_TIME_FIELD_NUMBER: _ClassVar[int] + WATERMARK_FIELD_NUMBER: _ClassVar[int] + ID_FIELD_NUMBER: _ClassVar[int] + HEADERS_FIELD_NUMBER: _ClassVar[int] + keys: _containers.RepeatedScalarFieldContainer[str] + value: bytes + event_time: _timestamp_pb2.Timestamp + watermark: _timestamp_pb2.Timestamp + id: str + headers: _containers.ScalarMap[str, str] + def __init__( + self, + keys: _Optional[_Iterable[str]] = ..., + value: _Optional[bytes] = ..., + event_time: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., + watermark: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., + id: _Optional[str] = ..., + headers: _Optional[_Mapping[str, str]] = ..., + ) -> None: ... + +class AccumulatorRequest(_message.Message): + __slots__ = ("payload", "operation") + + class WindowOperation(_message.Message): + __slots__ = ("event", "keyedWindow") + + class Event(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + OPEN: _ClassVar[AccumulatorRequest.WindowOperation.Event] + CLOSE: _ClassVar[AccumulatorRequest.WindowOperation.Event] + APPEND: _ClassVar[AccumulatorRequest.WindowOperation.Event] + OPEN: AccumulatorRequest.WindowOperation.Event + CLOSE: AccumulatorRequest.WindowOperation.Event + APPEND: AccumulatorRequest.WindowOperation.Event + EVENT_FIELD_NUMBER: _ClassVar[int] + KEYEDWINDOW_FIELD_NUMBER: _ClassVar[int] + event: AccumulatorRequest.WindowOperation.Event + keyedWindow: KeyedWindow + def __init__( + self, + event: _Optional[_Union[AccumulatorRequest.WindowOperation.Event, str]] = ..., + keyedWindow: _Optional[_Union[KeyedWindow, _Mapping]] = ..., + ) -> None: ... + PAYLOAD_FIELD_NUMBER: _ClassVar[int] + OPERATION_FIELD_NUMBER: _ClassVar[int] + payload: Payload + operation: AccumulatorRequest.WindowOperation + def __init__( + self, + payload: _Optional[_Union[Payload, _Mapping]] = ..., + operation: _Optional[_Union[AccumulatorRequest.WindowOperation, _Mapping]] = ..., + ) -> None: ... + +class KeyedWindow(_message.Message): + __slots__ = ("start", "end", "slot", "keys") + START_FIELD_NUMBER: _ClassVar[int] + END_FIELD_NUMBER: _ClassVar[int] + SLOT_FIELD_NUMBER: _ClassVar[int] + KEYS_FIELD_NUMBER: _ClassVar[int] + start: _timestamp_pb2.Timestamp + end: _timestamp_pb2.Timestamp + slot: str + keys: _containers.RepeatedScalarFieldContainer[str] + def __init__( + self, + start: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., + end: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., + slot: _Optional[str] = ..., + keys: _Optional[_Iterable[str]] = ..., + ) -> None: ... + +class AccumulatorResponse(_message.Message): + __slots__ = ("payload", "window", "tags", "EOF") + PAYLOAD_FIELD_NUMBER: _ClassVar[int] + WINDOW_FIELD_NUMBER: _ClassVar[int] + TAGS_FIELD_NUMBER: _ClassVar[int] + EOF_FIELD_NUMBER: _ClassVar[int] + payload: Payload + window: KeyedWindow + tags: _containers.RepeatedScalarFieldContainer[str] + EOF: bool + def __init__( + self, + payload: _Optional[_Union[Payload, _Mapping]] = ..., + window: _Optional[_Union[KeyedWindow, _Mapping]] = ..., + tags: _Optional[_Iterable[str]] = ..., + EOF: bool = ..., + ) -> None: ... + +class ReadyResponse(_message.Message): + __slots__ = ("ready",) + READY_FIELD_NUMBER: _ClassVar[int] + ready: bool + def __init__(self, ready: bool = ...) -> None: ... diff --git a/pynumaflow/proto/accumulator/accumulator_pb2_grpc.py b/pynumaflow/proto/accumulator/accumulator_pb2_grpc.py new file mode 100644 index 00000000..f41606dd --- /dev/null +++ b/pynumaflow/proto/accumulator/accumulator_pb2_grpc.py @@ -0,0 +1,134 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from . import accumulator_pb2 as accumulator__pb2 +from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 + + +class AccumulatorStub(object): + """AccumulatorWindow describes a special kind of SessionWindow (similar to Global Window) where output should + always have monotonically increasing WM but it can be manipulated through event-time by reordering the messages. + NOTE: Quite powerful, should not be abused; it can cause stalling of pipelines and leaks + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.AccumulateFn = channel.stream_stream( + "/accumulator.v1.Accumulator/AccumulateFn", + request_serializer=accumulator__pb2.AccumulatorRequest.SerializeToString, + response_deserializer=accumulator__pb2.AccumulatorResponse.FromString, + ) + self.IsReady = channel.unary_unary( + "/accumulator.v1.Accumulator/IsReady", + request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + response_deserializer=accumulator__pb2.ReadyResponse.FromString, + ) + + +class AccumulatorServicer(object): + """AccumulatorWindow describes a special kind of SessionWindow (similar to Global Window) where output should + always have monotonically increasing WM but it can be manipulated through event-time by reordering the messages. + NOTE: Quite powerful, should not be abused; it can cause stalling of pipelines and leaks + """ + + def AccumulateFn(self, request_iterator, context): + """AccumulateFn applies a accumulate function to a request stream.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def IsReady(self, request, context): + """IsReady is the heartbeat endpoint for gRPC.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + +def add_AccumulatorServicer_to_server(servicer, server): + rpc_method_handlers = { + "AccumulateFn": grpc.stream_stream_rpc_method_handler( + servicer.AccumulateFn, + request_deserializer=accumulator__pb2.AccumulatorRequest.FromString, + response_serializer=accumulator__pb2.AccumulatorResponse.SerializeToString, + ), + "IsReady": grpc.unary_unary_rpc_method_handler( + servicer.IsReady, + request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + response_serializer=accumulator__pb2.ReadyResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + "accumulator.v1.Accumulator", rpc_method_handlers + ) + server.add_generic_rpc_handlers((generic_handler,)) + + +# This class is part of an EXPERIMENTAL API. +class Accumulator(object): + """AccumulatorWindow describes a special kind of SessionWindow (similar to Global Window) where output should + always have monotonically increasing WM but it can be manipulated through event-time by reordering the messages. + NOTE: Quite powerful, should not be abused; it can cause stalling of pipelines and leaks + """ + + @staticmethod + def AccumulateFn( + request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.stream_stream( + request_iterator, + target, + "/accumulator.v1.Accumulator/AccumulateFn", + accumulator__pb2.AccumulatorRequest.SerializeToString, + accumulator__pb2.AccumulatorResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def IsReady( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/accumulator.v1.Accumulator/IsReady", + google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + accumulator__pb2.ReadyResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/pynumaflow/proto/mapper/map_pb2.pyi b/pynumaflow/proto/mapper/map_pb2.pyi index e1279ff0..9832bc3e 100644 --- a/pynumaflow/proto/mapper/map_pb2.pyi +++ b/pynumaflow/proto/mapper/map_pb2.pyi @@ -26,7 +26,6 @@ class MapRequest(_message.Message): key: str value: str def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... - KEYS_FIELD_NUMBER: _ClassVar[int] VALUE_FIELD_NUMBER: _ClassVar[int] EVENT_TIME_FIELD_NUMBER: _ClassVar[int] @@ -45,7 +44,6 @@ class MapRequest(_message.Message): watermark: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., headers: _Optional[_Mapping[str, str]] = ..., ) -> None: ... - REQUEST_FIELD_NUMBER: _ClassVar[int] ID_FIELD_NUMBER: _ClassVar[int] HANDSHAKE_FIELD_NUMBER: _ClassVar[int] @@ -91,7 +89,6 @@ class MapResponse(_message.Message): value: _Optional[bytes] = ..., tags: _Optional[_Iterable[str]] = ..., ) -> None: ... - RESULTS_FIELD_NUMBER: _ClassVar[int] ID_FIELD_NUMBER: _ClassVar[int] HANDSHAKE_FIELD_NUMBER: _ClassVar[int] diff --git a/pynumaflow/proto/reducer/reduce_pb2.pyi b/pynumaflow/proto/reducer/reduce_pb2.pyi index 2c4b248c..88b27d53 100644 --- a/pynumaflow/proto/reducer/reduce_pb2.pyi +++ b/pynumaflow/proto/reducer/reduce_pb2.pyi @@ -48,7 +48,6 @@ class ReduceRequest(_message.Message): key: str value: str def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... - KEYS_FIELD_NUMBER: _ClassVar[int] VALUE_FIELD_NUMBER: _ClassVar[int] EVENT_TIME_FIELD_NUMBER: _ClassVar[int] @@ -67,7 +66,6 @@ class ReduceRequest(_message.Message): watermark: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., headers: _Optional[_Mapping[str, str]] = ..., ) -> None: ... - PAYLOAD_FIELD_NUMBER: _ClassVar[int] OPERATION_FIELD_NUMBER: _ClassVar[int] payload: ReduceRequest.Payload @@ -110,7 +108,6 @@ class ReduceResponse(_message.Message): value: _Optional[bytes] = ..., tags: _Optional[_Iterable[str]] = ..., ) -> None: ... - RESULT_FIELD_NUMBER: _ClassVar[int] WINDOW_FIELD_NUMBER: _ClassVar[int] EOF_FIELD_NUMBER: _ClassVar[int] diff --git a/pynumaflow/proto/sinker/sink_pb2.pyi b/pynumaflow/proto/sinker/sink_pb2.pyi index 18d4d3b6..78926321 100644 --- a/pynumaflow/proto/sinker/sink_pb2.pyi +++ b/pynumaflow/proto/sinker/sink_pb2.pyi @@ -37,7 +37,6 @@ class SinkRequest(_message.Message): key: str value: str def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... - KEYS_FIELD_NUMBER: _ClassVar[int] VALUE_FIELD_NUMBER: _ClassVar[int] EVENT_TIME_FIELD_NUMBER: _ClassVar[int] @@ -59,7 +58,6 @@ class SinkRequest(_message.Message): id: _Optional[str] = ..., headers: _Optional[_Mapping[str, str]] = ..., ) -> None: ... - REQUEST_FIELD_NUMBER: _ClassVar[int] STATUS_FIELD_NUMBER: _ClassVar[int] HANDSHAKE_FIELD_NUMBER: _ClassVar[int] @@ -108,7 +106,6 @@ class SinkResponse(_message.Message): status: _Optional[_Union[Status, str]] = ..., err_msg: _Optional[str] = ..., ) -> None: ... - RESULTS_FIELD_NUMBER: _ClassVar[int] HANDSHAKE_FIELD_NUMBER: _ClassVar[int] STATUS_FIELD_NUMBER: _ClassVar[int] diff --git a/pynumaflow/proto/sourcer/source_pb2.pyi b/pynumaflow/proto/sourcer/source_pb2.pyi index 8f588410..f2cdc70e 100644 --- a/pynumaflow/proto/sourcer/source_pb2.pyi +++ b/pynumaflow/proto/sourcer/source_pb2.pyi @@ -32,7 +32,6 @@ class ReadRequest(_message.Message): def __init__( self, num_records: _Optional[int] = ..., timeout_in_ms: _Optional[int] = ... ) -> None: ... - REQUEST_FIELD_NUMBER: _ClassVar[int] HANDSHAKE_FIELD_NUMBER: _ClassVar[int] request: ReadRequest.Request @@ -56,7 +55,6 @@ class ReadResponse(_message.Message): key: str value: str def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... - PAYLOAD_FIELD_NUMBER: _ClassVar[int] OFFSET_FIELD_NUMBER: _ClassVar[int] EVENT_TIME_FIELD_NUMBER: _ClassVar[int] @@ -107,7 +105,6 @@ class ReadResponse(_message.Message): error: _Optional[_Union[ReadResponse.Status.Error, str]] = ..., msg: _Optional[str] = ..., ) -> None: ... - RESULT_FIELD_NUMBER: _ClassVar[int] STATUS_FIELD_NUMBER: _ClassVar[int] HANDSHAKE_FIELD_NUMBER: _ClassVar[int] @@ -131,7 +128,6 @@ class AckRequest(_message.Message): def __init__( self, offsets: _Optional[_Iterable[_Union[Offset, _Mapping]]] = ... ) -> None: ... - REQUEST_FIELD_NUMBER: _ClassVar[int] HANDSHAKE_FIELD_NUMBER: _ClassVar[int] request: AckRequest.Request @@ -152,7 +148,6 @@ class AckResponse(_message.Message): def __init__( self, success: _Optional[_Union[_empty_pb2.Empty, _Mapping]] = ... ) -> None: ... - RESULT_FIELD_NUMBER: _ClassVar[int] HANDSHAKE_FIELD_NUMBER: _ClassVar[int] result: AckResponse.Result @@ -177,7 +172,6 @@ class PendingResponse(_message.Message): COUNT_FIELD_NUMBER: _ClassVar[int] count: int def __init__(self, count: _Optional[int] = ...) -> None: ... - RESULT_FIELD_NUMBER: _ClassVar[int] result: PendingResponse.Result def __init__( @@ -192,7 +186,6 @@ class PartitionsResponse(_message.Message): PARTITIONS_FIELD_NUMBER: _ClassVar[int] partitions: _containers.RepeatedScalarFieldContainer[int] def __init__(self, partitions: _Optional[_Iterable[int]] = ...) -> None: ... - RESULT_FIELD_NUMBER: _ClassVar[int] result: PartitionsResponse.Result def __init__( diff --git a/pynumaflow/proto/sourcetransformer/transform_pb2.pyi b/pynumaflow/proto/sourcetransformer/transform_pb2.pyi index 1fe8cb08..cc8fe420 100644 --- a/pynumaflow/proto/sourcetransformer/transform_pb2.pyi +++ b/pynumaflow/proto/sourcetransformer/transform_pb2.pyi @@ -32,7 +32,6 @@ class SourceTransformRequest(_message.Message): key: str value: str def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... - KEYS_FIELD_NUMBER: _ClassVar[int] VALUE_FIELD_NUMBER: _ClassVar[int] EVENT_TIME_FIELD_NUMBER: _ClassVar[int] @@ -54,7 +53,6 @@ class SourceTransformRequest(_message.Message): headers: _Optional[_Mapping[str, str]] = ..., id: _Optional[str] = ..., ) -> None: ... - REQUEST_FIELD_NUMBER: _ClassVar[int] HANDSHAKE_FIELD_NUMBER: _ClassVar[int] request: SourceTransformRequest.Request @@ -85,7 +83,6 @@ class SourceTransformResponse(_message.Message): event_time: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., tags: _Optional[_Iterable[str]] = ..., ) -> None: ... - RESULTS_FIELD_NUMBER: _ClassVar[int] ID_FIELD_NUMBER: _ClassVar[int] HANDSHAKE_FIELD_NUMBER: _ClassVar[int] diff --git a/tests/accumulator/__init__.py b/tests/accumulator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/accumulator/test_async_accumulator.py b/tests/accumulator/test_async_accumulator.py new file mode 100644 index 00000000..292e3687 --- /dev/null +++ b/tests/accumulator/test_async_accumulator.py @@ -0,0 +1,476 @@ +import asyncio +import logging +import threading +import unittest +from collections.abc import AsyncIterable + +import grpc +from google.protobuf import empty_pb2 as _empty_pb2 +from grpc.aio._server import Server + +from pynumaflow import setup_logging +from pynumaflow.accumulator import ( + Message, + Datum, + AccumulatorAsyncServer, + Accumulator, +) +from pynumaflow.proto.accumulator import accumulator_pb2, accumulator_pb2_grpc +from pynumaflow.shared.asynciter import NonBlockingIterator +from tests.testing_utils import ( + mock_message, + mock_interval_window_start, + mock_interval_window_end, + get_time_args, +) + +LOGGER = setup_logging(__name__) + + +def request_generator(count, request, resetkey: bool = False, send_close: bool = False): + for i in range(count): + if resetkey: + # Clear previous keys and add new ones + del request.payload.keys[:] + request.payload.keys.extend([f"key-{i}"]) + + # Set operation based on index - first is OPEN, rest are APPEND + if i == 0: + request.operation.event = accumulator_pb2.AccumulatorRequest.WindowOperation.Event.OPEN + else: + request.operation.event = ( + accumulator_pb2.AccumulatorRequest.WindowOperation.Event.APPEND + ) + yield request + + if send_close: + # Send a close operation after all requests + request.operation.event = accumulator_pb2.AccumulatorRequest.WindowOperation.Event.CLOSE + yield request + + +def request_generator_append_only(count, request, resetkey: bool = False): + for i in range(count): + if resetkey: + # Clear previous keys and add new ones + del request.payload.keys[:] + request.payload.keys.extend([f"key-{i}"]) + + # Set operation to APPEND for all requests + request.operation.event = accumulator_pb2.AccumulatorRequest.WindowOperation.Event.APPEND + yield request + + +def request_generator_mixed(count, request, resetkey: bool = False): + for i in range(count): + if resetkey: + # Clear previous keys and add new ones + del request.payload.keys[:] + request.payload.keys.extend([f"key-{i}"]) + + if i % 2 == 0: + # Set operation to APPEND for even requests + request.operation.event = ( + accumulator_pb2.AccumulatorRequest.WindowOperation.Event.APPEND + ) + else: + # Set operation to CLOSE for odd requests + request.operation.event = accumulator_pb2.AccumulatorRequest.WindowOperation.Event.CLOSE + yield request + + +def start_request() -> accumulator_pb2.AccumulatorRequest: + event_time_timestamp, watermark_timestamp = get_time_args() + window = accumulator_pb2.KeyedWindow( + start=mock_interval_window_start(), + end=mock_interval_window_end(), + slot="slot-0", + keys=["test_key"], + ) + payload = accumulator_pb2.Payload( + keys=["test_key"], + value=mock_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + id="test_id", + ) + operation = accumulator_pb2.AccumulatorRequest.WindowOperation( + event=accumulator_pb2.AccumulatorRequest.WindowOperation.Event.OPEN, + keyedWindow=window, + ) + request = accumulator_pb2.AccumulatorRequest( + payload=payload, + operation=operation, + ) + return request + + +def start_request_without_open() -> accumulator_pb2.AccumulatorRequest: + event_time_timestamp, watermark_timestamp = get_time_args() + + payload = accumulator_pb2.Payload( + keys=["test_key"], + value=mock_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + id="test_id", + ) + + request = accumulator_pb2.AccumulatorRequest( + payload=payload, + ) + return request + + +_s: Server = None +_channel = grpc.insecure_channel("unix:///tmp/accumulator.sock") +_loop = None + + +def startup_callable(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + + +class ExampleClass(Accumulator): + def __init__(self, counter): + self.counter = counter + + async def handler(self, datums: AsyncIterable[Datum], output: NonBlockingIterator): + async for datum in datums: + self.counter += 1 + msg = f"counter:{self.counter}" + await output.put(Message(str.encode(msg), keys=datum.keys(), tags=[])) + + +async def accumulator_handler_func(datums: AsyncIterable[Datum], output: NonBlockingIterator): + counter = 0 + async for datum in datums: + counter += 1 + msg = f"counter:{counter}" + await output.put(Message(str.encode(msg), keys=datum.keys(), tags=[])) + + +def NewAsyncAccumulator(): + server_instance = AccumulatorAsyncServer(ExampleClass, init_args=(0,)) + udfs = server_instance.servicer + return udfs + + +async def start_server(udfs): + server = grpc.aio.server() + accumulator_pb2_grpc.add_AccumulatorServicer_to_server(udfs, server) + listen_addr = "unix:///tmp/accumulator.sock" + server.add_insecure_port(listen_addr) + logging.info("Starting server on %s", listen_addr) + global _s + _s = server + await server.start() + await server.wait_for_termination() + + +class TestAsyncAccumulator(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + global _loop + loop = asyncio.new_event_loop() + _loop = loop + _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) + _thread.start() + udfs = NewAsyncAccumulator() + asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) + while True: + try: + with grpc.insecure_channel("unix:///tmp/accumulator.sock") as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") + LOGGER.error(e) + + @classmethod + def tearDownClass(cls) -> None: + try: + _loop.stop() + LOGGER.info("stopped the event loop") + except Exception as e: + LOGGER.error(e) + + def test_accumulate(self) -> None: + stub = self.__stub() + request = start_request() + generator_response = None + try: + generator_response = stub.AccumulateFn( + request_iterator=request_generator(count=5, request=request) + ) + except grpc.RpcError as e: + logging.error(e) + + # capture the output from the AccumulateFn generator and assert. + count = 0 + eof_count = 0 + for r in generator_response: + if hasattr(r, "payload") and r.payload.value: + count += 1 + # Each datum should increment the counter + expected_msg = f"counter:{count}" + self.assertEqual( + bytes(expected_msg, encoding="utf-8"), + r.payload.value, + ) + self.assertEqual(r.EOF, False) + # Check that keys are preserved + self.assertEqual(list(r.payload.keys), ["test_key"]) + else: + self.assertEqual(r.EOF, True) + eof_count += 1 + + # We should have received 5 messages (one for each datum) + self.assertEqual(5, count) + self.assertEqual(1, eof_count) + + def test_accumulate_with_multiple_keys(self) -> None: + stub = self.__stub() + request = start_request() + generator_response = None + try: + generator_response = stub.AccumulateFn( + request_iterator=request_generator(count=10, request=request, resetkey=True), + ) + except grpc.RpcError as e: + LOGGER.error(e) + + count = 0 + eof_count = 0 + key_counts = {} + + # capture the output from the AccumulateFn generator and assert. + for r in generator_response: + # Check for responses with values + if r.payload.value: + count += 1 + # Track count per key + key = r.payload.keys[0] if r.payload.keys else "no_key" + key_counts[key] = key_counts.get(key, 0) + 1 + + # Each key should have its own counter starting from 1 + expected_msg = f"counter:{key_counts[key]}" + self.assertEqual( + bytes(expected_msg, encoding="utf-8"), + r.payload.value, + ) + self.assertEqual(r.EOF, False) + else: + eof_count += 1 + self.assertEqual(r.EOF, True) + + # We should have 10 messages (one for each key) + self.assertEqual(10, count) + self.assertEqual(10, eof_count) # Each key/task sends its own EOF + # Each key should appear once + self.assertEqual(len(key_counts), 10) + + def test_accumulate_with_close(self) -> None: + stub = self.__stub() + request = start_request() + generator_response = None + try: + generator_response = stub.AccumulateFn( + request_iterator=request_generator(count=5, request=request, send_close=True) + ) + except grpc.RpcError as e: + logging.error(e) + + # capture the output from the AccumulateFn generator and assert. + count = 0 + eof_count = 0 + for r in generator_response: + if hasattr(r, "payload") and r.payload.value: + count += 1 + # Each datum should increment the counter + expected_msg = f"counter:{count}" + self.assertEqual( + bytes(expected_msg, encoding="utf-8"), + r.payload.value, + ) + self.assertEqual(r.EOF, False) + # Check that keys are preserved + self.assertEqual(list(r.payload.keys), ["test_key"]) + else: + self.assertEqual(r.EOF, True) + eof_count += 1 + + # We should have received 5 messages (one for each datum) + self.assertEqual(5, count) + self.assertEqual(1, eof_count) + + def test_accumulate_append_without_open(self) -> None: + stub = self.__stub() + request = start_request_without_open() + generator_response = None + try: + generator_response = stub.AccumulateFn( + request_iterator=request_generator_append_only(count=5, request=request) + ) + except grpc.RpcError as e: + logging.error(e) + + # capture the output from the AccumulateFn generator and assert. + count = 0 + eof_count = 0 + for r in generator_response: + if hasattr(r, "payload") and r.payload.value: + count += 1 + # Each datum should increment the counter + expected_msg = f"counter:{count}" + self.assertEqual( + bytes(expected_msg, encoding="utf-8"), + r.payload.value, + ) + self.assertEqual(r.EOF, False) + # Check that keys are preserved + self.assertEqual(list(r.payload.keys), ["test_key"]) + else: + self.assertEqual(r.EOF, True) + eof_count += 1 + + # We should have received 5 messages (one for each datum) + self.assertEqual(5, count) + self.assertEqual(1, eof_count) + + def test_accumulate_append_mixed(self) -> None: + stub = self.__stub() + request = start_request() + generator_response = None + try: + generator_response = stub.AccumulateFn( + request_iterator=request_generator_mixed(count=5, request=request) + ) + except grpc.RpcError as e: + logging.error(e) + + # capture the output from the AccumulateFn generator and assert. + count = 0 + eof_count = 0 + for r in generator_response: + if hasattr(r, "payload") and r.payload.value: + count += 1 + # Each datum should increment the counter + expected_msg = "counter:1" + self.assertEqual( + bytes(expected_msg, encoding="utf-8"), + r.payload.value, + ) + self.assertEqual(r.EOF, False) + # Check that keys are preserved + self.assertEqual(list(r.payload.keys), ["test_key"]) + else: + self.assertEqual(r.EOF, True) + eof_count += 1 + + # We should have received 5 messages (one for each datum) + self.assertEqual(3, count) + self.assertEqual(3, eof_count) + + def test_is_ready(self) -> None: + with grpc.insecure_channel("unix:///tmp/accumulator.sock") as channel: + stub = accumulator_pb2_grpc.AccumulatorStub(channel) + + request = _empty_pb2.Empty() + response = None + try: + response = stub.IsReady(request=request) + except grpc.RpcError as e: + logging.error(e) + + self.assertTrue(response.ready) + + def __stub(self): + return accumulator_pb2_grpc.AccumulatorStub(_channel) + + def test_error_init(self): + # Check that accumulator_instance is required + with self.assertRaises(TypeError): + AccumulatorAsyncServer() + # Check that the init_args and init_kwargs are passed + # only with an Accumulator class + with self.assertRaises(TypeError): + AccumulatorAsyncServer(accumulator_handler_func, init_args=(0, 1)) + # Check that an instance is not passed instead of the class + # signature + with self.assertRaises(TypeError): + AccumulatorAsyncServer(ExampleClass(0)) + + # Check that an invalid class is passed + class ExampleBadClass: + pass + + with self.assertRaises(TypeError): + AccumulatorAsyncServer(accumulator_instance=ExampleBadClass) + + def test_max_threads(self): + # max cap at 16 + server = AccumulatorAsyncServer(accumulator_instance=ExampleClass, max_threads=32) + self.assertEqual(server.max_threads, 16) + + # use argument provided + server = AccumulatorAsyncServer(accumulator_instance=ExampleClass, max_threads=5) + self.assertEqual(server.max_threads, 5) + + # defaults to 4 + server = AccumulatorAsyncServer(accumulator_instance=ExampleClass) + self.assertEqual(server.max_threads, 4) + + # zero threads + server = AccumulatorAsyncServer(ExampleClass, max_threads=0) + self.assertEqual(server.max_threads, 0) + + # negative threads + server = AccumulatorAsyncServer(ExampleClass, max_threads=-5) + self.assertEqual(server.max_threads, -5) + + def test_server_info_file_path_handling(self): + """Test AccumulatorAsyncServer with custom server info file path.""" + + server = AccumulatorAsyncServer( + ExampleClass, init_args=(0,), server_info_file="/custom/path/server_info.json" + ) + + self.assertEqual(server.server_info_file, "/custom/path/server_info.json") + + def test_init_kwargs_none_handling(self): + """Test init_kwargs None handling in AccumulatorAsyncServer.""" + + server = AccumulatorAsyncServer( + ExampleClass, init_args=(0,), init_kwargs=None # This should be converted to {} + ) + + # Should not raise any errors and should work correctly + self.assertIsNotNone(server.accumulator_handler) + + def test_server_start_method_logging(self): + """Test server start method includes proper logging.""" + from unittest.mock import patch + + server = AccumulatorAsyncServer(ExampleClass) + + # Mock aiorun.run to prevent actual server startup + with patch("pynumaflow.accumulator.async_server.aiorun") as mock_aiorun, patch( + "pynumaflow.accumulator.async_server._LOGGER" + ) as mock_logger: + server.start() + + # Verify logging was called + mock_logger.info.assert_called_once_with("Starting Async Accumulator Server") + + # Verify aiorun.run was called with correct parameters + mock_aiorun.run.assert_called_once() + self.assertTrue(mock_aiorun.run.call_args[1]["use_uvloop"]) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/tests/accumulator/test_async_accumulator_err.py b/tests/accumulator/test_async_accumulator_err.py new file mode 100644 index 00000000..5b39174c --- /dev/null +++ b/tests/accumulator/test_async_accumulator_err.py @@ -0,0 +1,175 @@ +import asyncio +import logging +import threading +import unittest +from collections.abc import AsyncIterable +from unittest.mock import patch + +import grpc +from grpc.aio._server import Server + +from pynumaflow import setup_logging +from pynumaflow.accumulator import ( + Message, + Datum, + AccumulatorAsyncServer, + Accumulator, +) +from pynumaflow.proto.accumulator import accumulator_pb2, accumulator_pb2_grpc +from pynumaflow.shared.asynciter import NonBlockingIterator +from tests.testing_utils import ( + mock_message, + get_time_args, + mock_terminate_on_stop, +) + +LOGGER = setup_logging(__name__) + + +def request_generator(count, request): + for i in range(count): + yield request + + +def start_request() -> accumulator_pb2.AccumulatorRequest: + event_time_timestamp, watermark_timestamp = get_time_args() + window = accumulator_pb2.KeyedWindow( + start=event_time_timestamp, + end=watermark_timestamp, + slot="slot-0", + keys=["test_key"], + ) + payload = accumulator_pb2.Payload( + keys=["test_key"], + value=mock_message(), + event_time=event_time_timestamp, + watermark=watermark_timestamp, + id="test_id", + headers={"test_header_key": "test_header_value", "source": "test_source"}, + ) + operation = accumulator_pb2.AccumulatorRequest.WindowOperation( + event=accumulator_pb2.AccumulatorRequest.WindowOperation.Event.OPEN, + keyedWindow=window, + ) + request = accumulator_pb2.AccumulatorRequest( + payload=payload, + operation=operation, + ) + return request + + +_s: Server = None +_channel = grpc.insecure_channel("unix:///tmp/accumulator_err.sock") +_loop = None + + +def startup_callable(loop): + asyncio.set_event_loop(loop) + loop.run_forever() + + +class ExampleErrorClass(Accumulator): + def __init__(self, counter): + self.counter = counter + + async def handler(self, datums: AsyncIterable[Datum], output: NonBlockingIterator): + async for datum in datums: + self.counter += 1 + if self.counter == 2: + # Simulate an error on the second datum + raise RuntimeError("Simulated error in accumulator handler") + msg = f"counter:{self.counter}" + await output.put(Message(str.encode(msg), keys=datum.keys(), tags=[])) + + +async def error_accumulator_handler_func(datums: AsyncIterable[Datum], output: NonBlockingIterator): + counter = 0 + async for datum in datums: + counter += 1 + if counter == 2: + # Simulate an error on the second datum + raise RuntimeError("Simulated error in accumulator function") + msg = f"counter:{counter}" + await output.put(Message(str.encode(msg), keys=datum.keys(), tags=[])) + + +def NewAsyncAccumulatorError(): + server_instance = AccumulatorAsyncServer(ExampleErrorClass, init_args=(0,)) + udfs = server_instance.servicer + return udfs + + +@patch("psutil.Process.kill", mock_terminate_on_stop) +async def start_server(udfs): + server = grpc.aio.server() + accumulator_pb2_grpc.add_AccumulatorServicer_to_server(udfs, server) + listen_addr = "unix:///tmp/accumulator_err.sock" + server.add_insecure_port(listen_addr) + logging.info("Starting server on %s", listen_addr) + global _s + _s = server + await server.start() + await server.wait_for_termination() + + +@patch("psutil.Process.kill", mock_terminate_on_stop) +class TestAsyncAccumulatorError(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + global _loop + loop = asyncio.new_event_loop() + _loop = loop + _thread = threading.Thread(target=startup_callable, args=(loop,), daemon=True) + _thread.start() + udfs = NewAsyncAccumulatorError() + asyncio.run_coroutine_threadsafe(start_server(udfs), loop=loop) + while True: + try: + with grpc.insecure_channel("unix:///tmp/accumulator_err.sock") as channel: + f = grpc.channel_ready_future(channel) + f.result(timeout=10) + if f.done(): + break + except grpc.FutureTimeoutError as e: + LOGGER.error("error trying to connect to grpc server") + LOGGER.error(e) + + @classmethod + def tearDownClass(cls) -> None: + try: + _loop.stop() + LOGGER.info("stopped the event loop") + except Exception as e: + LOGGER.error(e) + + @patch("psutil.Process.kill", mock_terminate_on_stop) + def test_accumulate_partial_success(self) -> None: + """Test that the first datum is processed before error occurs""" + stub = self.__stub() + request = start_request() + + try: + generator_response = stub.AccumulateFn( + request_iterator=request_generator(count=5, request=request) + ) + + # Try to consume the generator + counter = 0 + for response in generator_response: + self.assertIsInstance(response, accumulator_pb2.AccumulatorResponse) + self.assertTrue(response.payload.value.startswith(b"counter:")) + counter += 1 + + self.assertEqual(counter, 1, "Expected only one successful response before error") + except BaseException as err: + self.assertTrue("Simulated error in accumulator handler" in str(err)) + return + self.fail("Expected an exception.") + + def __stub(self): + return accumulator_pb2_grpc.AccumulatorStub(_channel) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/tests/accumulator/test_datatypes.py b/tests/accumulator/test_datatypes.py new file mode 100644 index 00000000..a71f3452 --- /dev/null +++ b/tests/accumulator/test_datatypes.py @@ -0,0 +1,339 @@ +import unittest +from collections.abc import AsyncIterable +from datetime import datetime, timezone + +from google.protobuf import timestamp_pb2 as _timestamp_pb2 +from pynumaflow.accumulator import Accumulator + +from pynumaflow.accumulator._dtypes import ( + IntervalWindow, + KeyedWindow, + Datum, + AccumulatorResult, + AccumulatorRequest, + WindowOperation, + Message, +) +from pynumaflow.shared.asynciter import NonBlockingIterator +from tests.testing_utils import ( + mock_message, + mock_event_time, + mock_watermark, + mock_start_time, + mock_end_time, +) + +TEST_KEYS = ["test"] +TEST_ID = "test_id" +TEST_HEADERS = {"key1": "value1", "key2": "value2"} + + +class TestDatum(unittest.TestCase): + def test_err_event_time(self): + ts = _timestamp_pb2.Timestamp() + ts.GetCurrentTime() + headers = {"key1": "value1", "key2": "value2"} + with self.assertRaises(Exception) as context: + Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=ts, + watermark=mock_watermark(), + id_=TEST_ID, + headers=headers, + ) + self.assertEqual( + "Wrong data type: " + "for Datum.event_time", + str(context.exception), + ) + + def test_err_watermark(self): + ts = _timestamp_pb2.Timestamp() + ts.GetCurrentTime() + headers = {"key1": "value1", "key2": "value2"} + with self.assertRaises(Exception) as context: + Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=ts, + id_=TEST_ID, + headers=headers, + ) + self.assertEqual( + "Wrong data type: " + "for Datum.watermark", + str(context.exception), + ) + + def test_properties(self): + d = Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + id_=TEST_ID, + headers=TEST_HEADERS, + ) + self.assertEqual(mock_message(), d.value) + self.assertEqual(TEST_KEYS, d.keys()) + self.assertEqual(mock_event_time(), d.event_time) + self.assertEqual(mock_watermark(), d.watermark) + self.assertEqual(TEST_HEADERS, d.headers) + self.assertEqual(TEST_ID, d.id) + + def test_default_values(self): + d = Datum( + keys=None, + value=None, + event_time=mock_event_time(), + watermark=mock_watermark(), + id_=TEST_ID, + ) + self.assertEqual([], d.keys()) + self.assertEqual(b"", d.value) + self.assertEqual({}, d.headers) + + +class TestIntervalWindow(unittest.TestCase): + def test_start(self): + i = IntervalWindow(start=mock_start_time(), end=mock_end_time()) + self.assertEqual(mock_start_time(), i.start) + + def test_end(self): + i = IntervalWindow(start=mock_start_time(), end=mock_end_time()) + self.assertEqual(mock_end_time(), i.end) + + +class TestKeyedWindow(unittest.TestCase): + def test_create_window(self): + kw = KeyedWindow( + start=mock_start_time(), end=mock_end_time(), slot="slot-0", keys=["key1", "key2"] + ) + self.assertEqual(kw.start, mock_start_time()) + self.assertEqual(kw.end, mock_end_time()) + self.assertEqual(kw.slot, "slot-0") + self.assertEqual(kw.keys, ["key1", "key2"]) + + def test_default_values(self): + kw = KeyedWindow(start=mock_start_time(), end=mock_end_time()) + self.assertEqual(kw.slot, "") + self.assertEqual(kw.keys, []) + + def test_window_property(self): + kw = KeyedWindow(start=mock_start_time(), end=mock_end_time()) + self.assertIsInstance(kw.window, IntervalWindow) + self.assertEqual(kw.window.start, mock_start_time()) + self.assertEqual(kw.window.end, mock_end_time()) + + +class TestAccumulatorResult(unittest.TestCase): + def test_create_result(self): + # Create mock objects + future = None # In real usage, this would be an asyncio.Task + iterator = NonBlockingIterator() + keys = ["key1", "key2"] + result_queue = NonBlockingIterator() + consumer_future = None # In real usage, this would be an asyncio.Task + watermark = datetime.fromtimestamp(1662998400, timezone.utc) + + result = AccumulatorResult(future, iterator, keys, result_queue, consumer_future, watermark) + + self.assertEqual(result.future, future) + self.assertEqual(result.iterator, iterator) + self.assertEqual(result.keys, keys) + self.assertEqual(result.result_queue, result_queue) + self.assertEqual(result.consumer_future, consumer_future) + self.assertEqual(result.latest_watermark, watermark) + + def test_update_watermark(self): + result = AccumulatorResult( + None, None, [], None, None, datetime.fromtimestamp(1662998400, timezone.utc) + ) + new_watermark = datetime.fromtimestamp(1662998460, timezone.utc) + result.update_watermark(new_watermark) + self.assertEqual(result.latest_watermark, new_watermark) + + def test_update_watermark_invalid_type(self): + result = AccumulatorResult( + None, None, [], None, None, datetime.fromtimestamp(1662998400, timezone.utc) + ) + with self.assertRaises(TypeError): + result.update_watermark("not a datetime") + + +class TestAccumulatorRequest(unittest.TestCase): + def test_create_request(self): + operation = WindowOperation.OPEN + keyed_window = KeyedWindow(start=mock_start_time(), end=mock_end_time()) + payload = Datum( + keys=TEST_KEYS, + value=mock_message(), + event_time=mock_event_time(), + watermark=mock_watermark(), + id_=TEST_ID, + ) + + request = AccumulatorRequest(operation, keyed_window, payload) + self.assertEqual(request.operation, operation) + self.assertEqual(request.keyed_window, keyed_window) + self.assertEqual(request.payload, payload) + + +class TestWindowOperation(unittest.TestCase): + def test_enum_values(self): + self.assertEqual(WindowOperation.OPEN, 0) + self.assertEqual(WindowOperation.CLOSE, 1) + self.assertEqual(WindowOperation.APPEND, 2) + + +class TestMessage(unittest.TestCase): + def test_create_message(self): + value = b"test_value" + keys = ["key1", "key2"] + tags = ["tag1", "tag2"] + + msg = Message(value=value, keys=keys, tags=tags) + self.assertEqual(msg.value, value) + self.assertEqual(msg.keys, keys) + self.assertEqual(msg.tags, tags) + + def test_default_values(self): + msg = Message(value=b"test") + self.assertEqual(msg.keys, []) + self.assertEqual(msg.tags, []) + + def test_to_drop(self): + msg = Message.to_drop() + self.assertEqual(msg.value, b"") + self.assertEqual(msg.keys, []) + self.assertTrue("U+005C__DROP__" in msg.tags) + + def test_none_values(self): + msg = Message(value=None, keys=None, tags=None) + self.assertEqual(msg.value, b"") + self.assertEqual(msg.keys, []) + self.assertEqual(msg.tags, []) + + def test_from_datum(self): + """Test that Message.from_datum correctly creates a Message from a Datum""" + # Create a sample datum with all properties + test_keys = ["key1", "key2"] + test_value = b"test_message_value" + test_event_time = mock_event_time() + test_watermark = mock_watermark() + test_headers = {"header1": "value1", "header2": "value2"} + test_id = "test_datum_id" + + datum = Datum( + keys=test_keys, + value=test_value, + event_time=test_event_time, + watermark=test_watermark, + id_=test_id, + headers=test_headers, + ) + + # Create message from datum + message = Message.from_datum(datum) + + # Verify all properties are correctly transferred + self.assertEqual(message.value, test_value) + self.assertEqual(message.keys, test_keys) + self.assertEqual(message.event_time, test_event_time) + self.assertEqual(message.watermark, test_watermark) + self.assertEqual(message.headers, test_headers) + self.assertEqual(message.id, test_id) + + # Verify that tags are empty (default for Message) + self.assertEqual(message.tags, []) + + def test_from_datum_minimal(self): + """Test from_datum with minimal Datum (no headers)""" + test_keys = ["minimal_key"] + test_value = b"minimal_value" + test_event_time = mock_event_time() + test_watermark = mock_watermark() + test_id = "minimal_id" + + datum = Datum( + keys=test_keys, + value=test_value, + event_time=test_event_time, + watermark=test_watermark, + id_=test_id, + # headers not provided (will default to {}) + ) + + message = Message.from_datum(datum) + + self.assertEqual(message.value, test_value) + self.assertEqual(message.keys, test_keys) + self.assertEqual(message.event_time, test_event_time) + self.assertEqual(message.watermark, test_watermark) + self.assertEqual(message.headers, {}) + self.assertEqual(message.id, test_id) + self.assertEqual(message.tags, []) + + def test_from_datum_empty_keys(self): + """Test from_datum with empty keys""" + datum = Datum( + keys=None, # Will default to [] + value=b"test_value", + event_time=mock_event_time(), + watermark=mock_watermark(), + id_="test_id", + ) + + message = Message.from_datum(datum) + + self.assertEqual(message.keys, []) + self.assertEqual(message.value, b"test_value") + self.assertEqual(message.id, "test_id") + + +class TestAccumulatorClass(unittest.TestCase): + class ExampleClass(Accumulator): + async def handler(self, datums: AsyncIterable[Datum], output: NonBlockingIterator): + pass + + def __init__(self, test1, test2): + self.test1 = test1 + self.test2 = test2 + self.test3 = self.test1 + + def test_init(self): + r = self.ExampleClass(test1=1, test2=2) + self.assertEqual(1, r.test1) + self.assertEqual(2, r.test2) + self.assertEqual(1, r.test3) + + def test_callable(self): + """Test that accumulator instances can be called directly""" + r = self.ExampleClass(test1=1, test2=2) + # The __call__ method should be callable and delegate to the handler method + self.assertTrue(callable(r)) + # __call__ should return the result of calling handler + # Since handler is an async method, __call__ should return a coroutine + import asyncio + from pynumaflow.shared.asynciter import NonBlockingIterator + + async def test_datums(): + yield Datum( + keys=["test"], + value=b"test", + event_time=mock_event_time(), + watermark=mock_watermark(), + id_="test", + ) + + output = NonBlockingIterator() + result = r(test_datums(), output) + self.assertTrue(asyncio.iscoroutine(result)) + # Clean up the coroutine + result.close() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/accumulator/utils.py b/tests/accumulator/utils.py new file mode 100644 index 00000000..d0c68fbb --- /dev/null +++ b/tests/accumulator/utils.py @@ -0,0 +1,23 @@ +from datetime import datetime, timezone +from pynumaflow.accumulator import Datum + + +def create_test_datum(keys, value, event_time=None, watermark=None, id_=None, headers=None): + """Create a test Datum object with default values""" + if event_time is None: + event_time = datetime.fromtimestamp(1662998400, timezone.utc) + if watermark is None: + watermark = datetime.fromtimestamp(1662998460, timezone.utc) + if id_ is None: + id_ = "test_id" + if headers is None: + headers = {} + + return Datum( + keys=keys, + value=value, + event_time=event_time, + watermark=watermark, + id_=id_, + headers=headers, + )