Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/dify_plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
monkey.patch_all(sys=True)

from dify_plugin.config.config import DifyPluginEnv
from dify_plugin.core.session_context import get_current_session
from dify_plugin.interfaces.agent import AgentProvider, AgentStrategy
from dify_plugin.interfaces.endpoint import Endpoint
from dify_plugin.interfaces.model import ModelProvider
Expand Down Expand Up @@ -51,4 +52,5 @@
"TextEmbeddingModel",
"Tool",
"ToolProvider",
"get_current_session",
]
177 changes: 107 additions & 70 deletions src/dify_plugin/core/plugin_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
)
from dify_plugin.core.plugin_registration import PluginRegistration
from dify_plugin.core.runtime import Session
from dify_plugin.core.session_context import _current_session
from dify_plugin.core.utils.http_parser import deserialize_request, serialize_response
from dify_plugin.entities import ParameterOption
from dify_plugin.entities.agent import AgentRuntime
Expand Down Expand Up @@ -242,16 +243,24 @@ def invoke_llm(self, session: Session, data: ModelInvokeLLMRequest) -> object:
data.model_type,
)
if isinstance(model_instance, LargeLanguageModel):
return model_instance.invoke(
data.model,
data.credentials,
data.prompt_messages,
data.model_parameters,
data.tools,
data.stop,
data.stream,
data.user_id,
)

def _with_session_context() -> Generator:
token = _current_session.set(session)
try:
yield from model_instance.invoke(
data.model,
data.credentials,
data.prompt_messages,
data.model_parameters,
data.tools,
data.stop,
data.stream,
data.user_id,
)
finally:
_current_session.reset(token)

