diff --git a/cmd/internal/imports.go b/cmd/internal/imports.go index 126f9451a0a0..002c83b2e320 100644 --- a/cmd/internal/imports.go +++ b/cmd/internal/imports.go @@ -224,6 +224,10 @@ import ( _ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqliteexecutesql" _ "github.com/googleapis/genai-toolbox/internal/tools/sqlite/sqlitesql" _ "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidbexecutesql" + _ "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidbgetqueryplan" + _ "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidblistactivequeries" + _ "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidblisttables" + _ "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidblisttiflashreplicas" _ "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidbsql" _ "github.com/googleapis/genai-toolbox/internal/tools/trino/trinoexecutesql" _ "github.com/googleapis/genai-toolbox/internal/tools/trino/trinosql" diff --git a/internal/prebuiltconfigs/tools/tidb.yaml b/internal/prebuiltconfigs/tools/tidb.yaml new file mode 100644 index 000000000000..3c4a2dc51c28 --- /dev/null +++ b/internal/prebuiltconfigs/tools/tidb.yaml @@ -0,0 +1,77 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TiDB Prebuilt Configuration +# +# TiDB is a distributed SQL database that supports Hybrid Transactional +# and Analytical Processing (HTAP) workloads. It is MySQL-compatible and +# provides horizontal scalability, strong consistency, and high availability. +# +# Key TiDB-specific features exposed by these tools: +# - TiFlash: Columnar storage for real-time analytics (requires TiDB 4.0+) +# - Distributed query execution with transparent scaling +# +# Environment variables: +# TIDB_HOST - TiDB server host (default: localhost) +# TIDB_PORT - TiDB server port (default: 4000) +# TIDB_DATABASE - Database name (required) +# TIDB_USER - Database user (required) +# TIDB_PASSWORD - Database password (required) +# +# For TiDB Cloud, SSL is automatically enabled when the host matches +# the TiDB Cloud gateway pattern (gateway*.tidbcloud.com). + +sources: + tidb-source: + kind: tidb + host: ${TIDB_HOST:localhost} + port: ${TIDB_PORT:4000} + database: ${TIDB_DATABASE} + user: ${TIDB_USER} + password: ${TIDB_PASSWORD} +tools: + execute_sql: + kind: tidb-execute-sql + source: tidb-source + description: Execute arbitrary SQL statements on TiDB. Supports SELECT, INSERT, UPDATE, DELETE, and DDL statements. Use with caution for data-modifying operations. + list_tables: + kind: tidb-list-tables + source: tidb-source + description: "Lists detailed schema information (columns, constraints, indexes, TiFlash replica count) as JSON for user-created tables. Filters by a comma-separated list of names. If names are omitted, lists all tables in user schemas. Excludes system schemas (mysql, information_schema, performance_schema, sys, METRICS_SCHEMA, INSPECTION_SCHEMA)." + get_query_plan: + kind: tidb-get-query-plan + source: tidb-source + description: "Provide information about how TiDB executes a SQL statement using EXPLAIN. Common use cases include: 1) analyze query plan to improve performance, 2) determine effectiveness of existing indexes, 3) identify if TiFlash is being used for HTAP queries. Supports 'default', 'analyze' (actual execution stats, SELECT only for safety), and 'verbose' (detailed cost info) explain types. WARNING: EXPLAIN ANALYZE actually executes the query." + list_active_queries: + kind: tidb-list-active-queries + source: tidb-source + description: Lists top N (default 10) ongoing queries from TiDB's processlist, ordered by execution time in descending order. Returns detailed information including process id, user, host, database, command, execution time, state, query text (truncated to 1000 chars), memory usage, and transaction start timestamp. + list_tiflash_replicas: + kind: tidb-list-tiflash-replicas + source: tidb-source + description: "Lists TiFlash replica status for all tables that have TiFlash replicas configured. TiFlash is TiDB's columnar storage engine that enables real-time HTAP analytics (requires TiDB 4.0+). Returns replica count, availability status, and sync progress for each table. Useful for monitoring TiFlash deployment health and identifying tables ready for analytical queries." +toolsets: + data: + - execute_sql + - list_tables + - get_query_plan + - list_active_queries + monitor: + - get_query_plan + - list_active_queries + - list_tiflash_replicas + htap: + - execute_sql + - list_tables + - list_tiflash_replicas diff --git a/internal/tools/tidb/tidbgetqueryplan/tidbgetqueryplan.go b/internal/tools/tidb/tidbgetqueryplan/tidbgetqueryplan.go new file mode 100644 index 000000000000..1afb3f44f79a --- /dev/null +++ b/internal/tools/tidb/tidbgetqueryplan/tidbgetqueryplan.go @@ -0,0 +1,211 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tidbgetqueryplan + +import ( + "context" + "database/sql" + "fmt" + "net/http" + "regexp" + "strings" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/parameters" +) + +const resourceType string = "tidb-get-query-plan" + +// stripSQLComments removes SQL comments (both -- and /* */) and leading/trailing whitespace +func stripSQLComments(sql string) string { + // Remove multi-line comments /* ... */ + reMultiLine := regexp.MustCompile(`/\*[\s\S]*?\*/`) + sql = reMultiLine.ReplaceAllString(sql, "") + // Remove single-line comments -- ... + reSingleLine := regexp.MustCompile(`--[^\n]*`) + sql = reSingleLine.ReplaceAllString(sql, "") + return strings.TrimSpace(sql) +} + +// isSelectOrWithStatement checks if the SQL is a SELECT or WITH (CTE) statement +func isSelectOrWithStatement(sql string) bool { + normalized := strings.ToUpper(stripSQLComments(sql)) + return strings.HasPrefix(normalized, "SELECT") || strings.HasPrefix(normalized, "WITH") +} + +// containsMultipleStatements checks if SQL contains multiple statements (semicolon) +func containsMultipleStatements(sql string) bool { + // Strip comments first to avoid false positives from semicolons in comments + sql = stripSQLComments(sql) + // Remove string literals to avoid false positives + reString := regexp.MustCompile(`'[^']*'|"[^"]*"`) + cleaned := reString.ReplaceAllString(sql, "") + return strings.Contains(cleaned, ";") +} + +func init() { + if !tools.Register(resourceType, newConfig) { + panic(fmt.Sprintf("tool type %q already registered", resourceType)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + TiDBPool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigType() string { + return resourceType +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + allParameters := parameters.Parameters{ + parameters.NewStringParameter("sql", "The SQL query to analyze. Must be a SELECT, INSERT, UPDATE, or DELETE statement."), + parameters.NewStringParameterWithDefault("explain_type", "default", "Optional: The type of EXPLAIN output. Options: 'default' (basic plan), 'analyze' (actual execution stats - SELECT only), 'verbose' (detailed cost info)."), + } + paramManifest := allParameters.Manifest() + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) + + // finish tool setup + t := Tool{ + Config: cfg, + AllParams: allParameters, + manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + return t, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Config + AllParams parameters.Parameters `yaml:"allParams"` + + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) + } + + paramsMap := params.AsMap() + + sqlStr, ok := paramsMap["sql"].(string) + if !ok || strings.TrimSpace(sqlStr) == "" { + return nil, util.NewAgentError("'sql' parameter is required and must be a non-empty string", nil) + } + + // Security check: reject multiple statements to prevent injection like "SELECT 1; DELETE FROM t" + if containsMultipleStatements(sqlStr) { + return nil, util.NewAgentError("multiple SQL statements are not allowed; remove any semicolons from your query", nil) + } + + explainType, _ := paramsMap["explain_type"].(string) + if explainType == "" { + explainType = "default" + } + + // Build the EXPLAIN statement based on the type + var explainStmt string + switch strings.ToLower(explainType) { + case "analyze": + // EXPLAIN ANALYZE actually executes the query and shows real execution stats + // For safety, only allow SELECT or WITH (CTE) statements + if !isSelectOrWithStatement(sqlStr) { + return nil, util.NewAgentError("EXPLAIN ANALYZE only supports SELECT statements (including WITH/CTE) for safety reasons; use 'default' or 'verbose' for other statement types", nil) + } + explainStmt = fmt.Sprintf("EXPLAIN ANALYZE %s", sqlStr) + case "verbose": + // EXPLAIN FORMAT='verbose' shows detailed cost estimation + explainStmt = fmt.Sprintf("EXPLAIN FORMAT='verbose' %s", sqlStr) + case "default": + explainStmt = fmt.Sprintf("EXPLAIN %s", sqlStr) + default: + return nil, util.NewAgentError(fmt.Sprintf("invalid value for explain_type: must be 'default', 'analyze', or 'verbose', but got %q", explainType), nil) + } + + // Log the query for debugging + logger, err := util.LoggerFromContext(ctx) + if err != nil { + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) + } + logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", resourceType, explainStmt)) + + resp, err := source.RunSQL(ctx, explainStmt, nil) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + return resp, nil +} + +func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { + return parameters.EmbedParams(ctx, t.AllParams, paramValues, embeddingModelsMap, nil) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} + +func (t Tool) GetParameters() parameters.Parameters { + return t.AllParams +} diff --git a/internal/tools/tidb/tidbgetqueryplan/tidbgetqueryplan_test.go b/internal/tools/tidb/tidbgetqueryplan/tidbgetqueryplan_test.go new file mode 100644 index 000000000000..9b8d13fcc05f --- /dev/null +++ b/internal/tools/tidb/tidbgetqueryplan/tidbgetqueryplan_test.go @@ -0,0 +1,208 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tidbgetqueryplan + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" +) + +func TestParseFromYaml(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + kind: tool + name: get_query_plan + type: tidb-get-query-plan + source: my-tidb-instance + description: Get query execution plan + `, + want: server.ToolConfigs{ + "get_query_plan": Config{ + Name: "get_query_plan", + Type: "tidb-get-query-plan", + Source: "my-tidb-instance", + Description: "Get query execution plan", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Parse contents + _, _, _, got, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} + +func TestStripSQLComments(t *testing.T) { + tcs := []struct { + desc string + in string + want string + }{ + { + desc: "no comments", + in: "SELECT * FROM users", + want: "SELECT * FROM users", + }, + { + desc: "single line comment", + in: "-- this is a comment\nSELECT * FROM users", + want: "SELECT * FROM users", + }, + { + desc: "multi-line comment", + in: "/* comment */ SELECT * FROM users", + want: "SELECT * FROM users", + }, + { + desc: "comment at end", + in: "SELECT * FROM users -- trailing comment", + want: "SELECT * FROM users", + }, + { + desc: "comment with semicolon", + in: "SELECT 1 -- comment;", + want: "SELECT 1", + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := stripSQLComments(tc.in) + if got != tc.want { + t.Errorf("stripSQLComments(%q) = %q, want %q", tc.in, got, tc.want) + } + }) + } +} + +func TestIsSelectOrWithStatement(t *testing.T) { + tcs := []struct { + desc string + in string + want bool + }{ + { + desc: "simple select", + in: "SELECT * FROM users", + want: true, + }, + { + desc: "select with comment prefix", + in: "/* hint */ SELECT * FROM users", + want: true, + }, + { + desc: "WITH CTE", + in: "WITH cte AS (SELECT 1) SELECT * FROM cte", + want: true, + }, + { + desc: "lowercase select", + in: "select * from users", + want: true, + }, + { + desc: "DELETE statement", + in: "DELETE FROM users WHERE id = 1", + want: false, + }, + { + desc: "INSERT statement", + in: "INSERT INTO users VALUES (1, 'test')", + want: false, + }, + { + desc: "UPDATE statement", + in: "UPDATE users SET name = 'test'", + want: false, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := isSelectOrWithStatement(tc.in) + if got != tc.want { + t.Errorf("isSelectOrWithStatement(%q) = %v, want %v", tc.in, got, tc.want) + } + }) + } +} + +func TestContainsMultipleStatements(t *testing.T) { + tcs := []struct { + desc string + in string + want bool + }{ + { + desc: "single statement", + in: "SELECT * FROM users", + want: false, + }, + { + desc: "multiple statements", + in: "SELECT * FROM users; DELETE FROM users", + want: true, + }, + { + desc: "semicolon in string literal", + in: "SELECT * FROM users WHERE name = 'test;value'", + want: false, + }, + { + desc: "semicolon in double quoted string", + in: `SELECT * FROM users WHERE name = "test;value"`, + want: false, + }, + { + desc: "semicolon in single line comment", + in: "SELECT 1 -- comment;", + want: false, + }, + { + desc: "semicolon in multi-line comment", + in: "SELECT 1 /* comment; */", + want: false, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + got := containsMultipleStatements(tc.in) + if got != tc.want { + t.Errorf("containsMultipleStatements(%q) = %v, want %v", tc.in, got, tc.want) + } + }) + } +} diff --git a/internal/tools/tidb/tidblistactivequeries/tidblistactivequeries.go b/internal/tools/tidb/tidblistactivequeries/tidblistactivequeries.go new file mode 100644 index 000000000000..92463a5c87bb --- /dev/null +++ b/internal/tools/tidb/tidblistactivequeries/tidblistactivequeries.go @@ -0,0 +1,194 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tidblistactivequeries + +import ( + "context" + "database/sql" + "fmt" + "net/http" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/parameters" +) + +const resourceType string = "tidb-list-active-queries" + +// listActiveQueriesStatement queries active queries from TiDB's processlist and transaction info. +// TiDB uses INFORMATION_SCHEMA.PROCESSLIST and INFORMATION_SCHEMA.CLUSTER_PROCESSLIST +// for distributed query visibility. +const listActiveQueriesStatement = ` + SELECT + JSON_OBJECT( + 'process_id', P.ID, + 'user', P.USER, + 'host', P.HOST, + 'db', P.DB, + 'command', P.COMMAND, + 'time_seconds', P.TIME, + 'state', P.STATE, + 'info', LEFT(P.INFO, 1000), + 'mem_bytes', P.MEM, + 'txn_start_ts', P.TxnStart, + 'session_alias', P.SESSION_ALIAS + ) AS query_info + FROM + INFORMATION_SCHEMA.PROCESSLIST P + WHERE + P.COMMAND != 'Sleep' + AND P.INFO IS NOT NULL + AND P.INFO != '' + AND P.USER != 'system user' + ORDER BY + P.TIME DESC + LIMIT ?; +` + +func init() { + if !tools.Register(resourceType, newConfig) { + panic(fmt.Sprintf("tool type %q already registered", resourceType)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + TiDBPool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigType() string { + return resourceType +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + allParameters := parameters.Parameters{ + parameters.NewIntParameterWithDefault("limit", 10, "Optional: Maximum number of active queries to return (default: 10)."), + } + paramManifest := allParameters.Manifest() + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) + + // finish tool setup + t := Tool{ + Config: cfg, + AllParams: allParameters, + manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + return t, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Config + AllParams parameters.Parameters `yaml:"allParams"` + + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) + } + + paramsMap := params.AsMap() + + // Framework guarantees type safety for typed constructors like NewIntParameterWithDefault + limit := paramsMap["limit"].(int) + if limit <= 0 { + limit = 10 + } + + // Cap limit to prevent excessive results + if limit > 100 { + limit = 100 + } + + // Log the query for debugging + logger, err := util.LoggerFromContext(ctx) + if err != nil { + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) + } + logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool with limit: %d", resourceType, limit)) + + resp, err := source.RunSQL(ctx, listActiveQueriesStatement, []any{limit}) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + + // if there's no results, return empty list instead of null + resSlice, ok := resp.([]any) + if !ok || len(resSlice) == 0 { + return []any{}, nil + } + return resp, nil +} + +func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { + return parameters.EmbedParams(ctx, t.AllParams, paramValues, embeddingModelsMap, nil) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} + +func (t Tool) GetParameters() parameters.Parameters { + return t.AllParams +} diff --git a/internal/tools/tidb/tidblistactivequeries/tidblistactivequeries_test.go b/internal/tools/tidb/tidblistactivequeries/tidblistactivequeries_test.go new file mode 100644 index 000000000000..66b4615bf693 --- /dev/null +++ b/internal/tools/tidb/tidblistactivequeries/tidblistactivequeries_test.go @@ -0,0 +1,68 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tidblistactivequeries_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidblistactivequeries" +) + +func TestParseFromYaml(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + kind: tool + name: list_active_queries + type: tidb-list-active-queries + source: my-tidb-instance + description: List active queries in TiDB + `, + want: server.ToolConfigs{ + "list_active_queries": tidblistactivequeries.Config{ + Name: "list_active_queries", + Type: "tidb-list-active-queries", + Source: "my-tidb-instance", + Description: "List active queries in TiDB", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Parse contents + _, _, _, got, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} diff --git a/internal/tools/tidb/tidblisttables/tidblisttables.go b/internal/tools/tidb/tidblisttables/tidblisttables.go new file mode 100644 index 000000000000..e153d6ee944c --- /dev/null +++ b/internal/tools/tidb/tidblisttables/tidblisttables.go @@ -0,0 +1,284 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tidblisttables + +import ( + "context" + "database/sql" + "fmt" + "net/http" + + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/parameters" +) + +const resourceType string = "tidb-list-tables" + +// listTablesStatement queries table metadata from TiDB's INFORMATION_SCHEMA. +// TiDB is MySQL-compatible but has some differences in system tables. +// This query provides detailed schema information including TiDB-specific features. +const listTablesStatement = ` + SELECT + T.TABLE_SCHEMA AS schema_name, + T.TABLE_NAME AS object_name, + CASE + WHEN @output_format = 'simple' THEN + JSON_OBJECT('name', T.TABLE_NAME) + ELSE + CONVERT( + JSON_OBJECT( + 'schema_name', T.TABLE_SCHEMA, + 'object_name', T.TABLE_NAME, + 'object_type', 'TABLE', + 'comment', IFNULL(T.TABLE_COMMENT, ''), + 'tiflash_replica_count', IFNULL(T.TIFLASH_REPLICA_COUNT, 0), + 'columns', ( + SELECT + IFNULL( + JSON_ARRAYAGG( + JSON_OBJECT( + 'column_name', C.COLUMN_NAME, + 'data_type', C.COLUMN_TYPE, + 'ordinal_position', C.ORDINAL_POSITION, + 'is_not_nullable', IF(C.IS_NULLABLE = 'NO', TRUE, FALSE), + 'column_default', C.COLUMN_DEFAULT, + 'column_comment', IFNULL(C.COLUMN_COMMENT, '') + ) + ), + JSON_ARRAY() + ) + FROM + INFORMATION_SCHEMA.COLUMNS C + WHERE + C.TABLE_SCHEMA = T.TABLE_SCHEMA AND C.TABLE_NAME = T.TABLE_NAME + ORDER BY C.ORDINAL_POSITION + ), + 'constraints', ( + SELECT + IFNULL( + JSON_ARRAYAGG( + JSON_OBJECT( + 'constraint_name', TC.CONSTRAINT_NAME, + 'constraint_type', + CASE TC.CONSTRAINT_TYPE + WHEN 'PRIMARY KEY' THEN 'PRIMARY KEY' + WHEN 'FOREIGN KEY' THEN 'FOREIGN KEY' + WHEN 'UNIQUE' THEN 'UNIQUE' + ELSE TC.CONSTRAINT_TYPE + END, + 'constraint_columns', ( + SELECT + IFNULL(JSON_ARRAYAGG(KCU.COLUMN_NAME), JSON_ARRAY()) + FROM + INFORMATION_SCHEMA.KEY_COLUMN_USAGE KCU + WHERE + KCU.CONSTRAINT_SCHEMA = TC.CONSTRAINT_SCHEMA + AND KCU.CONSTRAINT_NAME = TC.CONSTRAINT_NAME + AND KCU.TABLE_NAME = TC.TABLE_NAME + ORDER BY KCU.ORDINAL_POSITION + ), + 'foreign_key_referenced_table', IF(TC.CONSTRAINT_TYPE = 'FOREIGN KEY', RC.REFERENCED_TABLE_NAME, NULL), + 'foreign_key_referenced_columns', IF(TC.CONSTRAINT_TYPE = 'FOREIGN KEY', + (SELECT IFNULL(JSON_ARRAYAGG(FKCU.REFERENCED_COLUMN_NAME), JSON_ARRAY()) + FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE FKCU + WHERE FKCU.CONSTRAINT_SCHEMA = TC.CONSTRAINT_SCHEMA + AND FKCU.CONSTRAINT_NAME = TC.CONSTRAINT_NAME + AND FKCU.TABLE_NAME = TC.TABLE_NAME + AND FKCU.REFERENCED_TABLE_NAME IS NOT NULL + ORDER BY FKCU.ORDINAL_POSITION), + NULL + ) + ) + ), + JSON_ARRAY() + ) + FROM + INFORMATION_SCHEMA.TABLE_CONSTRAINTS TC + LEFT JOIN + INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS RC + ON TC.CONSTRAINT_SCHEMA = RC.CONSTRAINT_SCHEMA + AND TC.CONSTRAINT_NAME = RC.CONSTRAINT_NAME + AND TC.TABLE_NAME = RC.TABLE_NAME + WHERE + TC.TABLE_SCHEMA = T.TABLE_SCHEMA AND TC.TABLE_NAME = T.TABLE_NAME + ), + 'indexes', ( + SELECT + IFNULL( + JSON_ARRAYAGG( + JSON_OBJECT( + 'index_name', IndexData.INDEX_NAME, + 'is_unique', IF(IndexData.NON_UNIQUE = 0, TRUE, FALSE), + 'is_primary', IF(IndexData.INDEX_NAME = 'PRIMARY', TRUE, FALSE), + 'index_columns', IFNULL(IndexData.INDEX_COLUMNS_ARRAY, JSON_ARRAY()) + ) + ), + JSON_ARRAY() + ) + FROM ( + SELECT + S.TABLE_SCHEMA, + S.TABLE_NAME, + S.INDEX_NAME, + MIN(S.NON_UNIQUE) AS NON_UNIQUE, + JSON_ARRAYAGG(S.COLUMN_NAME) AS INDEX_COLUMNS_ARRAY + FROM + INFORMATION_SCHEMA.STATISTICS S + GROUP BY + S.TABLE_SCHEMA, S.TABLE_NAME, S.INDEX_NAME + ) AS IndexData + WHERE IndexData.TABLE_SCHEMA = T.TABLE_SCHEMA AND IndexData.TABLE_NAME = T.TABLE_NAME + ORDER BY IndexData.INDEX_NAME + ) + ) + USING utf8mb4) + END AS object_details + FROM + INFORMATION_SCHEMA.TABLES T + CROSS JOIN (SELECT @table_names := ?, @output_format := ?) AS variables + WHERE + T.TABLE_SCHEMA NOT IN ('mysql', 'information_schema', 'performance_schema', 'sys', 'METRICS_SCHEMA', 'INSPECTION_SCHEMA') + AND (NULLIF(TRIM(@table_names), '') IS NULL OR FIND_IN_SET(T.TABLE_NAME, @table_names)) + AND T.TABLE_TYPE = 'BASE TABLE' + ORDER BY + T.TABLE_SCHEMA, T.TABLE_NAME; +` + +func init() { + if !tools.Register(resourceType, newConfig) { + panic(fmt.Sprintf("tool type %q already registered", resourceType)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + TiDBPool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigType() string { + return resourceType +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + allParameters := parameters.Parameters{ + parameters.NewStringParameterWithDefault("table_names", "", "Optional: A comma-separated list of table names. If empty, details for all tables will be listed."), + parameters.NewStringParameterWithDefault("output_format", "detailed", "Optional: Use 'simple' for names only or 'detailed' for full info including TiFlash replica count."), + } + paramManifest := allParameters.Manifest() + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) + + // finish tool setup + t := Tool{ + Config: cfg, + AllParams: allParameters, + manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + return t, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Config + AllParams parameters.Parameters `yaml:"allParams"` + + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) + } + + paramsMap := params.AsMap() + + tableNames, ok := paramsMap["table_names"].(string) + if !ok { + return nil, util.NewAgentError(fmt.Sprintf("invalid 'table_names' parameter; expected a string but got %T", paramsMap["table_names"]), nil) + } + outputFormat, _ := paramsMap["output_format"].(string) + if outputFormat != "simple" && outputFormat != "detailed" { + return nil, util.NewAgentError(fmt.Sprintf("invalid value for output_format: must be 'simple' or 'detailed', but got %q", outputFormat), nil) + } + resp, err := source.RunSQL(ctx, listTablesStatement, []any{tableNames, outputFormat}) + if err != nil { + return nil, util.ProcessGeneralError(err) + } + // if there's no results, return empty list instead of null + resSlice, ok := resp.([]any) + if !ok || len(resSlice) == 0 { + return []any{}, nil + } + return resp, nil +} + +func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { + return parameters.EmbedParams(ctx, t.AllParams, paramValues, embeddingModelsMap, nil) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} + +func (t Tool) GetParameters() parameters.Parameters { + return t.AllParams +} diff --git a/internal/tools/tidb/tidblisttables/tidblisttables_test.go b/internal/tools/tidb/tidblisttables/tidblisttables_test.go new file mode 100644 index 000000000000..77914e461905 --- /dev/null +++ b/internal/tools/tidb/tidblisttables/tidblisttables_test.go @@ -0,0 +1,89 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tidblisttables_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidblisttables" +) + +func TestParseFromYaml(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + kind: tool + name: list_tables + type: tidb-list-tables + source: my-tidb-instance + description: List all tables in the TiDB database + `, + want: server.ToolConfigs{ + "list_tables": tidblisttables.Config{ + Name: "list_tables", + Type: "tidb-list-tables", + Source: "my-tidb-instance", + Description: "List all tables in the TiDB database", + AuthRequired: []string{}, + }, + }, + }, + { + desc: "with auth required", + in: ` + kind: tool + name: list_tables + type: tidb-list-tables + source: my-tidb-instance + description: List all tables in the TiDB database + authRequired: + - my-auth-service + `, + want: server.ToolConfigs{ + "list_tables": tidblisttables.Config{ + Name: "list_tables", + Type: "tidb-list-tables", + Source: "my-tidb-instance", + Description: "List all tables in the TiDB database", + AuthRequired: []string{"my-auth-service"}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Parse contents + _, _, _, got, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +} diff --git a/internal/tools/tidb/tidblisttiflashreplicas/tidblisttiflashreplicas.go b/internal/tools/tidb/tidblisttiflashreplicas/tidblisttiflashreplicas.go new file mode 100644 index 000000000000..dd760ea178bf --- /dev/null +++ b/internal/tools/tidb/tidblisttiflashreplicas/tidblisttiflashreplicas.go @@ -0,0 +1,203 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tidblisttiflashreplicas provides a tool to list TiFlash replica status. +// TiFlash is TiDB's columnar storage engine for real-time analytics. +// Note: TiFlash is available in TiDB 4.0+. This tool will return an empty list +// or an error on older versions. +package tidblisttiflashreplicas + +import ( + "context" + "errors" + "database/sql" + "fmt" + "net/http" + "strings" + + "github.com/go-sql-driver/mysql" + yaml "github.com/goccy/go-yaml" + "github.com/googleapis/genai-toolbox/internal/embeddingmodels" + "github.com/googleapis/genai-toolbox/internal/sources" + "github.com/googleapis/genai-toolbox/internal/tools" + "github.com/googleapis/genai-toolbox/internal/util" + "github.com/googleapis/genai-toolbox/internal/util/parameters" +) + +const resourceType string = "tidb-list-tiflash-replicas" + +// MySQL error code for "Unknown column" +const mysqlErrUnknownColumn = 1054 + +// listTiFlashReplicasStatement queries TiFlash replica status from TiDB. +// This is a TiDB-specific feature not available in MySQL. +// Uses IFNULL to handle potential NULL values gracefully. +const listTiFlashReplicasStatement = ` + SELECT + JSON_OBJECT( + 'table_schema', T.TABLE_SCHEMA, + 'table_name', T.TABLE_NAME, + 'replica_count', IFNULL(T.TIFLASH_REPLICA_COUNT, 0), + 'available', IFNULL(TR.AVAILABLE, 0), + 'progress', IFNULL(TR.PROGRESS, 0) + ) AS tiflash_info + FROM + INFORMATION_SCHEMA.TABLES T + LEFT JOIN + INFORMATION_SCHEMA.TIFLASH_REPLICA TR + ON T.TABLE_SCHEMA = TR.TABLE_SCHEMA AND T.TABLE_NAME = TR.TABLE_NAME + WHERE + T.TABLE_SCHEMA NOT IN ('mysql', 'information_schema', 'performance_schema', 'sys', 'METRICS_SCHEMA', 'INSPECTION_SCHEMA') + AND T.TABLE_TYPE = 'BASE TABLE' + AND IFNULL(T.TIFLASH_REPLICA_COUNT, 0) > 0 + ORDER BY + T.TABLE_SCHEMA, T.TABLE_NAME; +` + +func init() { + if !tools.Register(resourceType, newConfig) { + panic(fmt.Sprintf("tool type %q already registered", resourceType)) + } +} + +func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) { + actual := Config{Name: name} + if err := decoder.DecodeContext(ctx, &actual); err != nil { + return nil, err + } + return actual, nil +} + +type compatibleSource interface { + TiDBPool() *sql.DB + RunSQL(context.Context, string, []any) (any, error) +} + +type Config struct { + Name string `yaml:"name" validate:"required"` + Type string `yaml:"type" validate:"required"` + Source string `yaml:"source" validate:"required"` + Description string `yaml:"description" validate:"required"` + AuthRequired []string `yaml:"authRequired"` +} + +// validate interface +var _ tools.ToolConfig = Config{} + +func (cfg Config) ToolConfigType() string { + return resourceType +} + +func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) { + // No parameters needed - this tool returns all TiFlash replicas + allParameters := parameters.Parameters{} + paramManifest := allParameters.Manifest() + mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil) + + // finish tool setup + t := Tool{ + Config: cfg, + AllParams: allParameters, + manifest: tools.Manifest{Description: cfg.Description, Parameters: paramManifest, AuthRequired: cfg.AuthRequired}, + mcpManifest: mcpManifest, + } + return t, nil +} + +// validate interface +var _ tools.Tool = Tool{} + +type Tool struct { + Config + AllParams parameters.Parameters `yaml:"allParams"` + + manifest tools.Manifest + mcpManifest tools.McpManifest +} + +// isTiFlashUnsupportedError checks if the error indicates TiFlash is not available +// (older TiDB version or TiFlash not deployed). Uses MySQL error code 1054 (Unknown column) +// and checks if the missing column is TiFlash-related. +func isTiFlashUnsupportedError(err error) bool { + var mysqlErr *mysql.MySQLError + if ok := errors.As(err, &mysqlErr); ok { + if mysqlErr.Number == mysqlErrUnknownColumn { + msg := strings.ToUpper(mysqlErr.Message) + return strings.Contains(msg, "TIFLASH_REPLICA_COUNT") || strings.Contains(msg, "TIFLASH_REPLICA") + } + } + return false +} + +func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, util.ToolboxError) { + source, err := tools.GetCompatibleSource[compatibleSource](resourceMgr, t.Source, t.Name, t.Type) + if err != nil { + return nil, util.NewClientServerError("source used is not compatible with the tool", http.StatusInternalServerError, err) + } + + // Log the query for debugging + logger, err := util.LoggerFromContext(ctx) + if err != nil { + return nil, util.NewClientServerError("error getting logger", http.StatusInternalServerError, err) + } + logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool", resourceType)) + + resp, err := source.RunSQL(ctx, listTiFlashReplicasStatement, nil) + if err != nil { + // Check for TiFlash-related "Unknown column" errors (MySQL error 1054) + if isTiFlashUnsupportedError(err) { + return nil, util.NewAgentError("TiFlash is not available on this TiDB version (requires TiDB 4.0+) or TiFlash is not deployed", err) + } + return nil, util.ProcessGeneralError(err) + } + + // if there's no results, return empty list instead of null + resSlice, ok := resp.([]any) + if !ok || len(resSlice) == 0 { + return []any{}, nil + } + return resp, nil +} + +func (t Tool) EmbedParams(ctx context.Context, paramValues parameters.ParamValues, embeddingModelsMap map[string]embeddingmodels.EmbeddingModel) (parameters.ParamValues, error) { + return parameters.EmbedParams(ctx, t.AllParams, paramValues, embeddingModelsMap, nil) +} + +func (t Tool) Manifest() tools.Manifest { + return t.manifest +} + +func (t Tool) McpManifest() tools.McpManifest { + return t.mcpManifest +} + +func (t Tool) Authorized(verifiedAuthServices []string) bool { + return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices) +} + +func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) (bool, error) { + return false, nil +} + +func (t Tool) ToConfig() tools.ToolConfig { + return t.Config +} + +func (t Tool) GetAuthTokenHeaderName(resourceMgr tools.SourceProvider) (string, error) { + return "Authorization", nil +} + +func (t Tool) GetParameters() parameters.Parameters { + return t.AllParams +} diff --git a/internal/tools/tidb/tidblisttiflashreplicas/tidblisttiflashreplicas_test.go b/internal/tools/tidb/tidblisttiflashreplicas/tidblisttiflashreplicas_test.go new file mode 100644 index 000000000000..f572ca61677a --- /dev/null +++ b/internal/tools/tidb/tidblisttiflashreplicas/tidblisttiflashreplicas_test.go @@ -0,0 +1,68 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tidblisttiflashreplicas_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/googleapis/genai-toolbox/internal/server" + "github.com/googleapis/genai-toolbox/internal/testutils" + "github.com/googleapis/genai-toolbox/internal/tools/tidb/tidblisttiflashreplicas" +) + +func TestParseFromYaml(t *testing.T) { + ctx, err := testutils.ContextWithNewLogger() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + tcs := []struct { + desc string + in string + want server.ToolConfigs + }{ + { + desc: "basic example", + in: ` + kind: tool + name: list_tiflash_replicas + type: tidb-list-tiflash-replicas + source: my-tidb-instance + description: List TiFlash replica status + `, + want: server.ToolConfigs{ + "list_tiflash_replicas": tidblisttiflashreplicas.Config{ + Name: "list_tiflash_replicas", + Type: "tidb-list-tiflash-replicas", + Source: "my-tidb-instance", + Description: "List TiFlash replica status", + AuthRequired: []string{}, + }, + }, + }, + } + for _, tc := range tcs { + t.Run(tc.desc, func(t *testing.T) { + // Parse contents + _, _, _, got, _, _, err := server.UnmarshalResourceConfig(ctx, testutils.FormatYaml(tc.in)) + if err != nil { + t.Fatalf("unable to unmarshal: %s", err) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Fatalf("incorrect parse: diff %v", diff) + } + }) + } +}