diff --git a/samples/cpp/live-audio-transcription/README.md b/samples/cpp/live-audio-transcription/README.md index a9fca9774..3e8b8e8d6 100644 --- a/samples/cpp/live-audio-transcription/README.md +++ b/samples/cpp/live-audio-transcription/README.md @@ -26,3 +26,7 @@ g++ -std=c++20 main.cpp -lfoundry_local -o live-audio-transcription-example # Synthetic 440Hz sine wave (no microphone needed) ./live-audio-transcription-example --synth ``` + +Press `Ctrl+C` to request a graceful stop. The sample passes that signal to +execution-provider and model downloads so long-running downloads can be +cancelled before transcription starts. diff --git a/samples/cpp/live-audio-transcription/main.cpp b/samples/cpp/live-audio-transcription/main.cpp index 1a3341e4c..5c94d6180 100644 --- a/samples/cpp/live-audio-transcription/main.cpp +++ b/samples/cpp/live-audio-transcription/main.cpp @@ -122,7 +122,8 @@ int main(int argc, char* argv[]) { foundry_local::Manager::Create(config); auto& manager = foundry_local::Manager::Instance(); - manager.EnsureEpsDownloaded(); + auto isCancellationRequested = [] { return !g_running.load(); }; + manager.DownloadAndRegisterEps(nullptr, isCancellationRequested); auto& catalog = manager.GetCatalog(); auto* model = catalog.GetModel("nemotron-speech-streaming-en-0.6b"); @@ -131,9 +132,12 @@ int main(int argc, char* argv[]) { } std::cout << "Downloading model (if needed)..." << std::endl; - model->Download([](float pct) { - std::cout << "\rDownloading: " << pct << "% " << std::flush; - }); + model->Download( + [](float pct) { + std::cout << "\rDownloading: " << pct << "% " << std::flush; + return true; + }, + isCancellationRequested); std::cout << std::endl; std::cout << "Loading model..." << std::endl; model->Load(); diff --git a/sdk/cpp/include/foundry_local_manager.h b/sdk/cpp/include/foundry_local_manager.h index 51af7f161..ac04a335a 100644 --- a/sdk/cpp/include/foundry_local_manager.h +++ b/sdk/cpp/include/foundry_local_manager.h @@ -83,15 +83,21 @@ namespace foundry_local { /// Download and register all available execution providers. /// @param progressCallback Optional callback invoked with (ep_name, percent) during download. + /// @param isCancellationRequested Optional callback checked on each progress update. Return true to cancel. /// @return Result describing which EPs were registered or failed. - EpDownloadResult DownloadAndRegisterEps(EpProgressCallback progressCallback = nullptr) const; + EpDownloadResult DownloadAndRegisterEps( + EpProgressCallback progressCallback = nullptr, + CancellationCallback isCancellationRequested = nullptr) const; /// Download and register specific execution providers by name. /// @param names EP names to download (as returned by DiscoverEps). /// @param progressCallback Optional callback invoked with (ep_name, percent) during download. + /// @param isCancellationRequested Optional callback checked on each progress update. Return true to cancel. /// @return Result describing which EPs were registered or failed. - EpDownloadResult DownloadAndRegisterEps(const std::vector& names, - EpProgressCallback progressCallback = nullptr) const; + EpDownloadResult DownloadAndRegisterEps( + const std::vector& names, + EpProgressCallback progressCallback = nullptr, + CancellationCallback isCancellationRequested = nullptr) const; private: explicit Manager(Configuration configuration, ILogger* logger); diff --git a/sdk/cpp/include/model.h b/sdk/cpp/include/model.h index b52fae76c..052cf45e9 100644 --- a/sdk/cpp/include/model.h +++ b/sdk/cpp/include/model.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include @@ -33,6 +34,7 @@ namespace foundry_local { #endif using DownloadProgressCallback = std::function; + using CancellationCallback = std::function; class IModel { public: @@ -43,7 +45,11 @@ namespace foundry_local { virtual bool IsLoaded() const = 0; virtual bool IsCached() const = 0; virtual const std::filesystem::path& GetPath() const = 0; - virtual void Download(DownloadProgressCallback onProgress = nullptr) = 0; + + /// Download the model, with an optional cancellation callback checked on each progress update. + /// Return true from isCancellationRequested to cancel the in-progress download. + virtual void Download(DownloadProgressCallback onProgress = nullptr, + CancellationCallback isCancellationRequested = nullptr) = 0; virtual void Load() = 0; virtual void Unload() = 0; virtual void RemoveFromCache() = 0; @@ -123,7 +129,8 @@ namespace foundry_local { const ModelInfo& GetInfo() const; const std::filesystem::path& GetPath() const override; - void Download(DownloadProgressCallback onProgress = nullptr) override; + void Download(DownloadProgressCallback onProgress = nullptr, + CancellationCallback isCancellationRequested = nullptr) override; void Load() override; bool IsLoaded() const override; @@ -158,8 +165,9 @@ namespace foundry_local { bool IsLoaded() const override { return SelectedVariant().IsLoaded(); } bool IsCached() const override { return SelectedVariant().IsCached(); } const std::filesystem::path& GetPath() const override { return SelectedVariant().GetPath(); } - void Download(DownloadProgressCallback onProgress = nullptr) override { - SelectedVariant().Download(std::move(onProgress)); + void Download(DownloadProgressCallback onProgress = nullptr, + CancellationCallback isCancellationRequested = nullptr) override { + SelectedVariant().Download(std::move(onProgress), std::move(isCancellationRequested)); } void Load() override { SelectedVariant().Load(); } void Unload() override { SelectedVariant().Unload(); } diff --git a/sdk/cpp/sample/main.cpp b/sdk/cpp/sample/main.cpp index 7c377da99..b82047800 100644 --- a/sdk/cpp/sample/main.cpp +++ b/sdk/cpp/sample/main.cpp @@ -4,6 +4,8 @@ #include "foundry_local.h" #include +#include +#include #include #include #include @@ -14,6 +16,18 @@ using namespace foundry_local; +namespace { +std::atomic g_cancelRequested{false}; + +void SignalHandler(int /*signum*/) { + g_cancelRequested.store(true); +} + +bool IsCancellationRequested() { + return g_cancelRequested.load(); +} +} // namespace + // --------------------------------------------------------------------------- // Logger // --------------------------------------------------------------------------- @@ -118,7 +132,8 @@ void ChatNonStreaming(Manager& manager, const std::string& alias) { PreferCpuVariant(*concreteModel); } - model->Download([](float pct) { printf("\rDownloading: %5.1f%%", pct); fflush(stdout); return true; }); + model->Download([](float pct) { printf("\rDownloading: %5.1f%%", pct); fflush(stdout); return true; }, + IsCancellationRequested); std::cout << "\n"; model->Load(); @@ -211,7 +226,8 @@ void TranscribeAudio(Manager& manager, const std::string& alias, const std::stri PreferCpuVariant(*concreteModel); } - model->Download([](float pct) { printf("\rDownloading: %5.1f%%", pct); fflush(stdout); return true; }); + model->Download([](float pct) { printf("\rDownloading: %5.1f%%", pct); fflush(stdout); return true; }, + IsCancellationRequested); std::cout << "\n"; model->Load(); @@ -263,7 +279,8 @@ void ChatWithToolCalling(Manager& manager, const std::string& alias) { PreferCpuVariant(*concreteModel); } - model->Download([](float pct) { printf("\rDownloading: %5.1f%%", pct); fflush(stdout); return true; }); + model->Download([](float pct) { printf("\rDownloading: %5.1f%%", pct); fflush(stdout); return true; }, + IsCancellationRequested); std::cout << "\n"; model->Load(); @@ -376,6 +393,8 @@ int main(int argc, char* argv[]) { const std::string audioPath = (argc > 3) ? argv[3] : ""; try { + std::signal(SIGINT, SignalHandler); + StdLogger logger; Manager::Create({"SampleApp"}, &logger); auto& manager = Manager::Instance(); @@ -399,7 +418,7 @@ int main(int argc, char* argv[]) { } printf("\r %-30s %5.1f%%", epName.c_str(), percent); fflush(stdout); - }); + }, IsCancellationRequested); if (!currentEp.empty()) std::cout << "\n"; } else { std::cout << "\nNo execution providers to download.\n"; diff --git a/sdk/cpp/src/core.h b/sdk/cpp/src/core.h index eb598373d..a69f961cc 100644 --- a/sdk/cpp/src/core.h +++ b/sdk/cpp/src/core.h @@ -187,7 +187,7 @@ namespace foundry_local { std::unique_ptr responseGuard(&response, safeDeleter); if (callback != nullptr) { - execCbCmd_(&request, &response, reinterpret_cast(callback), data); + execCbCmd_(&request, &response, callback, data); } else { execCmd_(&request, &response); diff --git a/sdk/cpp/src/core_helpers.h b/sdk/cpp/src/core_helpers.h index c46f294a2..e0cd37092 100644 --- a/sdk/cpp/src/core_helpers.h +++ b/sdk/cpp/src/core_helpers.h @@ -6,12 +6,15 @@ #pragma once +#include #include #include #include #include #include +#include #include +#include #include @@ -47,38 +50,82 @@ namespace foundry_local::detail { return core->call(command, logger, &payload, callback, userData); } + inline bool TryParseFloatToken(std::string_view token, float& value) { + if (token.empty()) { + return false; + } + + const auto* begin = token.data(); + const auto* end = begin + token.size(); + const auto result = std::from_chars(begin, end, value); + return result.ec == std::errc{} && result.ptr == end; + } + + inline bool TryParseDoubleToken(std::string_view token, double& value) { + if (token.empty()) { + return false; + } + + const auto* begin = token.data(); + const auto* end = begin + token.size(); + const auto result = std::from_chars(begin, end, value); + return result.ec == std::errc{} && result.ptr == end; + } + // Serialize + call with a streaming chunk handler. // Wraps the caller-supplied onChunk with the native callback boilerplate - // (null/length checks, exception capture, rethrow after the call). + // (null/length checks, exception capture, cancellation, rethrow after the call). // The errorContext string is used to prefix any core-layer error message. inline CoreResponse CallWithStreamingCallback(Internal::IFoundryLocalCore* core, std::string_view command, - const std::string& payload, ILogger& logger, - const std::function& onChunk, - std::string_view errorContext) { + const std::string* payload, ILogger& logger, + const std::function& onChunk, + std::string_view errorContext, + CancellationCallback isCancellationRequested = nullptr) { struct State { - const std::function* cb; + const std::function* cb; + CancellationCallback isCancellationRequested; + bool cancellationObserved = false; std::exception_ptr exception; - } state{&onChunk, nullptr}; + } state{&onChunk, std::move(isCancellationRequested), false, nullptr}; - auto nativeCallback = [](void* data, int32_t len, void* user) -> int { - if (!data || len <= 0) + auto nativeCallback = [](const void* data, int32_t len, void* user) -> int32_t { + auto* st = static_cast(user); + if (!st) { return 0; + } - auto* st = static_cast(user); - if (st->exception) + if (st->exception || st->cancellationObserved) { + return 1; + } + + if (!data || len <= 0) return 0; try { + if (st->isCancellationRequested && st->isCancellationRequested()) { + st->cancellationObserved = true; + return 1; + } + std::string chunk(static_cast(data), static_cast(len)); - (*(st->cb))(chunk); + if (!(*(st->cb))(chunk)) { + st->cancellationObserved = true; + return 1; + } } catch (...) { st->exception = std::current_exception(); + return 1; } + return 0; }; - auto response = core->call(command, logger, &payload, +nativeCallback, &state); + auto response = core->call(command, logger, payload, +nativeCallback, &state); + if (state.cancellationObserved) { + throw Exception("Operation cancelled", logger); + } + if (response.HasError()) { throw Exception(std::string(errorContext) + response.error, logger); } @@ -90,6 +137,38 @@ namespace foundry_local::detail { return response; } + inline CoreResponse CallWithStreamingCallback(Internal::IFoundryLocalCore* core, std::string_view command, + const std::string* payload, ILogger& logger, + const std::function& onChunk, + std::string_view errorContext, + CancellationCallback isCancellationRequested = nullptr) { + const std::function continuingOnChunk = + [&onChunk](const std::string& chunk) { + onChunk(chunk); + return true; + }; + return CallWithStreamingCallback(core, command, payload, logger, continuingOnChunk, errorContext, + std::move(isCancellationRequested)); + } + + inline CoreResponse CallWithStreamingCallback(Internal::IFoundryLocalCore* core, std::string_view command, + const std::string& payload, ILogger& logger, + const std::function& onChunk, + std::string_view errorContext, + CancellationCallback isCancellationRequested = nullptr) { + return CallWithStreamingCallback(core, command, &payload, logger, onChunk, errorContext, + std::move(isCancellationRequested)); + } + + inline CoreResponse CallWithStreamingCallback(Internal::IFoundryLocalCore* core, std::string_view command, + const std::string& payload, ILogger& logger, + const std::function& onChunk, + std::string_view errorContext, + CancellationCallback isCancellationRequested = nullptr) { + return CallWithStreamingCallback(core, command, &payload, logger, onChunk, errorContext, + std::move(isCancellationRequested)); + } + // Overload: allow Params object directly inline CoreResponse CallWithParams(Internal::IFoundryLocalCore* core, std::string_view command, const nlohmann::json& params, ILogger& logger) { diff --git a/sdk/cpp/src/flcore_native.h b/sdk/cpp/src/flcore_native.h index 2ea792b9e..b4c95ac49 100644 --- a/sdk/cpp/src/flcore_native.h +++ b/sdk/cpp/src/flcore_native.h @@ -5,10 +5,12 @@ #include #include -#ifdef _WIN32 - #define FL_CDECL __cdecl -#else - #define FL_CDECL +#ifndef FL_CDECL + #ifdef _WIN32 + #define FL_CDECL __cdecl + #else + #define FL_CDECL + #endif #endif extern "C" @@ -29,8 +31,9 @@ extern "C" int32_t ErrorLength; }; - // Callback signature: int(*)(void* data, int length, void* userData) — returns 0 to continue, 1 to cancel - using UserCallbackFn = int(__cdecl*)(void*, int32_t, void*); + // Callback signature: int32_t(*)(const void* data, int length, void* userData) + // Return 0 to continue, 1 to cancel. + using UserCallbackFn = int32_t(FL_CDECL*)(const void*, int32_t, void*); struct StreamingRequestBuffer { const void* Command; @@ -43,7 +46,8 @@ extern "C" // Exported function pointer types using execute_command_fn = void(FL_CDECL*)(RequestBuffer*, ResponseBuffer*); - using execute_command_with_callback_fn = void(FL_CDECL*)(RequestBuffer*, ResponseBuffer*, void* /*callback*/, + using execute_command_with_callback_fn = void(FL_CDECL*)(RequestBuffer*, ResponseBuffer*, + UserCallbackFn /*callback*/, void* /*userData*/); using execute_command_with_binary_fn = void(FL_CDECL*)(StreamingRequestBuffer*, ResponseBuffer*); using free_response_fn = void(FL_CDECL*)(ResponseBuffer*); diff --git a/sdk/cpp/src/foundry_local_internal_core.h b/sdk/cpp/src/foundry_local_internal_core.h index 368096dec..3a982b16c 100644 --- a/sdk/cpp/src/foundry_local_internal_core.h +++ b/sdk/cpp/src/foundry_local_internal_core.h @@ -8,11 +8,20 @@ #include #include "logger.h" +#ifndef FL_CDECL + #ifdef _WIN32 + #define FL_CDECL __cdecl + #else + #define FL_CDECL + #endif +#endif + namespace foundry_local { /// Native callback signature used by the core DLL interop. /// Parameters: (data, dataLength, userData). - using NativeCallbackFn = int (*)(void*, int32_t, void*); + /// Return 0 to continue, 1 to cancel the native operation. + using NativeCallbackFn = int32_t(FL_CDECL*)(const void*, int32_t, void*); /// Value returned by IFoundryLocalCore::call(). /// On success, `data` contains the response payload and `error` is empty. @@ -40,4 +49,4 @@ namespace foundry_local { }; } // namespace Internal -} // namespace foundry_local \ No newline at end of file +} // namespace foundry_local diff --git a/sdk/cpp/src/foundry_local_manager.cpp b/sdk/cpp/src/foundry_local_manager.cpp index 2c1e6177c..2d634f4ca 100644 --- a/sdk/cpp/src/foundry_local_manager.cpp +++ b/sdk/cpp/src/foundry_local_manager.cpp @@ -6,7 +6,7 @@ #include #include #include -#include +#include #include #include @@ -15,6 +15,7 @@ #include "foundry_local_internal_core.h" #include "foundry_local_exception.h" #include "core_interop_request.h" +#include "core_helpers.h" #include "core.h" #include "logger.h" @@ -163,39 +164,16 @@ void Manager::Cleanup() noexcept { return result; } - namespace { - struct EpCallbackContext { - EpProgressCallback* callback; - }; - - int EpProgressNativeCallback(void* data, int32_t dataLength, void* userData) { - auto* ctx = static_cast(userData); - if (!ctx || !ctx->callback || !*ctx->callback) return 0; - if (!data || dataLength <= 0) return 0; - - std::string progressStr(static_cast(data), static_cast(dataLength)); - auto sepIndex = progressStr.find('|'); - if (sepIndex != std::string::npos) { - std::string name = progressStr.substr(0, sepIndex); - // Parse percent using locale-independent std::from_chars - const auto* begin = progressStr.data() + sepIndex + 1; - const auto* end = progressStr.data() + progressStr.size(); - double percent = 0.0; - auto [ptr, ec] = std::from_chars(begin, end, percent); - if (ec == std::errc{}) { - (*ctx->callback)(name, percent); - } - } - return 0; - } + EpDownloadResult Manager::DownloadAndRegisterEps( + EpProgressCallback progressCallback, + CancellationCallback isCancellationRequested) const { + return DownloadAndRegisterEps({}, std::move(progressCallback), std::move(isCancellationRequested)); } - EpDownloadResult Manager::DownloadAndRegisterEps(EpProgressCallback progressCallback) const { - return DownloadAndRegisterEps({}, std::move(progressCallback)); - } - - EpDownloadResult Manager::DownloadAndRegisterEps(const std::vector& names, - EpProgressCallback progressCallback) const { + EpDownloadResult Manager::DownloadAndRegisterEps( + const std::vector& names, + EpProgressCallback progressCallback, + CancellationCallback isCancellationRequested) const { std::string requestData; std::string* requestDataPtr = nullptr; @@ -212,16 +190,32 @@ void Manager::Cleanup() noexcept { } CoreResponse response; - if (progressCallback) { - EpCallbackContext ctx{&progressCallback}; - response = core_->call("download_and_register_eps", *logger_, - requestDataPtr, EpProgressNativeCallback, &ctx); + if (progressCallback || isCancellationRequested) { + auto onChunk = [&progressCallback](const std::string& chunk) { + if (!progressCallback) { + return; + } + + const auto sep = chunk.find('|'); + if (sep == std::string::npos) { + return; + } + + double percent = 0.0; + if (detail::TryParseDoubleToken(std::string_view(chunk).substr(sep + 1), percent)) { + progressCallback(chunk.substr(0, sep), percent); + } + }; + + response = detail::CallWithStreamingCallback(core_.get(), "download_and_register_eps", + requestDataPtr, *logger_, onChunk, + "Error downloading execution providers: ", + std::move(isCancellationRequested)); } else { response = core_->call("download_and_register_eps", *logger_, requestDataPtr); - } - - if (response.HasError()) { - throw Exception(std::string("Error downloading execution providers: ") + response.error, *logger_); + if (response.HasError()) { + throw Exception(std::string("Error downloading execution providers: ") + response.error, *logger_); + } } EpDownloadResult result; diff --git a/sdk/cpp/src/model.cpp b/sdk/cpp/src/model.cpp index e09f55414..9cc7f3672 100644 --- a/sdk/cpp/src/model.cpp +++ b/sdk/cpp/src/model.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include @@ -67,38 +68,32 @@ namespace foundry_local { return false; } - void ModelVariant::Download(DownloadProgressCallback onProgress) { + void ModelVariant::Download(DownloadProgressCallback onProgress, + CancellationCallback isCancellationRequested) { if (IsCached()) { logger_->Log(LogLevel::Information, "Model '" + info_.name + "' is already cached, skipping download."); return; } - if (onProgress) { - struct ProgressState { - DownloadProgressCallback* cb; - ILogger* logger; - } state{&onProgress, logger_}; - - auto nativeCallback = [](void* data, int32_t len, void* user) -> int { - if (!data || len <= 0) - return 0; - auto* st = static_cast(user); - std::string perc(static_cast(data), static_cast(len)); - try { - float value = std::stof(perc); - (*(st->cb))(value); + if (onProgress || isCancellationRequested) { + std::function onChunk = [&onProgress](const std::string& chunk) { + if (!onProgress) { + return true; } - catch (...) { - st->logger->Log(LogLevel::Warning, "Failed to parse download progress: " + perc); + + float value = 0.0f; + if (TryParseFloatToken(chunk, value)) { + if (!onProgress(value)) { + return false; + } } - return 0; + return true; }; - auto response = CallWithJsonAndCallback(core_, "download_model", MakeModelParams(info_.name), *logger_, - +nativeCallback, &state); - if (response.HasError()) { - throw Exception("Error downloading model [" + info_.name + "]: " + response.error, *logger_); - } + const std::string payload = MakeModelParams(info_.name).dump(); + CallWithStreamingCallback(core_, "download_model", payload, *logger_, onChunk, + "Error downloading model [" + info_.name + "]: ", + std::move(isCancellationRequested)); } else { auto response = CallWithJson(core_, "download_model", MakeModelParams(info_.name), *logger_); diff --git a/sdk/cpp/test/ep_test.cpp b/sdk/cpp/test/ep_test.cpp index 7649b1efd..78c9ecaf6 100644 --- a/sdk/cpp/test/ep_test.cpp +++ b/sdk/cpp/test/ep_test.cpp @@ -72,7 +72,7 @@ static EpDownloadResult TestDownloadAndRegisterEps( struct EpCallbackContext { EpProgressCallback* callback; }; - auto nativeCb = [](void* data, int32_t dataLength, void* userData) -> int { + auto nativeCb = [](const void* data, int32_t dataLength, void* userData) -> int32_t { auto* ctx = static_cast(userData); if (!ctx || !ctx->callback || !*ctx->callback) return 0; if (!data || dataLength <= 0) return 0; @@ -249,9 +249,9 @@ TEST_F(DownloadAndRegisterEpsTest, CallbackInvokedWithProgressData) { [](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { if (callback) { std::string p1 = "WebGpuExecutionProvider|25.0"; - callback(const_cast(p1.data()), static_cast(p1.size()), userData); + callback(p1.data(), static_cast(p1.size()), userData); std::string p2 = "WebGpuExecutionProvider|100.0"; - callback(const_cast(p2.data()), static_cast(p2.size()), userData); + callback(p2.data(), static_cast(p2.size()), userData); } return R"({"Success": true, "Status": "OK", "RegisteredEps": ["WebGpuExecutionProvider"], "FailedEps": []})"; }); diff --git a/sdk/cpp/test/model_variant_test.cpp b/sdk/cpp/test/model_variant_test.cpp index c631f8ff3..82bdea5c0 100644 --- a/sdk/cpp/test/model_variant_test.cpp +++ b/sdk/cpp/test/model_variant_test.cpp @@ -9,6 +9,7 @@ #include "foundry_local_exception.h" #include +#include using namespace foundry_local; using namespace foundry_local::Testing; @@ -136,7 +137,7 @@ TEST_F(ModelVariantTest, Download_WithCallback_ReturnsZeroToContinue) { [](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { if (callback && userData) { std::string progress = "50"; - int result = callback(progress.data(), static_cast(progress.size()), userData); + const int32_t result = callback(progress.data(), static_cast(progress.size()), userData); EXPECT_EQ(0, result) << "Callback should return 0 (continue), not " << result; } return ""; @@ -146,6 +147,84 @@ TEST_F(ModelVariantTest, Download_WithCallback_ReturnsZeroToContinue) { variant.Download([&](float) { return true; }); } +TEST_F(ModelVariantTest, Download_ParsesNumericProgressChunk) { + core_.OnCall("get_cached_models", R"([])"); + core_.OnCall("download_model", + [](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + if (callback && userData) { + std::string progress = "12.5"; + callback(progress.data(), static_cast(progress.size()), userData); + } + return ""; + }); + + auto variant = MakeVariant("test-model"); + std::vector progressValues; + variant.Download([&](float pct) { + progressValues.push_back(pct); + return true; + }); + + ASSERT_EQ(1u, progressValues.size()); + EXPECT_NEAR(12.5f, progressValues[0], 0.01f); +} + +TEST_F(ModelVariantTest, Download_WithCancellationRequestsNativeCancel) { + core_.OnCall("get_cached_models", R"([])"); + bool nativeCallbackCancelled = false; + core_.OnCall("download_model", + [&](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + if (callback && userData) { + std::string progress = "50"; + nativeCallbackCancelled = + callback(progress.data(), static_cast(progress.size()), userData) == 1; + } + return ""; + }); + + auto variant = MakeVariant("test-model"); + EXPECT_THROW(variant.Download(nullptr, [] { return true; }), Exception); + EXPECT_TRUE(nativeCallbackCancelled); +} + +TEST_F(ModelVariantTest, Download_ProgressCallbackFalseRequestsNativeCancel) { + core_.OnCall("get_cached_models", R"([])"); + bool nativeCallbackCancelled = false; + core_.OnCall("download_model", + [&](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + if (callback && userData) { + std::string progress = "50"; + nativeCallbackCancelled = + callback(progress.data(), static_cast(progress.size()), userData) == 1; + } + return ""; + }); + + auto variant = MakeVariant("test-model"); + EXPECT_THROW(variant.Download([](float) { return false; }), Exception); + EXPECT_TRUE(nativeCallbackCancelled); +} + +TEST_F(ModelVariantTest, Download_CancellationAfterFinalCallbackDoesNotCancelSuccessfulDownload) { + core_.OnCall("get_cached_models", R"([])"); + core_.OnCall("download_model", + [](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + if (callback && userData) { + std::string progress = "100"; + callback(progress.data(), static_cast(progress.size()), userData); + } + return ""; + }); + + auto variant = MakeVariant("test-model"); + bool cancel = false; + EXPECT_NO_THROW(variant.Download([&](float) { + cancel = true; + return true; + }, [&] { return cancel; })); + EXPECT_TRUE(cancel); +} + TEST_F(ModelVariantTest, RemoveFromCache_CallsCore) { core_.OnCall("remove_cached_model", ""); auto variant = MakeVariant("test-model"); diff --git a/sdk/cs/README.md b/sdk/cs/README.md index 276ffb716..9493eea0b 100644 --- a/sdk/cs/README.md +++ b/sdk/cs/README.md @@ -99,6 +99,18 @@ await mgr.DownloadAndRegisterEpsAsync((epName, percent) => Console.WriteLine(); ``` +#### Cancelling model and EP downloads + +Pass a `CancellationToken` to either download API. Cancellation is observed on the next progress update. + +```csharp +// mgr and model already initialized +using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + +await mgr.DownloadAndRegisterEpsAsync(ct: cts.Token); +await model.DownloadAsync(ct: cts.Token); +``` + Catalog access no longer blocks on EP downloads. Call `DownloadAndRegisterEpsAsync` explicitly when you need hardware-accelerated execution providers. ## Quick Start diff --git a/sdk/cs/src/Detail/CoreInterop.cs b/sdk/cs/src/Detail/CoreInterop.cs index 7239a48e4..138aa9411 100644 --- a/sdk/cs/src/Detail/CoreInterop.cs +++ b/sdk/cs/src/Detail/CoreInterop.cs @@ -177,6 +177,7 @@ public Response ExecuteCommandImpl(string commandName, string? commandInput, }; ResponseBuffer response = default; + Exception? callbackException = null; if (callback != null) { @@ -190,18 +191,19 @@ public Response ExecuteCommandImpl(string commandName, string? commandInput, var helperHandle = GCHandle.Alloc(helper); var helperPtr = GCHandle.ToIntPtr(helperHandle); - unsafe + try { - CoreExecuteCommandWithCallback(&request, &response, funcPtr, helperPtr); + unsafe + { + CoreExecuteCommandWithCallback(&request, &response, funcPtr, helperPtr); + } } - - helperHandle.Free(); - - if (helper.Exception != null) + finally { - throw new FoundryLocalException("Exception in callback handler. See InnerException for details", - helper.Exception); + helperHandle.Free(); } + + callbackException = helper.Exception; } else { @@ -239,6 +241,17 @@ public Response ExecuteCommandImpl(string commandName, string? commandInput, Marshal.FreeHGlobal(inputPtr!.Value); } + if (callbackException != null) + { + if (callbackException is OperationCanceledException canceledException) + { + throw canceledException; + } + + throw new FoundryLocalException("Exception in callback handler. See InnerException for details", + callbackException); + } + return result; } catch (Exception ex) when (ex is not OperationCanceledException) diff --git a/sdk/cs/src/Detail/ModelVariant.cs b/sdk/cs/src/Detail/ModelVariant.cs index 250c601a2..442817228 100644 --- a/sdk/cs/src/Detail/ModelVariant.cs +++ b/sdk/cs/src/Detail/ModelVariant.cs @@ -6,6 +6,8 @@ namespace Microsoft.AI.Foundry.Local; +using System.Globalization; + using Microsoft.AI.Foundry.Local.Detail; using Microsoft.Extensions.Logging; @@ -63,8 +65,8 @@ public async Task DownloadAsync(Action? downloadProgress = null, CancellationToken? ct = null) { await Utils.CallWithExceptionHandling(() => DownloadImplAsync(downloadProgress, ct), - $"Error downloading model {Id}", _logger) - .ConfigureAwait(false); + $"Error downloading model {Id}", _logger) + .ConfigureAwait(false); } public async Task LoadAsync(CancellationToken? ct = null) @@ -144,16 +146,26 @@ private async Task DownloadImplAsync(Action? downloadProgress = null, }; ICoreInterop.Response? response; + var useCallbackPath = downloadProgress != null || (ct?.CanBeCanceled ?? false); - if (downloadProgress == null) - { - response = await _coreInterop.ExecuteCommandAsync("download_model", request, ct).ConfigureAwait(false); - } - else + if (useCallbackPath) { var callback = new ICoreInterop.CallbackFn(progressString => { - if (float.TryParse(progressString, out var progress)) + if (ct is CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + } + + if (downloadProgress == null) + { + return; + } + + if (float.TryParse(progressString, + NumberStyles.Float, + CultureInfo.InvariantCulture, + out var progress)) { downloadProgress(progress); } @@ -162,6 +174,10 @@ private async Task DownloadImplAsync(Action? downloadProgress = null, response = await _coreInterop.ExecuteCommandWithCallbackAsync("download_model", request, callback, ct).ConfigureAwait(false); } + else + { + response = await _coreInterop.ExecuteCommandAsync("download_model", request, ct).ConfigureAwait(false); + } if (response.Error != null) { diff --git a/sdk/cs/src/FoundryLocalManager.cs b/sdk/cs/src/FoundryLocalManager.cs index b014850f6..855aed4a2 100644 --- a/sdk/cs/src/FoundryLocalManager.cs +++ b/sdk/cs/src/FoundryLocalManager.cs @@ -6,6 +6,7 @@ namespace Microsoft.AI.Foundry.Local; using System; +using System.Globalization; using System.Text.Json; using System.Threading.Tasks; @@ -373,20 +374,27 @@ private async Task DownloadAndRegisterEpsImplAsync(IEnumerable ICoreInterop.Response result; - if (progressCallback != null) + var useCallbackPath = progressCallback != null || (ct?.CanBeCanceled ?? false); + + if (useCallbackPath) { var callback = new ICoreInterop.CallbackFn(progressString => { + if (ct is CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + } + var sepIndex = progressString.IndexOf('|'); if (sepIndex >= 0) { var name = progressString[..sepIndex]; if (double.TryParse(progressString[(sepIndex + 1)..], - System.Globalization.NumberStyles.Float, - System.Globalization.CultureInfo.InvariantCulture, + NumberStyles.Float, + CultureInfo.InvariantCulture, out var percent)) { - progressCallback(string.IsNullOrEmpty(name) ? "" : name, percent); + progressCallback?.Invoke(string.IsNullOrEmpty(name) ? "" : name, percent); } } }); diff --git a/sdk/cs/test/FoundryLocal.Tests/DownloadCancellationTests.cs b/sdk/cs/test/FoundryLocal.Tests/DownloadCancellationTests.cs new file mode 100644 index 000000000..04738b3b2 --- /dev/null +++ b/sdk/cs/test/FoundryLocal.Tests/DownloadCancellationTests.cs @@ -0,0 +1,121 @@ +// -------------------------------------------------------------------------------------------------------------------- +// +// Copyright (c) Microsoft. All rights reserved. +// +// -------------------------------------------------------------------------------------------------------------------- + +namespace Microsoft.AI.Foundry.Local.Tests; + +using Microsoft.AI.Foundry.Local.Detail; + +using Microsoft.Extensions.Logging; + +using Moq; + +internal sealed class DownloadCancellationTests +{ + [Test] + public async Task ModelVariantDownload_WithCancellationToken_UsesCallbackPathAndPropagatesCancellation() + { + var modelInfo = new ModelInfo + { + Id = "test-model-cpu:1", + Name = "test-model-cpu", + Alias = "test-model", + Version = 1, + ProviderType = "AzureFoundry", + Uri = "azureml://registries/azureml/models/test-model-cpu/versions/1", + ModelType = "ONNX", + }; + + var modelLoadManager = new Mock(MockBehavior.Strict); + var coreInterop = new Mock(MockBehavior.Strict); + var logger = new Mock(); + using var cts = new CancellationTokenSource(); + + coreInterop.Setup(x => x.ExecuteCommandWithCallbackAsync( + It.Is(s => s == "download_model"), + It.Is(r => r != null && + r.Params != null && + r.Params.ContainsKey("Model") && + r.Params["Model"] == modelInfo.Id), + It.IsAny(), + It.IsAny())) + .Returns((string commandName, + CoreInteropRequest? request, + ICoreInterop.CallbackFn callback, + CancellationToken? cancellationToken) => + { + callback("10"); + cts.Cancel(); + callback("20"); + return Task.FromResult(new ICoreInterop.Response()); + }); + + IModel model = new ModelVariant(modelInfo, modelLoadManager.Object, coreInterop.Object, logger.Object); + + OperationCanceledException? caught = null; + try + { + await model.DownloadAsync(ct: cts.Token); + } + catch (OperationCanceledException ex) + { + caught = ex; + } + + await Assert.That(caught).IsNotNull(); + coreInterop.Verify(x => x.ExecuteCommandWithCallbackAsync( + "download_model", + It.IsAny(), + It.IsAny(), + It.IsAny()), + Times.Once); + coreInterop.Verify(x => x.ExecuteCommandAsync( + It.IsAny(), + It.IsAny(), + It.IsAny()), + Times.Never); + } + + [Test] + public async Task ModelVariantDownload_WithProgressChunk_ParsesInvariantFloat() + { + var modelInfo = new ModelInfo + { + Id = "test-model-cpu:1", + Name = "test-model-cpu", + Alias = "test-model", + Version = 1, + ProviderType = "AzureFoundry", + Uri = "azureml://registries/azureml/models/test-model-cpu/versions/1", + ModelType = "ONNX", + }; + + var modelLoadManager = new Mock(MockBehavior.Strict); + var coreInterop = new Mock(MockBehavior.Strict); + var logger = new Mock(); + + coreInterop.Setup(x => x.ExecuteCommandWithCallbackAsync( + It.Is(s => s == "download_model"), + It.IsAny(), + It.IsAny(), + It.IsAny())) + .Returns((string commandName, + CoreInteropRequest? request, + ICoreInterop.CallbackFn callback, + CancellationToken? cancellationToken) => + { + callback("12.5"); + return Task.FromResult(new ICoreInterop.Response()); + }); + + var model = new ModelVariant(modelInfo, modelLoadManager.Object, coreInterop.Object, logger.Object); + var progressValues = new List(); + + await model.DownloadAsync(progressValues.Add); + + await Assert.That(progressValues.Count).IsEqualTo(1); + await Assert.That(progressValues[0]).IsEqualTo(12.5f); + } +} diff --git a/sdk/js/README.md b/sdk/js/README.md index 26471cc8c..fad973353 100644 --- a/sdk/js/README.md +++ b/sdk/js/README.md @@ -77,6 +77,19 @@ await manager.downloadAndRegisterEps((epName, percent) => { process.stdout.write('\n'); ``` +#### Cancelling model and EP downloads + +Use an `AbortController` with either `downloadAndRegisterEps()` or `model.download()`. Aborting the signal rejects the in-progress download promise. + +```typescript +// manager and model already initialized +const controller = new AbortController(); +setTimeout(() => controller.abort(), 5000); + +await manager.downloadAndRegisterEps(controller.signal); +await model.download(controller.signal); +``` + Catalog access does not block on EP downloads. Call `downloadAndRegisterEps()` when you need hardware-accelerated execution providers. ## Quick Start @@ -336,4 +349,4 @@ See `test/README.md` for details on prerequisites and setup. npm run example ``` -This runs the chat completion example in `examples/chat-completion.ts`. \ No newline at end of file +This runs the chat completion example in `examples/chat-completion.ts`. diff --git a/sdk/js/src/detail/coreInterop.ts b/sdk/js/src/detail/coreInterop.ts index 72013815c..36098d4ab 100644 --- a/sdk/js/src/detail/coreInterop.ts +++ b/sdk/js/src/detail/coreInterop.ts @@ -136,9 +136,47 @@ export class CoreInterop { return this.addon.executeCommandWithBinary(command, dataStr, binBuf); } - public executeCommandStreaming(command: string, params: any, callback: (chunk: string) => void): Promise { + public async executeCommandStreaming( + command: string, + params: any, + callback: (chunk: string) => void, + signal?: AbortSignal + ): Promise { + const createAbortError = (): Error => { + const error = new Error('Operation cancelled'); + error.name = 'AbortError'; + return error; + }; + + if (signal?.aborted) { + throw createAbortError(); + } + const dataStr = params ? JSON.stringify(params) : ''; - return this.addon.executeCommandStreaming(command, dataStr, callback); + let cancelled = false; + const wrappedCallback = (chunk: string) => { + if (signal?.aborted) { + cancelled = true; + throw createAbortError(); + } + + callback(chunk); + }; + + try { + const result = await this.addon.executeCommandStreaming(command, dataStr, wrappedCallback); + if (cancelled) { + throw createAbortError(); + } + + return result; + } catch (error) { + if (cancelled) { + throw createAbortError(); + } + + throw error; + } } } diff --git a/sdk/js/src/detail/model.ts b/sdk/js/src/detail/model.ts index ffd962db5..e70c0703c 100644 --- a/sdk/js/src/detail/model.ts +++ b/sdk/js/src/detail/model.ts @@ -125,10 +125,14 @@ export class Model implements IModel { /** * Downloads the currently selected variant. - * @param progressCallback - Optional callback to report download progress. + * @param progressCallbackOrSignal - Optional progress callback or AbortSignal. + * @param signal - Optional AbortSignal when a progress callback is provided. */ - public download(progressCallback?: (progress: number) => void): Promise { - return this.selectedVariant.download(progressCallback); + public download( + progressCallbackOrSignal?: ((progress: number) => void) | AbortSignal, + signal?: AbortSignal + ): Promise { + return this.selectedVariant.download(progressCallbackOrSignal, signal); } /** @@ -202,4 +206,4 @@ export class Model implements IModel { public createResponsesClient(baseUrl: string): ResponsesClient { return this.selectedVariant.createResponsesClient(baseUrl); } -} \ No newline at end of file +} diff --git a/sdk/js/src/detail/modelVariant.ts b/sdk/js/src/detail/modelVariant.ts index af150bb81..7f78353ac 100644 --- a/sdk/js/src/detail/modelVariant.ts +++ b/sdk/js/src/detail/modelVariant.ts @@ -107,19 +107,38 @@ export class ModelVariant implements IModel { /** * Downloads the model variant. - * @param progressCallback - Optional callback to report download progress (0-100). - */ - public async download(progressCallback?: (progress: number) => void): Promise { + * @param progressCallbackOrSignal - Optional progress callback (0-100) or AbortSignal. + * @param signal - Optional AbortSignal when a progress callback is provided. + */ + public async download( + progressCallbackOrSignal?: ((progress: number) => void) | AbortSignal, + signal?: AbortSignal + ): Promise { + const progressCallback = typeof progressCallbackOrSignal === 'function' + ? progressCallbackOrSignal + : undefined; + const abortSignal = typeof progressCallbackOrSignal === 'function' + ? signal + : progressCallbackOrSignal ?? signal; const request = { Params: { Model: this._modelInfo.id } }; - if (!progressCallback) { + if (!progressCallback && !abortSignal) { await this.coreInterop.executeCommandAsync("download_model", request); } else { + // Use the streaming path when progress or cancellation is needed. + // Provide a no-op callback when only cancellation is requested so + // the native callback mechanism is engaged. + const cb = progressCallback ?? (() => {}); await this.coreInterop.executeCommandStreaming("download_model", request, (chunk: string) => { - const progress = parseFloat(chunk); - if (!isNaN(progress)) { - progressCallback(progress); + const progressChunk = chunk.trim(); + if (progressChunk.length === 0) { + return; + } + + const progress = Number(progressChunk); + if (!Number.isNaN(progress)) { + cb(progress); } - }); + }, abortSignal); } } diff --git a/sdk/js/src/foundryLocalManager.ts b/sdk/js/src/foundryLocalManager.ts index f3224e656..449d84094 100644 --- a/sdk/js/src/foundryLocalManager.ts +++ b/sdk/js/src/foundryLocalManager.ts @@ -5,6 +5,13 @@ import { Catalog } from './catalog.js'; import { ResponsesClient } from './openai/responsesClient.js'; import { EpInfo, EpDownloadResult } from './types.js'; +function isAbortSignal(value: unknown): value is AbortSignal { + return typeof value === 'object' + && value !== null + && 'aborted' in value + && typeof (value as AbortSignal).aborted === 'boolean'; +} + /** * The main entry point for the Foundry Local SDK. * Manages the initialization of the core system and provides access to the Catalog and ModelLoadManager. @@ -178,18 +185,38 @@ export class FoundryLocalManager { * @returns A promise that resolves with an EpDownloadResult describing the outcome. */ public downloadAndRegisterEps(): Promise; + /** + * Downloads and registers execution providers. + * @param signal - Optional AbortSignal used to cancel an in-progress download. + * @returns A promise that resolves with an EpDownloadResult describing the outcome. + */ + public downloadAndRegisterEps(signal: AbortSignal): Promise; /** * Downloads and registers execution providers. * @param names - Array of EP names to download. * @returns A promise that resolves with an EpDownloadResult describing the outcome. */ public downloadAndRegisterEps(names: string[]): Promise; + /** + * Downloads and registers execution providers. + * @param names - Array of EP names to download. + * @param signal - Optional AbortSignal used to cancel an in-progress download. + * @returns A promise that resolves with an EpDownloadResult describing the outcome. + */ + public downloadAndRegisterEps(names: string[], signal: AbortSignal): Promise; /** * Downloads and registers execution providers, reporting progress. * @param progressCallback - Callback invoked with (epName, percent) as each EP downloads. Percent is 0-100. * @returns A promise that resolves with an EpDownloadResult describing the outcome. */ public downloadAndRegisterEps(progressCallback: (epName: string, percent: number) => void): Promise; + /** + * Downloads and registers execution providers, reporting progress. + * @param progressCallback - Callback invoked with (epName, percent) as each EP downloads. Percent is 0-100. + * @param signal - Optional AbortSignal used to cancel an in-progress download. + * @returns A promise that resolves with an EpDownloadResult describing the outcome. + */ + public downloadAndRegisterEps(progressCallback: (epName: string, percent: number) => void, signal: AbortSignal): Promise; /** * Downloads and registers execution providers, reporting progress. * @param names - Array of EP names to download. @@ -197,17 +224,45 @@ export class FoundryLocalManager { * @returns A promise that resolves with an EpDownloadResult describing the outcome. */ public downloadAndRegisterEps(names: string[], progressCallback: (epName: string, percent: number) => void): Promise; + /** + * Downloads and registers execution providers, reporting progress. + * @param names - Array of EP names to download. + * @param progressCallback - Callback invoked with (epName, percent) as each EP downloads. Percent is 0-100. + * @param signal - Optional AbortSignal used to cancel an in-progress download. + * @returns A promise that resolves with an EpDownloadResult describing the outcome. + */ + public downloadAndRegisterEps(names: string[], progressCallback: (epName: string, percent: number) => void, signal: AbortSignal): Promise; + /** + * Downloads and registers execution providers, preserving compatibility with callers that pass undefined for names. + * @param names - Undefined to download all EPs. + * @param signal - Optional AbortSignal used to cancel an in-progress download. + * @returns A promise that resolves with an EpDownloadResult describing the outcome. + */ + public downloadAndRegisterEps(names: undefined, signal: AbortSignal): Promise; + /** + * Downloads and registers execution providers, preserving compatibility with callers that pass undefined for names. + * @param names - Undefined to download all EPs. + * @param progressCallback - Callback invoked with (epName, percent) as each EP downloads. Percent is 0-100. + * @param signal - Optional AbortSignal used to cancel an in-progress download. + * @returns A promise that resolves with an EpDownloadResult describing the outcome. + */ + public downloadAndRegisterEps(names: undefined, progressCallback: (epName: string, percent: number) => void, signal?: AbortSignal): Promise; public async downloadAndRegisterEps( - namesOrCallback?: string[] | ((epName: string, percent: number) => void), - progressCallback?: (epName: string, percent: number) => void + namesOrCallbackOrSignal?: string[] | ((epName: string, percent: number) => void) | AbortSignal, + progressCallbackOrSignal?: ((epName: string, percent: number) => void) | AbortSignal, + maybeSignal?: AbortSignal ): Promise { - let names: string[] | undefined; - if (typeof namesOrCallback === 'function') { - progressCallback = namesOrCallback; - } else { - names = namesOrCallback; - } - + const names = Array.isArray(namesOrCallbackOrSignal) ? namesOrCallbackOrSignal : undefined; + const progressCallback = typeof namesOrCallbackOrSignal === 'function' + ? namesOrCallbackOrSignal + : typeof progressCallbackOrSignal === 'function' + ? progressCallbackOrSignal + : undefined; + const signal = isAbortSignal(namesOrCallbackOrSignal) + ? namesOrCallbackOrSignal + : isAbortSignal(progressCallbackOrSignal) + ? progressCallbackOrSignal + : maybeSignal; const params: { Params?: { Names: string } } = {}; if (names && names.length > 0) { params.Params = { Names: names.join(",") }; @@ -221,11 +276,17 @@ export class FoundryLocalManager { }; let response: string; + const commandParams = Object.keys(params).length > 0 ? params : undefined; - if (progressCallback) { + if (!progressCallback && !signal) { + response = await this.coreInterop.executeCommandAsync( + "download_and_register_eps", + commandParams + ); + } else if (progressCallback) { response = await this.coreInterop.executeCommandStreaming( "download_and_register_eps", - Object.keys(params).length > 0 ? params : undefined, + commandParams, (chunk: string) => { const sepIndex = chunk.indexOf('|'); if (sepIndex >= 0) { @@ -235,13 +296,15 @@ export class FoundryLocalManager { progressCallback(epName || '', percent); } } - } + }, + signal ); } else { response = await this.coreInterop.executeCommandStreaming( "download_and_register_eps", - Object.keys(params).length > 0 ? params : undefined, - () => {} // no-op callback + commandParams, + () => {}, // no-op callback + signal ); } diff --git a/sdk/js/src/imodel.ts b/sdk/js/src/imodel.ts index 7a8a79e35..72fdc4d8b 100644 --- a/sdk/js/src/imodel.ts +++ b/sdk/js/src/imodel.ts @@ -17,7 +17,13 @@ export interface IModel { get capabilities(): string | null; get supportsToolCalling(): boolean | null; - download(progressCallback?: (progress: number) => void): Promise; + /** + * Download the model to local cache if not already present. + * @param progressCallbackOrSignal - Optional callback for download progress (0-100), or AbortSignal. + * @param signal - Optional AbortSignal when a progress callback is provided. + */ + download(progressCallbackOrSignal?: ((progress: number) => void) | AbortSignal, + signal?: AbortSignal): Promise; get path(): string; load(): Promise; removeFromCache(): void; diff --git a/sdk/js/test/detail/coreInterop.test.ts b/sdk/js/test/detail/coreInterop.test.ts new file mode 100644 index 000000000..72191a927 --- /dev/null +++ b/sdk/js/test/detail/coreInterop.test.ts @@ -0,0 +1,86 @@ +import { describe, it } from 'mocha'; +import { expect } from 'chai'; +import { CoreInterop } from '../../src/detail/coreInterop.js'; + +describe('CoreInterop Tests', () => { + it('executeCommandStreaming should reject without calling native interop when signal is already aborted', async function() { + const controller = new AbortController(); + controller.abort(); + const interop = Object.create(CoreInterop.prototype) as any; + interop.addon = { + executeCommandStreaming: () => { + throw new Error('native interop should not be called for an already aborted signal'); + } + }; + + let caught: unknown; + try { + await CoreInterop.prototype.executeCommandStreaming.call( + interop, + 'download_model', + undefined, + () => {}, + controller.signal + ); + } catch (error) { + caught = error; + } + + expect(caught).to.be.instanceOf(Error); + expect((caught as Error).name).to.equal('AbortError'); + }); + + it('executeCommandStreaming should reject when signal is aborted before the next callback', async function() { + const controller = new AbortController(); + const chunks: string[] = []; + const interop = Object.create(CoreInterop.prototype) as any; + interop.addon = { + executeCommandStreaming: async (_command: string, _dataJson: string, callback: (chunk: string) => void) => { + callback('50'); + callback('60'); + return 'ok'; + } + }; + + let caught: unknown; + try { + await CoreInterop.prototype.executeCommandStreaming.call( + interop, + 'download_model', + undefined, + (chunk: string) => { + chunks.push(chunk); + controller.abort(); + }, + controller.signal + ); + } catch (error) { + caught = error; + } + + expect(chunks).to.deep.equal(['50']); + expect(caught).to.be.instanceOf(Error); + expect((caught as Error).name).to.equal('AbortError'); + }); + + it('executeCommandStreaming should not reject when signal aborts after the final observed callback', async function() { + const controller = new AbortController(); + const interop = Object.create(CoreInterop.prototype) as any; + interop.addon = { + executeCommandStreaming: async (_command: string, _dataJson: string, callback: (chunk: string) => void) => { + callback('100'); + return 'ok'; + } + }; + + const result = await CoreInterop.prototype.executeCommandStreaming.call( + interop, + 'download_model', + undefined, + () => controller.abort(), + controller.signal + ); + + expect(result).to.equal('ok'); + }); +}); diff --git a/sdk/js/test/foundryLocalManager.test.ts b/sdk/js/test/foundryLocalManager.test.ts index 48adcff40..c67959d05 100644 --- a/sdk/js/test/foundryLocalManager.test.ts +++ b/sdk/js/test/foundryLocalManager.test.ts @@ -1,6 +1,7 @@ import { describe, it } from 'mocha'; import { expect } from 'chai'; import { getTestManager } from './testUtils.js'; +import { FoundryLocalManager } from '../src/foundryLocalManager.js'; describe('Foundry Local Manager Tests', () => { it('should initialize successfully', function() { @@ -18,64 +19,160 @@ describe('Foundry Local Manager Tests', () => { }); it('downloadAndRegisterEps should call command without params when names are omitted', async function() { - const manager = getTestManager() as any; const calls: unknown[][] = []; - const originalExecuteCommandStreaming = manager.coreInterop.executeCommandStreaming; - - manager.coreInterop.executeCommandStreaming = (...args: unknown[]) => { - calls.push(args); - return Promise.resolve(JSON.stringify({ - Success: true, - Status: 'All providers registered', - RegisteredEps: ['CUDAExecutionProvider'], - FailedEps: [] - })); + const manager = Object.create(FoundryLocalManager.prototype) as any; + manager.coreInterop = { + executeCommandAsync: (...args: unknown[]) => { + calls.push(args); + return Promise.resolve(JSON.stringify({ + Success: true, + Status: 'All providers registered', + RegisteredEps: ['CUDAExecutionProvider'], + FailedEps: [] + })); + }, + executeCommandStreaming: () => { + throw new Error('download should not use streaming interop without progress or cancellation'); + } + }; + manager._catalog = { + invalidateCache: () => {} }; - try { - const result = await manager.downloadAndRegisterEps(); - expect(calls.length).to.equal(1); - expect(calls[0][0]).to.equal('download_and_register_eps'); - expect(calls[0][1]).to.be.undefined; - expect(result).to.deep.equal({ - success: true, - status: 'All providers registered', - registeredEps: ['CUDAExecutionProvider'], - failedEps: [] - }); - } finally { - manager.coreInterop.executeCommandStreaming = originalExecuteCommandStreaming; - } + const result = await manager.downloadAndRegisterEps(); + expect(calls.length).to.equal(1); + expect(calls[0][0]).to.equal('download_and_register_eps'); + expect(calls[0][1]).to.be.undefined; + expect(result).to.deep.equal({ + success: true, + status: 'All providers registered', + registeredEps: ['CUDAExecutionProvider'], + failedEps: [] + }); }); it('downloadAndRegisterEps should send Names param when subset is provided', async function() { - const manager = getTestManager() as any; const calls: unknown[][] = []; - const originalExecuteCommandStreaming = manager.coreInterop.executeCommandStreaming; + const manager = Object.create(FoundryLocalManager.prototype) as any; + manager.coreInterop = { + executeCommandAsync: (...args: unknown[]) => { + calls.push(args); + return Promise.resolve(JSON.stringify({ + Success: false, + Status: 'Some providers failed', + RegisteredEps: ['CUDAExecutionProvider'], + FailedEps: ['OpenVINOExecutionProvider'] + })); + }, + executeCommandStreaming: () => { + throw new Error('download should not use streaming interop without progress or cancellation'); + } + }; + manager._catalog = { + invalidateCache: () => {} + }; + + const result = await manager.downloadAndRegisterEps(['CUDAExecutionProvider', 'OpenVINOExecutionProvider']); + expect(calls.length).to.equal(1); + expect(calls[0][0]).to.equal('download_and_register_eps'); + expect(calls[0][1]).to.deep.equal({ Params: { Names: 'CUDAExecutionProvider,OpenVINOExecutionProvider' } }); + expect(result).to.deep.equal({ + success: false, + status: 'Some providers failed', + registeredEps: ['CUDAExecutionProvider'], + failedEps: ['OpenVINOExecutionProvider'] + }); + }); - manager.coreInterop.executeCommandStreaming = (...args: unknown[]) => { - calls.push(args); - return Promise.resolve(JSON.stringify({ - Success: false, - Status: 'Some providers failed', - RegisteredEps: ['CUDAExecutionProvider'], - FailedEps: ['OpenVINOExecutionProvider'] - })); + it('downloadAndRegisterEps should pass AbortSignal through to streaming interop', async function() { + const calls: unknown[][] = []; + const controller = new AbortController(); + const manager = Object.create(FoundryLocalManager.prototype) as any; + manager.coreInterop = { + executeCommandStreaming: (...args: unknown[]) => { + calls.push(args); + return Promise.resolve(JSON.stringify({ + Success: true, + Status: 'All providers registered', + RegisteredEps: ['CUDAExecutionProvider'], + FailedEps: [] + })); + } + }; + manager._catalog = { + invalidateCache: () => {} }; - try { - const result = await manager.downloadAndRegisterEps(['CUDAExecutionProvider', 'OpenVINOExecutionProvider']); - expect(calls.length).to.equal(1); - expect(calls[0][0]).to.equal('download_and_register_eps'); - expect(calls[0][1]).to.deep.equal({ Params: { Names: 'CUDAExecutionProvider,OpenVINOExecutionProvider' } }); - expect(result).to.deep.equal({ - success: false, - status: 'Some providers failed', - registeredEps: ['CUDAExecutionProvider'], - failedEps: ['OpenVINOExecutionProvider'] - }); - } finally { - manager.coreInterop.executeCommandStreaming = originalExecuteCommandStreaming; - } + await FoundryLocalManager.prototype.downloadAndRegisterEps.call( + manager, + ['CUDAExecutionProvider'], + controller.signal + ); + expect(calls.length).to.equal(1); + expect(calls[0][0]).to.equal('download_and_register_eps'); + expect(calls[0][3]).to.equal(controller.signal); }); + + it('downloadAndRegisterEps should honor progress callback when names are explicitly undefined', async function() { + const calls: unknown[][] = []; + const progress: Array<[string, number]> = []; + const manager = Object.create(FoundryLocalManager.prototype) as any; + manager.coreInterop = { + executeCommandStreaming: (...args: unknown[]) => { + calls.push(args); + const callback = args[2] as (chunk: string) => void; + callback('CUDAExecutionProvider|42.5'); + return Promise.resolve(JSON.stringify({ + Success: true, + Status: 'All providers registered', + RegisteredEps: ['CUDAExecutionProvider'], + FailedEps: [] + })); + } + }; + manager._catalog = { + invalidateCache: () => {} + }; + + await FoundryLocalManager.prototype.downloadAndRegisterEps.call( + manager, + undefined, + (epName: string, percent: number) => progress.push([epName, percent]) + ); + + expect(calls.length).to.equal(1); + expect(calls[0][1]).to.be.undefined; + expect(progress).to.deep.equal([['CUDAExecutionProvider', 42.5]]); + }); + + it('downloadAndRegisterEps should pass AbortSignal when names are explicitly undefined', async function() { + const calls: unknown[][] = []; + const controller = new AbortController(); + const manager = Object.create(FoundryLocalManager.prototype) as any; + manager.coreInterop = { + executeCommandStreaming: (...args: unknown[]) => { + calls.push(args); + return Promise.resolve(JSON.stringify({ + Success: true, + Status: 'All providers registered', + RegisteredEps: ['CUDAExecutionProvider'], + FailedEps: [] + })); + } + }; + manager._catalog = { + invalidateCache: () => {} + }; + + await FoundryLocalManager.prototype.downloadAndRegisterEps.call( + manager, + undefined, + controller.signal + ); + + expect(calls.length).to.equal(1); + expect(calls[0][1]).to.be.undefined; + expect(calls[0][3]).to.equal(controller.signal); + }); + }); diff --git a/sdk/js/test/model.test.ts b/sdk/js/test/model.test.ts index 4048d9a11..7203e5668 100644 --- a/sdk/js/test/model.test.ts +++ b/sdk/js/test/model.test.ts @@ -1,6 +1,9 @@ import { describe, it } from 'mocha'; import { expect } from 'chai'; import { getTestManager, TEST_MODEL_ALIAS } from './testUtils.js'; +import { Model } from '../src/detail/model.js'; +import { ModelVariant } from '../src/detail/modelVariant.js'; +import { DeviceType, type ModelInfo } from '../src/types.js'; describe('Model Tests', () => { it('should verify cached models from test-data-shared', async function() { @@ -58,4 +61,114 @@ describe('Model Tests', () => { await model.unload(); expect(await model.isLoaded()).to.be.false; }); -}); \ No newline at end of file + + it('download should use streaming interop when only an AbortSignal is provided', async function() { + const calls: unknown[][] = []; + const controller = new AbortController(); + const fakeCoreInterop = { + executeCommandAsync: () => { + throw new Error('download should not use executeCommandAsync when a signal is provided'); + }, + executeCommandStreaming: (...args: unknown[]) => { + calls.push(args); + return Promise.resolve(''); + } + }; + const modelInfo: ModelInfo = { + id: 'test-model-cpu:1', + name: 'test-model-cpu', + version: 1, + alias: TEST_MODEL_ALIAS, + providerType: 'AzureFoundry', + uri: 'azureml://registries/azureml/models/test-model-cpu/versions/1', + modelType: 'ONNX', + cached: false, + createdAtUnix: 0, + runtime: { + deviceType: DeviceType.CPU, + executionProvider: 'CPUExecutionProvider' + } + }; + const variant = new ModelVariant(modelInfo, fakeCoreInterop as any, {} as any); + const model = new Model(variant); + + await model.download(controller.signal); + expect(calls.length).to.equal(1); + expect(calls[0][0]).to.equal('download_model'); + expect(calls[0][3]).to.equal(controller.signal); + }); + + it('download should preserve undefined progress callback with AbortSignal overload', async function() { + const calls: unknown[][] = []; + const controller = new AbortController(); + const fakeCoreInterop = { + executeCommandAsync: () => { + throw new Error('download should not use executeCommandAsync when a signal is provided'); + }, + executeCommandStreaming: (...args: unknown[]) => { + calls.push(args); + return Promise.resolve(''); + } + }; + const modelInfo: ModelInfo = { + id: 'test-model-cpu:1', + name: 'test-model-cpu', + version: 1, + alias: TEST_MODEL_ALIAS, + providerType: 'AzureFoundry', + uri: 'azureml://registries/azureml/models/test-model-cpu/versions/1', + modelType: 'ONNX', + cached: false, + createdAtUnix: 0, + runtime: { + deviceType: DeviceType.CPU, + executionProvider: 'CPUExecutionProvider' + } + }; + const variant = new ModelVariant(modelInfo, fakeCoreInterop as any, {} as any); + const model = new Model(variant); + + await model.download(undefined, controller.signal); + expect(calls.length).to.equal(1); + expect(calls[0][0]).to.equal('download_model'); + expect(calls[0][3]).to.equal(controller.signal); + }); + + it('download should parse a numeric progress chunk', async function() { + const progress: number[] = []; + const fakeCoreInterop = { + executeCommandAsync: () => { + throw new Error('download should use streaming interop when progress is provided'); + }, + executeCommandStreaming: async ( + _command: string, + _request: unknown, + callback: (chunk: string) => void + ) => { + callback('12.5'); + return ''; + } + }; + const modelInfo: ModelInfo = { + id: 'test-model-cpu:1', + name: 'test-model-cpu', + version: 1, + alias: TEST_MODEL_ALIAS, + providerType: 'AzureFoundry', + uri: 'azureml://registries/azureml/models/test-model-cpu/versions/1', + modelType: 'ONNX', + cached: false, + createdAtUnix: 0, + runtime: { + deviceType: DeviceType.CPU, + executionProvider: 'CPUExecutionProvider' + } + }; + const variant = new ModelVariant(modelInfo, fakeCoreInterop as any, {} as any); + const model = new Model(variant); + + await model.download(progress.push.bind(progress)); + + expect(progress).to.deep.equal([12.5]); + }); +}); diff --git a/sdk/python/README.md b/sdk/python/README.md index 2a121411e..55a6f8d17 100644 --- a/sdk/python/README.md +++ b/sdk/python/README.md @@ -108,6 +108,21 @@ manager.download_and_register_eps(progress_callback=on_progress) print() ``` +### Cancelling model and EP downloads + +Pass a `threading.Event` as `cancel_event` to either download API. Set the event from another thread or handler to cancel the in-progress download. + +```python +import threading + +# manager and model already initialized +cancel_event = threading.Event() +threading.Timer(5.0, cancel_event.set).start() + +manager.download_and_register_eps(cancel_event=cancel_event) +model.download(cancel_event=cancel_event) +``` + Catalog access does not block on EP downloads. Call `download_and_register_eps()` when you need hardware-accelerated execution providers. ## Quick Start @@ -328,4 +343,4 @@ See [test/README.md](test/README.md) for detailed test setup and structure. ```bash python examples/chat_completion.py -``` \ No newline at end of file +``` diff --git a/sdk/python/src/detail/core_interop.py b/sdk/python/src/detail/core_interop.py index f93b79f03..a013f7ba7 100644 --- a/sdk/python/src/detail/core_interop.py +++ b/sdk/python/src/detail/core_interop.py @@ -10,6 +10,7 @@ import logging import os import sys +import threading from dataclasses import dataclass from pathlib import Path @@ -84,6 +85,10 @@ class Response: error: Optional[str] = None +class CancelledException(Exception): + """Raised internally when a download or streaming operation is cancelled.""" + + class CallbackHelper: """Internal helper class to convert the callback from ctypes to a str and call the python callback.""" @staticmethod @@ -92,18 +97,27 @@ def callback(data_ptr, length, self_ptr): try: self = ctypes.cast(self_ptr, ctypes.POINTER(ctypes.py_object)).contents.value + # Check for cancellation before processing the callback data. + if self._cancel_event is not None and self._cancel_event.is_set(): + raise CancelledException("Operation cancelled") + # convert to a string and pass to the python callback data_bytes = ctypes.string_at(data_ptr, length) data_str = data_bytes.decode('utf-8') self._py_callback(data_str) return 0 # continue + except CancelledException as e: + if self is not None and self.exception is None: + self.exception = e + return 1 # cancel except Exception as e: if self is not None and self.exception is None: self.exception = e # keep the first only as they are likely all the same return 1 # cancel on error - def __init__(self, py_callback: Callable[[str], None]): + def __init__(self, py_callback: Callable[[str], None], cancel_event: Optional['threading.Event'] = None): self._py_callback = py_callback + self._cancel_event = cancel_event self.exception = None @@ -252,37 +266,44 @@ def __init__(self, config: Configuration): logger.info("Foundry.Local.Core initialized successfully: %s", response.data) def _execute_command(self, command: str, interop_request: InteropRequest = None, - callback: CoreInterop.CALLBACK_TYPE = None): + callback: CoreInterop.CALLBACK_TYPE = None, + cancel_event: Optional[threading.Event] = None): cmd_ptr, cmd_len, cmd_buf = CoreInterop._to_c_buffer(command) data_ptr, data_len, data_buf = CoreInterop._to_c_buffer(interop_request.to_json() if interop_request else None) req = RequestBuffer(Command=cmd_ptr, CommandLength=cmd_len, Data=data_ptr, DataLength=data_len) resp = ResponseBuffer() lib = CoreInterop._flcore_library + callback_exception = None if (callback is not None): # If a callback is provided, use the execute_command_with_callback method # We need a helper to do the initial conversion from ctypes to Python and pass it through to the # provided callback function - callback_helper = CallbackHelper(callback) + callback_helper = CallbackHelper(callback, cancel_event) callback_py_obj = ctypes.py_object(callback_helper) callback_helper_ptr = ctypes.cast(ctypes.pointer(callback_py_obj), ctypes.c_void_p) callback_fn = CoreInterop.CALLBACK_TYPE(CallbackHelper.callback) lib.execute_command_with_callback(ctypes.byref(req), ctypes.byref(resp), callback_fn, callback_helper_ptr) - - if callback_helper.exception is not None: - raise callback_helper.exception + callback_exception = callback_helper.exception else: lib.execute_command(ctypes.byref(req), ctypes.byref(resp)) req = None # Free Python reference to request - response_str = ctypes.string_at(resp.Data, resp.DataLength).decode("utf-8") if resp.Data else None - error_str = ctypes.string_at(resp.Error, resp.ErrorLength).decode("utf-8") if resp.Error else None - - # C# owns the memory in the response so we need to free it explicitly - lib.free_response(resp) + try: + response_str = ctypes.string_at(resp.Data, resp.DataLength).decode("utf-8") if resp.Data else None + error_str = ctypes.string_at(resp.Error, resp.ErrorLength).decode("utf-8") if resp.Error else None + finally: + # C# owns the memory in the response so we need to free it explicitly. + # Do this before surfacing callback exceptions so cancellation does not leak native buffers. + lib.free_response(resp) + + if callback_exception is not None: + if isinstance(callback_exception, CancelledException): + raise FoundryLocalException("Operation cancelled") + raise callback_exception return Response(data=response_str, error=error_str) @@ -303,23 +324,33 @@ def execute_command(self, command_name: str, command_input: Optional[InteropRequ return response def execute_command_with_callback(self, command_name: str, command_input: Optional[InteropRequest], - callback: Callable[[str], None]) -> Response: + callback: Callable[[str], None], + cancel_event: Optional[threading.Event] = None) -> Response: """Execute a command with a streaming callback. The ``callback`` receives incremental string data from the native layer (e.g. streaming chat tokens or download progress). + If ``cancel_event`` is provided and is set, the native call will be + cancelled at the next callback invocation and a ``FoundryLocalException`` + with message ``"Operation cancelled"`` will be raised. + Args: command_name: The native command name. command_input: Optional request parameters. callback: Called with each incremental string response. + cancel_event: Optional ``threading.Event`` that signals cancellation + when set. Returns: A ``Response`` with ``data`` on success or ``error`` on failure. + + Raises: + FoundryLocalException: If the operation is cancelled or fails. """ logger.debug("Executing command with callback: %s Input: %s", command_name, command_input.params if command_input else None) - response = self._execute_command(command_name, command_input, callback) + response = self._execute_command(command_name, command_input, callback, cancel_event) return response def execute_command_with_binary(self, command_name: str, diff --git a/sdk/python/src/detail/model.py b/sdk/python/src/detail/model.py index 6d60b7a2f..a71b1dba5 100644 --- a/sdk/python/src/detail/model.py +++ b/sdk/python/src/detail/model.py @@ -5,6 +5,7 @@ from __future__ import annotations import logging +from threading import Event from typing import Callable, List, Optional from ..imodel import IModel @@ -115,9 +116,10 @@ def is_loaded(self) -> bool: """Is the currently selected variant loaded in memory?""" return self._selected_variant.is_loaded - def download(self, progress_callback: Optional[Callable[[float], None]] = None) -> None: + def download(self, progress_callback: Optional[Callable[[float], None]] = None, + cancel_event: Optional[Event] = None) -> None: """Download the currently selected variant.""" - self._selected_variant.download(progress_callback) + self._selected_variant.download(progress_callback, cancel_event) def get_path(self) -> str: """Get the path to the currently selected variant.""" diff --git a/sdk/python/src/detail/model_variant.py b/sdk/python/src/detail/model_variant.py index 76efb05cd..a563baabd 100644 --- a/sdk/python/src/detail/model_variant.py +++ b/sdk/python/src/detail/model_variant.py @@ -5,6 +5,7 @@ from __future__ import annotations import logging +from threading import Event from typing import Callable, List, Optional from ..imodel import IModel @@ -112,20 +113,40 @@ def is_loaded(self) -> bool: loaded_model_ids = self._model_load_manager.list_loaded() return self.id in loaded_model_ids - def download(self, progress_callback: Callable[[float], None] = None): + def download(self, progress_callback: Callable[[float], None] = None, + cancel_event: Optional[Event] = None): """Download this variant to the local cache. Args: progress_callback: Optional callback receiving download progress as a percentage (0.0 to 100.0). + cancel_event: Optional ``threading.Event``. When set, the download will be + cancelled at the next progress update and ``FoundryLocalException`` is raised. """ + self._download_impl(progress_callback, cancel_event) + + def _download_impl(self, progress_callback: Callable[[float], None] = None, + cancel_event: Optional[Event] = None) -> None: request = InteropRequest(params={"Model": self.id}) - if progress_callback is None: + if progress_callback is None and cancel_event is None: response = self._core_interop.execute_command("download_model", request) else: + # Use the callback path when either progress or cancellation is needed. + # Ignore invalid progress chunks so cancellation-only downloads + # still tolerate any non-progress output from the native layer. + def _on_chunk(chunk: str) -> None: + if progress_callback is None: + return + + try: + progress_callback(float(chunk)) + except ValueError: + pass + response = self._core_interop.execute_command_with_callback( "download_model", request, - lambda pct_str: progress_callback(float(pct_str)) + _on_chunk, + cancel_event, ) logger.info("Download response: %s", response) diff --git a/sdk/python/src/foundry_local_manager.py b/sdk/python/src/foundry_local_manager.py index a649f8e56..e47569ecc 100644 --- a/sdk/python/src/foundry_local_manager.py +++ b/sdk/python/src/foundry_local_manager.py @@ -101,6 +101,7 @@ def download_and_register_eps( self, names: Optional[list[str]] = None, progress_callback: Optional[Callable[[str, float], None]] = None, + cancel_event: Optional[threading.Event] = None, ) -> EpDownloadResult: """Download and register execution providers. @@ -109,6 +110,8 @@ def download_and_register_eps( all discoverable EPs are downloaded. progress_callback: Optional callback ``(ep_name: str, percent: float) -> None`` invoked as each EP downloads. ``percent`` is 0-100. + cancel_event: Optional ``threading.Event`` that signals cancellation + when set. The download will be cancelled at the next progress update. Returns: ``EpDownloadResult`` describing operation status and per-EP outcomes. @@ -120,19 +123,20 @@ def download_and_register_eps( if names is not None and len(names) > 0: request = InteropRequest(params={"Names": ",".join(names)}) - if progress_callback is not None: + if progress_callback is not None or cancel_event is not None: def _on_chunk(chunk: str) -> None: - sep = chunk.find("|") - if sep >= 0: - ep_name = chunk[:sep] or "" - try: - percent = float(chunk[sep + 1:]) - progress_callback(ep_name, percent) - except ValueError: - pass + if progress_callback is not None: + sep = chunk.find("|") + if sep >= 0: + ep_name = chunk[:sep] or "" + try: + percent = float(chunk[sep + 1:]) + progress_callback(ep_name, percent) + except ValueError: + pass response = self._core_interop.execute_command_with_callback( - "download_and_register_eps", request, _on_chunk + "download_and_register_eps", request, _on_chunk, cancel_event ) else: response = self._core_interop.execute_command("download_and_register_eps", request) diff --git a/sdk/python/src/imodel.py b/sdk/python/src/imodel.py index f723e514a..fc63f3747 100644 --- a/sdk/python/src/imodel.py +++ b/sdk/python/src/imodel.py @@ -5,6 +5,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from threading import Event from typing import Callable, List, Optional from .openai.chat_client import ChatClient @@ -76,10 +77,13 @@ def supports_tool_calling(self) -> Optional[bool]: pass @abstractmethod - def download(self, progress_callback: Callable[[float], None] = None) -> None: + def download(self, progress_callback: Callable[[float], None] = None, + cancel_event: Optional[Event] = None) -> None: """ Download the model to local cache if not already present. :param progress_callback: Optional callback function for download progress as a percentage (0.0 to 100.0). + :param cancel_event: Optional ``threading.Event``. When set, the download will be + cancelled at the next progress update and ``FoundryLocalException`` is raised. """ pass diff --git a/sdk/python/test/test_core_interop.py b/sdk/python/test/test_core_interop.py new file mode 100644 index 000000000..336f59ad3 --- /dev/null +++ b/sdk/python/test/test_core_interop.py @@ -0,0 +1,35 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Tests for CoreInterop callback helpers.""" + +from __future__ import annotations + +import ctypes +import threading + +from foundry_local_sdk.detail.core_interop import CallbackHelper, CancelledException + + +class TestCoreInterop: + def test_callback_helper_returns_cancel_when_cancel_event_is_set(self): + """Callback helper should return 1 without invoking Python callback when cancelled.""" + cancel_event = threading.Event() + cancel_event.set() + called = False + + def _callback(_chunk: str) -> None: + nonlocal called + called = True + + helper = CallbackHelper(_callback, cancel_event) + helper_ref = ctypes.py_object(helper) + helper_ptr = ctypes.cast(ctypes.pointer(helper_ref), ctypes.c_void_p) + data = ctypes.create_string_buffer(b"50") + + result = CallbackHelper.callback(data, 2, helper_ptr) + + assert result == 1 + assert called is False + assert isinstance(helper.exception, CancelledException) diff --git a/sdk/python/test/test_foundry_local_manager.py b/sdk/python/test/test_foundry_local_manager.py index 315288912..3abb37f64 100644 --- a/sdk/python/test/test_foundry_local_manager.py +++ b/sdk/python/test/test_foundry_local_manager.py @@ -6,6 +6,10 @@ from __future__ import annotations +import threading + +from foundry_local_sdk.foundry_local_manager import FoundryLocalManager + class _Response: def __init__(self, data=None, error=None): @@ -22,6 +26,12 @@ def execute_command(self, command_name, command_input=None): self.calls.append((command_name, command_input)) return self._responses[command_name] + def execute_command_with_callback( + self, command_name, command_input=None, callback=None, cancel_event=None + ): + self.calls.append((command_name, command_input, callback, cancel_event)) + return self._responses[command_name] + class TestFoundryLocalManager: """Foundry Local Manager Tests.""" @@ -81,3 +91,36 @@ def test_download_and_register_eps_returns_result(self, manager): assert result.status == "ok" assert result.registered_eps == ["CUDAExecutionProvider"] assert result.failed_eps == [] + + def test_download_and_register_eps_uses_callback_path_when_cancel_event_is_provided(self): + fake_core = _FakeCoreInterop( + { + "download_and_register_eps": _Response( + data=( + '{"Success":true,"Status":"ok",' + '"RegisteredEps":["CUDAExecutionProvider"],"FailedEps":[]}' + ), + error=None, + ) + } + ) + manager = FoundryLocalManager.__new__(FoundryLocalManager) + manager._core_interop = fake_core + manager.catalog = type( + "_FakeCatalog", + (), + {"_invalidate_cache": staticmethod(lambda: None)}, + )() + cancel_event = threading.Event() + + result = manager.download_and_register_eps( + ["CUDAExecutionProvider"], cancel_event=cancel_event + ) + + assert result.success is True + assert len(fake_core.calls) == 1 + command_name, command_input, callback, seen_cancel_event = fake_core.calls[0] + assert command_name == "download_and_register_eps" + assert command_input.params == {"Names": "CUDAExecutionProvider"} + assert callable(callback) + assert seen_cancel_event is cancel_event diff --git a/sdk/python/test/test_model.py b/sdk/python/test/test_model.py index e2ea15090..3d83a44ec 100644 --- a/sdk/python/test/test_model.py +++ b/sdk/python/test/test_model.py @@ -6,6 +6,12 @@ from __future__ import annotations +import threading + +from types import SimpleNamespace + +from foundry_local_sdk.detail.model_variant import ModelVariant + from .conftest import TEST_MODEL_ALIAS, AUDIO_MODEL_ALIAS @@ -86,3 +92,75 @@ def test_should_expose_supports_tool_calling(self, catalog): assert model is not None stc = model.supports_tool_calling assert stc is None or isinstance(stc, bool) + + def test_download_should_use_callback_path_when_cancel_event_is_provided(self): + """Model download should route through callback interop when cancellation is enabled.""" + + class _Response: + def __init__(self, data=None, error=None): + self.data = data + self.error = error + + class _FakeCoreInterop: + def __init__(self): + self.calls = [] + + def execute_command(self, command_name, command_input=None): + raise AssertionError( + "download should not use execute_command when cancel_event is provided" + ) + + def execute_command_with_callback( + self, command_name, command_input=None, callback=None, cancel_event=None + ): + self.calls.append((command_name, command_input, callback, cancel_event)) + return _Response(data="", error=None) + + fake_core = _FakeCoreInterop() + cancel_event = threading.Event() + variant = ModelVariant.__new__(ModelVariant) + variant._model_info = SimpleNamespace(id="test-model-cpu:1", alias=TEST_MODEL_ALIAS) + variant._id = "test-model-cpu:1" + variant._alias = TEST_MODEL_ALIAS + variant._core_interop = fake_core + variant._model_load_manager = None + + variant.download(cancel_event=cancel_event) + + assert len(fake_core.calls) == 1 + command_name, command_input, callback, seen_cancel_event = fake_core.calls[0] + assert command_name == "download_model" + assert command_input.params == {"Model": "test-model-cpu:1"} + assert callable(callback) + assert seen_cancel_event is cancel_event + callback("50") + + def test_download_should_parse_numeric_progress_chunk(self): + """Model download progress parsing should parse the numeric native chunk.""" + + class _Response: + def __init__(self, data=None, error=None): + self.data = data + self.error = error + + class _FakeCoreInterop: + def execute_command(self, command_name, command_input=None): + raise AssertionError("download should use callback interop when progress is provided") + + def execute_command_with_callback( + self, command_name, command_input=None, callback=None, cancel_event=None + ): + callback("12.5") + return _Response(data="", error=None) + + progress = [] + variant = ModelVariant.__new__(ModelVariant) + variant._model_info = SimpleNamespace(id="test-model-cpu:1", alias=TEST_MODEL_ALIAS) + variant._id = "test-model-cpu:1" + variant._alias = TEST_MODEL_ALIAS + variant._core_interop = _FakeCoreInterop() + variant._model_load_manager = None + + variant.download(progress_callback=progress.append) + + assert progress == [12.5] diff --git a/sdk/rust/README.md b/sdk/rust/README.md index ce97a7dd0..058b8f721 100644 --- a/sdk/rust/README.md +++ b/sdk/rust/README.md @@ -107,6 +107,32 @@ manager.download_and_register_eps_with_progress(None, move |ep_name: &str, perce println!(); ``` +#### Cancelling model and EP downloads + +Use a shared `Arc` with the download builders. Set the flag from another task or signal handler to stop the in-progress download. + +```rust +use std::sync::{ + atomic::AtomicBool, + Arc, +}; + +// manager and model already initialized +let cancel_flag = Arc::new(AtomicBool::new(false)); +// call cancel_flag.store(true, ...) from another task or signal handler to cancel + +manager + .download_and_register_eps_builder() + .cancel(Arc::clone(&cancel_flag)) + .run() + .await?; +model + .download_builder() + .cancel(Arc::clone(&cancel_flag)) + .run() + .await?; +``` + Catalog access does not block on EP downloads. Call `download_and_register_eps` when you need hardware-accelerated execution providers. ## Quick Start @@ -197,6 +223,17 @@ model.download(Some(|progress: f64| { std::io::Write::flush(&mut std::io::stdout()).ok(); })).await?; +// Or use the builder when combining progress, cancellation, or future options +let cancel_flag = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); +model.download_builder() + .progress(|progress| { + print!("\r{progress:.1}%"); + std::io::Write::flush(&mut std::io::stdout()).ok(); + }) + .cancel(cancel_flag.clone()) + .run() + .await?; + // Load into memory model.load().await?; diff --git a/sdk/rust/src/detail/core_interop.rs b/sdk/rust/src/detail/core_interop.rs index 0d17fe62d..5120fa2ef 100644 --- a/sdk/rust/src/detail/core_interop.rs +++ b/sdk/rust/src/detail/core_interop.rs @@ -9,6 +9,7 @@ use std::ffi::CString; use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use libloading::{Library, Symbol}; @@ -17,6 +18,12 @@ use serde_json::Value; use crate::configuration::Configuration; use crate::error::{FoundryLocalError, Result}; +fn checked_i32_length(name: &str, len: usize) -> Result { + i32::try_from(len).map_err(|_| FoundryLocalError::CommandExecution { + reason: format!("{name} length {len} exceeds i32::MAX"), + }) +} + // ── FFI types ──────────────────────────────────────────────────────────────── /// Request buffer passed to the native library. @@ -143,6 +150,8 @@ unsafe fn free_native_buffer(ptr: *mut u8) { struct StreamingCallbackState<'a> { callback: &'a mut dyn FnMut(&str), buf: Vec, + cancel_flag: Option>, + cancelled_observed: bool, } impl<'a> StreamingCallbackState<'a> { @@ -150,9 +159,37 @@ impl<'a> StreamingCallbackState<'a> { Self { callback, buf: Vec::new(), + cancel_flag: None, + cancelled_observed: false, + } + } + + fn new_cancellable(callback: &'a mut dyn FnMut(&str), cancel_flag: Arc) -> Self { + Self { + callback, + buf: Vec::new(), + cancel_flag: Some(cancel_flag), + cancelled_observed: false, } } + /// Records and returns `true` only when this callback invocation observes a cancellation request. + fn mark_cancelled_if_requested(&mut self) -> bool { + let cancelled = self + .cancel_flag + .as_ref() + .is_some_and(|f| f.load(Ordering::Relaxed)); + if cancelled { + self.cancelled_observed = true; + } + + cancelled + } + + fn cancellation_observed(&self) -> bool { + self.cancelled_observed + } + /// Append raw bytes, decode as much valid UTF-8 as possible, and forward /// complete text to the callback. Any trailing incomplete multi-byte /// sequence is kept in the buffer for the next call. Invalid byte @@ -193,9 +230,13 @@ impl<'a> StreamingCallbackState<'a> { } } - /// Flush any remaining bytes as lossy UTF-8 (called once after the native - /// call completes). + /// Flush any remaining bytes as lossy UTF-8 after a completed native call. fn flush(&mut self) { + if self.cancelled_observed { + self.buf.clear(); + return; + } + if !self.buf.is_empty() { let text = String::from_utf8_lossy(&self.buf).into_owned(); (self.callback)(&text); @@ -225,16 +266,19 @@ unsafe extern "C" fn streaming_trampoline( // by the caller of `execute_command_with_callback` for the duration of // the native call. let state = &mut *(user_data as *mut StreamingCallbackState<'_>); + + // Check for cancellation before processing the chunk. + if state.mark_cancelled_if_requested() { + return 1; // cancel + } + // SAFETY: `data` is valid for `length` bytes as guaranteed by the native // core's callback contract. let slice = std::slice::from_raw_parts(data, length as usize); state.push(slice); + 0 // continue })); - if result.is_err() { - 1 - } else { - 0 - } + result.unwrap_or(1) } // ── CoreInterop ────────────────────────────────────────────────────────────── @@ -366,9 +410,9 @@ impl CoreInterop { let request = RequestBuffer { command: cmd.as_ptr(), - command_length: cmd.as_bytes().len() as i32, + command_length: checked_i32_length("command", cmd.as_bytes().len())?, data: data_cstr.as_ptr(), - data_length: data_cstr.as_bytes().len() as i32, + data_length: checked_i32_length("data", data_cstr.as_bytes().len())?, }; let mut response = ResponseBuffer::new(); @@ -416,15 +460,15 @@ impl CoreInterop { let request = StreamingRequestBuffer { command: cmd.as_ptr(), - command_length: cmd.as_bytes().len() as i32, + command_length: checked_i32_length("command", cmd.as_bytes().len())?, data: data_cstr.as_ptr(), - data_length: data_cstr.as_bytes().len() as i32, + data_length: checked_i32_length("data", data_cstr.as_bytes().len())?, binary_data: if binary_data.is_empty() { std::ptr::null() } else { binary_data.as_ptr() }, - binary_data_length: binary_data.len() as i32, + binary_data_length: checked_i32_length("binary data", binary_data.len())?, }; let mut response = ResponseBuffer::new(); @@ -452,6 +496,32 @@ impl CoreInterop { where F: FnMut(&str), { + self.execute_command_streaming_impl(command, params, &mut callback, None) + } + + /// Like [`Self::execute_command_streaming`], but accepts a cancellation + /// flag. When `cancel_flag` is set to `true`, the native call will be + /// cancelled at the next callback invocation and an error is returned. + pub fn execute_command_streaming_cancellable( + &self, + command: &str, + params: Option<&Value>, + mut callback: F, + cancel_flag: Arc, + ) -> Result + where + F: FnMut(&str), + { + self.execute_command_streaming_impl(command, params, &mut callback, Some(cancel_flag)) + } + + fn execute_command_streaming_impl( + &self, + command: &str, + params: Option<&Value>, + callback: &mut dyn FnMut(&str), + cancel_flag: Option>, + ) -> Result { let cmd = CString::new(command).map_err(|e| FoundryLocalError::CommandExecution { reason: format!("Invalid command string: {e}"), })?; @@ -467,17 +537,19 @@ impl CoreInterop { let request = RequestBuffer { command: cmd.as_ptr(), - command_length: cmd.as_bytes().len() as i32, + command_length: checked_i32_length("command", cmd.as_bytes().len())?, data: data_cstr.as_ptr(), - data_length: data_cstr.as_bytes().len() as i32, + data_length: checked_i32_length("data", data_cstr.as_bytes().len())?, }; let mut response = ResponseBuffer::new(); // Wrap the closure in a StreamingCallbackState that handles partial // UTF-8 sequences split across native callbacks. - let mut cb = |chunk: &str| callback(chunk); - let mut state = StreamingCallbackState::new(&mut cb); + let mut state = match cancel_flag { + Some(flag) => StreamingCallbackState::new_cancellable(callback, flag), + None => StreamingCallbackState::new(callback), + }; let user_data = &mut state as *mut StreamingCallbackState<'_> as *mut std::ffi::c_void; // SAFETY: `request` fields point into `cmd` and `data_cstr` which are @@ -494,9 +566,19 @@ impl CoreInterop { ); } - // Flush any trailing partial UTF-8 bytes. + let cancelled = state.cancellation_observed(); + + // Flush any trailing partial UTF-8 bytes unless cancellation was observed. state.flush(); + if cancelled { + // Free native response memory before returning the error. + Self::process_response(response).ok(); + return Err(FoundryLocalError::CommandExecution { + reason: "Operation cancelled".to_string(), + }); + } + Self::process_response(response) } @@ -540,6 +622,36 @@ impl CoreInterop { })? } + /// Async version of [`Self::execute_command_streaming_cancellable`]. + /// + /// Accepts a shared cancellation flag (`Arc`). When the flag + /// is set to `true`, the native call will be cancelled at the next + /// callback invocation and an error is returned. + pub async fn execute_command_streaming_cancellable_async( + self: &Arc, + command: String, + params: Option, + callback: F, + cancel_flag: Arc, + ) -> Result + where + F: FnMut(&str) + Send + 'static, + { + let this = Arc::clone(self); + tokio::task::spawn_blocking(move || { + this.execute_command_streaming_cancellable( + &command, + params.as_ref(), + callback, + cancel_flag, + ) + }) + .await + .map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("task join error: {e}"), + })? + } + /// Async streaming variant that bridges the FFI callback into a /// [`tokio::sync::mpsc`] channel. /// @@ -702,3 +814,68 @@ impl CoreInterop { Ok(libs) } } + +#[cfg(test)] +mod tests { + use super::{checked_i32_length, StreamingCallbackState}; + use crate::error::FoundryLocalError; + use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }; + + #[test] + fn cancellation_request_after_callback_is_not_observed_until_next_callback() { + let cancel_flag = Arc::new(AtomicBool::new(false)); + let mut callback = |_chunk: &str| {}; + let mut state = + StreamingCallbackState::new_cancellable(&mut callback, Arc::clone(&cancel_flag)); + + state.push(b"100"); + cancel_flag.store(true, Ordering::Relaxed); + + assert!(!state.cancellation_observed()); + } + + #[test] + fn cancellation_is_recorded_when_callback_observes_cancel_flag() { + let cancel_flag = Arc::new(AtomicBool::new(true)); + let mut callback = |_chunk: &str| {}; + let mut state = StreamingCallbackState::new_cancellable(&mut callback, cancel_flag); + + assert!(state.mark_cancelled_if_requested()); + assert!(state.cancellation_observed()); + } + + #[test] + fn flush_drops_buffer_after_cancellation_without_callback() { + let cancel_flag = Arc::new(AtomicBool::new(true)); + let mut chunks = Vec::new(); + + { + let mut callback = |chunk: &str| chunks.push(chunk.to_owned()); + let mut state = StreamingCallbackState::new_cancellable(&mut callback, cancel_flag); + + state.push(&[0xE2]); + assert!(state.mark_cancelled_if_requested()); + state.flush(); + } + + assert!(chunks.is_empty()); + } + + #[test] + fn checked_i32_length_rejects_too_large_values() { + assert_eq!( + checked_i32_length("data", i32::MAX as usize).unwrap(), + i32::MAX + ); + + match checked_i32_length("data", i32::MAX as usize + 1).unwrap_err() { + FoundryLocalError::CommandExecution { reason } => { + assert!(reason.contains("exceeds i32::MAX")); + } + err => panic!("unexpected error: {err:?}"), + } + } +} diff --git a/sdk/rust/src/detail/model.rs b/sdk/rust/src/detail/model.rs index 08288aee8..83dcc12c3 100644 --- a/sdk/rust/src/detail/model.rs +++ b/sdk/rust/src/detail/model.rs @@ -6,7 +6,7 @@ use std::fmt; use std::path::PathBuf; -use std::sync::atomic::{AtomicUsize, Ordering::Relaxed}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering::Relaxed}; use std::sync::Arc; use super::core_interop::CoreInterop; @@ -32,6 +32,50 @@ pub struct Model { inner: ModelKind, } +type DownloadProgressCallback = Box; + +/// Builder for configuring and running a model download. +/// +/// Use this builder when combining optional settings like progress and cancellation. +pub struct DownloadBuilder<'a> { + model: &'a Model, + progress: Option, + cancel_flag: Option>, +} + +impl<'a> DownloadBuilder<'a> { + fn new(model: &'a Model) -> Self { + Self { + model, + progress: None, + cancel_flag: None, + } + } + + /// Report download progress as a percentage from 0.0 to 100.0. + pub fn progress(mut self, callback: F) -> Self + where + F: FnMut(f64) + Send + 'static, + { + self.progress = Some(Box::new(callback)); + self + } + + /// Cancel the download when `cancel_flag` is set to `true`. + pub fn cancel(mut self, cancel_flag: Arc) -> Self { + self.cancel_flag = Some(cancel_flag); + self + } + + /// Run the configured download. + pub async fn run(self) -> Result<()> { + self.model + .selected_variant() + .download_with_options(self.progress, self.cancel_flag) + .await + } +} + #[allow(clippy::large_enum_variant)] enum ModelKind { /// A single model variant (from `get_model_variant` or `variants()`). @@ -213,6 +257,14 @@ impl Model { self.selected_variant().download(progress).await } + /// Configure and run a model download with a builder. + /// + /// Use this for call sites that need progress, cancellation, or future + /// download options. + pub fn download_builder(&self) -> DownloadBuilder<'_> { + DownloadBuilder::new(self) + } + /// Return the local file-system path of the (selected) variant. pub async fn path(&self) -> Result { self.selected_variant().path().await diff --git a/sdk/rust/src/detail/model_variant.rs b/sdk/rust/src/detail/model_variant.rs index 1f8ce7d5b..905ca5f74 100644 --- a/sdk/rust/src/detail/model_variant.rs +++ b/sdk/rust/src/detail/model_variant.rs @@ -5,6 +5,7 @@ use std::fmt; use std::path::PathBuf; +use std::sync::atomic::AtomicBool; use std::sync::Arc; use serde_json::json; @@ -88,26 +89,47 @@ impl ModelVariant { } pub(crate) async fn download(&self, progress: Option) -> Result<()> + where + F: FnMut(f64) + Send + 'static, + { + self.download_with_options(progress, None).await + } + + pub(crate) async fn download_with_options( + &self, + progress: Option, + cancel_flag: Option>, + ) -> Result<()> where F: FnMut(f64) + Send + 'static, { let params = json!({ "Params": { "Model": self.info.id } }); - match progress { - Some(mut cb) => { - let wrapper = move |chunk: &str| { - for token in chunk.split_whitespace() { - if let Ok(pct) = token.parse::() { - cb(pct); - } + if progress.is_none() && cancel_flag.is_none() { + self.core + .execute_command_async("download_model".into(), Some(params)) + .await?; + } else { + let mut progress = progress; + let wrapper = move |chunk: &str| { + if let Some(cb) = progress.as_mut() { + if let Ok(pct) = chunk.trim().parse::() { + cb(pct); } - }; + } + }; + + if let Some(flag) = cancel_flag { self.core - .execute_command_streaming_async("download_model".into(), Some(params), wrapper) + .execute_command_streaming_cancellable_async( + "download_model".into(), + Some(params), + wrapper, + flag, + ) .await?; - } - None => { + } else { self.core - .execute_command_async("download_model".into(), Some(params)) + .execute_command_streaming_async("download_model".into(), Some(params), wrapper) .await?; } } diff --git a/sdk/rust/src/foundry_local_manager.rs b/sdk/rust/src/foundry_local_manager.rs index 0c22ef154..4d9377ae5 100644 --- a/sdk/rust/src/foundry_local_manager.rs +++ b/sdk/rust/src/foundry_local_manager.rs @@ -4,6 +4,7 @@ //! library, provides access to the model [`Catalog`], and can start / stop //! the local web service. +use std::sync::atomic::AtomicBool; use std::sync::{Arc, Mutex, OnceLock}; use serde_json::json; @@ -32,6 +33,67 @@ pub struct FoundryLocalManager { _logger: Option>, } +type EpDownloadProgressCallback = Box; + +/// Builder for configuring and running execution provider downloads. +pub struct EpDownloadBuilder<'a> { + manager: &'a FoundryLocalManager, + names: Option>, + progress_callback: Option, + cancel_flag: Option>, +} + +impl<'a> EpDownloadBuilder<'a> { + fn new(manager: &'a FoundryLocalManager) -> Self { + Self { + manager, + names: None, + progress_callback: None, + cancel_flag: None, + } + } + + /// Download only the named execution providers. + pub fn names(mut self, names: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.names = Some(names.into_iter().map(Into::into).collect()); + self + } + + /// Report per-EP download progress as `(ep_name, percent)`. + pub fn progress(mut self, callback: F) -> Self + where + F: FnMut(&str, f64) + Send + 'static, + { + self.progress_callback = Some(Box::new(callback)); + self + } + + /// Cancel the download when `cancel_flag` is set to `true`. + pub fn cancel(mut self, cancel_flag: Arc) -> Self { + self.cancel_flag = Some(cancel_flag); + self + } + + /// Run the configured execution provider download. + pub async fn run(self) -> Result { + let names: Option> = self + .names + .as_ref() + .map(|names| names.iter().map(String::as_str).collect()); + self.manager + .download_and_register_eps_impl( + names.as_deref(), + self.progress_callback, + self.cancel_flag, + ) + .await + } +} + impl FoundryLocalManager { /// Initialise the SDK. /// @@ -150,7 +212,7 @@ impl FoundryLocalManager { &self, names: Option<&[&str]>, ) -> Result { - self.download_and_register_eps_impl(names, None::) + self.download_and_register_eps_impl(names, None::, None) .await } @@ -169,14 +231,23 @@ impl FoundryLocalManager { where F: FnMut(&str, f64) + Send + 'static, { - self.download_and_register_eps_impl(names, Some(progress_callback)) + self.download_and_register_eps_impl(names, Some(progress_callback), None) .await } + /// Configure and run execution provider downloads with a builder. + /// + /// Use this for call sites that need names, progress, cancellation, or + /// future download options. + pub fn download_and_register_eps_builder(&self) -> EpDownloadBuilder<'_> { + EpDownloadBuilder::new(self) + } + async fn download_and_register_eps_impl( &self, names: Option<&[&str]>, progress_callback: Option, + cancel_flag: Option>, ) -> Result where F: FnMut(&str, f64) + Send + 'static, @@ -186,8 +257,28 @@ impl FoundryLocalManager { _ => None, }; - let raw = match progress_callback { - Some(cb) => { + let raw = match (progress_callback, cancel_flag) { + (Some(cb), Some(flag)) => { + let mut callback = cb; + let wrapper = move |chunk: &str| { + if let Some(sep) = chunk.find('|') { + let name = &chunk[..sep]; + if let Ok(percent) = chunk[sep + 1..].parse::() { + callback(if name.is_empty() { "" } else { name }, percent); + } + } + }; + + self.core + .execute_command_streaming_cancellable_async( + "download_and_register_eps".into(), + params, + wrapper, + flag, + ) + .await? + } + (Some(cb), None) => { let mut callback = cb; let wrapper = move |chunk: &str| { if let Some(sep) = chunk.find('|') { @@ -206,7 +297,17 @@ impl FoundryLocalManager { ) .await? } - None => { + (None, Some(flag)) => { + self.core + .execute_command_streaming_cancellable_async( + "download_and_register_eps".into(), + params, + |_chunk: &str| {}, + flag, + ) + .await? + } + (None, None) => { self.core .execute_command_async("download_and_register_eps".into(), params) .await? diff --git a/sdk/rust/src/lib.rs b/sdk/rust/src/lib.rs index 9fb4bb85b..73c4180a0 100644 --- a/sdk/rust/src/lib.rs +++ b/sdk/rust/src/lib.rs @@ -13,9 +13,9 @@ pub mod openai; pub use self::catalog::Catalog; pub use self::configuration::{FoundryLocalConfig, LogLevel, Logger}; -pub use self::detail::model::Model; +pub use self::detail::model::{DownloadBuilder, Model}; pub use self::error::FoundryLocalError; -pub use self::foundry_local_manager::FoundryLocalManager; +pub use self::foundry_local_manager::{EpDownloadBuilder, FoundryLocalManager}; pub use self::types::{ ChatResponseFormat, ChatToolChoice, DeviceType, EpDownloadResult, EpInfo, ModelInfo, ModelSettings, Parameter, PromptTemplate, Runtime,