@@ -629,7 +629,8 @@ def __init__(
629629 if app :
630630 if app_name :
631631 raise ValueError (
632- "When app is provided, app_name should not be provided."
632+ "When app is provided, app_name should not be provided, "
633+ "since it will be derived from app.name."
633634 )
634635 if agent :
635636 raise ValueError ("When app is provided, agent should not be provided." )
@@ -656,6 +657,11 @@ def __init__(
656657 ),
657658 }
658659
660+ def _app_name (self ) -> str :
661+ """Returns the app name."""
662+ app = self ._tmpl_attrs .get ("app" )
663+ return app .name if app else self ._tmpl_attrs .get ("app_name" )
664+
659665 async def _init_session (
660666 self ,
661667 session_service : "BaseSessionService" ,
@@ -672,9 +678,8 @@ async def _init_session(
672678 auth = _Authorization (** auth )
673679 session_state [auth_id ] = auth .access_token
674680
675- app = self ._tmpl_attrs .get ("app" )
676681 session = await session_service .create_session (
677- app_name = app . name if app else self ._tmpl_attrs . get ( "app_name" ),
682+ app_name = self ._app_name ( ),
678683 user_id = request .user_id ,
679684 state = session_state ,
680685 )
@@ -694,7 +699,6 @@ async def _save_artifacts(
694699 request : _StreamRunRequest ,
695700 ):
696701 """Saves the artifacts."""
697- app = self ._tmpl_attrs .get ("app" )
698702 if request .artifacts :
699703 for artifact in request .artifacts :
700704 artifact = _Artifact (** artifact )
@@ -703,7 +707,7 @@ async def _save_artifacts(
703707 ):
704708 version_data = _ArtifactVersion (** version_data )
705709 saved_version = await artifact_service .save_artifact (
706- app_name = app . name if app else self ._tmpl_attrs . get ( "app_name" ),
710+ app_name = self ._app_name ( ),
707711 user_id = request .user_id ,
708712 session_id = session_id ,
709713 filename = artifact .file_name ,
@@ -749,7 +753,7 @@ async def _convert_response_events(
749753 _ArtifactVersion (
750754 version = version ,
751755 data = await artifact_service .load_artifact (
752- app_name = self ._tmpl_attrs . get ( "app_name" ),
756+ app_name = self ._app_name ( ),
753757 user_id = user_id ,
754758 session_id = session_id ,
755759 filename = key ,
@@ -1206,7 +1210,6 @@ async def streaming_agent_run_with_events(self, request_json: str):
12061210 )
12071211 ):
12081212 self .set_up ()
1209- app = self ._tmpl_attrs .get ("app" )
12101213
12111214 # Try to get the session, if it doesn't exist, create a new one.
12121215 if request .session_id :
@@ -1216,7 +1219,7 @@ async def streaming_agent_run_with_events(self, request_json: str):
12161219 session = None
12171220 try :
12181221 session = await session_service .get_session (
1219- app_name = app . name if app else self ._tmpl_attrs . get ( "app_name" ),
1222+ app_name = self ._app_name ( ),
12201223 user_id = request .user_id ,
12211224 session_id = request .session_id ,
12221225 )
@@ -1267,9 +1270,8 @@ async def streaming_agent_run_with_events(self, request_json: str):
12671270 yield converted_event
12681271 finally :
12691272 if session and not request .session_id :
1270- app = self ._tmpl_attrs .get ("app" )
12711273 await session_service .delete_session (
1272- app_name = app . name if app else self ._tmpl_attrs . get ( "app_name" ),
1274+ app_name = self ._app_name ( ),
12731275 user_id = request .user_id ,
12741276 session_id = session .id ,
12751277 )
@@ -1306,9 +1308,8 @@ async def async_get_session(
13061308 """
13071309 if not self ._tmpl_attrs .get ("session_service" ):
13081310 self .set_up ()
1309- app = self ._tmpl_attrs .get ("app" )
13101311 session = await self ._tmpl_attrs .get ("session_service" ).get_session (
1311- app_name = app . name if app else self ._tmpl_attrs . get ( "app_name" ),
1312+ app_name = self ._app_name ( ),
13121313 user_id = user_id ,
13131314 session_id = session_id ,
13141315 ** kwargs ,
@@ -1384,9 +1385,8 @@ async def async_list_sessions(self, *, user_id: str, **kwargs):
13841385 """
13851386 if not self ._tmpl_attrs .get ("session_service" ):
13861387 self .set_up ()
1387- app = self ._tmpl_attrs .get ("app" )
13881388 return await self ._tmpl_attrs .get ("session_service" ).list_sessions (
1389- app_name = app . name if app else self ._tmpl_attrs . get ( "app_name" ),
1389+ app_name = self ._app_name ( ),
13901390 user_id = user_id ,
13911391 ** kwargs ,
13921392 )
@@ -1457,9 +1457,8 @@ async def async_create_session(
14571457 """
14581458 if not self ._tmpl_attrs .get ("session_service" ):
14591459 self .set_up ()
1460- app = self ._tmpl_attrs .get ("app" )
14611460 session = await self ._tmpl_attrs .get ("session_service" ).create_session (
1462- app_name = app . name if app else self ._tmpl_attrs . get ( "app_name" ),
1461+ app_name = self ._app_name ( ),
14631462 user_id = user_id ,
14641463 session_id = session_id ,
14651464 state = state ,
@@ -1539,9 +1538,8 @@ async def async_delete_session(
15391538 """
15401539 if not self ._tmpl_attrs .get ("session_service" ):
15411540 self .set_up ()
1542- app = self ._tmpl_attrs .get ("app" )
15431541 await self ._tmpl_attrs .get ("session_service" ).delete_session (
1544- app_name = app . name if app else self ._tmpl_attrs . get ( "app_name" ),
1542+ app_name = self ._app_name ( ),
15451543 user_id = user_id ,
15461544 session_id = session_id ,
15471545 ** kwargs ,
@@ -1630,9 +1628,8 @@ async def async_search_memory(self, *, user_id: str, query: str):
16301628 """
16311629 if not self ._tmpl_attrs .get ("memory_service" ):
16321630 self .set_up ()
1633- app = self ._tmpl_attrs .get ("app" )
16341631 return await self ._tmpl_attrs .get ("memory_service" ).search_memory (
1635- app_name = app . name if app else self ._tmpl_attrs . get ( "app_name" ),
1632+ app_name = self ._app_name ( ),
16361633 user_id = user_id ,
16371634 query = query ,
16381635 )
0 commit comments