Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ trim_trailing_whitespace = false
[*.sh]
indent_size = 2

[*.{yaml,yml}]
[*.{yaml,yml,json}]
indent_size = 2

[*.{c,cpp,h,hpp}]
Expand Down
14 changes: 7 additions & 7 deletions .github/workflows/dist_pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,34 +34,34 @@ jobs:
exclude_archs: "windows_amd64_mingw;osx_amd64;wasm_mvp;wasm_eh;wasm_threads"

duckdb-stable-build:
uses: duckdb/extension-ci-tools/.github/workflows/_extension_distribution.yml@v1.5.2
uses: duckdb/extension-ci-tools/.github/workflows/_extension_distribution.yml@v1.5.3
with:
duckdb_version: v1.5.2
ci_tools_version: v1.5.2
duckdb_version: v1.5.3
ci_tools_version: v1.5.3
extension_name: infera
enable_rust: true
exclude_archs: "windows_amd64_mingw;osx_amd64;wasm_mvp;wasm_eh;wasm_threads"

create-release-draft:
name: Create Draft Release with Built Binaries
name: Create draft release with built binaries
needs:
- duckdb-stable-build
if: startsWith(github.ref, 'refs/tags/')
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- name: Download All Build Artifacts
- name: Download all build artifacts
uses: actions/download-artifact@v4
with:
path: dist
merge-multiple: true
- name: List Artifacts
- name: List artifacts
run: |
echo "Downloaded artifacts to: $(pwd)/dist"
ls -la dist || true
find dist -type f -maxdepth 2 -print || true
- name: Create Draft Release and Upload Assets
- name: Create draft release and upload assets
uses: softprops/action-gh-release@v2
with:
draft: true
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/lints.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@ jobs:
runs-on: ubuntu-latest

steps:
- name: Checkout Code
- name: Checkout code
uses: actions/checkout@v4
with:
submodules: recursive

- name: Set up Rust
uses: actions-rust-lang/setup-rust-toolchain@v1

- name: Install Dependencies
- name: Install dependencies
run: |
sudo apt-get update && sudo apt-get upgrade -y
sudo apt-get install -y gcc curl pkg-config libssl-dev make

- name: Run Linters
- name: Run linters
run: make rust-lint
21 changes: 11 additions & 10 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,19 @@ on:
push:
branches:
- main
- develop

permissions:
contents: read

jobs:
rust-tests:
name: Rust Tests and Code Coverage
name: Rust tests and code coverage
runs-on: ubuntu-latest
env:
CARGO_TERM_COLOR: always
steps:
- name: Checkout Code
- name: Checkout code
uses: actions/checkout@v4
with:
submodules: recursive
Expand All @@ -41,44 +42,44 @@ jobs:
restore-keys: |
${{ runner.os }}-cargo-

- name: Install System Dependencies
- name: Install system dependencies
run: |
sudo apt-get update
sudo apt-get install -y gcc curl pkg-config libssl-dev make
cargo install cargo-tarpaulin --locked || true

- name: Run Rust Tests
- name: Run Rust tests
run: |
make rust-test

- name: Generate Coverage (using Tarpaulin)
- name: Generate coverage (using Tarpaulin)
run: |
make rust-coverage

- name: Upload Coverage Reports to Codecov
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
continue-on-error: false

sqllogictest:
name: Sqllogicest Tests
name: Sqllogicest tests
runs-on: ubuntu-latest
steps:
- name: Checkout Code
- name: Checkout code
uses: actions/checkout@v4
with:
submodules: recursive

- name: Set up Rust
uses: actions-rust-lang/setup-rust-toolchain@v1

- name: Install System Dependencies
- name: Install system dependencies
run: |
sudo apt-get update
sudo apt-get install -y gcc curl pkg-config libssl-dev make

- name: Build Extension & Run SQL Tests
- name: Build extension and run SQL tests
run: |
make release
make test
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,5 @@ tests/temp_models/
*.cast
.claude/
.codex
.antigravitycli/
.agents/
1 change: 1 addition & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
@AGENTS.md
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.5)

set(CORROSION_VERBOSE_OUTPUT ON)
set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED 1)

