Skip to content

Commit 2796134

Browse files
bmehta001Copilot
andcommitted
Add SDK download cancellation support
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 58bba93 commit 2796134

40 files changed

Lines changed: 1294 additions & 180 deletions

samples/cpp/live-audio-transcription/main.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ int main(int argc, char* argv[]) {
122122

123123
foundry_local::Manager::Create(config);
124124
auto& manager = foundry_local::Manager::Instance();
125-
manager.EnsureEpsDownloaded();
125+
auto isCancellationRequested = [] { return !g_running.load(); };
126+
manager.EnsureEpsDownloaded(nullptr, isCancellationRequested);
126127

127128
auto& catalog = manager.GetCatalog();
128129
auto* model = catalog.GetModel("nemotron-speech-streaming-en-0.6b");
@@ -131,9 +132,11 @@ int main(int argc, char* argv[]) {
131132
}
132133

133134
std::cout << "Downloading model (if needed)..." << std::endl;
134-
model->Download([](float pct) {
135-
std::cout << "\rDownloading: " << pct << "% " << std::flush;
136-
});
135+
model->Download(
136+
[](float pct) {
137+
std::cout << "\rDownloading: " << pct << "% " << std::flush;
138+
},
139+
isCancellationRequested);
137140
std::cout << std::endl;
138141
std::cout << "Loading model..." << std::endl;
139142
model->Load();

sdk/cpp/include/foundry_local_manager.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <string>
66
#include <vector>
77
#include <memory>
8+
#include <functional>
89

910
#include <gsl/pointers>
1011
#include <gsl/span>
@@ -19,6 +20,8 @@ namespace foundry_local::Internal {
1920

2021
namespace foundry_local {
2122

23+
using ExecutionProviderDownloadProgressCallback = std::function<void(std::string epName, double percentage)>;
24+
2225
class Manager final {
2326
public:
2427
Manager(const Manager&) = delete;
@@ -61,7 +64,8 @@ namespace foundry_local {
6164

6265
/// Ensure execution providers are downloaded and registered.
6366
/// Once downloaded, EPs are not re-downloaded unless a new version is available.
64-
void EnsureEpsDownloaded() const;
67+
void EnsureEpsDownloaded(ExecutionProviderDownloadProgressCallback onProgress = nullptr,
68+
CancellationCallback isCancellationRequested = nullptr) const;
6569

6670
private:
6771
explicit Manager(Configuration configuration, ILogger* logger);

sdk/cpp/include/model.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <memory>
1111
#include <functional>
1212
#include <filesystem>
13+
#include <utility>
1314

1415
#include <gsl/pointers>
1516
#include <gsl/span>
@@ -33,6 +34,7 @@ namespace foundry_local {
3334
#endif
3435

3536
using DownloadProgressCallback = std::function<void(float percentage)>;
37+
using CancellationCallback = std::function<bool()>;
3638

3739
class IModel {
3840
public:
@@ -43,7 +45,8 @@ namespace foundry_local {
4345
virtual bool IsLoaded() const = 0;
4446
virtual bool IsCached() const = 0;
4547
virtual const std::filesystem::path& GetPath() const = 0;
46-
virtual void Download(DownloadProgressCallback onProgress = nullptr) = 0;
48+
virtual void Download(DownloadProgressCallback onProgress = nullptr,
49+
CancellationCallback isCancellationRequested = nullptr) = 0;
4750
virtual void Load() = 0;
4851
virtual void Unload() = 0;
4952
virtual void RemoveFromCache() = 0;
@@ -123,7 +126,8 @@ namespace foundry_local {
123126

124127
const ModelInfo& GetInfo() const;
125128
const std::filesystem::path& GetPath() const override;
126-
void Download(DownloadProgressCallback onProgress = nullptr) override;
129+
void Download(DownloadProgressCallback onProgress = nullptr,
130+
CancellationCallback isCancellationRequested = nullptr) override;
127131
void Load() override;
128132

129133
bool IsLoaded() const override;
@@ -158,8 +162,9 @@ namespace foundry_local {
158162
bool IsLoaded() const override { return SelectedVariant().IsLoaded(); }
159163
bool IsCached() const override { return SelectedVariant().IsCached(); }
160164
const std::filesystem::path& GetPath() const override { return SelectedVariant().GetPath(); }
161-
void Download(DownloadProgressCallback onProgress = nullptr) override {
162-
SelectedVariant().Download(std::move(onProgress));
165+
void Download(DownloadProgressCallback onProgress = nullptr,
166+
CancellationCallback isCancellationRequested = nullptr) override {
167+
SelectedVariant().Download(std::move(onProgress), std::move(isCancellationRequested));
163168
}
164169
void Load() override { SelectedVariant().Load(); }
165170
void Unload() override { SelectedVariant().Unload(); }

sdk/cpp/sample/main.cpp

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,26 @@
33

44
#include "foundry_local.h"
55

6+
#include <atomic>
7+
#include <csignal>
68
#include <iostream>
79
#include <string>
810
#include <vector>
911

1012
using namespace foundry_local;
1113

14+
namespace {
15+
std::atomic<bool> g_cancelRequested{false};
16+
17+
void SignalHandler(int /*signum*/) {
18+
g_cancelRequested.store(true);
19+
}
20+
21+
bool IsCancellationRequested() {
22+
return g_cancelRequested.load();
23+
}
24+
} // namespace
25+
1226
// ---------------------------------------------------------------------------
1327
// Logger
1428
// ---------------------------------------------------------------------------
@@ -93,7 +107,8 @@ void ChatNonStreaming(Manager& manager, const std::string& alias) {
93107
return;
94108
}
95109

96-
model->Download([](float pct) { std::cout << "\rDownloading: " << pct << "% " << std::flush; });
110+
model->Download([](float pct) { std::cout << "\rDownloading: " << pct << "% " << std::flush; },
111+
IsCancellationRequested);
97112
std::cout << "\n";
98113

99114
model->Load();
@@ -176,7 +191,8 @@ void TranscribeAudio(Manager& manager, const std::string& alias, const std::stri
176191
return;
177192
}
178193

179-
model->Download([](float pct) { std::cout << "\rDownloading: " << pct << "% " << std::flush; });
194+
model->Download([](float pct) { std::cout << "\rDownloading: " << pct << "% " << std::flush; },
195+
IsCancellationRequested);
180196
std::cout << "\n";
181197

182198
model->Load();
@@ -223,7 +239,8 @@ void ChatWithToolCalling(Manager& manager, const std::string& alias) {
223239
return;
224240
}
225241

226-
model->Download([](float pct) { std::cout << "\rDownloading: " << pct << "% " << std::flush; });
242+
model->Download([](float pct) { std::cout << "\rDownloading: " << pct << "% " << std::flush; },
243+
IsCancellationRequested);
227244
std::cout << "\n";
228245

229246
model->Load();
@@ -327,6 +344,8 @@ void ChatWithToolCalling(Manager& manager, const std::string& alias) {
327344
// ---------------------------------------------------------------------------
328345
int main() {
329346
try {
347+
std::signal(SIGINT, SignalHandler);
348+
330349
StdLogger logger;
331350
Manager::Create({"SampleApp"}, &logger);
332351
auto& manager = Manager::Instance();

sdk/cpp/src/core.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ namespace foundry_local {
7373
std::unique_ptr<ResponseBuffer, decltype(safeDeleter)> responseGuard(&response, safeDeleter);
7474

7575
if (callback != nullptr) {
76-
execCbCmd_(&request, &response, reinterpret_cast<void*>(callback), data);
76+
execCbCmd_(&request, &response, callback, data);
7777
}
7878
else {
7979
execCmd_(&request, &response);

sdk/cpp/src/core_helpers.h

Lines changed: 75 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <functional>
1313
#include <exception>
1414
#include <unordered_map>
15+
#include <utility>
1516

1617
#include <nlohmann/json.hpp>
1718

@@ -47,37 +48,91 @@ namespace foundry_local::detail {
4748
return core->call(command, logger, &payload, callback, userData);
4849
}
4950

51+
inline bool TryParseFloatToken(std::string_view token, float& value) {
52+
if (token.empty()) {
53+
return false;
54+
}
55+
56+
const std::string text(token);
57+
size_t processed = 0;
58+
try {
59+
value = std::stof(text, &processed);
60+
}
61+
catch (...) {
62+
return false;
63+
}
64+
65+
return processed == text.size();
66+
}
67+
68+
inline bool TryParseDoubleToken(std::string_view token, double& value) {
69+
if (token.empty()) {
70+
return false;
71+
}
72+
73+
const std::string text(token);
74+
size_t processed = 0;
75+
try {
76+
value = std::stod(text, &processed);
77+
}
78+
catch (...) {
79+
return false;
80+
}
81+
82+
return processed == text.size();
83+
}
84+
5085
// Serialize + call with a streaming chunk handler.
5186
// Wraps the caller-supplied onChunk with the native callback boilerplate
5287
// (null/length checks, exception capture, rethrow after the call).
5388
// The errorContext string is used to prefix any core-layer error message.
5489
inline CoreResponse CallWithStreamingCallback(Internal::IFoundryLocalCore* core, std::string_view command,
55-
const std::string& payload, ILogger& logger,
56-
const std::function<void(const std::string&)>& onChunk,
57-
std::string_view errorContext) {
90+
const std::string* payload, ILogger& logger,
91+
const std::function<void(const std::string&)>& onChunk,
92+
std::string_view errorContext,
93+
CancellationCallback isCancellationRequested = nullptr) {
5894
struct State {
5995
const std::function<void(const std::string&)>* cb;
96+
CancellationCallback isCancellationRequested;
97+
bool cancellationObserved = false;
6098
std::exception_ptr exception;
61-
} state{&onChunk, nullptr};
62-
63-
auto nativeCallback = [](void* data, int32_t len, void* user) {
64-
if (!data || len <= 0)
65-
return;
99+
} state{&onChunk, std::move(isCancellationRequested), false, nullptr};
66100

101+
auto nativeCallback = [](const void* data, int32_t len, void* user) -> int32_t {
67102
auto* st = static_cast<State*>(user);
68-
if (st->exception)
69-
return;
103+
if (!st) {
104+
return 0;
105+
}
106+
107+
if (st->exception || st->cancellationObserved) {
108+
return 1;
109+
}
110+
111+
if (!data || len <= 0)
112+
return 0;
70113

71114
try {
115+
if (st->isCancellationRequested && st->isCancellationRequested()) {
116+
st->cancellationObserved = true;
117+
return 1;
118+
}
119+
72120
std::string chunk(static_cast<const char*>(data), static_cast<size_t>(len));
73121
(*(st->cb))(chunk);
74122
}
75123
catch (...) {
76124
st->exception = std::current_exception();
125+
return 1;
77126
}
127+
128+
return 0;
78129
};
79130

80-
auto response = core->call(command, logger, &payload, +nativeCallback, &state);
131+
auto response = core->call(command, logger, payload, +nativeCallback, &state);
132+
if (state.cancellationObserved) {
133+
throw Exception("Operation cancelled", logger);
134+
}
135+
81136
if (response.HasError()) {
82137
throw Exception(std::string(errorContext) + response.error, logger);
83138
}
@@ -89,6 +144,15 @@ namespace foundry_local::detail {
89144
return response;
90145
}
91146

147+
inline CoreResponse CallWithStreamingCallback(Internal::IFoundryLocalCore* core, std::string_view command,
148+
const std::string& payload, ILogger& logger,
149+
const std::function<void(const std::string&)>& onChunk,
150+
std::string_view errorContext,
151+
CancellationCallback isCancellationRequested = nullptr) {
152+
return CallWithStreamingCallback(core, command, &payload, logger, onChunk, errorContext,
153+
std::move(isCancellationRequested));
154+
}
155+
92156
// Overload: allow Params object directly
93157
inline CoreResponse CallWithParams(Internal::IFoundryLocalCore* core, std::string_view command,
94158
const nlohmann::json& params, ILogger& logger) {

sdk/cpp/src/flcore_native.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ extern "C"
2323
int32_t ErrorLength;
2424
};
2525

26-
// Callback signature: void(*)(void* data, int length, void* userData)
27-
using UserCallbackFn = void(__cdecl*)(void*, int32_t, void*);
26+
// Callback signature: int32_t(*)(const void* data, int length, void* userData)
27+
// Return 0 to continue, 1 to cancel.
28+
using UserCallbackFn = int32_t(__cdecl*)(const void*, int32_t, void*);
2829

2930
struct StreamingRequestBuffer {
3031
const void* Command;
@@ -37,8 +38,8 @@ extern "C"
3738

3839
// Exported function pointer types
3940
using execute_command_fn = void(__cdecl*)(RequestBuffer*, ResponseBuffer*);
40-
using execute_command_with_callback_fn = void(__cdecl*)(RequestBuffer*, ResponseBuffer*, void* /*callback*/,
41-
void* /*userData*/);
41+
using execute_command_with_callback_fn = void(__cdecl*)(RequestBuffer*, ResponseBuffer*,
42+
UserCallbackFn /*callback*/, void* /*userData*/);
4243
using execute_command_with_binary_fn = void(__cdecl*)(StreamingRequestBuffer*, ResponseBuffer*);
4344
using free_response_fn = void(__cdecl*)(ResponseBuffer*);
4445

sdk/cpp/src/foundry_local_internal_core.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ namespace foundry_local {
1212

1313
/// Native callback signature used by the core DLL interop.
1414
/// Parameters: (data, dataLength, userData).
15-
using NativeCallbackFn = void (*)(void*, int32_t, void*);
15+
/// Return 0 to continue, 1 to cancel the native operation.
16+
using NativeCallbackFn = int32_t(__cdecl*)(const void*, int32_t, void*);
1617

1718
/// Value returned by IFoundryLocalCore::call().
1819
/// On success, `data` contains the response payload and `error` is empty.
@@ -40,4 +41,4 @@ namespace foundry_local {
4041
};
4142

4243
} // namespace Internal
43-
} // namespace foundry_local
44+
} // namespace foundry_local

sdk/cpp/src/foundry_local_manager.cpp

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "foundry_local_internal_core.h"
1313
#include "foundry_local_exception.h"
1414
#include "core_interop_request.h"
15+
#include "core_helpers.h"
1516
#include "core.h"
1617
#include "logger.h"
1718

@@ -128,10 +129,36 @@ void Manager::Cleanup() noexcept {
128129
return urls_;
129130
}
130131

131-
void Manager::EnsureEpsDownloaded() const {
132-
auto response = core_->call("ensure_eps_downloaded", *logger_);
133-
if (response.HasError()) {
134-
throw Exception(std::string("Error ensuring execution providers downloaded: ") + response.error, *logger_);
132+
void Manager::EnsureEpsDownloaded(ExecutionProviderDownloadProgressCallback onProgress,
133+
CancellationCallback isCancellationRequested) const {
134+
if (onProgress || isCancellationRequested) {
135+
auto onChunk = [&onProgress](const std::string& chunk) {
136+
if (!onProgress) {
137+
return;
138+
}
139+
140+
const auto sep = chunk.find('|');
141+
if (sep == std::string::npos) {
142+
return;
143+
}
144+
145+
double percent = 0.0;
146+
if (detail::TryParseDoubleToken(std::string_view(chunk).substr(sep + 1), percent)) {
147+
onProgress(chunk.substr(0, sep), percent);
148+
}
149+
};
150+
151+
detail::CallWithStreamingCallback(core_.get(), "download_and_register_eps",
152+
static_cast<const std::string*>(nullptr), *logger_, onChunk,
153+
"Error ensuring execution providers downloaded: ",
154+
std::move(isCancellationRequested));
155+
}
156+
else {
157+
auto response = core_->call("download_and_register_eps", *logger_);
158+
if (response.HasError()) {
159+
throw Exception(std::string("Error ensuring execution providers downloaded: ") + response.error,
160+
*logger_);
161+
}
135162
}
136163
}
137164

0 commit comments

Comments
 (0)