diff --git a/src/graphon/dsl/entities.py b/src/graphon/dsl/entities.py index 3871b77..8645a9e 100644 --- a/src/graphon/dsl/entities.py +++ b/src/graphon/dsl/entities.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import abstractmethod from collections.abc import Mapping from enum import StrEnum, auto from typing import Any, Protocol @@ -125,4 +126,5 @@ def loadable(self) -> bool: class TypedNodeFactory(Protocol): + @abstractmethod def create_node(self, node_config: NodeConfigDict) -> Any: ... diff --git a/src/graphon/file/protocols.py b/src/graphon/file/protocols.py index c376c02..f42470a 100644 --- a/src/graphon/file/protocols.py +++ b/src/graphon/file/protocols.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import abstractmethod from collections.abc import Generator from typing import TYPE_CHECKING, Literal, Protocol @@ -18,8 +19,10 @@ class WorkflowFileRuntimeProtocol(Protocol): """ @property + @abstractmethod def multimodal_send_format(self) -> str: ... + @abstractmethod def http_get( self, url: str, @@ -27,10 +30,13 @@ def http_get( follow_redirects: bool = True, ) -> HttpResponseProtocol: ... + @abstractmethod def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: ... + @abstractmethod def load_file_bytes(self, *, file: File) -> bytes: ... + @abstractmethod def resolve_file_url( self, *, @@ -38,6 +44,7 @@ def resolve_file_url( for_external: bool = True, ) -> str | None: ... + @abstractmethod def resolve_upload_file_url( self, *, @@ -46,6 +53,7 @@ def resolve_upload_file_url( for_external: bool = True, ) -> str: ... + @abstractmethod def resolve_tool_file_url( self, *, @@ -54,6 +62,7 @@ def resolve_tool_file_url( for_external: bool = True, ) -> str: ... + @abstractmethod def verify_preview_signature( self, *, diff --git a/src/graphon/graph/graph.py b/src/graphon/graph/graph.py index ac79390..a07dfaa 100644 --- a/src/graphon/graph/graph.py +++ b/src/graphon/graph/graph.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from abc import abstractmethod from collections import defaultdict from collections.abc import Mapping, Sequence from typing import Any, Protocol, final @@ -27,6 +28,7 @@ class NodeFactory(Protocol): allowing for different node creation strategies while maintaining type safety. """ + @abstractmethod def create_node(self, node_config: NodeConfigDict) -> Node: """Create a Node instance from node configuration data. diff --git a/src/graphon/graph/validation.py b/src/graphon/graph/validation.py index b66def8..712cbf0 100644 --- a/src/graphon/graph/validation.py +++ b/src/graphon/graph/validation.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import abstractmethod from collections.abc import Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Protocol @@ -34,6 +35,7 @@ def __init__(self, issues: Sequence[GraphValidationIssue]) -> None: class GraphValidationRule(Protocol): """Protocol that individual validation rules must satisfy.""" + @abstractmethod def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: """Validate the provided graph and return any discovered issues.""" ... diff --git a/src/graphon/graph_engine/command_channels/protocol.py b/src/graphon/graph_engine/command_channels/protocol.py index 10e3e59..350f6c6 100644 --- a/src/graphon/graph_engine/command_channels/protocol.py +++ b/src/graphon/graph_engine/command_channels/protocol.py @@ -4,6 +4,7 @@ to/from a GraphEngine instance, supporting both local and distributed scenarios. """ +from abc import abstractmethod from typing import Protocol from ..entities.commands import GraphEngineCommand @@ -16,6 +17,7 @@ class CommandChannel(Protocol): this channel is dedicated to that single execution. """ + @abstractmethod def fetch_commands(self) -> list[GraphEngineCommand]: """Fetch pending commands for this GraphEngine instance. @@ -27,6 +29,7 @@ def fetch_commands(self) -> list[GraphEngineCommand]: """ ... + @abstractmethod def send_command(self, command: GraphEngineCommand) -> None: """Send a command to be processed by this GraphEngine instance. diff --git a/src/graphon/graph_engine/command_channels/redis_channel.py b/src/graphon/graph_engine/command_channels/redis_channel.py index 79312dc..f55b045 100644 --- a/src/graphon/graph_engine/command_channels/redis_channel.py +++ b/src/graphon/graph_engine/command_channels/redis_channel.py @@ -6,6 +6,7 @@ """ import json +from abc import abstractmethod from contextlib import AbstractContextManager from typing import Any, Protocol, final @@ -27,18 +28,32 @@ class RedisPipelineProtocol(Protocol): """Minimal Redis pipeline contract used by the command channel.""" + @abstractmethod def lrange(self, name: str, start: int, end: int) -> Any: ... + + @abstractmethod def delete(self, *names: str) -> Any: ... + + @abstractmethod def execute(self) -> list[Any]: ... + + @abstractmethod def rpush(self, name: str, *values: str) -> Any: ... + + @abstractmethod def expire(self, name: str, time: int) -> Any: ... + + @abstractmethod def set(self, name: str, value: str, ex: int | None = None) -> Any: ... + + @abstractmethod def get(self, name: str) -> Any: ... class RedisClientProtocol(Protocol): """Redis client contract required by the command channel.""" + @abstractmethod def pipeline(self) -> AbstractContextManager[RedisPipelineProtocol]: ... diff --git a/src/graphon/graph_engine/command_processing/command_processor.py b/src/graphon/graph_engine/command_processing/command_processor.py index 99903d9..1cadb50 100644 --- a/src/graphon/graph_engine/command_processing/command_processor.py +++ b/src/graphon/graph_engine/command_processing/command_processor.py @@ -1,6 +1,7 @@ """Main command processor for handling external commands.""" import logging +from abc import abstractmethod from collections.abc import Callable from typing import Protocol, final @@ -15,6 +16,7 @@ class CommandHandler[CommandT: GraphEngineCommand](Protocol): """Protocol for command handlers.""" + @abstractmethod def handle( self, command: CommandT, diff --git a/src/graphon/graph_engine/domain/graph_execution.py b/src/graphon/graph_engine/domain/graph_execution.py index 996de86..a785e63 100644 --- a/src/graphon/graph_engine/domain/graph_execution.py +++ b/src/graphon/graph_engine/domain/graph_execution.py @@ -10,7 +10,6 @@ from graphon.entities.pause_reason import PauseReason from graphon.enums import NodeState -from graphon.runtime.graph_runtime_state import GraphExecutionProtocol from .node_execution import NodeExecution @@ -246,6 +245,3 @@ def loads(self, data: str) -> None: def record_node_failure(self) -> None: """Increment the count of node failures encountered during execution.""" self.exceptions_count += 1 - - -_: GraphExecutionProtocol = GraphExecution(workflow_id="") diff --git a/src/graphon/graph_engine/ready_queue/protocol.py b/src/graphon/graph_engine/ready_queue/protocol.py index 6c53677..c4e4348 100644 --- a/src/graphon/graph_engine/ready_queue/protocol.py +++ b/src/graphon/graph_engine/ready_queue/protocol.py @@ -4,6 +4,7 @@ for execution, supporting both in-memory and persistent storage scenarios. """ +from abc import abstractmethod from collections.abc import Sequence from typing import Protocol @@ -35,6 +36,7 @@ class ReadyQueue(Protocol): that can be serialized for state storage. """ + @abstractmethod def put(self, item: str) -> None: """Add a node ID to the ready queue. @@ -44,6 +46,7 @@ def put(self, item: str) -> None: """ ... + @abstractmethod def get(self, timeout: float | None = None) -> str: """Retrieve and remove a node ID from the queue. @@ -56,6 +59,7 @@ def get(self, timeout: float | None = None) -> str: """ ... + @abstractmethod def task_done(self) -> None: """Indicate that a previously retrieved task is complete. @@ -64,6 +68,7 @@ def task_done(self) -> None: """ ... + @abstractmethod def empty(self) -> bool: """Check if the queue is empty. @@ -73,6 +78,7 @@ def empty(self) -> bool: """ ... + @abstractmethod def qsize(self) -> int: """Get the approximate size of the queue. @@ -82,6 +88,7 @@ def qsize(self) -> int: """ ... + @abstractmethod def dumps(self) -> str: """Serialize the queue state to a JSON string for storage. @@ -92,6 +99,7 @@ def dumps(self) -> str: """ ... + @abstractmethod def loads(self, data: str) -> None: """Restore the queue state from a JSON string. diff --git a/src/graphon/http/protocols.py b/src/graphon/http/protocols.py index 505dfa0..9ad6d2a 100644 --- a/src/graphon/http/protocols.py +++ b/src/graphon/http/protocols.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from collections.abc import Mapping from typing import Any, Protocol @@ -6,38 +7,51 @@ class HttpResponseProtocol(Protocol): @property + @abstractmethod def headers(self) -> Mapping[str, str]: ... @property + @abstractmethod def content(self) -> bytes: ... @property + @abstractmethod def status_code(self) -> int: ... @property + @abstractmethod def text(self) -> str: ... @property + @abstractmethod def is_success(self) -> bool: ... + @abstractmethod def raise_for_status(self) -> None: ... class HttpClientProtocol(Protocol): @property + @abstractmethod def max_retries_exceeded_error(self) -> type[Exception]: ... @property + @abstractmethod def request_error(self) -> type[Exception]: ... + @abstractmethod def get(self, url: str, max_retries: int = ..., **kwargs: Any) -> HttpResponse: ... + @abstractmethod def head(self, url: str, max_retries: int = ..., **kwargs: Any) -> HttpResponse: ... + @abstractmethod def post(self, url: str, max_retries: int = ..., **kwargs: Any) -> HttpResponse: ... + @abstractmethod def put(self, url: str, max_retries: int = ..., **kwargs: Any) -> HttpResponse: ... + @abstractmethod def delete( self, url: str, @@ -45,6 +59,7 @@ def delete( **kwargs: Any, ) -> HttpResponse: ... + @abstractmethod def patch( self, url: str, diff --git a/src/graphon/model_runtime/memory/prompt_message_memory.py b/src/graphon/model_runtime/memory/prompt_message_memory.py index ea2c245..09d5bff 100644 --- a/src/graphon/model_runtime/memory/prompt_message_memory.py +++ b/src/graphon/model_runtime/memory/prompt_message_memory.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import abstractmethod from collections.abc import Sequence from typing import Protocol @@ -11,6 +12,7 @@ class PromptMessageMemory(Protocol): """Port for loading memory as prompt messages.""" + @abstractmethod def get_history_prompt_messages( self, max_token_limit: int = DEFAULT_MEMORY_MAX_TOKEN_LIMIT, diff --git a/src/graphon/model_runtime/model_providers/base/tokenizers/gpt2_tokenizer.py b/src/graphon/model_runtime/model_providers/base/tokenizers/gpt2_tokenizer.py index 2e33253..5ed4871 100644 --- a/src/graphon/model_runtime/model_providers/base/tokenizers/gpt2_tokenizer.py +++ b/src/graphon/model_runtime/model_providers/base/tokenizers/gpt2_tokenizer.py @@ -1,4 +1,5 @@ import logging +from abc import abstractmethod from collections.abc import Sequence from pathlib import Path from threading import Lock @@ -8,6 +9,7 @@ class _TokenizerProtocol(Protocol): + @abstractmethod def encode(self, text: str) -> Sequence[int]: ... diff --git a/src/graphon/model_runtime/protocols/llm_runtime.py b/src/graphon/model_runtime/protocols/llm_runtime.py index 08b52d1..7fad95c 100644 --- a/src/graphon/model_runtime/protocols/llm_runtime.py +++ b/src/graphon/model_runtime/protocols/llm_runtime.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import abstractmethod from collections.abc import Generator, Sequence from typing import Any, Literal, Protocol, overload, runtime_checkable @@ -49,6 +50,7 @@ def invoke_llm( stream: Literal[True], ) -> Generator[LLMResultChunk, None, None]: ... + @abstractmethod def invoke_llm( self, *, @@ -90,6 +92,7 @@ def invoke_llm_with_structured_output( stream: Literal[True], ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + @abstractmethod def invoke_llm_with_structured_output( self, *, @@ -106,6 +109,7 @@ def invoke_llm_with_structured_output( | Generator[LLMResultChunkWithStructuredOutput, None, None] ): ... + @abstractmethod def get_llm_num_tokens( self, *, diff --git a/src/graphon/model_runtime/protocols/moderation_runtime.py b/src/graphon/model_runtime/protocols/moderation_runtime.py index 8cf6ae9..15e0371 100644 --- a/src/graphon/model_runtime/protocols/moderation_runtime.py +++ b/src/graphon/model_runtime/protocols/moderation_runtime.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import abstractmethod from typing import Any, Protocol, runtime_checkable from graphon.model_runtime.protocols.provider_runtime import ModelProviderRuntime @@ -9,6 +10,7 @@ class ModerationModelRuntime(ModelProviderRuntime, Protocol): """Runtime surface required by moderation model wrappers.""" + @abstractmethod def invoke_moderation( self, *, diff --git a/src/graphon/model_runtime/protocols/provider_runtime.py b/src/graphon/model_runtime/protocols/provider_runtime.py index 83c35fe..030a32d 100644 --- a/src/graphon/model_runtime/protocols/provider_runtime.py +++ b/src/graphon/model_runtime/protocols/provider_runtime.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import abstractmethod from collections.abc import Sequence from typing import Any, Protocol, runtime_checkable @@ -11,8 +12,10 @@ class ModelProviderRuntime(Protocol): """Shared provider discovery, credential validation, and schema lookup.""" + @abstractmethod def fetch_model_providers(self) -> Sequence[ProviderEntity]: ... + @abstractmethod def get_provider_icon( self, *, @@ -21,6 +24,7 @@ def get_provider_icon( lang: str, ) -> tuple[bytes, str]: ... + @abstractmethod def validate_provider_credentials( self, *, @@ -28,6 +32,7 @@ def validate_provider_credentials( credentials: dict[str, Any], ) -> None: ... + @abstractmethod def validate_model_credentials( self, *, @@ -37,6 +42,7 @@ def validate_model_credentials( credentials: dict[str, Any], ) -> None: ... + @abstractmethod def get_model_schema( self, *, diff --git a/src/graphon/model_runtime/protocols/rerank_runtime.py b/src/graphon/model_runtime/protocols/rerank_runtime.py index aa3814a..66f70ff 100644 --- a/src/graphon/model_runtime/protocols/rerank_runtime.py +++ b/src/graphon/model_runtime/protocols/rerank_runtime.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import abstractmethod from typing import Any, Protocol, runtime_checkable from graphon.model_runtime.entities.rerank_entities import ( @@ -13,6 +14,7 @@ class RerankModelRuntime(ModelProviderRuntime, Protocol): """Runtime surface required by rerank model wrappers.""" + @abstractmethod def invoke_rerank( self, *, @@ -25,6 +27,7 @@ def invoke_rerank( top_n: int | None, ) -> RerankResult: ... + @abstractmethod def invoke_multimodal_rerank( self, *, diff --git a/src/graphon/model_runtime/protocols/speech_to_text_runtime.py b/src/graphon/model_runtime/protocols/speech_to_text_runtime.py index 8f59a62..ddcae2b 100644 --- a/src/graphon/model_runtime/protocols/speech_to_text_runtime.py +++ b/src/graphon/model_runtime/protocols/speech_to_text_runtime.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import abstractmethod from typing import IO, Any, Protocol, runtime_checkable from graphon.model_runtime.protocols.provider_runtime import ModelProviderRuntime @@ -9,6 +10,7 @@ class SpeechToTextModelRuntime(ModelProviderRuntime, Protocol): """Runtime surface required by speech-to-text model wrappers.""" + @abstractmethod def invoke_speech_to_text( self, *, diff --git a/src/graphon/model_runtime/protocols/text_embedding_runtime.py b/src/graphon/model_runtime/protocols/text_embedding_runtime.py index 4938ccd..d94f80a 100644 --- a/src/graphon/model_runtime/protocols/text_embedding_runtime.py +++ b/src/graphon/model_runtime/protocols/text_embedding_runtime.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import abstractmethod from typing import Any, Protocol, runtime_checkable from graphon.model_runtime.entities.text_embedding_entities import ( @@ -13,6 +14,7 @@ class TextEmbeddingModelRuntime(ModelProviderRuntime, Protocol): """Runtime surface required by text and multimodal embedding wrappers.""" + @abstractmethod def invoke_text_embedding( self, *, @@ -23,6 +25,7 @@ def invoke_text_embedding( input_type: EmbeddingInputType, ) -> EmbeddingResult: ... + @abstractmethod def invoke_multimodal_embedding( self, *, @@ -33,6 +36,7 @@ def invoke_multimodal_embedding( input_type: EmbeddingInputType, ) -> EmbeddingResult: ... + @abstractmethod def get_text_embedding_num_tokens( self, *, diff --git a/src/graphon/model_runtime/protocols/tts_runtime.py b/src/graphon/model_runtime/protocols/tts_runtime.py index 2b9129a..a1a2388 100644 --- a/src/graphon/model_runtime/protocols/tts_runtime.py +++ b/src/graphon/model_runtime/protocols/tts_runtime.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import abstractmethod from collections.abc import Iterable from typing import Any, Protocol, runtime_checkable @@ -10,6 +11,7 @@ class TTSModelRuntime(ModelProviderRuntime, Protocol): """Runtime surface required by text-to-speech model wrappers.""" + @abstractmethod def invoke_tts( self, *, @@ -20,6 +22,7 @@ def invoke_tts( voice: str, ) -> Iterable[bytes]: ... + @abstractmethod def get_tts_model_voices( self, *, diff --git a/src/graphon/nodes/code/code_node.py b/src/graphon/nodes/code/code_node.py index a37039d..fbdb2e6 100644 --- a/src/graphon/nodes/code/code_node.py +++ b/src/graphon/nodes/code/code_node.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import abstractmethod from collections.abc import Mapping, Sequence from decimal import Decimal from textwrap import dedent @@ -23,6 +24,7 @@ class CodeExecutorProtocol(Protocol): + @abstractmethod def execute( self, *, @@ -31,6 +33,7 @@ def execute( inputs: Mapping[str, Any], ) -> Mapping[str, Any]: ... + @abstractmethod def is_execution_error(self, error: Exception) -> bool: ... diff --git a/src/graphon/nodes/llm/file_saver.py b/src/graphon/nodes/llm/file_saver.py index c6708ae..8053017 100644 --- a/src/graphon/nodes/llm/file_saver.py +++ b/src/graphon/nodes/llm/file_saver.py @@ -2,6 +2,7 @@ import mimetypes import typing as tp +from abc import abstractmethod from graphon.file.constants import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE from graphon.file.enums import ( @@ -21,6 +22,7 @@ class LLMFileSaver(tp.Protocol): LLM. """ + @abstractmethod def save_binary_string( self, data: bytes, @@ -58,6 +60,7 @@ def save_binary_string( """ raise NotImplementedError + @abstractmethod def save_remote_url(self, url: str, file_type: FileType) -> File: """save_remote_url saves the file from a remote url returned by LLM. diff --git a/src/graphon/nodes/llm/protocols.py b/src/graphon/nodes/llm/protocols.py index 740d039..dd393fb 100644 --- a/src/graphon/nodes/llm/protocols.py +++ b/src/graphon/nodes/llm/protocols.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import abstractmethod from typing import Any, Protocol from graphon.nodes.llm.runtime_protocols import LLMProtocol @@ -8,6 +9,7 @@ class CredentialsProvider(Protocol): """Port for loading runtime credentials for a provider/model pair.""" + @abstractmethod def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: """Return credentials for the target provider/model or raise a domain error.""" ... @@ -16,6 +18,7 @@ def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: class ModelFactory(Protocol): """Port for creating prepared graph-facing LLM runtimes for execution.""" + @abstractmethod def init_model_instance( self, provider_name: str, diff --git a/src/graphon/nodes/llm/runtime_protocols.py b/src/graphon/nodes/llm/runtime_protocols.py index efafc4c..361a336 100644 --- a/src/graphon/nodes/llm/runtime_protocols.py +++ b/src/graphon/nodes/llm/runtime_protocols.py @@ -1,5 +1,6 @@ from __future__ import annotations +from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence from typing import Any, Literal, Protocol, overload @@ -22,22 +23,29 @@ class LLMProtocol(Protocol): """A graph-facing LLM runtime adapter for node execution.""" @property + @abstractmethod def provider(self) -> str: ... @property + @abstractmethod def model_name(self) -> str: ... @property + @abstractmethod def parameters(self) -> Mapping[str, Any]: ... @parameters.setter + @abstractmethod def parameters(self, value: Mapping[str, Any]) -> None: ... @property + @abstractmethod def stop(self) -> Sequence[str] | None: ... + @abstractmethod def get_model_schema(self) -> AIModelEntity: ... + @abstractmethod def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: ... @overload @@ -62,6 +70,7 @@ def invoke_llm( stream: Literal[True], ) -> Generator[LLMResultChunk, None, None]: ... + @abstractmethod def invoke_llm( self, *, @@ -94,6 +103,7 @@ def invoke_llm_with_structured_output( stream: Literal[True], ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + @abstractmethod def invoke_llm_with_structured_output( self, *, @@ -107,12 +117,14 @@ def invoke_llm_with_structured_output( | Generator[LLMResultChunkWithStructuredOutput, None, None] ): ... + @abstractmethod def is_structured_output_parse_error(self, error: Exception) -> bool: ... class PromptMessageSerializerProtocol(Protocol): """Port for converting compiled prompt messages into persisted process data.""" + @abstractmethod def serialize( self, *, @@ -124,4 +136,5 @@ def serialize( class RetrieverAttachmentLoaderProtocol(Protocol): """Port for resolving retriever segment attachments into graph file references.""" + @abstractmethod def load(self, *, segment_id: str) -> Sequence[File]: ... diff --git a/src/graphon/nodes/protocols.py b/src/graphon/nodes/protocols.py index 3557792..ea9d9a5 100644 --- a/src/graphon/nodes/protocols.py +++ b/src/graphon/nodes/protocols.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from collections.abc import Generator, Mapping from typing import Any, Protocol @@ -6,10 +7,12 @@ class FileManagerProtocol(Protocol): + @abstractmethod def download(self, f: File, /) -> bytes: ... class ToolFileManagerProtocol(Protocol): + @abstractmethod def create_file_by_raw( self, *, @@ -18,6 +21,7 @@ def create_file_by_raw( filename: str | None = None, ) -> Any: ... + @abstractmethod def get_file_generator_by_tool_file_id( self, tool_file_id: str, @@ -29,6 +33,7 @@ class FileReferenceFactoryProtocol(Protocol): format. It enforces approriate permission filtering for the file. """ + @abstractmethod def build_from_mapping(self, *, mapping: Mapping[str, Any]) -> File: ... diff --git a/src/graphon/nodes/runtime.py b/src/graphon/nodes/runtime.py index 2843e4a..1426cf1 100644 --- a/src/graphon/nodes/runtime.py +++ b/src/graphon/nodes/runtime.py @@ -1,6 +1,6 @@ from __future__ import annotations -import abc +from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence from datetime import datetime from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable @@ -26,6 +26,7 @@ class ToolNodeRuntimeProtocol(Protocol): translate between graph-owned abstractions and `core.tools` internals. """ + @abstractmethod def get_runtime( self, *, @@ -35,12 +36,14 @@ def get_runtime( node_execution_id: str | None = None, ) -> ToolRuntimeHandle: ... + @abstractmethod def get_runtime_parameters( self, *, tool_runtime: ToolRuntimeHandle, ) -> Sequence[ToolRuntimeParameter]: ... + @abstractmethod def invoke( self, *, @@ -50,12 +53,14 @@ def invoke( provider_name: str, ) -> Generator[ToolRuntimeMessage, None, None]: ... + @abstractmethod def get_usage( self, *, tool_runtime: ToolRuntimeHandle, ) -> LLMUsage: ... + @abstractmethod def build_file_reference(self, *, mapping: Mapping[str, Any]) -> Any: ... @@ -63,14 +68,14 @@ def build_file_reference(self, *, mapping: Mapping[str, Any]) -> Any: ... class HumanInputNodeRuntimeProtocol(Protocol): """Workflow-layer adapter for human-input runtime persistence and delivery.""" - @abc.abstractmethod + @abstractmethod def get_form( self, *, node_id: str, ) -> HumanInputFormStateProtocol | None: ... - @abc.abstractmethod + @abstractmethod def create_form( self, *, @@ -85,6 +90,7 @@ def create_form( class HumanInputFormRepositoryBindableRuntimeProtocol(Protocol): """Optional capability for runtimes that require explicit repository binding.""" + @abstractmethod def with_form_repository( self, form_repository: object, @@ -130,22 +136,29 @@ def _normalize_human_input_runtime( class HumanInputFormStateProtocol(Protocol): @property + @abstractmethod def id(self) -> str: ... @property + @abstractmethod def rendered_content(self) -> str: ... @property + @abstractmethod def selected_action_id(self) -> str | None: ... @property + @abstractmethod def submitted_data(self) -> Mapping[str, Any] | None: ... @property + @abstractmethod def submitted(self) -> bool: ... @property + @abstractmethod def status(self) -> HumanInputFormStatus: ... @property + @abstractmethod def expiration_time(self) -> datetime: ... diff --git a/src/graphon/runtime/graph_runtime_state.py b/src/graphon/runtime/graph_runtime_state.py index c4a9221..fd0c2cd 100644 --- a/src/graphon/runtime/graph_runtime_state.py +++ b/src/graphon/runtime/graph_runtime_state.py @@ -2,6 +2,7 @@ import importlib import json +from abc import abstractmethod from collections.abc import Mapping, Sequence from contextlib import AbstractContextManager, nullcontext from copy import deepcopy @@ -25,32 +26,39 @@ class ReadyQueueProtocol(Protocol): """Structural interface required from ready queue implementations.""" + @abstractmethod def put(self, item: str) -> None: """Enqueue the identifier of a node that is ready to run.""" ... + @abstractmethod def get(self, timeout: float | None = None) -> str: """Return the next node identifier, blocking until available or timeout expires. """ ... + @abstractmethod def task_done(self) -> None: """Signal that the most recently dequeued node has completed processing.""" ... + @abstractmethod def empty(self) -> bool: """Return True when the queue contains no pending nodes.""" ... + @abstractmethod def qsize(self) -> int: """Approximate the number of pending nodes awaiting execution.""" ... + @abstractmethod def dumps(self) -> str: """Serialize the queue contents for persistence.""" ... + @abstractmethod def loads(self, data: str) -> None: """Restore the queue contents from a serialized payload.""" ... @@ -63,18 +71,22 @@ class NodeExecutionProtocol(Protocol): retry_count: int execution_id: str | None + @abstractmethod def mark_started(self, execution_id: str) -> None: """Mark the node execution as started.""" ... + @abstractmethod def mark_taken(self) -> None: """Mark the node execution as successfully completed.""" ... + @abstractmethod def mark_failed(self, error: str) -> None: """Mark the node execution as failed with an error.""" ... + @abstractmethod def increment_retry(self) -> None: """Increment the retry counter for the node execution.""" ... @@ -98,52 +110,64 @@ class GraphExecutionProtocol(Protocol): pause_reasons: list[PauseReason] @property + @abstractmethod def node_executions(self) -> Mapping[str, NodeExecutionProtocol]: """Return the persisted node execution state keyed by node id.""" ... + @abstractmethod def start(self) -> None: """Transition execution into the running state.""" ... + @abstractmethod def complete(self) -> None: """Mark execution as successfully completed.""" ... + @abstractmethod def abort(self, reason: str) -> None: """Abort execution in response to an external stop request.""" ... + @abstractmethod def pause(self, reason: PauseReason) -> None: """Pause execution with a recorded reason.""" ... + @abstractmethod def fail(self, error: Exception) -> None: """Record an unrecoverable error and end execution.""" ... + @abstractmethod def record_node_failure(self) -> None: """Increment the count of node failures observed during execution.""" ... + @abstractmethod def get_or_create_node_execution(self, node_id: str) -> NodeExecutionProtocol: """Return the execution entity for a node, creating it when needed.""" ... @property + @abstractmethod def is_paused(self) -> bool: """Return whether the execution is currently paused.""" ... @property + @abstractmethod def has_error(self) -> bool: """Return whether the execution has recorded an error.""" ... + @abstractmethod def dumps(self) -> str: """Serialize execution state into a JSON payload.""" ... + @abstractmethod def loads(self, data: str) -> None: """Restore execution state from a previously serialized payload.""" ... @@ -152,18 +176,22 @@ def loads(self, data: str) -> None: class ResponseStreamCoordinatorProtocol(Protocol): """Structural interface for response stream coordinator.""" + @abstractmethod def register(self, response_node_id: str) -> None: """Register a response node so its outputs can be streamed.""" ... + @abstractmethod def track_node_execution(self, node_id: str, execution_id: str) -> None: """Track the current execution id for a node.""" ... + @abstractmethod def on_edge_taken(self, edge_id: str) -> Sequence[NodeRunStreamChunkEvent]: """Update pending response sessions after an edge is taken.""" ... + @abstractmethod def intercept_event( self, event: NodeRunStreamChunkEvent | NodeRunSucceededEvent, @@ -171,10 +199,12 @@ def intercept_event( """Translate node events into streamed response events.""" ... + @abstractmethod def loads(self, data: str) -> None: """Restore coordinator state from a serialized payload.""" ... + @abstractmethod def dumps(self) -> str: """Serialize coordinator state for persistence.""" ... @@ -188,6 +218,7 @@ class NodeProtocol(Protocol): execution_type: NodeExecutionType node_type: ClassVar[NodeType] + @abstractmethod def blocks_variable_output( self, variable_selectors: set[tuple[str, ...]], @@ -207,14 +238,24 @@ class GraphProtocol(Protocol): to the runtime state. """ - nodes: Mapping[str, NodeProtocol] - edges: Mapping[str, EdgeProtocol] - root_node: NodeProtocol + @property + @abstractmethod + def nodes(self) -> Mapping[str, NodeProtocol]: ... + + @property + @abstractmethod + def edges(self) -> Mapping[str, EdgeProtocol]: ... + + @property + @abstractmethod + def root_node(self) -> NodeProtocol: ... + @abstractmethod def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ... class ChildGraphEngineBuilderProtocol(Protocol): + @abstractmethod def build_child_engine( self, *, diff --git a/src/graphon/runtime/graph_runtime_state_protocol.py b/src/graphon/runtime/graph_runtime_state_protocol.py index 7b5c9ce..f4e1e31 100644 --- a/src/graphon/runtime/graph_runtime_state_protocol.py +++ b/src/graphon/runtime/graph_runtime_state_protocol.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from collections.abc import Mapping, Sequence from typing import Protocol @@ -8,10 +9,12 @@ class ReadOnlyVariablePool(Protocol): """Read-only interface for VariablePool.""" + @abstractmethod def get(self, selector: Sequence[str], /) -> Segment | None: """Get a variable value (read-only).""" ... + @abstractmethod def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]: """Get all variables stored under a given node prefix (read-only).""" ... @@ -26,49 +29,59 @@ class ReadOnlyGraphRuntimeState(Protocol): """ @property + @abstractmethod def variable_pool(self) -> ReadOnlyVariablePool: """Get read-only access to the variable pool.""" ... @property + @abstractmethod def start_at(self) -> float: """Get the start time (read-only).""" ... @property + @abstractmethod def total_tokens(self) -> int: """Get the total tokens count (read-only).""" ... @property + @abstractmethod def llm_usage(self) -> LLMUsage: """Get a copy of LLM usage info (read-only).""" ... @property + @abstractmethod def outputs(self) -> dict[str, object]: """Get a defensive copy of outputs (read-only).""" ... @property + @abstractmethod def node_run_steps(self) -> int: """Get the node run steps count (read-only).""" ... @property + @abstractmethod def ready_queue_size(self) -> int: """Get the number of nodes currently in the ready queue.""" ... @property + @abstractmethod def exceptions_count(self) -> int: """Get the number of node execution exceptions recorded.""" ... + @abstractmethod def get_output(self, key: str, default: object = None) -> object: """Get a single output value (returns a copy).""" ... + @abstractmethod def dumps(self) -> str: """Serialize the runtime state into a JSON snapshot (read-only).""" ... diff --git a/src/graphon/runtime/variable_pool.py b/src/graphon/runtime/variable_pool.py index 6fe8103..36787c3 100644 --- a/src/graphon/runtime/variable_pool.py +++ b/src/graphon/runtime/variable_pool.py @@ -4,7 +4,7 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from copy import deepcopy -from typing import TYPE_CHECKING, Annotated, Any, Self +from typing import Annotated, Any, Self from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -371,12 +371,3 @@ def flatten(self, *, unprefixed_node_id: str | None = None) -> Mapping[str, obje def empty(cls) -> VariablePool: """Create an empty variable pool.""" return cls() - - -if TYPE_CHECKING: - # static assertion to ensure VariablePool implements the - # ReadOnlyVariablePool. - from .graph_runtime_state_protocol import ReadOnlyVariablePool - - def _assert_readonly_variable_pool(pool: VariablePool) -> ReadOnlyVariablePool: # pyright: ignore[reportUnusedFunction] - return pool diff --git a/tests/graph_engine/test_response_coordinator.py b/tests/graph_engine/test_response_coordinator.py index 031ade2..e05536e 100644 --- a/tests/graph_engine/test_response_coordinator.py +++ b/tests/graph_engine/test_response_coordinator.py @@ -53,14 +53,22 @@ def __init__(self) -> None: class _TestGraph(GraphProtocol): - nodes: dict[str, _TestNode] - edges: dict[str, _TestEdge] - root_node: _TestNode - def __init__(self, root_node: _TestNode) -> None: - self.nodes = {root_node.id: root_node} - self.edges: dict[str, _TestEdge] = {} - self.root_node = root_node + self._nodes = {root_node.id: root_node} + self._edges: dict[str, _TestEdge] = {} + self._root_node = root_node + + @property + def nodes(self) -> dict[str, _TestNode]: + return self._nodes + + @property + def edges(self) -> dict[str, _TestEdge]: + return self._edges + + @property + def root_node(self) -> _TestNode: + return self._root_node def get_outgoing_edges(self, node_id: str) -> Sequence[_TestEdge]: _ = node_id diff --git a/tests/test_protocol_abstract_contracts.py b/tests/test_protocol_abstract_contracts.py new file mode 100644 index 0000000..8537f8b --- /dev/null +++ b/tests/test_protocol_abstract_contracts.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import ast +import importlib +import inspect +from pathlib import Path + +import pytest + + +def _is_direct_protocol_class( + class_def: ast.ClassDef, + *, + protocol_aliases: set[str], +) -> bool: + for base in class_def.bases: + base_expr = base.value if isinstance(base, ast.Subscript) else base + if isinstance(base_expr, ast.Name) and base_expr.id in protocol_aliases: + return True + if ( + isinstance(base_expr, ast.Attribute) + and isinstance(base_expr.value, ast.Name) + and f"{base_expr.value.id}.{base_expr.attr}" in protocol_aliases + ): + return True + return False + + +def _discover_protocol_aliases(parsed: ast.Module) -> set[str]: + protocol_aliases = set[str]() + typing_aliases: set[str] = set() + + for node in parsed.body: + if isinstance(node, ast.ImportFrom) and node.module in { + "typing", + "typing_extensions", + }: + for alias in node.names: + if alias.name == "Protocol": + protocol_aliases.add(alias.asname or alias.name) + continue + + if isinstance(node, ast.Import): + for alias in node.names: + if alias.name in {"typing", "typing_extensions"}: + typing_aliases.add(alias.asname or alias.name) + + protocol_aliases.update(f"{alias}.Protocol" for alias in typing_aliases) + return protocol_aliases + + +def _has_protocol_members(class_def: ast.ClassDef) -> bool: + return any( + isinstance(member, ast.FunctionDef | ast.AsyncFunctionDef) + for member in class_def.body + ) + + +def _discover_protocol_targets() -> list[type[object]]: + src_root = Path(__file__).resolve().parents[1] / "src" / "graphon" + protocol_classes: list[type[object]] = [] + + for file_path in sorted(src_root.rglob("*.py")): + parsed = ast.parse(file_path.read_text()) + protocol_aliases = _discover_protocol_aliases(parsed) + if not protocol_aliases: + continue + + module_name = "graphon." + ".".join( + file_path.relative_to(src_root).with_suffix("").parts, + ) + module = importlib.import_module(module_name) + + for class_def in [ + node for node in parsed.body if isinstance(node, ast.ClassDef) + ]: + if not _is_direct_protocol_class( + class_def, + protocol_aliases=protocol_aliases, + ): + continue + if not _has_protocol_members(class_def): + continue + protocol_classes.append(getattr(module, class_def.name)) + + protocol_classes.sort(key=lambda cls: (cls.__module__, cls.__qualname__)) + return protocol_classes + + +def _protocol_member_names(protocol_cls: type[object]) -> list[str]: + member_names: list[str] = [] + for name, value in protocol_cls.__dict__.items(): + if name.startswith("__") and name.endswith("__"): + continue + if isinstance(value, property | classmethod | staticmethod): + member_names.append(name) + continue + if inspect.isfunction(value): + member_names.append(name) + return member_names + + +PROTOCOL_TARGETS = _discover_protocol_targets() + + +def _protocol_id(protocol_cls: type[object]) -> str: + return f"{protocol_cls.__module__}.{protocol_cls.__qualname__}" + + +def test_protocol_targets_should_be_discovered() -> None: + assert PROTOCOL_TARGETS + + +@pytest.mark.parametrize( + "protocol_cls", + PROTOCOL_TARGETS, + ids=_protocol_id, +) +def test_protocol_members_should_be_abstract(protocol_cls: type[object]) -> None: + member_names = _protocol_member_names(protocol_cls) + # This protects the test from a vacuous pass if discovery and runtime member + # detection drift apart. Discovery only targets Protocol classes with + # methods, so an empty member list means the test is no longer checking the + # contract it claims to check. + assert member_names, f"{_protocol_id(protocol_cls)} has no protocol members." + + non_abstract_members = [ + name + for name in member_names + if not getattr(protocol_cls.__dict__[name], "__isabstractmethod__", False) + ] + assert non_abstract_members == []