-
Notifications
You must be signed in to change notification settings - Fork 512
feat(adk): Allow session name generation using LLM #1611
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,8 @@ import ( | |
| "fmt" | ||
| "maps" | ||
| "os" | ||
| "strings" | ||
| "time" | ||
|
|
||
| a2atype "github.com/a2aproject/a2a-go/a2a" | ||
| "github.com/a2aproject/a2a-go/a2asrv" | ||
|
|
@@ -14,14 +16,22 @@ import ( | |
| "github.com/kagent-dev/kagent/go/adk/pkg/skills" | ||
| "github.com/kagent-dev/kagent/go/adk/pkg/telemetry" | ||
| adkagent "google.golang.org/adk/agent" | ||
| "google.golang.org/adk/model" | ||
| "google.golang.org/adk/runner" | ||
| "google.golang.org/adk/server/adka2a" | ||
| "google.golang.org/genai" | ||
| ) | ||
|
|
||
| const ( | ||
| defaultSkillsDirectory = "/skills" | ||
| envSkillsFolder = "KAGENT_SKILLS_FOLDER" | ||
| sessionNameMaxLength = 20 | ||
| defaultSkillsDirectory = "/skills" | ||
| envSkillsFolder = "KAGENT_SKILLS_FOLDER" | ||
| envSessionNameUpdateInterval = "KAGENT_SESSION_NAME_UPDATE_INTERVAL" | ||
| ) | ||
|
|
||
| const ( | ||
| sessionNameSummarizationPrompt = ` | ||
| Generate a short title (5-7 words max, no quotes or punctuation) for a conversation that starts with this message: %s\nRespond with only the title, nothing else. | ||
| ` | ||
| ) | ||
|
|
||
| // KAgentExecutorConfig holds the configuration for KAgentExecutor | ||
|
|
@@ -33,17 +43,27 @@ type KAgentExecutorConfig struct { | |
| AppName string | ||
| SkillsDirectory string | ||
| Logger logr.Logger | ||
| SessionNameLLM model.LLM | ||
| } | ||
|
|
||
| // sessionNameMeta holds per-request metadata used for session name generation. | ||
| type sessionNameMeta struct { | ||
| userID string | ||
| updatedAt time.Time | ||
| messageText string | ||
| } | ||
|
|
||
| // KAgentExecutor implements a2asrv.AgentExecutor | ||
| type KAgentExecutor struct { | ||
| runnerConfig runner.Config | ||
| subagentSessionIDs map[string]string | ||
| sessionService *session.KAgentSessionService | ||
| stream bool | ||
| appName string | ||
| skillsDirectory string | ||
| logger logr.Logger | ||
| runnerConfig runner.Config | ||
| subagentSessionIDs map[string]string | ||
| sessionService *session.KAgentSessionService | ||
| stream bool | ||
| appName string | ||
| skillsDirectory string | ||
| logger logr.Logger | ||
| sessionNameLLM model.LLM | ||
| sessionNameUpdateInterval time.Duration | ||
| } | ||
|
|
||
| var _ a2asrv.AgentExecutor = (*KAgentExecutor)(nil) | ||
|
|
@@ -57,14 +77,24 @@ func NewKAgentExecutor(cfg KAgentExecutorConfig) *KAgentExecutor { | |
| if skillsDir == "" { | ||
| skillsDir = defaultSkillsDirectory | ||
| } | ||
|
|
||
| var sessionNameUpdateInterval time.Duration | ||
| if intervalStr := os.Getenv(envSessionNameUpdateInterval); intervalStr != "" { | ||
| if d, err := time.ParseDuration(intervalStr); err == nil { | ||
| sessionNameUpdateInterval = d | ||
| } | ||
| } | ||
|
|
||
| return &KAgentExecutor{ | ||
| runnerConfig: cfg.RunnerConfig, | ||
| subagentSessionIDs: cfg.SubagentSessionIDs, | ||
| sessionService: cfg.SessionService, | ||
| stream: cfg.Stream, | ||
| appName: cfg.AppName, | ||
| skillsDirectory: skillsDir, | ||
| logger: cfg.Logger.WithName("kagent-executor"), | ||
| runnerConfig: cfg.RunnerConfig, | ||
| subagentSessionIDs: cfg.SubagentSessionIDs, | ||
| sessionService: cfg.SessionService, | ||
| stream: cfg.Stream, | ||
| appName: cfg.AppName, | ||
| skillsDirectory: skillsDir, | ||
| logger: cfg.Logger.WithName("kagent-executor"), | ||
| sessionNameLLM: cfg.SessionNameLLM, | ||
| sessionNameUpdateInterval: sessionNameUpdateInterval, | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -140,22 +170,36 @@ func (e *KAgentExecutor) Execute(ctx context.Context, reqCtx *a2asrv.RequestCont | |
| } | ||
|
|
||
| // 4. Create / lookup session via sessionService. | ||
| var meta *sessionNameMeta | ||
| if e.sessionService != nil { | ||
| sess, err := e.sessionService.GetSession(ctx, e.appName, userID, sessionID) | ||
| if err != nil { | ||
| e.logger.V(1).Info("Session lookup failed, will create", "error", err, "sessionID", sessionID) | ||
| sess = nil | ||
| } | ||
|
|
||
| // Track the session's last update time for post-execution name generation. | ||
| // For new sessions, updatedAt stays zero which always exceeds any interval. | ||
| var updatedAt time.Time | ||
| if sess != nil { | ||
| type timeProvider interface { | ||
| LastUpdateTime() time.Time | ||
| } | ||
| if tp, ok := sess.(timeProvider); ok { | ||
| updatedAt = tp.LastUpdateTime() | ||
| } | ||
| } | ||
|
|
||
| if sess == nil { | ||
| sessionName := extractSessionName(reqCtx.Message) | ||
| sessionName := extractMessageText(reqCtx.Message) | ||
| state := make(map[string]any) | ||
| if sessionName != "" { | ||
| state[StateKeySessionName] = sessionName | ||
| } | ||
|
Comment on lines
193
to
198
|
||
| // Propagate x-kagent-source so the session is tagged in the DB. | ||
| if callCtx, ok := a2asrv.CallContextFrom(ctx); ok { | ||
| if meta := callCtx.RequestMeta(); meta != nil { | ||
| if vals, ok := meta.Get("x-kagent-source"); ok && len(vals) > 0 && vals[0] != "" { | ||
| if callMeta := callCtx.RequestMeta(); callMeta != nil { | ||
| if vals, ok := callMeta.Get("x-kagent-source"); ok && len(vals) > 0 && vals[0] != "" { | ||
| state[StateKeySource] = vals[0] | ||
| } | ||
| } | ||
|
|
@@ -164,6 +208,14 @@ func (e *KAgentExecutor) Execute(ctx context.Context, reqCtx *a2asrv.RequestCont | |
| return fmt.Errorf("failed to create session: %w", err) | ||
| } | ||
| } | ||
|
|
||
| if e.sessionNameLLM != nil { | ||
| meta = &sessionNameMeta{ | ||
| userID: userID, | ||
| updatedAt: updatedAt, | ||
| messageText: extractMessageText(reqCtx.Message), | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // 5. Detect HITL decision and build the resume message if needed. | ||
|
|
@@ -398,6 +450,21 @@ func (e *KAgentExecutor) Execute(ctx context.Context, reqCtx *a2asrv.RequestCont | |
| return queue.Write(ctx, inputRequired) | ||
| } | ||
|
|
||
| // Generate session name via LLM if the update interval has elapsed. | ||
| if meta != nil && e.sessionNameLLM != nil && e.sessionNameUpdateInterval > 0 { | ||
| if time.Since(meta.updatedAt) >= e.sessionNameUpdateInterval && meta.messageText != "" { | ||
| if name := e.generateSessionName(ctx, meta.messageText); name != "" { | ||
| if updateErr := e.sessionService.UpdateSessionName(ctx, meta.userID, sessionID, name); updateErr != nil { | ||
| e.logger.V(1).Info("Failed to update session name", "error", updateErr, "sessionID", sessionID) | ||
| } else { | ||
| e.logger.Info("Session name updated", "sessionID", sessionID, "name", name) | ||
| finalMeta[GetKAgentMetadataKey("session_name")] = name | ||
| finalMeta[GetKAgentMetadataKey("session_id")] = sessionID | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // completed: inject last text into final status if no message present. | ||
| var finalMsg *a2atype.Message | ||
| if len(lastTextParts) > 0 { | ||
|
|
@@ -419,16 +486,44 @@ func (e *KAgentExecutor) Cancel(ctx context.Context, reqCtx *a2asrv.RequestConte | |
| return queue.Write(ctx, event) | ||
| } | ||
|
|
||
| // extractSessionName extracts session name from the first text part of a message. | ||
| func extractSessionName(message *a2atype.Message) string { | ||
| // generateSessionName calls the LLM to produce a short title from the first user message. | ||
| func (e *KAgentExecutor) generateSessionName(ctx context.Context, messageText string) string { | ||
| if e.sessionNameLLM == nil || messageText == "" { | ||
| return "" | ||
| } | ||
|
|
||
| prompt := fmt.Sprintf(sessionNameSummarizationPrompt, messageText) | ||
|
|
||
| req := &model.LLMRequest{ | ||
| Contents: []*genai.Content{ | ||
| {Role: "user", Parts: []*genai.Part{{Text: prompt}}}, | ||
| }, | ||
| } | ||
|
|
||
| var name string | ||
| for resp, err := range e.sessionNameLLM.GenerateContent(ctx, req, false) { | ||
| if err != nil { | ||
| e.logger.V(1).Info("LLM error during session name generation", "error", err) | ||
| return "" | ||
| } | ||
| if resp != nil && resp.Content != nil { | ||
| for _, part := range resp.Content.Parts { | ||
| if part != nil && part.Text != "" { | ||
| name = strings.TrimSpace(part.Text) | ||
| } | ||
| } | ||
| } | ||
| } | ||
| return name | ||
|
Comment on lines
+503
to
+517
|
||
| } | ||
|
|
||
| // extractMessageText returns the first text part of a message. | ||
| func extractMessageText(message *a2atype.Message) string { | ||
| if message == nil { | ||
| return "" | ||
| } | ||
| for _, part := range message.Parts { | ||
| if tp, ok := part.(a2atype.TextPart); ok && tp.Text != "" { | ||
| if len(tp.Text) > sessionNameMaxLength { | ||
| return tp.Text[:sessionNameMaxLength] + "..." | ||
| } | ||
| return tp.Text | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -138,8 +138,10 @@ func (s *KAgentSessionService) Get(ctx context.Context, req *adksession.GetReque | |
| var result struct { | ||
| Data struct { | ||
| Session struct { | ||
| ID string `json:"id"` | ||
| UserID string `json:"user_id"` | ||
| ID string `json:"id"` | ||
| UserID string `json:"user_id"` | ||
| Name *string `json:"name"` | ||
| UpdatedAt time.Time `json:"updated_at"` | ||
| } `json:"session"` | ||
| Events []struct { | ||
| Data json.RawMessage `json:"data"` | ||
|
|
@@ -172,13 +174,16 @@ func (s *KAgentSessionService) Get(ctx context.Context, req *adksession.GetReque | |
| adkEvents = append(adkEvents, e) | ||
| } | ||
|
|
||
| log.V(1).Info("Parsed session events", "totalEvents", len(result.Data.Events), "outputEvents", len(adkEvents)) | ||
|
|
||
| return &adksession.GetResponse{ | ||
| Session: &localSession{ | ||
| appName: req.AppName, | ||
| userID: result.Data.Session.UserID, | ||
| sessionID: result.Data.Session.ID, | ||
| events: adkEvents, | ||
| state: make(map[string]any), | ||
| updatedAt: result.Data.Session.UpdatedAt, | ||
| }, | ||
| }, nil | ||
| } | ||
|
|
@@ -310,6 +315,37 @@ func (s *KAgentSessionService) CreateSession(ctx context.Context, appName, userI | |
| return err | ||
| } | ||
|
|
||
| // UpdateSessionName updates the display name of a session via the KAgent API. | ||
| func (s *KAgentSessionService) UpdateSessionName(ctx context.Context, userID, sessionID, name string) error { | ||
| log := logr.FromContextOrDiscard(ctx) | ||
| log.V(1).Info("Updating session name", "sessionID", sessionID, "userID", userID, "name", name) | ||
|
|
||
| body, err := json.Marshal(map[string]string{"name": name}) | ||
| if err != nil { | ||
| return fmt.Errorf("failed to marshal request: %w", err) | ||
| } | ||
| req, err := http.NewRequestWithContext(ctx, http.MethodPatch, s.BaseURL+"/api/sessions/"+sessionID, bytes.NewReader(body)) | ||
| if err != nil { | ||
| return fmt.Errorf("failed to create request: %w", err) | ||
| } | ||
| req.Header.Set("Content-Type", "application/json") | ||
| req.Header.Set("X-User-ID", userID) | ||
|
|
||
| resp, err := s.Client.Do(req) | ||
| if err != nil { | ||
| return fmt.Errorf("failed to execute update session name request: %w", err) | ||
| } | ||
| defer resp.Body.Close() | ||
|
|
||
| if resp.StatusCode != http.StatusOK { | ||
| bodyBytes, _ := io.ReadAll(resp.Body) | ||
| return fmt.Errorf("failed to update session name: status %d - %s", resp.StatusCode, string(bodyBytes)) | ||
| } | ||
|
|
||
| log.V(1).Info("Session name updated successfully", "sessionID", sessionID) | ||
| return nil | ||
| } | ||
|
Comment on lines
+318
to
+347
|
||
|
|
||
| // normalizeConfirmationEventRole fixes the role on adk_request_confirmation | ||
| // functionCall events from "model" to "user". | ||
| // | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NewKAgentExecutor silently ignores invalid KAGENT_SESSION_NAME_UPDATE_INTERVAL values (ParseDuration error just leaves interval at 0). This makes misconfiguration hard to diagnose. Log a warning when parsing fails so operators know why session name generation isn't running.