Skip to content

Commit ee9fbe1

Browse files
yeesiancopybara-github
authored andcommitted
fix: Standardize on the app_name in AdkApp
PiperOrigin-RevId: 890604835
1 parent 4fdae8e commit ee9fbe1

File tree

1 file changed

+17
-20
lines changed
  • vertexai/agent_engines/templates

1 file changed

+17
-20
lines changed

vertexai/agent_engines/templates/adk.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,8 @@ def __init__(
629629
if app:
630630
if app_name:
631631
raise ValueError(
632-
"When app is provided, app_name should not be provided."
632+
"When app is provided, app_name should not be provided, "
633+
"since it will be derived from app.name."
633634
)
634635
if agent:
635636
raise ValueError("When app is provided, agent should not be provided.")
@@ -656,6 +657,11 @@ def __init__(
656657
),
657658
}
658659

660+
def _app_name(self) -> str:
661+
"""Returns the app name."""
662+
app = self._tmpl_attrs.get("app")
663+
return app.name if app else self._tmpl_attrs.get("app_name")
664+
659665
async def _init_session(
660666
self,
661667
session_service: "BaseSessionService",
@@ -672,9 +678,8 @@ async def _init_session(
672678
auth = _Authorization(**auth)
673679
session_state[auth_id] = auth.access_token
674680

675-
app = self._tmpl_attrs.get("app")
676681
session = await session_service.create_session(
677-
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
682+
app_name=self._app_name(),
678683
user_id=request.user_id,
679684
state=session_state,
680685
)
@@ -694,7 +699,6 @@ async def _save_artifacts(
694699
request: _StreamRunRequest,
695700
):
696701
"""Saves the artifacts."""
697-
app = self._tmpl_attrs.get("app")
698702
if request.artifacts:
699703
for artifact in request.artifacts:
700704
artifact = _Artifact(**artifact)
@@ -703,7 +707,7 @@ async def _save_artifacts(
703707
):
704708
version_data = _ArtifactVersion(**version_data)
705709
saved_version = await artifact_service.save_artifact(
706-
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
710+
app_name=self._app_name(),
707711
user_id=request.user_id,
708712
session_id=session_id,
709713
filename=artifact.file_name,
@@ -749,7 +753,7 @@ async def _convert_response_events(
749753
_ArtifactVersion(
750754
version=version,
751755
data=await artifact_service.load_artifact(
752-
app_name=self._tmpl_attrs.get("app_name"),
756+
app_name=self._app_name(),
753757
user_id=user_id,
754758
session_id=session_id,
755759
filename=key,
@@ -1206,7 +1210,6 @@ async def streaming_agent_run_with_events(self, request_json: str):
12061210
)
12071211
):
12081212
self.set_up()
1209-
app = self._tmpl_attrs.get("app")
12101213

12111214
# Try to get the session, if it doesn't exist, create a new one.
12121215
if request.session_id:
@@ -1216,7 +1219,7 @@ async def streaming_agent_run_with_events(self, request_json: str):
12161219
session = None
12171220
try:
12181221
session = await session_service.get_session(
1219-
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
1222+
app_name=self._app_name(),
12201223
user_id=request.user_id,
12211224
session_id=request.session_id,
12221225
)
@@ -1267,9 +1270,8 @@ async def streaming_agent_run_with_events(self, request_json: str):
12671270
yield converted_event
12681271
finally:
12691272
if session and not request.session_id:
1270-
app = self._tmpl_attrs.get("app")
12711273
await session_service.delete_session(
1272-
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
1274+
app_name=self._app_name(),
12731275
user_id=request.user_id,
12741276
session_id=session.id,
12751277
)
@@ -1306,9 +1308,8 @@ async def async_get_session(
13061308
"""
13071309
if not self._tmpl_attrs.get("session_service"):
13081310
self.set_up()
1309-
app = self._tmpl_attrs.get("app")
13101311
session = await self._tmpl_attrs.get("session_service").get_session(
1311-
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
1312+
app_name=self._app_name(),
13121313
user_id=user_id,
13131314
session_id=session_id,
13141315
**kwargs,
@@ -1384,9 +1385,8 @@ async def async_list_sessions(self, *, user_id: str, **kwargs):
13841385
"""
13851386
if not self._tmpl_attrs.get("session_service"):
13861387
self.set_up()
1387-
app = self._tmpl_attrs.get("app")
13881388
return await self._tmpl_attrs.get("session_service").list_sessions(
1389-
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
1389+
app_name=self._app_name(),
13901390
user_id=user_id,
13911391
**kwargs,
13921392
)
@@ -1457,9 +1457,8 @@ async def async_create_session(
14571457
"""
14581458
if not self._tmpl_attrs.get("session_service"):
14591459
self.set_up()
1460-
app = self._tmpl_attrs.get("app")
14611460
session = await self._tmpl_attrs.get("session_service").create_session(
1462-
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
1461+
app_name=self._app_name(),
14631462
user_id=user_id,
14641463
session_id=session_id,
14651464
state=state,
@@ -1539,9 +1538,8 @@ async def async_delete_session(
15391538
"""
15401539
if not self._tmpl_attrs.get("session_service"):
15411540
self.set_up()
1542-
app = self._tmpl_attrs.get("app")
15431541
await self._tmpl_attrs.get("session_service").delete_session(
1544-
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
1542+
app_name=self._app_name(),
15451543
user_id=user_id,
15461544
session_id=session_id,
15471545
**kwargs,
@@ -1630,9 +1628,8 @@ async def async_search_memory(self, *, user_id: str, query: str):
16301628
"""
16311629
if not self._tmpl_attrs.get("memory_service"):
16321630
self.set_up()
1633-
app = self._tmpl_attrs.get("app")
16341631
return await self._tmpl_attrs.get("memory_service").search_memory(
1635-
app_name=app.name if app else self._tmpl_attrs.get("app_name"),
1632+
app_name=self._app_name(),
16361633
user_id=user_id,
16371634
query=query,
16381635
)

0 commit comments

Comments
 (0)