return _with_session_context()
msg = f"Model `{data.model_type}` not found for provider `{data.provider}`"
raise ValueError(
msg,
Expand Down Expand Up @@ -291,12 +300,16 @@ def invoke_text_embedding(
data.model_type,
)
if isinstance(model_instance, TextEmbeddingModel):
return model_instance.invoke(
data.model,
data.credentials,
data.texts,
data.user_id,
)
token = _current_session.set(session)
try:
return model_instance.invoke(
data.model,
data.credentials,
data.texts,
data.user_id,
)
finally:
_current_session.reset(token)
msg = f"Model `{data.model_type}` not found for provider `{data.provider}`"
raise ValueError(
msg,
Expand All @@ -312,13 +325,17 @@ def invoke_multimodal_embedding(
data.model_type,
)
if isinstance(model_instance, TextEmbeddingModel):
return model_instance.invoke_multimodal(
data.model,
data.credentials,
data.documents,
user=data.user_id,
input_type=data.input_type,
)
token = _current_session.set(session)
try:
return model_instance.invoke_multimodal(
data.model,
data.credentials,
data.documents,
user=data.user_id,
input_type=data.input_type,
)
finally:
_current_session.reset(token)
msg = f"Model `{data.model_type}` not found for provider `{data.provider}`"
raise ValueError(
msg,
Expand Down Expand Up @@ -352,15 +369,19 @@ def invoke_rerank(self, session: Session, data: ModelInvokeRerankRequest) -> obj
data.model_type,
)
if isinstance(model_instance, RerankModel):
return model_instance.invoke(
data.model,
data.credentials,
data.query,
data.docs,
data.score_threshold,
data.top_n,
data.user_id,
)
token = _current_session.set(session)
try:
return model_instance.invoke(
data.model,
data.credentials,
data.query,
data.docs,
data.score_threshold,
data.top_n,
data.user_id,
)
finally:
_current_session.reset(token)
msg = f"Model `{data.model_type}` not found for provider `{data.provider}`"
raise ValueError(
msg,
Expand All @@ -376,15 +397,19 @@ def invoke_multimodal_rerank(
data.model_type,
)
if isinstance(model_instance, RerankModel):
return model_instance.invoke_multimodal(
data.model,
data.credentials,
data.query,
data.docs,
score_threshold=data.score_threshold,
top_n=data.top_n,
user=data.user_id,
)
token = _current_session.set(session)
try:
return model_instance.invoke_multimodal(
data.model,
data.credentials,
data.query,
data.docs,
score_threshold=data.score_threshold,
top_n=data.top_n,
user=data.user_id,
)
finally:
_current_session.reset(token)
msg = f"Model `{data.model_type}` not found for provider `{data.provider}`"
raise ValueError(
msg,
Expand All @@ -400,20 +425,24 @@ def invoke_tts(
data.model_type,
)
if isinstance(model_instance, TTSModel):
b = model_instance.invoke(
data.model,
data.tenant_id,
data.credentials,
data.content_text,
data.voice,
data.user_id,
)
if isinstance(b, bytes | bytearray | memoryview):
yield {"result": binascii.hexlify(b).decode()}
return
token = _current_session.set(session)
try:
b = model_instance.invoke(
data.model,
data.tenant_id,
data.credentials,
data.content_text,
data.voice,
data.user_id,
)
if isinstance(b, bytes | bytearray | memoryview):
yield {"result": binascii.hexlify(b).decode()}
return

for chunk in b:
yield {"result": binascii.hexlify(chunk).decode()}
for chunk in b:
yield {"result": binascii.hexlify(chunk).decode()}
finally:
_current_session.reset(token)
else:
msg = f"Model `{data.model_type}` not found for provider `{data.provider}`"
raise ValueError(
Expand Down Expand Up @@ -458,14 +487,18 @@ def invoke_speech_to_text(

with pathlib.Path(temp.name).open("rb") as f:
if isinstance(model_instance, Speech2TextModel):
return {
"result": model_instance.invoke(
data.model,
data.credentials,
f,
data.user_id,
),
}
token = _current_session.set(session)
try:
return {
"result": model_instance.invoke(
data.model,
data.credentials,
f,
data.user_id,
),
}
finally:
_current_session.reset(token)
msg = (
f"Model `{data.model_type}` not found for provider "
f"`{data.provider}`"
Expand Down Expand Up @@ -506,14 +539,18 @@ def invoke_moderation(
)

if isinstance(model_instance, ModerationModel):
return {
"result": model_instance.invoke(
data.model,
data.credentials,
data.text,
data.user_id,
),
}
token = _current_session.set(session)
try:
return {
"result": model_instance.invoke(
data.model,
data.credentials,
data.text,
data.user_id,
),
}
finally:
_current_session.reset(token)
msg = f"Model `{data.model_type}` not found for provider `{data.provider}`"
raise ValueError(
msg,
Expand Down
59 changes: 59 additions & 0 deletions src/dify_plugin/core/session_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
Request-scoped session context for model plugins.

Model plugins (LLM, Embedding, Rerank, etc.) do not receive the Session
object through their ``_invoke()`` signature — unlike tool plugins which
get it via their constructor. This module bridges that gap by storing
the current Session in a :class:`~contextvars.ContextVar` so that model
plugin code can retrieve it on demand via :func:`get_current_session`.

Usage in a custom model plugin::

from dify_plugin.core.session_context import get_current_session

class MyLLM(LargeLanguageModel):
def _invoke(self, model, credentials, prompt_messages, ...):
session = get_current_session()
if session and session.app_id:
# tag the request with the originating Dify app
...

Note on ``app_id`` being ``None``:

``session.app_id`` is ``None`` when the model is invoked outside of
an app execution context — for example, RAG routing, conversation
title generation, or suggested question generation. These calls
represent shared infrastructure costs not attributable to a specific
app.

When building provider-side cost dashboards, the recommended
approach is:

* If ``app_id`` is not ``None``, tag the request with it for
per-app cost attribution.
* If ``app_id`` is ``None``, either skip tagging or use a
sentinel value such as ``"dify_system"`` to bucket these
calls separately from external (non-Dify) traffic.
"""

from __future__ import annotations

from contextvars import ContextVar
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from dify_plugin.core.runtime import Session

_current_session: ContextVar[Session | None] = ContextVar(
"_current_session", default=None
)


def get_current_session() -> Session | None:
"""Return the :class:`Session` for the current model invocation, or
``None`` when called outside of a plugin dispatch context.

Returns:
The current session, or ``None``.
"""
return _current_session.get()
Empty file added tests/core/__init__.py
Empty file.
Loading