Skip to content

Commit 541f6af

Browse files
feat: expose session context to model plugins via get_current_session()
1 parent 705ca0d commit 541f6af

5 files changed

Lines changed: 418 additions & 70 deletions

File tree

src/dify_plugin/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
monkey.patch_all(sys=True)
55

66
from dify_plugin.config.config import DifyPluginEnv
7+
from dify_plugin.core.session_context import get_current_session
78
from dify_plugin.interfaces.agent import AgentProvider, AgentStrategy
89
from dify_plugin.interfaces.endpoint import Endpoint
910
from dify_plugin.interfaces.model import ModelProvider
@@ -51,4 +52,5 @@
5152
"TextEmbeddingModel",
5253
"Tool",
5354
"ToolProvider",
55+
"get_current_session",
5456
]

src/dify_plugin/core/plugin_executor.py

Lines changed: 107 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
)
5252
from dify_plugin.core.plugin_registration import PluginRegistration
5353
from dify_plugin.core.runtime import Session
54+
from dify_plugin.core.session_context import _current_session
5455
from dify_plugin.core.utils.http_parser import deserialize_request, serialize_response
5556
from dify_plugin.entities import ParameterOption
5657
from dify_plugin.entities.agent import AgentRuntime
@@ -242,16 +243,24 @@ def invoke_llm(self, session: Session, data: ModelInvokeLLMRequest) -> object:
242243
data.model_type,
243244
)
244245
if isinstance(model_instance, LargeLanguageModel):
245-
return model_instance.invoke(
246-
data.model,
247-
data.credentials,
248-
data.prompt_messages,
249-
data.model_parameters,
250-
data.tools,
251-
data.stop,
252-
data.stream,
253-
data.user_id,
254-
)
246+
247+
def _with_session_context() -> Generator:
248+
token = _current_session.set(session)
249+
try:
250+
yield from model_instance.invoke(
251+
data.model,
252+
data.credentials,
253+
data.prompt_messages,
254+
data.model_parameters,
255+
data.tools,
256+
data.stop,
257+
data.stream,
258+
data.user_id,
259+
)
260+
finally:
261+
_current_session.reset(token)
262+
263+
return _with_session_context()
255264
msg = f"Model `{data.model_type}` not found for provider `{data.provider}`"
256265
raise ValueError(
257266
msg,
@@ -291,12 +300,16 @@ def invoke_text_embedding(
291300
data.model_type,
292301
)
293302
if isinstance(model_instance, TextEmbeddingModel):
294-
return model_instance.invoke(
295-
data.model,
296-
data.credentials,
297-
data.texts,
298-
data.user_id,
299-
)
303+
token = _current_session.set(session)
304+
try:
305+
return model_instance.invoke(
306+
data.model,
307+
data.credentials,
308+
data.texts,
309+
data.user_id,
310+
)
311+
finally:
312+
_current_session.reset(token)
300313
msg = f"Model `{data.model_type}` not found for provider `{data.provider}`"
301314
raise ValueError(
302315
msg,
@@ -312,13 +325,17 @@ def invoke_multimodal_embedding(
312325
data.model_type,
313326
)
314327
if isinstance(model_instance, TextEmbeddingModel):
315-
return model_instance.invoke_multimodal(
316-
data.model,
317-
data.credentials,
318-
data.documents,
319-
user=data.user_id,
320-
input_type=data.input_type,
321-
)
328+
token = _current_session.set(session)
329+
try:
330+
return model_instance.invoke_multimodal(
331+
data.model,
332+
data.credentials,
333+
data.documents,
334+
user=data.user_id,
335+
input_type=data.input_type,
336+
)
337+
finally:
338+
_current_session.reset(token)
322339
msg = f"Model `{data.model_type}` not found for provider `{data.provider}`"
323340
raise ValueError(
324341
msg,
@@ -352,15 +369,19 @@ def invoke_rerank(self, session: Session, data: ModelInvokeRerankRequest) -> obj
352369
data.model_type,
353370
)
354371
if isinstance(model_instance, RerankModel):
355-
return model_instance.invoke(
356-
data.model,
357-
data.credentials,
358-
data.query,
359-
data.docs,
360-
data.score_threshold,
361-
data.top_n,
362-
data.user_id,
363-
)
372+
token = _current_session.set(session)
373+
try:
374+
return model_instance.invoke(
375+
data.model,
376+
data.credentials,
377+
data.query,
378+
data.docs,
379+
data.score_threshold,
380+
data.top_n,
381+
data.user_id,
382+
)
383+
finally:
384+
_current_session.reset(token)
364385
msg = f"Model `{data.model_type}` not found for provider `{data.provider}`"
365386
raise ValueError(
366387
msg,
@@ -376,15 +397,19 @@ def invoke_multimodal_rerank(
376397
data.model_type,
377398
)
378399
if isinstance(model_instance, RerankModel):
379-
return model_instance.invoke_multimodal(
380-
data.model,
381-
data.credentials,
382-
data.query,
383-
data.docs,
384-
score_threshold=data.score_threshold,
385-
top_n=data.top_n,
386-
user=data.user_id,
387-
)
400+
token = _current_session.set(session)
401+
try:
402+
return model_instance.invoke_multimodal(
403+
data.model,
404+
data.credentials,
405+
data.query,
406+
data.docs,
407+
score_threshold=data.score_threshold,
408+
top_n=data.top_n,
409+
user=data.user_id,
410+
)
411+
finally:
412+
_current_session.reset(token)
388413
msg = f"Model `{data.model_type}` not found for provider `{data.provider}`"
389414
raise ValueError(
390415
msg,
@@ -400,20 +425,24 @@ def invoke_tts(
400425
data.model_type,
401426
)
402427
if isinstance(model_instance, TTSModel):
403-
b = model_instance.invoke(
404-
data.model,
405-
data.tenant_id,
406-
data.credentials,
407-
data.content_text,
408-
data.voice,
409-
data.user_id,
410-
)
411-
if isinstance(b, bytes | bytearray | memoryview):
412-
yield {"result": binascii.hexlify(b).decode()}
413-
return
428+
token = _current_session.set(session)
429+
try:
430+
b = model_instance.invoke(
431+
data.model,
432+
data.tenant_id,
433+
data.credentials,
434+
data.content_text,
435+
data.voice,
436+
data.user_id,
437+
)
438+
if isinstance(b, bytes | bytearray | memoryview):
439+
yield {"result": binascii.hexlify(b).decode()}
440+
return
414441

415-
for chunk in b:
416-
yield {"result": binascii.hexlify(chunk).decode()}
442+
for chunk in b:
443+
yield {"result": binascii.hexlify(chunk).decode()}
444+
finally:
445+
_current_session.reset(token)
417446
else:
418447
msg = f"Model `{data.model_type}` not found for provider `{data.provider}`"
419448
raise ValueError(
@@ -458,14 +487,18 @@ def invoke_speech_to_text(
458487

459488
with pathlib.Path(temp.name).open("rb") as f:
460489
if isinstance(model_instance, Speech2TextModel):
461-
return {
462-
"result": model_instance.invoke(
463-
data.model,
464-
data.credentials,
465-
f,
466-
data.user_id,
467-
),
468-
}
490+
token = _current_session.set(session)
491+
try:
492+
return {
493+
"result": model_instance.invoke(
494+
data.model,
495+
data.credentials,
496+
f,
497+
data.user_id,
498+
),
499+
}
500+
finally:
501+
_current_session.reset(token)
469502
msg = (
470503
f"Model `{data.model_type}` not found for provider "
471504
f"`{data.provider}`"
@@ -506,14 +539,18 @@ def invoke_moderation(
506539
)
507540

508541
if isinstance(model_instance, ModerationModel):
509-
return {
510-
"result": model_instance.invoke(
511-
data.model,
512-
data.credentials,
513-
data.text,
514-
data.user_id,
515-
),
516-
}
542+
token = _current_session.set(session)
543+
try:
544+
return {
545+
"result": model_instance.invoke(
546+
data.model,
547+
data.credentials,
548+
data.text,
549+
data.user_id,
550+
),
551+
}
552+
finally:
553+
_current_session.reset(token)
517554
msg = f"Model `{data.model_type}` not found for provider `{data.provider}`"
518555
raise ValueError(
519556
msg,
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""
2+
Request-scoped session context for model plugins.
3+
4+
Model plugins (LLM, Embedding, Rerank, etc.) do not receive the Session
5+
object through their ``_invoke()`` signature — unlike tool plugins which
6+
get it via their constructor. This module bridges that gap by storing
7+
the current Session in a :class:`~contextvars.ContextVar` so that model
8+
plugin code can retrieve it on demand via :func:`get_current_session`.
9+
10+
Usage in a custom model plugin::
11+
12+
from dify_plugin.core.session_context import get_current_session
13+
14+
class MyLLM(LargeLanguageModel):
15+
def _invoke(self, model, credentials, prompt_messages, ...):
16+
session = get_current_session()
17+
if session and session.app_id:
18+
# tag the request with the originating Dify app
19+
...
20+
21+
Note on ``app_id`` being ``None``:
22+
23+
``session.app_id`` is ``None`` when the model is invoked outside of
24+
an app execution context — for example, RAG routing, conversation
25+
title generation, or suggested question generation. These calls
26+
represent shared infrastructure costs not attributable to a specific
27+
app.
28+
29+
When building provider-side cost dashboards, the recommended
30+
approach is:
31+
32+
* If ``app_id`` is not ``None``, tag the request with it for
33+
per-app cost attribution.
34+
* If ``app_id`` is ``None``, either skip tagging or use a
35+
sentinel value such as ``"dify_system"`` to bucket these
36+
calls separately from external (non-Dify) traffic.
37+
"""
38+
39+
from __future__ import annotations
40+
41+
from contextvars import ContextVar
42+
from typing import TYPE_CHECKING
43+
44+
if TYPE_CHECKING:
45+
from dify_plugin.core.runtime import Session
46+
47+
_current_session: ContextVar[Session | None] = ContextVar(
48+
"_current_session", default=None
49+
)
50+
51+
52+
def get_current_session() -> Session | None:
53+
"""Return the :class:`Session` for the current model invocation, or
54+
``None`` when called outside of a plugin dispatch context.
55+
56+
Returns:
57+
The current session, or ``None``.
58+
"""
59+
return _current_session.get()

tests/core/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)