Skip to content

Commit 1fe95bd

Browse files
bmehta001Copilot
andcommitted
Add SDK download cancellation support
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 3f94e83 commit 1fe95bd

41 files changed

Lines changed: 1444 additions & 205 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

samples/cpp/live-audio-transcription/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,7 @@ g++ -std=c++20 main.cpp -lfoundry_local -o live-audio-transcription-example
2626
# Synthetic 440Hz sine wave (no microphone needed)
2727
./live-audio-transcription-example --synth
2828
```
29+
30+
Press `Ctrl+C` to request a graceful stop. The sample passes that signal to
31+
execution-provider and model downloads so long-running downloads can be
32+
cancelled before transcription starts.

samples/cpp/live-audio-transcription/main.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ int main(int argc, char* argv[]) {
122122

123123
foundry_local::Manager::Create(config);
124124
auto& manager = foundry_local::Manager::Instance();
125-
manager.EnsureEpsDownloaded();
125+
auto isCancellationRequested = [] { return !g_running.load(); };
126+
manager.EnsureEpsDownloaded(nullptr, isCancellationRequested);
126127

127128
auto& catalog = manager.GetCatalog();
128129
auto* model = catalog.GetModel("nemotron-speech-streaming-en-0.6b");
@@ -131,9 +132,11 @@ int main(int argc, char* argv[]) {
131132
}
132133

133134
std::cout << "Downloading model (if needed)..." << std::endl;
134-
model->Download([](float pct) {
135-
std::cout << "\rDownloading: " << pct << "% " << std::flush;
136-
});
135+
model->Download(
136+
[](float pct) {
137+
std::cout << "\rDownloading: " << pct << "% " << std::flush;
138+
},
139+
isCancellationRequested);
137140
std::cout << std::endl;
138141
std::cout << "Loading model..." << std::endl;
139142
model->Load();

sdk/cpp/include/foundry_local_manager.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <string>
77
#include <vector>
88
#include <memory>
9+
#include <functional>
910

1011
#include <gsl/pointers>
1112
#include <gsl/span>
@@ -20,6 +21,8 @@ namespace foundry_local::Internal {
2021

2122
namespace foundry_local {
2223

24+
using ExecutionProviderDownloadProgressCallback = std::function<void(std::string epName, double percentage)>;
25+
2326
class Manager final {
2427
public:
2528
Manager(const Manager&) = delete;
@@ -61,7 +64,11 @@ namespace foundry_local {
6164

6265
/// Ensure execution providers are downloaded and registered.
6366
/// Once downloaded, EPs are not re-downloaded unless a new version is available.
64-
void EnsureEpsDownloaded() const;
67+
/// @param onProgress Optional callback receiving execution provider name and percentage progress.
68+
/// @param isCancellationRequested Optional callback checked on each progress update.
69+
/// Return true to cancel the in-progress download.
70+
void EnsureEpsDownloaded(ExecutionProviderDownloadProgressCallback onProgress = nullptr,
71+
CancellationCallback isCancellationRequested = nullptr) const;
6572

6673
private:
6774
explicit Manager(Configuration configuration, ILogger* logger);

sdk/cpp/include/model.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <memory>
1111
#include <functional>
1212
#include <filesystem>
13+
#include <utility>
1314

1415
#include <gsl/pointers>
1516
#include <gsl/span>
@@ -33,6 +34,7 @@ namespace foundry_local {
3334
#endif
3435

3536
using DownloadProgressCallback = std::function<bool(float percentage)>;
37+
using CancellationCallback = std::function<bool()>;
3638

3739
class IModel {
3840
public:
@@ -43,7 +45,13 @@ namespace foundry_local {
4345
virtual bool IsLoaded() const = 0;
4446
virtual bool IsCached() const = 0;
4547
virtual const std::filesystem::path& GetPath() const = 0;
46-
virtual void Download(DownloadProgressCallback onProgress = nullptr) = 0;
48+
49+
/// Download the model to the local cache.
50+
/// @param onProgress Optional callback receiving percentage progress. Return true to continue.
51+
/// @param isCancellationRequested Optional callback checked on each progress update.
52+
/// Return true to cancel the in-progress download.
53+
virtual void Download(DownloadProgressCallback onProgress = nullptr,
54+
CancellationCallback isCancellationRequested = nullptr) = 0;
4755
virtual void Load() = 0;
4856
virtual void Unload() = 0;
4957
virtual void RemoveFromCache() = 0;
@@ -123,7 +131,8 @@ namespace foundry_local {
123131

124132
const ModelInfo& GetInfo() const;
125133
const std::filesystem::path& GetPath() const override;
126-
void Download(DownloadProgressCallback onProgress = nullptr) override;
134+
void Download(DownloadProgressCallback onProgress = nullptr,
135+
CancellationCallback isCancellationRequested = nullptr) override;
127136
void Load() override;
128137

129138
bool IsLoaded() const override;
@@ -158,8 +167,9 @@ namespace foundry_local {
158167
bool IsLoaded() const override { return SelectedVariant().IsLoaded(); }
159168
bool IsCached() const override { return SelectedVariant().IsCached(); }
160169
const std::filesystem::path& GetPath() const override { return SelectedVariant().GetPath(); }
161-
void Download(DownloadProgressCallback onProgress = nullptr) override {
162-
SelectedVariant().Download(std::move(onProgress));
170+
void Download(DownloadProgressCallback onProgress = nullptr,
171+
CancellationCallback isCancellationRequested = nullptr) override {
172+
SelectedVariant().Download(std::move(onProgress), std::move(isCancellationRequested));
163173
}
164174
void Load() override { SelectedVariant().Load(); }
165175
void Unload() override { SelectedVariant().Unload(); }

sdk/cpp/sample/main.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
#include "foundry_local.h"
55

6+
#include <atomic>
7+
#include <csignal>
68
#include <iostream>
79
#include <string>
810
#include <vector>
@@ -13,6 +15,18 @@
1315

1416
using namespace foundry_local;
1517

18+
namespace {
19+
std::atomic<bool> g_cancelRequested{false};
20+
21+
void SignalHandler(int /*signum*/) {
22+
g_cancelRequested.store(true);
23+
}
24+
25+
bool IsCancellationRequested() {
26+
return g_cancelRequested.load();
27+
}
28+
} // namespace
29+
1630
// ---------------------------------------------------------------------------
1731
// Logger
1832
// ---------------------------------------------------------------------------
@@ -117,7 +131,8 @@ void ChatNonStreaming(Manager& manager, const std::string& alias) {
117131
PreferCpuVariant(*concreteModel);
118132
}
119133

120-
model->Download([](float pct) { std::cout << "\rDownloading: " << pct << "% " << std::flush; return true; });
134+
model->Download([](float pct) { std::cout << "\rDownloading: " << pct << "% " << std::flush; return true; },
135+
IsCancellationRequested);
121136
std::cout << "\n";
122137

123138
model->Load();
@@ -210,7 +225,8 @@ void TranscribeAudio(Manager& manager, const std::string& alias, const std::stri
210225
PreferCpuVariant(*concreteModel);
211226
}
212227

213-
model->Download([](float pct) { std::cout << "\rDownloading: " << pct << "% " << std::flush; return true; });
228+
model->Download([](float pct) { std::cout << "\rDownloading: " << pct << "% " << std::flush; return true; },
229+
IsCancellationRequested);
214230
std::cout << "\n";
215231

216232
model->Load();
@@ -262,7 +278,8 @@ void ChatWithToolCalling(Manager& manager, const std::string& alias) {
262278
PreferCpuVariant(*concreteModel);
263279
}
264280

265-
model->Download([](float pct) { std::cout << "\rDownloading: " << pct << "% " << std::flush; return true; });
281+
model->Download([](float pct) { std::cout << "\rDownloading: " << pct << "% " << std::flush; return true; },
282+
IsCancellationRequested);
266283
std::cout << "\n";
267284

268285
model->Load();
@@ -375,6 +392,8 @@ int main(int argc, char* argv[]) {
375392
const std::string audioPath = (argc > 3) ? argv[3] : "";
376393

377394
try {
395+
std::signal(SIGINT, SignalHandler);
396+
378397
StdLogger logger;
379398
Manager::Create({"SampleApp"}, &logger);
380399
auto& manager = Manager::Instance();

sdk/cpp/src/core.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ namespace foundry_local {
187187
std::unique_ptr<ResponseBuffer, decltype(safeDeleter)> responseGuard(&response, safeDeleter);
188188

189189
if (callback != nullptr) {
190-
execCbCmd_(&request, &response, reinterpret_cast<void*>(callback), data);
190+
execCbCmd_(&request, &response, callback, data);
191191
}
192192
else {
193193
execCmd_(&request, &response);

sdk/cpp/src/core_helpers.h

Lines changed: 91 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@
66

77
#pragma once
88

9+
#include <charconv>
910
#include <string>
1011
#include <string_view>
1112
#include <vector>
1213
#include <functional>
1314
#include <exception>
15+
#include <system_error>
1416
#include <unordered_map>
17+
#include <utility>
1518

1619
#include <nlohmann/json.hpp>
1720

@@ -47,38 +50,82 @@ namespace foundry_local::detail {
4750
return core->call(command, logger, &payload, callback, userData);
4851
}
4952

53+
inline bool TryParseFloatToken(std::string_view token, float& value) {
54+
if (token.empty()) {
55+
return false;
56+
}
57+
58+
const auto* begin = token.data();
59+
const auto* end = begin + token.size();
60+
const auto result = std::from_chars(begin, end, value);
61+
return result.ec == std::errc{} && result.ptr == end;
62+
}
63+
64+
inline bool TryParseDoubleToken(std::string_view token, double& value) {
65+
if (token.empty()) {
66+
return false;
67+
}
68+
69+
const auto* begin = token.data();
70+
const auto* end = begin + token.size();
71+
const auto result = std::from_chars(begin, end, value);
72+
return result.ec == std::errc{} && result.ptr == end;
73+
}
74+
5075
// Serialize + call with a streaming chunk handler.
5176
// Wraps the caller-supplied onChunk with the native callback boilerplate
52-
// (null/length checks, exception capture, rethrow after the call).
77+
// (null/length checks, exception capture, cancellation, rethrow after the call).
5378
// The errorContext string is used to prefix any core-layer error message.
5479
inline CoreResponse CallWithStreamingCallback(Internal::IFoundryLocalCore* core, std::string_view command,
55-
const std::string& payload, ILogger& logger,
56-
const std::function<void(const std::string&)>& onChunk,
57-
std::string_view errorContext) {
80+
const std::string* payload, ILogger& logger,
81+
const std::function<bool(const std::string&)>& onChunk,
82+
std::string_view errorContext,
83+
CancellationCallback isCancellationRequested = nullptr) {
5884
struct State {
59-
const std::function<void(const std::string&)>* cb;
85+
const std::function<bool(const std::string&)>* cb;
86+
CancellationCallback isCancellationRequested;
87+
bool cancellationObserved = false;
6088
std::exception_ptr exception;
61-
} state{&onChunk, nullptr};
89+
} state{&onChunk, std::move(isCancellationRequested), false, nullptr};
6290

63-
auto nativeCallback = [](void* data, int32_t len, void* user) -> int {
64-
if (!data || len <= 0)
91+
auto nativeCallback = [](const void* data, int32_t len, void* user) -> int32_t {
92+
auto* st = static_cast<State*>(user);
93+
if (!st) {
6594
return 0;
95+
}
6696

67-
auto* st = static_cast<State*>(user);
68-
if (st->exception)
97+
if (st->exception || st->cancellationObserved) {
98+
return 1;
99+
}
100+
101+
if (!data || len <= 0)
69102
return 0;
70103

71104
try {
105+
if (st->isCancellationRequested && st->isCancellationRequested()) {
106+
st->cancellationObserved = true;
107+
return 1;
108+
}
109+
72110
std::string chunk(static_cast<const char*>(data), static_cast<size_t>(len));
73-
(*(st->cb))(chunk);
111+
if (!(*(st->cb))(chunk)) {
112+
st->cancellationObserved = true;
113+
return 1;
114+
}
74115
}
75116
catch (...) {
76117
st->exception = std::current_exception();
118+
return 1;
77119
}
120+
78121
return 0;
79122
};
80123

81-
auto response = core->call(command, logger, &payload, +nativeCallback, &state);
124+
auto response = core->call(command, logger, payload, +nativeCallback, &state);
125+
if (state.cancellationObserved) {
126+
throw Exception("Operation cancelled", logger);
127+
}
128+
82129
if (response.HasError()) {
83130
throw Exception(std::string(errorContext) + response.error, logger);
84131
}
@@ -90,6 +137,38 @@ namespace foundry_local::detail {
90137
return response;
91138
}
92139

140+
inline CoreResponse CallWithStreamingCallback(Internal::IFoundryLocalCore* core, std::string_view command,
141+
const std::string* payload, ILogger& logger,
142+
const std::function<void(const std::string&)>& onChunk,
143+
std::string_view errorContext,
144+
CancellationCallback isCancellationRequested = nullptr) {
145+
const std::function<bool(const std::string&)> continuingOnChunk =
146+
[&onChunk](const std::string& chunk) {
147+
onChunk(chunk);
148+
return true;
149+
};
150+
return CallWithStreamingCallback(core, command, payload, logger, continuingOnChunk, errorContext,
151+
std::move(isCancellationRequested));
152+
}
153+
154+
inline CoreResponse CallWithStreamingCallback(Internal::IFoundryLocalCore* core, std::string_view command,
155+
const std::string& payload, ILogger& logger,
156+
const std::function<bool(const std::string&)>& onChunk,
157+
std::string_view errorContext,
158+
CancellationCallback isCancellationRequested = nullptr) {
159+
return CallWithStreamingCallback(core, command, &payload, logger, onChunk, errorContext,
160+
std::move(isCancellationRequested));
161+
}
162+
163+
inline CoreResponse CallWithStreamingCallback(Internal::IFoundryLocalCore* core, std::string_view command,
164+
const std::string& payload, ILogger& logger,
165+
const std::function<void(const std::string&)>& onChunk,
166+
std::string_view errorContext,
167+
CancellationCallback isCancellationRequested = nullptr) {
168+
return CallWithStreamingCallback(core, command, &payload, logger, onChunk, errorContext,
169+
std::move(isCancellationRequested));
170+
}
171+
93172
// Overload: allow Params object directly
94173
inline CoreResponse CallWithParams(Internal::IFoundryLocalCore* core, std::string_view command,
95174
const nlohmann::json& params, ILogger& logger) {

sdk/cpp/src/flcore_native.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
#include <cstdint>
66
#include <type_traits>
77

8-
#ifdef _WIN32
9-
#define FL_CDECL __cdecl
10-
#else
11-
#define FL_CDECL
8+
#ifndef FL_CDECL
9+
#ifdef _WIN32
10+
#define FL_CDECL __cdecl
11+
#else
12+
#define FL_CDECL
13+
#endif
1214
#endif
1315

1416
extern "C"
@@ -29,8 +31,9 @@ extern "C"
2931
int32_t ErrorLength;
3032
};
3133

32-
// Callback signature: int(*)(void* data, int length, void* userData) — returns 0 to continue, 1 to cancel
33-
using UserCallbackFn = int(__cdecl*)(void*, int32_t, void*);
34+
// Callback signature: int32_t(*)(const void* data, int length, void* userData)
35+
// Return 0 to continue, 1 to cancel.
36+
using UserCallbackFn = int32_t(FL_CDECL*)(const void*, int32_t, void*);
3437

3538
struct StreamingRequestBuffer {
3639
const void* Command;
@@ -43,7 +46,8 @@ extern "C"
4346

4447
// Exported function pointer types
4548
using execute_command_fn = void(FL_CDECL*)(RequestBuffer*, ResponseBuffer*);
46-
using execute_command_with_callback_fn = void(FL_CDECL*)(RequestBuffer*, ResponseBuffer*, void* /*callback*/,
49+
using execute_command_with_callback_fn = void(FL_CDECL*)(RequestBuffer*, ResponseBuffer*,
50+
UserCallbackFn /*callback*/,
4751
void* /*userData*/);
4852
using execute_command_with_binary_fn = void(FL_CDECL*)(StreamingRequestBuffer*, ResponseBuffer*);
4953
using free_response_fn = void(FL_CDECL*)(ResponseBuffer*);

0 commit comments

Comments
 (0)