diff --git a/sdk/cpp/CMakeLists.txt b/sdk/cpp/CMakeLists.txt index 7e32b7fb..41f12c27 100644 --- a/sdk/cpp/CMakeLists.txt +++ b/sdk/cpp/CMakeLists.txt @@ -54,6 +54,8 @@ add_library(CppSdk STATIC src/catalog.cpp src/openai_chat_client.cpp src/openai_audio_client.cpp + src/openai_live_audio_types.cpp + src/openai_live_audio_client.cpp src/foundry_local_manager.cpp ) @@ -91,6 +93,7 @@ if (BUILD_TESTING) test/model_variant_test.cpp test/catalog_test.cpp test/client_test.cpp + test/live_audio_test.cpp ) target_include_directories(CppSdkTests diff --git a/sdk/cpp/include/foundry_local.h b/sdk/cpp/include/foundry_local.h index c16337e1..01b8b98d 100644 --- a/sdk/cpp/include/foundry_local.h +++ b/sdk/cpp/include/foundry_local.h @@ -16,3 +16,5 @@ #include "openai/openai_tool_types.h" #include "openai/openai_chat_client.h" #include "openai/openai_audio_client.h" +#include "openai/openai_live_audio_types.h" +#include "openai/openai_live_audio_client.h" diff --git a/sdk/cpp/include/openai/openai_audio_client.h b/sdk/cpp/include/openai/openai_audio_client.h index ac1ce719..c58fad1c 100644 --- a/sdk/cpp/include/openai/openai_audio_client.h +++ b/sdk/cpp/include/openai/openai_audio_client.h @@ -7,6 +7,7 @@ #include #include #include +#include #include @@ -22,6 +23,8 @@ namespace foundry_local { std::string text; }; + class LiveAudioTranscriptionSession; + class OpenAIAudioClient final { public: explicit OpenAIAudioClient(const IModel& model); @@ -34,6 +37,9 @@ namespace foundry_local { using StreamCallback = std::function; void TranscribeAudioStreaming(const std::filesystem::path& audioFilePath, const StreamCallback& onChunk) const; + /// Create a new live audio transcription session for streaming PCM audio. + std::unique_ptr CreateLiveTranscriptionSession() const; + private: OpenAIAudioClient(gsl::not_null core, std::string_view modelId, gsl::not_null logger); diff --git a/sdk/cpp/include/openai/openai_live_audio_client.h b/sdk/cpp/include/openai/openai_live_audio_client.h new file mode 100644 index 00000000..c65a2d63 --- /dev/null +++ b/sdk/cpp/include/openai/openai_live_audio_client.h @@ -0,0 +1,101 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "openai_live_audio_types.h" + +namespace foundry_local::Internal { + struct IFoundryLocalCore; + template class ThreadSafeQueue; +} // namespace foundry_local::Internal + +namespace foundry_local { + class ILogger; + + class LiveAudioTranscriptionSession final { + public: + LiveAudioTranscriptionSession(gsl::not_null core, + std::string modelId, + gsl::not_null logger); + ~LiveAudioTranscriptionSession() noexcept; + + // Non-copyable, non-movable + LiveAudioTranscriptionSession(const LiveAudioTranscriptionSession&) = delete; + LiveAudioTranscriptionSession& operator=(const LiveAudioTranscriptionSession&) = delete; + LiveAudioTranscriptionSession(LiveAudioTranscriptionSession&&) = delete; + LiveAudioTranscriptionSession& operator=(LiveAudioTranscriptionSession&&) = delete; + + /// Mutable settings reference; only effective before Start(). + LiveAudioTranscriptionOptions& Settings() { return settings_; } + /// Read-only settings reference. + const LiveAudioTranscriptionOptions& Settings() const { return settings_; } + /// Settings that were active when Start() was called. + const LiveAudioTranscriptionOptions& ActiveSettings() const { return activeSettings_; } + + /// Begin the streaming session. Must be called before Append/TryAppend. + void Start(); + + /// Enqueue PCM audio data. Blocks if the push queue is full. + void Append(const uint8_t* pcmData, size_t length); + + /// Try to get the next transcription result within the given timeout. + TranscriptionStatus TryGetNext(LiveAudioTranscriptionResponse& result, + std::chrono::milliseconds timeout = std::chrono::seconds(5)); + + /// Signal the end of audio input and stop the session. + void Stop(); + + /// Returns the error message if the session is in an error state. + std::string GetErrorMessage() const; + + /// Returns true if the session has been started. + bool IsStarted() const; + + /// Returns true if the session has been stopped. + bool IsStopped() const; + + private: + enum class SessionState { + Created, + Starting, + Started, + Stopped + }; + + void PushWorkerLoop(); + void StopInternal(std::unique_lock& lock); + + gsl::not_null core_; + std::string modelId_; + gsl::not_null logger_; + + LiveAudioTranscriptionOptions settings_; + LiveAudioTranscriptionOptions activeSettings_; + + mutable std::mutex mutex_; + SessionState state_ = SessionState::Created; + std::string sessionHandle_; + + using AudioChunk = std::vector; + std::unique_ptr> pushQueue_; + std::unique_ptr> resultQueue_; + + std::thread pushThread_; + std::string errorMessage_; + LiveAudioTranscriptionResponse finalResult_; + bool hasFinalResult_ = false; + }; + +} // namespace foundry_local diff --git a/sdk/cpp/include/openai/openai_live_audio_types.h b/sdk/cpp/include/openai/openai_live_audio_types.h new file mode 100644 index 00000000..d7d31f12 --- /dev/null +++ b/sdk/cpp/include/openai/openai_live_audio_types.h @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +namespace foundry_local { + + struct ContentPart { + std::string text; + std::string transcript; + }; + + struct LiveAudioTranscriptionResponse { + std::string text; + bool is_final = false; + std::optional start_time; + std::optional end_time; + std::vector content; + + static LiveAudioTranscriptionResponse FromJson(const std::string& json); + }; + + struct LiveAudioTranscriptionOptions { + int sample_rate = 16000; + int channels = 1; + int bits_per_sample = 16; + std::optional language; + int push_queue_capacity = 100; + }; + + struct CoreErrorResponse { + std::string code; + std::string message; + bool is_transient = false; + + static std::optional TryParse(const std::string& error_string); + }; + + enum class TranscriptionStatus { + Result, + Timeout, + Closed, + Error + }; + +} // namespace foundry_local diff --git a/sdk/cpp/src/core.h b/sdk/cpp/src/core.h index 10feee5b..cc37ce9e 100644 --- a/sdk/cpp/src/core.h +++ b/sdk/cpp/src/core.h @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. // -// Core DLL interop – loads Microsoft.AI.Foundry.Local.Core.dll at runtime. +// Core DLL interop � loads Microsoft.AI.Foundry.Local.Core.dll at runtime. // Internal header, not part of the public API. #pragma once @@ -46,6 +46,7 @@ namespace foundry_local { module_.reset(); execCmd_ = nullptr; execCbCmd_ = nullptr; + execBinaryCmd_ = nullptr; freeResCmd_ = nullptr; } @@ -91,10 +92,55 @@ namespace foundry_local { return result; } + CoreResponse callWithBinary(std::string_view command, ILogger& logger, + const std::string* dataArgument, + const uint8_t* binaryData, size_t binaryDataLength) const override { + if (!module_ || !execBinaryCmd_ || !freeResCmd_) { + throw Exception("Core is not loaded. Cannot call command: " + std::string(command), logger); + } + + StreamingRequestBuffer request{}; + request.Command = command.empty() ? nullptr : command.data(); + request.CommandLength = static_cast(command.size()); + + if (dataArgument && !dataArgument->empty()) { + request.Data = dataArgument->data(); + request.DataLength = static_cast(dataArgument->size()); + } + + if (binaryData && binaryDataLength > 0) { + if (binaryDataLength > static_cast(INT32_MAX)) { + throw Exception("Binary data length exceeds maximum supported size (INT32_MAX).", logger); + } + request.BinaryData = binaryData; + request.BinaryDataLength = static_cast(binaryDataLength); + } + + ResponseBuffer response{}; + auto safeDeleter = [fn = freeResCmd_](ResponseBuffer* buf) { + if (fn) + fn(buf); + }; + std::unique_ptr responseGuard(&response, safeDeleter); + + execBinaryCmd_(&request, &response); + + CoreResponse result; + if (response.Error && response.ErrorLength > 0) { + result.error.assign(static_cast(response.Error), response.ErrorLength); + return result; + } + if (response.Data && response.DataLength > 0) { + result.data.assign(static_cast(response.Data), response.DataLength); + } + return result; + } + private: wil::unique_hmodule module_; execute_command_fn execCmd_{}; execute_command_with_callback_fn execCbCmd_{}; + execute_command_with_binary_fn execBinaryCmd_{}; free_response_fn freeResCmd_{}; void LoadFromPath(const std::filesystem::path& path) { @@ -105,6 +151,8 @@ namespace foundry_local { execCmd_ = reinterpret_cast(RequireProc(m.get(), "execute_command")); execCbCmd_ = reinterpret_cast( RequireProc(m.get(), "execute_command_with_callback")); + execBinaryCmd_ = reinterpret_cast( + RequireProc(m.get(), "execute_command_with_binary")); freeResCmd_ = reinterpret_cast(RequireProc(m.get(), "free_response")); module_ = std::move(m); diff --git a/sdk/cpp/src/flcore_native.h b/sdk/cpp/src/flcore_native.h index b0778116..d87baa09 100644 --- a/sdk/cpp/src/flcore_native.h +++ b/sdk/cpp/src/flcore_native.h @@ -26,14 +26,25 @@ extern "C" // Callback signature: void(*)(void* data, int length, void* userData) using UserCallbackFn = void(__cdecl*)(void*, int32_t, void*); + struct StreamingRequestBuffer { + const void* Command; + int32_t CommandLength; + const void* Data; + int32_t DataLength; + const void* BinaryData; + int32_t BinaryDataLength; + }; + // Exported function pointer types using execute_command_fn = void(__cdecl*)(RequestBuffer*, ResponseBuffer*); using execute_command_with_callback_fn = void(__cdecl*)(RequestBuffer*, ResponseBuffer*, void* /*callback*/, void* /*userData*/); + using execute_command_with_binary_fn = void(__cdecl*)(StreamingRequestBuffer*, ResponseBuffer*); using free_response_fn = void(__cdecl*)(ResponseBuffer*); static_assert(std::is_standard_layout::value, "RequestBuffer must be standard layout"); static_assert(std::is_standard_layout::value, "ResponseBuffer must be standard layout"); + static_assert(std::is_standard_layout::value, "StreamingRequestBuffer must be standard layout"); #pragma pack(pop) } diff --git a/sdk/cpp/src/foundry_local_internal_core.h b/sdk/cpp/src/foundry_local_internal_core.h index 1e5af79d..f6c2af77 100644 --- a/sdk/cpp/src/foundry_local_internal_core.h +++ b/sdk/cpp/src/foundry_local_internal_core.h @@ -31,6 +31,11 @@ namespace foundry_local { virtual CoreResponse call(std::string_view command, ILogger& logger, const std::string* dataArgument = nullptr, NativeCallbackFn callback = nullptr, void* data = nullptr) const = 0; + + virtual CoreResponse callWithBinary(std::string_view command, ILogger& logger, + const std::string* dataArgument, + const uint8_t* binaryData, size_t binaryDataLength) const = 0; + virtual void unload() = 0; }; diff --git a/sdk/cpp/src/openai_audio_client.cpp b/sdk/cpp/src/openai_audio_client.cpp index d4409d1f..42b1c6a6 100644 --- a/sdk/cpp/src/openai_audio_client.cpp +++ b/sdk/cpp/src/openai_audio_client.cpp @@ -16,6 +16,8 @@ #include "core_helpers.h" #include "logger.h" +#include "openai/openai_live_audio_client.h" + namespace foundry_local { OpenAIAudioClient::OpenAIAudioClient(gsl::not_null core, std::string_view modelId, @@ -67,4 +69,8 @@ namespace foundry_local { } } + std::unique_ptr OpenAIAudioClient::CreateLiveTranscriptionSession() const { + return std::make_unique(core_, modelId_, logger_); + } + } // namespace foundry_local diff --git a/sdk/cpp/src/openai_live_audio_client.cpp b/sdk/cpp/src/openai_live_audio_client.cpp new file mode 100644 index 00000000..51ea8be6 --- /dev/null +++ b/sdk/cpp/src/openai_live_audio_client.cpp @@ -0,0 +1,276 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +#include + +#include "openai/openai_live_audio_client.h" +#include "openai/openai_live_audio_types.h" +#include "foundry_local_internal_core.h" +#include "foundry_local_exception.h" +#include "core_interop_request.h" +#include "thread_safe_queue.h" +#include "logger.h" + +namespace foundry_local { + + LiveAudioTranscriptionSession::LiveAudioTranscriptionSession( + gsl::not_null core, + std::string modelId, + gsl::not_null logger) + : core_(core), modelId_(std::move(modelId)), logger_(logger) {} + + LiveAudioTranscriptionSession::~LiveAudioTranscriptionSession() noexcept { + try { + std::unique_lock lock(mutex_); + if (state_ == SessionState::Started) { + StopInternal(lock); + } + } + catch (...) { + // Suppress exceptions in destructor + } + } + + void LiveAudioTranscriptionSession::Start() { + std::unique_lock lock(mutex_); + if (state_ != SessionState::Created) { + throw Exception("Session has already been started.", *logger_); + } + + // Transition to Starting state before releasing lock for FFI call + state_ = SessionState::Starting; + activeSettings_ = settings_; + + // Validate queue capacity early + if (activeSettings_.push_queue_capacity <= 0) { + state_ = SessionState::Created; + throw Exception("push_queue_capacity must be greater than 0.", *logger_); + } + + // Build the start command + CoreInteropRequest req("audio_stream_start"); + req.AddParam("Model", modelId_); + req.AddParam("SampleRate", std::to_string(activeSettings_.sample_rate)); + req.AddParam("Channels", std::to_string(activeSettings_.channels)); + req.AddParam("BitsPerSample", std::to_string(activeSettings_.bits_per_sample)); + if (activeSettings_.language.has_value()) { + req.AddParam("Language", activeSettings_.language.value()); + } + std::string json = req.ToJson(); + + // Release lock during FFI call to avoid holding mutex across boundary + lock.unlock(); + + auto response = core_->call(req.Command(), *logger_, &json); + + lock.lock(); + + if (response.HasError()) { + state_ = SessionState::Created; + throw Exception("Failed to start audio stream: " + response.error, *logger_); + } + + sessionHandle_ = std::move(response.data); + if (sessionHandle_.empty()) { + state_ = SessionState::Created; + throw Exception("audio_stream_start returned an empty session handle.", *logger_); + } + + // Validate queue capacity + const size_t queueCapacity = static_cast(activeSettings_.push_queue_capacity); + + // Create the queues + pushQueue_ = std::make_unique>(queueCapacity); + resultQueue_ = std::make_unique>(queueCapacity); + + state_ = SessionState::Started; + + // Start the push worker thread + pushThread_ = std::thread([this] { PushWorkerLoop(); }); + } + + void LiveAudioTranscriptionSession::Append(const uint8_t* pcmData, size_t length) { + { + std::lock_guard lock(mutex_); + if (state_ != SessionState::Started) { + throw Exception( + state_ == SessionState::Stopped + ? "Session has already been stopped." + : "Session is not started. Call Start() first.", + *logger_); + } + } + + AudioChunk chunk(pcmData, pcmData + length); + if (!pushQueue_->Push(std::move(chunk))) { + throw Exception("Failed to enqueue audio data: session is closed.", *logger_); + } + } + + TranscriptionStatus LiveAudioTranscriptionSession::TryGetNext(LiveAudioTranscriptionResponse& result, + std::chrono::milliseconds timeout) { + { + std::lock_guard lock(mutex_); + if (state_ != SessionState::Started && state_ != SessionState::Stopped) { + throw Exception("Session is not started. Call Start() first.", *logger_); + } + } + + auto status = resultQueue_->TryPop(result, timeout); + switch (status) { + case Internal::DequeueStatus::Item: + return TranscriptionStatus::Result; + case Internal::DequeueStatus::Timeout: + return TranscriptionStatus::Timeout; + case Internal::DequeueStatus::Closed: { + // Return the final result from Stop() if available + std::lock_guard lock(mutex_); + if (hasFinalResult_) { + result = std::move(finalResult_); + hasFinalResult_ = false; + return TranscriptionStatus::Result; + } + return TranscriptionStatus::Closed; + } + case Internal::DequeueStatus::Error: + return TranscriptionStatus::Error; + default: + return TranscriptionStatus::Error; + } + } + + void LiveAudioTranscriptionSession::Stop() { + std::unique_lock lock(mutex_); + if (state_ != SessionState::Started) { + return; + } + StopInternal(lock); + } + + void LiveAudioTranscriptionSession::StopInternal(std::unique_lock& lock) { + state_ = SessionState::Stopped; + std::string handle = sessionHandle_; + + // Close the push queue to signal the worker thread to finish + if (pushQueue_) { + pushQueue_->Close(); + } + + // Close the result queue to unblock any blocked Push() in the worker thread, + // preventing a deadlock when joining below. + if (resultQueue_) { + resultQueue_->Close(); + } + + lock.unlock(); + + // Wait for the push thread to finish (safe now — worker is unblocked) + if (pushThread_.joinable()) { + pushThread_.join(); + } + + // Send stop command to core + CoreInteropRequest req("audio_stream_stop"); + req.AddParam("SessionHandle", handle); + std::string json = req.ToJson(); + + auto response = core_->call(req.Command(), *logger_, &json); + + // Store the final result or error for retrieval via TryGetNext + if (response.HasError()) { + if (resultQueue_) { + resultQueue_->CloseWithError("audio_stream_stop failed: " + response.error); + } + } + else if (!response.data.empty()) { + try { + finalResult_ = LiveAudioTranscriptionResponse::FromJson(response.data); + hasFinalResult_ = true; + } + catch (const std::exception& e) { + logger_->Log(LogLevel::Warning, + std::string("Failed to parse final transcription response: ") + e.what()); + } + } + + lock.lock(); + } + + void LiveAudioTranscriptionSession::PushWorkerLoop() { + AudioChunk chunk; + while (true) { + auto status = pushQueue_->Pop(chunk); + if (status != Internal::DequeueStatus::Item) { + break; + } + + std::string handle; + { + std::lock_guard lock(mutex_); + handle = sessionHandle_; + } + + CoreInteropRequest req("audio_stream_push"); + req.AddParam("SessionHandle", handle); + std::string json = req.ToJson(); + + auto response = core_->callWithBinary(req.Command(), *logger_, &json, + chunk.data(), chunk.size()); + + if (response.HasError()) { + auto coreError = CoreErrorResponse::TryParse(response.error); + std::string msg = + (coreError.has_value() && !coreError->message.empty()) + ? coreError->message + : response.error; + + logger_->Log(LogLevel::Error, "audio_stream_push failed: " + msg); + pushQueue_->Close(); + resultQueue_->CloseWithError(msg); + + std::lock_guard lock(mutex_); + errorMessage_ = std::move(msg); + return; + } + + // Parse the response as a transcription result if there is data + if (!response.data.empty()) { + try { + auto result = LiveAudioTranscriptionResponse::FromJson(response.data); + if (!resultQueue_->TryPush(std::move(result))) { + logger_->Log( + LogLevel::Warning, + "Dropping transcription result because the result queue is full."); + } + } + catch (const std::exception& e) { + logger_->Log(LogLevel::Warning, + std::string("Failed to parse transcription response: ") + e.what()); + } + } + } + } + + std::string LiveAudioTranscriptionSession::GetErrorMessage() const { + std::lock_guard lock(mutex_); + return errorMessage_; + } + + bool LiveAudioTranscriptionSession::IsStarted() const { + std::lock_guard lock(mutex_); + return state_ == SessionState::Started; + } + + bool LiveAudioTranscriptionSession::IsStopped() const { + std::lock_guard lock(mutex_); + return state_ == SessionState::Stopped; + } + +} // namespace foundry_local diff --git a/sdk/cpp/src/openai_live_audio_types.cpp b/sdk/cpp/src/openai_live_audio_types.cpp new file mode 100644 index 00000000..f781a992 --- /dev/null +++ b/sdk/cpp/src/openai_live_audio_types.cpp @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include + +#include "openai/openai_live_audio_types.h" + +namespace foundry_local { + + LiveAudioTranscriptionResponse LiveAudioTranscriptionResponse::FromJson(const std::string& json) { + auto j = nlohmann::json::parse(json); + LiveAudioTranscriptionResponse response; + + if (j.contains("text") && j["text"].is_string()) { + response.text = j["text"].get(); + } + + if (j.contains("is_final") && j["is_final"].is_boolean()) { + response.is_final = j["is_final"].get(); + } + else if (j.contains("isFinal") && j["isFinal"].is_boolean()) { + response.is_final = j["isFinal"].get(); + } + + if (j.contains("start_time") && j["start_time"].is_number()) { + response.start_time = j["start_time"].get(); + } + else if (j.contains("startTime") && j["startTime"].is_number()) { + response.start_time = j["startTime"].get(); + } + + if (j.contains("end_time") && j["end_time"].is_number()) { + response.end_time = j["end_time"].get(); + } + else if (j.contains("endTime") && j["endTime"].is_number()) { + response.end_time = j["endTime"].get(); + } + + if (j.contains("content") && j["content"].is_array()) { + for (const auto& item : j["content"]) { + ContentPart part; + if (item.contains("text") && item["text"].is_string()) { + part.text = item["text"].get(); + } + if (item.contains("transcript") && item["transcript"].is_string()) { + part.transcript = item["transcript"].get(); + } + response.content.push_back(std::move(part)); + } + } + + return response; + } + + std::optional CoreErrorResponse::TryParse(const std::string& error_string) { + try { + auto j = nlohmann::json::parse(error_string); + CoreErrorResponse response; + + if (j.contains("code") && j["code"].is_string()) { + response.code = j["code"].get(); + } + if (j.contains("message") && j["message"].is_string()) { + response.message = j["message"].get(); + } + if (j.contains("is_transient") && j["is_transient"].is_boolean()) { + response.is_transient = j["is_transient"].get(); + } + else if (j.contains("isTransient") && j["isTransient"].is_boolean()) { + response.is_transient = j["isTransient"].get(); + } + + return response; + } + catch (const nlohmann::json::exception&) { + return std::nullopt; + } + } + +} // namespace foundry_local diff --git a/sdk/cpp/src/thread_safe_queue.h b/sdk/cpp/src/thread_safe_queue.h new file mode 100644 index 00000000..c6ea7446 --- /dev/null +++ b/sdk/cpp/src/thread_safe_queue.h @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace foundry_local::Internal { + + enum class DequeueStatus { + Item, + Timeout, + Closed, + Error + }; + + /// A bounded, thread-safe queue with graceful close/error semantics. + template class ThreadSafeQueue final { + public: + explicit ThreadSafeQueue(size_t capacity) : capacity_(capacity) {} + + /// Blocking push. Waits until space is available or the queue is closed. + /// Returns true if the item was enqueued, false if the queue was closed. + bool Push(T item) { + std::unique_lock lock(mutex_); + notFull_.wait(lock, [this] { return queue_.size() < capacity_ || closed_; }); + if (closed_) { + return false; + } + queue_.push(std::move(item)); + notEmpty_.notify_one(); + return true; + } + + /// Non-blocking push. Returns true if the item was enqueued. + bool TryPush(T item) { + std::lock_guard lock(mutex_); + if (closed_ || queue_.size() >= capacity_) { + return false; + } + queue_.push(std::move(item)); + notEmpty_.notify_one(); + return true; + } + + /// Timed push. Returns true if the item was enqueued within the timeout. + bool TryPushFor(T item, std::chrono::milliseconds timeout) { + std::unique_lock lock(mutex_); + if (!notFull_.wait_for(lock, timeout, [this] { return queue_.size() < capacity_ || closed_; })) { + return false; + } + if (closed_) { + return false; + } + queue_.push(std::move(item)); + notEmpty_.notify_one(); + return true; + } + + /// Blocking pop. Waits until an item is available or the queue is closed/errored. + DequeueStatus Pop(T& item) { + std::unique_lock lock(mutex_); + notEmpty_.wait(lock, [this] { return !queue_.empty() || closed_ || hasError_; }); + if (hasError_ && queue_.empty()) { + return DequeueStatus::Error; + } + if (queue_.empty()) { + return DequeueStatus::Closed; + } + item = std::move(queue_.front()); + queue_.pop(); + notFull_.notify_one(); + return DequeueStatus::Item; + } + + /// Timed pop. Returns the dequeue status. + DequeueStatus TryPop(T& item, std::chrono::milliseconds timeout) { + std::unique_lock lock(mutex_); + if (!notEmpty_.wait_for(lock, timeout, [this] { return !queue_.empty() || closed_ || hasError_; })) { + return DequeueStatus::Timeout; + } + if (hasError_ && queue_.empty()) { + return DequeueStatus::Error; + } + if (queue_.empty()) { + return DequeueStatus::Closed; + } + item = std::move(queue_.front()); + queue_.pop(); + notFull_.notify_one(); + return DequeueStatus::Item; + } + + /// Close the queue gracefully. No more items can be pushed. + void Close() { + std::lock_guard lock(mutex_); + closed_ = true; + notEmpty_.notify_all(); + notFull_.notify_all(); + } + + /// Close the queue with an error message. + void CloseWithError(std::string errorMessage) { + std::lock_guard lock(mutex_); + closed_ = true; + hasError_ = true; + errorMessage_ = std::move(errorMessage); + notEmpty_.notify_all(); + notFull_.notify_all(); + } + + bool IsClosed() const { + std::lock_guard lock(mutex_); + return closed_; + } + + bool HasError() const { + std::lock_guard lock(mutex_); + return hasError_; + } + + std::string GetErrorMessage() const { + std::lock_guard lock(mutex_); + return errorMessage_; + } + + private: + const size_t capacity_; + std::queue queue_; + mutable std::mutex mutex_; + std::condition_variable notEmpty_; + std::condition_variable notFull_; + bool closed_ = false; + bool hasError_ = false; + std::string errorMessage_; + }; + +} // namespace foundry_local::Internal diff --git a/sdk/cpp/test/live_audio_test.cpp b/sdk/cpp/test/live_audio_test.cpp new file mode 100644 index 00000000..c6fc10b4 --- /dev/null +++ b/sdk/cpp/test/live_audio_test.cpp @@ -0,0 +1,334 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "mock_core.h" +#include "mock_object_factory.h" +#include "foundry_local_exception.h" + +#include "openai/openai_live_audio_types.h" +#include "openai/openai_live_audio_client.h" + +#include + +#include +#include +#include +#include + +using namespace foundry_local; +using namespace foundry_local::Testing; + +// --------------------------------------------------------------------------- +// LiveAudioTranscriptionResponse parsing tests +// --------------------------------------------------------------------------- + +TEST(LiveAudioTypesTest, FromJson_BasicResponse) { + nlohmann::json j = { + {"text", "hello world"}, + {"is_final", true}, + {"start_time", 0.5}, + {"end_time", 1.5}}; + + auto resp = LiveAudioTranscriptionResponse::FromJson(j.dump()); + EXPECT_EQ("hello world", resp.text); + EXPECT_TRUE(resp.is_final); + ASSERT_TRUE(resp.start_time.has_value()); + EXPECT_DOUBLE_EQ(0.5, resp.start_time.value()); + ASSERT_TRUE(resp.end_time.has_value()); + EXPECT_DOUBLE_EQ(1.5, resp.end_time.value()); +} + +TEST(LiveAudioTypesTest, FromJson_CamelCaseFields) { + nlohmann::json j = { + {"text", "test"}, + {"isFinal", false}, + {"startTime", 1.0}, + {"endTime", 2.0}}; + + auto resp = LiveAudioTranscriptionResponse::FromJson(j.dump()); + EXPECT_EQ("test", resp.text); + EXPECT_FALSE(resp.is_final); + ASSERT_TRUE(resp.start_time.has_value()); + EXPECT_DOUBLE_EQ(1.0, resp.start_time.value()); +} + +TEST(LiveAudioTypesTest, FromJson_WithContent) { + nlohmann::json j = { + {"text", "hello"}, + {"is_final", true}, + {"content", {{{"text", "hi"}, {"transcript", "hi there"}}}}}; + + auto resp = LiveAudioTranscriptionResponse::FromJson(j.dump()); + ASSERT_EQ(1u, resp.content.size()); + EXPECT_EQ("hi", resp.content[0].text); + EXPECT_EQ("hi there", resp.content[0].transcript); +} + +TEST(LiveAudioTypesTest, FromJson_EmptyJson) { + auto resp = LiveAudioTranscriptionResponse::FromJson("{}"); + EXPECT_TRUE(resp.text.empty()); + EXPECT_FALSE(resp.is_final); + EXPECT_FALSE(resp.start_time.has_value()); + EXPECT_FALSE(resp.end_time.has_value()); + EXPECT_TRUE(resp.content.empty()); +} + +TEST(LiveAudioTypesTest, CoreErrorResponse_TryParse_Valid) { + nlohmann::json j = { + {"code", "RATE_LIMITED"}, + {"message", "Too many requests"}, + {"is_transient", true}}; + + auto result = CoreErrorResponse::TryParse(j.dump()); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ("RATE_LIMITED", result->code); + EXPECT_EQ("Too many requests", result->message); + EXPECT_TRUE(result->is_transient); +} + +TEST(LiveAudioTypesTest, CoreErrorResponse_TryParse_Invalid) { + auto result = CoreErrorResponse::TryParse("not json"); + EXPECT_FALSE(result.has_value()); +} + +// --------------------------------------------------------------------------- +// LiveAudioTranscriptionSession tests +// --------------------------------------------------------------------------- + +class LiveAudioSessionTest : public ::testing::Test { +protected: + MockCore core_; + NullLogger logger_; + + void SetUpStartHandlers(const std::string& sessionHandle = "session-123") { + core_.OnCall("audio_stream_start", sessionHandle); + } + + void SetUpPushHandler(const std::string& responseJson = "") { + core_.OnCall("audio_stream_push", + [responseJson](std::string_view, const std::string*, NativeCallbackFn, void*) { + return responseJson; + }); + } + + void SetUpStopHandler() { + core_.OnCall("audio_stream_stop", ""); + } + + void SetUpAllHandlers(const std::string& pushResponse = "") { + SetUpStartHandlers(); + SetUpPushHandler(pushResponse); + SetUpStopHandler(); + } +}; + +TEST_F(LiveAudioSessionTest, ConstructorDefaults) { + LiveAudioTranscriptionSession session(&core_, "whisper-model", &logger_); + EXPECT_FALSE(session.IsStarted()); + EXPECT_FALSE(session.IsStopped()); + EXPECT_EQ(16000, session.Settings().sample_rate); + EXPECT_EQ(1, session.Settings().channels); + EXPECT_EQ(16, session.Settings().bits_per_sample); +} + +TEST_F(LiveAudioSessionTest, SettingsCanBeModifiedBeforeStart) { + LiveAudioTranscriptionSession session(&core_, "whisper-model", &logger_); + session.Settings().sample_rate = 44100; + session.Settings().channels = 2; + session.Settings().language = "en"; + + EXPECT_EQ(44100, session.Settings().sample_rate); + EXPECT_EQ(2, session.Settings().channels); + EXPECT_EQ("en", session.Settings().language.value()); +} + +TEST_F(LiveAudioSessionTest, Start_Success) { + SetUpAllHandlers(); + + LiveAudioTranscriptionSession session(&core_, "whisper-model", &logger_); + session.Start(); + + EXPECT_TRUE(session.IsStarted()); + EXPECT_FALSE(session.IsStopped()); + EXPECT_EQ(16000, session.ActiveSettings().sample_rate); + + session.Stop(); + EXPECT_TRUE(session.IsStopped()); +} + +TEST_F(LiveAudioSessionTest, Start_WithCustomSettings) { + SetUpAllHandlers(); + + LiveAudioTranscriptionSession session(&core_, "whisper-model", &logger_); + session.Settings().sample_rate = 44100; + session.Settings().language = "fr"; + session.Start(); + + EXPECT_EQ(44100, session.ActiveSettings().sample_rate); + EXPECT_EQ("fr", session.ActiveSettings().language.value()); + + // Verify the request included our settings + auto lastArg = core_.GetLastDataArg("audio_stream_start"); + auto parsed = nlohmann::json::parse(lastArg); + EXPECT_EQ("44100", parsed["Params"]["SampleRate"].get()); + EXPECT_EQ("fr", parsed["Params"]["Language"].get()); + + session.Stop(); +} + +TEST_F(LiveAudioSessionTest, Start_Failure) { + core_.OnCallThrow("audio_stream_start", "Connection refused"); + + LiveAudioTranscriptionSession session(&core_, "whisper-model", &logger_); + EXPECT_THROW(session.Start(), Exception); + EXPECT_FALSE(session.IsStarted()); +} + +TEST_F(LiveAudioSessionTest, Start_EmptyHandle) { + core_.OnCall("audio_stream_start", ""); + + LiveAudioTranscriptionSession session(&core_, "whisper-model", &logger_); + EXPECT_THROW(session.Start(), Exception); + EXPECT_FALSE(session.IsStarted()); +} + +TEST_F(LiveAudioSessionTest, DoubleStartThrows) { + SetUpAllHandlers(); + + LiveAudioTranscriptionSession session(&core_, "whisper-model", &logger_); + session.Start(); + EXPECT_THROW(session.Start(), Exception); + + session.Stop(); +} + +TEST_F(LiveAudioSessionTest, AppendBeforeStartThrows) { + LiveAudioTranscriptionSession session(&core_, "whisper-model", &logger_); + std::vector data = {0, 1, 2, 3}; + EXPECT_THROW(session.Append(data.data(), data.size()), Exception); +} + +TEST_F(LiveAudioSessionTest, AppendAfterStopThrows) { + SetUpAllHandlers(); + + LiveAudioTranscriptionSession session(&core_, "whisper-model", &logger_); + session.Start(); + session.Stop(); + std::vector data = {0, 1, 2, 3}; + EXPECT_THROW(session.Append(data.data(), data.size()), Exception); +} + +TEST_F(LiveAudioSessionTest, Start_InvalidCapacityThrows) { + LiveAudioTranscriptionSession session(&core_, "whisper-model", &logger_); + session.Settings().push_queue_capacity = 0; + EXPECT_THROW(session.Start(), Exception); +} + +TEST_F(LiveAudioSessionTest, StopParseFinalResponse) { + SetUpStartHandlers(); + SetUpPushHandler(); + + // audio_stream_stop returns a final transcription result + nlohmann::json finalResponse = { + {"text", "final result"}, + {"is_final", true}}; + core_.OnCall("audio_stream_stop", finalResponse.dump()); + + LiveAudioTranscriptionSession session(&core_, "whisper-model", &logger_); + session.Start(); + session.Stop(); + + // The final result should be retrievable from the result queue + LiveAudioTranscriptionResponse result; + auto status = session.TryGetNext(result, std::chrono::milliseconds(100)); + EXPECT_EQ(TranscriptionStatus::Result, status); + EXPECT_EQ("final result", result.text); + EXPECT_TRUE(result.is_final); +} + +TEST_F(LiveAudioSessionTest, AppendAndGetResult) { + nlohmann::json pushResponse = { + {"text", "hello"}, + {"is_final", false}}; + SetUpAllHandlers(pushResponse.dump()); + + LiveAudioTranscriptionSession session(&core_, "whisper-model", &logger_); + session.Start(); + + // Append some data + std::vector data(320, 0); + session.Append(data.data(), data.size()); + + // Try to get a result + LiveAudioTranscriptionResponse result; + auto status = session.TryGetNext(result, std::chrono::seconds(2)); + + if (status == TranscriptionStatus::Result) { + EXPECT_EQ("hello", result.text); + EXPECT_FALSE(result.is_final); + } + + session.Stop(); +} + +TEST_F(LiveAudioSessionTest, StopSendsCommand) { + SetUpAllHandlers(); + + LiveAudioTranscriptionSession session(&core_, "whisper-model", &logger_); + session.Start(); + session.Stop(); + + EXPECT_EQ(1, core_.GetCallCount("audio_stream_stop")); + + auto lastArg = core_.GetLastDataArg("audio_stream_stop"); + auto parsed = nlohmann::json::parse(lastArg); + EXPECT_EQ("session-123", parsed["Params"]["SessionHandle"].get()); +} + +TEST_F(LiveAudioSessionTest, StopWhenNotStartedIsNoop) { + LiveAudioTranscriptionSession session(&core_, "whisper-model", &logger_); + session.Stop(); // Should not throw + EXPECT_EQ(0, core_.GetCallCount("audio_stream_stop")); +} + +TEST_F(LiveAudioSessionTest, DoubleStopIsNoop) { + SetUpAllHandlers(); + + LiveAudioTranscriptionSession session(&core_, "whisper-model", &logger_); + session.Start(); + session.Stop(); + session.Stop(); // Should not throw or send a second command + EXPECT_EQ(1, core_.GetCallCount("audio_stream_stop")); +} + +TEST_F(LiveAudioSessionTest, DestructorStopsSession) { + SetUpAllHandlers(); + + { + LiveAudioTranscriptionSession session(&core_, "whisper-model", &logger_); + session.Start(); + // Destructor should call Stop + } + + EXPECT_EQ(1, core_.GetCallCount("audio_stream_stop")); +} + +TEST_F(LiveAudioSessionTest, TryGetNextTimeout) { + SetUpAllHandlers(); + + LiveAudioTranscriptionSession session(&core_, "whisper-model", &logger_); + session.Start(); + + LiveAudioTranscriptionResponse result; + auto status = session.TryGetNext(result, std::chrono::milliseconds(50)); + EXPECT_EQ(TranscriptionStatus::Timeout, status); + + session.Stop(); +} + +TEST_F(LiveAudioSessionTest, GetErrorMessage_NoError) { + LiveAudioTranscriptionSession session(&core_, "whisper-model", &logger_); + EXPECT_TRUE(session.GetErrorMessage().empty()); +} diff --git a/sdk/cpp/test/mock_core.h b/sdk/cpp/test/mock_core.h index f89af91a..e7b5f84c 100644 --- a/sdk/cpp/test/mock_core.h +++ b/sdk/cpp/test/mock_core.h @@ -81,6 +81,13 @@ namespace foundry_local::Testing { return resp; } + CoreResponse callWithBinary(std::string_view command, ILogger& logger, + const std::string* dataArgument, + const uint8_t* /*binaryData*/, size_t /*binaryDataLength*/) const override { + // Route through regular call() for testing + return call(command, logger, dataArgument); + } + void unload() override {} private: @@ -147,6 +154,12 @@ namespace foundry_local::Testing { return resp; } + CoreResponse callWithBinary(std::string_view command, ILogger& logger, + const std::string* dataArgument, + const uint8_t* /*binaryData*/, size_t /*binaryDataLength*/) const override { + return call(command, logger, dataArgument); + } + void unload() override {} private: