5151)
5252from dify_plugin .core .plugin_registration import PluginRegistration
5353from dify_plugin .core .runtime import Session
54+ from dify_plugin .core .session_context import _current_session
5455from dify_plugin .core .utils .http_parser import deserialize_request , serialize_response
5556from dify_plugin .entities import ParameterOption
5657from dify_plugin .entities .agent import AgentRuntime
@@ -242,16 +243,24 @@ def invoke_llm(self, session: Session, data: ModelInvokeLLMRequest) -> object:
242243 data .model_type ,
243244 )
244245 if isinstance (model_instance , LargeLanguageModel ):
245- return model_instance .invoke (
246- data .model ,
247- data .credentials ,
248- data .prompt_messages ,
249- data .model_parameters ,
250- data .tools ,
251- data .stop ,
252- data .stream ,
253- data .user_id ,
254- )
246+
247+ def _with_session_context () -> Generator :
248+ token = _current_session .set (session )
249+ try :
250+ yield from model_instance .invoke (
251+ data .model ,
252+ data .credentials ,
253+ data .prompt_messages ,
254+ data .model_parameters ,
255+ data .tools ,
256+ data .stop ,
257+ data .stream ,
258+ data .user_id ,
259+ )
260+ finally :
261+ _current_session .reset (token )
262+
263+ return _with_session_context ()
255264 msg = f"Model `{ data .model_type } ` not found for provider `{ data .provider } `"
256265 raise ValueError (
257266 msg ,
@@ -291,12 +300,16 @@ def invoke_text_embedding(
291300 data .model_type ,
292301 )
293302 if isinstance (model_instance , TextEmbeddingModel ):
294- return model_instance .invoke (
295- data .model ,
296- data .credentials ,
297- data .texts ,
298- data .user_id ,
299- )
303+ token = _current_session .set (session )
304+ try :
305+ return model_instance .invoke (
306+ data .model ,
307+ data .credentials ,
308+ data .texts ,
309+ data .user_id ,
310+ )
311+ finally :
312+ _current_session .reset (token )
300313 msg = f"Model `{ data .model_type } ` not found for provider `{ data .provider } `"
301314 raise ValueError (
302315 msg ,
@@ -312,13 +325,17 @@ def invoke_multimodal_embedding(
312325 data .model_type ,
313326 )
314327 if isinstance (model_instance , TextEmbeddingModel ):
315- return model_instance .invoke_multimodal (
316- data .model ,
317- data .credentials ,
318- data .documents ,
319- user = data .user_id ,
320- input_type = data .input_type ,
321- )
328+ token = _current_session .set (session )
329+ try :
330+ return model_instance .invoke_multimodal (
331+ data .model ,
332+ data .credentials ,
333+ data .documents ,
334+ user = data .user_id ,
335+ input_type = data .input_type ,
336+ )
337+ finally :
338+ _current_session .reset (token )
322339 msg = f"Model `{ data .model_type } ` not found for provider `{ data .provider } `"
323340 raise ValueError (
324341 msg ,
@@ -352,15 +369,19 @@ def invoke_rerank(self, session: Session, data: ModelInvokeRerankRequest) -> obj
352369 data .model_type ,
353370 )
354371 if isinstance (model_instance , RerankModel ):
355- return model_instance .invoke (
356- data .model ,
357- data .credentials ,
358- data .query ,
359- data .docs ,
360- data .score_threshold ,
361- data .top_n ,
362- data .user_id ,
363- )
372+ token = _current_session .set (session )
373+ try :
374+ return model_instance .invoke (
375+ data .model ,
376+ data .credentials ,
377+ data .query ,
378+ data .docs ,
379+ data .score_threshold ,
380+ data .top_n ,
381+ data .user_id ,
382+ )
383+ finally :
384+ _current_session .reset (token )
364385 msg = f"Model `{ data .model_type } ` not found for provider `{ data .provider } `"
365386 raise ValueError (
366387 msg ,
@@ -376,15 +397,19 @@ def invoke_multimodal_rerank(
376397 data .model_type ,
377398 )
378399 if isinstance (model_instance , RerankModel ):
379- return model_instance .invoke_multimodal (
380- data .model ,
381- data .credentials ,
382- data .query ,
383- data .docs ,
384- score_threshold = data .score_threshold ,
385- top_n = data .top_n ,
386- user = data .user_id ,
387- )
400+ token = _current_session .set (session )
401+ try :
402+ return model_instance .invoke_multimodal (
403+ data .model ,
404+ data .credentials ,
405+ data .query ,
406+ data .docs ,
407+ score_threshold = data .score_threshold ,
408+ top_n = data .top_n ,
409+ user = data .user_id ,
410+ )
411+ finally :
412+ _current_session .reset (token )
388413 msg = f"Model `{ data .model_type } ` not found for provider `{ data .provider } `"
389414 raise ValueError (
390415 msg ,
@@ -400,20 +425,24 @@ def invoke_tts(
400425 data .model_type ,
401426 )
402427 if isinstance (model_instance , TTSModel ):
403- b = model_instance .invoke (
404- data .model ,
405- data .tenant_id ,
406- data .credentials ,
407- data .content_text ,
408- data .voice ,
409- data .user_id ,
410- )
411- if isinstance (b , bytes | bytearray | memoryview ):
412- yield {"result" : binascii .hexlify (b ).decode ()}
413- return
428+ token = _current_session .set (session )
429+ try :
430+ b = model_instance .invoke (
431+ data .model ,
432+ data .tenant_id ,
433+ data .credentials ,
434+ data .content_text ,
435+ data .voice ,
436+ data .user_id ,
437+ )
438+ if isinstance (b , bytes | bytearray | memoryview ):
439+ yield {"result" : binascii .hexlify (b ).decode ()}
440+ return
414441
415- for chunk in b :
416- yield {"result" : binascii .hexlify (chunk ).decode ()}
442+ for chunk in b :
443+ yield {"result" : binascii .hexlify (chunk ).decode ()}
444+ finally :
445+ _current_session .reset (token )
417446 else :
418447 msg = f"Model `{ data .model_type } ` not found for provider `{ data .provider } `"
419448 raise ValueError (
@@ -458,14 +487,18 @@ def invoke_speech_to_text(
458487
459488 with pathlib .Path (temp .name ).open ("rb" ) as f :
460489 if isinstance (model_instance , Speech2TextModel ):
461- return {
462- "result" : model_instance .invoke (
463- data .model ,
464- data .credentials ,
465- f ,
466- data .user_id ,
467- ),
468- }
490+ token = _current_session .set (session )
491+ try :
492+ return {
493+ "result" : model_instance .invoke (
494+ data .model ,
495+ data .credentials ,
496+ f ,
497+ data .user_id ,
498+ ),
499+ }
500+ finally :
501+ _current_session .reset (token )
469502 msg = (
470503 f"Model `{ data .model_type } ` not found for provider "
471504 f"`{ data .provider } `"
@@ -506,14 +539,18 @@ def invoke_moderation(
506539 )
507540
508541 if isinstance (model_instance , ModerationModel ):
509- return {
510- "result" : model_instance .invoke (
511- data .model ,
512- data .credentials ,
513- data .text ,
514- data .user_id ,
515- ),
516- }
542+ token = _current_session .set (session )
543+ try :
544+ return {
545+ "result" : model_instance .invoke (
546+ data .model ,
547+ data .credentials ,
548+ data .text ,
549+ data .user_id ,
550+ ),
551+ }
552+ finally :
553+ _current_session .reset (token )
517554 msg = f"Model `{ data .model_type } ` not found for provider `{ data .provider } `"
518555 raise ValueError (
519556 msg ,
0 commit comments