diff --git a/samples/cs/embeddings/Program.cs b/samples/cs/embeddings/Program.cs index 348bc346..53bbca83 100644 --- a/samples/cs/embeddings/Program.cs +++ b/samples/cs/embeddings/Program.cs @@ -13,6 +13,44 @@ // Initialize the singleton instance. await FoundryLocalManager.CreateAsync(config, Utils.GetAppLogger()); var mgr = FoundryLocalManager.Instance; + +// Discover available execution providers and their registration status. +var eps = mgr.DiscoverEps(); +int maxNameLen = 30; +Console.WriteLine("Available execution providers:"); +Console.WriteLine($" {"Name".PadRight(maxNameLen)} Registered"); +Console.WriteLine($" {new string('─', maxNameLen)} {"──────────"}"); +foreach (var ep in eps) +{ + Console.WriteLine($" {ep.Name.PadRight(maxNameLen)} {ep.IsRegistered}"); +} + +// Download and register all execution providers with per-EP progress. +// EP packages include dependencies and may be large. +// Download is only required again if a new version of the EP is released. +// For cross platform builds there is no dynamic EP download and this will return immediately. +Console.WriteLine("\nDownloading execution providers:"); +if (eps.Length > 0) +{ + string currentEp = ""; + await mgr.DownloadAndRegisterEpsAsync((epName, percent) => + { + if (epName != currentEp) + { + if (currentEp != "") + { + Console.WriteLine(); + } + currentEp = epName; + } + Console.Write($"\r {epName.PadRight(maxNameLen)} {percent,6:F1}%"); + }); + Console.WriteLine(); +} +else +{ + Console.WriteLine("No execution providers to download."); +} // // @@ -20,7 +58,7 @@ var catalog = await mgr.GetCatalogAsync(); // Get an embedding model -var model = await catalog.GetModelAsync("qwen3-0.6b-embedding") ?? throw new Exception("Embedding model not found"); +var model = await catalog.GetModelAsync("qwen3-embedding-0.6b") ?? throw new Exception("Embedding model not found"); // Download the model (the method skips download if already cached) await model.DownloadAsync(progress => @@ -69,6 +107,5 @@ await model.DownloadAsync(progress => // // Tidy up - unload the model await model.UnloadAsync(); -Console.WriteLine("\nModel unloaded."); // // diff --git a/samples/js/embeddings/app.js b/samples/js/embeddings/app.js index ea6ff185..5577566a 100644 --- a/samples/js/embeddings/app.js +++ b/samples/js/embeddings/app.js @@ -14,9 +14,39 @@ const manager = FoundryLocalManager.create({ // console.log('✓ SDK initialized successfully'); +// Discover available execution providers and their registration status. +const eps = manager.discoverEps(); +const maxNameLen = 30; +console.log('\nAvailable execution providers:'); +console.log(` ${'Name'.padEnd(maxNameLen)} Registered`); +console.log(` ${'─'.repeat(maxNameLen)} ──────────`); +for (const ep of eps) { + console.log(` ${ep.name.padEnd(maxNameLen)} ${ep.isRegistered}`); +} + +// Download and register all execution providers with per-EP progress. +// EP packages include dependencies and may be large. +// Download is only required again if a new version of the EP is released. +console.log('\nDownloading execution providers:'); +if (eps.length > 0) { + let currentEp = ''; + await manager.downloadAndRegisterEps((epName, percent) => { + if (epName !== currentEp) { + if (currentEp !== '') { + process.stdout.write('\n'); + } + currentEp = epName; + } + process.stdout.write(`\r ${epName.padEnd(maxNameLen)} ${percent.toFixed(1).padStart(5)}%`); + }); + process.stdout.write('\n'); +} else { + console.log('No execution providers to download.'); +} + // // Get an embedding model -const modelAlias = 'qwen3-0.6b-embedding'; +const modelAlias = 'qwen3-embedding-0.6b'; const model = await manager.catalog.getModel(modelAlias); // Download the model diff --git a/samples/python/embeddings/src/app.py b/samples/python/embeddings/src/app.py index 30ade4b2..f10a71e4 100644 --- a/samples/python/embeddings/src/app.py +++ b/samples/python/embeddings/src/app.py @@ -11,8 +11,35 @@ def main(): FoundryLocalManager.initialize(config) manager = FoundryLocalManager.instance + # Discover available execution providers and their registration status. + eps = manager.discover_eps() + max_name_len = 30 + print("Available execution providers:") + print(f" {'Name':<{max_name_len}} Registered") + print(f" {'─' * max_name_len} ──────────") + for ep in eps: + print(f" {ep.name:<{max_name_len}} {ep.is_registered}") + + # Download and register all execution providers. + print("\nDownloading execution providers:") + current_ep = "" + def ep_progress(ep_name: str, percent: float): + nonlocal current_ep + if ep_name != current_ep: + if current_ep: + print() + current_ep = ep_name + print(f"\r {ep_name:<{max_name_len}} {percent:5.1f}%", end="", flush=True) + + if eps: + manager.download_and_register_eps(progress_callback=ep_progress) + if current_ep: + print() + else: + print("No execution providers to download.") + # Select and load an embedding model from the catalog - model = manager.catalog.get_model("qwen3-0.6b-embedding") + model = manager.catalog.get_model("qwen3-embedding-0.6b") model.download( lambda progress: print( f"\rDownloading model: {progress:.2f}%", diff --git a/samples/rust/embeddings/src/main.rs b/samples/rust/embeddings/src/main.rs index 9b5550f0..2849edd8 100644 --- a/samples/rust/embeddings/src/main.rs +++ b/samples/rust/embeddings/src/main.rs @@ -3,10 +3,12 @@ // Licensed under the MIT License. // +use std::io::{self, Write}; + use foundry_local_sdk::{FoundryLocalConfig, FoundryLocalManager}; // -const ALIAS: &str = "qwen3-0.6b-embedding"; +const ALIAS: &str = "qwen3-embedding-0.6b"; #[tokio::main] async fn main() -> Result<(), Box> { @@ -18,6 +20,39 @@ async fn main() -> Result<(), Box> { let manager = FoundryLocalManager::create(FoundryLocalConfig::new("foundry_local_samples"))?; // + // Discover available execution providers and their registration status. + let eps = manager.discover_eps()?; + let max_name_len = 30; + println!("Available execution providers:"); + println!(" {: let model = manager.catalog().get_model(ALIAS).await?; diff --git a/sdk/cpp/CMakeLists.txt b/sdk/cpp/CMakeLists.txt index 41f12c27..60fa547a 100644 --- a/sdk/cpp/CMakeLists.txt +++ b/sdk/cpp/CMakeLists.txt @@ -31,6 +31,14 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) +# Enable MSVC exception handling so wil's exception-based APIs (e.g. the +# one-arg GetModuleFileNameW template and THROW_IF_FAILED) are defined. +# Without /EHsc, WIL_ENABLE_EXCEPTIONS is not set and those declarations +# are omitted from . +if (MSVC) + add_compile_options(/EHsc) +endif() + # Optional: target Windows 10+ APIs (adjust if you need older) add_compile_definitions(_WIN32_WINNT=0x0A00 WINVER=0x0A00) @@ -56,6 +64,7 @@ add_library(CppSdk STATIC src/openai_audio_client.cpp src/openai_live_audio_types.cpp src/openai_live_audio_client.cpp + src/openai_embedding_client.cpp src/foundry_local_manager.cpp ) diff --git a/sdk/cpp/include/foundry_local.h b/sdk/cpp/include/foundry_local.h index 01b8b98d..737b6751 100644 --- a/sdk/cpp/include/foundry_local.h +++ b/sdk/cpp/include/foundry_local.h @@ -18,3 +18,4 @@ #include "openai/openai_audio_client.h" #include "openai/openai_live_audio_types.h" #include "openai/openai_live_audio_client.h" +#include "openai/openai_embedding_client.h" diff --git a/sdk/cpp/include/model.h b/sdk/cpp/include/model.h index 9238cf12..c3aa0c3e 100644 --- a/sdk/cpp/include/model.h +++ b/sdk/cpp/include/model.h @@ -19,6 +19,7 @@ namespace foundry_local { class OpenAIChatClient; class OpenAIAudioClient; + class OpenAIEmbeddingClient; } namespace foundry_local::Internal { @@ -59,6 +60,7 @@ namespace foundry_local { friend class OpenAIChatClient; friend class OpenAIAudioClient; + friend class OpenAIEmbeddingClient; }; enum class DeviceType { diff --git a/sdk/cpp/include/openai/openai_embedding_client.h b/sdk/cpp/include/openai/openai_embedding_client.h new file mode 100644 index 00000000..795e6813 --- /dev/null +++ b/sdk/cpp/include/openai/openai_embedding_client.h @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace foundry_local::Internal { + struct IFoundryLocalCore; +} + +namespace foundry_local { + class ILogger; + class IModel; + + struct EmbeddingObject { + int index = 0; + std::vector embedding; + }; + + struct EmbeddingUsage { + std::optional prompt_tokens; + std::optional total_tokens; + }; + + struct EmbeddingCreateResponse { + std::string model; + std::string object; ///< Always "list" + std::vector data; + std::optional usage; + }; + + class OpenAIEmbeddingClient final { + public: + explicit OpenAIEmbeddingClient(const IModel& model); + + /// Returns the model ID this client was created for. + const std::string& GetModelId() const noexcept { return modelId_; } + + /// Generate embedding for a single input string. + EmbeddingCreateResponse GenerateEmbedding(std::string_view input) const; + + /// Generate embeddings for multiple input strings in a single request. + EmbeddingCreateResponse GenerateEmbeddings(gsl::span inputs) const; + + private: + OpenAIEmbeddingClient(gsl::not_null core, std::string_view modelId, + gsl::not_null logger); + + std::string BuildSingleRequestJson(std::string_view input) const; + std::string BuildBatchRequestJson(gsl::span inputs) const; + + std::string modelId_; + gsl::not_null core_; + gsl::not_null logger_; + }; + +} // namespace foundry_local diff --git a/sdk/cpp/sample/main.cpp b/sdk/cpp/sample/main.cpp index 8ccc39d8..b8f314b9 100644 --- a/sdk/cpp/sample/main.cpp +++ b/sdk/cpp/sample/main.cpp @@ -322,6 +322,46 @@ void ChatWithToolCalling(Manager& manager, const std::string& alias) { std::cout << "Model unloaded.\n"; } +// --------------------------------------------------------------------------- +// Example 6 – Embeddings (single and batch) +// --------------------------------------------------------------------------- +void GenerateEmbeddings(Manager& manager, const std::string& alias) { + std::cout << "\n=== Example 6: Embeddings ===\n"; + + auto& catalog = manager.GetCatalog(); + + auto* model = catalog.GetModel(alias); + if (!model) { + std::cerr << "Model '" << alias << "' not found in catalog.\n"; + return; + } + + model->Download([](float pct) { std::cout << "\rDownloading: " << pct << "% " << std::flush; }); + std::cout << "\n"; + + model->Load(); + + OpenAIEmbeddingClient embeddings(*model); + + // Single input + auto single = embeddings.GenerateEmbedding("The quick brown fox jumps over the lazy dog"); + if (!single.data.empty()) { + std::cout << "Single embedding: dim=" << single.data[0].embedding.size() << "\n"; + } + + // Batch input + std::vector inputs = {"The capital of France is Paris", "Machine learning is a subset of AI"}; + auto batch = embeddings.GenerateEmbeddings(inputs); + std::cout << "Batch embeddings: count=" << batch.data.size(); + if (!batch.data.empty()) { + std::cout << " dim=" << batch.data[0].embedding.size(); + } + std::cout << "\n"; + + model->Unload(); + std::cout << "Model unloaded.\n"; +} + // --------------------------------------------------------------------------- // main // --------------------------------------------------------------------------- @@ -346,6 +386,9 @@ int main() { // 5. Tool calling (define tools, let the model call them, feed results back) ChatWithToolCalling(manager, "phi-3.5-mini"); + // 6. Embeddings — generate single and batch embeddings + GenerateEmbeddings(manager, "qwen3-embedding-0.6b"); + Manager::Destroy(); return 0; } diff --git a/sdk/cpp/src/core.h b/sdk/cpp/src/core.h index cc37ce9e..37804d25 100644 --- a/sdk/cpp/src/core.h +++ b/sdk/cpp/src/core.h @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. // -// Core DLL interop � loads Microsoft.AI.Foundry.Local.Core.dll at runtime. +// Core DLL interop – loads Microsoft.AI.Foundry.Local.Core.dll at runtime. // Internal header, not part of the public API. #pragma once diff --git a/sdk/cpp/src/openai_embedding_client.cpp b/sdk/cpp/src/openai_embedding_client.cpp new file mode 100644 index 00000000..1ad19201 --- /dev/null +++ b/sdk/cpp/src/openai_embedding_client.cpp @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include + +#include +#include + +#include "foundry_local.h" +#include "foundry_local_internal_core.h" +#include "foundry_local_exception.h" +#include "core_interop_request.h" +#include "parser.h" +#include "logger.h" + +namespace foundry_local { + + namespace { + /// True for strings that are empty or contain only whitespace characters. + bool IsBlank(std::string_view s) { + for (char c : s) { + if (!std::isspace(static_cast(c))) { + return false; + } + } + return true; + } + } // namespace + + OpenAIEmbeddingClient::OpenAIEmbeddingClient(gsl::not_null core, + std::string_view modelId, gsl::not_null logger) + : core_(core), modelId_(modelId), logger_(logger) {} + + std::string OpenAIEmbeddingClient::BuildSingleRequestJson(std::string_view input) const { + nlohmann::json req = {{"model", modelId_}, {"input", std::string(input)}}; + return req.dump(); + } + + std::string OpenAIEmbeddingClient::BuildBatchRequestJson(gsl::span inputs) const { + nlohmann::json jInputs = nlohmann::json::array(); + for (const auto& s : inputs) { + jInputs.push_back(s); + } + nlohmann::json req = {{"model", modelId_}, {"input", std::move(jInputs)}}; + return req.dump(); + } + + EmbeddingCreateResponse OpenAIEmbeddingClient::GenerateEmbedding(std::string_view input) const { + if (IsBlank(input)) { + throw Exception("Embedding input must be a non-empty string.", *logger_); + } + + std::string openAiReqJson = BuildSingleRequestJson(input); + + CoreInteropRequest req("embeddings"); + req.AddParam("OpenAICreateRequest", openAiReqJson); + + std::string json = req.ToJson(); + auto response = core_->call(req.Command(), *logger_, &json); + if (response.HasError()) { + throw Exception("Embedding generation failed: " + response.error, *logger_); + } + + return nlohmann::json::parse(response.data).get(); + } + + EmbeddingCreateResponse OpenAIEmbeddingClient::GenerateEmbeddings(gsl::span inputs) const { + if (inputs.empty()) { + throw Exception("Embedding inputs must be a non-empty array of strings.", *logger_); + } + for (const auto& s : inputs) { + if (IsBlank(s)) { + throw Exception("Each embedding input must be a non-empty string.", *logger_); + } + } + + std::string openAiReqJson = BuildBatchRequestJson(inputs); + + CoreInteropRequest req("embeddings"); + req.AddParam("OpenAICreateRequest", openAiReqJson); + + std::string json = req.ToJson(); + auto response = core_->call(req.Command(), *logger_, &json); + if (response.HasError()) { + throw Exception("Batch embedding generation failed: " + response.error, *logger_); + } + + return nlohmann::json::parse(response.data).get(); + } + + OpenAIEmbeddingClient::OpenAIEmbeddingClient(const IModel& model) + : OpenAIEmbeddingClient(model.GetCoreAccess().core, model.GetCoreAccess().modelName, + model.GetCoreAccess().logger) { + if (!model.IsLoaded()) { + throw Exception("Model " + model.GetCoreAccess().modelName + " is not loaded. Call Load() first.", + *model.GetCoreAccess().logger); + } + } + +} // namespace foundry_local diff --git a/sdk/cpp/src/parser.h b/sdk/cpp/src/parser.h index 3596579c..3da60271 100644 --- a/sdk/cpp/src/parser.h +++ b/sdk/cpp/src/parser.h @@ -292,6 +292,43 @@ namespace foundry_local { c.delta = j.at("delta").get(); } + inline void from_json(const nlohmann::json& j, EmbeddingObject& e) { + if (j.contains("index")) + j.at("index").get_to(e.index); + e.embedding.clear(); + if (j.contains("embedding") && j.at("embedding").is_array()) { + const auto& arr = j.at("embedding"); + e.embedding.reserve(arr.size()); + for (const auto& v : arr) { + if (v.is_number()) { + e.embedding.push_back(v.get()); + } + } + } + } + + inline void from_json(const nlohmann::json& j, EmbeddingUsage& u) { + u.prompt_tokens = ParsingUtils::get_opt_int(j, "prompt_tokens"); + u.total_tokens = ParsingUtils::get_opt_int(j, "total_tokens"); + } + + inline void from_json(const nlohmann::json& j, EmbeddingCreateResponse& r) { + r.model = ParsingUtils::get_string_or_empty(j, "model"); + r.object = ParsingUtils::get_string_or_empty(j, "object"); + + r.data.clear(); + if (j.contains("data") && j.at("data").is_array()) { + r.data = j.at("data").get>(); + } + + if (j.contains("usage") && j.at("usage").is_object()) { + r.usage = j.at("usage").get(); + } + else { + r.usage.reset(); + } + } + inline void from_json(const nlohmann::json& j, ChatCompletionCreateResponse& r) { if (j.contains("created")) j.at("created").get_to(r.created); diff --git a/sdk/cpp/test/client_test.cpp b/sdk/cpp/test/client_test.cpp index 6f083cef..53a5353a 100644 --- a/sdk/cpp/test/client_test.cpp +++ b/sdk/cpp/test/client_test.cpp @@ -743,3 +743,177 @@ TEST_F(OpenAIChatClientTest, CompleteChat_ToolCallRoundTrip) { EXPECT_EQ("call_1", openAiReq["messages"][3]["tool_call_id"].get()); EXPECT_EQ("auto", openAiReq["tool_choice"].get()); } + +// ===================================================================== +// OpenAIEmbeddingClient tests +// ===================================================================== + +class OpenAIEmbeddingClientTest : public ::testing::Test { +protected: + MockCore core_; + NullLogger logger_; + + static std::string MakeEmbeddingResponseJson(const std::vector>& vectors, + const std::string& modelName = "embedding-model") { + nlohmann::json data = nlohmann::json::array(); + for (size_t i = 0; i < vectors.size(); ++i) { + data.push_back({{"index", static_cast(i)}, {"object", "embedding"}, {"embedding", vectors[i]}}); + } + nlohmann::json resp = {{"model", modelName}, + {"object", "list"}, + {"data", std::move(data)}, + {"usage", {{"prompt_tokens", 5}, {"total_tokens", 5}}}}; + return resp.dump(); + } + + ModelVariant MakeLoadedVariant(const std::string& name = "embedding-model") { + core_.OnCall("list_loaded_models", "[\"" + name + ":1\"]"); + return Factory::CreateModelVariant(&core_, Factory::MakeModelInfo(name, "alias"), &logger_); + } +}; + +TEST_F(OpenAIEmbeddingClientTest, GenerateEmbedding_BasicResponse) { + core_.OnCall("embeddings", MakeEmbeddingResponseJson({{0.1f, 0.2f, 0.3f, 0.4f}})); + core_.OnCall("list_loaded_models", R"(["embedding-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIEmbeddingClient client(variant); + auto response = client.GenerateEmbedding("hello world"); + + EXPECT_EQ("embedding-model", response.model); + EXPECT_EQ("list", response.object); + ASSERT_EQ(1u, response.data.size()); + EXPECT_EQ(0, response.data[0].index); + ASSERT_EQ(4u, response.data[0].embedding.size()); + EXPECT_NEAR(0.1f, response.data[0].embedding[0], 1e-5f); + EXPECT_NEAR(0.4f, response.data[0].embedding[3], 1e-5f); + ASSERT_TRUE(response.usage.has_value()); + EXPECT_EQ(5, *response.usage->prompt_tokens); +} + +TEST_F(OpenAIEmbeddingClientTest, GenerateEmbedding_RequestFormat) { + core_.OnCall("embeddings", MakeEmbeddingResponseJson({{0.0f}})); + core_.OnCall("list_loaded_models", R"(["embedding-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIEmbeddingClient client(variant); + client.GenerateEmbedding("hello world"); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("embeddings")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + EXPECT_EQ("embedding-model", openAiReq["model"].get()); + EXPECT_EQ("hello world", openAiReq["input"].get()); +} + +TEST_F(OpenAIEmbeddingClientTest, GenerateEmbeddings_BasicResponse) { + core_.OnCall("embeddings", MakeEmbeddingResponseJson({{0.1f, 0.2f}, {0.3f, 0.4f}, {0.5f, 0.6f}})); + core_.OnCall("list_loaded_models", R"(["embedding-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIEmbeddingClient client(variant); + + std::vector inputs = {"first", "second", "third"}; + auto response = client.GenerateEmbeddings(inputs); + + ASSERT_EQ(3u, response.data.size()); + EXPECT_EQ(0, response.data[0].index); + EXPECT_EQ(1, response.data[1].index); + EXPECT_EQ(2, response.data[2].index); + EXPECT_NEAR(0.5f, response.data[2].embedding[0], 1e-5f); +} + +TEST_F(OpenAIEmbeddingClientTest, GenerateEmbeddings_RequestFormat) { + core_.OnCall("embeddings", MakeEmbeddingResponseJson({{0.0f}, {0.0f}})); + core_.OnCall("list_loaded_models", R"(["embedding-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIEmbeddingClient client(variant); + + std::vector inputs = {"a", "b"}; + client.GenerateEmbeddings(inputs); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("embeddings")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + EXPECT_EQ("embedding-model", openAiReq["model"].get()); + ASSERT_TRUE(openAiReq["input"].is_array()); + ASSERT_EQ(2u, openAiReq["input"].size()); + EXPECT_EQ("a", openAiReq["input"][0].get()); + EXPECT_EQ("b", openAiReq["input"][1].get()); +} + +TEST_F(OpenAIEmbeddingClientTest, GenerateEmbedding_EmptyInput_Throws) { + core_.OnCall("list_loaded_models", R"(["embedding-model:1"])"); + auto variant = MakeLoadedVariant(); + OpenAIEmbeddingClient client(variant); + + EXPECT_THROW(client.GenerateEmbedding(""), Exception); +} + +TEST_F(OpenAIEmbeddingClientTest, GenerateEmbedding_WhitespaceOnlyInput_Throws) { + core_.OnCall("list_loaded_models", R"(["embedding-model:1"])"); + auto variant = MakeLoadedVariant(); + OpenAIEmbeddingClient client(variant); + + EXPECT_THROW(client.GenerateEmbedding(" \t\n "), Exception); +} + +TEST_F(OpenAIEmbeddingClientTest, GenerateEmbeddings_EmptyList_Throws) { + core_.OnCall("list_loaded_models", R"(["embedding-model:1"])"); + auto variant = MakeLoadedVariant(); + OpenAIEmbeddingClient client(variant); + + std::vector empty; + EXPECT_THROW(client.GenerateEmbeddings(empty), Exception); +} + +TEST_F(OpenAIEmbeddingClientTest, GenerateEmbeddings_ListWithEmptyString_Throws) { + core_.OnCall("list_loaded_models", R"(["embedding-model:1"])"); + auto variant = MakeLoadedVariant(); + OpenAIEmbeddingClient client(variant); + + std::vector inputs = {"valid", "", "also valid"}; + EXPECT_THROW(client.GenerateEmbeddings(inputs), Exception); +} + +TEST_F(OpenAIEmbeddingClientTest, GenerateEmbeddings_ListWithWhitespaceOnlyString_Throws) { + core_.OnCall("list_loaded_models", R"(["embedding-model:1"])"); + auto variant = MakeLoadedVariant(); + OpenAIEmbeddingClient client(variant); + + std::vector inputs = {"valid", " ", "also valid"}; + EXPECT_THROW(client.GenerateEmbeddings(inputs), Exception); +} + +TEST_F(OpenAIEmbeddingClientTest, Constructor_ThrowsIfNotLoaded) { + core_.OnCall("list_loaded_models", R"([])"); + auto variant = Factory::CreateModelVariant(&core_, Factory::MakeModelInfo("unloaded-model", "alias"), &logger_); + EXPECT_THROW(OpenAIEmbeddingClient client(variant), Exception); +} + +TEST_F(OpenAIEmbeddingClientTest, GetModelId) { + core_.OnCall("list_loaded_models", R"(["embedding-model:1"])"); + auto variant = MakeLoadedVariant(); + OpenAIEmbeddingClient client(variant); + EXPECT_EQ("embedding-model", client.GetModelId()); +} + +TEST_F(OpenAIEmbeddingClientTest, GenerateEmbedding_CoreError_Throws) { + core_.OnCallThrow("embeddings", "embedding generation failed"); + core_.OnCall("list_loaded_models", R"(["embedding-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIEmbeddingClient client(variant); + + EXPECT_THROW(client.GenerateEmbedding("test"), Exception); +} + +TEST_F(OpenAIEmbeddingClientTest, GenerateEmbeddings_CoreError_Throws) { + core_.OnCallThrow("embeddings", "batch embedding generation failed"); + core_.OnCall("list_loaded_models", R"(["embedding-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIEmbeddingClient client(variant); + + std::vector inputs = {"a", "b"}; + EXPECT_THROW(client.GenerateEmbeddings(inputs), Exception); +} diff --git a/sdk/cpp/test/e2e_test.cpp b/sdk/cpp/test/e2e_test.cpp index 06bdc0ff..b7a08ced 100644 --- a/sdk/cpp/test/e2e_test.cpp +++ b/sdk/cpp/test/e2e_test.cpp @@ -61,6 +61,33 @@ class EndToEndTest : public ::testing::Test { static bool IsAudioModel(const std::string& alias) { return alias.find("whisper") != std::string::npos; } + static bool IsEmbeddingModel(const std::string& alias) { return alias.find("embedding") != std::string::npos; } + + + /// Find an embedding model, preferring cached. + static IModel* FindEmbeddingModel(Catalog& catalog) { + IModel* target = nullptr; + + auto cached = catalog.GetCachedModels(); + for (auto* variant : cached) { + if (IsEmbeddingModel(variant->GetAlias())) { + target = catalog.GetModel(variant->GetAlias()); + if (target) + break; + } + } + + if (!target) { + for (const auto& alias : {"qwen3-embedding-0.6b"}) { + target = catalog.GetModel(alias); + if (target) + break; + } + } + + return target; + } + /// Find a chat-capable model, preferring cached, then known small models, then any. /// Selects the CPU variant when available to avoid GPU/EP dependency issues. static IModel* FindChatModel(Catalog& catalog) { @@ -572,3 +599,47 @@ TEST_F(EndToEndTest, DISABLED_DownloadAndRemoveFromCache) { std::cout << "[E2E] RemoveFromCache completed for: " << target->GetAlias() << " (IsCached=" << (target->IsCached() ? "true" : "false") << ")\n"; } + +// =========================================================================== +// Download, load, embeddings (single and batch), unload +// =========================================================================== +// +// The embedding tests below mirror the integration test suites in the C#, JS, +// Python, and Rust SDKs. They all require a real embedding model (loaded via +// the Catalog); they are DISABLED_ by default and run only with +// --gtest_also_run_disabled_tests. +// +// Each test prepares an OpenAIEmbeddingClient over a loaded variant and relies +// on the suite's SetUp/TearDown to bring up the Manager. + +TEST_F(EndToEndTest, DISABLED_DownloadLoadEmbeddingUnload) { + if (IsRunningInCI()) GTEST_SKIP() << "Skipped in CI (requires model download)"; + auto& catalog = Manager::Instance().GetCatalog(); + auto* target = FindEmbeddingModel(catalog); + if (!target) GTEST_SKIP() << "No embedding model found in catalog"; + + std::cout << "[E2E] Using embedding model: " << target->GetAlias() + << " variant: " << target->GetId() << "\n"; + target->Download(); + EXPECT_TRUE(target->IsCached()); + target->Load(); + EXPECT_TRUE(target->IsLoaded()); + + OpenAIEmbeddingClient client(*target); + + // Single input + auto single = client.GenerateEmbedding("The capital of France is Paris"); + ASSERT_FALSE(single.data.empty()); + EXPECT_FALSE(single.data[0].embedding.empty()); + std::cout << "[E2E] Single embedding dim: " << single.data[0].embedding.size() << "\n"; + + // Batch input + std::vector inputs = {"short", "a longer sentence for embedding"}; + auto batch = client.GenerateEmbeddings(inputs); + ASSERT_EQ(2u, batch.data.size()); + EXPECT_EQ(single.data[0].embedding.size(), batch.data[0].embedding.size()); + EXPECT_EQ(single.data[0].embedding.size(), batch.data[1].embedding.size()); + + target->Unload(); + EXPECT_FALSE(target->IsLoaded()); +} diff --git a/sdk/cs/src/OpenAI/EmbeddingClient.cs b/sdk/cs/src/OpenAI/EmbeddingClient.cs index 91877f47..4486a606 100644 --- a/sdk/cs/src/OpenAI/EmbeddingClient.cs +++ b/sdk/cs/src/OpenAI/EmbeddingClient.cs @@ -77,12 +77,14 @@ private async Task GenerateEmbeddingImplAsync(string in private async Task GenerateEmbeddingsImplAsync(IEnumerable inputs, CancellationToken? ct) { - if (inputs == null || !inputs.Any()) + var inputList = inputs?.ToList(); + + if (inputList == null || inputList.Count == 0) { throw new ArgumentException("Inputs must be a non-empty array of strings.", nameof(inputs)); } - foreach (var input in inputs) + foreach (var input in inputList) { if (string.IsNullOrWhiteSpace(input)) { @@ -90,7 +92,7 @@ private async Task GenerateEmbeddingsImplAsync(IEnumera } } - var embeddingRequest = EmbeddingCreateRequestExtended.FromUserInput(_modelId, inputs); + var embeddingRequest = EmbeddingCreateRequestExtended.FromUserInput(_modelId, inputList); var embeddingRequestJson = embeddingRequest.ToJson(); var request = new CoreInteropRequest { Params = new() { { "OpenAICreateRequest", embeddingRequestJson } } }; diff --git a/sdk/cs/test/FoundryLocal.Tests/EmbeddingClientTests.cs b/sdk/cs/test/FoundryLocal.Tests/EmbeddingClientTests.cs index bed3a8ea..d4d52197 100644 --- a/sdk/cs/test/FoundryLocal.Tests/EmbeddingClientTests.cs +++ b/sdk/cs/test/FoundryLocal.Tests/EmbeddingClientTests.cs @@ -19,19 +19,14 @@ public static async Task Setup() var manager = FoundryLocalManager.Instance; // initialized by Utils var catalog = await manager.GetCatalogAsync(); - // Reduce max_length in the embedding model's genai_config.json to avoid OOM - // when allocating the KV cache. Embedding models only need a single forward pass - // so a large max_length is unnecessary. - Utils.PatchModelMaxLength("qwen3-0.6b-embedding-generic-cpu-1", "v1"); - // Load the specific cached model variant directly - var model = await catalog.GetModelVariantAsync("qwen3-0.6b-embedding-generic-cpu:1").ConfigureAwait(false); - await Assert.That(model).IsNotNull(); + var loadedModel = await catalog.GetModelVariantAsync("qwen3-0.6b-embedding-generic-cpu:1").ConfigureAwait(false); + await Assert.That(loadedModel).IsNotNull(); - await model!.LoadAsync().ConfigureAwait(false); - await Assert.That(await model.IsLoadedAsync()).IsTrue(); + await loadedModel!.LoadAsync().ConfigureAwait(false); + await Assert.That(await loadedModel.IsLoadedAsync()).IsTrue(); - EmbeddingClientTests.model = model; + EmbeddingClientTests.model = loadedModel; } [After(Class)] diff --git a/sdk/cs/test/FoundryLocal.Tests/Utils.cs b/sdk/cs/test/FoundryLocal.Tests/Utils.cs index f8969853..8b25ba05 100644 --- a/sdk/cs/test/FoundryLocal.Tests/Utils.cs +++ b/sdk/cs/test/FoundryLocal.Tests/Utils.cs @@ -483,26 +483,4 @@ private static string GetRepoRoot() throw new InvalidOperationException("Could not find git repository root from test file location"); } - - /// - /// Patches max_length in a cached model's genai_config.json to a small value. - /// ORT GenAI allocates a KV cache sized by max_length; the default (32768) can cause - /// OOM when multiple models are loaded. Embedding models only need a single forward pass - /// so a small max_length is sufficient. - /// - internal static void PatchModelMaxLength(string modelDirName, string variantSubDir, int newMaxLength = 512) - { - var repoRoot = new DirectoryInfo(GetRepoRoot()); - var configPath = Path.Combine(repoRoot.Parent!.FullName, "test-data-shared", - modelDirName, variantSubDir, "genai_config.json"); - - if (!File.Exists(configPath)) return; - - var json = File.ReadAllText(configPath); - if (json.Contains("\"max_length\": 32768")) - { - json = json.Replace("\"max_length\": 32768", $"\"max_length\": {newMaxLength}"); - File.WriteAllText(configPath, json); - } - } } diff --git a/sdk/python/src/openai/embedding_client.py b/sdk/python/src/openai/embedding_client.py index 89a3b8e5..069c6bca 100644 --- a/sdk/python/src/openai/embedding_client.py +++ b/sdk/python/src/openai/embedding_client.py @@ -98,7 +98,7 @@ def generate_embeddings(self, inputs: List[str]) -> CreateEmbeddingResponse: ValueError: If *inputs* is empty or contains empty strings. FoundryLocalException: If the underlying native embeddings command fails. """ - if not inputs or len(inputs) == 0: + if not inputs: raise ValueError("Inputs must be a non-empty list of strings.") for text in inputs: diff --git a/sdk/rust/docs/api.md b/sdk/rust/docs/api.md index 8dcb0c29..a2045f0c 100644 --- a/sdk/rust/docs/api.md +++ b/sdk/rust/docs/api.md @@ -16,7 +16,6 @@ - [ChatClient](#chatclient) - [ChatCompletionStream](#chatcompletionstream) - [EmbeddingClient](#embeddingclient) - - [EmbeddingResponse](#embeddingresponse) - [AudioClient](#audioclient) - [AudioTranscriptionStream](#audiotranscriptionstream) - [AudioTranscriptionResponse](#audiotranscriptionresponse) diff --git a/sdk/rust/src/openai/embedding_client.rs b/sdk/rust/src/openai/embedding_client.rs index 5de080a0..3215cb05 100644 --- a/sdk/rust/src/openai/embedding_client.rs +++ b/sdk/rust/src/openai/embedding_client.rs @@ -55,27 +55,23 @@ impl EmbeddingClient { .execute_command_async("embeddings".into(), Some(params)) .await?; - // Patch the response to add fields required by async_openai types - // that the server doesn't return (object on each item, usage) + // The server omits two fields that async_openai's CreateEmbeddingResponse + // requires: per-item `object` and top-level `usage`. Inject defaults before + // deserializing. let mut response_value: Value = serde_json::from_str(&raw)?; if let Some(data) = response_value .get_mut("data") .and_then(|d| d.as_array_mut()) { for item in data { - if item.get("object").is_none() { - item.as_object_mut() - .map(|m| m.insert("object".into(), json!("embedding"))); + if let Some(obj) = item.as_object_mut() { + obj.entry("object").or_insert_with(|| json!("embedding")); } } } - if response_value.get("usage").is_none() { - response_value.as_object_mut().map(|m| { - m.insert( - "usage".into(), - json!({"prompt_tokens": 0, "total_tokens": 0}), - ) - }); + if let Some(root) = response_value.as_object_mut() { + root.entry("usage") + .or_insert_with(|| json!({"prompt_tokens": 0, "total_tokens": 0})); } let parsed: CreateEmbeddingResponse = serde_json::from_value(response_value)?;