diff --git a/go/adk/cmd/main.go b/go/adk/cmd/main.go index 876be030f..3a61b4166 100644 --- a/go/adk/cmd/main.go +++ b/go/adk/cmd/main.go @@ -12,6 +12,7 @@ import ( "github.com/go-logr/logr" "github.com/go-logr/zapr" "github.com/kagent-dev/kagent/go/adk/pkg/a2a" + agentpkg "github.com/kagent-dev/kagent/go/adk/pkg/agent" "github.com/kagent-dev/kagent/go/adk/pkg/app" "github.com/kagent-dev/kagent/go/adk/pkg/auth" "github.com/kagent-dev/kagent/go/adk/pkg/config" @@ -20,6 +21,7 @@ import ( "github.com/kagent-dev/kagent/go/adk/pkg/session" "go.uber.org/zap" "go.uber.org/zap/zapcore" + adkmodel "google.golang.org/adk/model" ) func setupLogger(logLevel string) (logr.Logger, *zap.Logger) { @@ -149,6 +151,16 @@ func main() { } stream := agentConfig.GetStream() + + var sessionNameLLM adkmodel.LLM + if sessionService != nil { + if llm, err := agentpkg.CreateLLM(ctx, agentConfig.Model, logger); err == nil { + sessionNameLLM = llm + } else { + logger.Info("Could not create LLM for session name generation, names will not be set", "error", err) + } + } + executor := a2a.NewKAgentExecutor(a2a.KAgentExecutorConfig{ RunnerConfig: runnerConfig, SubagentSessionIDs: subagentSessionIDs, @@ -156,6 +168,7 @@ func main() { Stream: stream, AppName: appName, Logger: logger, + SessionNameLLM: sessionNameLLM, }) // Build the agent card. diff --git a/go/adk/pkg/a2a/executor.go b/go/adk/pkg/a2a/executor.go index 149992710..714770799 100644 --- a/go/adk/pkg/a2a/executor.go +++ b/go/adk/pkg/a2a/executor.go @@ -5,6 +5,8 @@ import ( "fmt" "maps" "os" + "strings" + "time" a2atype "github.com/a2aproject/a2a-go/a2a" "github.com/a2aproject/a2a-go/a2asrv" @@ -14,14 +16,22 @@ import ( "github.com/kagent-dev/kagent/go/adk/pkg/skills" "github.com/kagent-dev/kagent/go/adk/pkg/telemetry" adkagent "google.golang.org/adk/agent" + "google.golang.org/adk/model" "google.golang.org/adk/runner" "google.golang.org/adk/server/adka2a" + "google.golang.org/genai" ) const ( - defaultSkillsDirectory = "/skills" - envSkillsFolder = "KAGENT_SKILLS_FOLDER" - sessionNameMaxLength = 20 + defaultSkillsDirectory = "/skills" + envSkillsFolder = "KAGENT_SKILLS_FOLDER" + envSessionNameUpdateInterval = "KAGENT_SESSION_NAME_UPDATE_INTERVAL" +) + +const ( + sessionNameSummarizationPrompt = ` + Generate a short title (5-7 words max, no quotes or punctuation) for a conversation that starts with this message: %s\nRespond with only the title, nothing else. + ` ) // KAgentExecutorConfig holds the configuration for KAgentExecutor @@ -33,17 +43,27 @@ type KAgentExecutorConfig struct { AppName string SkillsDirectory string Logger logr.Logger + SessionNameLLM model.LLM +} + +// sessionNameMeta holds per-request metadata used for session name generation. +type sessionNameMeta struct { + userID string + updatedAt time.Time + messageText string } // KAgentExecutor implements a2asrv.AgentExecutor type KAgentExecutor struct { - runnerConfig runner.Config - subagentSessionIDs map[string]string - sessionService *session.KAgentSessionService - stream bool - appName string - skillsDirectory string - logger logr.Logger + runnerConfig runner.Config + subagentSessionIDs map[string]string + sessionService *session.KAgentSessionService + stream bool + appName string + skillsDirectory string + logger logr.Logger + sessionNameLLM model.LLM + sessionNameUpdateInterval time.Duration } var _ a2asrv.AgentExecutor = (*KAgentExecutor)(nil) @@ -57,14 +77,24 @@ func NewKAgentExecutor(cfg KAgentExecutorConfig) *KAgentExecutor { if skillsDir == "" { skillsDir = defaultSkillsDirectory } + + var sessionNameUpdateInterval time.Duration + if intervalStr := os.Getenv(envSessionNameUpdateInterval); intervalStr != "" { + if d, err := time.ParseDuration(intervalStr); err == nil { + sessionNameUpdateInterval = d + } + } + return &KAgentExecutor{ - runnerConfig: cfg.RunnerConfig, - subagentSessionIDs: cfg.SubagentSessionIDs, - sessionService: cfg.SessionService, - stream: cfg.Stream, - appName: cfg.AppName, - skillsDirectory: skillsDir, - logger: cfg.Logger.WithName("kagent-executor"), + runnerConfig: cfg.RunnerConfig, + subagentSessionIDs: cfg.SubagentSessionIDs, + sessionService: cfg.SessionService, + stream: cfg.Stream, + appName: cfg.AppName, + skillsDirectory: skillsDir, + logger: cfg.Logger.WithName("kagent-executor"), + sessionNameLLM: cfg.SessionNameLLM, + sessionNameUpdateInterval: sessionNameUpdateInterval, } } @@ -140,22 +170,36 @@ func (e *KAgentExecutor) Execute(ctx context.Context, reqCtx *a2asrv.RequestCont } // 4. Create / lookup session via sessionService. + var meta *sessionNameMeta if e.sessionService != nil { sess, err := e.sessionService.GetSession(ctx, e.appName, userID, sessionID) if err != nil { e.logger.V(1).Info("Session lookup failed, will create", "error", err, "sessionID", sessionID) sess = nil } + + // Track the session's last update time for post-execution name generation. + // For new sessions, updatedAt stays zero which always exceeds any interval. + var updatedAt time.Time + if sess != nil { + type timeProvider interface { + LastUpdateTime() time.Time + } + if tp, ok := sess.(timeProvider); ok { + updatedAt = tp.LastUpdateTime() + } + } + if sess == nil { - sessionName := extractSessionName(reqCtx.Message) + sessionName := extractMessageText(reqCtx.Message) state := make(map[string]any) if sessionName != "" { state[StateKeySessionName] = sessionName } // Propagate x-kagent-source so the session is tagged in the DB. if callCtx, ok := a2asrv.CallContextFrom(ctx); ok { - if meta := callCtx.RequestMeta(); meta != nil { - if vals, ok := meta.Get("x-kagent-source"); ok && len(vals) > 0 && vals[0] != "" { + if callMeta := callCtx.RequestMeta(); callMeta != nil { + if vals, ok := callMeta.Get("x-kagent-source"); ok && len(vals) > 0 && vals[0] != "" { state[StateKeySource] = vals[0] } } @@ -164,6 +208,14 @@ func (e *KAgentExecutor) Execute(ctx context.Context, reqCtx *a2asrv.RequestCont return fmt.Errorf("failed to create session: %w", err) } } + + if e.sessionNameLLM != nil { + meta = &sessionNameMeta{ + userID: userID, + updatedAt: updatedAt, + messageText: extractMessageText(reqCtx.Message), + } + } } // 5. Detect HITL decision and build the resume message if needed. @@ -398,6 +450,21 @@ func (e *KAgentExecutor) Execute(ctx context.Context, reqCtx *a2asrv.RequestCont return queue.Write(ctx, inputRequired) } + // Generate session name via LLM if the update interval has elapsed. + if meta != nil && e.sessionNameLLM != nil && e.sessionNameUpdateInterval > 0 { + if time.Since(meta.updatedAt) >= e.sessionNameUpdateInterval && meta.messageText != "" { + if name := e.generateSessionName(ctx, meta.messageText); name != "" { + if updateErr := e.sessionService.UpdateSessionName(ctx, meta.userID, sessionID, name); updateErr != nil { + e.logger.V(1).Info("Failed to update session name", "error", updateErr, "sessionID", sessionID) + } else { + e.logger.Info("Session name updated", "sessionID", sessionID, "name", name) + finalMeta[GetKAgentMetadataKey("session_name")] = name + finalMeta[GetKAgentMetadataKey("session_id")] = sessionID + } + } + } + } + // completed: inject last text into final status if no message present. var finalMsg *a2atype.Message if len(lastTextParts) > 0 { @@ -419,16 +486,44 @@ func (e *KAgentExecutor) Cancel(ctx context.Context, reqCtx *a2asrv.RequestConte return queue.Write(ctx, event) } -// extractSessionName extracts session name from the first text part of a message. -func extractSessionName(message *a2atype.Message) string { +// generateSessionName calls the LLM to produce a short title from the first user message. +func (e *KAgentExecutor) generateSessionName(ctx context.Context, messageText string) string { + if e.sessionNameLLM == nil || messageText == "" { + return "" + } + + prompt := fmt.Sprintf(sessionNameSummarizationPrompt, messageText) + + req := &model.LLMRequest{ + Contents: []*genai.Content{ + {Role: "user", Parts: []*genai.Part{{Text: prompt}}}, + }, + } + + var name string + for resp, err := range e.sessionNameLLM.GenerateContent(ctx, req, false) { + if err != nil { + e.logger.V(1).Info("LLM error during session name generation", "error", err) + return "" + } + if resp != nil && resp.Content != nil { + for _, part := range resp.Content.Parts { + if part != nil && part.Text != "" { + name = strings.TrimSpace(part.Text) + } + } + } + } + return name +} + +// extractMessageText returns the first text part of a message. +func extractMessageText(message *a2atype.Message) string { if message == nil { return "" } for _, part := range message.Parts { if tp, ok := part.(a2atype.TextPart); ok && tp.Text != "" { - if len(tp.Text) > sessionNameMaxLength { - return tp.Text[:sessionNameMaxLength] + "..." - } return tp.Text } } diff --git a/go/adk/pkg/session/session.go b/go/adk/pkg/session/session.go index 7b0206c3f..411048efc 100644 --- a/go/adk/pkg/session/session.go +++ b/go/adk/pkg/session/session.go @@ -138,8 +138,10 @@ func (s *KAgentSessionService) Get(ctx context.Context, req *adksession.GetReque var result struct { Data struct { Session struct { - ID string `json:"id"` - UserID string `json:"user_id"` + ID string `json:"id"` + UserID string `json:"user_id"` + Name *string `json:"name"` + UpdatedAt time.Time `json:"updated_at"` } `json:"session"` Events []struct { Data json.RawMessage `json:"data"` @@ -172,6 +174,8 @@ func (s *KAgentSessionService) Get(ctx context.Context, req *adksession.GetReque adkEvents = append(adkEvents, e) } + log.V(1).Info("Parsed session events", "totalEvents", len(result.Data.Events), "outputEvents", len(adkEvents)) + return &adksession.GetResponse{ Session: &localSession{ appName: req.AppName, @@ -179,6 +183,7 @@ func (s *KAgentSessionService) Get(ctx context.Context, req *adksession.GetReque sessionID: result.Data.Session.ID, events: adkEvents, state: make(map[string]any), + updatedAt: result.Data.Session.UpdatedAt, }, }, nil } @@ -310,6 +315,37 @@ func (s *KAgentSessionService) CreateSession(ctx context.Context, appName, userI return err } +// UpdateSessionName updates the display name of a session via the KAgent API. +func (s *KAgentSessionService) UpdateSessionName(ctx context.Context, userID, sessionID, name string) error { + log := logr.FromContextOrDiscard(ctx) + log.V(1).Info("Updating session name", "sessionID", sessionID, "userID", userID, "name", name) + + body, err := json.Marshal(map[string]string{"name": name}) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPatch, s.BaseURL+"/api/sessions/"+sessionID, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-User-ID", userID) + + resp, err := s.Client.Do(req) + if err != nil { + return fmt.Errorf("failed to execute update session name request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + return fmt.Errorf("failed to update session name: status %d - %s", resp.StatusCode, string(bodyBytes)) + } + + log.V(1).Info("Session name updated successfully", "sessionID", sessionID) + return nil +} + // normalizeConfirmationEventRole fixes the role on adk_request_confirmation // functionCall events from "model" to "user". // diff --git a/go/api/database/client.go b/go/api/database/client.go index 6f3bfedde..19c70114b 100644 --- a/go/api/database/client.go +++ b/go/api/database/client.go @@ -39,6 +39,7 @@ type Client interface { // Get methods GetSession(ctx context.Context, sessionID string, userID string) (*Session, error) + GetSessionByID(ctx context.Context, sessionID string) (*Session, error) GetAgent(ctx context.Context, name string) (*Agent, error) GetTask(ctx context.Context, id string) (*protocol.Task, error) GetTool(ctx context.Context, name string) (*Tool, error) diff --git a/go/core/internal/database/client.go b/go/core/internal/database/client.go index 8a38366a4..0379b147b 100644 --- a/go/core/internal/database/client.go +++ b/go/core/internal/database/client.go @@ -114,6 +114,13 @@ func (c *clientImpl) GetSession(ctx context.Context, sessionID string, userID st Clause{Key: "user_id", Value: userID}) } +// GetSessionByID retrieves a session by id only, without filtering by user ID. +// Use this for internal/agent callers that need cross-user session visibility. +func (c *clientImpl) GetSessionByID(ctx context.Context, sessionID string) (*dbpkg.Session, error) { + return get[dbpkg.Session](c.db.WithContext(ctx), + Clause{Key: "id", Value: sessionID}) +} + // GetAgent retrieves an agent by name and user ID func (c *clientImpl) GetAgent(ctx context.Context, agentID string) (*dbpkg.Agent, error) { return get[dbpkg.Agent](c.db.WithContext(ctx), Clause{Key: "id", Value: agentID}) diff --git a/go/core/internal/database/fake/client.go b/go/core/internal/database/fake/client.go index 11157c478..4ca9416d6 100644 --- a/go/core/internal/database/fake/client.go +++ b/go/core/internal/database/fake/client.go @@ -252,6 +252,19 @@ func (c *InMemoryFakeClient) GetSession(_ context.Context, sessionID string, use return session, nil } +// GetSessionByID retrieves a session by ID only, without filtering by user ID. +func (c *InMemoryFakeClient) GetSessionByID(_ context.Context, sessionID string) (*database.Session, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + for _, session := range c.sessions { + if session.ID == sessionID { + return session, nil + } + } + return nil, gorm.ErrRecordNotFound +} + // GetAgent retrieves an agent by name func (c *InMemoryFakeClient) GetAgent(_ context.Context, agentName string) (*database.Agent, error) { c.mu.RLock() diff --git a/go/core/internal/httpserver/handlers/sessions.go b/go/core/internal/httpserver/handlers/sessions.go index eb26e6714..0b970107a 100644 --- a/go/core/internal/httpserver/handlers/sessions.go +++ b/go/core/internal/httpserver/handlers/sessions.go @@ -177,7 +177,16 @@ func (h *SessionsHandler) HandleGetSession(w ErrorResponseWriter, r *http.Reques log = log.WithValues("userID", userID) log.V(1).Info("Getting session from database") - session, err := h.DatabaseService.GetSession(r.Context(), sessionID, userID) + // Agent callers (the ADK runtime) need to find sessions regardless of which + // user created them, since the A2A user ID is synthetic and won't match the + // session's real user_id. Users are restricted to their own sessions. + var session *database.Session + principal, _ := GetPrincipal(r) + if principal.Agent.ID != "" { + session, err = h.DatabaseService.GetSessionByID(r.Context(), sessionID) + } else { + session, err = h.DatabaseService.GetSession(r.Context(), sessionID, userID) + } if err != nil { w.RespondWithError(errors.NewNotFoundError("Session not found", err)) return @@ -208,7 +217,7 @@ func (h *SessionsHandler) HandleGetSession(w ErrorResponseWriter, r *http.Reques } } - events, err := h.DatabaseService.ListEventsForSession(r.Context(), sessionID, userID, queryOptions) + events, err := h.DatabaseService.ListEventsForSession(r.Context(), sessionID, session.UserID, queryOptions) if err != nil { w.RespondWithError(errors.NewInternalServerError("Failed to get events for session", err)) return @@ -274,6 +283,53 @@ func (h *SessionsHandler) HandleUpdateSession(w ErrorResponseWriter, r *http.Req RespondWithJSON(w, http.StatusOK, data) } +// HandlePatchSession handles PATCH /api/sessions/{session_id} to update the session name. +func (h *SessionsHandler) HandlePatchSession(w ErrorResponseWriter, r *http.Request) { + log := ctrllog.FromContext(r.Context()).WithName("sessions-handler").WithValues("operation", "patch-db") + + sessionID, err := GetPathParam(r, "session_id") + if err != nil { + w.RespondWithError(errors.NewBadRequestError("Failed to get session ID from path", err)) + return + } + log = log.WithValues("session_id", sessionID) + + userID, err := getUserIDOrAgentUser(r) + if err != nil { + w.RespondWithError(errors.NewBadRequestError("Failed to get user ID", err)) + return + } + log = log.WithValues("userID", userID) + + var patchReq struct { + Name string `json:"name"` + } + if err := DecodeJSONBody(r, &patchReq); err != nil { + w.RespondWithError(errors.NewBadRequestError("Invalid request body", err)) + return + } + if patchReq.Name == "" { + w.RespondWithError(errors.NewBadRequestError("name is required", nil)) + return + } + + session, err := h.DatabaseService.GetSession(r.Context(), sessionID, userID) + if err != nil { + w.RespondWithError(errors.NewNotFoundError("Session not found", err)) + return + } + + session.Name = &patchReq.Name + if err := h.DatabaseService.StoreSession(r.Context(), session); err != nil { + w.RespondWithError(errors.NewInternalServerError("Failed to update session name", err)) + return + } + + log.Info("Successfully updated session name", "name", patchReq.Name) + data := api.NewResponse(session, "Session name updated successfully", false) + RespondWithJSON(w, http.StatusOK, data) +} + // HandleDeleteSession handles DELETE /api/sessions/{session_id} requests using database func (h *SessionsHandler) HandleDeleteSession(w ErrorResponseWriter, r *http.Request) { log := ctrllog.FromContext(r.Context()).WithName("sessions-handler").WithValues("operation", "delete-db") diff --git a/go/core/internal/httpserver/server.go b/go/core/internal/httpserver/server.go index bdbfde190..4886bc44f 100644 --- a/go/core/internal/httpserver/server.go +++ b/go/core/internal/httpserver/server.go @@ -219,6 +219,7 @@ func (s *HTTPServer) setupRoutes() { s.router.HandleFunc(APIPathSessions+"/{session_id}/tasks", adaptHandler(s.handlers.Sessions.HandleListTasksForSession)).Methods(http.MethodGet) s.router.HandleFunc(APIPathSessions+"/{session_id}", adaptHandler(s.handlers.Sessions.HandleDeleteSession)).Methods(http.MethodDelete) s.router.HandleFunc(APIPathSessions+"/{session_id}", adaptHandler(s.handlers.Sessions.HandleUpdateSession)).Methods(http.MethodPut) + s.router.HandleFunc(APIPathSessions+"/{session_id}", adaptHandler(s.handlers.Sessions.HandlePatchSession)).Methods(http.MethodPatch) s.router.HandleFunc(APIPathSessions+"/{session_id}/events", adaptHandler(s.handlers.Sessions.HandleAddEventToSession)).Methods(http.MethodPost) // Tasks diff --git a/python/packages/kagent-adk/src/kagent/adk/_agent_executor.py b/python/packages/kagent-adk/src/kagent/adk/_agent_executor.py index 726e5ad46..0e9a7f555 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_agent_executor.py +++ b/python/packages/kagent-adk/src/kagent/adk/_agent_executor.py @@ -3,6 +3,7 @@ import asyncio import inspect import logging +import os import uuid from contextlib import suppress from datetime import datetime, timezone @@ -60,6 +61,22 @@ logger = logging.getLogger("kagent_adk." + __name__) +session_name_summarization_prompt = """ +Generate a short title (5-7 words max, no quotes or punctuation) for a conversation that starts with this message: {message_text}\nRespond with only the title, nothing else. +""" + + +def _parse_duration_seconds(duration_str: str) -> float: + """Parse a Go-style duration string (e.g. '5m', '1h', '30s', '300') into seconds.""" + s = duration_str.strip() + if s.endswith("h"): + return float(s[:-1]) * 3600 + if s.endswith("m"): + return float(s[:-1]) * 60 + if s.endswith("s"): + return float(s[:-1]) + return float(s) + class A2aAgentExecutorConfig(BaseModel): """Configuration for the KAgent A2aAgentExecutor.""" @@ -125,6 +142,15 @@ def __init__( self._kagent_config = config self._task_store = task_store + # Parse session name update interval from env var (0 = only set on first message). + self._session_name_update_interval: float = 0.0 + interval_str = os.environ.get("KAGENT_SESSION_NAME_UPDATE_INTERVAL", "") + if interval_str: + try: + self._session_name_update_interval = _parse_duration_seconds(interval_str) + except ValueError: + logger.warning("Invalid KAGENT_SESSION_NAME_UPDATE_INTERVAL value: %s", interval_str) + @override async def _resolve_runner(self) -> Runner: """Resolve the runner from the callable. @@ -317,6 +343,76 @@ async def _safe_close_runner(self, runner: Runner): continue raise result + async def _generate_session_name(self, message_text: str, model: Any) -> Optional[str]: + """Call the agent's LLM to produce a short title from the user message.""" + if not model or not message_text: + return None + try: + from google.adk.models.llm_request import LlmRequest + from google.genai.types import Content + from google.genai.types import Part as GenaiPart + + prompt = session_name_summarization_prompt.format(message_text=message_text) + llm_request = LlmRequest(model=model.model, contents=[Content(role="user", parts=[GenaiPart(text=prompt)])]) + name = "" + async for chunk in model.generate_content_async(llm_request, stream=False): + if chunk.content and chunk.content.parts: + name += "".join(p.text for p in chunk.content.parts if hasattr(p, "text") and p.text) + return name.strip() or None + except Exception as e: + logger.warning("Failed to generate session name via LLM: %s", e, exc_info=True) + return None + + def _should_update_session_name(self, session: Any) -> bool: + """Return True if the session name should be generated""" + session_name = session.state.get("_kagent_session_name") + if not session_name: + return True + if self._session_name_update_interval > 0: + updated_at_str = session.state.get("_kagent_session_updated_at") + if updated_at_str: + try: + updated_at = datetime.fromisoformat(updated_at_str.replace("Z", "+00:00")) + elapsed = (datetime.now(timezone.utc) - updated_at).total_seconds() + return elapsed >= self._session_name_update_interval + except (ValueError, TypeError): + pass + return False + + async def _update_session_name( + self, + session: Any, + runner: Runner, + message: Any, + ) -> Optional[str]: + """Generate and persist a session name from the user message.""" + try: + # Extract first text part from the user message. + message_text = "" + if message and message.parts: + for part in message.parts: + if isinstance(part, Part): + root_part = part.root + if isinstance(root_part, TextPart) and root_part.text: + message_text = root_part.text.strip() + break + + if not message_text: + return None + + model = getattr(runner.agent, "model", None) + name = await self._generate_session_name(message_text, model) + if not name: + return None + + if hasattr(runner.session_service, "update_session_name"): + await runner.session_service.update_session_name(session.id, session.user_id, name) + logger.info("Session name updated: session_id=%s name=%s", session.id, name) + return name + except Exception as e: + logger.warning("Failed to update session name: %s", e) + return None + async def _publish_failed_status_event( self, context: RequestContext, @@ -632,6 +728,12 @@ async def _handle_request( if last_usage_metadata is not None: run_metadata[get_kagent_metadata_key("usage_metadata")] = serialize_metadata_value(last_usage_metadata) + # Update session name if needed (skip for HITL continuation turns). + if not decision and self._should_update_session_name(session): + new_name = await self._update_session_name(session, runner, context.message) + if new_name: + run_metadata[get_kagent_metadata_key("session_name")] = new_name + # publish the task result event - this is final if ( task_result_aggregator.task_state == TaskState.working @@ -690,20 +792,7 @@ async def _prepare_session(self, context: RequestContext, run_args: dict[str, An ) if session is None: - # Extract session name from the first TextPart (like the UI does) - session_name = None - if context.message and context.message.parts: - for part in context.message.parts: - # A2A parts have a .root property that contains the actual part (TextPart, FilePart, etc.) - if isinstance(part, Part): - root_part = part.root - if isinstance(root_part, TextPart) and root_part.text: - # Take first 20 chars + "..." if longer (matching UI behavior) - text = root_part.text.strip() - session_name = text[:20] + ("..." if len(text) > 20 else "") - break - - state: dict[str, Any] = {"session_name": session_name} + state: dict[str, Any] = {} # Propagate source (e.g. "agent") so the session is tagged in the DB. source = None if context.call_context and context.call_context.state: diff --git a/python/packages/kagent-adk/src/kagent/adk/_session_service.py b/python/packages/kagent-adk/src/kagent/adk/_session_service.py index da08895a5..03d4d197e 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_session_service.py +++ b/python/packages/kagent-adk/src/kagent/adk/_session_service.py @@ -122,6 +122,12 @@ async def get_session( for event in events: await super().append_event(session, event) + # Store DB session metadata in state for session name tracking. + # Use underscore-prefixed keys to avoid conflicts with agent state. + if session_data.get("name"): + session.state["_kagent_session_name"] = session_data["name"] + session.state["_kagent_session_updated_at"] = session_data.get("updated_at", "") + return session except httpx.HTTPStatusError as e: if e.response.status_code == 404: @@ -157,6 +163,15 @@ async def delete_session(self, *, app_name: str, user_id: str, session_id: str) ) response.raise_for_status() + async def update_session_name(self, session_id: str, user_id: str, name: str) -> None: + """Update the session name via the kagent PATCH endpoint.""" + response = await self.client.patch( + f"/api/sessions/{session_id}", + json={"name": name}, + headers={"X-User-ID": user_id}, + ) + response.raise_for_status() + @override async def append_event(self, session: Session, event: Event) -> Event: if event.partial: diff --git a/python/packages/kagent-adk/src/kagent/adk/types.py b/python/packages/kagent-adk/src/kagent/adk/types.py index 74ef7f46f..db0b1c0b2 100644 --- a/python/packages/kagent-adk/src/kagent/adk/types.py +++ b/python/packages/kagent-adk/src/kagent/adk/types.py @@ -9,6 +9,7 @@ from google.adk.agents.readonly_context import ReadonlyContext from google.adk.agents.remote_a2a_agent import AGENT_CARD_WELL_KNOWN_PATH, DEFAULT_TIMEOUT from google.adk.models.anthropic_llm import Claude as ClaudeLLM +from google.adk.models.base_llm import BaseLlm from google.adk.models.google_llm import Gemini as GeminiLLM from google.adk.tools.mcp_tool import SseConnectionParams, StreamableHTTPConnectionParams from pydantic import BaseModel, Field @@ -459,7 +460,7 @@ async def auto_save_session_to_memory_callback(callback_context: CallbackContext logger.error("Failed to inject memory configuration: %s", e) -def _create_llm_from_model_config(model_config: ModelUnion): +def _create_llm_from_model_config(model_config: ModelUnion) -> BaseLlm: extra_headers = model_config.headers or {} base_url = getattr(model_config, "base_url", None) @@ -513,7 +514,7 @@ def _create_llm_from_model_config(model_config: ModelUnion): api_key_passthrough=model_config.api_key_passthrough, ) if model_config.type == "gemini": - return model_config.model + return GeminiLLM(model=model_config.model) if model_config.type == "bedrock": # api key passthrough is not applicable for bedrock return KAgentBedrockLlm( diff --git a/ui/src/components/chat/ChatInterface.tsx b/ui/src/components/chat/ChatInterface.tsx index c1be21249..70a84e3c3 100644 --- a/ui/src/components/chat/ChatInterface.tsx +++ b/ui/src/components/chat/ChatInterface.tsx @@ -93,7 +93,10 @@ export default function ChatInterface({ selectedAgentName, selectedNamespace, se agentContext: { namespace: selectedNamespace, agentName: selectedAgentName - } + }, + onSessionNameUpdate: (sessionId, name) => { + window.dispatchEvent(new CustomEvent('session-name-updated', { detail: { sessionId, name } })); + }, }), [selectedNamespace, selectedAgentName]); useEffect(() => { diff --git a/ui/src/components/chat/ChatLayoutUI.tsx b/ui/src/components/chat/ChatLayoutUI.tsx index 10561dfcb..157294fca 100644 --- a/ui/src/components/chat/ChatLayoutUI.tsx +++ b/ui/src/components/chat/ChatLayoutUI.tsx @@ -91,6 +91,18 @@ export default function ChatLayoutUI({ }; }, [agentName, namespace]); + useEffect(() => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const handleSessionNameUpdate = (event: any) => { + const { sessionId, name } = event.detail; + setSessions(prev => prev.map(s => s.id === sessionId ? { ...s, name } : s)); + }; + window.addEventListener('session-name-updated', handleSessionNameUpdate); + return () => { + window.removeEventListener('session-name-updated', handleSessionNameUpdate); + }; + }, []); + return ( <> void; }; export const createMessageHandlers = (handlers: MessageHandlers) => { @@ -929,6 +930,11 @@ export const createMessageHandlers = (handlers: MessageHandlers) => { } if (statusUpdate.final) { + const sessionName = getMetadataValue(adkMetadata as Record, "session_name"); + const sessionId = getMetadataValue(adkMetadata as Record, "session_id"); + if (sessionName && sessionId && handlers.onSessionNameUpdate) { + handlers.onSessionNameUpdate(sessionId, sessionName); + } finalizeStreaming(); } } catch (error) {