diff --git a/go/adk/pkg/agent/agent.go b/go/adk/pkg/agent/agent.go index e1b475240..c265634e0 100644 --- a/go/adk/pkg/agent/agent.go +++ b/go/adk/pkg/agent/agent.go @@ -289,6 +289,16 @@ func CreateLLM(ctx context.Context, m adk.Model, log logr.Logger) (adkmodel.LLM, } return models.NewAnthropicVertexAIModelWithLogger(ctx, cfg, region, project, log) + case *adk.SAPAICore: + cfg := models.SAPAICoreConfig{ + Model: m.Model, + BaseUrl: m.BaseUrl, + ResourceGroup: m.ResourceGroup, + AuthUrl: m.AuthUrl, + Headers: extractHeaders(m.Headers), + } + return models.NewSAPAICoreModelWithLogger(cfg, log) + default: return nil, fmt.Errorf("unsupported model type: %s", m.GetType()) } diff --git a/go/adk/pkg/models/sapaicore.go b/go/adk/pkg/models/sapaicore.go new file mode 100644 index 000000000..380e4f6e7 --- /dev/null +++ b/go/adk/pkg/models/sapaicore.go @@ -0,0 +1,188 @@ +package models + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "os" + "strings" + "sync" + "time" + + "github.com/go-logr/logr" +) + +type SAPAICoreConfig struct { + Model string + BaseUrl string + ResourceGroup string + AuthUrl string + Headers map[string]string +} + +type SAPAICoreModel struct { + Config SAPAICoreConfig + Logger logr.Logger + + mu sync.Mutex + token string + tokenExpiresAt time.Time + deploymentURL string + deploymentURLAt time.Time + httpClient *http.Client +} + +func NewSAPAICoreModelWithLogger(config SAPAICoreConfig, logger logr.Logger) (*SAPAICoreModel, error) { + if config.BaseUrl == "" { + return nil, fmt.Errorf("SAP AI Core requires base_url") + } + if config.ResourceGroup == "" { + config.ResourceGroup = "default" + } + return &SAPAICoreModel{ + Config: config, + Logger: logger, + httpClient: &http.Client{Timeout: 5 * time.Minute}, + }, nil +} + +func (m *SAPAICoreModel) ensureToken(ctx context.Context) (string, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.token != "" && time.Now().Before(m.tokenExpiresAt.Add(-2*time.Minute)) { + return m.token, nil + } + + clientID := os.Getenv("SAP_AI_CORE_CLIENT_ID") + clientSecret := os.Getenv("SAP_AI_CORE_CLIENT_SECRET") + if m.Config.AuthUrl == "" || clientID == "" || clientSecret == "" { + return "", fmt.Errorf("SAP AI Core requires auth_url + SAP_AI_CORE_CLIENT_ID/SECRET env vars") + } + + tokenURL := strings.TrimRight(m.Config.AuthUrl, "/") + if !strings.HasSuffix(tokenURL, "/oauth/token") { + tokenURL += "/oauth/token" + } + + formData := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {clientID}, + "client_secret": {clientSecret}, + } + req, err := http.NewRequestWithContext(ctx, "POST", tokenURL, strings.NewReader(formData.Encode())) + if err != nil { + return "", fmt.Errorf("failed to create OAuth2 token request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := m.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("OAuth2 token request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", &orchHTTPError{StatusCode: resp.StatusCode, URL: tokenURL} + } + + var tokenResp struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + } + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + return "", fmt.Errorf("failed to decode OAuth2 token response: %w", err) + } + + m.token = tokenResp.AccessToken + if tokenResp.ExpiresIn > 0 { + m.tokenExpiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + } else { + m.tokenExpiresAt = time.Now().Add(12 * time.Hour) + } + return m.token, nil +} + +func (m *SAPAICoreModel) invalidateToken() { + m.mu.Lock() + defer m.mu.Unlock() + m.token = "" + m.tokenExpiresAt = time.Time{} +} + +func (m *SAPAICoreModel) resolveDeploymentURL(ctx context.Context) (string, error) { + m.mu.Lock() + if m.deploymentURL != "" && time.Now().Before(m.deploymentURLAt.Add(time.Hour)) { + u := m.deploymentURL + m.mu.Unlock() + return u, nil + } + m.mu.Unlock() + + token, err := m.ensureToken(ctx) + if err != nil { + return "", err + } + + reqURL := fmt.Sprintf("%s/v2/lm/deployments", m.Config.BaseUrl) + req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) + if err != nil { + return "", err + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("AI-Resource-Group", m.Config.ResourceGroup) + + resp, err := m.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("failed to list deployments: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", &orchHTTPError{StatusCode: resp.StatusCode, URL: reqURL} + } + + var result struct { + Resources []struct { + ID string `json:"id"` + ScenarioID string `json:"scenarioId"` + Status string `json:"status"` + DeploymentURL string `json:"deploymentUrl"` + CreatedAt string `json:"createdAt"` + } `json:"resources"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("failed to decode deployments: %w", err) + } + + var best string + var bestCreated string + for _, d := range result.Resources { + if d.ScenarioID == "orchestration" && d.Status == "RUNNING" && d.DeploymentURL != "" { + if d.CreatedAt > bestCreated { + best = d.DeploymentURL + bestCreated = d.CreatedAt + } + } + } + if best == "" { + return "", fmt.Errorf("no running orchestration deployment found in SAP AI Core") + } + + m.mu.Lock() + m.deploymentURL = best + m.deploymentURLAt = time.Now() + m.mu.Unlock() + + m.Logger.Info("Resolved SAP AI Core orchestration deployment", "url", best) + return best, nil +} + +func (m *SAPAICoreModel) invalidateDeploymentURL() { + m.mu.Lock() + defer m.mu.Unlock() + m.deploymentURL = "" + m.deploymentURLAt = time.Time{} +} diff --git a/go/adk/pkg/models/sapaicore_adk.go b/go/adk/pkg/models/sapaicore_adk.go new file mode 100644 index 000000000..ba03c391e --- /dev/null +++ b/go/adk/pkg/models/sapaicore_adk.go @@ -0,0 +1,535 @@ +package models + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "iter" + "net/http" + "slices" + "strings" + + "google.golang.org/adk/model" + "google.golang.org/genai" +) + +func (m *SAPAICoreModel) Name() string { + return m.Config.Model +} + +func (m *SAPAICoreModel) GenerateContent(ctx context.Context, req *model.LLMRequest, stream bool) iter.Seq2[*model.LLMResponse, error] { + return func(yield func(*model.LLMResponse, error) bool) { + resp, err := m.doRequest(ctx, req, stream) + if err != nil { + if isRetryableError(err) { + m.invalidateToken() + m.invalidateDeploymentURL() + var he *orchHTTPError + if errors.As(err, &he) { + m.Logger.Info("SAP AI Core request failed, retrying", "status", he.StatusCode, "url", he.URL) + } else { + m.Logger.Info("SAP AI Core request failed, retrying", "error", err) + } + resp, err = m.doRequest(ctx, req, stream) + if err != nil { + yield(nil, fmt.Errorf("SAP AI Core retry failed: %w", err)) + return + } + } else { + yield(nil, fmt.Errorf("SAP AI Core request failed: %w", err)) + return + } + } + defer resp.Body.Close() + + if stream { + m.handleStream(ctx, resp.Body, yield) + } else { + m.handleNonStream(resp.Body, yield) + } + } +} + +func (m *SAPAICoreModel) doRequest(ctx context.Context, req *model.LLMRequest, stream bool) (*http.Response, error) { + deploymentURL, err := m.resolveDeploymentURL(ctx) + if err != nil { + return nil, err + } + + token, err := m.ensureToken(ctx) + if err != nil { + return nil, err + } + + body := m.buildOrchestrationBody(req, stream) + bodyBytes, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + url := deploymentURL + "/v2/completion" + httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, err + } + httpReq.Header.Set("Authorization", "Bearer "+token) + httpReq.Header.Set("AI-Resource-Group", m.Config.ResourceGroup) + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := m.httpClient.Do(httpReq) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + defer resp.Body.Close() + errBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) + return nil, &orchHTTPError{StatusCode: resp.StatusCode, Body: string(errBody), URL: url} + } + + return resp, nil +} + +type orchHTTPError struct { + StatusCode int + Body string + URL string +} + +func (e *orchHTTPError) Error() string { + return fmt.Sprintf("SAP AI Core returned HTTP %d (url: %s): %s", e.StatusCode, e.URL, e.Body) +} + +func isRetryableError(err error) bool { + if he, ok := err.(*orchHTTPError); ok { + switch he.StatusCode { + case 401, 403, 404, 502, 503, 504: + return true + } + } + return false +} + +func (m *SAPAICoreModel) buildOrchestrationBody(req *model.LLMRequest, stream bool) map[string]any { + messages, systemInstruction := genaiContentsToOrchTemplate(req.Contents, req.Config) + if systemInstruction != "" { + messages = append([]map[string]any{{"role": "system", "content": systemInstruction}}, messages...) + } + + modelName := req.Model + if modelName == "" { + modelName = m.Config.Model + } + + params := map[string]any{} + if req.Config != nil { + if req.Config.Temperature != nil { + params["temperature"] = *req.Config.Temperature + } + if req.Config.MaxOutputTokens > 0 { + params["max_tokens"] = req.Config.MaxOutputTokens + } + if req.Config.TopP != nil { + params["top_p"] = *req.Config.TopP + } + } + + promptConfig := map[string]any{ + "template": messages, + } + + if req.Config != nil && len(req.Config.Tools) > 0 { + tools := genaiToolsToOrchTools(req.Config.Tools) + if len(tools) > 0 { + promptConfig["tools"] = tools + } + } + + return map[string]any{ + "config": map[string]any{ + "modules": map[string]any{ + "prompt_templating": map[string]any{ + "prompt": promptConfig, + "model": map[string]any{ + "name": modelName, + "params": params, + "version": "latest", + }, + }, + }, + "stream": map[string]any{ + "enabled": stream, + }, + }, + } +} + +func genaiContentsToOrchTemplate(contents []*genai.Content, config *genai.GenerateContentConfig) ([]map[string]any, string) { + var systemBuilder strings.Builder + if config != nil && config.SystemInstruction != nil { + for _, p := range config.SystemInstruction.Parts { + if p != nil && p.Text != "" { + systemBuilder.WriteString(p.Text) + systemBuilder.WriteByte('\n') + } + } + } + systemInstruction := strings.TrimSpace(systemBuilder.String()) + + functionResponses := make(map[string]*genai.FunctionResponse) + for _, c := range contents { + if c == nil { + continue + } + for _, p := range c.Parts { + if p != nil && p.FunctionResponse != nil { + functionResponses[p.FunctionResponse.ID] = p.FunctionResponse + } + } + } + + var messages []map[string]any + for _, content := range contents { + if content == nil || strings.TrimSpace(content.Role) == "system" { + continue + } + role := "user" + if content.Role == "model" || content.Role == "assistant" { + role = "assistant" + } + + var textParts []string + var functionCalls []*genai.FunctionCall + + for _, part := range content.Parts { + if part == nil { + continue + } + if part.Text != "" { + textParts = append(textParts, part.Text) + } else if part.FunctionCall != nil { + functionCalls = append(functionCalls, part.FunctionCall) + } + } + + if len(functionCalls) > 0 && role == "assistant" { + toolCalls := make([]map[string]any, 0, len(functionCalls)) + var toolResponses []map[string]any + for _, fc := range functionCalls { + argsJSON, _ := json.Marshal(fc.Args) + tc := map[string]any{ + "type": "function", + "function": map[string]any{ + "name": fc.Name, + "arguments": string(argsJSON), + }, + } + if fc.ID != "" { + tc["id"] = fc.ID + } + toolCalls = append(toolCalls, tc) + + respContent := "No response available." + if fr := functionResponses[fc.ID]; fr != nil { + respContent = functionResponseContentString(fr.Response) + } + toolResponses = append(toolResponses, map[string]any{ + "role": "tool", + "tool_call_id": fc.ID, + "content": respContent, + }) + } + + msg := map[string]any{"role": "assistant", "tool_calls": toolCalls} + if len(textParts) > 0 { + msg["content"] = strings.Join(textParts, "\n") + } else { + msg["content"] = "" + } + messages = append(messages, msg) + messages = append(messages, toolResponses...) + } else if len(textParts) > 0 { + messages = append(messages, map[string]any{ + "role": role, + "content": strings.Join(textParts, "\n"), + }) + } + } + + return messages, systemInstruction +} + +func genaiToolsToOrchTools(tools []*genai.Tool) []map[string]any { + var out []map[string]any + for _, t := range tools { + if t == nil || t.FunctionDeclarations == nil { + continue + } + for _, fd := range t.FunctionDeclarations { + if fd == nil { + continue + } + params := map[string]any{"type": "object", "properties": map[string]any{}} + if fd.ParametersJsonSchema != nil { + if m, ok := fd.ParametersJsonSchema.(map[string]any); ok { + params = m + } + } else if fd.Parameters != nil { + if m := genaiSchemaToMap(fd.Parameters); m != nil { + params = m + } + } + out = append(out, map[string]any{ + "type": "function", + "function": map[string]any{ + "name": fd.Name, + "description": fd.Description, + "parameters": params, + }, + }) + } + } + return out +} + +func (m *SAPAICoreModel) handleStream(ctx context.Context, body io.Reader, yield func(*model.LLMResponse, error) bool) { + scanner := bufio.NewScanner(body) + scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) + + var aggregatedText string + toolCallsAcc := make(map[int64]map[string]any) + var finishReason string + var promptTokens, completionTokens int64 + + for scanner.Scan() { + if ctx.Err() != nil { + return + } + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + payload := line + if strings.HasPrefix(line, "data: ") { + payload = line[6:] + } + if payload == "[DONE]" { + break + } + + var event map[string]any + if err := json.Unmarshal([]byte(payload), &event); err != nil { + continue + } + + if _, ok := event["code"]; ok { + yield(nil, fmt.Errorf("SAP AI Core stream error: %s", payload)) + return + } + + chunk := parseOrchChunk(event) + if chunk == nil { + continue + } + + choices, _ := chunk["choices"].([]any) + for _, c := range choices { + choice, ok := c.(map[string]any) + if !ok { + continue + } + delta, _ := choice["delta"].(map[string]any) + if content, ok := delta["content"].(string); ok && content != "" { + aggregatedText += content + if !yield(&model.LLMResponse{ + Partial: true, + TurnComplete: false, + Content: &genai.Content{Role: string(genai.RoleModel), Parts: []*genai.Part{{Text: content}}}, + }, nil) { + return + } + } + + if tcs, ok := delta["tool_calls"].([]any); ok { + for _, tcRaw := range tcs { + tc, ok := tcRaw.(map[string]any) + if !ok { + continue + } + idx := int64(0) + if v, ok := tc["index"].(float64); ok { + idx = int64(v) + } + if toolCallsAcc[idx] == nil { + toolCallsAcc[idx] = map[string]any{"id": "", "name": "", "arguments": ""} + } + if id, ok := tc["id"].(string); ok && id != "" { + toolCallsAcc[idx]["id"] = id + } + if fn, ok := tc["function"].(map[string]any); ok { + if name, ok := fn["name"].(string); ok && name != "" { + toolCallsAcc[idx]["name"] = name + } + if args, ok := fn["arguments"].(string); ok && args != "" { + prev, _ := toolCallsAcc[idx]["arguments"].(string) + toolCallsAcc[idx]["arguments"] = prev + args + } + } + } + } + + if fr, ok := choice["finish_reason"].(string); ok && fr != "" { + finishReason = fr + } + } + + if usage, ok := chunk["usage"].(map[string]any); ok { + if v, ok := usage["prompt_tokens"].(float64); ok { + promptTokens = int64(v) + } + if v, ok := usage["completion_tokens"].(float64); ok { + completionTokens = int64(v) + } + } + } + + if err := scanner.Err(); err != nil { + if ctx.Err() == context.Canceled { + return + } + yield(nil, fmt.Errorf("SAP AI Core stream read error: %w", err)) + return + } + + indices := make([]int64, 0, len(toolCallsAcc)) + for k := range toolCallsAcc { + indices = append(indices, k) + } + slices.Sort(indices) + + finalParts := make([]*genai.Part, 0, 1+len(toolCallsAcc)) + if aggregatedText != "" { + finalParts = append(finalParts, &genai.Part{Text: aggregatedText}) + } + for _, idx := range indices { + tc := toolCallsAcc[idx] + argsStr, _ := tc["arguments"].(string) + var args map[string]any + if argsStr != "" { + _ = json.Unmarshal([]byte(argsStr), &args) + } + name, _ := tc["name"].(string) + id, _ := tc["id"].(string) + if name != "" || id != "" { + p := genai.NewPartFromFunctionCall(name, args) + p.FunctionCall.ID = id + finalParts = append(finalParts, p) + } + } + + var usage *genai.GenerateContentResponseUsageMetadata + if promptTokens > 0 || completionTokens > 0 { + usage = &genai.GenerateContentResponseUsageMetadata{ + PromptTokenCount: int32(promptTokens), + CandidatesTokenCount: int32(completionTokens), + } + } + + yield(&model.LLMResponse{ + Partial: false, + TurnComplete: true, + FinishReason: openAIFinishReasonToGenai(finishReason), + UsageMetadata: usage, + Content: &genai.Content{Role: string(genai.RoleModel), Parts: finalParts}, + }, nil) +} + +func (m *SAPAICoreModel) handleNonStream(body io.Reader, yield func(*model.LLMResponse, error) bool) { + var data map[string]any + if err := json.NewDecoder(body).Decode(&data); err != nil { + yield(nil, fmt.Errorf("failed to decode SAP AI Core response: %w", err)) + return + } + + result, ok := data["final_result"].(map[string]any) + if !ok { + result = data + } + + choices, _ := result["choices"].([]any) + if len(choices) == 0 { + yield(&model.LLMResponse{ErrorCode: "API_ERROR", ErrorMessage: "No choices in response"}, nil) + return + } + + parts := make([]*genai.Part, 0) + firstChoice, _ := choices[0].(map[string]any) + msg, _ := firstChoice["message"].(map[string]any) + + if content, ok := msg["content"].(string); ok && content != "" { + parts = append(parts, &genai.Part{Text: content}) + } + + if toolCalls, ok := msg["tool_calls"].([]any); ok { + for _, tcRaw := range toolCalls { + tc, ok := tcRaw.(map[string]any) + if !ok { + continue + } + fn, _ := tc["function"].(map[string]any) + name, _ := fn["name"].(string) + argsStr, _ := fn["arguments"].(string) + var args map[string]any + if argsStr != "" { + _ = json.Unmarshal([]byte(argsStr), &args) + } + id, _ := tc["id"].(string) + p := genai.NewPartFromFunctionCall(name, args) + p.FunctionCall.ID = id + parts = append(parts, p) + } + } + + var usage *genai.GenerateContentResponseUsageMetadata + if u, ok := result["usage"].(map[string]any); ok { + pt, _ := u["prompt_tokens"].(float64) + ct, _ := u["completion_tokens"].(float64) + if pt > 0 || ct > 0 { + usage = &genai.GenerateContentResponseUsageMetadata{ + PromptTokenCount: int32(pt), + CandidatesTokenCount: int32(ct), + } + } + } + + fr := "stop" + if f, ok := firstChoice["finish_reason"].(string); ok { + fr = f + } + + yield(&model.LLMResponse{ + Partial: false, + TurnComplete: true, + FinishReason: openAIFinishReasonToGenai(fr), + UsageMetadata: usage, + Content: &genai.Content{Role: string(genai.RoleModel), Parts: parts}, + }, nil) +} + +func parseOrchChunk(event map[string]any) map[string]any { + if r, ok := event["orchestration_result"].(map[string]any); ok { + return r + } + if r, ok := event["final_result"].(map[string]any); ok { + return r + } + if _, ok := event["choices"]; ok { + return event + } + return nil +} diff --git a/go/adk/pkg/models/sapaicore_adk_test.go b/go/adk/pkg/models/sapaicore_adk_test.go new file mode 100644 index 000000000..ade2c909d --- /dev/null +++ b/go/adk/pkg/models/sapaicore_adk_test.go @@ -0,0 +1,1163 @@ +package models + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/go-logr/logr" + "google.golang.org/adk/model" + "google.golang.org/genai" +) + +// ---- helpers ---- + +func newTestSAPModel(t *testing.T, baseURL, authURL string) *SAPAICoreModel { + t.Helper() + m, err := NewSAPAICoreModelWithLogger(SAPAICoreConfig{ + Model: "test-model", + BaseUrl: baseURL, + ResourceGroup: "default", + AuthUrl: authURL, + }, logr.Discard()) + if err != nil { + t.Fatalf("NewSAPAICoreModelWithLogger: %v", err) + } + return m +} + +func oauthServer(t *testing.T, token string) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "access_token": token, + "expires_in": 3600, + }) + })) +} + +func deploymentServerWith(urls ...string) *httptest.Server { + resources := make([]map[string]any, 0, len(urls)) + for i, u := range urls { + resources = append(resources, map[string]any{ + "id": fmt.Sprintf("dep-%d", i), + "scenarioId": "orchestration", + "status": "RUNNING", + "deploymentUrl": u, + "createdAt": fmt.Sprintf("2024-01-%02d", i+1), + }) + } + body, _ := json.Marshal(map[string]any{"resources": resources}) + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write(body) + })) +} + +// ---- genaiContentsToOrchTemplate ---- + +func TestGenaiContentsToOrchTemplate_Empty(t *testing.T) { + msgs, sys := genaiContentsToOrchTemplate(nil, nil) + if len(msgs) != 0 { + t.Errorf("len(msgs) = %d, want 0", len(msgs)) + } + if sys != "" { + t.Errorf("sys = %q, want empty", sys) + } +} + +func TestGenaiContentsToOrchTemplate_SystemInstruction(t *testing.T) { + config := &genai.GenerateContentConfig{ + SystemInstruction: &genai.Content{ + Parts: []*genai.Part{ + {Text: "You are helpful."}, + {Text: "Be concise."}, + }, + }, + } + _, sys := genaiContentsToOrchTemplate(nil, config) + want := "You are helpful.\nBe concise." + if sys != want { + t.Errorf("sys = %q, want %q", sys, want) + } +} + +func TestGenaiContentsToOrchTemplate_TextMessages(t *testing.T) { + contents := []*genai.Content{ + {Role: "user", Parts: []*genai.Part{{Text: "Hello"}}}, + {Role: "model", Parts: []*genai.Part{{Text: "Hi there"}}}, + } + msgs, sys := genaiContentsToOrchTemplate(contents, nil) + if sys != "" { + t.Errorf("sys = %q, want empty", sys) + } + if len(msgs) != 2 { + t.Fatalf("len(msgs) = %d, want 2", len(msgs)) + } + if msgs[0]["role"] != "user" || msgs[0]["content"] != "Hello" { + t.Errorf("msgs[0] = %v, want {role:user, content:Hello}", msgs[0]) + } + if msgs[1]["role"] != "assistant" || msgs[1]["content"] != "Hi there" { + t.Errorf("msgs[1] = %v, want {role:assistant, content:Hi there}", msgs[1]) + } +} + +func TestGenaiContentsToOrchTemplate_SkipsSystemRole(t *testing.T) { + contents := []*genai.Content{ + {Role: "system", Parts: []*genai.Part{{Text: "ignored"}}}, + {Role: "user", Parts: []*genai.Part{{Text: "hello"}}}, + } + msgs, _ := genaiContentsToOrchTemplate(contents, nil) + if len(msgs) != 1 { + t.Errorf("len(msgs) = %d, want 1 (system role skipped)", len(msgs)) + } +} + +func TestGenaiContentsToOrchTemplate_SkipsNilContent(t *testing.T) { + contents := []*genai.Content{ + nil, + {Role: "user", Parts: []*genai.Part{{Text: "hello"}}}, + } + msgs, _ := genaiContentsToOrchTemplate(contents, nil) + if len(msgs) != 1 { + t.Errorf("len(msgs) = %d, want 1 (nil content skipped)", len(msgs)) + } +} + +func TestGenaiContentsToOrchTemplate_ToolCall(t *testing.T) { + fc := genai.NewPartFromFunctionCall("get_weather", map[string]any{"city": "Berlin"}) + fc.FunctionCall.ID = "call_1" + + contents := []*genai.Content{ + {Role: "model", Parts: []*genai.Part{fc}}, + } + msgs, _ := genaiContentsToOrchTemplate(contents, nil) + if len(msgs) == 0 { + t.Fatal("expected at least 1 message") + } + msg := msgs[0] + if msg["role"] != "assistant" { + t.Errorf("role = %v, want assistant", msg["role"]) + } + toolCalls, ok := msg["tool_calls"].([]map[string]any) + if !ok || len(toolCalls) == 0 { + t.Fatalf("tool_calls = %v, want non-empty slice", msg["tool_calls"]) + } + if toolCalls[0]["id"] != "call_1" { + t.Errorf("tool_calls[0].id = %v, want call_1", toolCalls[0]["id"]) + } +} + +func TestGenaiContentsToOrchTemplate_FunctionResponse(t *testing.T) { + fc := genai.NewPartFromFunctionCall("get_weather", map[string]any{"city": "Berlin"}) + fc.FunctionCall.ID = "call_1" + fr := genai.NewPartFromFunctionResponse("get_weather", map[string]any{"temp": "20C"}) + fr.FunctionResponse.ID = "call_1" + + contents := []*genai.Content{ + {Role: "model", Parts: []*genai.Part{fc}}, + {Role: "user", Parts: []*genai.Part{fr}}, + } + msgs, _ := genaiContentsToOrchTemplate(contents, nil) + + if len(msgs) < 2 { + t.Fatalf("len(msgs) = %d, want >= 2", len(msgs)) + } + toolMsg := msgs[len(msgs)-1] + if toolMsg["role"] != "tool" { + t.Errorf("last msg role = %v, want tool", toolMsg["role"]) + } + if toolMsg["tool_call_id"] != "call_1" { + t.Errorf("tool_call_id = %v, want call_1", toolMsg["tool_call_id"]) + } +} + +// ---- buildOrchestrationBody ---- + +func TestBuildOrchestrationBody_Basic(t *testing.T) { + m := &SAPAICoreModel{Config: SAPAICoreConfig{Model: "my-model"}} + req := &model.LLMRequest{Model: "my-model"} + body := m.buildOrchestrationBody(req, false) + + cfg, ok := body["config"].(map[string]any) + if !ok { + t.Fatalf("body[config] missing or wrong type") + } + modules, ok := cfg["modules"].(map[string]any) + if !ok { + t.Fatalf("config[modules] missing") + } + if _, ok := modules["prompt_templating"]; !ok { + t.Error("modules[prompt_templating] missing") + } + stream, ok := cfg["stream"].(map[string]any) + if !ok { + t.Fatalf("config[stream] missing") + } + if stream["enabled"] != false { + t.Errorf("stream.enabled = %v, want false", stream["enabled"]) + } +} + +func TestBuildOrchestrationBody_StreamEnabled(t *testing.T) { + m := &SAPAICoreModel{Config: SAPAICoreConfig{Model: "my-model"}} + body := m.buildOrchestrationBody(&model.LLMRequest{}, true) + cfg := body["config"].(map[string]any) + stream := cfg["stream"].(map[string]any) + if stream["enabled"] != true { + t.Errorf("stream.enabled = %v, want true", stream["enabled"]) + } +} + +func TestBuildOrchestrationBody_Params(t *testing.T) { + m := &SAPAICoreModel{Config: SAPAICoreConfig{Model: "my-model"}} + temp := float32(0.7) + topP := float32(0.9) + req := &model.LLMRequest{ + Config: &genai.GenerateContentConfig{ + Temperature: &temp, + MaxOutputTokens: 512, + TopP: &topP, + }, + } + body := m.buildOrchestrationBody(req, false) + + cfg := body["config"].(map[string]any) + modules := cfg["modules"].(map[string]any) + pt := modules["prompt_templating"].(map[string]any) + modelBlock := pt["model"].(map[string]any) + params := modelBlock["params"].(map[string]any) + + if params["temperature"] != float32(0.7) { + t.Errorf("temperature = %v, want 0.7", params["temperature"]) + } + if params["max_tokens"] != int32(512) { + t.Errorf("max_tokens = %v, want 512", params["max_tokens"]) + } + if params["top_p"] != float32(0.9) { + t.Errorf("top_p = %v, want 0.9", params["top_p"]) + } +} + +func TestBuildOrchestrationBody_WithTools(t *testing.T) { + m := &SAPAICoreModel{Config: SAPAICoreConfig{Model: "my-model"}} + req := &model.LLMRequest{ + Config: &genai.GenerateContentConfig{ + Tools: []*genai.Tool{{ + FunctionDeclarations: []*genai.FunctionDeclaration{{ + Name: "list_pods", + Description: "List pods", + }}, + }}, + }, + } + body := m.buildOrchestrationBody(req, false) + cfg := body["config"].(map[string]any) + modules := cfg["modules"].(map[string]any) + pt := modules["prompt_templating"].(map[string]any) + prompt := pt["prompt"].(map[string]any) + tools, ok := prompt["tools"].([]map[string]any) + if !ok || len(tools) == 0 { + t.Fatalf("prompt[tools] = %v, want non-empty", prompt["tools"]) + } + fn := tools[0]["function"].(map[string]any) + if fn["name"] != "list_pods" { + t.Errorf("tool name = %v, want list_pods", fn["name"]) + } +} + +// ---- parseOrchChunk ---- + +func TestParseOrchChunk(t *testing.T) { + tests := []struct { + name string + event map[string]any + wantNil bool + wantKey string + }{ + { + name: "orchestration_result", + event: map[string]any{"orchestration_result": map[string]any{"choices": []any{}}}, + wantNil: false, + wantKey: "choices", + }, + { + name: "final_result", + event: map[string]any{"final_result": map[string]any{"choices": []any{}}}, + wantNil: false, + wantKey: "choices", + }, + { + name: "direct choices", + event: map[string]any{"choices": []any{}, "other": "data"}, + wantNil: false, + wantKey: "choices", + }, + { + name: "unrecognized", + event: map[string]any{"foo": "bar"}, + wantNil: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := parseOrchChunk(tt.event) + if tt.wantNil { + if got != nil { + t.Errorf("parseOrchChunk() = %v, want nil", got) + } + return + } + if got == nil { + t.Fatal("parseOrchChunk() = nil, want non-nil") + } + if _, ok := got[tt.wantKey]; !ok { + t.Errorf("result missing key %q", tt.wantKey) + } + }) + } +} + +// ---- isRetryableError ---- + +func TestIsRetryableError(t *testing.T) { + retryable := []int{401, 403, 404, 502, 503, 504} + for _, code := range retryable { + t.Run(fmt.Sprintf("HTTP_%d_retryable", code), func(t *testing.T) { + if !isRetryableError(&orchHTTPError{StatusCode: code}) { + t.Errorf("isRetryableError(HTTP %d) = false, want true", code) + } + }) + } + nonRetryable := []int{400, 422, 500} + for _, code := range nonRetryable { + t.Run(fmt.Sprintf("HTTP_%d_not_retryable", code), func(t *testing.T) { + if isRetryableError(&orchHTTPError{StatusCode: code}) { + t.Errorf("isRetryableError(HTTP %d) = true, want false", code) + } + }) + } + t.Run("non-HTTP error not retryable", func(t *testing.T) { + if isRetryableError(fmt.Errorf("network error")) { + t.Error("isRetryableError(non-HTTP) = true, want false") + } + }) +} + +// ---- ensureToken (OAuth token caching) ---- + +func TestEnsureToken_CachesToken(t *testing.T) { + callCount := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "access_token": "tok-cached", + "expires_in": 3600, + }) + })) + defer srv.Close() + + t.Setenv("SAP_AI_CORE_CLIENT_ID", "id") + t.Setenv("SAP_AI_CORE_CLIENT_SECRET", "secret") + + m := newTestSAPModel(t, "http://base", srv.URL) + ctx := context.Background() + + tok1, err := m.ensureToken(ctx) + if err != nil { + t.Fatalf("ensureToken first call: %v", err) + } + tok2, err := m.ensureToken(ctx) + if err != nil { + t.Fatalf("ensureToken second call: %v", err) + } + if tok1 != "tok-cached" || tok2 != "tok-cached" { + t.Errorf("tokens = %q, %q, want tok-cached for both", tok1, tok2) + } + if callCount != 1 { + t.Errorf("auth server called %d times, want 1 (cached)", callCount) + } +} + +func TestEnsureToken_RefreshesExpiredToken(t *testing.T) { + callCount := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{ + "access_token": fmt.Sprintf("tok-%d", callCount), + "expires_in": 3600, + }) + })) + defer srv.Close() + + t.Setenv("SAP_AI_CORE_CLIENT_ID", "id") + t.Setenv("SAP_AI_CORE_CLIENT_SECRET", "secret") + + m := newTestSAPModel(t, "http://base", srv.URL) + ctx := context.Background() + + if _, err := m.ensureToken(ctx); err != nil { + t.Fatalf("first ensureToken: %v", err) + } + + // Force expiry + m.mu.Lock() + m.tokenExpiresAt = time.Now().Add(-1 * time.Second) + m.mu.Unlock() + + if _, err := m.ensureToken(ctx); err != nil { + t.Fatalf("second ensureToken: %v", err) + } + if callCount != 2 { + t.Errorf("auth server called %d times, want 2 (expired token refreshed)", callCount) + } +} + +func TestEnsureToken_MissingEnvVarsReturnsError(t *testing.T) { + os.Unsetenv("SAP_AI_CORE_CLIENT_ID") + os.Unsetenv("SAP_AI_CORE_CLIENT_SECRET") + + m := newTestSAPModel(t, "http://base", "http://auth") + _, err := m.ensureToken(context.Background()) + if err == nil { + t.Error("ensureToken() = nil, want error when env vars missing") + } +} + +func TestInvalidateToken_ClearsCache(t *testing.T) { + m := newTestSAPModel(t, "http://base", "http://auth") + m.token = "old-token" + m.tokenExpiresAt = time.Now().Add(time.Hour) + m.invalidateToken() + if m.token != "" { + t.Errorf("token = %q after invalidate, want empty", m.token) + } + if !m.tokenExpiresAt.IsZero() { + t.Errorf("tokenExpiresAt = %v after invalidate, want zero", m.tokenExpiresAt) + } +} + +// ---- resolveDeploymentURL (URL discovery & caching) ---- + +func TestResolveDeploymentURL_CachesURL(t *testing.T) { + authSrv := oauthServer(t, "tok") + defer authSrv.Close() + + callCount := 0 + depSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + body, _ := json.Marshal(map[string]any{"resources": []map[string]any{{ + "scenarioId": "orchestration", "status": "RUNNING", + "deploymentUrl": "https://dep.example.com", "createdAt": "2024-01-01", + }}}) + w.Write(body) + })) + defer depSrv.Close() + + t.Setenv("SAP_AI_CORE_CLIENT_ID", "id") + t.Setenv("SAP_AI_CORE_CLIENT_SECRET", "secret") + + m := newTestSAPModel(t, depSrv.URL, authSrv.URL) + ctx := context.Background() + + url1, err := m.resolveDeploymentURL(ctx) + if err != nil { + t.Fatalf("first resolveDeploymentURL: %v", err) + } + url2, err := m.resolveDeploymentURL(ctx) + if err != nil { + t.Fatalf("second resolveDeploymentURL: %v", err) + } + if url1 != "https://dep.example.com" || url2 != "https://dep.example.com" { + t.Errorf("urls = %q, %q, want https://dep.example.com for both", url1, url2) + } + if callCount != 1 { + t.Errorf("deployment API called %d times, want 1 (cached)", callCount) + } +} + +func TestResolveDeploymentURL_PicksLatestCreated(t *testing.T) { + authSrv := oauthServer(t, "tok") + defer authSrv.Close() + + depSrv := deploymentServerWith("https://older.example.com", "https://newer.example.com") + defer depSrv.Close() + + t.Setenv("SAP_AI_CORE_CLIENT_ID", "id") + t.Setenv("SAP_AI_CORE_CLIENT_SECRET", "secret") + + m := newTestSAPModel(t, depSrv.URL, authSrv.URL) + url, err := m.resolveDeploymentURL(context.Background()) + if err != nil { + t.Fatalf("resolveDeploymentURL: %v", err) + } + // "2024-01-02" > "2024-01-01" — newer should win + if url != "https://newer.example.com" { + t.Errorf("url = %q, want https://newer.example.com", url) + } +} + +func TestResolveDeploymentURL_NoRunningDeploymentError(t *testing.T) { + authSrv := oauthServer(t, "tok") + defer authSrv.Close() + + depSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := json.Marshal(map[string]any{"resources": []map[string]any{{ + "scenarioId": "other-scenario", "status": "RUNNING", + "deploymentUrl": "https://x.example.com", "createdAt": "2024-01-01", + }}}) + w.Write(body) + })) + defer depSrv.Close() + + t.Setenv("SAP_AI_CORE_CLIENT_ID", "id") + t.Setenv("SAP_AI_CORE_CLIENT_SECRET", "secret") + + m := newTestSAPModel(t, depSrv.URL, authSrv.URL) + _, err := m.resolveDeploymentURL(context.Background()) + if err == nil { + t.Error("resolveDeploymentURL() = nil, want error for no running orchestration deployments") + } +} + +func TestResolveDeploymentURL_ExpiresAfterOneHour(t *testing.T) { + authSrv := oauthServer(t, "tok") + defer authSrv.Close() + + callCount := 0 + depSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + body, _ := json.Marshal(map[string]any{"resources": []map[string]any{{ + "scenarioId": "orchestration", "status": "RUNNING", + "deploymentUrl": "https://dep.example.com", "createdAt": "2024-01-01", + }}}) + w.Write(body) + })) + defer depSrv.Close() + + t.Setenv("SAP_AI_CORE_CLIENT_ID", "id") + t.Setenv("SAP_AI_CORE_CLIENT_SECRET", "secret") + + m := newTestSAPModel(t, depSrv.URL, authSrv.URL) + ctx := context.Background() + + // First call — populates cache. + if _, err := m.resolveDeploymentURL(ctx); err != nil { + t.Fatalf("first resolveDeploymentURL: %v", err) + } + + // Expire the cache by backdating the timestamp. + m.mu.Lock() + m.deploymentURLAt = time.Now().Add(-2 * time.Hour) + m.mu.Unlock() + + // Second call — cache expired, must re-fetch. + url, err := m.resolveDeploymentURL(ctx) + if err != nil { + t.Fatalf("second resolveDeploymentURL: %v", err) + } + if url != "https://dep.example.com" { + t.Errorf("url = %q, want https://dep.example.com", url) + } + if callCount != 2 { + t.Errorf("deployment API called %d times, want 2 (cache expired)", callCount) + } +} + +func TestInvalidateDeploymentURL_ClearsCache(t *testing.T) { + m := newTestSAPModel(t, "http://base", "http://auth") + m.deploymentURL = "https://old.example.com" + m.deploymentURLAt = time.Now() + m.invalidateDeploymentURL() + if m.deploymentURL != "" { + t.Errorf("deploymentURL = %q after invalidate, want empty", m.deploymentURL) + } +} + +// ---- context cancellation ---- + +func TestEnsureToken_RespectsContextCancellation(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case <-r.Context().Done(): + return + case <-time.After(500 * time.Millisecond): + json.NewEncoder(w).Encode(map[string]any{"access_token": "tok", "expires_in": 3600}) + } + })) + defer srv.Close() + + t.Setenv("SAP_AI_CORE_CLIENT_ID", "id") + t.Setenv("SAP_AI_CORE_CLIENT_SECRET", "secret") + + m := newTestSAPModel(t, "http://base", srv.URL) + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err := m.ensureToken(ctx) + if err == nil { + t.Error("ensureToken() = nil, want error when context cancelled") + } +} + +// ---- concurrent token access ---- + +func TestEnsureToken_ConcurrentAccess(t *testing.T) { + var mu sync.Mutex + callCount := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + callCount++ + mu.Unlock() + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]any{"access_token": "tok", "expires_in": 3600}) + })) + defer srv.Close() + + t.Setenv("SAP_AI_CORE_CLIENT_ID", "id") + t.Setenv("SAP_AI_CORE_CLIENT_SECRET", "secret") + + m := newTestSAPModel(t, "http://base", srv.URL) + ctx := context.Background() + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + if _, err := m.ensureToken(ctx); err != nil { + t.Errorf("ensureToken concurrent: %v", err) + } + }() + } + wg.Wait() + + // ensureToken holds the mutex while doing HTTP, so only 1 request is made + // even under concurrent access. + mu.Lock() + defer mu.Unlock() + if callCount != 1 { + t.Errorf("auth server called %d times, want 1 (mutex serializes concurrent requests)", callCount) + } +} + +// ---- handleNonStream ---- + +func TestHandleNonStream_TextResponse(t *testing.T) { + m := newTestSAPModel(t, "http://base", "http://auth") + body := map[string]any{ + "final_result": map[string]any{ + "choices": []any{ + map[string]any{ + "finish_reason": "stop", + "message": map[string]any{ + "role": "assistant", + "content": "Hello from SAP AI Core", + }, + }, + }, + }, + } + bodyBytes, _ := json.Marshal(body) + + var got *model.LLMResponse + m.handleNonStream(jsonReader(bodyBytes), func(r *model.LLMResponse, err error) bool { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + got = r + return true + }) + if got == nil { + t.Fatal("got nil response") + } + if got.Content == nil || len(got.Content.Parts) == 0 { + t.Fatal("got empty content parts") + } + if got.Content.Parts[0].Text != "Hello from SAP AI Core" { + t.Errorf("text = %q, want Hello from SAP AI Core", got.Content.Parts[0].Text) + } +} + +func TestHandleNonStream_ToolCallResponse(t *testing.T) { + m := newTestSAPModel(t, "http://base", "http://auth") + body := map[string]any{ + "choices": []any{ + map[string]any{ + "finish_reason": "tool_calls", + "message": map[string]any{ + "role": "assistant", + "content": "", + "tool_calls": []any{ + map[string]any{ + "id": "call_1", + "type": "function", + "function": map[string]any{ + "name": "get_weather", + "arguments": `{"city":"Berlin"}`, + }, + }, + }, + }, + }, + }, + } + bodyBytes, _ := json.Marshal(body) + + var got *model.LLMResponse + m.handleNonStream(jsonReader(bodyBytes), func(r *model.LLMResponse, err error) bool { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + got = r + return true + }) + if got == nil || got.Content == nil { + t.Fatal("got nil response or content") + } + var fc *genai.FunctionCall + for _, p := range got.Content.Parts { + if p.FunctionCall != nil { + fc = p.FunctionCall + break + } + } + if fc == nil { + t.Fatal("no function call part in response") + } + if fc.Name != "get_weather" { + t.Errorf("function name = %q, want get_weather", fc.Name) + } + if fc.ID != "call_1" { + t.Errorf("function ID = %q, want call_1", fc.ID) + } +} + +func TestHandleNonStream_NoChoicesError(t *testing.T) { + m := newTestSAPModel(t, "http://base", "http://auth") + body := map[string]any{"choices": []any{}} + bodyBytes, _ := json.Marshal(body) + + var got *model.LLMResponse + m.handleNonStream(jsonReader(bodyBytes), func(r *model.LLMResponse, err error) bool { + got = r + return true + }) + if got == nil { + t.Fatal("expected error response, got nil") + } + if got.ErrorCode == "" { + t.Error("expected non-empty ErrorCode for empty choices") + } +} + +// ---- genaiToolsToOrchTools ---- + +func TestGenaiToolsToOrchTools_NilAndEmpty(t *testing.T) { + if got := genaiToolsToOrchTools(nil); len(got) != 0 { + t.Errorf("genaiToolsToOrchTools(nil) = %v, want empty", got) + } + if got := genaiToolsToOrchTools([]*genai.Tool{}); len(got) != 0 { + t.Errorf("genaiToolsToOrchTools([]) = %v, want empty", got) + } +} + +func TestGenaiToolsToOrchTools_SkipsNilTool(t *testing.T) { + tools := []*genai.Tool{ + nil, + {FunctionDeclarations: []*genai.FunctionDeclaration{{Name: "foo"}}}, + } + got := genaiToolsToOrchTools(tools) + if len(got) != 1 { + t.Errorf("len(got) = %d, want 1 (nil tool skipped)", len(got)) + } +} + +func TestGenaiToolsToOrchTools_WithJsonSchema(t *testing.T) { + tools := []*genai.Tool{{ + FunctionDeclarations: []*genai.FunctionDeclaration{{ + Name: "list_namespaces", + Description: "List K8s namespaces", + ParametersJsonSchema: map[string]any{ + "type": "object", + "properties": map[string]any{"label": map[string]any{"type": "string"}}, + }, + }}, + }} + got := genaiToolsToOrchTools(tools) + if len(got) != 1 { + t.Fatalf("len(got) = %d, want 1", len(got)) + } + fn := got[0]["function"].(map[string]any) + if fn["name"] != "list_namespaces" { + t.Errorf("name = %v, want list_namespaces", fn["name"]) + } + if fn["description"] != "List K8s namespaces" { + t.Errorf("description = %v", fn["description"]) + } +} + +// ---- handleStream ---- + +// sseBody builds a minimal SSE byte stream from a slice of JSON-serialisable +// payloads. Each entry is written as "data: \n\n"; the stream ends with +// "data: [DONE]\n\n". +func sseBody(t *testing.T, payloads ...any) *strings.Reader { + t.Helper() + var b strings.Builder + for _, p := range payloads { + raw, err := json.Marshal(p) + if err != nil { + t.Fatalf("sseBody marshal: %v", err) + } + b.WriteString("data: ") + b.Write(raw) + b.WriteString("\n\n") + } + b.WriteString("data: [DONE]\n\n") + return strings.NewReader(b.String()) +} + +// orchChunk wraps a choices slice in the orchestration_result envelope that +// the SAP Orchestration Service uses for streaming responses. +func orchChunk(choices []any) map[string]any { + return map[string]any{ + "orchestration_result": map[string]any{ + "choices": choices, + }, + } +} + +func textDelta(text string) map[string]any { + return map[string]any{ + "delta": map[string]any{"content": text}, + } +} + +func finishDelta(reason string) map[string]any { + return map[string]any{ + "delta": map[string]any{}, + "finish_reason": reason, + } +} + +func TestHandleStream_TextChunks(t *testing.T) { + m := newTestSAPModel(t, "http://base", "http://auth") + + body := sseBody(t, + orchChunk([]any{textDelta("Hello")}), + orchChunk([]any{textDelta(", world")}), + orchChunk([]any{finishDelta("stop")}), + ) + + var responses []*model.LLMResponse + m.handleStream(context.Background(), body, func(r *model.LLMResponse, err error) bool { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + responses = append(responses, r) + return true + }) + + // Expect two partial responses + one final response. + if len(responses) < 3 { + t.Fatalf("got %d responses, want >= 3 (2 partials + 1 final)", len(responses)) + } + + // First two are partial text chunks. + if !responses[0].Partial || responses[0].Content.Parts[0].Text != "Hello" { + t.Errorf("responses[0]: partial=%v text=%q, want partial=true text=Hello", + responses[0].Partial, responses[0].Content.Parts[0].Text) + } + if !responses[1].Partial || responses[1].Content.Parts[0].Text != ", world" { + t.Errorf("responses[1]: partial=%v text=%q, want partial=true text=', world'", + responses[1].Partial, responses[1].Content.Parts[0].Text) + } + + // Last response is the final aggregated one. + final := responses[len(responses)-1] + if final.Partial || !final.TurnComplete { + t.Errorf("final: partial=%v turnComplete=%v, want partial=false turnComplete=true", + final.Partial, final.TurnComplete) + } + if final.Content == nil || len(final.Content.Parts) == 0 { + t.Fatal("final response has no content parts") + } + if final.Content.Parts[0].Text != "Hello, world" { + t.Errorf("final aggregated text = %q, want 'Hello, world'", final.Content.Parts[0].Text) + } +} + +func TestHandleStream_ToolCall(t *testing.T) { + m := newTestSAPModel(t, "http://base", "http://auth") + + // Two chunks that together build up a single tool call (as real SSE streams do). + chunk1 := orchChunk([]any{map[string]any{ + "delta": map[string]any{ + "tool_calls": []any{map[string]any{ + "index": float64(0), + "id": "call_42", + "function": map[string]any{ + "name": "get_weather", + "arguments": `{"city":`, + }, + }}, + }, + }}) + chunk2 := orchChunk([]any{map[string]any{ + "delta": map[string]any{ + "tool_calls": []any{map[string]any{ + "index": float64(0), + "function": map[string]any{ + "arguments": `"Berlin"}`, + }, + }}, + }, + "finish_reason": "tool_calls", + }}) + + body := sseBody(t, chunk1, chunk2) + + var responses []*model.LLMResponse + m.handleStream(context.Background(), body, func(r *model.LLMResponse, err error) bool { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + responses = append(responses, r) + return true + }) + + if len(responses) == 0 { + t.Fatal("got 0 responses") + } + + final := responses[len(responses)-1] + if final.Partial { + t.Error("final response should not be partial") + } + + var fc *genai.FunctionCall + for _, p := range final.Content.Parts { + if p.FunctionCall != nil { + fc = p.FunctionCall + break + } + } + if fc == nil { + t.Fatal("no function call part in final response") + } + if fc.Name != "get_weather" { + t.Errorf("function name = %q, want get_weather", fc.Name) + } + if fc.ID != "call_42" { + t.Errorf("function ID = %q, want call_42", fc.ID) + } + if city, ok := fc.Args["city"].(string); !ok || city != "Berlin" { + t.Errorf("args[city] = %v, want Berlin", fc.Args["city"]) + } +} + +func TestHandleStream_UsageMetadata(t *testing.T) { + m := newTestSAPModel(t, "http://base", "http://auth") + + body := sseBody(t, + orchChunk([]any{textDelta("hi")}), + // final_result envelope carries usage + map[string]any{ + "final_result": map[string]any{ + "choices": []any{finishDelta("stop")}, + "usage": map[string]any{ + "prompt_tokens": float64(10), + "completion_tokens": float64(5), + }, + }, + }, + ) + + var responses []*model.LLMResponse + m.handleStream(context.Background(), body, func(r *model.LLMResponse, err error) bool { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + responses = append(responses, r) + return true + }) + + final := responses[len(responses)-1] + if final.UsageMetadata == nil { + t.Fatal("expected UsageMetadata in final response, got nil") + } + if final.UsageMetadata.PromptTokenCount != 10 { + t.Errorf("PromptTokenCount = %d, want 10", final.UsageMetadata.PromptTokenCount) + } + if final.UsageMetadata.CandidatesTokenCount != 5 { + t.Errorf("CandidatesTokenCount = %d, want 5", final.UsageMetadata.CandidatesTokenCount) + } +} + +func TestHandleStream_ErrorEvent(t *testing.T) { + m := newTestSAPModel(t, "http://base", "http://auth") + + body := sseBody(t, + map[string]any{"code": "500", "message": "internal error"}, + ) + + var gotErr error + m.handleStream(context.Background(), body, func(r *model.LLMResponse, err error) bool { + if err != nil { + gotErr = err + } + return true + }) + + if gotErr == nil { + t.Error("handleStream() error = nil, want error for stream error event") + } +} + +func TestHandleStream_IgnoresMalformedLines(t *testing.T) { + m := newTestSAPModel(t, "http://base", "http://auth") + + // Inject a non-JSON line between valid chunks; it must be silently skipped. + var b strings.Builder + b.WriteString("data: not-valid-json\n\n") + raw, _ := json.Marshal(orchChunk([]any{textDelta("ok")})) + b.WriteString("data: ") + b.Write(raw) + b.WriteString("\n\n") + b.WriteString("data: [DONE]\n\n") + + var responses []*model.LLMResponse + m.handleStream(context.Background(), strings.NewReader(b.String()), func(r *model.LLMResponse, err error) bool { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + responses = append(responses, r) + return true + }) + + // Should still produce at least the partial + final for the valid chunk. + if len(responses) < 2 { + t.Errorf("got %d responses, want >= 2 (malformed line skipped)", len(responses)) + } +} + +func TestHandleStream_ContextCancellation(t *testing.T) { + m := newTestSAPModel(t, "http://base", "http://auth") + + // Build a stream with many chunks; cancel context after the first yield. + chunks := make([]any, 20) + for i := range chunks { + chunks[i] = orchChunk([]any{textDelta("x")}) + } + body := sseBody(t, chunks...) + + ctx, cancel := context.WithCancel(context.Background()) + count := 0 + m.handleStream(ctx, body, func(r *model.LLMResponse, err error) bool { + count++ + cancel() // cancel after receiving the first chunk + return true + }) + + // After cancellation the loop should stop; we should not receive all 20 partials. + if count >= 20 { + t.Errorf("received %d responses after cancel, expected early stop", count) + } +} + +func TestHandleStream_EmptyStream(t *testing.T) { + m := newTestSAPModel(t, "http://base", "http://auth") + + body := sseBody(t) // only [DONE] + + var responses []*model.LLMResponse + m.handleStream(context.Background(), body, func(r *model.LLMResponse, err error) bool { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + responses = append(responses, r) + return true + }) + + // An empty stream yields one final TurnComplete response with no parts. + if len(responses) != 1 { + t.Fatalf("got %d responses, want 1 (empty final)", len(responses)) + } + if responses[0].Partial || !responses[0].TurnComplete { + t.Errorf("empty stream final: partial=%v turnComplete=%v, want false/true", + responses[0].Partial, responses[0].TurnComplete) + } +} + +func TestGenaiToolsToOrchTools_WithParameters(t *testing.T) { + // fd.Parameters (genai.Schema) path — used when ParametersJsonSchema is nil. + tools := []*genai.Tool{{ + FunctionDeclarations: []*genai.FunctionDeclaration{{ + Name: "scale_deployment", + Description: "Scale a K8s deployment", + Parameters: &genai.Schema{ + Type: genai.TypeObject, + Properties: map[string]*genai.Schema{ + "replicas": {Type: genai.TypeInteger}, + }, + }, + }}, + }} + got := genaiToolsToOrchTools(tools) + if len(got) != 1 { + t.Fatalf("len(got) = %d, want 1", len(got)) + } + fn := got[0]["function"].(map[string]any) + if fn["name"] != "scale_deployment" { + t.Errorf("name = %v, want scale_deployment", fn["name"]) + } + params, ok := fn["parameters"].(map[string]any) + if !ok { + t.Fatalf("parameters not a map: %v", fn["parameters"]) + } + if params["type"] != "object" { + t.Errorf("parameters.type = %v, want object", params["type"]) + } +} + +func TestHandleStream_FinishReasonAtChoiceLevel(t *testing.T) { + // finish_reason sits at the choice top level (not inside delta), + // which is what SAP sends in the final chunk (sapaicore_adk.go:386). + m := newTestSAPModel(t, "http://base", "http://auth") + + chunk := orchChunk([]any{map[string]any{ + "delta": map[string]any{"content": "done"}, + "finish_reason": "length", // top-level, not inside delta + }}) + body := sseBody(t, chunk) + + var responses []*model.LLMResponse + m.handleStream(context.Background(), body, func(r *model.LLMResponse, err error) bool { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + responses = append(responses, r) + return true + }) + + final := responses[len(responses)-1] + if final.FinishReason != openAIFinishReasonToGenai("length") { + t.Errorf("FinishReason = %v, want MAX_TOKENS (length)", final.FinishReason) + } +} + +// ---- helpers ---- + +func jsonReader(b []byte) *strings.Reader { + return strings.NewReader(string(b)) +} diff --git a/go/api/adk/types.go b/go/api/adk/types.go index aee673f09..4f11226fb 100644 --- a/go/api/adk/types.go +++ b/go/api/adk/types.go @@ -86,6 +86,7 @@ const ( ModelTypeOllama = "ollama" ModelTypeGemini = "gemini" ModelTypeBedrock = "bedrock" + ModelTypeSAPAICore = "sap_ai_core" ) func (o *OpenAI) MarshalJSON() ([]byte, error) { @@ -246,6 +247,28 @@ func (b *Bedrock) GetType() string { return ModelTypeBedrock } +type SAPAICore struct { + BaseModel + BaseUrl string `json:"base_url"` + ResourceGroup string `json:"resource_group,omitempty"` + AuthUrl string `json:"auth_url,omitempty"` +} + +func (s *SAPAICore) MarshalJSON() ([]byte, error) { + type Alias SAPAICore + return json.Marshal(&struct { + Type string `json:"type"` + *Alias + }{ + Type: ModelTypeSAPAICore, + Alias: (*Alias)(s), + }) +} + +func (s *SAPAICore) GetType() string { + return ModelTypeSAPAICore +} + // GenericModel is a catch-all model type used by the Go ADK when the model // type doesn't match any known constant. type GenericModel struct { @@ -308,6 +331,12 @@ func ParseModel(bytes []byte) (Model, error) { return nil, err } return &bedrock, nil + case ModelTypeSAPAICore: + var sapAICore SAPAICore + if err := json.Unmarshal(bytes, &sapAICore); err != nil { + return nil, err + } + return &sapAICore, nil } return nil, fmt.Errorf("unknown model type: %s", model.Type) } @@ -373,6 +402,9 @@ func ModelToEmbeddingConfig(m Model) *EmbeddingConfig { e.Model = v.Model case *Bedrock: e.Model = v.Model + case *SAPAICore: + e.Model = v.Model + e.BaseUrl = v.BaseUrl default: e.Model = "" } diff --git a/go/api/config/crd/bases/kagent.dev_modelconfigs.yaml b/go/api/config/crd/bases/kagent.dev_modelconfigs.yaml index 08cdab1ec..3b5342021 100644 --- a/go/api/config/crd/bases/kagent.dev_modelconfigs.yaml +++ b/go/api/config/crd/bases/kagent.dev_modelconfigs.yaml @@ -438,12 +438,15 @@ spec: Mutually exclusive with apiKeySecret. type: boolean apiKeySecret: - description: The name of the secret that contains the API key. Must - be a reference to the name of a secret in the same namespace as - the referencing ModelConfig + description: |- + The name of the secret that contains the API key. Must be a reference to the name of a secret in the same namespace as the referencing ModelConfig. + For the SAPAICore provider, the secret must contain two keys: "client_id" and "client_secret" + (the OAuth2 client credentials for SAP AI Core). The apiKeySecretKey field is not used for SAPAICore. type: string apiKeySecretKey: - description: The key in the secret that contains the API key + description: |- + The key in the secret that contains the API key. + Not used for the SAPAICore provider (which always reads "client_id" and "client_secret" from the secret). type: string azureOpenAI: description: Azure OpenAI-specific configuration @@ -594,7 +597,24 @@ spec: - GeminiVertexAI - AnthropicVertexAI - Bedrock + - SAPAICore type: string + sapAICore: + description: SAP AI Core-specific configuration + properties: + authUrl: + description: OAuth2 token endpoint URL (e.g., https://tenant.authentication.eu10.hana.ondemand.com) + type: string + baseUrl: + description: Base URL for the SAP AI Core API (e.g., https://api.ai.prod.eu-central-1.aws.ml.hana.ondemand.com) + type: string + resourceGroup: + default: default + description: Resource group in SAP AI Core + type: string + required: + - baseUrl + type: object tls: description: |- TLS configuration for provider connections. @@ -656,12 +676,14 @@ spec: rule: '!(has(self.anthropicVertexAI) && self.provider != ''AnthropicVertexAI'')' - message: provider.bedrock must be nil if the provider is not Bedrock rule: '!(has(self.bedrock) && self.provider != ''Bedrock'')' + - message: provider.sapAICore must be nil if the provider is not SAPAICore + rule: '!(has(self.sapAICore) && self.provider != ''SAPAICore'')' - message: apiKeySecret must be set if apiKeySecretKey is set rule: '!(has(self.apiKeySecretKey) && !has(self.apiKeySecret))' - message: apiKeySecretKey must be set if apiKeySecret is set (except - for Bedrock provider) + for Bedrock and SAPAICore providers) rule: '!(has(self.apiKeySecret) && !has(self.apiKeySecretKey) && self.provider - != ''Bedrock'')' + != ''Bedrock'' && self.provider != ''SAPAICore'')' - message: apiKeyPassthrough and apiKeySecret are mutually exclusive rule: '!(has(self.apiKeyPassthrough) && self.apiKeyPassthrough && has(self.apiKeySecret) && size(self.apiKeySecret) > 0)' diff --git a/go/api/config/crd/bases/kagent.dev_modelproviderconfigs.yaml b/go/api/config/crd/bases/kagent.dev_modelproviderconfigs.yaml index ff922c140..493e817e9 100644 --- a/go/api/config/crd/bases/kagent.dev_modelproviderconfigs.yaml +++ b/go/api/config/crd/bases/kagent.dev_modelproviderconfigs.yaml @@ -90,6 +90,7 @@ spec: - GeminiVertexAI - AnthropicVertexAI - Bedrock + - SAPAICore type: string required: - type diff --git a/go/api/httpapi/types.go b/go/api/httpapi/types.go index 2e9ee9f76..c9f538472 100644 --- a/go/api/httpapi/types.go +++ b/go/api/httpapi/types.go @@ -66,6 +66,7 @@ type CreateModelConfigRequest struct { GeminiParams *v1alpha2.GeminiConfig `json:"gemini,omitempty"` GeminiVertexAIParams *v1alpha2.GeminiVertexAIConfig `json:"geminiVertexAI,omitempty"` AnthropicVertexAIParams *v1alpha2.AnthropicVertexAIConfig `json:"anthropicVertexAI,omitempty"` + SAPAICoreParams *v1alpha2.SAPAICoreConfig `json:"sapAICore,omitempty"` } // UpdateModelConfigRequest represents a request to update a model configuration @@ -80,6 +81,7 @@ type UpdateModelConfigRequest struct { GeminiParams *v1alpha2.GeminiConfig `json:"gemini,omitempty"` GeminiVertexAIParams *v1alpha2.GeminiVertexAIConfig `json:"geminiVertexAI,omitempty"` AnthropicVertexAIParams *v1alpha2.AnthropicVertexAIConfig `json:"anthropicVertexAI,omitempty"` + SAPAICoreParams *v1alpha2.SAPAICoreConfig `json:"sapAICore,omitempty"` } // Agent types diff --git a/go/api/v1alpha2/modelconfig_types.go b/go/api/v1alpha2/modelconfig_types.go index d23a9d079..45750bba3 100644 --- a/go/api/v1alpha2/modelconfig_types.go +++ b/go/api/v1alpha2/modelconfig_types.go @@ -25,7 +25,7 @@ const ( ) // ModelProvider represents the model provider type -// +kubebuilder:validation:Enum=Anthropic;OpenAI;AzureOpenAI;Ollama;Gemini;GeminiVertexAI;AnthropicVertexAI;Bedrock +// +kubebuilder:validation:Enum=Anthropic;OpenAI;AzureOpenAI;Ollama;Gemini;GeminiVertexAI;AnthropicVertexAI;Bedrock;SAPAICore type ModelProvider string const ( @@ -37,6 +37,7 @@ const ( ModelProviderGeminiVertexAI ModelProvider = "GeminiVertexAI" ModelProviderAnthropicVertexAI ModelProvider = "AnthropicVertexAI" ModelProviderBedrock ModelProvider = "Bedrock" + ModelProviderSAPAICore ModelProvider = "SAPAICore" ) type BaseVertexAIConfig struct { @@ -218,6 +219,22 @@ type BedrockConfig struct { Region string `json:"region"` } +// SAPAICoreConfig contains SAP AI Core-specific configuration options. +type SAPAICoreConfig struct { + // Base URL for the SAP AI Core API (e.g., https://api.ai.prod.eu-central-1.aws.ml.hana.ondemand.com) + // +required + BaseURL string `json:"baseUrl"` + + // Resource group in SAP AI Core + // +kubebuilder:default="default" + // +optional + ResourceGroup string `json:"resourceGroup,omitempty"` + + // OAuth2 token endpoint URL (e.g., https://tenant.authentication.eu10.hana.ondemand.com) + // +optional + AuthURL string `json:"authUrl,omitempty"` +} + // TLSConfig contains TLS/SSL configuration options for model provider connections. // This enables agents to connect to internal LiteLLM gateways or other providers // that use self-signed certificates or custom certificate authorities. @@ -264,8 +281,9 @@ type TLSConfig struct { // +kubebuilder:validation:XValidation:message="provider.geminiVertexAI must be nil if the provider is not GeminiVertexAI",rule="!(has(self.geminiVertexAI) && self.provider != 'GeminiVertexAI')" // +kubebuilder:validation:XValidation:message="provider.anthropicVertexAI must be nil if the provider is not AnthropicVertexAI",rule="!(has(self.anthropicVertexAI) && self.provider != 'AnthropicVertexAI')" // +kubebuilder:validation:XValidation:message="provider.bedrock must be nil if the provider is not Bedrock",rule="!(has(self.bedrock) && self.provider != 'Bedrock')" +// +kubebuilder:validation:XValidation:message="provider.sapAICore must be nil if the provider is not SAPAICore",rule="!(has(self.sapAICore) && self.provider != 'SAPAICore')" // +kubebuilder:validation:XValidation:message="apiKeySecret must be set if apiKeySecretKey is set",rule="!(has(self.apiKeySecretKey) && !has(self.apiKeySecret))" -// +kubebuilder:validation:XValidation:message="apiKeySecretKey must be set if apiKeySecret is set (except for Bedrock provider)",rule="!(has(self.apiKeySecret) && !has(self.apiKeySecretKey) && self.provider != 'Bedrock')" +// +kubebuilder:validation:XValidation:message="apiKeySecretKey must be set if apiKeySecret is set (except for Bedrock and SAPAICore providers)",rule="!(has(self.apiKeySecret) && !has(self.apiKeySecretKey) && self.provider != 'Bedrock' && self.provider != 'SAPAICore')" // +kubebuilder:validation:XValidation:message="apiKeyPassthrough and apiKeySecret are mutually exclusive",rule="!(has(self.apiKeyPassthrough) && self.apiKeyPassthrough && has(self.apiKeySecret) && size(self.apiKeySecret) > 0)" // +kubebuilder:validation:XValidation:message="apiKeyPassthrough must be false if provider is Gemini;GeminiVertexAI;AnthropicVertexAI",rule="!(has(self.apiKeyPassthrough) && self.apiKeyPassthrough && (self.provider == 'Gemini' || self.provider == 'GeminiVertexAI' || self.provider == 'AnthropicVertexAI'))" // +kubebuilder:validation:XValidation:message="caCertSecretKey requires caCertSecretRef",rule="!(has(self.tls) && has(self.tls.caCertSecretKey) && size(self.tls.caCertSecretKey) > 0 && (!has(self.tls.caCertSecretRef) || size(self.tls.caCertSecretRef) == 0))" @@ -274,11 +292,14 @@ type TLSConfig struct { type ModelConfigSpec struct { Model string `json:"model"` - // The name of the secret that contains the API key. Must be a reference to the name of a secret in the same namespace as the referencing ModelConfig + // The name of the secret that contains the API key. Must be a reference to the name of a secret in the same namespace as the referencing ModelConfig. + // For the SAPAICore provider, the secret must contain two keys: "client_id" and "client_secret" + // (the OAuth2 client credentials for SAP AI Core). The apiKeySecretKey field is not used for SAPAICore. // +optional APIKeySecret string `json:"apiKeySecret"` - // The key in the secret that contains the API key + // The key in the secret that contains the API key. + // Not used for the SAPAICore provider (which always reads "client_id" and "client_secret" from the secret). // +optional APIKeySecretKey string `json:"apiKeySecretKey"` @@ -328,6 +349,10 @@ type ModelConfigSpec struct { // +optional Bedrock *BedrockConfig `json:"bedrock,omitempty"` + // SAP AI Core-specific configuration + // +optional + SAPAICore *SAPAICoreConfig `json:"sapAICore,omitempty"` + // TLS configuration for provider connections. // Enables agents to connect to internal LiteLLM gateways or other providers // that use self-signed certificates or custom certificate authorities. diff --git a/go/api/v1alpha2/zz_generated.deepcopy.go b/go/api/v1alpha2/zz_generated.deepcopy.go index ed186ec34..fcb7e21e4 100644 --- a/go/api/v1alpha2/zz_generated.deepcopy.go +++ b/go/api/v1alpha2/zz_generated.deepcopy.go @@ -734,6 +734,11 @@ func (in *ModelConfigSpec) DeepCopyInto(out *ModelConfigSpec) { *out = new(BedrockConfig) **out = **in } + if in.SAPAICore != nil { + in, out := &in.SAPAICore, &out.SAPAICore + *out = new(SAPAICoreConfig) + **out = **in + } if in.TLS != nil { in, out := &in.TLS, &out.TLS *out = new(TLSConfig) @@ -1110,6 +1115,21 @@ func (in *RemoteMCPServerStatus) DeepCopy() *RemoteMCPServerStatus { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *SAPAICoreConfig) DeepCopyInto(out *SAPAICoreConfig) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new SAPAICoreConfig. +func (in *SAPAICoreConfig) DeepCopy() *SAPAICoreConfig { + if in == nil { + return nil + } + out := new(SAPAICoreConfig) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *SandboxAgent) DeepCopyInto(out *SandboxAgent) { *out = *in diff --git a/go/core/internal/controller/translator/agent/adk_api_translator.go b/go/core/internal/controller/translator/agent/adk_api_translator.go index c4873e35c..ead5c6533 100644 --- a/go/core/internal/controller/translator/agent/adk_api_translator.go +++ b/go/core/internal/controller/translator/agent/adk_api_translator.go @@ -640,6 +640,55 @@ func (a *adkApiTranslator) translateModel(ctx context.Context, namespace, modelC bedrock.APIKeyPassthrough = model.Spec.APIKeyPassthrough return bedrock, modelDeploymentData, secretHashBytes, nil + case v1alpha2.ModelProviderSAPAICore: + if model.Spec.SAPAICore == nil { + return nil, nil, nil, fmt.Errorf("sapAICore model config is required") + } + + if !model.Spec.APIKeyPassthrough && model.Spec.APIKeySecret != "" { + secret := &corev1.Secret{} + if err := a.kube.Get(ctx, types.NamespacedName{Namespace: namespace, Name: model.Spec.APIKeySecret}, secret); err != nil { + return nil, nil, nil, fmt.Errorf("failed to get SAP AI Core credentials secret: %w", err) + } + + modelDeploymentData.EnvVars = append(modelDeploymentData.EnvVars, corev1.EnvVar{ + Name: env.SAPAICoreClientID.Name(), + ValueFrom: &corev1.EnvVarSource{ + SecretKeyRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{ + Name: model.Spec.APIKeySecret, + }, + Key: "client_id", + }, + }, + }) + modelDeploymentData.EnvVars = append(modelDeploymentData.EnvVars, corev1.EnvVar{ + Name: env.SAPAICoreClientSecret.Name(), + ValueFrom: &corev1.EnvVarSource{ + SecretKeyRef: &corev1.SecretKeySelector{ + LocalObjectReference: corev1.LocalObjectReference{ + Name: model.Spec.APIKeySecret, + }, + Key: "client_secret", + }, + }, + }) + } + + sapAICore := &adk.SAPAICore{ + BaseModel: adk.BaseModel{ + Model: model.Spec.Model, + Headers: model.Spec.DefaultHeaders, + }, + BaseUrl: model.Spec.SAPAICore.BaseURL, + ResourceGroup: model.Spec.SAPAICore.ResourceGroup, + AuthUrl: model.Spec.SAPAICore.AuthURL, + } + + populateTLSFields(&sapAICore.BaseModel, model.Spec.TLS) + sapAICore.APIKeyPassthrough = model.Spec.APIKeyPassthrough + + return sapAICore, modelDeploymentData, secretHashBytes, nil default: return nil, nil, nil, fmt.Errorf("unsupported model provider: %s", model.Spec.Provider) } diff --git a/go/core/internal/httpserver/handlers/modelconfig.go b/go/core/internal/httpserver/handlers/modelconfig.go index 9e022f9f9..250b26895 100644 --- a/go/core/internal/httpserver/handlers/modelconfig.go +++ b/go/core/internal/httpserver/handlers/modelconfig.go @@ -60,6 +60,9 @@ func (h *ModelConfigHandler) HandleListModelConfigs(w ErrorResponseWriter, r *ht if config.Spec.Ollama != nil { FlattenStructToMap(config.Spec.Ollama, modelParams) } + if config.Spec.SAPAICore != nil { + FlattenStructToMap(config.Spec.SAPAICore, modelParams) + } responseItem := api.ModelConfigResponse{ Ref: common.GetObjectRef(&config), @@ -143,6 +146,9 @@ func (h *ModelConfigHandler) HandleGetModelConfig(w ErrorResponseWriter, r *http if modelConfig.Spec.Ollama != nil { FlattenStructToMap(modelConfig.Spec.Ollama, modelParams) } + if modelConfig.Spec.SAPAICore != nil { + FlattenStructToMap(modelConfig.Spec.SAPAICore, modelParams) + } responseItem := api.ModelConfigResponse{ Ref: common.GetObjectRef(modelConfig), @@ -311,6 +317,13 @@ func (h *ModelConfigHandler) HandleCreateModelConfig(w ErrorResponseWriter, r *h } else { log.V(1).Info("No AnthropicVertexAI params provided in create.") } + case v1alpha2.ModelProviderSAPAICore: + if req.SAPAICoreParams != nil { + modelConfig.Spec.SAPAICore = req.SAPAICoreParams + log.V(1).Info("Assigned SAPAICore params to spec") + } else { + log.V(1).Info("No SAPAICore params provided in create.") + } default: providerConfigErr = fmt.Errorf("unsupported provider type: %s", req.Provider.Type) } @@ -431,6 +444,7 @@ func (h *ModelConfigHandler) HandleUpdateModelConfig(w ErrorResponseWriter, r *h Gemini: nil, GeminiVertexAI: nil, AnthropicVertexAI: nil, + SAPAICore: nil, } // --- Update Secret if API Key is provided (and not Ollama) --- @@ -510,6 +524,13 @@ func (h *ModelConfigHandler) HandleUpdateModelConfig(w ErrorResponseWriter, r *h } else { log.V(1).Info("No AnthropicVertexAI params provided in update.") } + case v1alpha2.ModelProviderSAPAICore: + if req.SAPAICoreParams != nil { + modelConfig.Spec.SAPAICore = req.SAPAICoreParams + log.V(1).Info("Assigned updated SAPAICore params to spec") + } else { + log.V(1).Info("No SAPAICore params provided in update.") + } default: providerConfigErr = fmt.Errorf("unsupported provider type specified: %s", req.Provider.Type) } @@ -535,6 +556,8 @@ func (h *ModelConfigHandler) HandleUpdateModelConfig(w ErrorResponseWriter, r *h FlattenStructToMap(modelConfig.Spec.AzureOpenAI, updatedParams) } else if modelConfig.Spec.Ollama != nil { FlattenStructToMap(modelConfig.Spec.Ollama, updatedParams) + } else if modelConfig.Spec.SAPAICore != nil { + FlattenStructToMap(modelConfig.Spec.SAPAICore, updatedParams) } responseItem := api.ModelConfigResponse{ diff --git a/go/core/internal/httpserver/handlers/modelproviderconfig.go b/go/core/internal/httpserver/handlers/modelproviderconfig.go index 42a9eb735..226df54d3 100644 --- a/go/core/internal/httpserver/handlers/modelproviderconfig.go +++ b/go/core/internal/httpserver/handlers/modelproviderconfig.go @@ -48,8 +48,9 @@ func getRequiredKeysForModelProvider(providerType v1alpha2.ModelProvider) []stri // Based on the +required comments in the AzureOpenAIConfig struct definition return []string{"azureEndpoint", "apiVersion"} case v1alpha2.ModelProviderBedrock: - // Region is required for Bedrock return []string{"region"} + case v1alpha2.ModelProviderSAPAICore: + return []string{"baseUrl"} case v1alpha2.ModelProviderOpenAI, v1alpha2.ModelProviderAnthropic, v1alpha2.ModelProviderOllama: // These providers currently have no fields marked as strictly required in the API definition return []string{} @@ -126,6 +127,7 @@ func (h *ModelProviderConfigHandler) HandleListSupportedModelProviders(w ErrorRe {v1alpha2.ModelProviderGeminiVertexAI, reflect.TypeFor[v1alpha2.GeminiVertexAIConfig]()}, {v1alpha2.ModelProviderAnthropicVertexAI, reflect.TypeFor[v1alpha2.AnthropicVertexAIConfig]()}, {v1alpha2.ModelProviderBedrock, reflect.TypeFor[v1alpha2.BedrockConfig]()}, + {v1alpha2.ModelProviderSAPAICore, reflect.TypeFor[v1alpha2.SAPAICoreConfig]()}, } providersResponse := []map[string]any{} diff --git a/go/core/internal/httpserver/handlers/models.go b/go/core/internal/httpserver/handlers/models.go index e7b7b1e33..62c1b7b69 100644 --- a/go/core/internal/httpserver/handlers/models.go +++ b/go/core/internal/httpserver/handlers/models.go @@ -107,6 +107,56 @@ func (h *ModelHandler) HandleListSupportedModels(w ErrorResponseWriter, r *http. {Name: "global.anthropic.claude-opus-4-5-20251101-v1:0", FunctionCalling: true}, {Name: "us.amazon.nova-2-lite-v1:0", FunctionCalling: false}, }, + v1alpha2.ModelProviderSAPAICore: { + {Name: "anthropic--claude-4.6-sonnet", FunctionCalling: true}, + {Name: "anthropic--claude-4.6-opus", FunctionCalling: true}, + {Name: "anthropic--claude-4.5-sonnet", FunctionCalling: true}, + {Name: "anthropic--claude-4.5-opus", FunctionCalling: true}, + {Name: "anthropic--claude-4.5-haiku", FunctionCalling: true}, + {Name: "anthropic--claude-4-sonnet", FunctionCalling: true}, + {Name: "anthropic--claude-4-opus", FunctionCalling: true}, + {Name: "anthropic--claude-3.7-sonnet", FunctionCalling: true}, + {Name: "anthropic--claude-3.5-sonnet", FunctionCalling: true}, + {Name: "anthropic--claude-3-haiku", FunctionCalling: true}, + {Name: "gpt-5.2", FunctionCalling: true}, + {Name: "gpt-5", FunctionCalling: true}, + {Name: "gpt-5-mini", FunctionCalling: true}, + {Name: "gpt-5-nano", FunctionCalling: true}, + {Name: "gpt-4o", FunctionCalling: true}, + {Name: "gpt-4o-mini", FunctionCalling: true}, + {Name: "gpt-4.1", FunctionCalling: true}, + {Name: "gpt-4.1-mini", FunctionCalling: true}, + {Name: "gpt-4.1-nano", FunctionCalling: true}, + {Name: "o1", FunctionCalling: true}, + {Name: "o3", FunctionCalling: true}, + {Name: "o3-mini", FunctionCalling: true}, + {Name: "o4-mini", FunctionCalling: true}, + {Name: "gemini-3-pro-preview", FunctionCalling: true}, + {Name: "gemini-2.5-pro", FunctionCalling: true}, + {Name: "gemini-2.5-flash", FunctionCalling: true}, + {Name: "gemini-2.5-flash-lite", FunctionCalling: true}, + {Name: "gemini-2.0-flash", FunctionCalling: true}, + {Name: "gemini-2.0-flash-lite", FunctionCalling: true}, + {Name: "amazon--nova-premier", FunctionCalling: true}, + {Name: "amazon--nova-pro", FunctionCalling: true}, + {Name: "amazon--nova-lite", FunctionCalling: false}, + {Name: "amazon--nova-micro", FunctionCalling: false}, + {Name: "meta--llama3-70b-instruct", FunctionCalling: false}, + {Name: "mistralai--mistral-large-instruct", FunctionCalling: true}, + {Name: "mistralai--mistral-small-instruct", FunctionCalling: true}, + {Name: "mistralai--mistral-medium-instruct", FunctionCalling: true}, + {Name: "cohere--command-a-reasoning", FunctionCalling: true}, + {Name: "deepseek-v3.2", FunctionCalling: true}, + {Name: "deepseek-r1-0528", FunctionCalling: true}, + {Name: "qwen3-max", FunctionCalling: true}, + {Name: "qwen3.5-plus", FunctionCalling: true}, + {Name: "qwen-turbo", FunctionCalling: true}, + {Name: "qwen-flash", FunctionCalling: true}, + {Name: "sonar-deep-research", FunctionCalling: false}, + {Name: "sonar-pro", FunctionCalling: false}, + {Name: "sonar", FunctionCalling: false}, + {Name: "sap-abap-1", FunctionCalling: false}, + }, } log.Info("Successfully listed supported models", "count", len(supportedModels)) diff --git a/go/core/pkg/env/providers.go b/go/core/pkg/env/providers.go index 7250fb29c..5f264de09 100644 --- a/go/core/pkg/env/providers.go +++ b/go/core/pkg/env/providers.go @@ -153,3 +153,20 @@ var ( ComponentAgentRuntime, ) ) + +// SAP AI Core +var ( + SAPAICoreClientID = RegisterStringVar( + "SAP_AI_CORE_CLIENT_ID", + "", + "OAuth2 client ID for SAP AI Core authentication.", + ComponentAgentRuntime, + ) + + SAPAICoreClientSecret = RegisterStringVar( + "SAP_AI_CORE_CLIENT_SECRET", + "", + "OAuth2 client secret for SAP AI Core authentication.", + ComponentAgentRuntime, + ) +) diff --git a/helm/kagent-crds/templates/kagent.dev_modelconfigs.yaml b/helm/kagent-crds/templates/kagent.dev_modelconfigs.yaml index 08cdab1ec..3b5342021 100644 --- a/helm/kagent-crds/templates/kagent.dev_modelconfigs.yaml +++ b/helm/kagent-crds/templates/kagent.dev_modelconfigs.yaml @@ -438,12 +438,15 @@ spec: Mutually exclusive with apiKeySecret. type: boolean apiKeySecret: - description: The name of the secret that contains the API key. Must - be a reference to the name of a secret in the same namespace as - the referencing ModelConfig + description: |- + The name of the secret that contains the API key. Must be a reference to the name of a secret in the same namespace as the referencing ModelConfig. + For the SAPAICore provider, the secret must contain two keys: "client_id" and "client_secret" + (the OAuth2 client credentials for SAP AI Core). The apiKeySecretKey field is not used for SAPAICore. type: string apiKeySecretKey: - description: The key in the secret that contains the API key + description: |- + The key in the secret that contains the API key. + Not used for the SAPAICore provider (which always reads "client_id" and "client_secret" from the secret). type: string azureOpenAI: description: Azure OpenAI-specific configuration @@ -594,7 +597,24 @@ spec: - GeminiVertexAI - AnthropicVertexAI - Bedrock + - SAPAICore type: string + sapAICore: + description: SAP AI Core-specific configuration + properties: + authUrl: + description: OAuth2 token endpoint URL (e.g., https://tenant.authentication.eu10.hana.ondemand.com) + type: string + baseUrl: + description: Base URL for the SAP AI Core API (e.g., https://api.ai.prod.eu-central-1.aws.ml.hana.ondemand.com) + type: string + resourceGroup: + default: default + description: Resource group in SAP AI Core + type: string + required: + - baseUrl + type: object tls: description: |- TLS configuration for provider connections. @@ -656,12 +676,14 @@ spec: rule: '!(has(self.anthropicVertexAI) && self.provider != ''AnthropicVertexAI'')' - message: provider.bedrock must be nil if the provider is not Bedrock rule: '!(has(self.bedrock) && self.provider != ''Bedrock'')' + - message: provider.sapAICore must be nil if the provider is not SAPAICore + rule: '!(has(self.sapAICore) && self.provider != ''SAPAICore'')' - message: apiKeySecret must be set if apiKeySecretKey is set rule: '!(has(self.apiKeySecretKey) && !has(self.apiKeySecret))' - message: apiKeySecretKey must be set if apiKeySecret is set (except - for Bedrock provider) + for Bedrock and SAPAICore providers) rule: '!(has(self.apiKeySecret) && !has(self.apiKeySecretKey) && self.provider - != ''Bedrock'')' + != ''Bedrock'' && self.provider != ''SAPAICore'')' - message: apiKeyPassthrough and apiKeySecret are mutually exclusive rule: '!(has(self.apiKeyPassthrough) && self.apiKeyPassthrough && has(self.apiKeySecret) && size(self.apiKeySecret) > 0)' diff --git a/helm/kagent-crds/templates/kagent.dev_modelproviderconfigs.yaml b/helm/kagent-crds/templates/kagent.dev_modelproviderconfigs.yaml index ff922c140..493e817e9 100644 --- a/helm/kagent-crds/templates/kagent.dev_modelproviderconfigs.yaml +++ b/helm/kagent-crds/templates/kagent.dev_modelproviderconfigs.yaml @@ -90,6 +90,7 @@ spec: - GeminiVertexAI - AnthropicVertexAI - Bedrock + - SAPAICore type: string required: - type diff --git a/python/packages/kagent-adk/src/kagent/adk/models/__init__.py b/python/packages/kagent-adk/src/kagent/adk/models/__init__.py index a8fc43a68..d79d9c2a1 100644 --- a/python/packages/kagent-adk/src/kagent/adk/models/__init__.py +++ b/python/packages/kagent-adk/src/kagent/adk/models/__init__.py @@ -2,5 +2,6 @@ from ._bedrock import KAgentBedrockLlm from ._ollama import KAgentOllamaLlm from ._openai import AzureOpenAI, OpenAI +from ._sap_ai_core import KAgentSAPAICoreLlm -__all__ = ["OpenAI", "AzureOpenAI", "KAgentAnthropicLlm", "KAgentBedrockLlm", "KAgentOllamaLlm"] +__all__ = ["OpenAI", "AzureOpenAI", "KAgentAnthropicLlm", "KAgentBedrockLlm", "KAgentOllamaLlm", "KAgentSAPAICoreLlm"] diff --git a/python/packages/kagent-adk/src/kagent/adk/models/_sap_ai_core.py b/python/packages/kagent-adk/src/kagent/adk/models/_sap_ai_core.py new file mode 100644 index 000000000..e90bbfdd9 --- /dev/null +++ b/python/packages/kagent-adk/src/kagent/adk/models/_sap_ai_core.py @@ -0,0 +1,500 @@ +"""SAP AI Core model implementation via Orchestration Service.""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +import ssl +import time +from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional + +import httpx +from google.adk.models import BaseLlm +from google.adk.models.llm_response import LlmResponse +from google.genai import types + +from ._openai import _convert_tools_to_openai + +if TYPE_CHECKING: + from google.adk.models.llm_request import LlmRequest + +logger = logging.getLogger(__name__) + + +async def _fetch_oauth_token(auth_url: str, client_id: str, client_secret: str) -> tuple[str, float]: + """Fetch a new OAuth2 token from the auth server. No caching — callers manage expiry.""" + token_url = auth_url.rstrip("/") + if not token_url.endswith("/oauth/token"): + token_url += "/oauth/token" + + def _sync_fetch() -> tuple[str, float]: + resp = httpx.post( + token_url, + data={ + "grant_type": "client_credentials", + "client_id": client_id, + "client_secret": client_secret, + }, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=30, + ) + resp.raise_for_status() + data = resp.json() + token = data["access_token"] + expires_at = time.time() + data.get("expires_in", 43200) + return token, expires_at + + return await asyncio.to_thread(_sync_fetch) + + +def _build_orchestration_template( + messages: list[types.Content], + system_instruction: Optional[str] = None, +) -> list[dict[str, Any]]: + template: list[dict[str, Any]] = [] + if system_instruction: + template.append({"role": "system", "content": system_instruction}) + + for content in messages: + role = "assistant" if content.role in ("model", "assistant") else "user" + text_parts: list[str] = [] + tool_calls: list[dict[str, Any]] = [] + function_responses: list[tuple[str, str]] = [] + + for part in content.parts or []: + if part.text: + text_parts.append(part.text) + elif part.function_call: + fc = part.function_call + tc: dict[str, Any] = { + "type": "function", + "function": { + "name": fc.name or "", + "arguments": json.dumps(fc.args) if fc.args else "{}", + }, + } + if fc.id: + tc["id"] = fc.id + tool_calls.append(tc) + elif part.function_response: + fr = part.function_response + resp_content = "" + if fr.response: + resp_content = json.dumps(fr.response) if isinstance(fr.response, dict) else str(fr.response) + function_responses.append((fr.id or fr.name or "", resp_content)) + + if tool_calls: + msg: dict[str, Any] = {"role": "assistant"} + if text_parts: + msg["content"] = "\n".join(text_parts) + else: + msg["content"] = "" + msg["tool_calls"] = tool_calls + template.append(msg) + elif function_responses: + if text_parts: + template.append({"role": role, "content": "\n".join(text_parts)}) + for tool_call_id, resp_content in function_responses: + template.append( + { + "role": "tool", + "tool_call_id": tool_call_id, + "content": resp_content, + } + ) + elif text_parts: + template.append({"role": role, "content": "\n".join(text_parts)}) + + return template + + +def _build_orchestration_tools( + tools: list[types.Tool], +) -> list[dict[str, Any]]: + openai_tools = _convert_tools_to_openai(tools) + result = [] + for t in openai_tools: + result.append( + { + "type": "function", + "function": { + "name": t["function"]["name"], + "description": t["function"].get("description", ""), + "parameters": t["function"].get("parameters", {"type": "object", "properties": {}}), + }, + } + ) + return result + + +def _parse_orchestration_chunk(event_data: dict[str, Any]) -> Optional[dict[str, Any]]: + if "orchestration_result" in event_data: + return event_data["orchestration_result"] + if "final_result" in event_data: + fr = event_data["final_result"] + if "object" not in fr: + fr["object"] = "chat.completion.chunk" + return fr + if "choices" in event_data and "object" in event_data: + return event_data + return None + + +_RETRYABLE_STATUS_CODES = {401, 403, 404, 502, 503, 504} + + +class KAgentSAPAICoreLlm(BaseLlm): + """SAP AI Core LLM via Orchestration Service. + + Supports all model families (OpenAI, Anthropic, Gemini, etc.) through + SAP's unified Orchestration endpoint. + """ + + base_url: Optional[str] = None + resource_group: str = "default" + auth_url: Optional[str] = None + api_key_passthrough: Optional[bool] = None + + tls_disable_verify: bool = False + tls_ca_cert_path: Optional[str] = None + tls_disable_system_cas: bool = False + + _passthrough_key: Optional[str] = None + _http_client: Optional[httpx.AsyncClient] = None + _token: Optional[str] = None + _token_expires_at: float = 0.0 + _deployment_url: Optional[str] = None + _deployment_url_expires_at: float = 0.0 + + model_config = {"arbitrary_types_allowed": True} + + @classmethod + def supported_models(cls) -> list[str]: + return [] + + def set_passthrough_key(self, token: str) -> None: + self._passthrough_key = token + self._http_client = None + + def _create_ssl_context(self) -> Optional[ssl.SSLContext]: + if not self.tls_disable_verify and not self.tls_ca_cert_path and not self.tls_disable_system_cas: + return None + ctx = ssl.create_default_context() + if self.tls_disable_verify: + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + elif self.tls_disable_system_cas: + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + if self.tls_ca_cert_path: + ctx.load_verify_locations(self.tls_ca_cert_path) + elif self.tls_ca_cert_path: + ctx.load_verify_locations(self.tls_ca_cert_path) + return ctx + + def _get_http_client(self) -> httpx.AsyncClient: + if self._http_client is not None: + return self._http_client + ssl_ctx = self._create_ssl_context() + kwargs: dict[str, Any] = {"timeout": 300} + if ssl_ctx is not None: + kwargs["verify"] = ssl_ctx + self._http_client = httpx.AsyncClient(**kwargs) + return self._http_client + + def _invalidate_token(self) -> None: + self._token = None + self._token_expires_at = 0.0 + + def _invalidate_deployment_url(self) -> None: + self._deployment_url = None + self._deployment_url_expires_at = 0.0 + + async def _ensure_token(self) -> str: + if self._passthrough_key: + return self._passthrough_key + + now = time.time() + if self._token and now < self._token_expires_at - 120: + return self._token + + client_id = os.environ.get("SAP_AI_CORE_CLIENT_ID", "") + client_secret = os.environ.get("SAP_AI_CORE_CLIENT_SECRET", "") + + if self.auth_url and client_id and client_secret: + token, expires_at = await _fetch_oauth_token(self.auth_url, client_id, client_secret) + self._token = token + self._token_expires_at = expires_at + return token + raise ValueError("SAP AI Core requires auth_url + SAP_AI_CORE_CLIENT_ID/SECRET env vars") + + async def _get_headers(self) -> dict[str, str]: + token = await self._ensure_token() + return { + "Authorization": f"Bearer {token}", + "AI-Resource-Group": self.resource_group, + "Content-Type": "application/json", + } + + async def _resolve_deployment_url(self) -> str: + now = time.time() + if self._deployment_url and now < self._deployment_url_expires_at: + return self._deployment_url + + if not self.base_url: + raise ValueError("SAP AI Core requires base_url") + + base = self.base_url.rstrip("/") + headers = await self._get_headers() + client = self._get_http_client() + resp = await client.get(f"{base}/v2/lm/deployments", headers=headers) + resp.raise_for_status() + deployments = resp.json() + + valid: list[tuple[str, str]] = [] + for dep in deployments.get("resources", []): + if dep.get("scenarioId") == "orchestration" and dep.get("status") == "RUNNING": + url = dep.get("deploymentUrl", "") + created = dep.get("createdAt", "") + if url: + valid.append((url, created)) + + if not valid: + raise ValueError("No running orchestration deployment found in SAP AI Core") + + self._deployment_url = sorted(valid, key=lambda x: x[1], reverse=True)[0][0] + self._deployment_url_expires_at = now + 3600 + return self._deployment_url + + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + deployment_url = await self._resolve_deployment_url() + url = f"{deployment_url}/v2/completion" + headers = await self._get_headers() + + system_instruction = None + if llm_request.config and llm_request.config.system_instruction: + si = llm_request.config.system_instruction + if isinstance(si, str): + system_instruction = si + elif hasattr(si, "parts"): + parts = getattr(si, "parts", None) or [] + text_parts = [p.text for p in parts if hasattr(p, "text") and p.text] + system_instruction = "\n".join(text_parts) if text_parts else None + + template = _build_orchestration_template(llm_request.contents or [], system_instruction) + + model_params: dict[str, Any] = {} + if llm_request.config: + if llm_request.config.temperature is not None: + model_params["temperature"] = llm_request.config.temperature + if llm_request.config.max_output_tokens is not None: + model_params["max_tokens"] = llm_request.config.max_output_tokens + if llm_request.config.top_p is not None: + model_params["top_p"] = llm_request.config.top_p + + prompt_config: dict[str, Any] = {"template": template} + + if llm_request.config and llm_request.config.tools: + genai_tools: list[types.Tool] = [ + t for t in llm_request.config.tools if isinstance(t, types.Tool) and hasattr(t, "function_declarations") + ] + if genai_tools: + orch_tools = _build_orchestration_tools(genai_tools) + if orch_tools: + prompt_config["tools"] = orch_tools + + body: dict[str, Any] = { + "config": { + "modules": { + "prompt_templating": { + "prompt": prompt_config, + "model": { + "name": llm_request.model or self.model, + "params": model_params, + "version": "latest", + }, + }, + }, + "stream": {"enabled": stream}, + } + } + + try: + if stream: + async for llm_resp in self._stream_request(url, headers, body): + yield llm_resp + else: + yield await self._non_stream_request(url, headers, body) + except httpx.HTTPStatusError as e: + status = e.response.status_code + if status in _RETRYABLE_STATUS_CODES: + if status in (401, 403): + self._invalidate_token() + self._invalidate_deployment_url() + logger.warning("SAP AI Core returned %d from %s, invalidated caches. Retrying once.", status, e.response.url) + try: + headers = await self._get_headers() + deployment_url = await self._resolve_deployment_url() + url = f"{deployment_url}/v2/completion" + if stream: + async for llm_resp in self._stream_request(url, headers, body): + yield llm_resp + else: + yield await self._non_stream_request(url, headers, body) + except Exception as retry_err: + logger.error("SAP AI Core retry failed: %s", retry_err) + yield LlmResponse(error_code="API_ERROR", error_message=str(retry_err)) + else: + logger.error("SAP AI Core error: %s", e) + yield LlmResponse(error_code="API_ERROR", error_message=str(e)) + except Exception as e: + logger.error("SAP AI Core Orchestration error: %s", e) + yield LlmResponse(error_code="API_ERROR", error_message=str(e)) + + async def _stream_request( + self, url: str, headers: dict[str, str], body: dict[str, Any] + ) -> AsyncGenerator[LlmResponse, None]: + aggregated_text = "" + tool_calls_acc: dict[int, dict[str, Any]] = {} + finish_reason_str: Optional[str] = None + usage_metadata: Optional[types.GenerateContentResponseUsageMetadata] = None + + client = self._get_http_client() + async with client.stream("POST", url, headers=headers, json=body) as resp: + resp.raise_for_status() + async for line in resp.aiter_lines(): + line = line.strip() + if not line: + continue + + payload = line[len("data: ") :] if line.startswith("data: ") else line + if payload == "[DONE]": + break + + try: + event = json.loads(payload) + except json.JSONDecodeError: + continue + + if "code" in event or "error" in event: + raise RuntimeError(json.dumps(event)) + + chunk = _parse_orchestration_chunk(event) + if not chunk: + continue + + for choice in chunk.get("choices", []): + delta = choice.get("delta", {}) + if delta.get("content"): + text = delta["content"] + aggregated_text += text + yield LlmResponse( + content=types.Content(role="model", parts=[types.Part.from_text(text=text)]), + partial=True, + turn_complete=False, + ) + + if delta.get("tool_calls"): + for tc in delta["tool_calls"]: + idx = tc.get("index", 0) + if idx not in tool_calls_acc: + tool_calls_acc[idx] = {"id": "", "name": "", "arguments": ""} + if tc.get("id"): + tool_calls_acc[idx]["id"] = tc["id"] + func = tc.get("function", {}) + if func.get("name"): + tool_calls_acc[idx]["name"] = func["name"] + if func.get("arguments"): + tool_calls_acc[idx]["arguments"] += func["arguments"] + + if choice.get("finish_reason"): + finish_reason_str = choice["finish_reason"] + + usage = chunk.get("usage") + if usage: + usage_metadata = types.GenerateContentResponseUsageMetadata( + prompt_token_count=usage.get("prompt_tokens"), + candidates_token_count=usage.get("completion_tokens"), + total_token_count=usage.get("total_tokens"), + ) + + final_parts: list[types.Part] = [] + if aggregated_text: + final_parts.append(types.Part.from_text(text=aggregated_text)) + for idx in sorted(tool_calls_acc.keys()): + tc = tool_calls_acc[idx] + try: + args = json.loads(tc["arguments"]) if tc["arguments"] else {} + except json.JSONDecodeError: + args = {} + part = types.Part.from_function_call(name=tc["name"], args=args) + if part.function_call: + part.function_call.id = tc["id"] + final_parts.append(part) + + fr = self._map_finish_reason(finish_reason_str) + + yield LlmResponse( + content=types.Content(role="model", parts=final_parts), + partial=False, + turn_complete=True, + finish_reason=fr, + usage_metadata=usage_metadata, + ) + + async def _non_stream_request(self, url: str, headers: dict[str, str], body: dict[str, Any]) -> LlmResponse: + client = self._get_http_client() + resp = await client.post(url, headers=headers, json=body) + resp.raise_for_status() + data = resp.json() + + result = data.get("final_result", data) + parts: list[types.Part] = [] + + for choice in result.get("choices", []): + msg = choice.get("message", {}) + if msg.get("content"): + parts.append(types.Part.from_text(text=msg["content"])) + for tc in msg.get("tool_calls", []): + func = tc.get("function", {}) + try: + args = json.loads(func.get("arguments", "{}")) + except json.JSONDecodeError: + args = {} + part = types.Part.from_function_call(name=func.get("name", ""), args=args) + if part.function_call: + part.function_call.id = tc.get("id", "") + parts.append(part) + + usage = result.get("usage", {}) + usage_metadata = ( + types.GenerateContentResponseUsageMetadata( + prompt_token_count=usage.get("prompt_tokens"), + candidates_token_count=usage.get("completion_tokens"), + total_token_count=usage.get("total_tokens"), + ) + if usage + else None + ) + + stop_reason = result.get("choices", [{}])[0].get("finish_reason", "stop") + fr = self._map_finish_reason(stop_reason) + + return LlmResponse( + content=types.Content(role="model", parts=parts), + finish_reason=fr, + usage_metadata=usage_metadata, + ) + + @staticmethod + def _map_finish_reason(reason: Optional[str]) -> types.FinishReason: + if reason == "length": + return types.FinishReason.MAX_TOKENS + if reason == "content_filter": + return types.FinishReason.SAFETY + if reason == "tool_calls": + return types.FinishReason.STOP + return types.FinishReason.STOP diff --git a/python/packages/kagent-adk/src/kagent/adk/types.py b/python/packages/kagent-adk/src/kagent/adk/types.py index 74ef7f46f..4e29c4a45 100644 --- a/python/packages/kagent-adk/src/kagent/adk/types.py +++ b/python/packages/kagent-adk/src/kagent/adk/types.py @@ -223,7 +223,14 @@ class Bedrock(BaseLLM): type: Literal["bedrock"] -ModelUnion = Union[OpenAI, Anthropic, GeminiVertexAI, GeminiAnthropic, Ollama, AzureOpenAI, Gemini, Bedrock] +class SAPAICore(BaseLLM): + base_url: str | None = None + resource_group: str = "default" + auth_url: str | None = None + type: Literal["sap_ai_core"] + + +ModelUnion = Union[OpenAI, Anthropic, GeminiVertexAI, GeminiAnthropic, Ollama, AzureOpenAI, Gemini, Bedrock, SAPAICore] class ContextCompressionSettings(BaseModel): @@ -515,11 +522,23 @@ def _create_llm_from_model_config(model_config: ModelUnion): if model_config.type == "gemini": return model_config.model if model_config.type == "bedrock": - # api key passthrough is not applicable for bedrock return KAgentBedrockLlm( model=model_config.model, extra_headers=extra_headers, ) + if model_config.type == "sap_ai_core": + from .models._sap_ai_core import KAgentSAPAICoreLlm + + return KAgentSAPAICoreLlm( + model=model_config.model, + base_url=base_url, + resource_group=model_config.resource_group, + auth_url=model_config.auth_url, + api_key_passthrough=model_config.api_key_passthrough, + tls_disable_verify=model_config.tls_disable_verify or False, + tls_ca_cert_path=model_config.tls_ca_cert_path, + tls_disable_system_cas=model_config.tls_disable_system_cas or False, + ) raise ValueError(f"Invalid model type: {model_config.type}") diff --git a/python/packages/kagent-adk/tests/unittests/models/test_sap_ai_core.py b/python/packages/kagent-adk/tests/unittests/models/test_sap_ai_core.py new file mode 100644 index 000000000..fce65b433 --- /dev/null +++ b/python/packages/kagent-adk/tests/unittests/models/test_sap_ai_core.py @@ -0,0 +1,752 @@ +"""Tests for KAgentSAPAICoreLlm (SAP AI Core via Orchestration Service).""" + +import json +import time +from unittest import mock +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +from google.genai import types + +from kagent.adk.models._sap_ai_core import ( + KAgentSAPAICoreLlm, + _build_orchestration_template, + _build_orchestration_tools, + _fetch_oauth_token, + _parse_orchestration_chunk, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_llm(base_url="https://api.example.com", auth_url="https://auth.example.com"): + return KAgentSAPAICoreLlm( + model="anthropic--claude-3.5-sonnet", + base_url=base_url, + resource_group="default", + auth_url=auth_url, + ) + + +def _content(role: str, text: str) -> types.Content: + return types.Content(role=role, parts=[types.Part.from_text(text=text)]) + + +def _sse_body(*payloads: dict) -> AsyncMock: + """Return a mock httpx streaming response that yields SSE lines.""" + lines = [] + for p in payloads: + lines.append(f"data: {json.dumps(p)}") + lines.append("data: [DONE]") + + async def _aiter_lines(): + for line in lines: + yield line + + resp = MagicMock() + resp.raise_for_status = MagicMock() + resp.aiter_lines = _aiter_lines + return resp + + +def _make_request(contents=None, model="anthropic--claude-3.5-sonnet", config=None): + req = MagicMock() + req.model = model + req.contents = contents or [] + req.config = config + return req + + +# --------------------------------------------------------------------------- +# _build_orchestration_template +# --------------------------------------------------------------------------- + + +class TestBuildOrchestrationTemplate: + def test_empty_input(self): + result = _build_orchestration_template([]) + assert result == [] + + def test_system_instruction_prepended(self): + result = _build_orchestration_template([], system_instruction="Be helpful.") + assert result[0] == {"role": "system", "content": "Be helpful."} + + def test_user_and_assistant_messages(self): + contents = [ + _content("user", "Hello"), + _content("model", "Hi there"), + ] + result = _build_orchestration_template(contents) + assert len(result) == 2 + assert result[0] == {"role": "user", "content": "Hello"} + assert result[1] == {"role": "assistant", "content": "Hi there"} + + def test_assistant_role_alias(self): + contents = [_content("assistant", "Reply")] + result = _build_orchestration_template(contents) + assert result[0]["role"] == "assistant" + + def test_tool_call_message(self): + fc_part = types.Part.from_function_call(name="get_weather", args={"city": "Berlin"}) + fc_part.function_call.id = "call_1" + content = types.Content(role="model", parts=[fc_part]) + result = _build_orchestration_template([content]) + + assert len(result) == 1 + msg = result[0] + assert msg["role"] == "assistant" + assert msg["content"] == "" + tool_calls = msg["tool_calls"] + assert len(tool_calls) == 1 + assert tool_calls[0]["id"] == "call_1" + assert tool_calls[0]["function"]["name"] == "get_weather" + assert json.loads(tool_calls[0]["function"]["arguments"]) == {"city": "Berlin"} + + def test_function_response_message(self): + fr_part = types.Part.from_function_response( + name="get_weather", response={"temp": "20C"} + ) + fr_part.function_response.id = "call_1" + content = types.Content(role="user", parts=[fr_part]) + result = _build_orchestration_template([content]) + + assert len(result) == 1 + assert result[0]["role"] == "tool" + assert result[0]["tool_call_id"] == "call_1" + assert "20C" in result[0]["content"] + + def test_mixed_text_and_tool_call(self): + """Tool call with accompanying text — text goes into content field.""" + fc_part = types.Part.from_function_call(name="search", args={}) + fc_part.function_call.id = "call_2" + text_part = types.Part.from_text(text="Searching now…") + content = types.Content(role="model", parts=[text_part, fc_part]) + result = _build_orchestration_template([content]) + + assert result[0]["content"] == "Searching now…" + assert "tool_calls" in result[0] + + def test_empty_parts_skipped(self): + content = types.Content(role="user", parts=[]) + result = _build_orchestration_template([content]) + assert result == [] + + +# --------------------------------------------------------------------------- +# _build_orchestration_tools +# --------------------------------------------------------------------------- + + +class TestBuildOrchestrationTools: + def test_empty_input(self): + assert _build_orchestration_tools([]) == [] + + def test_single_function(self): + tool = types.Tool( + function_declarations=[ + types.FunctionDeclaration( + name="list_pods", + description="List Kubernetes pods", + parameters=types.Schema( + type=types.Type.OBJECT, + properties={"namespace": types.Schema(type=types.Type.STRING)}, + ), + ) + ] + ) + result = _build_orchestration_tools([tool]) + assert len(result) == 1 + fn = result[0]["function"] + assert fn["name"] == "list_pods" + assert fn["description"] == "List Kubernetes pods" + assert "namespace" in fn["parameters"]["properties"] + + def test_multiple_declarations(self): + tool = types.Tool( + function_declarations=[ + types.FunctionDeclaration(name="fn_a", description="A"), + types.FunctionDeclaration(name="fn_b", description="B"), + ] + ) + result = _build_orchestration_tools([tool]) + names = [r["function"]["name"] for r in result] + assert names == ["fn_a", "fn_b"] + + +# --------------------------------------------------------------------------- +# _parse_orchestration_chunk +# --------------------------------------------------------------------------- + + +class TestParseOrchestrationChunk: + def test_orchestration_result_envelope(self): + event = {"orchestration_result": {"choices": []}} + result = _parse_orchestration_chunk(event) + assert result is not None + assert "choices" in result + + def test_final_result_envelope(self): + event = {"final_result": {"choices": [], "object": "chat.completion.chunk"}} + result = _parse_orchestration_chunk(event) + assert result is not None + assert "choices" in result + + def test_final_result_adds_object_field(self): + event = {"final_result": {"choices": []}} + result = _parse_orchestration_chunk(event) + assert result["object"] == "chat.completion.chunk" + + def test_direct_choices_with_object(self): + event = {"choices": [], "object": "chat.completion.chunk"} + result = _parse_orchestration_chunk(event) + assert result is event + + def test_unrecognized_returns_none(self): + assert _parse_orchestration_chunk({"foo": "bar"}) is None + + +# --------------------------------------------------------------------------- +# OAuth token caching (_ensure_token) +# --------------------------------------------------------------------------- + + +class TestEnsureToken: + @pytest.mark.asyncio + async def test_fetches_token_on_first_call(self, monkeypatch): + llm = _make_llm() + monkeypatch.setenv("SAP_AI_CORE_CLIENT_ID", "cid") + monkeypatch.setenv("SAP_AI_CORE_CLIENT_SECRET", "csecret") + + with patch( + "kagent.adk.models._sap_ai_core._fetch_oauth_token", + new_callable=AsyncMock, + return_value=("tok-1", time.time() + 3600), + ) as mock_fetch: + token = await llm._ensure_token() + + assert token == "tok-1" + mock_fetch.assert_awaited_once() + + @pytest.mark.asyncio + async def test_caches_valid_token(self, monkeypatch): + llm = _make_llm() + llm._token = "cached-tok" + llm._token_expires_at = time.time() + 3600 + monkeypatch.setenv("SAP_AI_CORE_CLIENT_ID", "cid") + monkeypatch.setenv("SAP_AI_CORE_CLIENT_SECRET", "csecret") + + with patch( + "kagent.adk.models._sap_ai_core._fetch_oauth_token", + new_callable=AsyncMock, + ) as mock_fetch: + token = await llm._ensure_token() + + assert token == "cached-tok" + mock_fetch.assert_not_awaited() + + @pytest.mark.asyncio + async def test_refreshes_expired_token(self, monkeypatch): + llm = _make_llm() + llm._token = "old-tok" + llm._token_expires_at = time.time() - 1 # already expired + monkeypatch.setenv("SAP_AI_CORE_CLIENT_ID", "cid") + monkeypatch.setenv("SAP_AI_CORE_CLIENT_SECRET", "csecret") + + with patch( + "kagent.adk.models._sap_ai_core._fetch_oauth_token", + new_callable=AsyncMock, + return_value=("new-tok", time.time() + 3600), + ) as mock_fetch: + token = await llm._ensure_token() + + assert token == "new-tok" + mock_fetch.assert_awaited_once() + + @pytest.mark.asyncio + async def test_raises_when_env_vars_missing(self, monkeypatch): + llm = _make_llm() + monkeypatch.delenv("SAP_AI_CORE_CLIENT_ID", raising=False) + monkeypatch.delenv("SAP_AI_CORE_CLIENT_SECRET", raising=False) + + with pytest.raises(ValueError, match="SAP_AI_CORE_CLIENT"): + await llm._ensure_token() + + def test_set_passthrough_key_invalidates_http_client(self): + llm = _make_llm() + # Force creation of the http client. + client = llm._get_http_client() + assert client is not None + assert llm._http_client is not None + + llm.set_passthrough_key("my-bearer-token") + + assert llm._passthrough_key == "my-bearer-token" + # Client must be cleared so a new one is created with fresh config. + assert llm._http_client is None + + @pytest.mark.asyncio + async def test_passthrough_key_skips_oauth(self, monkeypatch): + llm = _make_llm() + llm.set_passthrough_key("bearer-pass") + monkeypatch.delenv("SAP_AI_CORE_CLIENT_ID", raising=False) + + with patch( + "kagent.adk.models._sap_ai_core._fetch_oauth_token", + new_callable=AsyncMock, + ) as mock_fetch: + token = await llm._ensure_token() + + assert token == "bearer-pass" + mock_fetch.assert_not_awaited() + + @pytest.mark.asyncio + async def test_raises_when_auth_url_missing(self, monkeypatch): + """auth_url=None with no passthrough key should raise ValueError.""" + llm = KAgentSAPAICoreLlm(model="test", base_url="https://api.example.com", auth_url=None) + monkeypatch.setenv("SAP_AI_CORE_CLIENT_ID", "cid") + monkeypatch.setenv("SAP_AI_CORE_CLIENT_SECRET", "csecret") + + with pytest.raises(ValueError, match="SAP_AI_CORE_CLIENT"): + await llm._ensure_token() + + def test_invalidate_clears_token(self): + llm = _make_llm() + llm._token = "tok" + llm._token_expires_at = time.time() + 3600 + llm._invalidate_token() + assert llm._token is None + assert llm._token_expires_at == 0.0 + + +# --------------------------------------------------------------------------- +# Deployment URL resolution and caching (_resolve_deployment_url) +# --------------------------------------------------------------------------- + + +class TestResolveDeploymentURL: + def _dep_response(self, *urls): + resources = [ + { + "scenarioId": "orchestration", + "status": "RUNNING", + "deploymentUrl": u, + "createdAt": f"2024-01-{i+1:02d}T00:00:00Z", + } + for i, u in enumerate(urls) + ] + return {"resources": resources} + + @pytest.mark.asyncio + async def test_resolves_and_caches_url(self, monkeypatch): + llm = _make_llm() + monkeypatch.setenv("SAP_AI_CORE_CLIENT_ID", "cid") + monkeypatch.setenv("SAP_AI_CORE_CLIENT_SECRET", "csecret") + + mock_resp = MagicMock() + mock_resp.raise_for_status = MagicMock() + mock_resp.json.return_value = self._dep_response("https://dep.example.com") + + with ( + patch.object(llm, "_ensure_token", new_callable=AsyncMock, return_value="tok"), + patch.object(llm._get_http_client(), "get", new_callable=AsyncMock, return_value=mock_resp) as mock_get, + ): + url1 = await llm._resolve_deployment_url() + url2 = await llm._resolve_deployment_url() + + assert url1 == "https://dep.example.com" + assert url2 == "https://dep.example.com" + # Second call must use the cache — HTTP GET called only once. + mock_get.assert_awaited_once() + + @pytest.mark.asyncio + async def test_picks_most_recently_created(self, monkeypatch): + llm = _make_llm() + monkeypatch.setenv("SAP_AI_CORE_CLIENT_ID", "cid") + monkeypatch.setenv("SAP_AI_CORE_CLIENT_SECRET", "csecret") + + mock_resp = MagicMock() + mock_resp.raise_for_status = MagicMock() + mock_resp.json.return_value = self._dep_response( + "https://older.example.com", # createdAt 2024-01-01 + "https://newer.example.com", # createdAt 2024-01-02 + ) + + with ( + patch.object(llm, "_ensure_token", new_callable=AsyncMock, return_value="tok"), + patch.object(llm._get_http_client(), "get", new_callable=AsyncMock, return_value=mock_resp), + ): + url = await llm._resolve_deployment_url() + + assert url == "https://newer.example.com" + + @pytest.mark.asyncio + async def test_raises_when_no_running_deployment(self, monkeypatch): + llm = _make_llm() + monkeypatch.setenv("SAP_AI_CORE_CLIENT_ID", "cid") + monkeypatch.setenv("SAP_AI_CORE_CLIENT_SECRET", "csecret") + + mock_resp = MagicMock() + mock_resp.raise_for_status = MagicMock() + mock_resp.json.return_value = {"resources": [ + {"scenarioId": "other", "status": "RUNNING", "deploymentUrl": "https://x.example.com"} + ]} + + with ( + patch.object(llm, "_ensure_token", new_callable=AsyncMock, return_value="tok"), + patch.object(llm._get_http_client(), "get", new_callable=AsyncMock, return_value=mock_resp), + ): + with pytest.raises(ValueError, match="No running orchestration"): + await llm._resolve_deployment_url() + + @pytest.mark.asyncio + async def test_expires_and_refreshes(self, monkeypatch): + llm = _make_llm() + monkeypatch.setenv("SAP_AI_CORE_CLIENT_ID", "cid") + monkeypatch.setenv("SAP_AI_CORE_CLIENT_SECRET", "csecret") + + mock_resp = MagicMock() + mock_resp.raise_for_status = MagicMock() + mock_resp.json.return_value = self._dep_response("https://dep.example.com") + + with ( + patch.object(llm, "_ensure_token", new_callable=AsyncMock, return_value="tok"), + patch.object(llm._get_http_client(), "get", new_callable=AsyncMock, return_value=mock_resp) as mock_get, + ): + await llm._resolve_deployment_url() + + # Expire the cache. + llm._deployment_url_expires_at = time.time() - 1 + + await llm._resolve_deployment_url() + + assert mock_get.await_count == 2 + + def test_invalidate_clears_url(self): + llm = _make_llm() + llm._deployment_url = "https://old.example.com" + llm._deployment_url_expires_at = time.time() + 3600 + llm._invalidate_deployment_url() + assert llm._deployment_url is None + assert llm._deployment_url_expires_at == 0.0 + + +# --------------------------------------------------------------------------- +# _non_stream_request +# --------------------------------------------------------------------------- + + +class TestNonStreamRequest: + @pytest.mark.asyncio + async def test_text_response(self): + llm = _make_llm() + data = { + "final_result": { + "choices": [{"finish_reason": "stop", "message": {"content": "Hello!"}}], + } + } + mock_resp = MagicMock() + mock_resp.raise_for_status = MagicMock() + mock_resp.json.return_value = data + + with patch.object(llm._get_http_client(), "post", new_callable=AsyncMock, return_value=mock_resp): + result = await llm._non_stream_request("https://dep/v2/completion", {}, {}) + + assert result.content.parts[0].text == "Hello!" + + @pytest.mark.asyncio + async def test_tool_call_response(self): + llm = _make_llm() + data = { + "choices": [{ + "finish_reason": "tool_calls", + "message": { + "content": "", + "tool_calls": [{ + "id": "call_99", + "type": "function", + "function": {"name": "get_pods", "arguments": '{"ns":"default"}'}, + }], + }, + }] + } + mock_resp = MagicMock() + mock_resp.raise_for_status = MagicMock() + mock_resp.json.return_value = data + + with patch.object(llm._get_http_client(), "post", new_callable=AsyncMock, return_value=mock_resp): + result = await llm._non_stream_request("https://dep/v2/completion", {}, {}) + + fc = next((p.function_call for p in result.content.parts if p.function_call), None) + assert fc is not None + assert fc.name == "get_pods" + assert fc.id == "call_99" + assert fc.args == {"ns": "default"} + + @pytest.mark.asyncio + async def test_usage_metadata(self): + llm = _make_llm() + data = { + "choices": [{"finish_reason": "stop", "message": {"content": "ok"}}], + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + } + mock_resp = MagicMock() + mock_resp.raise_for_status = MagicMock() + mock_resp.json.return_value = data + + with patch.object(llm._get_http_client(), "post", new_callable=AsyncMock, return_value=mock_resp): + result = await llm._non_stream_request("https://dep/v2/completion", {}, {}) + + assert result.usage_metadata.prompt_token_count == 10 + assert result.usage_metadata.candidates_token_count == 5 + assert result.usage_metadata.total_token_count == 15 + + +# --------------------------------------------------------------------------- +# _stream_request +# --------------------------------------------------------------------------- + + +class TestStreamRequest: + def _orch_chunk(self, choices): + return {"orchestration_result": {"choices": choices}} + + def _text_delta(self, text): + return {"delta": {"content": text}} + + def _finish_delta(self, reason): + return {"delta": {}, "finish_reason": reason} + + @pytest.mark.asyncio + async def test_text_chunks_and_final_aggregation(self): + llm = _make_llm() + payloads = [ + self._orch_chunk([self._text_delta("Hello")]), + self._orch_chunk([self._text_delta(", world")]), + self._orch_chunk([self._finish_delta("stop")]), + ] + mock_resp = _sse_body(*payloads) + + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=mock_resp) + cm.__aexit__ = AsyncMock(return_value=False) + + with patch.object(llm._get_http_client(), "stream", return_value=cm): + responses = [r async for r in llm._stream_request("https://dep/v2/completion", {}, {})] + + partials = [r for r in responses if r.partial] + assert len(partials) == 2 + assert partials[0].content.parts[0].text == "Hello" + assert partials[1].content.parts[0].text == ", world" + + final = responses[-1] + assert not final.partial + assert final.turn_complete + assert final.content.parts[0].text == "Hello, world" + + @pytest.mark.asyncio + async def test_tool_call_assembled_across_chunks(self): + llm = _make_llm() + payloads = [ + self._orch_chunk([{"delta": {"tool_calls": [ + {"index": 0, "id": "call_5", "function": {"name": "list_pods", "arguments": '{"ns":'}} + ]}}]), + self._orch_chunk([{"delta": {"tool_calls": [ + {"index": 0, "function": {"arguments": '"default"}'}} + ]}, "finish_reason": "tool_calls"}]), + ] + mock_resp = _sse_body(*payloads) + + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=mock_resp) + cm.__aexit__ = AsyncMock(return_value=False) + + with patch.object(llm._get_http_client(), "stream", return_value=cm): + responses = [r async for r in llm._stream_request("https://dep/v2/completion", {}, {})] + + final = responses[-1] + fc = next((p.function_call for p in final.content.parts if p.function_call), None) + assert fc is not None + assert fc.name == "list_pods" + assert fc.id == "call_5" + assert fc.args == {"ns": "default"} + + @pytest.mark.asyncio + async def test_usage_metadata_in_final_result_envelope(self): + llm = _make_llm() + payloads = [ + self._orch_chunk([self._text_delta("hi")]), + {"final_result": { + "choices": [self._finish_delta("stop")], + "usage": {"prompt_tokens": 8, "completion_tokens": 3, "total_tokens": 11}, + }}, + ] + mock_resp = _sse_body(*payloads) + + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=mock_resp) + cm.__aexit__ = AsyncMock(return_value=False) + + with patch.object(llm._get_http_client(), "stream", return_value=cm): + responses = [r async for r in llm._stream_request("https://dep/v2/completion", {}, {})] + + final = responses[-1] + assert final.usage_metadata is not None + assert final.usage_metadata.prompt_token_count == 8 + assert final.usage_metadata.candidates_token_count == 3 + + @pytest.mark.asyncio + async def test_error_event_raises(self): + llm = _make_llm() + mock_resp = _sse_body({"code": "500", "message": "internal error"}) + + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=mock_resp) + cm.__aexit__ = AsyncMock(return_value=False) + + with patch.object(llm._get_http_client(), "stream", return_value=cm): + with pytest.raises(RuntimeError, match="internal error"): + async for _ in llm._stream_request("https://dep/v2/completion", {}, {}): + pass + + @pytest.mark.asyncio + async def test_malformed_lines_skipped(self): + llm = _make_llm() + + async def _aiter_lines(): + yield "data: not-valid-json" + yield f"data: {json.dumps(self._orch_chunk([self._text_delta('ok')]))}" + yield "data: [DONE]" + + mock_resp = MagicMock() + mock_resp.raise_for_status = MagicMock() + mock_resp.aiter_lines = _aiter_lines + + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=mock_resp) + cm.__aexit__ = AsyncMock(return_value=False) + + with patch.object(llm._get_http_client(), "stream", return_value=cm): + responses = [r async for r in llm._stream_request("https://dep/v2/completion", {}, {})] + + partials = [r for r in responses if r.partial] + assert len(partials) == 1 + assert partials[0].content.parts[0].text == "ok" + + +# --------------------------------------------------------------------------- +# generate_content_async — retry on retryable status codes +# --------------------------------------------------------------------------- + + +class TestGenerateContentAsyncRetry: + @pytest.mark.asyncio + async def test_retries_on_401(self, monkeypatch): + llm = _make_llm() + monkeypatch.setenv("SAP_AI_CORE_CLIENT_ID", "cid") + monkeypatch.setenv("SAP_AI_CORE_CLIENT_SECRET", "csecret") + + ok_resp = MagicMock() + ok_resp.raise_for_status = MagicMock() + ok_resp.json.return_value = { + "choices": [{"finish_reason": "stop", "message": {"content": "retry ok"}}] + } + + error_resp = MagicMock() + error_resp.status_code = 401 + http_error = httpx.HTTPStatusError("401", request=MagicMock(), response=error_resp) + + call_count = 0 + + async def mock_non_stream(url, headers, body): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise http_error + return (await _make_real_non_stream(url, headers, body)) + + from google.adk.models.llm_response import LlmResponse + from google.genai import types as gtypes + + async def _make_real_non_stream(url, headers, body): + return LlmResponse( + content=gtypes.Content(role="model", parts=[gtypes.Part.from_text(text="retry ok")]) + ) + + with ( + patch.object(llm, "_resolve_deployment_url", new_callable=AsyncMock, return_value="https://dep/"), + patch.object(llm, "_get_headers", new_callable=AsyncMock, return_value={}), + patch.object(llm, "_non_stream_request", side_effect=mock_non_stream), + patch.object(llm, "_invalidate_token") as mock_inv_tok, + patch.object(llm, "_invalidate_deployment_url") as mock_inv_dep, + ): + responses = [r async for r in llm.generate_content_async(_make_request(), stream=False)] + + assert call_count == 2 + mock_inv_tok.assert_called_once() + mock_inv_dep.assert_called_once() + assert responses[-1].content.parts[0].text == "retry ok" + + @pytest.mark.asyncio + async def test_no_retry_on_400(self, monkeypatch): + llm = _make_llm() + monkeypatch.setenv("SAP_AI_CORE_CLIENT_ID", "cid") + monkeypatch.setenv("SAP_AI_CORE_CLIENT_SECRET", "csecret") + + error_resp = MagicMock() + error_resp.status_code = 400 + http_error = httpx.HTTPStatusError("400", request=MagicMock(), response=error_resp) + + call_count = 0 + + async def mock_non_stream(url, headers, body): + nonlocal call_count + call_count += 1 + raise http_error + + with ( + patch.object(llm, "_resolve_deployment_url", new_callable=AsyncMock, return_value="https://dep/"), + patch.object(llm, "_get_headers", new_callable=AsyncMock, return_value={}), + patch.object(llm, "_non_stream_request", side_effect=mock_non_stream), + ): + responses = [r async for r in llm.generate_content_async(_make_request(), stream=False)] + + assert call_count == 1 # no retry + assert responses[-1].error_code == "API_ERROR" + + +# --------------------------------------------------------------------------- +# _create_llm_from_model_config integration +# --------------------------------------------------------------------------- + + +class TestCreateLlmFromModelConfig: + def test_returns_kagent_sap_ai_core_llm(self): + from kagent.adk.types import SAPAICore, _create_llm_from_model_config + + config = SAPAICore( + type="sap_ai_core", + model="anthropic--claude-3.5-sonnet", + base_url="https://api.example.com", + resource_group="my-group", + auth_url="https://auth.example.com", + ) + result = _create_llm_from_model_config(config) + assert isinstance(result, KAgentSAPAICoreLlm) + assert result.model == "anthropic--claude-3.5-sonnet" + assert result.base_url == "https://api.example.com" + assert result.resource_group == "my-group" + assert result.auth_url == "https://auth.example.com" + + def test_default_resource_group(self): + from kagent.adk.types import SAPAICore, _create_llm_from_model_config + + config = SAPAICore( + type="sap_ai_core", + model="anthropic--claude-3.5-sonnet", + ) + result = _create_llm_from_model_config(config) + assert result.resource_group == "default" diff --git a/ui/src/app/models/new/page.tsx b/ui/src/app/models/new/page.tsx index 772944ff3..9178ecf25 100644 --- a/ui/src/app/models/new/page.tsx +++ b/ui/src/app/models/new/page.tsx @@ -18,7 +18,8 @@ import type { ProviderModelsResponse, GeminiConfigPayload, GeminiVertexAIConfigPayload, - AnthropicVertexAIConfigPayload + AnthropicVertexAIConfigPayload, + SAPAICoreConfigPayload } from "@/types"; import { toast } from "sonner"; import { isResourceNameValid, createRFC1123ValidName } from "@/lib/utils"; @@ -592,6 +593,9 @@ function ModelPageContent() { case 'AnthropicVertexAI': payload.anthropicVertexAI = providerParams as AnthropicVertexAIConfigPayload; break; + case 'SAPAICore': + payload.sapAICore = providerParams as SAPAICoreConfigPayload; + break; default: console.error("Unsupported provider type during payload construction:", providerType); toast.error("Internal error: Unsupported provider type."); @@ -610,6 +614,7 @@ function ModelPageContent() { anthropic: payload.anthropic, azureOpenAI: payload.azureOpenAI, ollama: payload.ollama, + sapAICore: payload.sapAICore, }; const modelConfigRef = k8sRefUtils.toRef(modelConfigNamespace || '', modelConfigName); response = await updateModelConfig(modelConfigRef, updatePayload); diff --git a/ui/src/components/ModelProviderCombobox.tsx b/ui/src/components/ModelProviderCombobox.tsx index 98ede6b02..ddfd32ff9 100644 --- a/ui/src/components/ModelProviderCombobox.tsx +++ b/ui/src/components/ModelProviderCombobox.tsx @@ -11,6 +11,7 @@ import { Anthropic } from './icons/Anthropic'; import { Ollama } from './icons/Ollama'; import { Azure } from './icons/Azure'; import { Gemini } from './icons/Gemini'; +import { SAPAICore } from './icons/SAPAICore'; interface ComboboxOption { label: string; // e.g., "OpenAI - gpt-4o" @@ -64,6 +65,7 @@ export function ModelProviderCombobox({ 'Gemini': Gemini, 'GeminiVertexAI': Gemini, 'AnthropicVertexAI': Anthropic, + 'SAPAICore': SAPAICore, }; if (!providerKey || !PROVIDER_ICONS[providerKey]) { return null; diff --git a/ui/src/components/ProviderCombobox.tsx b/ui/src/components/ProviderCombobox.tsx index 65d8bce84..ff332149c 100644 --- a/ui/src/components/ProviderCombobox.tsx +++ b/ui/src/components/ProviderCombobox.tsx @@ -11,6 +11,7 @@ import { Anthropic } from './icons/Anthropic'; import { Ollama } from './icons/Ollama'; import { Azure } from './icons/Azure'; import { Gemini } from './icons/Gemini'; +import { SAPAICore } from './icons/SAPAICore'; const PROVIDER_ICONS: Record> = { 'OpenAI': OpenAI, @@ -20,6 +21,7 @@ const PROVIDER_ICONS: Record + + + ); +} diff --git a/ui/src/lib/providers.ts b/ui/src/lib/providers.ts index 79246c174..7ce6f3ef2 100644 --- a/ui/src/lib/providers.ts +++ b/ui/src/lib/providers.ts @@ -1,6 +1,6 @@ -export type BackendModelProviderType = "OpenAI" | "AzureOpenAI" | "Anthropic" | "Ollama" | "Gemini" | "GeminiVertexAI" | "AnthropicVertexAI"; -export const modelProviders = ["OpenAI", "AzureOpenAI", "Anthropic", "Ollama", "Gemini", "GeminiVertexAI", "AnthropicVertexAI"] as const; +export type BackendModelProviderType = "OpenAI" | "AzureOpenAI" | "Anthropic" | "Ollama" | "Gemini" | "GeminiVertexAI" | "AnthropicVertexAI" | "SAPAICore"; +export const modelProviders = ["OpenAI", "AzureOpenAI", "Anthropic", "Ollama", "Gemini", "GeminiVertexAI", "AnthropicVertexAI", "SAPAICore"] as const; export type ModelProviderKey = typeof modelProviders[number]; @@ -62,6 +62,13 @@ export const PROVIDERS_INFO: { modelDocsLink: "https://cloud.google.com/vertex-ai/docs", help: "Configure your Google Cloud project and credentials for Vertex AI." }, + SAPAICore: { + name: "SAP AI Core", + type: "SAPAICore", + apiKeyLink: "https://help.sap.com/docs/sap-ai-core", + modelDocsLink: "https://help.sap.com/docs/sap-ai-core/sap-ai-core-service-guide/models-and-scenarios-in-generative-ai-hub", + help: "Create a K8s Secret with client_id and client_secret from your SAP AI Core service key." + }, }; export const isValidProviderInfoKey = (key: string): key is ModelProviderKey => { diff --git a/ui/src/types/index.ts b/ui/src/types/index.ts index 87fe16176..9a9986189 100644 --- a/ui/src/types/index.ts +++ b/ui/src/types/index.ts @@ -122,6 +122,12 @@ export interface AnthropicVertexAIConfigPayload { topK?: number; } +export interface SAPAICoreConfigPayload { + baseUrl: string; + resourceGroup?: string; + authUrl?: string; +} + export interface CreateModelConfigRequest { ref: string; provider: Pick; @@ -134,6 +140,7 @@ export interface CreateModelConfigRequest { gemini?: GeminiConfigPayload; geminiVertexAI?: GeminiVertexAIConfigPayload; anthropicVertexAI?: AnthropicVertexAIConfigPayload; + sapAICore?: SAPAICoreConfigPayload; } export interface UpdateModelConfigPayload { @@ -147,6 +154,7 @@ export interface UpdateModelConfigPayload { gemini?: GeminiConfigPayload; geminiVertexAI?: GeminiVertexAIConfigPayload; anthropicVertexAI?: AnthropicVertexAIConfigPayload; + sapAICore?: SAPAICoreConfigPayload; } /**