diff --git a/src/dify_plugin/__init__.py b/src/dify_plugin/__init__.py index f35c25fd..101bdfe0 100644 --- a/src/dify_plugin/__init__.py +++ b/src/dify_plugin/__init__.py @@ -4,6 +4,7 @@ monkey.patch_all(sys=True) from dify_plugin.config.config import DifyPluginEnv +from dify_plugin.core.session_context import get_current_session from dify_plugin.interfaces.agent import AgentProvider, AgentStrategy from dify_plugin.interfaces.endpoint import Endpoint from dify_plugin.interfaces.model import ModelProvider @@ -51,4 +52,5 @@ "TextEmbeddingModel", "Tool", "ToolProvider", + "get_current_session", ] diff --git a/src/dify_plugin/core/plugin_executor.py b/src/dify_plugin/core/plugin_executor.py index 1aa8668a..b2ac252e 100644 --- a/src/dify_plugin/core/plugin_executor.py +++ b/src/dify_plugin/core/plugin_executor.py @@ -51,6 +51,7 @@ ) from dify_plugin.core.plugin_registration import PluginRegistration from dify_plugin.core.runtime import Session +from dify_plugin.core.session_context import _current_session from dify_plugin.core.utils.http_parser import deserialize_request, serialize_response from dify_plugin.entities import ParameterOption from dify_plugin.entities.agent import AgentRuntime @@ -242,16 +243,24 @@ def invoke_llm(self, session: Session, data: ModelInvokeLLMRequest) -> object: data.model_type, ) if isinstance(model_instance, LargeLanguageModel): - return model_instance.invoke( - data.model, - data.credentials, - data.prompt_messages, - data.model_parameters, - data.tools, - data.stop, - data.stream, - data.user_id, - ) + + def _with_session_context() -> Generator: + token = _current_session.set(session) + try: + yield from model_instance.invoke( + data.model, + data.credentials, + data.prompt_messages, + data.model_parameters, + data.tools, + data.stop, + data.stream, + data.user_id, + ) + finally: + _current_session.reset(token) + + return _with_session_context() msg = f"Model `{data.model_type}` not found for provider `{data.provider}`" raise ValueError( msg, @@ -291,12 +300,16 @@ def invoke_text_embedding( data.model_type, ) if isinstance(model_instance, TextEmbeddingModel): - return model_instance.invoke( - data.model, - data.credentials, - data.texts, - data.user_id, - ) + token = _current_session.set(session) + try: + return model_instance.invoke( + data.model, + data.credentials, + data.texts, + data.user_id, + ) + finally: + _current_session.reset(token) msg = f"Model `{data.model_type}` not found for provider `{data.provider}`" raise ValueError( msg, @@ -312,13 +325,17 @@ def invoke_multimodal_embedding( data.model_type, ) if isinstance(model_instance, TextEmbeddingModel): - return model_instance.invoke_multimodal( - data.model, - data.credentials, - data.documents, - user=data.user_id, - input_type=data.input_type, - ) + token = _current_session.set(session) + try: + return model_instance.invoke_multimodal( + data.model, + data.credentials, + data.documents, + user=data.user_id, + input_type=data.input_type, + ) + finally: + _current_session.reset(token) msg = f"Model `{data.model_type}` not found for provider `{data.provider}`" raise ValueError( msg, @@ -352,15 +369,19 @@ def invoke_rerank(self, session: Session, data: ModelInvokeRerankRequest) -> obj data.model_type, ) if isinstance(model_instance, RerankModel): - return model_instance.invoke( - data.model, - data.credentials, - data.query, - data.docs, - data.score_threshold, - data.top_n, - data.user_id, - ) + token = _current_session.set(session) + try: + return model_instance.invoke( + data.model, + data.credentials, + data.query, + data.docs, + data.score_threshold, + data.top_n, + data.user_id, + ) + finally: + _current_session.reset(token) msg = f"Model `{data.model_type}` not found for provider `{data.provider}`" raise ValueError( msg, @@ -376,15 +397,19 @@ def invoke_multimodal_rerank( data.model_type, ) if isinstance(model_instance, RerankModel): - return model_instance.invoke_multimodal( - data.model, - data.credentials, - data.query, - data.docs, - score_threshold=data.score_threshold, - top_n=data.top_n, - user=data.user_id, - ) + token = _current_session.set(session) + try: + return model_instance.invoke_multimodal( + data.model, + data.credentials, + data.query, + data.docs, + score_threshold=data.score_threshold, + top_n=data.top_n, + user=data.user_id, + ) + finally: + _current_session.reset(token) msg = f"Model `{data.model_type}` not found for provider `{data.provider}`" raise ValueError( msg, @@ -400,20 +425,24 @@ def invoke_tts( data.model_type, ) if isinstance(model_instance, TTSModel): - b = model_instance.invoke( - data.model, - data.tenant_id, - data.credentials, - data.content_text, - data.voice, - data.user_id, - ) - if isinstance(b, bytes | bytearray | memoryview): - yield {"result": binascii.hexlify(b).decode()} - return + token = _current_session.set(session) + try: + b = model_instance.invoke( + data.model, + data.tenant_id, + data.credentials, + data.content_text, + data.voice, + data.user_id, + ) + if isinstance(b, bytes | bytearray | memoryview): + yield {"result": binascii.hexlify(b).decode()} + return - for chunk in b: - yield {"result": binascii.hexlify(chunk).decode()} + for chunk in b: + yield {"result": binascii.hexlify(chunk).decode()} + finally: + _current_session.reset(token) else: msg = f"Model `{data.model_type}` not found for provider `{data.provider}`" raise ValueError( @@ -458,14 +487,18 @@ def invoke_speech_to_text( with pathlib.Path(temp.name).open("rb") as f: if isinstance(model_instance, Speech2TextModel): - return { - "result": model_instance.invoke( - data.model, - data.credentials, - f, - data.user_id, - ), - } + token = _current_session.set(session) + try: + return { + "result": model_instance.invoke( + data.model, + data.credentials, + f, + data.user_id, + ), + } + finally: + _current_session.reset(token) msg = ( f"Model `{data.model_type}` not found for provider " f"`{data.provider}`" @@ -506,14 +539,18 @@ def invoke_moderation( ) if isinstance(model_instance, ModerationModel): - return { - "result": model_instance.invoke( - data.model, - data.credentials, - data.text, - data.user_id, - ), - } + token = _current_session.set(session) + try: + return { + "result": model_instance.invoke( + data.model, + data.credentials, + data.text, + data.user_id, + ), + } + finally: + _current_session.reset(token) msg = f"Model `{data.model_type}` not found for provider `{data.provider}`" raise ValueError( msg, diff --git a/src/dify_plugin/core/session_context.py b/src/dify_plugin/core/session_context.py new file mode 100644 index 00000000..86a64283 --- /dev/null +++ b/src/dify_plugin/core/session_context.py @@ -0,0 +1,59 @@ +""" +Request-scoped session context for model plugins. + +Model plugins (LLM, Embedding, Rerank, etc.) do not receive the Session +object through their ``_invoke()`` signature — unlike tool plugins which +get it via their constructor. This module bridges that gap by storing +the current Session in a :class:`~contextvars.ContextVar` so that model +plugin code can retrieve it on demand via :func:`get_current_session`. + +Usage in a custom model plugin:: + + from dify_plugin.core.session_context import get_current_session + + class MyLLM(LargeLanguageModel): + def _invoke(self, model, credentials, prompt_messages, ...): + session = get_current_session() + if session and session.app_id: + # tag the request with the originating Dify app + ... + +Note on ``app_id`` being ``None``: + + ``session.app_id`` is ``None`` when the model is invoked outside of + an app execution context — for example, RAG routing, conversation + title generation, or suggested question generation. These calls + represent shared infrastructure costs not attributable to a specific + app. + + When building provider-side cost dashboards, the recommended + approach is: + + * If ``app_id`` is not ``None``, tag the request with it for + per-app cost attribution. + * If ``app_id`` is ``None``, either skip tagging or use a + sentinel value such as ``"dify_system"`` to bucket these + calls separately from external (non-Dify) traffic. +""" + +from __future__ import annotations + +from contextvars import ContextVar +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dify_plugin.core.runtime import Session + +_current_session: ContextVar[Session | None] = ContextVar( + "_current_session", default=None +) + + +def get_current_session() -> Session | None: + """Return the :class:`Session` for the current model invocation, or + ``None`` when called outside of a plugin dispatch context. + + Returns: + The current session, or ``None``. + """ + return _current_session.get() diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/test_session_context.py b/tests/core/test_session_context.py new file mode 100644 index 00000000..4bad8459 --- /dev/null +++ b/tests/core/test_session_context.py @@ -0,0 +1,250 @@ +"""Tests for dify_plugin.core.session_context — ContextVar-based session propagation. + +Best practices applied: +- Each test cleans up ContextVar state via an autouse fixture. +- ContextVar tests and plugin_executor integration tests are separated. +- Thread isolation is verified with ThreadPoolExecutor. +""" + +from __future__ import annotations + +import contextlib +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import MagicMock + +import pytest + +from dify_plugin.core.runtime import Session +from dify_plugin.core.session_context import ( + _current_session, # noqa: PLC2701 + get_current_session, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _clean_context_var() -> None: + """Ensure _current_session is None before and after every test.""" + # Reset to default before the test + token = _current_session.set(None) # type: ignore[arg-type] + _current_session.reset(token) + yield # type: ignore[misc] + # Reset to default after the test (in case a test forgot to clean up) + with contextlib.suppress(ValueError): + _current_session.set(None) # type: ignore[arg-type] + + +def _make_session(app_id: str | None = "app-test-123") -> Session: + """Create a minimal Session with the given app_id.""" + return Session( + session_id="sess-1", + executor=ThreadPoolExecutor(max_workers=1), + reader=MagicMock(), + writer=MagicMock(), + app_id=app_id, + ) + + +# --------------------------------------------------------------------------- +# 1. get_current_session() — basic ContextVar behaviour +# --------------------------------------------------------------------------- + + +class TestGetCurrentSessionBasic: + """Verify the raw ContextVar get/set/reset semantics.""" + + def test_returns_none_by_default(self) -> None: + assert get_current_session() is None + + def test_returns_session_after_set(self) -> None: + session = _make_session() + token = _current_session.set(session) + try: + assert get_current_session() is session + finally: + _current_session.reset(token) + + def test_returns_none_after_reset(self) -> None: + session = _make_session() + token = _current_session.set(session) + _current_session.reset(token) + assert get_current_session() is None + + def test_nested_set_restores_previous_on_reset(self) -> None: + """ContextVar spec: reset restores the value before the corresponding set.""" + session_a = _make_session(app_id="app-a") + session_b = _make_session(app_id="app-b") + + token_a = _current_session.set(session_a) + token_b = _current_session.set(session_b) + + assert get_current_session() is session_b + + _current_session.reset(token_b) + assert get_current_session() is session_a + + _current_session.reset(token_a) + assert get_current_session() is None + + def test_thread_isolation(self) -> None: + """A session set in the main thread must not be visible in a worker thread.""" + session = _make_session() + token = _current_session.set(session) + + results: list[Session | None] = [] + + def _worker() -> None: + results.append(get_current_session()) + + with ThreadPoolExecutor(max_workers=1) as pool: + pool.submit(_worker).result() + + _current_session.reset(token) + assert results == [None] + + +# --------------------------------------------------------------------------- +# 2. app_id access through the session +# --------------------------------------------------------------------------- + + +class TestAppIdAccess: + def test_app_id_is_accessible(self) -> None: + session = _make_session(app_id="app-456") + token = _current_session.set(session) + try: + current = get_current_session() + assert current is not None + assert current.app_id == "app-456" + finally: + _current_session.reset(token) + + def test_app_id_none_for_out_of_app_context(self) -> None: + session = _make_session(app_id=None) + token = _current_session.set(session) + try: + current = get_current_session() + assert current is not None + assert current.app_id is None + finally: + _current_session.reset(token) + + +# --------------------------------------------------------------------------- +# 3. plugin_executor integration +# +# Instead of instantiating real LargeLanguageModel subclasses (which pull in +# graphon and require complex setup), we directly test that the ContextVar +# is set/reset around the model invocation by patching the model instance. +# --------------------------------------------------------------------------- + + +class TestPluginExecutorSessionPropagation: + """Verify plugin_executor.invoke_llm sets/resets the ContextVar.""" + + @staticmethod + def _make_executor_with_mock_llm( + invoke_side_effect: object = None, + ) -> tuple[object, MagicMock]: + """Return (executor, mock_model) with mock_model passing isinstance checks.""" + from dify_plugin.core.plugin_executor import PluginExecutor + from dify_plugin.interfaces.model.large_language_model import ( + LargeLanguageModel, + ) + + mock_model = MagicMock(spec=LargeLanguageModel) + if invoke_side_effect is not None: + mock_model.invoke.side_effect = invoke_side_effect + else: + mock_model.invoke.return_value = iter([]) + + config = MagicMock() + registration = MagicMock() + registration.get_model_instance.return_value = mock_model + + executor = PluginExecutor(config=config, registration=registration) + return executor, mock_model + + @staticmethod + def _make_llm_data() -> MagicMock: + data = MagicMock() + data.provider = "test-provider" + data.model_type = "llm" + data.model = "gpt-test" + data.credentials = {} + data.prompt_messages = [] + data.model_parameters = {} + data.tools = None + data.stop = None + data.stream = False + data.user_id = "user-1" + return data + + def test_session_is_set_during_invoke(self) -> None: + """get_current_session() returns the session while invoke is running.""" + captured: list[Session | None] = [] + + def _capture_invoke(*_a: object, **_kw: object) -> list: + captured.append(get_current_session()) + return iter([]) # type: ignore[return-value] + + executor, _ = self._make_executor_with_mock_llm( + invoke_side_effect=_capture_invoke, + ) + session = _make_session(app_id="app-during") + data = self._make_llm_data() + + result = executor.invoke_llm(session, data) + if hasattr(result, "__iter__") and not isinstance(result, str | bytes): + list(result) + + assert len(captured) == 1 + assert captured[0] is session + assert captured[0].app_id == "app-during" + + def test_session_is_reset_after_invoke(self) -> None: + """After invoke_llm returns, get_current_session() is None.""" + executor, _ = self._make_executor_with_mock_llm() + session = _make_session(app_id="app-after") + data = self._make_llm_data() + + result = executor.invoke_llm(session, data) + if hasattr(result, "__iter__") and not isinstance(result, str | bytes): + list(result) + + assert get_current_session() is None + + def test_session_is_reset_on_exception(self) -> None: + """If invoke raises during iteration, the ContextVar is still cleaned up.""" + executor, _ = self._make_executor_with_mock_llm( + invoke_side_effect=RuntimeError("boom"), + ) + session = _make_session(app_id="app-error") + data = self._make_llm_data() + + result = executor.invoke_llm(session, data) + # Exception occurs when the generator is consumed + with pytest.raises(RuntimeError, match="boom"): + list(result) + + assert get_current_session() is None + + +# --------------------------------------------------------------------------- +# 4. Public API surface +# --------------------------------------------------------------------------- + + +class TestPublicApi: + def test_importable_from_top_level(self) -> None: + from dify_plugin import get_current_session as fn + + assert callable(fn) + + def test_listed_in_all(self) -> None: + import dify_plugin + + assert "get_current_session" in dify_plugin.__all__