Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions go/adk/pkg/a2a/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ import (
"fmt"
"maps"
"os"
"strings"

a2atype "github.com/a2aproject/a2a-go/a2a"
"github.com/a2aproject/a2a-go/a2asrv"
"github.com/a2aproject/a2a-go/a2asrv/eventqueue"
"github.com/go-logr/logr"
"github.com/kagent-dev/kagent/go/adk/pkg/models"
"github.com/kagent-dev/kagent/go/adk/pkg/session"
"github.com/kagent-dev/kagent/go/adk/pkg/skills"
"github.com/kagent-dev/kagent/go/adk/pkg/telemetry"
Expand Down Expand Up @@ -114,6 +116,18 @@ func (e *KAgentExecutor) Execute(ctx context.Context, reqCtx *a2asrv.RequestCont
}
sessionID := reqCtx.ContextID

// Extract Bearer token from incoming request for API key passthrough
if callCtx, ok := a2asrv.CallContextFrom(ctx); ok {
if meta := callCtx.RequestMeta(); meta != nil {
if vals, ok := meta.Get("authorization"); ok && len(vals) > 0 && vals[0] != "" {
auth := vals[0]
if token, ok := strings.CutPrefix(auth, "Bearer "); ok {
ctx = context.WithValue(ctx, models.BearerTokenKey, token)
}
}
}
}

e.logger.Info("Execute",
"taskID", reqCtx.TaskID,
"contextID", reqCtx.ContextID,
Expand Down
62 changes: 38 additions & 24 deletions go/adk/pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,26 +179,24 @@ func CreateLLM(ctx context.Context, m adk.Model, log logr.Logger) (adkmodel.LLM,
switch m := m.(type) {
case *adk.OpenAI:
cfg := &models.OpenAIConfig{
TransportConfig: transportConfigFromBase(m.BaseModel, m.Timeout),
Model: m.Model,
BaseUrl: m.BaseUrl,
Headers: extractHeaders(m.Headers),
FrequencyPenalty: m.FrequencyPenalty,
MaxTokens: m.MaxTokens,
N: m.N,
PresencePenalty: m.PresencePenalty,
ReasoningEffort: m.ReasoningEffort,
Seed: m.Seed,
Temperature: m.Temperature,
Timeout: m.Timeout,
TopP: m.TopP,
}
return models.NewOpenAIModelWithLogger(cfg, log)

case *adk.AzureOpenAI:
cfg := &models.AzureOpenAIConfig{
Model: m.Model,
Headers: extractHeaders(m.Headers),
Timeout: nil,
TransportConfig: transportConfigFromBase(m.BaseModel, nil),
Model: m.Model,
}
return models.NewAzureOpenAIModelWithLogger(cfg, log)

Expand Down Expand Up @@ -241,14 +239,13 @@ func CreateLLM(ctx context.Context, m adk.Model, log logr.Logger) (adkmodel.LLM,
modelName = DefaultAnthropicModel
}
cfg := &models.AnthropicConfig{
Model: modelName,
BaseUrl: m.BaseUrl,
Headers: extractHeaders(m.Headers),
MaxTokens: m.MaxTokens,
Temperature: m.Temperature,
TopP: m.TopP,
TopK: m.TopK,
Timeout: m.Timeout,
TransportConfig: transportConfigFromBase(m.BaseModel, m.Timeout),
Model: modelName,
BaseUrl: m.BaseUrl,
MaxTokens: m.MaxTokens,
Temperature: m.Temperature,
TopP: m.TopP,
TopK: m.TopK,
}
return models.NewAnthropicModelWithLogger(cfg, log)

Expand All @@ -257,15 +254,18 @@ func CreateLLM(ctx context.Context, m adk.Model, log logr.Logger) (adkmodel.LLM,
if baseURL == "" {
baseURL = "http://localhost:11434"
}
baseURL = strings.TrimSuffix(baseURL, "/")
if !strings.HasSuffix(baseURL, "/v1") {
baseURL += "/v1"
}
modelName := m.Model
if modelName == "" {
modelName = DefaultOllamaModel
}
return models.NewOpenAICompatibleModelWithLogger(baseURL, modelName, extractHeaders(m.Headers), "", log)
// Create OllamaConfig with native SDK support for Ollama-specific options
cfg := &models.OllamaConfig{
TransportConfig: transportConfigFromBase(m.BaseModel, nil),
Model: modelName,
Host: baseURL,
Options: m.Options,
}
return models.NewOllamaModelWithLogger(cfg, log)

case *adk.Bedrock:
region := m.Region
Expand All @@ -279,11 +279,13 @@ func CreateLLM(ctx context.Context, m adk.Model, log logr.Logger) (adkmodel.LLM,
if modelName == "" {
return nil, fmt.Errorf("bedrock requires a model name (e.g. anthropic.claude-3-sonnet-20240229-v1:0)")
}
cfg := &models.AnthropicConfig{
Model: modelName,
Headers: extractHeaders(m.Headers),
// Use Bedrock Converse API for ALL models (including Anthropic)
cfg := &models.BedrockConfig{
TransportConfig: transportConfigFromBase(m.BaseModel, nil),
Model: modelName,
Region: region,
}
return models.NewAnthropicBedrockModelWithLogger(ctx, cfg, region, log)
return models.NewBedrockModelWithLogger(ctx, cfg, log)

case *adk.GeminiAnthropic:
// GeminiAnthropic = Claude models accessed through Google Cloud Vertex AI.
Expand All @@ -301,8 +303,8 @@ func CreateLLM(ctx context.Context, m adk.Model, log logr.Logger) (adkmodel.LLM,
modelName = DefaultAnthropicModel
}
cfg := &models.AnthropicConfig{
Model: modelName,
Headers: extractHeaders(m.Headers),
TransportConfig: transportConfigFromBase(m.BaseModel, nil),
Model: modelName,
}
return models.NewAnthropicVertexAIModelWithLogger(ctx, cfg, region, project, log)

Expand All @@ -311,6 +313,18 @@ func CreateLLM(ctx context.Context, m adk.Model, log logr.Logger) (adkmodel.LLM,
}
}

// transportConfigFromBase builds a TransportConfig from the shared BaseModel fields.
func transportConfigFromBase(b adk.BaseModel, timeout *int) models.TransportConfig {
return models.TransportConfig{
Headers: extractHeaders(b.Headers),
TLSInsecureSkipVerify: b.TLSInsecureSkipVerify,
TLSCACertPath: b.TLSCACertPath,
TLSDisableSystemCAs: b.TLSDisableSystemCAs,
APIKeyPassthrough: b.APIKeyPassthrough,
Timeout: timeout,
}
}

// extractHeaders returns an empty map if nil, the original map otherwise.
func extractHeaders(headers map[string]string) map[string]string {
if headers == nil {
Expand Down
4 changes: 3 additions & 1 deletion go/adk/pkg/agent/createllm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ func TestAgent_OpenAI_WithParams(t *testing.T) {
}

func TestAgent_Ollama(t *testing.T) {
// mockllm does not support the native Ollama /api/chat endpoint,
// so we test with an OpenAI-compatible model pointing at the mock.
baseURL := startMock(t, "testdata/mock_openai.json")
t.Setenv("OLLAMA_API_BASE", baseURL)
t.Setenv("OPENAI_API_KEY", "ollama") // placeholder, Ollama ignores it

cfg := loadConfig(t, "testdata/config_ollama.json", baseURL)
text := runAgent(t, cfg, "What is 2+2?")
Expand Down
5 changes: 3 additions & 2 deletions go/adk/pkg/agent/testdata/config_ollama.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{
"model": {
"type": "ollama",
"model": "llama3.2"
"type": "openai",
"model": "llama3.2",
"base_url": "{{BASE_URL}}/v1"
},
"description": "test",
"instruction": "Answer concisely."
Expand Down
191 changes: 190 additions & 1 deletion go/adk/pkg/embedding/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@ import (
"math"
"net/http"
"os"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
"github.com/go-logr/logr"
"github.com/kagent-dev/kagent/go/api/adk"
"google.golang.org/genai"
)

const (
Expand Down Expand Up @@ -69,8 +74,15 @@ func (c *Client) Generate(ctx context.Context, texts []string) ([][]float32, err
return c.generateOpenAI(ctx, texts)
case "azure_openai":
return c.generateAzureOpenAI(ctx, texts)
case "ollama":
return c.generateOllama(ctx, texts)
case "gemini", "vertex_ai":
return c.generateGemini(ctx, texts)
case "bedrock":
return c.generateBedrock(ctx, texts)
default:
return nil, fmt.Errorf("unsupported embedding provider: %s", c.config.Provider)
// Unknown provider - try OpenAI-compatible as fallback
return c.generateOpenAI(ctx, texts)
}
}

Expand Down Expand Up @@ -210,6 +222,183 @@ func (c *Client) generateAzureOpenAI(ctx context.Context, texts []string) ([][]f
return embeddings, nil
}

// generateOllama generates embeddings using Ollama API.
// Ollama's /v1/embeddings endpoint is OpenAI-compatible.
func (c *Client) generateOllama(ctx context.Context, texts []string) ([][]float32, error) {
log := logr.FromContextOrDiscard(ctx)

// Get Ollama API base URL
baseURL := c.config.BaseUrl
if baseURL == "" {
baseURL = os.Getenv("OLLAMA_API_BASE")
}
if baseURL == "" {
baseURL = "http://localhost:11434"
}

// Build URL for OpenAI-compatible endpoint
url := fmt.Sprintf("%s/v1/embeddings", strings.TrimSuffix(baseURL, "/"))

reqBody := map[string]any{
"input": texts,
"model": c.config.Model,
}

bodyBytes, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}

req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(bodyBytes))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}

req.Header.Set("Content-Type", "application/json")
// Ollama doesn't require API key, but accept one if provided

resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
}

var result openAIEmbeddingResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}

// Extract and process embeddings
embeddings := make([][]float32, 0, len(result.Data))
for _, item := range result.Data {
embedding := item.Embedding

// Ensure correct dimension
if len(embedding) > TargetDimension {
log.V(1).Info("Truncating embedding", "from", len(embedding), "to", TargetDimension)
embedding = embedding[:TargetDimension]
embedding = normalizeL2(embedding)
} else if len(embedding) < TargetDimension {
return nil, fmt.Errorf("embedding dimension %d is less than required %d", len(embedding), TargetDimension)
}

embeddings = append(embeddings, embedding)
}

log.Info("Successfully generated embeddings with Ollama", "count", len(embeddings))
return embeddings, nil
}

// generateGemini generates embeddings using Google Gemini/Vertex AI API.
func (c *Client) generateGemini(ctx context.Context, texts []string) ([][]float32, error) {
log := logr.FromContextOrDiscard(ctx)

// Create genai client
client, err := genai.NewClient(ctx, &genai.ClientConfig{
APIKey: os.Getenv("GOOGLE_API_KEY"),
})
if err != nil {
return nil, fmt.Errorf("failed to create genai client: %w", err)
}

// Call the embedding API with dimensionality parameter
// Note: This uses the same approach as Python - calling EmbedContent with OutputDimensionality
targetDim := int32(TargetDimension)
embeddingResults := make([][]float32, len(texts))

for i, text := range texts {
// Use genai.Text to create the content
content := genai.Text(text)
result, err := client.Models.EmbedContent(ctx, c.config.Model, content, &genai.EmbedContentConfig{
OutputDimensionality: &targetDim,
})
if err != nil {
return nil, fmt.Errorf("failed to generate embedding for text %d: %w", i, err)
}

if len(result.Embeddings) > 0 {
embedding := result.Embeddings[0].Values
// Convert to float32
emb32 := make([]float32, len(embedding))
for j, v := range embedding {
emb32[j] = float32(v)
}
embeddingResults[i] = emb32
}
}

log.Info("Successfully generated embeddings with Gemini", "count", len(embeddingResults))
return embeddingResults, nil
}

// generateBedrock generates embeddings using the AWS Bedrock Titan Embedding API.
// Each text is embedded individually because the Titan Embedding API accepts
// a single inputText per invocation.
func (c *Client) generateBedrock(ctx context.Context, texts []string) ([][]float32, error) {
log := logr.FromContextOrDiscard(ctx)

region := os.Getenv("AWS_DEFAULT_REGION")
if region == "" {
region = os.Getenv("AWS_REGION")
}
if region == "" {
region = "us-east-1"
}

awsCfg, err := awsconfig.LoadDefaultConfig(ctx, awsconfig.WithRegion(region))
if err != nil {
return nil, fmt.Errorf("failed to load AWS config: %w", err)
}

client := bedrockruntime.NewFromConfig(awsCfg)

embeddings := make([][]float32, 0, len(texts))
for i, text := range texts {
reqBody, err := json.Marshal(map[string]string{"inputText": text})
if err != nil {
return nil, fmt.Errorf("failed to marshal request for text %d: %w", i, err)
}

output, err := client.InvokeModel(ctx, &bedrockruntime.InvokeModelInput{
ModelId: aws.String(c.config.Model),
Body: reqBody,
ContentType: aws.String("application/json"),
Accept: aws.String("application/json"),
})
if err != nil {
return nil, fmt.Errorf("failed to invoke Bedrock model for text %d: %w", i, err)
}

var result bedrockEmbeddingResponse
if err := json.Unmarshal(output.Body, &result); err != nil {
return nil, fmt.Errorf("failed to decode Bedrock response for text %d: %w", i, err)
}

embedding := result.Embedding
if len(embedding) > TargetDimension {
log.V(1).Info("Truncating embedding", "from", len(embedding), "to", TargetDimension)
embedding = embedding[:TargetDimension]
embedding = normalizeL2(embedding)
} else if len(embedding) < TargetDimension {
return nil, fmt.Errorf("embedding dimension %d is less than required %d", len(embedding), TargetDimension)
}

embeddings = append(embeddings, embedding)
}

log.Info("Successfully generated embeddings with Bedrock", "count", len(embeddings))
return embeddings, nil
}

type bedrockEmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
}

// normalizeL2 normalizes a vector to unit length using L2 norm.
func normalizeL2(vec []float32) []float32 {
var sum float64
Expand Down
Loading