option(INFERA_ENABLE_ONNX "Enable ONNX (tract) inference backend in Rust crate" ON)
Expand Down
3 changes: 3 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ would like to work on or if it has already been resolved.

### Development Workflow

> [!IMPORTANT]
> If you're using an AI-assisted coding tool like Claude Code or Codex, make sure the AI follows the instructions in the [AGENTS.md](AGENTS.md) file.

#### Code Style

- Use the `make rust-format` command to format the code.
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ make release
#### Trying Infera

```sql
-- 0. Install and load Infera
-- 0. Install and load Infera
-- Skip this step if you built from source and ran `./build/release/duckdb`
install infera from community;
load infera;
Expand Down
20 changes: 18 additions & 2 deletions extension_config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,14 @@ if (EXISTS ${INFERA_RUST_LIB})
# Create an imported target for the Rust library
add_library(infera_rust STATIC IMPORTED GLOBAL)
if(UNIX)
set(_INFERA_RUST_LINK_LIBS "pthread;dl;m")
if(APPLE)
# reqwest 0.13 uses rustls-platform-verifier on Apple platforms, which
# calls into the macOS Security and CoreFoundation frameworks. Link them
# explicitly so the static Rust archive resolves at extension link time.
set(_INFERA_RUST_LINK_LIBS "pthread;dl;m;-framework Security;-framework CoreFoundation;-framework SystemConfiguration")
else()
set(_INFERA_RUST_LINK_LIBS "pthread;dl;m")
endif()
else()
set(_INFERA_RUST_LINK_LIBS "")
endif()
Expand All @@ -80,7 +87,11 @@ if (EXISTS ${INFERA_RUST_LIB})

# Add the Rust library to global link libraries so it gets linked to everything
if(UNIX)
link_libraries(${INFERA_RUST_LIB} pthread dl m)
if(APPLE)
link_libraries(${INFERA_RUST_LIB} pthread dl m "-framework Security" "-framework CoreFoundation" "-framework SystemConfiguration")
else()
link_libraries(${INFERA_RUST_LIB} pthread dl m)
endif()
else()
link_libraries(${INFERA_RUST_LIB})
if(WIN32)
Expand All @@ -107,6 +118,11 @@ if (EXISTS ${INFERA_RUST_LIB})
add_link_options($<$<STREQUAL:$<TARGET_PROPERTY:TYPE>,EXECUTABLE>:-lpthread>)
add_link_options($<$<STREQUAL:$<TARGET_PROPERTY:TYPE>,EXECUTABLE>:-ldl>)
add_link_options($<$<STREQUAL:$<TARGET_PROPERTY:TYPE>,EXECUTABLE>:-lm>)
# Apple framework flags are NOT added via add_link_options: CMake splits
# "-framework Security" at the space inside generator expressions, producing
# the literal token '$<0:-framework' instead of an empty string. The frameworks
# are already propagated to all targets by link_libraries() and
# INTERFACE_LINK_LIBRARIES above, so no add_link_options entry is needed.
else()
add_link_options($<$<STREQUAL:$<TARGET_PROPERTY:TYPE>,EXECUTABLE>:${INFERA_RUST_LIB}>)
endif()
Expand Down
5 changes: 5 additions & 0 deletions infera/.cargo/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[target.x86_64-pc-windows-msvc]
rustflags = ["-C", "target-feature=+crt-static"]

[target.aarch64-pc-windows-msvc]
rustflags = ["-C", "target-feature=+crt-static"]
4 changes: 2 additions & 2 deletions infera/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ thiserror = "2.0"
serde_json = "1.0"
tract-onnx = { version = "0.22", optional = true }
ndarray = { version = "0.17.1", optional = true }
reqwest = { version = "0.12", features = ["blocking", "rustls-tls"], default-features = false }
sha2 = "0.10"
reqwest = { version = "0.13.3", features = ["blocking", "rustls"], default-features = false }
sha2 = "0.11.0"
hex = "0.4"
filetime = "0.2"

Expand Down
85 changes: 64 additions & 21 deletions infera/bindings/infera_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,25 @@ static std::string GetInferaError() {
return err ? std::string(err) : std::string("unknown error");
}

/**
* @brief Creates a DuckDB scalar function with Infera execution metadata.
*
* Infera functions generally read or mutate model engine state, and many can
* raise runtime errors. The metadata prevents DuckDB from optimizing them as
* pure deterministic functions.
*/
static ScalarFunction InferaScalarFunction(const std::string &name, vector<LogicalType> arguments, LogicalType return_type,
scalar_function_t function, bool volatile_state = false, bool fallible = true) {
auto scalar_function = ScalarFunction(name, std::move(arguments), std::move(return_type), function);
if (volatile_state) {
scalar_function.SetVolatile();
}
if (fallible) {
scalar_function.SetFallible();
}
return scalar_function;
}

/**
* @brief Implements the `infera_set_autoload_dir(path)` SQL function.
*
Expand Down Expand Up @@ -300,11 +319,14 @@ static void PredictFromBlob(DataChunk &args, ExpressionState &state, Vector &res
* @param result The result vector to populate.
*/
static void GetLoadedModels(DataChunk &args, ExpressionState &state, Vector &result) {
char *models_json = infera::infera_get_loaded_models();
char *models_json_c = infera::infera_get_loaded_models();
// Guard against the null_mut() path in the Rust fallback (should not occur in
// practice, but infera_get_loaded_models documents that it can return NULL).
std::string models_json = models_json_c ? std::string(models_json_c) : std::string("[]");
infera::infera_free(models_json_c);
result.SetVectorType(VectorType::CONSTANT_VECTOR);
ConstantVector::GetData<string_t>(result)[0] = StringVector::AddString(result, models_json);
ConstantVector::SetNull(result, false);
infera::infera_free(models_json);
}

static void IsModelLoaded(DataChunk &args, ExpressionState &state, Vector &result) {
Expand Down Expand Up @@ -504,30 +526,51 @@ static void GetCacheInfo(DataChunk &args, ExpressionState &state, Vector &result
* @param loader The extension loader provided by DuckDB.
*/
static void LoadInternal(ExtensionLoader &loader) {
loader.RegisterFunction(ScalarFunction("infera_load_model", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, LoadModel));
loader.RegisterFunction(ScalarFunction("infera_unload_model", {LogicalType::VARCHAR}, LogicalType::BOOLEAN, UnloadModel));
loader.RegisterFunction(InferaScalarFunction("infera_load_model", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, LoadModel, true));
loader.RegisterFunction(InferaScalarFunction("infera_unload_model", {LogicalType::VARCHAR}, LogicalType::BOOLEAN, UnloadModel, true));

const idx_t MAX_FEATURES = 127;
for (idx_t feature_count = 1; feature_count <= MAX_FEATURES; feature_count++) {
vector<LogicalType> arg_types;
arg_types.reserve(feature_count + 1);
arg_types.push_back(LogicalType::VARCHAR);
vector<LogicalType> float_arg_types;
float_arg_types.reserve(feature_count + 1);
float_arg_types.push_back(LogicalType::VARCHAR);
for (idx_t i = 0; i < feature_count; i++) {
float_arg_types.push_back(LogicalType::FLOAT);
}
// volatile_state=true: inference reads shared mutable model state; the
// planner must not CSE or constant-fold these calls across row groups.
loader.RegisterFunction(InferaScalarFunction("infera_predict", float_arg_types, LogicalType::FLOAT, Predict, true));
loader.RegisterFunction(InferaScalarFunction("infera_predict_multi", float_arg_types, LogicalType::VARCHAR, PredictMulti, true));
loader.RegisterFunction(InferaScalarFunction("infera_predict_multi_list", float_arg_types, LogicalType::LIST(LogicalType::FLOAT), PredictMultiList, true));

// DOUBLE overloads: DuckDB main changed how it handles DECIMAL→FLOAT implicit
// casts at bind time, causing an internal error for DECIMAL literal inputs.
// Registering DOUBLE overloads gives DuckDB a DECIMAL→DOUBLE path that works
// across all supported versions. ExtractFeatures already handles DOUBLE values.
vector<LogicalType> double_arg_types;
double_arg_types.reserve(feature_count + 1);
double_arg_types.push_back(LogicalType::VARCHAR);
for (idx_t i = 0; i < feature_count; i++) {
arg_types.push_back(LogicalType::FLOAT);
double_arg_types.push_back(LogicalType::DOUBLE);
}
loader.RegisterFunction(ScalarFunction("infera_predict", arg_types, LogicalType::FLOAT, Predict));
loader.RegisterFunction(ScalarFunction("infera_predict_multi", arg_types, LogicalType::VARCHAR, PredictMulti));
loader.RegisterFunction(ScalarFunction("infera_predict_multi_list", arg_types, LogicalType::LIST(LogicalType::FLOAT), PredictMultiList));
}

loader.RegisterFunction(ScalarFunction("infera_predict_from_blob", {LogicalType::VARCHAR, LogicalType::BLOB}, LogicalType::LIST(LogicalType::FLOAT), PredictFromBlob));
loader.RegisterFunction(ScalarFunction("infera_get_loaded_models", {}, LogicalType::VARCHAR, GetLoadedModels));
loader.RegisterFunction(ScalarFunction("infera_get_model_info", {LogicalType::VARCHAR}, LogicalType::VARCHAR, GetModelInfo));
loader.RegisterFunction(ScalarFunction("infera_get_version", {}, LogicalType::VARCHAR, GetVersion));
loader.RegisterFunction(ScalarFunction("infera_set_autoload_dir", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetAutoloadDir));
loader.RegisterFunction(ScalarFunction("infera_is_model_loaded", {LogicalType::VARCHAR}, LogicalType::BOOLEAN, IsModelLoaded));
loader.RegisterFunction(ScalarFunction("infera_clear_cache", {}, LogicalType::BOOLEAN, ClearCache));
loader.RegisterFunction(ScalarFunction("infera_get_cache_info", {}, LogicalType::VARCHAR, GetCacheInfo));
loader.RegisterFunction(InferaScalarFunction("infera_predict", double_arg_types, LogicalType::FLOAT, Predict, true));
loader.RegisterFunction(InferaScalarFunction("infera_predict_multi", double_arg_types, LogicalType::VARCHAR, PredictMulti, true));
loader.RegisterFunction(InferaScalarFunction("infera_predict_multi_list", double_arg_types, LogicalType::LIST(LogicalType::FLOAT), PredictMultiList, true));
}

// volatile_state=true: reads mutable model state; same reasoning as predict.
loader.RegisterFunction(InferaScalarFunction("infera_predict_from_blob", {LogicalType::VARCHAR, LogicalType::BLOB}, LogicalType::LIST(LogicalType::FLOAT), PredictFromBlob, true));
loader.RegisterFunction(InferaScalarFunction("infera_get_loaded_models", {}, LogicalType::VARCHAR, GetLoadedModels, true, false));
// volatile_state=true: reads the live model registry; a model reload between
// two calls in the same query must produce fresh metadata each time.
loader.RegisterFunction(InferaScalarFunction("infera_get_model_info", {LogicalType::VARCHAR}, LogicalType::VARCHAR, GetModelInfo, true));
loader.RegisterFunction(InferaScalarFunction("infera_get_version", {}, LogicalType::VARCHAR, GetVersion, false, false));
loader.RegisterFunction(InferaScalarFunction("infera_set_autoload_dir", {LogicalType::VARCHAR}, LogicalType::VARCHAR, SetAutoloadDir, true));
loader.RegisterFunction(InferaScalarFunction("infera_is_model_loaded", {LogicalType::VARCHAR}, LogicalType::BOOLEAN, IsModelLoaded, true, false));
loader.RegisterFunction(InferaScalarFunction("infera_clear_cache", {}, LogicalType::BOOLEAN, ClearCache, true));
// volatile_state=true: cache state changes whenever infera_clear_cache or
// a remote model download updates the cache directory.
loader.RegisterFunction(InferaScalarFunction("infera_get_cache_info", {}, LogicalType::VARCHAR, GetCacheInfo, true, false));
}

void InferaExtension::Load(ExtensionLoader &loader) { LoadInternal(loader); }
Expand Down
Loading
Loading