diff --git a/servers/Azure.Mcp.Server/changelog-entries/cosmos-mcp-toolkit-tools.yaml b/servers/Azure.Mcp.Server/changelog-entries/cosmos-mcp-toolkit-tools.yaml new file mode 100644 index 0000000000..d36a432a14 --- /dev/null +++ b/servers/Azure.Mcp.Server/changelog-entries/cosmos-mcp-toolkit-tools.yaml @@ -0,0 +1,9 @@ +changes: + - section: "Features Added" + description: | + Added five Cosmos DB MCP tools: + - `cosmos_database_container_schema_get`: Infer container schema by sampling documents + - `cosmos_database_container_item_list-recent`: Return the most recently modified documents + - `cosmos_database_container_item_get`: Look up a document by id, with optional partition key for point reads + - `cosmos_database_container_item_text-search`: FullTextContains-based property search + - `cosmos_database_container_item_vector-search`: VectorDistance similarity search with optional Azure OpenAI embedding generation via `--search-text` diff --git a/servers/Azure.Mcp.Server/docs/azmcp-commands.md b/servers/Azure.Mcp.Server/docs/azmcp-commands.md index 6a0abf0785..b87fc5237f 100644 --- a/servers/Azure.Mcp.Server/docs/azmcp-commands.md +++ b/servers/Azure.Mcp.Server/docs/azmcp-commands.md @@ -2001,6 +2001,57 @@ azmcp cosmos database container item query --subscription \ --database \ --container \ [--query "SELECT * FROM c"] + +# Infer an approximate schema for a Cosmos DB container by sampling documents. +# ❌ Destructive | ✅ Idempotent | ❌ OpenWorld | ✅ ReadOnly | ❌ Secret | ❌ LocalRequired +azmcp cosmos database container schema get --subscription \ + --account \ + --database \ + --container \ + [--sample-size 10] + +# Get the most recently modified documents from a Cosmos DB container. +# ❌ Destructive | ✅ Idempotent | ❌ OpenWorld | ✅ ReadOnly | ❌ Secret | ❌ LocalRequired +azmcp cosmos database container item list-recent --subscription \ + --account \ + --database \ + --container \ + [--count 10] + +# Get a single Cosmos DB document by id (point read when --partition-key is supplied). +# ❌ Destructive | ✅ Idempotent | ❌ OpenWorld | ✅ ReadOnly | ❌ Secret | ❌ LocalRequired +azmcp cosmos database container item get --subscription \ + --account \ + --database \ + --container \ + --id \ + [--partition-key ] + +# Search Cosmos DB documents where a property contains a phrase (FullTextContains). +# ❌ Destructive | ✅ Idempotent | ❌ OpenWorld | ✅ ReadOnly | ❌ Secret | ❌ LocalRequired +azmcp cosmos database container item text-search --subscription \ + --account \ + --database \ + --container \ + --property \ + --search-phrase \ + [--count 10] + +# Vector similarity search against a Cosmos DB container. Provide --embedding (CSV floats) +# or --search-text plus --openai-endpoint and --embedding-deployment to generate one. +# ❌ Destructive | ✅ Idempotent | ❌ OpenWorld | ✅ ReadOnly | ❌ Secret | ❌ LocalRequired +azmcp cosmos database container item vector-search --subscription \ + --account \ + --database \ + --container \ + --vector-property \ + --select-properties \ + [--count 10] \ + [--embedding "0.1,0.2,..."] \ + [--search-text "free-form text"] \ + [--openai-endpoint ] \ + [--embedding-deployment ] \ + [--embedding-dimensions ] ``` ### Azure Data Explorer Operations diff --git a/servers/Azure.Mcp.Server/docs/e2eTestPrompts.md b/servers/Azure.Mcp.Server/docs/e2eTestPrompts.md index 7235614bc5..da52991a49 100644 --- a/servers/Azure.Mcp.Server/docs/e2eTestPrompts.md +++ b/servers/Azure.Mcp.Server/docs/e2eTestPrompts.md @@ -350,6 +350,15 @@ This file contains prompts used for end-to-end testing to ensure each tool is in | cosmos_list | List all the containers in the database for the cosmosdb account | | cosmos_list | Show me the containers in the database for the cosmosdb account | | cosmos_database_container_item_query | Show me the items that contain the word in the container in the database for the cosmosdb account | +| cosmos_database_container_item_get | Get the document with id from container in database of the cosmosdb account | +| cosmos_database_container_item_get | Find the document in container from database of the cosmosdb account using partition key | +| cosmos_database_container_item_list-recent | Show me the 15 most recent documents in container of database in cosmosdb account | +| cosmos_database_container_item_list-recent | Get the latest documents from in for cosmosdb account | +| cosmos_database_container_item_text-search | Search documents in container from database of the cosmosdb account where contains "" | +| cosmos_database_container_item_vector-search | Run a vector search in container of database for cosmosdb account using vector property and the embedding | +| cosmos_database_container_item_vector-search | Find documents similar to "" in container using vector property with Azure OpenAI endpoint and deployment | +| cosmos_database_container_schema_get | Infer the schema of container in database for cosmosdb account | +| cosmos_database_container_schema_get | Sample documents from container in database of the cosmosdb account and tell me the property names and types | ## Azure Data Explorer diff --git a/tools/Azure.Mcp.Tools.Cosmos/src/Azure.Mcp.Tools.Cosmos.csproj b/tools/Azure.Mcp.Tools.Cosmos/src/Azure.Mcp.Tools.Cosmos.csproj index d3da9fc821..63bd1bcc5c 100644 --- a/tools/Azure.Mcp.Tools.Cosmos/src/Azure.Mcp.Tools.Cosmos.csproj +++ b/tools/Azure.Mcp.Tools.Cosmos/src/Azure.Mcp.Tools.Cosmos.csproj @@ -11,6 +11,7 @@ + diff --git a/tools/Azure.Mcp.Tools.Cosmos/src/Commands/Container/ContainerSchemaGetCommand.cs b/tools/Azure.Mcp.Tools.Cosmos/src/Commands/Container/ContainerSchemaGetCommand.cs new file mode 100644 index 0000000000..aaaa7992cd --- /dev/null +++ b/tools/Azure.Mcp.Tools.Cosmos/src/Commands/Container/ContainerSchemaGetCommand.cs @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Net; +using Azure.Mcp.Tools.Cosmos.Options; +using Azure.Mcp.Tools.Cosmos.Options.Container; +using Azure.Mcp.Tools.Cosmos.Services; +using Microsoft.Azure.Cosmos; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Commands; +using Microsoft.Mcp.Core.Extensions; +using Microsoft.Mcp.Core.Models; +using Microsoft.Mcp.Core.Models.Command; + +namespace Azure.Mcp.Tools.Cosmos.Commands.Container; + +[CommandMetadata( + Id = "f1c6a0e2-3d40-4b3f-9a37-2dc1f6cf4a12", + Name = "get", + Title = "Infer Cosmos DB Container Schema", + Description = "Infer an approximate schema for a Cosmos DB container by sampling documents and reporting the top-level properties along with their inferred types and the number of sampled documents in which each appeared.", + Destructive = false, + Idempotent = true, + OpenWorld = false, + ReadOnly = true, + Secret = false, + LocalRequired = false)] +public sealed class ContainerSchemaGetCommand(ILogger logger, ICosmosService cosmosService) + : BaseContainerCommand() +{ + private readonly ILogger _logger = logger; + private readonly ICosmosService _cosmosService = cosmosService; + + protected override void RegisterOptions(Command command) + { + base.RegisterOptions(command); + command.Options.Add(CosmosOptionDefinitions.SampleSize); + command.Validators.Add(result => + { + var size = result.GetValueOrDefault(CosmosOptionDefinitions.SampleSize.Name); + if (size < 1 || size > 100) + { + result.AddError("--sample-size must be between 1 and 100."); + } + }); + } + + protected override ContainerSchemaGetOptions BindOptions(ParseResult parseResult) + { + var options = base.BindOptions(parseResult); + options.SampleSize = parseResult.GetValueOrDefault(CosmosOptionDefinitions.SampleSize.Name); + return options; + } + + public override async Task ExecuteAsync(CommandContext context, ParseResult parseResult, CancellationToken cancellationToken) + { + if (!Validate(parseResult.CommandResult, context.Response).IsValid) + { + return context.Response; + } + + var options = BindOptions(parseResult); + + try + { + var schema = await _cosmosService.GetApproximateSchema( + options.Account!, + options.Database!, + options.Container!, + options.SampleSize ?? 10, + options.Subscription!, + options.AuthMethod ?? AuthMethod.Credential, + options.Tenant, + options.RetryPolicy, + cancellationToken); + + context.Response.Results = ResponseResult.Create( + new ContainerSchemaGetCommandResult(schema.SampleSize, schema.Properties), + CosmosJsonContext.Default.ContainerSchemaGetCommandResult); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error in {Operation}. Account: {Account}, Database: {Database}, Container: {Container}", + Name, options.Account, options.Database, options.Container); + HandleException(context, ex); + } + + return context.Response; + } + + protected override string GetErrorMessage(Exception ex) => ex switch + { + CosmosException cosmosEx => cosmosEx.Message, + _ => base.GetErrorMessage(ex) + }; + + protected override HttpStatusCode GetStatusCode(Exception ex) => ex switch + { + CosmosException cosmosEx => cosmosEx.StatusCode, + _ => base.GetStatusCode(ex) + }; + + internal record ContainerSchemaGetCommandResult(int SampleSize, IReadOnlyList Properties); +} diff --git a/tools/Azure.Mcp.Tools.Cosmos/src/Commands/CosmosJsonContext.cs b/tools/Azure.Mcp.Tools.Cosmos/src/Commands/CosmosJsonContext.cs index 8b04dc84ca..2839afa1c0 100644 --- a/tools/Azure.Mcp.Tools.Cosmos/src/Commands/CosmosJsonContext.cs +++ b/tools/Azure.Mcp.Tools.Cosmos/src/Commands/CosmosJsonContext.cs @@ -2,11 +2,18 @@ // Licensed under the MIT License. using System.Text.Json.Serialization; +using Azure.Mcp.Tools.Cosmos.Commands.Container; +using Azure.Mcp.Tools.Cosmos.Commands.Item; namespace Azure.Mcp.Tools.Cosmos.Commands; [JsonSerializable(typeof(CosmosListCommand.CosmosListCommandResult))] [JsonSerializable(typeof(ItemQueryCommand.ItemQueryCommandResult))] +[JsonSerializable(typeof(ContainerSchemaGetCommand.ContainerSchemaGetCommandResult))] +[JsonSerializable(typeof(ItemListRecentCommand.ItemListRecentCommandResult))] +[JsonSerializable(typeof(ItemGetCommand.ItemGetCommandResult))] +[JsonSerializable(typeof(ItemTextSearchCommand.ItemTextSearchCommandResult))] +[JsonSerializable(typeof(ItemVectorSearchCommand.ItemVectorSearchCommandResult))] [JsonSourceGenerationOptions( PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull)] diff --git a/tools/Azure.Mcp.Tools.Cosmos/src/Commands/Item/ItemGetCommand.cs b/tools/Azure.Mcp.Tools.Cosmos/src/Commands/Item/ItemGetCommand.cs new file mode 100644 index 0000000000..7fa02bb527 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Cosmos/src/Commands/Item/ItemGetCommand.cs @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Net; +using Azure.Mcp.Tools.Cosmos.Options; +using Azure.Mcp.Tools.Cosmos.Options.Item; +using Azure.Mcp.Tools.Cosmos.Services; +using Microsoft.Azure.Cosmos; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Commands; +using Microsoft.Mcp.Core.Extensions; +using Microsoft.Mcp.Core.Models; +using Microsoft.Mcp.Core.Models.Command; + +namespace Azure.Mcp.Tools.Cosmos.Commands.Item; + +[CommandMetadata( + Id = "d4c1b2a3-9e8f-4d7c-86b5-1a2b3c4d5e6f", + Name = "get", + Title = "Get Cosmos DB Document by Id", + Description = "Find a single Cosmos DB document by its id in the specified database and container. When --partition-key is supplied, an efficient point read is used; otherwise a cross-partition query is executed.", + Destructive = false, + Idempotent = true, + OpenWorld = false, + ReadOnly = true, + Secret = false, + LocalRequired = false)] +public sealed class ItemGetCommand(ILogger logger, ICosmosService cosmosService) + : BaseContainerCommand() +{ + private readonly ILogger _logger = logger; + private readonly ICosmosService _cosmosService = cosmosService; + + protected override void RegisterOptions(Command command) + { + base.RegisterOptions(command); + command.Options.Add(CosmosOptionDefinitions.ItemId); + command.Options.Add(CosmosOptionDefinitions.PartitionKey); + } + + protected override ItemGetOptions BindOptions(ParseResult parseResult) + { + var options = base.BindOptions(parseResult); + options.Id = parseResult.GetValueOrDefault(CosmosOptionDefinitions.ItemId.Name); + options.PartitionKey = parseResult.GetValueOrDefault(CosmosOptionDefinitions.PartitionKey.Name); + return options; + } + + public override async Task ExecuteAsync(CommandContext context, ParseResult parseResult, CancellationToken cancellationToken) + { + if (!Validate(parseResult.CommandResult, context.Response).IsValid) + { + return context.Response; + } + + var options = BindOptions(parseResult); + + try + { + var item = await _cosmosService.GetItem( + options.Account!, + options.Database!, + options.Container!, + options.Id!, + options.PartitionKey, + options.Subscription!, + options.AuthMethod ?? AuthMethod.Credential, + options.Tenant, + options.RetryPolicy, + cancellationToken); + + context.Response.Results = ResponseResult.Create( + new ItemGetCommandResult(item), + CosmosJsonContext.Default.ItemGetCommandResult); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error in {Operation}. Account: {Account}, Database: {Database}, Container: {Container}", + Name, options.Account, options.Database, options.Container); + HandleException(context, ex); + } + + return context.Response; + } + + protected override string GetErrorMessage(Exception ex) => ex switch + { + CosmosException cosmosEx => cosmosEx.Message, + _ => base.GetErrorMessage(ex) + }; + + protected override HttpStatusCode GetStatusCode(Exception ex) => ex switch + { + CosmosException cosmosEx => cosmosEx.StatusCode, + _ => base.GetStatusCode(ex) + }; + + internal record ItemGetCommandResult(JsonElement? Item); +} diff --git a/tools/Azure.Mcp.Tools.Cosmos/src/Commands/Item/ItemListRecentCommand.cs b/tools/Azure.Mcp.Tools.Cosmos/src/Commands/Item/ItemListRecentCommand.cs new file mode 100644 index 0000000000..bd3dd37413 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Cosmos/src/Commands/Item/ItemListRecentCommand.cs @@ -0,0 +1,104 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Net; +using Azure.Mcp.Tools.Cosmos.Options; +using Azure.Mcp.Tools.Cosmos.Options.Item; +using Azure.Mcp.Tools.Cosmos.Services; +using Microsoft.Azure.Cosmos; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Commands; +using Microsoft.Mcp.Core.Extensions; +using Microsoft.Mcp.Core.Models; +using Microsoft.Mcp.Core.Models.Command; + +namespace Azure.Mcp.Tools.Cosmos.Commands.Item; + +[CommandMetadata( + Id = "9a1b1c2d-3e4f-4a5b-9c6d-7e8f9a0b1c2d", + Name = "list-recent", + Title = "List Recent Cosmos DB Documents", + Description = "Retrieve the most recently modified documents from a Cosmos DB container, ordered by the system timestamp (_ts) in descending order. Use the --count option to control how many documents are returned (1-20, default is 10).", + Destructive = false, + Idempotent = true, + OpenWorld = false, + ReadOnly = true, + Secret = false, + LocalRequired = false)] +public sealed class ItemListRecentCommand(ILogger logger, ICosmosService cosmosService) + : BaseContainerCommand() +{ + private readonly ILogger _logger = logger; + private readonly ICosmosService _cosmosService = cosmosService; + + protected override void RegisterOptions(Command command) + { + base.RegisterOptions(command); + command.Options.Add(CosmosOptionDefinitions.Count); + command.Validators.Add(result => + { + var count = result.GetValueOrDefault(CosmosOptionDefinitions.Count.Name); + if (count < 1 || count > 20) + { + result.AddError("--count must be between 1 and 20."); + } + }); + } + + protected override ItemListRecentOptions BindOptions(ParseResult parseResult) + { + var options = base.BindOptions(parseResult); + options.Count = parseResult.GetValueOrDefault(CosmosOptionDefinitions.Count.Name); + return options; + } + + public override async Task ExecuteAsync(CommandContext context, ParseResult parseResult, CancellationToken cancellationToken) + { + if (!Validate(parseResult.CommandResult, context.Response).IsValid) + { + return context.Response; + } + + var options = BindOptions(parseResult); + + try + { + var items = await _cosmosService.GetRecentItems( + options.Account!, + options.Database!, + options.Container!, + options.Count ?? 10, + options.Subscription!, + options.AuthMethod ?? AuthMethod.Credential, + options.Tenant, + options.RetryPolicy, + cancellationToken); + + context.Response.Results = ResponseResult.Create( + new ItemListRecentCommandResult(items ?? []), + CosmosJsonContext.Default.ItemListRecentCommandResult); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error in {Operation}. Account: {Account}, Database: {Database}, Container: {Container}", + Name, options.Account, options.Database, options.Container); + HandleException(context, ex); + } + + return context.Response; + } + + protected override string GetErrorMessage(Exception ex) => ex switch + { + CosmosException cosmosEx => cosmosEx.Message, + _ => base.GetErrorMessage(ex) + }; + + protected override HttpStatusCode GetStatusCode(Exception ex) => ex switch + { + CosmosException cosmosEx => cosmosEx.StatusCode, + _ => base.GetStatusCode(ex) + }; + + internal record ItemListRecentCommandResult(List Items); +} diff --git a/tools/Azure.Mcp.Tools.Cosmos/src/Commands/Item/ItemTextSearchCommand.cs b/tools/Azure.Mcp.Tools.Cosmos/src/Commands/Item/ItemTextSearchCommand.cs new file mode 100644 index 0000000000..335c8741c4 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Cosmos/src/Commands/Item/ItemTextSearchCommand.cs @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Net; +using Azure.Mcp.Tools.Cosmos.Options; +using Azure.Mcp.Tools.Cosmos.Options.Item; +using Azure.Mcp.Tools.Cosmos.Services; +using Azure.Mcp.Tools.Cosmos.Validation; +using Microsoft.Azure.Cosmos; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Commands; +using Microsoft.Mcp.Core.Extensions; +using Microsoft.Mcp.Core.Models; +using Microsoft.Mcp.Core.Models.Command; + +namespace Azure.Mcp.Tools.Cosmos.Commands.Item; + +[CommandMetadata( + Id = "8b4c5d6e-7f80-4a91-bc23-d4e5f6a7b8c9", + Name = "text-search", + Title = "Text Search Cosmos DB Documents", + Description = "Retrieve the TOP N documents in a Cosmos DB container where the specified property contains the provided search string. Use the --count option to control how many documents are returned (1-20, default is 10). Requires a Cosmos DB full-text index on the property.", + Destructive = false, + Idempotent = true, + OpenWorld = false, + ReadOnly = true, + Secret = false, + LocalRequired = false)] +public sealed class ItemTextSearchCommand(ILogger logger, ICosmosService cosmosService) + : BaseContainerCommand() +{ + private readonly ILogger _logger = logger; + private readonly ICosmosService _cosmosService = cosmosService; + + protected override void RegisterOptions(Command command) + { + base.RegisterOptions(command); + command.Options.Add(CosmosOptionDefinitions.Property); + command.Options.Add(CosmosOptionDefinitions.SearchPhrase); + command.Options.Add(CosmosOptionDefinitions.Count); + command.Validators.Add(result => + { + var property = result.GetValueOrDefault(CosmosOptionDefinitions.Property.Name); + if (!PropertyValidator.IsValid(property)) + { + result.AddError("--property must use dot notation with letters, digits, and underscores only (e.g., name or profile.name)."); + } + + var count = result.GetValueOrDefault(CosmosOptionDefinitions.Count.Name); + if (count < 1 || count > 20) + { + result.AddError("--count must be between 1 and 20."); + } + }); + } + + protected override ItemTextSearchOptions BindOptions(ParseResult parseResult) + { + var options = base.BindOptions(parseResult); + options.Property = parseResult.GetValueOrDefault(CosmosOptionDefinitions.Property.Name); + options.SearchPhrase = parseResult.GetValueOrDefault(CosmosOptionDefinitions.SearchPhrase.Name); + options.Count = parseResult.GetValueOrDefault(CosmosOptionDefinitions.Count.Name); + return options; + } + + public override async Task ExecuteAsync(CommandContext context, ParseResult parseResult, CancellationToken cancellationToken) + { + if (!Validate(parseResult.CommandResult, context.Response).IsValid) + { + return context.Response; + } + + var options = BindOptions(parseResult); + + try + { + var items = await _cosmosService.TextSearch( + options.Account!, + options.Database!, + options.Container!, + options.Property!, + options.SearchPhrase!, + options.Count ?? 10, + options.Subscription!, + options.AuthMethod ?? AuthMethod.Credential, + options.Tenant, + options.RetryPolicy, + cancellationToken); + + context.Response.Results = ResponseResult.Create( + new ItemTextSearchCommandResult(items ?? []), + CosmosJsonContext.Default.ItemTextSearchCommandResult); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error in {Operation}. Account: {Account}, Database: {Database}, Container: {Container}, Property: {Property}", + Name, options.Account, options.Database, options.Container, options.Property); + HandleException(context, ex); + } + + return context.Response; + } + + protected override string GetErrorMessage(Exception ex) => ex switch + { + CosmosException cosmosEx => cosmosEx.Message, + _ => base.GetErrorMessage(ex) + }; + + protected override HttpStatusCode GetStatusCode(Exception ex) => ex switch + { + CosmosException cosmosEx => cosmosEx.StatusCode, + _ => base.GetStatusCode(ex) + }; + + internal record ItemTextSearchCommandResult(List Items); +} diff --git a/tools/Azure.Mcp.Tools.Cosmos/src/Commands/Item/ItemVectorSearchCommand.cs b/tools/Azure.Mcp.Tools.Cosmos/src/Commands/Item/ItemVectorSearchCommand.cs new file mode 100644 index 0000000000..0e119e474a --- /dev/null +++ b/tools/Azure.Mcp.Tools.Cosmos/src/Commands/Item/ItemVectorSearchCommand.cs @@ -0,0 +1,210 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Globalization; +using System.Net; +using Azure.Mcp.Tools.Cosmos.Models; +using Azure.Mcp.Tools.Cosmos.Options; +using Azure.Mcp.Tools.Cosmos.Options.Item; +using Azure.Mcp.Tools.Cosmos.Services; +using Azure.Mcp.Tools.Cosmos.Validation; +using Microsoft.Azure.Cosmos; +using Microsoft.Extensions.Logging; +using Microsoft.Mcp.Core.Commands; +using Microsoft.Mcp.Core.Extensions; +using Microsoft.Mcp.Core.Models; +using Microsoft.Mcp.Core.Models.Command; + +namespace Azure.Mcp.Tools.Cosmos.Commands.Item; + +[CommandMetadata( + Id = "5e6f7a8b-9c0d-4e1f-a2b3-c4d5e6f7a8b9", + Name = "vector-search", + Title = "Vector Search Cosmos DB Documents", + Description = "Perform a vector similarity search on Cosmos DB. Use Azure OpenAI to generate an embedding by supplying --search-text along with --openai-endpoint and --embedding-deployment or provide a precomputed embedding via --embedding as comma-separated float values. The container must have a vector index on --vector-property.", + Destructive = false, + Idempotent = true, + OpenWorld = false, + ReadOnly = true, + Secret = false, + LocalRequired = false)] +public sealed class ItemVectorSearchCommand(ILogger logger, ICosmosService cosmosService) + : BaseContainerCommand() +{ + private readonly ILogger _logger = logger; + private readonly ICosmosService _cosmosService = cosmosService; + + protected override void RegisterOptions(Command command) + { + base.RegisterOptions(command); + command.Options.Add(CosmosOptionDefinitions.VectorProperty); + command.Options.Add(CosmosOptionDefinitions.SelectProperties); + command.Options.Add(CosmosOptionDefinitions.Count); + command.Options.Add(CosmosOptionDefinitions.Embedding); + command.Options.Add(CosmosOptionDefinitions.SearchText); + command.Options.Add(CosmosOptionDefinitions.OpenAIEndpoint); + command.Options.Add(CosmosOptionDefinitions.EmbeddingDeployment); + command.Options.Add(CosmosOptionDefinitions.EmbeddingDimensions); + + command.Validators.Add(result => + { + var vectorProperty = result.GetValueOrDefault(CosmosOptionDefinitions.VectorProperty.Name); + if (!PropertyValidator.IsValid(vectorProperty)) + { + result.AddError("--vector-property must be a dot-delimited identifier (letters, digits, and underscores only)."); + } + + var selectProperties = result.GetValueOrDefault(CosmosOptionDefinitions.SelectProperties.Name); + if (string.IsNullOrWhiteSpace(selectProperties) || selectProperties.Contains('*')) + { + result.AddError("--select-properties must be a comma-separated list of explicit property names (no '*' wildcards)."); + } + else + { + foreach (var prop in selectProperties.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries)) + { + if (!PropertyValidator.IsValid(prop)) + { + result.AddError($"--select-properties contains an invalid property '{prop}'. Use letters, digits, and underscores only."); + break; + } + } + } + + var count = result.GetValueOrDefault(CosmosOptionDefinitions.Count.Name); + if (count < 1 || count > 50) + { + result.AddError("--count must be between 1 and 50."); + } + + var embedding = result.GetValueOrDefault(CosmosOptionDefinitions.Embedding.Name); + var searchText = result.GetValueOrDefault(CosmosOptionDefinitions.SearchText.Name); + + if (string.IsNullOrWhiteSpace(embedding) && string.IsNullOrWhiteSpace(searchText)) + { + result.AddError("Either --embedding or --search-text must be supplied."); + return; + } + + if (!string.IsNullOrWhiteSpace(embedding) && !string.IsNullOrWhiteSpace(searchText)) + { + result.AddError("--embedding and --search-text are mutually exclusive."); + return; + } + + if (!string.IsNullOrWhiteSpace(searchText)) + { + var endpoint = result.GetValueOrDefault(CosmosOptionDefinitions.OpenAIEndpoint.Name); + var deployment = result.GetValueOrDefault(CosmosOptionDefinitions.EmbeddingDeployment.Name); + if (string.IsNullOrWhiteSpace(endpoint) || string.IsNullOrWhiteSpace(deployment)) + { + result.AddError("--openai-endpoint and --embedding-deployment are required when --search-text is supplied."); + } + } + }); + } + + protected override ItemVectorSearchOptions BindOptions(ParseResult parseResult) + { + var options = base.BindOptions(parseResult); + options.VectorProperty = parseResult.GetValueOrDefault(CosmosOptionDefinitions.VectorProperty.Name); + options.SelectProperties = parseResult.GetValueOrDefault(CosmosOptionDefinitions.SelectProperties.Name); + options.Count = parseResult.GetValueOrDefault(CosmosOptionDefinitions.Count.Name); + options.Embedding = parseResult.GetValueOrDefault(CosmosOptionDefinitions.Embedding.Name); + options.SearchText = parseResult.GetValueOrDefault(CosmosOptionDefinitions.SearchText.Name); + options.OpenAIEndpoint = parseResult.GetValueOrDefault(CosmosOptionDefinitions.OpenAIEndpoint.Name); + options.EmbeddingDeployment = parseResult.GetValueOrDefault(CosmosOptionDefinitions.EmbeddingDeployment.Name); + options.EmbeddingDimensions = parseResult.GetValueOrDefault(CosmosOptionDefinitions.EmbeddingDimensions.Name); + return options; + } + + public override async Task ExecuteAsync(CommandContext context, ParseResult parseResult, CancellationToken cancellationToken) + { + if (!Validate(parseResult.CommandResult, context.Response).IsValid) + { + return context.Response; + } + + var options = BindOptions(parseResult); + + try + { + var selectProperties = options.SelectProperties! + .Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); + + float[] embedding; + if (!string.IsNullOrWhiteSpace(options.Embedding)) + { + embedding = ParseEmbedding(options.Embedding!); + } + else + { + embedding = await _cosmosService.GenerateEmbedding( + options.SearchText!, + new EmbeddingRequest(options.OpenAIEndpoint!, options.EmbeddingDeployment!, options.EmbeddingDimensions), + options.Tenant, + cancellationToken); + } + + var items = await _cosmosService.VectorSearch( + options.Account!, + options.Database!, + options.Container!, + options.VectorProperty!, + selectProperties, + embedding, + options.Count ?? 10, + options.Subscription!, + options.AuthMethod ?? AuthMethod.Credential, + options.Tenant, + options.RetryPolicy, + cancellationToken); + + context.Response.Results = ResponseResult.Create( + new ItemVectorSearchCommandResult(items ?? []), + CosmosJsonContext.Default.ItemVectorSearchCommandResult); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error in {Operation}. Account: {Account}, Database: {Database}, Container: {Container}, VectorProperty: {VectorProperty}", + Name, options.Account, options.Database, options.Container, options.VectorProperty); + HandleException(context, ex); + } + + return context.Response; + } + + private static float[] ParseEmbedding(string value) + { + var parts = value.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); + if (parts.Length == 0) + { + throw new ArgumentException("--embedding must contain at least one number.", nameof(value)); + } + + var result = new float[parts.Length]; + for (var i = 0; i < parts.Length; i++) + { + if (!float.TryParse(parts[i], NumberStyles.Float, CultureInfo.InvariantCulture, out result[i])) + { + throw new ArgumentException($"--embedding contains a value that is not a valid number: '{parts[i]}'.", nameof(value)); + } + } + + return result; + } + + protected override string GetErrorMessage(Exception ex) => ex switch + { + CosmosException cosmosEx => cosmosEx.Message, + _ => base.GetErrorMessage(ex) + }; + + protected override HttpStatusCode GetStatusCode(Exception ex) => ex switch + { + CosmosException cosmosEx => cosmosEx.StatusCode, + _ => base.GetStatusCode(ex) + }; + + internal record ItemVectorSearchCommandResult(List Items); +} diff --git a/tools/Azure.Mcp.Tools.Cosmos/src/CosmosSetup.cs b/tools/Azure.Mcp.Tools.Cosmos/src/CosmosSetup.cs index ff8c831a8f..ecd5ed03ac 100644 --- a/tools/Azure.Mcp.Tools.Cosmos/src/CosmosSetup.cs +++ b/tools/Azure.Mcp.Tools.Cosmos/src/CosmosSetup.cs @@ -2,6 +2,8 @@ // Licensed under the MIT License. using Azure.Mcp.Tools.Cosmos.Commands; +using Azure.Mcp.Tools.Cosmos.Commands.Container; +using Azure.Mcp.Tools.Cosmos.Commands.Item; using Azure.Mcp.Tools.Cosmos.Services; using Microsoft.Extensions.DependencyInjection; using Microsoft.Mcp.Core.Areas; @@ -21,6 +23,11 @@ public void ConfigureServices(IServiceCollection services) services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); } public CommandGroup RegisterCommands(IServiceProvider serviceProvider) @@ -38,11 +45,20 @@ public CommandGroup RegisterCommands(IServiceProvider serviceProvider) var cosmosContainer = new CommandGroup("container", "Cosmos DB container operations - Commands for managing containers within your Cosmos DB databases."); databases.AddSubGroup(cosmosContainer); + // Schema operations on a container + var schema = new CommandGroup("schema", "Cosmos DB container schema operations - Commands for inferring the shape of documents inside a container."); + cosmosContainer.AddSubGroup(schema); + schema.AddCommand(serviceProvider); + // Create items subgroup for Cosmos - var cosmosItem = new CommandGroup("item", "Cosmos DB item operations - Commands for querying, creating, updating, and deleting documents within your Cosmos DB containers."); + var cosmosItem = new CommandGroup("item", "Cosmos DB item operations - Commands for querying, retrieving, and searching documents within your Cosmos DB containers."); cosmosContainer.AddSubGroup(cosmosItem); cosmosItem.AddCommand(serviceProvider); + cosmosItem.AddCommand(serviceProvider); + cosmosItem.AddCommand(serviceProvider); + cosmosItem.AddCommand(serviceProvider); + cosmosItem.AddCommand(serviceProvider); return cosmos; } diff --git a/tools/Azure.Mcp.Tools.Cosmos/src/Models/ContainerSchema.cs b/tools/Azure.Mcp.Tools.Cosmos/src/Models/ContainerSchema.cs new file mode 100644 index 0000000000..bed291dcd3 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Cosmos/src/Models/ContainerSchema.cs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.Mcp.Tools.Cosmos.Models; + +/// +/// Approximate schema inferred for a Cosmos DB container. +/// +/// Number of documents that were sampled. +/// Top-level properties discovered across the sampled documents. +public sealed record ContainerSchema(int SampleSize, IReadOnlyList Properties); diff --git a/tools/Azure.Mcp.Tools.Cosmos/src/Models/EmbeddingRequest.cs b/tools/Azure.Mcp.Tools.Cosmos/src/Models/EmbeddingRequest.cs new file mode 100644 index 0000000000..3c51b4ec70 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Cosmos/src/Models/EmbeddingRequest.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.Mcp.Tools.Cosmos.Models; + +/// +/// Configuration required to generate an embedding from free-form text via Azure OpenAI. +/// +/// Azure OpenAI endpoint, e.g., "https://my-endpoint.openai.azure.com/". +/// Name of the embedding deployment. +/// +/// Optional embedding dimensions to request. Only honored by models that support custom dimensions +/// (for example, text-embedding-3-*). When null, the model's native dimensionality is used. +/// +public sealed record EmbeddingRequest(string Endpoint, string DeploymentName, int? Dimensions); diff --git a/tools/Azure.Mcp.Tools.Cosmos/src/Models/SchemaProperty.cs b/tools/Azure.Mcp.Tools.Cosmos/src/Models/SchemaProperty.cs new file mode 100644 index 0000000000..a7b1641ebe --- /dev/null +++ b/tools/Azure.Mcp.Tools.Cosmos/src/Models/SchemaProperty.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Azure.Mcp.Tools.Cosmos.Models; + +/// +/// Approximate schema inferred from a sample of documents in a Cosmos DB container. +/// +/// Property name as it appears on the document. +/// Pipe-delimited list of JSON value kinds observed (e.g., "string" or "number | null"). +/// Number of sampled documents that contained this property. +/// Total number of sampled documents. +public sealed record SchemaProperty(string Name, string Type, int AppearedIn, int SampleSize); diff --git a/tools/Azure.Mcp.Tools.Cosmos/src/Options/Container/ContainerSchemaGetOptions.cs b/tools/Azure.Mcp.Tools.Cosmos/src/Options/Container/ContainerSchemaGetOptions.cs new file mode 100644 index 0000000000..31250e5e08 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Cosmos/src/Options/Container/ContainerSchemaGetOptions.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Azure.Mcp.Tools.Cosmos.Options.Container; + +public class ContainerSchemaGetOptions : BaseContainerOptions +{ + [JsonPropertyName(CosmosOptionDefinitions.SampleSizeName)] + public int? SampleSize { get; set; } +} diff --git a/tools/Azure.Mcp.Tools.Cosmos/src/Options/CosmosOptionDefinitions.cs b/tools/Azure.Mcp.Tools.Cosmos/src/Options/CosmosOptionDefinitions.cs index ecb80daa3f..a9b141028f 100644 --- a/tools/Azure.Mcp.Tools.Cosmos/src/Options/CosmosOptionDefinitions.cs +++ b/tools/Azure.Mcp.Tools.Cosmos/src/Options/CosmosOptionDefinitions.cs @@ -9,6 +9,19 @@ public static class CosmosOptionDefinitions public const string DatabaseName = "database"; public const string ContainerName = "container"; public const string QueryText = "query"; + public const string CountName = "count"; + public const string SampleSizeName = "sample-size"; + public const string ItemIdName = "id"; + public const string PartitionKeyName = "partition-key"; + public const string PropertyName = "property"; + public const string SearchPhraseName = "search-phrase"; + public const string VectorPropertyName = "vector-property"; + public const string SelectPropertiesName = "select-properties"; + public const string EmbeddingName = "embedding"; + public const string SearchTextName = "search-text"; + public const string OpenAIEndpointName = "openai-endpoint"; + public const string EmbeddingDeploymentName = "embedding-deployment"; + public const string EmbeddingDimensionsName = "embedding-dimensions"; public static readonly Option Account = new( $"--{AccountName}" @@ -56,4 +69,104 @@ public static class CosmosOptionDefinitions Required = false, DefaultValueFactory = _ => "SELECT * FROM c" }; + + public static readonly Option Count = new( + $"--{CountName}" + ) + { + Description = "Maximum number of documents to return (1-20). Defaults to 10.", + Required = false, + DefaultValueFactory = _ => 10 + }; + + public static readonly Option SampleSize = new( + $"--{SampleSizeName}" + ) + { + Description = "Number of documents to sample for schema inference (1-20). Defaults to 10.", + Required = false, + DefaultValueFactory = _ => 10 + }; + + public static readonly Option ItemId = new( + $"--{ItemIdName}" + ) + { + Description = "The id of the document to retrieve.", + Required = true + }; + + public static readonly Option PartitionKey = new( + $"--{PartitionKeyName}" + ) + { + Description = "Optional partition key value for the document. When provided, a point read is used (cheaper than a cross-partition query)." + }; + + public static readonly Option Property = new( + $"--{PropertyName}" + ) + { + Description = "The document property to search. Supports dot notation (e.g., 'name' or 'profile.name'). Allowed characters: letters, digits, and underscores.", + Required = true + }; + + public static readonly Option SearchPhrase = new( + $"--{SearchPhraseName}" + ) + { + Description = "The phrase to search for. Passed as a parameterized value to a Cosmos DB FullTextContains query. The container must have a full-text index on the property.", + Required = true + }; + + public static readonly Option VectorProperty = new( + $"--{VectorPropertyName}" + ) + { + Description = "The document property containing the vector embedding (e.g., 'embedding' or 'metadata.vector'). The container must have a vector index on this property.", + Required = true + }; + + public static readonly Option SelectProperties = new( + $"--{SelectPropertiesName}" + ) + { + Description = "Comma-separated list of properties to project in the result (e.g., 'id,title,metadata.author'). Wildcards ('*') are not allowed; explicit property names are required.", + Required = true + }; + + public static readonly Option Embedding = new( + $"--{EmbeddingName}" + ) + { + Description = "Comma-separated list of floating-point numbers representing a precomputed embedding vector (e.g., '0.12,-0.34,0.56'). Mutually exclusive with --search-text." + }; + + public static readonly Option SearchText = new( + $"--{SearchTextName}" + ) + { + Description = "Free-form text to embed via Azure OpenAI before searching. Requires --openai-endpoint and --embedding-deployment. Mutually exclusive with --embedding." + }; + + public static readonly Option OpenAIEndpoint = new( + $"--{OpenAIEndpointName}" + ) + { + Description = "Azure OpenAI endpoint (e.g., 'https://my-endpoint.openai.azure.com/') used to generate the embedding when --search-text is supplied." + }; + + public static readonly Option EmbeddingDeployment = new( + $"--{EmbeddingDeploymentName}" + ) + { + Description = "Name of the Azure OpenAI embedding deployment (e.g., 'text-embedding-ada-002') used when --search-text is supplied." + }; + + public static readonly Option EmbeddingDimensions = new( + $"--{EmbeddingDimensionsName}" + ) + { + Description = "Optional embedding dimensions to request from the model (only honored by models that support custom dimensions, e.g., text-embedding-3-*)." + }; } diff --git a/tools/Azure.Mcp.Tools.Cosmos/src/Options/Item/ItemGetOptions.cs b/tools/Azure.Mcp.Tools.Cosmos/src/Options/Item/ItemGetOptions.cs new file mode 100644 index 0000000000..289533a05e --- /dev/null +++ b/tools/Azure.Mcp.Tools.Cosmos/src/Options/Item/ItemGetOptions.cs @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Azure.Mcp.Tools.Cosmos.Options.Item; + +public class ItemGetOptions : BaseContainerOptions +{ + [JsonPropertyName(CosmosOptionDefinitions.ItemIdName)] + public string? Id { get; set; } + + [JsonPropertyName(CosmosOptionDefinitions.PartitionKeyName)] + public string? PartitionKey { get; set; } +} diff --git a/tools/Azure.Mcp.Tools.Cosmos/src/Options/Item/ItemListRecentOptions.cs b/tools/Azure.Mcp.Tools.Cosmos/src/Options/Item/ItemListRecentOptions.cs new file mode 100644 index 0000000000..8e0635c593 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Cosmos/src/Options/Item/ItemListRecentOptions.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Azure.Mcp.Tools.Cosmos.Options.Item; + +public class ItemListRecentOptions : BaseContainerOptions +{ + [JsonPropertyName(CosmosOptionDefinitions.CountName)] + public int? Count { get; set; } +} diff --git a/tools/Azure.Mcp.Tools.Cosmos/src/Options/Item/ItemTextSearchOptions.cs b/tools/Azure.Mcp.Tools.Cosmos/src/Options/Item/ItemTextSearchOptions.cs new file mode 100644 index 0000000000..71b97e6823 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Cosmos/src/Options/Item/ItemTextSearchOptions.cs @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Azure.Mcp.Tools.Cosmos.Options.Item; + +public class ItemTextSearchOptions : BaseContainerOptions +{ + [JsonPropertyName(CosmosOptionDefinitions.PropertyName)] + public string? Property { get; set; } + + [JsonPropertyName(CosmosOptionDefinitions.SearchPhraseName)] + public string? SearchPhrase { get; set; } + + [JsonPropertyName(CosmosOptionDefinitions.CountName)] + public int? Count { get; set; } +} diff --git a/tools/Azure.Mcp.Tools.Cosmos/src/Options/Item/ItemVectorSearchOptions.cs b/tools/Azure.Mcp.Tools.Cosmos/src/Options/Item/ItemVectorSearchOptions.cs new file mode 100644 index 0000000000..49c342d2b9 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Cosmos/src/Options/Item/ItemVectorSearchOptions.cs @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.Json.Serialization; + +namespace Azure.Mcp.Tools.Cosmos.Options.Item; + +public class ItemVectorSearchOptions : BaseContainerOptions +{ + [JsonPropertyName(CosmosOptionDefinitions.VectorPropertyName)] + public string? VectorProperty { get; set; } + + [JsonPropertyName(CosmosOptionDefinitions.SelectPropertiesName)] + public string? SelectProperties { get; set; } + + [JsonPropertyName(CosmosOptionDefinitions.CountName)] + public int? Count { get; set; } + + [JsonPropertyName(CosmosOptionDefinitions.EmbeddingName)] + public string? Embedding { get; set; } + + [JsonPropertyName(CosmosOptionDefinitions.SearchTextName)] + public string? SearchText { get; set; } + + [JsonPropertyName(CosmosOptionDefinitions.OpenAIEndpointName)] + public string? OpenAIEndpoint { get; set; } + + [JsonPropertyName(CosmosOptionDefinitions.EmbeddingDeploymentName)] + public string? EmbeddingDeployment { get; set; } + + [JsonPropertyName(CosmosOptionDefinitions.EmbeddingDimensionsName)] + public int? EmbeddingDimensions { get; set; } +} diff --git a/tools/Azure.Mcp.Tools.Cosmos/src/Services/CosmosService.cs b/tools/Azure.Mcp.Tools.Cosmos/src/Services/CosmosService.cs index 62956a58e1..6de9056c0f 100644 --- a/tools/Azure.Mcp.Tools.Cosmos/src/Services/CosmosService.cs +++ b/tools/Azure.Mcp.Tools.Cosmos/src/Services/CosmosService.cs @@ -1,9 +1,12 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using Azure.AI.OpenAI; using Azure.Mcp.Core.Services.Azure; using Azure.Mcp.Core.Services.Azure.Subscription; using Azure.Mcp.Core.Services.Azure.Tenant; +using Azure.Mcp.Tools.Cosmos.Models; +using Azure.Mcp.Tools.Cosmos.Validation; using Azure.ResourceManager.CosmosDB; using Microsoft.Azure.Cosmos; using Microsoft.Extensions.Logging; @@ -295,6 +298,432 @@ public async Task> QueryItems( return items; } + public async Task GetApproximateSchema( + string accountName, + string databaseName, + string containerName, + int sampleSize, + string subscription, + AuthMethod authMethod = AuthMethod.Credential, + string? tenant = null, + RetryPolicyOptions? retryPolicy = null, + CancellationToken cancellationToken = default) + { + ValidateRequiredParameters( + (nameof(accountName), accountName), + (nameof(databaseName), databaseName), + (nameof(containerName), containerName), + (nameof(subscription), subscription)); + + if (sampleSize < 1 || sampleSize > 20) + { + throw new ArgumentOutOfRangeException(nameof(sampleSize), sampleSize, "Sample size must be between 1 and 20."); + } + + var client = await GetCosmosClientAsync(accountName, subscription, authMethod, tenant, retryPolicy, cancellationToken); + var container = client.GetContainer(databaseName, containerName); + + var queryDef = new QueryDefinition($"SELECT TOP {sampleSize} * FROM c"); + var iterator = container.GetItemQueryStreamIterator( + queryDef, + requestOptions: new QueryRequestOptions { MaxItemCount = sampleSize }); + + var typeMap = new Dictionary>(StringComparer.Ordinal); + var countMap = new Dictionary(StringComparer.Ordinal); + var sampled = 0; + + while (iterator.HasMoreResults && sampled < sampleSize) + { + using var response = await iterator.ReadNextAsync(cancellationToken); + if (!response.IsSuccessStatusCode) + { + throw new Exception(response.ErrorMessage); + } + + using var doc = await JsonDocument.ParseAsync(response.Content, cancellationToken: cancellationToken); + if (!doc.RootElement.TryGetProperty("Documents", out var docs)) + { + continue; + } + + foreach (var item in docs.EnumerateArray()) + { + if (item.ValueKind != JsonValueKind.Object) + { + continue; + } + + sampled++; + + foreach (var prop in item.EnumerateObject()) + { + var typeName = prop.Value.ValueKind switch + { + JsonValueKind.String => "string", + JsonValueKind.Number => "number", + JsonValueKind.True or JsonValueKind.False => "boolean", + JsonValueKind.Object => "object", + JsonValueKind.Array => "array", + JsonValueKind.Null => "null", + _ => "unknown", + }; + + if (!typeMap.TryGetValue(prop.Name, out var set)) + { + set = new HashSet(StringComparer.Ordinal); + typeMap[prop.Name] = set; + } + set.Add(typeName); + + countMap.TryGetValue(prop.Name, out var current); + countMap[prop.Name] = current + 1; + } + + if (sampled >= sampleSize) + { + break; + } + } + } + + var properties = typeMap + .OrderBy(kvp => kvp.Key, StringComparer.Ordinal) + .Select(kvp => new SchemaProperty( + kvp.Key, + string.Join(" | ", kvp.Value.OrderBy(t => t, StringComparer.Ordinal)), + countMap.TryGetValue(kvp.Key, out var c) ? c : 0, + sampled)) + .ToList(); + + return new ContainerSchema(sampled, properties); + } + + public async Task> GetRecentItems( + string accountName, + string databaseName, + string containerName, + int count, + string subscription, + AuthMethod authMethod = AuthMethod.Credential, + string? tenant = null, + RetryPolicyOptions? retryPolicy = null, + CancellationToken cancellationToken = default) + { + ValidateRequiredParameters( + (nameof(accountName), accountName), + (nameof(databaseName), databaseName), + (nameof(containerName), containerName), + (nameof(subscription), subscription)); + + if (count < 1 || count > 20) + { + throw new ArgumentOutOfRangeException(nameof(count), count, "Count must be between 1 and 20."); + } + + var client = await GetCosmosClientAsync(accountName, subscription, authMethod, tenant, retryPolicy, cancellationToken); + var container = client.GetContainer(databaseName, containerName); + + var queryDef = new QueryDefinition($"SELECT TOP {count} * FROM c ORDER BY c._ts DESC"); + var iterator = container.GetItemQueryStreamIterator( + queryDef, + requestOptions: new QueryRequestOptions { MaxItemCount = count }); + + var results = new List(count); + while (iterator.HasMoreResults && results.Count < count) + { + using var response = await iterator.ReadNextAsync(cancellationToken); + if (!response.IsSuccessStatusCode) + { + throw new Exception(response.ErrorMessage); + } + + using var doc = await JsonDocument.ParseAsync(response.Content, cancellationToken: cancellationToken); + if (!doc.RootElement.TryGetProperty("Documents", out var docs)) + { + continue; + } + + foreach (var item in docs.EnumerateArray()) + { + results.Add(item.Clone()); + if (results.Count >= count) + { + break; + } + } + } + + return results; + } + + public async Task GetItem( + string accountName, + string databaseName, + string containerName, + string id, + string? partitionKey, + string subscription, + AuthMethod authMethod = AuthMethod.Credential, + string? tenant = null, + RetryPolicyOptions? retryPolicy = null, + CancellationToken cancellationToken = default) + { + ValidateRequiredParameters( + (nameof(accountName), accountName), + (nameof(databaseName), databaseName), + (nameof(containerName), containerName), + (nameof(id), id), + (nameof(subscription), subscription)); + + var client = await GetCosmosClientAsync(accountName, subscription, authMethod, tenant, retryPolicy, cancellationToken); + var container = client.GetContainer(databaseName, containerName); + + if (!string.IsNullOrEmpty(partitionKey)) + { + try + { + using var response = await container.ReadItemStreamAsync(id, new PartitionKey(partitionKey), cancellationToken: cancellationToken); + if (response.StatusCode == System.Net.HttpStatusCode.NotFound) + { + return null; + } + + if (!response.IsSuccessStatusCode) + { + throw new Exception(response.ErrorMessage); + } + + using var doc = await JsonDocument.ParseAsync(response.Content, cancellationToken: cancellationToken); + return doc.RootElement.Clone(); + } + catch (CosmosException ex) when (ex.StatusCode == System.Net.HttpStatusCode.NotFound) + { + return null; + } + } + + var queryDef = new QueryDefinition("SELECT * FROM c WHERE c.id = @id").WithParameter("@id", id); + var iterator = container.GetItemQueryStreamIterator( + queryDef, + requestOptions: new QueryRequestOptions { MaxItemCount = 1 }); + + while (iterator.HasMoreResults) + { + using var response = await iterator.ReadNextAsync(cancellationToken); + if (!response.IsSuccessStatusCode) + { + throw new Exception(response.ErrorMessage); + } + + using var doc = await JsonDocument.ParseAsync(response.Content, cancellationToken: cancellationToken); + if (doc.RootElement.TryGetProperty("Documents", out var docs) && docs.GetArrayLength() > 0) + { + return docs[0].Clone(); + } + } + + return null; + } + + public async Task> TextSearch( + string accountName, + string databaseName, + string containerName, + string property, + string searchPhrase, + int count, + string subscription, + AuthMethod authMethod = AuthMethod.Credential, + string? tenant = null, + RetryPolicyOptions? retryPolicy = null, + CancellationToken cancellationToken = default) + { + ValidateRequiredParameters( + (nameof(accountName), accountName), + (nameof(databaseName), databaseName), + (nameof(containerName), containerName), + (nameof(property), property), + (nameof(searchPhrase), searchPhrase), + (nameof(subscription), subscription)); + + if (!PropertyValidator.IsValid(property)) + { + throw new ArgumentException( + "Invalid property name. Use dot notation with letters, digits, and underscores only (e.g., 'name' or 'profile.name').", + nameof(property)); + } + + if (count < 1 || count > 20) + { + throw new ArgumentOutOfRangeException(nameof(count), count, "Count must be between 1 and 20."); + } + + var client = await GetCosmosClientAsync(accountName, subscription, authMethod, tenant, retryPolicy, cancellationToken); + var container = client.GetContainer(databaseName, containerName); + + var queryDef = new QueryDefinition( + $"SELECT TOP {count} * FROM c WHERE FullTextContains(c.{property}, @searchPhrase)") + .WithParameter("@searchPhrase", searchPhrase); + + var iterator = container.GetItemQueryStreamIterator( + queryDef, + requestOptions: new QueryRequestOptions { MaxItemCount = count }); + + var results = new List(count); + while (iterator.HasMoreResults && results.Count < count) + { + using var response = await iterator.ReadNextAsync(cancellationToken); + if (!response.IsSuccessStatusCode) + { + throw new Exception(response.ErrorMessage); + } + + using var doc = await JsonDocument.ParseAsync(response.Content, cancellationToken: cancellationToken); + if (!doc.RootElement.TryGetProperty("Documents", out var docs)) + { + continue; + } + + foreach (var item in docs.EnumerateArray()) + { + results.Add(item.Clone()); + if (results.Count >= count) + { + break; + } + } + } + + return results; + } + + public async Task> VectorSearch( + string accountName, + string databaseName, + string containerName, + string vectorProperty, + IReadOnlyList selectProperties, + IReadOnlyList embedding, + int count, + string subscription, + AuthMethod authMethod = AuthMethod.Credential, + string? tenant = null, + RetryPolicyOptions? retryPolicy = null, + CancellationToken cancellationToken = default) + { + ValidateRequiredParameters( + (nameof(accountName), accountName), + (nameof(databaseName), databaseName), + (nameof(containerName), containerName), + (nameof(vectorProperty), vectorProperty), + (nameof(subscription), subscription)); + + if (!PropertyValidator.IsValid(vectorProperty)) + { + throw new ArgumentException( + "Invalid vector property name. Use dot notation with letters, digits, and underscores only (e.g., 'embedding' or 'metadata.vector').", + nameof(vectorProperty)); + } + + if (selectProperties == null || selectProperties.Count == 0) + { + throw new ArgumentException("At least one property must be supplied in properties to select.", nameof(selectProperties)); + } + + foreach (var prop in selectProperties) + { + if (!PropertyValidator.IsValid(prop)) + { + throw new ArgumentException( + $"Invalid property name '{prop}' in properties to select. Use dot notation with letters, digits, and underscores only (e.g., 'name' or 'profile.name').", + nameof(selectProperties)); + } + } + + if (embedding == null || embedding.Count == 0) + { + throw new ArgumentException("Embedding vector must contain at least one value.", nameof(embedding)); + } + + if (count < 1 || count > 50) + { + throw new ArgumentOutOfRangeException(nameof(count), count, "Count must be between 1 and 50."); + } + + var client = await GetCosmosClientAsync(accountName, subscription, authMethod, tenant, retryPolicy, cancellationToken); + var container = client.GetContainer(databaseName, containerName); + + var selectClause = string.Join(", ", selectProperties.Select(p => $"c.{p}")); + var embeddingArray = embedding is float[] arr ? arr : embedding.ToArray(); + + var queryDef = new QueryDefinition( + $"SELECT TOP @topN {selectClause}, VectorDistance(c.{vectorProperty}, @embedding) AS _score " + + $"FROM c ORDER BY VectorDistance(c.{vectorProperty}, @embedding)") + .WithParameter("@topN", count) + .WithParameter("@embedding", embeddingArray); + + var iterator = container.GetItemQueryStreamIterator( + queryDef, + requestOptions: new QueryRequestOptions { MaxItemCount = count }); + + var results = new List(count); + while (iterator.HasMoreResults && results.Count < count) + { + using var response = await iterator.ReadNextAsync(cancellationToken); + if (!response.IsSuccessStatusCode) + { + throw new Exception(response.ErrorMessage); + } + + using var doc = await JsonDocument.ParseAsync(response.Content, cancellationToken: cancellationToken); + if (!doc.RootElement.TryGetProperty("Documents", out var docs)) + { + continue; + } + + foreach (var item in docs.EnumerateArray()) + { + results.Add(item.Clone()); + if (results.Count >= count) + { + break; + } + } + } + + return results; + } + + public async Task GenerateEmbedding( + string text, + EmbeddingRequest request, + string? tenant = null, + CancellationToken cancellationToken = default) + { + ValidateRequiredParameters( + (nameof(text), text), + (nameof(request.Endpoint), request?.Endpoint), + (nameof(request.DeploymentName), request?.DeploymentName)); + + var credential = await GetCredential(tenant, cancellationToken); + var clientOptions = new AzureOpenAIClientOptions + { + Transport = new System.ClientModel.Primitives.HttpClientPipelineTransport(_httpClientFactory.CreateClient()), + }; + + var openAi = new AzureOpenAIClient(new Uri(request!.Endpoint!), credential, clientOptions); + var embeddingClient = openAi.GetEmbeddingClient(request.DeploymentName); + + var response = request.Dimensions.HasValue + ? await embeddingClient.GenerateEmbeddingAsync( + text, + new OpenAI.Embeddings.EmbeddingGenerationOptions { Dimensions = request.Dimensions.Value }, + cancellationToken) + : await embeddingClient.GenerateEmbeddingAsync(text, cancellationToken: cancellationToken); + + return response.Value.ToFloats().ToArray(); + } + internal static (string Query, List<(string Name, string Value)> Parameters) ParameterizeStringLiterals(string query) => SqlQueryParameterizer.Parameterize(query, SqlQueryParameterizer.SqlDialect.Standard); diff --git a/tools/Azure.Mcp.Tools.Cosmos/src/Services/ICosmosService.cs b/tools/Azure.Mcp.Tools.Cosmos/src/Services/ICosmosService.cs index dba582bf3c..ec62eff389 100644 --- a/tools/Azure.Mcp.Tools.Cosmos/src/Services/ICosmosService.cs +++ b/tools/Azure.Mcp.Tools.Cosmos/src/Services/ICosmosService.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using Azure.Mcp.Tools.Cosmos.Models; using Microsoft.Mcp.Core.Models; using Microsoft.Mcp.Core.Options; @@ -41,4 +42,71 @@ Task> QueryItems( string? tenant = null, RetryPolicyOptions? retryPolicy = null, CancellationToken cancellationToken = default); + + Task GetApproximateSchema( + string accountName, + string databaseName, + string containerName, + int sampleSize, + string subscription, + AuthMethod authMethod = AuthMethod.Credential, + string? tenant = null, + RetryPolicyOptions? retryPolicy = null, + CancellationToken cancellationToken = default); + + Task> GetRecentItems( + string accountName, + string databaseName, + string containerName, + int count, + string subscription, + AuthMethod authMethod = AuthMethod.Credential, + string? tenant = null, + RetryPolicyOptions? retryPolicy = null, + CancellationToken cancellationToken = default); + + Task GetItem( + string accountName, + string databaseName, + string containerName, + string id, + string? partitionKey, + string subscription, + AuthMethod authMethod = AuthMethod.Credential, + string? tenant = null, + RetryPolicyOptions? retryPolicy = null, + CancellationToken cancellationToken = default); + + Task> TextSearch( + string accountName, + string databaseName, + string containerName, + string property, + string searchPhrase, + int count, + string subscription, + AuthMethod authMethod = AuthMethod.Credential, + string? tenant = null, + RetryPolicyOptions? retryPolicy = null, + CancellationToken cancellationToken = default); + + Task> VectorSearch( + string accountName, + string databaseName, + string containerName, + string vectorProperty, + IReadOnlyList selectProperties, + IReadOnlyList embedding, + int count, + string subscription, + AuthMethod authMethod = AuthMethod.Credential, + string? tenant = null, + RetryPolicyOptions? retryPolicy = null, + CancellationToken cancellationToken = default); + + Task GenerateEmbedding( + string text, + EmbeddingRequest request, + string? tenant = null, + CancellationToken cancellationToken = default); } diff --git a/tools/Azure.Mcp.Tools.Cosmos/src/Validation/PropertyValidator.cs b/tools/Azure.Mcp.Tools.Cosmos/src/Validation/PropertyValidator.cs new file mode 100644 index 0000000000..47537e466a --- /dev/null +++ b/tools/Azure.Mcp.Tools.Cosmos/src/Validation/PropertyValidator.cs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Text.RegularExpressions; +using Microsoft.Mcp.Core.Helpers; + +namespace Azure.Mcp.Tools.Cosmos.Validation; + +/// +/// Validates property identifiers used inside SQL fragments that cannot be parameterized, +/// such as the property names interpolated into FullTextContains(c.{property}, ...) +/// and VectorDistance(c.{vector}, ...). +/// +internal static class PropertyValidator +{ + private static readonly Regex PropertyPattern = RegexHelper.CreateRegex( + @"^[A-Za-z_][A-Za-z0-9_]*(\.[A-Za-z_][A-Za-z0-9_]*)*$", + RegexOptions.Compiled); + + /// + /// Returns true if the value is a safe dot-delimited property identifier + /// (letters, digits, and underscores only). + /// + public static bool IsValid(string? value) => + !string.IsNullOrWhiteSpace(value) && PropertyPattern.IsMatch(value); +} diff --git a/tools/Azure.Mcp.Tools.Cosmos/tests/Azure.Mcp.Tools.Cosmos.UnitTests/Container/ContainerSchemaGetCommandTests.cs b/tools/Azure.Mcp.Tools.Cosmos/tests/Azure.Mcp.Tools.Cosmos.UnitTests/Container/ContainerSchemaGetCommandTests.cs new file mode 100644 index 0000000000..e2bd46973d --- /dev/null +++ b/tools/Azure.Mcp.Tools.Cosmos/tests/Azure.Mcp.Tools.Cosmos.UnitTests/Container/ContainerSchemaGetCommandTests.cs @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Net; +using Azure.Mcp.Tools.Cosmos.Commands; +using Azure.Mcp.Tools.Cosmos.Commands.Container; +using Azure.Mcp.Tools.Cosmos.Models; +using Azure.Mcp.Tools.Cosmos.Services; +using Microsoft.Mcp.Core.Models; +using Microsoft.Mcp.Core.Options; +using Microsoft.Mcp.Tests.Client; +using NSubstitute; +using NSubstitute.ExceptionExtensions; +using Xunit; + +namespace Azure.Mcp.Tools.Cosmos.UnitTests.Container; + +public class ContainerSchemaGetCommandTests + : CommandUnitTestsBase +{ + [Fact] + public void Name_IsCorrect() => Assert.Equal("get", Command.Name); + + [Fact] + public void Metadata_IsReadOnly() + { + Assert.True(Command.Metadata.ReadOnly); + Assert.False(Command.Metadata.Destructive); + } + + [Fact] + public async Task ExecuteAsync_ReturnsSchema_OnSuccess() + { + var properties = new List + { + new("id", "string", 5, 5), + new("name", "string", 4, 5), + }; + Cosmos.Models.ContainerSchema schema = new(5, properties); + + Service.GetApproximateSchema( + Arg.Is("acct"), + Arg.Is("db"), + Arg.Is("c"), + Arg.Is(5), + Arg.Is("sub"), + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()) + .Returns(schema); + + var response = await ExecuteCommandAsync( + "--subscription", "sub", + "--account", "acct", + "--database", "db", + "--container", "c", + "--sample-size", "5"); + + var result = ValidateAndDeserializeResponse(response, CosmosJsonContext.Default.ContainerSchemaGetCommandResult); + Assert.Equal(5, result.SampleSize); + Assert.Equal(2, result.Properties.Count); + } + + [Theory] + [InlineData("0")] + [InlineData("101")] + public async Task ExecuteAsync_RejectsOutOfRangeSampleSize(string sampleSize) + { + var response = await ExecuteCommandAsync( + "--subscription", "sub", + "--account", "acct", + "--database", "db", + "--container", "c", + "--sample-size", sampleSize); + + Assert.Equal(HttpStatusCode.BadRequest, response.Status); + Assert.Contains("sample-size", response.Message, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task ExecuteAsync_PropagatesServiceErrors() + { + Service.GetApproximateSchema( + Arg.Any(), Arg.Any(), Arg.Any(), Arg.Any(), + Arg.Any(), Arg.Any(), Arg.Any(), + Arg.Any(), Arg.Any()) + .ThrowsAsync(new InvalidOperationException("boom")); + + var response = await ExecuteCommandAsync( + "--subscription", "sub", + "--account", "acct", + "--database", "db", + "--container", "c"); + + Assert.NotEqual(HttpStatusCode.OK, response.Status); + Assert.Contains("boom", response.Message); + } +} diff --git a/tools/Azure.Mcp.Tools.Cosmos/tests/Azure.Mcp.Tools.Cosmos.UnitTests/Item/ItemGetCommandTests.cs b/tools/Azure.Mcp.Tools.Cosmos/tests/Azure.Mcp.Tools.Cosmos.UnitTests/Item/ItemGetCommandTests.cs new file mode 100644 index 0000000000..bab0bef46b --- /dev/null +++ b/tools/Azure.Mcp.Tools.Cosmos/tests/Azure.Mcp.Tools.Cosmos.UnitTests/Item/ItemGetCommandTests.cs @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Net; +using System.Text.Json; +using Azure.Mcp.Tools.Cosmos.Commands; +using Azure.Mcp.Tools.Cosmos.Commands.Item; +using Azure.Mcp.Tools.Cosmos.Services; +using Microsoft.Mcp.Core.Models; +using Microsoft.Mcp.Core.Options; +using Microsoft.Mcp.Tests.Client; +using NSubstitute; +using Xunit; + +namespace Azure.Mcp.Tools.Cosmos.UnitTests.Item; + +public class ItemGetCommandTests + : CommandUnitTestsBase +{ + [Fact] + public void Name_IsCorrect() => Assert.Equal("get", Command.Name); + + [Fact] + public async Task ExecuteAsync_ReturnsItem_OnSuccess() + { + var item = JsonDocument.Parse("{\"id\":\"abc\",\"value\":42}").RootElement.Clone(); + + Service.GetItem( + Arg.Is("acct"), Arg.Is("db"), Arg.Is("c"), Arg.Is("abc"), + Arg.Is("pk1"), + Arg.Is("sub"), Arg.Any(), Arg.Any(), + Arg.Any(), Arg.Any()) + .Returns(item); + + var response = await ExecuteCommandAsync( + "--subscription", "sub", + "--account", "acct", + "--database", "db", + "--container", "c", + "--id", "abc", + "--partition-key", "pk1"); + + var result = ValidateAndDeserializeResponse(response, CosmosJsonContext.Default.ItemGetCommandResult); + Assert.NotNull(result.Item); + Assert.Equal("abc", result.Item.Value.GetProperty("id").GetString()); + } + + [Fact] + public async Task ExecuteAsync_ReturnsNullItem_WhenNotFound() + { + Service.GetItem( + Arg.Any(), Arg.Any(), Arg.Any(), Arg.Any(), + Arg.Any(), + Arg.Any(), Arg.Any(), Arg.Any(), + Arg.Any(), Arg.Any()) + .Returns((JsonElement?)null); + + var response = await ExecuteCommandAsync( + "--subscription", "sub", + "--account", "acct", + "--database", "db", + "--container", "c", + "--id", "missing"); + + var result = ValidateAndDeserializeResponse(response, CosmosJsonContext.Default.ItemGetCommandResult); + Assert.Null(result.Item); + } + + [Fact] + public async Task ExecuteAsync_FailsWithoutId() + { + var response = await ExecuteCommandAsync( + "--subscription", "sub", + "--account", "acct", + "--database", "db", + "--container", "c"); + + Assert.Equal(HttpStatusCode.BadRequest, response.Status); + } +} diff --git a/tools/Azure.Mcp.Tools.Cosmos/tests/Azure.Mcp.Tools.Cosmos.UnitTests/Item/ItemListRecentCommandTests.cs b/tools/Azure.Mcp.Tools.Cosmos/tests/Azure.Mcp.Tools.Cosmos.UnitTests/Item/ItemListRecentCommandTests.cs new file mode 100644 index 0000000000..14a6533929 --- /dev/null +++ b/tools/Azure.Mcp.Tools.Cosmos/tests/Azure.Mcp.Tools.Cosmos.UnitTests/Item/ItemListRecentCommandTests.cs @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Net; +using System.Text.Json; +using Azure.Mcp.Tools.Cosmos.Commands; +using Azure.Mcp.Tools.Cosmos.Commands.Item; +using Azure.Mcp.Tools.Cosmos.Services; +using Microsoft.Mcp.Core.Models; +using Microsoft.Mcp.Core.Options; +using Microsoft.Mcp.Tests.Client; +using NSubstitute; +using Xunit; + +namespace Azure.Mcp.Tools.Cosmos.UnitTests.Item; + +public class ItemListRecentCommandTests + : CommandUnitTestsBase +{ + [Fact] + public void Name_IsCorrect() => Assert.Equal("list-recent", Command.Name); + + [Fact] + public async Task ExecuteAsync_ReturnsItems_OnSuccess() + { + var items = new List + { + JsonDocument.Parse("{\"id\":\"a\"}").RootElement.Clone(), + JsonDocument.Parse("{\"id\":\"b\"}").RootElement.Clone(), + }; + + Service.GetRecentItems( + Arg.Is("acct"), Arg.Is("db"), Arg.Is("c"), Arg.Is(2), + Arg.Is("sub"), Arg.Any(), Arg.Any(), + Arg.Any(), Arg.Any()) + .Returns(items); + + var response = await ExecuteCommandAsync( + "--subscription", "sub", + "--account", "acct", + "--database", "db", + "--container", "c", + "--count", "2"); + + var result = ValidateAndDeserializeResponse(response, CosmosJsonContext.Default.ItemListRecentCommandResult); + Assert.Equal(2, result.Items.Count); + } + + [Theory] + [InlineData("0")] + [InlineData("101")] + public async Task ExecuteAsync_RejectsOutOfRangeCount(string count) + { + var response = await ExecuteCommandAsync( + "--subscription", "sub", + "--account", "acct", + "--database", "db", + "--container", "c", + "--count", count); + + Assert.Equal(HttpStatusCode.BadRequest, response.Status); + Assert.Contains("count", response.Message, StringComparison.OrdinalIgnoreCase); + } +} diff --git a/tools/Azure.Mcp.Tools.Cosmos/tests/Azure.Mcp.Tools.Cosmos.UnitTests/Item/ItemTextSearchCommandTests.cs b/tools/Azure.Mcp.Tools.Cosmos/tests/Azure.Mcp.Tools.Cosmos.UnitTests/Item/ItemTextSearchCommandTests.cs new file mode 100644 index 0000000000..2e7e698ebb --- /dev/null +++ b/tools/Azure.Mcp.Tools.Cosmos/tests/Azure.Mcp.Tools.Cosmos.UnitTests/Item/ItemTextSearchCommandTests.cs @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Net; +using System.Text.Json; +using Azure.Mcp.Tools.Cosmos.Commands; +using Azure.Mcp.Tools.Cosmos.Commands.Item; +using Azure.Mcp.Tools.Cosmos.Services; +using Microsoft.Mcp.Core.Models; +using Microsoft.Mcp.Core.Options; +using Microsoft.Mcp.Tests.Client; +using NSubstitute; +using Xunit; + +namespace Azure.Mcp.Tools.Cosmos.UnitTests.Item; + +public class ItemTextSearchCommandTests + : CommandUnitTestsBase +{ + [Fact] + public void Name_IsCorrect() => Assert.Equal("text-search", Command.Name); + + [Fact] + public async Task ExecuteAsync_ReturnsItems_OnSuccess() + { + var items = new List + { + JsonDocument.Parse("{\"id\":\"hit\"}").RootElement.Clone(), + }; + + Service.TextSearch( + Arg.Is("acct"), Arg.Is("db"), Arg.Is("c"), + Arg.Is("name"), Arg.Is("azure"), Arg.Is(5), + Arg.Is("sub"), Arg.Any(), Arg.Any(), + Arg.Any(), Arg.Any()) + .Returns(items); + + var response = await ExecuteCommandAsync( + "--subscription", "sub", + "--account", "acct", + "--database", "db", + "--container", "c", + "--property", "name", + "--search-phrase", "azure", + "--count", "5"); + + var result = ValidateAndDeserializeResponse(response, CosmosJsonContext.Default.ItemTextSearchCommandResult); + Assert.Single(result.Items); + } + + [Theory] + [InlineData("123name")] + [InlineData("name;drop")] + [InlineData("a..b")] + public async Task ExecuteAsync_RejectsInvalidProperty(string property) + { + var response = await ExecuteCommandAsync( + "--subscription", "sub", + "--account", "acct", + "--database", "db", + "--container", "c", + "--property", property, + "--search-phrase", "azure"); + + Assert.Equal(HttpStatusCode.BadRequest, response.Status); + Assert.Contains("property", response.Message, StringComparison.OrdinalIgnoreCase); + } +} diff --git a/tools/Azure.Mcp.Tools.Cosmos/tests/Azure.Mcp.Tools.Cosmos.UnitTests/Item/ItemVectorSearchCommandTests.cs b/tools/Azure.Mcp.Tools.Cosmos/tests/Azure.Mcp.Tools.Cosmos.UnitTests/Item/ItemVectorSearchCommandTests.cs new file mode 100644 index 0000000000..ba6376becb --- /dev/null +++ b/tools/Azure.Mcp.Tools.Cosmos/tests/Azure.Mcp.Tools.Cosmos.UnitTests/Item/ItemVectorSearchCommandTests.cs @@ -0,0 +1,139 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Net; +using System.Text.Json; +using Azure.Mcp.Tools.Cosmos.Commands; +using Azure.Mcp.Tools.Cosmos.Commands.Item; +using Azure.Mcp.Tools.Cosmos.Models; +using Azure.Mcp.Tools.Cosmos.Services; +using Microsoft.Mcp.Core.Models; +using Microsoft.Mcp.Core.Options; +using Microsoft.Mcp.Tests.Client; +using NSubstitute; +using Xunit; + +namespace Azure.Mcp.Tools.Cosmos.UnitTests.Item; + +public class ItemVectorSearchCommandTests + : CommandUnitTestsBase +{ + [Fact] + public void Name_IsCorrect() => Assert.Equal("vector-search", Command.Name); + + [Fact] + public async Task ExecuteAsync_ReturnsItems_WhenEmbeddingProvided() + { + var items = new List + { + JsonDocument.Parse("{\"id\":\"x\",\"_score\":0.1}").RootElement.Clone(), + }; + + Service.VectorSearch( + Arg.Is("acct"), Arg.Is("db"), Arg.Is("c"), + Arg.Is("embedding"), + Arg.Is>(p => p.Count == 2 && p[0] == "id" && p[1] == "title"), + Arg.Is>(v => v.Count == 3), + Arg.Is(3), + Arg.Is("sub"), Arg.Any(), Arg.Any(), + Arg.Any(), Arg.Any()) + .Returns(items); + + var response = await ExecuteCommandAsync( + "--subscription", "sub", + "--account", "acct", + "--database", "db", + "--container", "c", + "--vector-property", "embedding", + "--select-properties", "id,title", + "--count", "3", + "--embedding", "0.1,0.2,0.3"); + + var result = ValidateAndDeserializeResponse(response, CosmosJsonContext.Default.ItemVectorSearchCommandResult); + Assert.Single(result.Items); + } + + [Fact] + public async Task ExecuteAsync_GeneratesEmbedding_WhenSearchTextProvided() + { + Service.GenerateEmbedding( + Arg.Is("hello"), + Arg.Is(r => r.Endpoint == "https://aoai.example/" && r.DeploymentName == "ada"), + Arg.Any(), + Arg.Any()) + .Returns(new[] { 0.5f, 0.25f }); + + Service.VectorSearch( + Arg.Any(), Arg.Any(), Arg.Any(), + Arg.Any(), + Arg.Any>(), + Arg.Is>(v => v.Count == 2 && v[0] == 0.5f), + Arg.Any(), + Arg.Any(), Arg.Any(), Arg.Any(), + Arg.Any(), Arg.Any()) + .Returns([]); + + var response = await ExecuteCommandAsync( + "--subscription", "sub", + "--account", "acct", + "--database", "db", + "--container", "c", + "--vector-property", "embedding", + "--select-properties", "id", + "--search-text", "hello", + "--openai-endpoint", "https://aoai.example/", + "--embedding-deployment", "ada"); + + Assert.Equal(HttpStatusCode.OK, response.Status); + } + + [Fact] + public async Task ExecuteAsync_RequiresEmbeddingOrSearchText() + { + var response = await ExecuteCommandAsync( + "--subscription", "sub", + "--account", "acct", + "--database", "db", + "--container", "c", + "--vector-property", "embedding", + "--select-properties", "id"); + + Assert.Equal(HttpStatusCode.BadRequest, response.Status); + Assert.Contains("embedding", response.Message, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task ExecuteAsync_RejectsBothEmbeddingAndSearchText() + { + var response = await ExecuteCommandAsync( + "--subscription", "sub", + "--account", "acct", + "--database", "db", + "--container", "c", + "--vector-property", "embedding", + "--select-properties", "id", + "--embedding", "0.1,0.2", + "--search-text", "hi", + "--openai-endpoint", "https://aoai.example/", + "--embedding-deployment", "ada"); + + Assert.Equal(HttpStatusCode.BadRequest, response.Status); + Assert.Contains("mutually exclusive", response.Message, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task ExecuteAsync_RejectsWildcardSelectProperties() + { + var response = await ExecuteCommandAsync( + "--subscription", "sub", + "--account", "acct", + "--database", "db", + "--container", "c", + "--vector-property", "embedding", + "--select-properties", "*", + "--embedding", "0.1,0.2"); + + Assert.Equal(HttpStatusCode.BadRequest, response.Status); + Assert.Contains("wildcard", response.Message, StringComparison.OrdinalIgnoreCase); + } +}