diff --git a/.editorconfig b/.editorconfig index fe6d677..8955f96 100644 --- a/.editorconfig +++ b/.editorconfig @@ -20,7 +20,7 @@ trim_trailing_whitespace = false [*.sh] indent_size = 2 -[*.{yaml,yml}] +[*.{yaml,yml,json}] indent_size = 2 [*.{c,cpp,h,hpp}] diff --git a/.github/workflows/dist_pipeline.yml b/.github/workflows/dist_pipeline.yml index c7356f8..174db21 100644 --- a/.github/workflows/dist_pipeline.yml +++ b/.github/workflows/dist_pipeline.yml @@ -34,16 +34,16 @@ jobs: exclude_archs: "windows_amd64_mingw;osx_amd64;wasm_mvp;wasm_eh;wasm_threads" duckdb-stable-build: - uses: duckdb/extension-ci-tools/.github/workflows/_extension_distribution.yml@v1.5.2 + uses: duckdb/extension-ci-tools/.github/workflows/_extension_distribution.yml@v1.5.3 with: - duckdb_version: v1.5.2 - ci_tools_version: v1.5.2 + duckdb_version: v1.5.3 + ci_tools_version: v1.5.3 extension_name: infera enable_rust: true exclude_archs: "windows_amd64_mingw;osx_amd64;wasm_mvp;wasm_eh;wasm_threads" create-release-draft: - name: Create Draft Release with Built Binaries + name: Create draft release with built binaries needs: - duckdb-stable-build if: startsWith(github.ref, 'refs/tags/') @@ -51,17 +51,17 @@ jobs: permissions: contents: write steps: - - name: Download All Build Artifacts + - name: Download all build artifacts uses: actions/download-artifact@v4 with: path: dist merge-multiple: true - - name: List Artifacts + - name: List artifacts run: | echo "Downloaded artifacts to: $(pwd)/dist" ls -la dist || true find dist -type f -maxdepth 2 -print || true - - name: Create Draft Release and Upload Assets + - name: Create draft release and upload assets uses: softprops/action-gh-release@v2 with: draft: true diff --git a/.github/workflows/lints.yml b/.github/workflows/lints.yml index 885f73a..8253aba 100644 --- a/.github/workflows/lints.yml +++ b/.github/workflows/lints.yml @@ -17,7 +17,7 @@ jobs: runs-on: ubuntu-latest steps: - - name: Checkout Code + - name: Checkout code uses: actions/checkout@v4 with: submodules: recursive @@ -25,10 +25,10 @@ jobs: - name: Set up Rust uses: actions-rust-lang/setup-rust-toolchain@v1 - - name: Install Dependencies + - name: Install dependencies run: | sudo apt-get update && sudo apt-get upgrade -y sudo apt-get install -y gcc curl pkg-config libssl-dev make - - name: Run Linters + - name: Run linters run: make rust-lint diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 72e9b91..a7a58ed 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -11,18 +11,19 @@ on: push: branches: - main + - develop permissions: contents: read jobs: rust-tests: - name: Rust Tests and Code Coverage + name: Rust tests and code coverage runs-on: ubuntu-latest env: CARGO_TERM_COLOR: always steps: - - name: Checkout Code + - name: Checkout code uses: actions/checkout@v4 with: submodules: recursive @@ -41,31 +42,31 @@ jobs: restore-keys: | ${{ runner.os }}-cargo- - - name: Install System Dependencies + - name: Install system dependencies run: | sudo apt-get update sudo apt-get install -y gcc curl pkg-config libssl-dev make cargo install cargo-tarpaulin --locked || true - - name: Run Rust Tests + - name: Run Rust tests run: | make rust-test - - name: Generate Coverage (using Tarpaulin) + - name: Generate coverage (using Tarpaulin) run: | make rust-coverage - - name: Upload Coverage Reports to Codecov + - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODECOV_TOKEN }} continue-on-error: false sqllogictest: - name: Sqllogicest Tests + name: Sqllogicest tests runs-on: ubuntu-latest steps: - - name: Checkout Code + - name: Checkout code uses: actions/checkout@v4 with: submodules: recursive @@ -73,12 +74,12 @@ jobs: - name: Set up Rust uses: actions-rust-lang/setup-rust-toolchain@v1 - - name: Install System Dependencies + - name: Install system dependencies run: | sudo apt-get update sudo apt-get install -y gcc curl pkg-config libssl-dev make - - name: Build Extension & Run SQL Tests + - name: Build extension and run SQL tests run: | make release make test diff --git a/.gitignore b/.gitignore index 7ec3255..2e63bbb 100644 --- a/.gitignore +++ b/.gitignore @@ -97,3 +97,5 @@ tests/temp_models/ *.cast .claude/ .codex +.antigravitycli/ +.agents/ diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..43c994c --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1 @@ +@AGENTS.md diff --git a/CMakeLists.txt b/CMakeLists.txt index f831707..e850f7b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.5) set(CORROSION_VERBOSE_OUTPUT ON) -set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED 1) option(INFERA_ENABLE_ONNX "Enable ONNX (tract) inference backend in Rust crate" ON) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8fa73e7..f4f5d7a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -29,6 +29,9 @@ would like to work on or if it has already been resolved. ### Development Workflow +> [!IMPORTANT] +> If you're using an AI-assisted coding tool like Claude Code or Codex, make sure the AI follows the instructions in the [AGENTS.md](AGENTS.md) file. + #### Code Style - Use the `make rust-format` command to format the code. diff --git a/README.md b/README.md index afd1c94..7aef1ca 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ make release #### Trying Infera ```sql --- 0. Install and load Infera +-- 0. Install and load Infera -- Skip this step if you built from source and ran `./build/release/duckdb` install infera from community; load infera; diff --git a/extension_config.cmake b/extension_config.cmake index e456c9f..3e4d606 100644 --- a/extension_config.cmake +++ b/extension_config.cmake @@ -69,7 +69,14 @@ if (EXISTS ${INFERA_RUST_LIB}) # Create an imported target for the Rust library add_library(infera_rust STATIC IMPORTED GLOBAL) if(UNIX) - set(_INFERA_RUST_LINK_LIBS "pthread;dl;m") + if(APPLE) + # reqwest 0.13 uses rustls-platform-verifier on Apple platforms, which + # calls into the macOS Security and CoreFoundation frameworks. Link them + # explicitly so the static Rust archive resolves at extension link time. + set(_INFERA_RUST_LINK_LIBS "pthread;dl;m;-framework Security;-framework CoreFoundation;-framework SystemConfiguration") + else() + set(_INFERA_RUST_LINK_LIBS "pthread;dl;m") + endif() else() set(_INFERA_RUST_LINK_LIBS "") endif() @@ -80,7 +87,11 @@ if (EXISTS ${INFERA_RUST_LIB}) # Add the Rust library to global link libraries so it gets linked to everything if(UNIX) - link_libraries(${INFERA_RUST_LIB} pthread dl m) + if(APPLE) + link_libraries(${INFERA_RUST_LIB} pthread dl m "-framework Security" "-framework CoreFoundation" "-framework SystemConfiguration") + else() + link_libraries(${INFERA_RUST_LIB} pthread dl m) + endif() else() link_libraries(${INFERA_RUST_LIB}) if(WIN32) @@ -107,6 +118,11 @@ if (EXISTS ${INFERA_RUST_LIB}) add_link_options($<$,EXECUTABLE>:-lpthread>) add_link_options($<$,EXECUTABLE>:-ldl>) add_link_options($<$,EXECUTABLE>:-lm>) + # Apple framework flags are NOT added via add_link_options: CMake splits + # "-framework Security" at the space inside generator expressions, producing + # the literal token '$<0:-framework' instead of an empty string. The frameworks + # are already propagated to all targets by link_libraries() and + # INTERFACE_LINK_LIBRARIES above, so no add_link_options entry is needed. else() add_link_options($<$,EXECUTABLE>:${INFERA_RUST_LIB}>) endif() diff --git a/infera/.cargo/config.toml b/infera/.cargo/config.toml new file mode 100644 index 0000000..81b00ec --- /dev/null +++ b/infera/.cargo/config.toml @@ -0,0 +1,5 @@ +[target.x86_64-pc-windows-msvc] +rustflags = ["-C", "target-feature=+crt-static"] + +[target.aarch64-pc-windows-msvc] +rustflags = ["-C", "target-feature=+crt-static"] diff --git a/infera/Cargo.toml b/infera/Cargo.toml index 491d4b7..08168ff 100644 --- a/infera/Cargo.toml +++ b/infera/Cargo.toml @@ -20,8 +20,8 @@ thiserror = "2.0" serde_json = "1.0" tract-onnx = { version = "0.22", optional = true } ndarray = { version = "0.17.1", optional = true } -reqwest = { version = "0.12", features = ["blocking", "rustls-tls"], default-features = false } -sha2 = "0.10" +reqwest = { version = "0.13.3", features = ["blocking", "rustls"], default-features = false } +sha2 = "0.11.0" hex = "0.4" filetime = "0.2" diff --git a/infera/bindings/infera_extension.cpp b/infera/bindings/infera_extension.cpp index 92f207c..33d8ed0 100644 --- a/infera/bindings/infera_extension.cpp +++ b/infera/bindings/infera_extension.cpp @@ -37,6 +37,25 @@ static std::string GetInferaError() { return err ? std::string(err) : std::string("unknown error"); } +/** + * @brief Creates a DuckDB scalar function with Infera execution metadata. + * + * Infera functions generally read or mutate model engine state, and many can + * raise runtime errors. The metadata prevents DuckDB from optimizing them as + * pure deterministic functions. + */ +static ScalarFunction InferaScalarFunction(const std::string &name, vector arguments, LogicalType return_type, + scalar_function_t function, bool volatile_state = false, bool fallible = true) { + auto scalar_function = ScalarFunction(name, std::move(arguments), std::move(return_type), function); + if (volatile_state) { + scalar_function.SetVolatile(); + } + if (fallible) { + scalar_function.SetFallible(); + } + return scalar_function; +} + /** * @brief Implements the `infera_set_autoload_dir(path)` SQL function. * @@ -300,11 +319,14 @@ static void PredictFromBlob(DataChunk &args, ExpressionState &state, Vector &res * @param result The result vector to populate. */ static void GetLoadedModels(DataChunk &args, ExpressionState &state, Vector &result) { - char *models_json = infera::infera_get_loaded_models(); + char *models_json_c = infera::infera_get_loaded_models(); + // Guard against the null_mut() path in the Rust fallback (should not occur in + // practice, but infera_get_loaded_models documents that it can return NULL). + std::string models_json = models_json_c ? std::string(models_json_c) : std::string("[]"); + infera::infera_free(models_json_c); result.SetVectorType(VectorType::CONSTANT_VECTOR); ConstantVector::GetData(result)[0] = StringVector::AddString(result, models_json); ConstantVector::SetNull(result, false); - infera::infera_free(models_json); } static void IsModelLoaded(DataChunk &args, ExpressionState &state, Vector &result) { @@ -504,30 +526,51 @@ static void GetCacheInfo(DataChunk &args, ExpressionState &state, Vector &result * @param loader The extension loader provided by DuckDB. */ static void LoadInternal(ExtensionLoader &loader) { - loader.RegisterFunction(ScalarFunction("infera_load_model", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, LoadModel)); - loader.RegisterFunction(ScalarFunction("infera_unload_model", {LogicalType::VARCHAR}, LogicalType::BOOLEAN, UnloadModel)); + loader.RegisterFunction(InferaScalarFunction("infera_load_model", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, LoadModel, true)); + loader.RegisterFunction(InferaScalarFunction("infera_unload_model", {LogicalType::VARCHAR}, LogicalType::BOOLEAN, UnloadModel, true)); const idx_t MAX_FEATURES = 127; for (idx_t feature_count = 1; feature_count <= MAX_FEATURES; feature_count++) { - vector arg_types; - arg_types.reserve(feature_count + 1); - arg_types.push_back(LogicalType::VARCHAR); + vector float_arg_types; + float_arg_types.reserve(feature_count + 1); + float_arg_types.push_back(LogicalType::VARCHAR); + for (idx_t i = 0; i < feature_count; i++) { + float_arg_types.push_back(LogicalType::FLOAT); + } + // volatile_state=true: inference reads shared mutable model state; the + // planner must not CSE or constant-fold these calls across row groups. + loader.RegisterFunction(InferaScalarFunction("infera_predict", float_arg_types, LogicalType::FLOAT, Predict, true)); + loader.RegisterFunction(InferaScalarFunction("infera_predict_multi", float_arg_types, LogicalType::VARCHAR, PredictMulti, true)); + loader.RegisterFunction(InferaScalarFunction("infera_predict_multi_list", float_arg_types, LogicalType::LIST(LogicalType::FLOAT), PredictMultiList, true)); + + // DOUBLE overloads: DuckDB main changed how it handles DECIMAL→FLOAT implicit + // casts at bind time, causing an internal error for DECIMAL literal inputs. + // Registering DOUBLE overloads gives DuckDB a DECIMAL→DOUBLE path that works + // across all supported versions. ExtractFeatures already handles DOUBLE values. + vector double_arg_types; + double_arg_types.reserve(feature_count + 1); + double_arg_types.push_back(LogicalType::VARCHAR); for (idx_t i = 0; i < feature_count; i++) { - arg_types.push_back(LogicalType::FLOAT); + double_arg_types.push_back(LogicalType::DOUBLE); } - loader.RegisterFunction(ScalarFunction("infera_predict", arg_types, LogicalType::FLOAT, Predict)); - loader.RegisterFunction(ScalarFunction("infera_predict_multi", arg_types, LogicalType::VARCHAR, PredictMulti)); - loader.RegisterFunction(ScalarFunction("infera_predict_multi_list", arg_types, LogicalType::LIST(LogicalType::FLOAT), PredictMultiList)); - } - - loader.RegisterFunction(ScalarFunction("infera_predict_from_blob", {LogicalType::VARCHAR, LogicalType::BLOB}, LogicalType::LIST(LogicalType::FLOAT), PredictFromBlob)); - loader.RegisterFunction(ScalarFunction("infera_get_loaded_models", {}, LogicalType::VARCHAR, GetLoadedModels)); - loader.RegisterFunction(ScalarFunction("infera_get_model_info", {LogicalType::VARCHAR}, LogicalType::VARCHAR, GetModelInfo)); - loader.RegisterFunction(ScalarFunction("infera_get_version", {}, LogicalType::VARCHAR, GetVersion)); - loader.RegisterFunction(ScalarFunction("infera_set_autoload_dir", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetAutoloadDir)); - loader.RegisterFunction(ScalarFunction("infera_is_model_loaded", {LogicalType::VARCHAR}, LogicalType::BOOLEAN, IsModelLoaded)); - loader.RegisterFunction(ScalarFunction("infera_clear_cache", {}, LogicalType::BOOLEAN, ClearCache)); - loader.RegisterFunction(ScalarFunction("infera_get_cache_info", {}, LogicalType::VARCHAR, GetCacheInfo)); + loader.RegisterFunction(InferaScalarFunction("infera_predict", double_arg_types, LogicalType::FLOAT, Predict, true)); + loader.RegisterFunction(InferaScalarFunction("infera_predict_multi", double_arg_types, LogicalType::VARCHAR, PredictMulti, true)); + loader.RegisterFunction(InferaScalarFunction("infera_predict_multi_list", double_arg_types, LogicalType::LIST(LogicalType::FLOAT), PredictMultiList, true)); + } + + // volatile_state=true: reads mutable model state; same reasoning as predict. + loader.RegisterFunction(InferaScalarFunction("infera_predict_from_blob", {LogicalType::VARCHAR, LogicalType::BLOB}, LogicalType::LIST(LogicalType::FLOAT), PredictFromBlob, true)); + loader.RegisterFunction(InferaScalarFunction("infera_get_loaded_models", {}, LogicalType::VARCHAR, GetLoadedModels, true, false)); + // volatile_state=true: reads the live model registry; a model reload between + // two calls in the same query must produce fresh metadata each time. + loader.RegisterFunction(InferaScalarFunction("infera_get_model_info", {LogicalType::VARCHAR}, LogicalType::VARCHAR, GetModelInfo, true)); + loader.RegisterFunction(InferaScalarFunction("infera_get_version", {}, LogicalType::VARCHAR, GetVersion, false, false)); + loader.RegisterFunction(InferaScalarFunction("infera_set_autoload_dir", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetAutoloadDir, true)); + loader.RegisterFunction(InferaScalarFunction("infera_is_model_loaded", {LogicalType::VARCHAR}, LogicalType::BOOLEAN, IsModelLoaded, true, false)); + loader.RegisterFunction(InferaScalarFunction("infera_clear_cache", {}, LogicalType::BOOLEAN, ClearCache, true)); + // volatile_state=true: cache state changes whenever infera_clear_cache or + // a remote model download updates the cache directory. + loader.RegisterFunction(InferaScalarFunction("infera_get_cache_info", {}, LogicalType::VARCHAR, GetCacheInfo, true, false)); } void InferaExtension::Load(ExtensionLoader &loader) { LoadInternal(loader); } diff --git a/infera/src/http.rs b/infera/src/http.rs index d508551..c3cee4c 100644 --- a/infera/src/http.rs +++ b/infera/src/http.rs @@ -140,12 +140,26 @@ pub(crate) fn clear_cache() -> Result<(), InferaError> { Ok(()) } +/// The result of a remote model cache validation or download check. +#[derive(Debug, PartialEq, Eq)] +enum DownloadResult { + /// The remote model has not been modified on the server. + NotModified, + /// A new model was downloaded, optionally returning the server's new ETag. + Downloaded { etag: Option }, +} + /// Handles the download and caching of a remote model from a URL. /// /// If the model for the given URL is already present in the local cache, this -/// function updates its access time and returns the path. Otherwise, it downloads -/// the file, evicts old cache entries if needed, stores it in the cache directory, -/// and then returns the path. +/// function performs an HTTP cache validation check using the stored ETag metadata +/// if available. If the server confirms that the model is unmodified (HTTP 304), +/// the cached model is reused. If the server has an updated model (HTTP 200), +/// it is downloaded, local cache is evicted if needed, and the new ETag metadata +/// is stored. +/// +/// If no local ETag exists but the cached model file does, it falls back to +/// assuming the cached model is valid to prevent unnecessary server requests. /// /// The cache uses an LRU (Least Recently Used) eviction policy with a configurable /// size limit (default 1GB, configurable via INFERA_CACHE_SIZE_LIMIT env var). @@ -172,24 +186,37 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { hasher.update(url.as_bytes()); let hash_hex = hex::encode(hasher.finalize()); let cached_path = cache_dir.join(format!("{}.onnx", hash_hex)); + let etag_path = cache_dir.join(format!("{}.etag", hash_hex)); + let mut local_etag = None; if cached_path.exists() { - log!(LogLevel::Info, "Cache hit for URL: {}", url); - // Update access time for LRU tracking - touch_cache_file(&cached_path)?; - return Ok(cached_path); + if etag_path.exists() { + if let Ok(etag_val) = fs::read_to_string(&etag_path) { + local_etag = Some(etag_val.trim().to_string()); + log!(LogLevel::Info, "Found local ETag metadata for URL: {}", url); + } + } else { + log!( + LogLevel::Info, + "Cache hit for URL (no ETag metadata): {}", + url + ); + touch_cache_file(&cached_path)?; + return Ok(cached_path); + } } log!( LogLevel::Info, - "Cache miss for URL: {}, downloading...", - url + "Cache check/download path engaged for URL: {}, local_etag: {:?}", + url, + local_etag ); let temp_path = cached_path.with_extension("onnx.part"); let mut guard = TempFileGuard::new(&temp_path); - // Download with retry logic + // Download or validate with retry logic let max_attempts = CONFIG.http_retry_attempts; let retry_delay_ms = CONFIG.http_retry_delay_ms; let timeout_secs = CONFIG.http_timeout_secs; @@ -199,14 +226,19 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { for attempt in 1..=max_attempts { log!( LogLevel::Debug, - "Download attempt {}/{} for {}", + "Download/Validation attempt {}/{} for {}", attempt, max_attempts, url ); - match download_file(url, &temp_path, timeout_secs) { - Ok(_) => { + match download_file(url, &temp_path, timeout_secs, local_etag.as_deref()) { + Ok(DownloadResult::NotModified) => { + log!(LogLevel::Info, "Cache hit (ETag verified) for URL: {}", url); + touch_cache_file(&cached_path)?; + return Ok(cached_path); + } + Ok(DownloadResult::Downloaded { etag: new_etag }) => { log!(LogLevel::Info, "Successfully downloaded: {}", url); // Check file size and evict cache if needed @@ -220,13 +252,22 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { fs::rename(&temp_path, &cached_path) .map_err(|e| InferaError::IoError(e.to_string()))?; + // Save new ETag metadata if provided, otherwise clean up stale metadata + if let Some(etag_val) = new_etag { + if let Err(e) = fs::write(&etag_path, etag_val) { + log!(LogLevel::Warn, "Failed to write ETag metadata: {}", e); + } + } else { + let _ = fs::remove_file(&etag_path); + } + guard.commit(); return Ok(cached_path); } Err(e) => { log!( LogLevel::Warn, - "Download attempt {}/{} failed: {}", + "Download/Validation attempt {}/{} failed: {}", attempt, max_attempts, e @@ -245,31 +286,52 @@ pub(crate) fn handle_remote_model(url: &str) -> Result { log!( LogLevel::Error, - "Failed to download after {} attempts: {}", + "Failed to download/validate after {} attempts: {}", max_attempts, url ); Err(last_error.unwrap_or_else(|| InferaError::HttpRequestError("Unknown error".to_string()))) } -/// Download a file from a URL to a local path with timeout -fn download_file(url: &str, dest: &Path, timeout_secs: u64) -> Result<(), InferaError> { +/// Download a file from a URL to a local path with timeout, optionally verifying via ETag. +fn download_file( + url: &str, + dest: &Path, + timeout_secs: u64, + etag: Option<&str>, +) -> Result { let client = reqwest::blocking::Client::builder() .timeout(Duration::from_secs(timeout_secs)) .build() .map_err(|e| InferaError::HttpRequestError(e.to_string()))?; - let mut response = client - .get(url) + let mut request = client.get(url); + if let Some(etag_val) = etag { + request = request.header(reqwest::header::IF_NONE_MATCH, etag_val); + } + + let mut response = request .send() - .map_err(|e| InferaError::HttpRequestError(e.to_string()))? + .map_err(|e| InferaError::HttpRequestError(e.to_string()))?; + + if response.status() == reqwest::StatusCode::NOT_MODIFIED { + return Ok(DownloadResult::NotModified); + } + + response = response .error_for_status() .map_err(|e| InferaError::HttpRequestError(e.to_string()))?; + let new_etag = response + .headers() + .get(reqwest::header::ETAG) + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + let mut file = File::create(dest).map_err(|e| InferaError::IoError(e.to_string()))?; io::copy(&mut response, &mut file).map_err(|e| InferaError::IoError(e.to_string()))?; - Ok(()) + Ok(DownloadResult::Downloaded { etag: new_etag }) } #[cfg(test)] @@ -415,6 +477,144 @@ mod tests { assert!(!temp_path.exists(), "no partial file should remain"); } + #[test] + fn test_handle_remote_model_etag_verified_304() { + let mut server = Server::new(); + let body = b"onnxdata".to_vec(); + + // Use a path unique to this test so the URL hash never collides with other + // tests that mockito may schedule on the same port (OS port reuse). + // 1. Initial request (no ETag matched) returns 200 with ETag "tag1" + let m1 = server + .mock("GET", "/ok_model_etag_304.onnx") + .match_header("if-none-match", mockito::Matcher::Missing) + .with_status(200) + .with_header("ETag", "tag1") + .with_body(body.clone()) + .create(); + + // 2. Subsequent request (with If-None-Match: tag1) returns 304 Not Modified + let m2 = server + .mock("GET", "/ok_model_etag_304.onnx") + .match_header("if-none-match", "tag1") + .with_status(304) + .create(); + + let url = format!("{}/ok_model_etag_304.onnx", server.url()); + + // First download creates the file and the .etag metadata + let path1 = handle_remote_model(&url).expect("initial download should succeed"); + assert!(path1.exists(), "cached file must exist"); + let content1 = fs::read(&path1).expect("read cached file"); + assert_eq!(content1, body); + + let hash_hex = { + let mut hasher = Sha256::new(); + hasher.update(url.as_bytes()); + hex::encode(hasher.finalize()) + }; + let etag_path = cache_dir().join(format!("{}.etag", hash_hex)); + assert!(etag_path.exists(), "etag metadata must exist"); + let etag_content = fs::read_to_string(&etag_path).expect("read etag metadata"); + assert_eq!(etag_content.trim(), "tag1"); + + // Second check hits the server, gets 304, and reuses the cached file + let path2 = handle_remote_model(&url).expect("validation should succeed"); + assert_eq!(path1, path2); + + m1.assert(); + m2.assert(); + } + + #[test] + fn test_handle_remote_model_etag_changed_200() { + let mut server = Server::new(); + let body1 = b"onnxdata1".to_vec(); + let body2 = b"onnxdata2".to_vec(); + + // Use a path unique to this test so the URL hash never collides with other + // tests that mockito may schedule on the same port (OS port reuse). + // 1. Initial request returns 200 with ETag "tag1" and body1 + let m1 = server + .mock("GET", "/ok_model_etag_200.onnx") + .match_header("if-none-match", mockito::Matcher::Missing) + .with_status(200) + .with_header("ETag", "tag1") + .with_body(body1.clone()) + .create(); + + // 2. Subsequent request (with If-None-Match: tag1) returns 200 with ETag "tag2" and body2 + let m2 = server + .mock("GET", "/ok_model_etag_200.onnx") + .match_header("if-none-match", "tag1") + .with_status(200) + .with_header("ETag", "tag2") + .with_body(body2.clone()) + .create(); + + let url = format!("{}/ok_model_etag_200.onnx", server.url()); + + // First download gets body1 + let path1 = handle_remote_model(&url).expect("initial download should succeed"); + let content1 = fs::read(&path1).expect("read cached file"); + assert_eq!(content1, body1); + + // Second validation check gets 200 and downloads body2 + let path2 = handle_remote_model(&url).expect("updated download should succeed"); + assert_eq!(path1, path2); + let content2 = fs::read(&path2).expect("read updated file"); + assert_eq!(content2, body2); + + let hash_hex = { + let mut hasher = Sha256::new(); + hasher.update(url.as_bytes()); + hex::encode(hasher.finalize()) + }; + let etag_path = cache_dir().join(format!("{}.etag", hash_hex)); + assert!(etag_path.exists()); + let etag_content = fs::read_to_string(&etag_path).expect("read etag metadata"); + assert_eq!(etag_content.trim(), "tag2"); + + m1.assert(); + m2.assert(); + } + + #[test] + fn test_handle_remote_model_no_etag_support() { + let mut server = Server::new(); + let body = b"onnxdata".to_vec(); + + // Server does not return ETag + let m = server + .mock("GET", "/no_etag_model.onnx") + .with_status(200) + .with_body(body.clone()) + .create(); + + let url = format!("{}/no_etag_model.onnx", server.url()); + + // First download succeeds but no .etag file is written + let path1 = handle_remote_model(&url).expect("download should succeed"); + assert!(path1.exists()); + let content1 = fs::read(&path1).expect("read cached file"); + assert_eq!(content1, body); + + let hash_hex = { + let mut hasher = Sha256::new(); + hasher.update(url.as_bytes()); + hex::encode(hasher.finalize()) + }; + let etag_path = cache_dir().join(format!("{}.etag", hash_hex)); + assert!(!etag_path.exists(), "etag metadata should not be created"); + + // Second download immediately treats it as a fast-path cache hit (no server request) + let path2 = handle_remote_model(&url).expect("fast cache hit should succeed"); + assert_eq!(path1, path2); + + // Verify that the server mock was only called once + m.assert(); + } + #[test] fn test_clear_cache_removes_files() { let dir = cache_dir(); diff --git a/pyproject.toml b/pyproject.toml index 481d460..7b846c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,15 +11,15 @@ dependencies = [ [project.optional-dependencies] dev = [ - "pytest>=8.0.1", - "pytest-cov>=6.0.0", - "pytest-mock>=3.14.0", + "pytest (>=8.0.1,<10.0.0)", + "pytest-cov (>=6.0.0,<8.0.0)", + "pytest-mock (>=3.14.0,<4.0.0)", "pytest-asyncio (>=0.26.0,<0.27.0)", - "pytest-order>=1.2.0", - "mypy>=1.11.1", - "ruff>=0.9.3", + "pytest-order (>=1.2.0,<2.0.0)", + "mypy (>=1.11.1,<2.0.0)", + "ruff (>=0.9.3,<0.10.0)", "icecream (>=2.1.4,<3.0.0)", - "jsonschema>=4.22.0", + "jsonschema (>=4.22.0,<5.0.0)", "duckdb (>=1.4.0,<1.5.0)", - "onnx>=1.19.0", + "onnx (>=1.19.0,<2.0.0)", ] diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 8cbf438..ea6a42a 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] channel = "nightly-2025-09-25" -components = ["rustfmt", "clippy"] +components = ["rustfmt", "clippy", "rust-analyzer"] diff --git a/test/sql/test_decimal_features.test b/test/sql/test_decimal_features.test index 1e7bc18..6398abc 100644 --- a/test/sql/test_decimal_features.test +++ b/test/sql/test_decimal_features.test @@ -1,7 +1,10 @@ # name: test/sql/test_decimal_features.test # group: [infera] -# Validate DECIMAL feature inputs work (via implicit cast) and return correct results +# Validate DECIMAL feature inputs work and return correct results. +# Note: bare DECIMAL literals (DECIMAL '1.0') have no precision/scale in DuckDB main +# and resolve to PhysicalType::INVALID, which crashes Vector initialization. Use +# explicit precision (::DECIMAL(10,2)) so the type has a valid physical representation. statement ok pragma enable_verification @@ -12,9 +15,9 @@ load 'build/release/extension/infera/infera.duckdb_extension' statement ok select infera_load_model('linear_dec', 'test/models/linear.onnx') -# Use DECIMAL features; expected y = 2*1 - 1*2 + 0.5*3 + 0.25 = 1.75 +# Use DECIMAL(10,2) features; expected y = 2*1 - 1*2 + 0.5*3 + 0.25 = 1.75 query I -select abs(infera_predict('linear_dec', DECIMAL '1.0', DECIMAL '2.0', DECIMAL '3.0') - 1.75) < 1e-5 +select abs(infera_predict('linear_dec', 1.0::DECIMAL(10,2), 2.0::DECIMAL(10,2), 3.0::DECIMAL(10,2)) - 1.75) < 1e-5 ---- true diff --git a/test/sql/test_volatile_and_null_safety.test b/test/sql/test_volatile_and_null_safety.test new file mode 100644 index 0000000..2c6a26e --- /dev/null +++ b/test/sql/test_volatile_and_null_safety.test @@ -0,0 +1,152 @@ +# name: test/sql/test_volatile_and_null_safety.test +# group: [infera] + +# Regression tests + +statement ok +pragma enable_verification + +statement ok +load 'build/release/extension/infera/infera.duckdb_extension' + +# ── Finding 1: predict volatile ──────────────────────────────────────────────── +# Load the linear model (y = 2*x1 - 1*x2 + 0.5*x3 + 0.25). +# Expected output for (1.0, 2.0, 3.0) is 1.75. + +statement ok +select infera_load_model('linear', 'test/models/linear.onnx') + +# Baseline prediction. +query I +select abs(infera_predict('linear', 1.0, 2.0, 3.0) - 1.75) < 1e-5 +---- +true + +# Reload the same model. If the planner had folded the predict call to a +# constant the reload would have no effect on subsequent results. The test +# below asserts that predictions still work correctly after a reload, which +# would fail if folding had cached an engine-level error. +statement ok +select infera_unload_model('linear') + +statement ok +select infera_load_model('linear', 'test/models/linear.onnx') + +query I +select abs(infera_predict('linear', 1.0, 2.0, 3.0) - 1.75) < 1e-5 +---- +true + +# Repeat via infera_predict_multi (JSON string output). +query I +select instr(infera_predict_multi('linear', 1.0, 2.0, 3.0), '1.75') > 0 +---- +true + +# Repeat via infera_predict_multi_list (LIST[FLOAT] output). +query I +select len(infera_predict_multi_list('linear', 1.0, 2.0, 3.0)) > 0 +---- +true + +# Repeat via infera_predict_from_blob. +# The linear model expects 1×3 floats (12 bytes, little-endian). +# encode 1.0, 2.0, 3.0 as IEEE-754 little-endian: each is 4 bytes. +# We use a zero blob to test that the function executes (not the specific value). +query I +select len(infera_predict_from_blob('linear', cast(repeat(chr(0), 12) as blob))) >= 0 +---- +true + +# ── Finding 2: GetLoadedModels null safety ────────────────────────────────────── +# infera_get_loaded_models must return a valid JSON string under normal +# conditions; the null guard is a defensive path for an extreme OOM case in +# Rust. We verify normal behavior returns a well-formed JSON array, not NULL. + +query I +select infera_get_loaded_models() is not null +---- +true + +query I +select infera_get_loaded_models() like '[%' +---- +true + +query I +select instr(infera_get_loaded_models(), 'linear') > 0 +---- +true + +# ── Finding 3: get_model_info volatile ───────────────────────────────────────── +# infera_get_model_info must return fresh data after a reload. If it were +# non-volatile, a constant-folded plan might return stale metadata for a second +# call with the same literal model name. + +query I +select infera_get_model_info('linear') is not null +---- +true + +query I +select infera_get_model_info('linear') like '%input_shape%' +---- +true + +# Unload, then reload. get_model_info should still work correctly afterward, +# confirming it queries live state rather than a cached result. +statement ok +select infera_unload_model('linear') + +statement error +select infera_get_model_info('linear') +---- +Failed to get info for model 'linear' + +statement ok +select infera_load_model('linear', 'test/models/linear.onnx') + +query I +select infera_get_model_info('linear') like '%input_shape%' +---- +true + +# ── Finding 4: get_cache_info volatile ───────────────────────────────────────── +# infera_get_cache_info must return non-null results both before and after a +# cache clear. If it were non-volatile and DuckDB folded two calls in the same +# query to one, a clear followed by a stats call could return stale pre-clear +# data. The test below verifies the function returns valid JSON in both states. + +query I +select infera_get_cache_info() is not null +---- +true + +query I +select infera_get_cache_info() like '%cache_dir%' +---- +true + +query I +select infera_get_cache_info() like '%total_size_bytes%' +---- +true + +# After a cache clear, get_cache_info must still return valid data. +statement ok +select infera_clear_cache() + +query I +select infera_get_cache_info() is not null +---- +true + +query I +select infera_get_cache_info() like '%cache_dir%' +---- +true + +# ── Cleanup ───────────────────────────────────────────────────────────────────── + +statement ok +select infera_unload_model('linear')