-
Notifications
You must be signed in to change notification settings - Fork 184
Add UDF Usage and Developer docs #2030
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -573,8 +573,10 @@ The planner is responsible for: | |||||||||||
| #pragma once | ||||||||||||
|
|
||||||||||||
| #include <cuvs/detail/jit_lto/AlgorithmPlanner.hpp> | ||||||||||||
| #include <cuvs/detail/jit_lto/FragmentEntry.hpp> | ||||||||||||
| #include <cuvs/detail/jit_lto/MakeFragmentKey.hpp> | ||||||||||||
| #include <cuvs/detail/jit_lto/registration_tags.hpp> | ||||||||||||
| #include <memory> | ||||||||||||
| #include <string> | ||||||||||||
|
|
||||||||||||
| struct SearchPlanner : AlgorithmPlanner { | ||||||||||||
|
|
@@ -602,6 +604,16 @@ struct SearchPlanner : AlgorithmPlanner { | |||||||||||
| { | ||||||||||||
| add_static_fragment<fragment_tag_filter<FilterTag, IndexTag>>(); | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| void add_metric_udf_fragment(std::unique_ptr<UDFFatbinFragment> fragment) | ||||||||||||
| { | ||||||||||||
| add_fragment(std::move(fragment)); | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| void add_filter_udf_fragment(std::unique_ptr<UDFFatbinFragment> fragment) | ||||||||||||
| { | ||||||||||||
| add_fragment(std::move(fragment)); | ||||||||||||
| } | ||||||||||||
| }; | ||||||||||||
| ``` | ||||||||||||
|
|
||||||||||||
|
|
@@ -617,9 +629,13 @@ Now we integrate the planner into the actual search function: | |||||||||||
| #include "search_planner.hpp" | ||||||||||||
| #include <cuvs/detail/jit_lto/registration_tags.hpp> | ||||||||||||
| #include <raft/core/device_resources.hpp> | ||||||||||||
| #include <type_traits> | ||||||||||||
|
|
||||||||||||
| namespace example::detail { | ||||||||||||
|
|
||||||||||||
| enum class DistanceType { Euclidean }; | ||||||||||||
| enum class FilterType { None }; | ||||||||||||
|
|
||||||||||||
| // Type tag helpers | ||||||||||||
| template <typename T> | ||||||||||||
| constexpr auto get_data_type_tag() { | ||||||||||||
|
|
@@ -671,7 +687,6 @@ void search_jit( | |||||||||||
| // cannot handle non-type template parameters | ||||||||||||
| SearchPlanner planner; | ||||||||||||
|
|
||||||||||||
| // Add required device function fragments | ||||||||||||
| planner.add_search_function<data_tag, out_tag, idx_tag, Optimized, Veclen>(); | ||||||||||||
| planner.add_compute_distance_device_function<metric_tag, data_tag>(); | ||||||||||||
| planner.add_filter_device_function<filter_tag, idx_tag>(); | ||||||||||||
|
|
@@ -701,6 +716,210 @@ void search_jit( | |||||||||||
| } // namespace example::detail | ||||||||||||
| ``` | ||||||||||||
|
|
||||||||||||
| ### Step 7b: Example — NVRTC UDFs for `compute_distance` and `apply_filter` | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is clever, however, the explanatory text leading into the code snippets is a really large wall of text that mixes:
Which makes it really hard to read and follow, particularly for someone that has never used NVRTC or familiar with the UDF infra at all. Consider restructuring into something like:
|
||||||||||||
|
|
||||||||||||
| Same entry kernel as Steps 1–7, but `compute_distance` / `apply_filter` are **not** linked from static matrix fatbins: one NVRTC TU per hook (or compile twice) and register each through the planner’s **UDF-specific** APIs (see below—not the same calls as static matrix fragments). Both are **templates** in this example, so each TU must include a **forwarding definition** of the hook plus an **explicit instantiation** for every concrete specialization the entry fatbin calls (e.g. `compute_distance<float>` and `apply_filter<uint32_t>`). | ||||||||||||
|
|
||||||||||||
| **1. Entry / shared header — declarations only** | ||||||||||||
|
|
||||||||||||
| ```cpp | ||||||||||||
| namespace example::detail { | ||||||||||||
|
|
||||||||||||
| template <typename T> | ||||||||||||
| __device__ float compute_distance(T q, T d); | ||||||||||||
|
|
||||||||||||
| template <typename IdxT> | ||||||||||||
| __device__ bool apply_filter(uint32_t query_id, IdxT node_id, void* filter_data); | ||||||||||||
|
|
||||||||||||
| } // namespace example::detail | ||||||||||||
| ``` | ||||||||||||
|
|
||||||||||||
| Register each NVRTC fatbin through your planner’s **UDF** hooks (the snippet below uses `add_metric_udf_fragment` / `add_filter_udf_fragment`), not through the static-matrix helpers from Step 6 (`add_compute_distance_function` / `add_filter_function` on `SearchPlanner`). Those paths select different embedded fatbins; using both for the same hook would duplicate device definitions. | ||||||||||||
|
|
||||||||||||
| If you only declare `template ... compute_distance` / `apply_filter` in the entry TU, the NVRTC program must define them **and** emit matching `template __device__ ...` **explicit instantiations** for the concrete types the entry calls. For `IdxT`-dependent hooks, do not bake `IdxT` into a macro string: append those lines from a small host helper (`instantiate_apply_filter_udf` below) that takes the same NVRTC spelling `type_name<IdxT>()` returns, analogous to `instantiate_compute_distance_udf` for `T`. | ||||||||||||
|
|
||||||||||||
| **2. Write the body first, then a `#define` turns it into NVRTC source** | ||||||||||||
|
|
||||||||||||
| The hooks are **function-like macros** (`#define EXAMPLE_UDF_DISTANCE(NAME, BODY) ...`). You only edit the braced **body** argument. The preprocessor emits (1) a real `__device__` function template so NVCC checks `T`, `q`, `d`, etc., and (2) a `NAME_udf()` / `NAME_filter_udf()` factory that concatenates boilerplate with `std::string(#BODY)` so NVRTC compiles the **same** tokens NVCC already parsed. If the body needs raw `"` characters, splice that part with raw-string concatenation instead of relying on `#BODY` alone. For distance here, the NVRTC text also defines `compute_distance_udf_impl`, which forwards to `NAME_distance`. | ||||||||||||
|
|
||||||||||||
| **Macro definitions** (typically a shared header included before the invocations): | ||||||||||||
|
|
||||||||||||
| ```cpp | ||||||||||||
| #include <sstream> | ||||||||||||
| #include <string> | ||||||||||||
|
|
||||||||||||
| #define EXAMPLE_PP_CAT_(a, b) a##b | ||||||||||||
| #define EXAMPLE_PP_CAT(a, b) EXAMPLE_PP_CAT_(a, b) | ||||||||||||
| #define EXAMPLE_PP_STR_(x) #x | ||||||||||||
| #define EXAMPLE_PP_STR(x) EXAMPLE_PP_STR_(x) | ||||||||||||
|
|
||||||||||||
| // NAME_udf(): NVRTC program defines NAME_distance and compute_distance_udf_impl; host appends | ||||||||||||
| // instantiate_compute_distance_udf (forwarding compute_distance + explicit inst only). | ||||||||||||
| #define EXAMPLE_UDF_DISTANCE(NAME, BODY) \ | ||||||||||||
| template <typename T> \ | ||||||||||||
| __device__ float EXAMPLE_PP_CAT(NAME, _distance)(T q, T d) BODY \ | ||||||||||||
| \ | ||||||||||||
| inline std::string EXAMPLE_PP_CAT(NAME, _udf)() \ | ||||||||||||
| { \ | ||||||||||||
| return std::string("#include <cuda_runtime.h>\n" \ | ||||||||||||
| "namespace example::detail {\n" \ | ||||||||||||
| "template <typename T>\n" \ | ||||||||||||
| "__device__ float " EXAMPLE_PP_STR(EXAMPLE_PP_CAT(NAME, _distance)) \ | ||||||||||||
| "(T q, T d) ") \ | ||||||||||||
| + std::string(#BODY) + \ | ||||||||||||
| std::string("\n" \ | ||||||||||||
| "template <typename T>\n" \ | ||||||||||||
| "__device__ float compute_distance_udf_impl(T q, T d) {\n" \ | ||||||||||||
| " return ") + \ | ||||||||||||
| std::string(EXAMPLE_PP_STR(EXAMPLE_PP_CAT(NAME, _distance))) + \ | ||||||||||||
| std::string("(q, d);\n" \ | ||||||||||||
| "}\n}\n"); \ | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| // Forwarding compute_distance + explicit inst only (user metric stays inside NAME_udf()). | ||||||||||||
| inline std::string instantiate_compute_distance_udf(char const* t_type) | ||||||||||||
| { | ||||||||||||
| std::ostringstream oss; | ||||||||||||
| oss << "\nnamespace example::detail {\n" | ||||||||||||
| << "template <typename T>\n" | ||||||||||||
| << "__device__ float compute_distance(T q, T d) {\n" | ||||||||||||
| << " return compute_distance_udf_impl(q, d);\n" | ||||||||||||
| << "}\n" | ||||||||||||
| << "template __device__ float compute_distance<" << t_type << ">(" << t_type << ", " << t_type | ||||||||||||
| << ");\n" | ||||||||||||
| << "}\n"; | ||||||||||||
| return oss.str(); | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| // Device NAME_filter + filter_udf(); append instantiate_apply_filter_udf for apply_filter<IdxT>. | ||||||||||||
| #define EXAMPLE_UDF_FILTER(NAME, BODY) \ | ||||||||||||
| template <typename IdxT> \ | ||||||||||||
| __device__ bool EXAMPLE_PP_CAT(NAME, _filter)(uint32_t query_id, IdxT node_id, void* filter_data) \ | ||||||||||||
| BODY \ | ||||||||||||
| \ | ||||||||||||
| inline std::string EXAMPLE_PP_CAT(NAME, _filter_udf)() \ | ||||||||||||
| { \ | ||||||||||||
| return std::string("#include <cuda_runtime.h>\n" \ | ||||||||||||
| "namespace example::detail {\n" \ | ||||||||||||
| "template <typename IdxT>\n" \ | ||||||||||||
| "__device__ bool " EXAMPLE_PP_STR(EXAMPLE_PP_CAT(NAME, _filter)) \ | ||||||||||||
| "(uint32_t query_id, IdxT node_id, void* filter_data) ") \ | ||||||||||||
| + std::string(#BODY) + \ | ||||||||||||
| std::string("\n" \ | ||||||||||||
| "template <typename IdxT>\n" \ | ||||||||||||
| "__device__ bool apply_filter(uint32_t query_id, IdxT node_id, " \ | ||||||||||||
| "void* filter_data) {\n" \ | ||||||||||||
| " return " EXAMPLE_PP_STR(EXAMPLE_PP_CAT(NAME, _filter)) \ | ||||||||||||
| "(query_id, node_id, filter_data);\n" \ | ||||||||||||
| "}\n" \ | ||||||||||||
| "}\n"); \ | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| // Call after NAME_filter_udf() string is concatenated, before compile. | ||||||||||||
| inline std::string instantiate_apply_filter_udf(char const* idx_type) | ||||||||||||
| { | ||||||||||||
| std::ostringstream oss; | ||||||||||||
| oss << "\nnamespace example::detail {\n" | ||||||||||||
| << "template __device__ bool apply_filter<" << idx_type << ">(uint32_t, " << idx_type | ||||||||||||
| << ", void*);\n" | ||||||||||||
| << "}\n"; | ||||||||||||
| return oss.str(); | ||||||||||||
| } | ||||||||||||
| ``` | ||||||||||||
|
|
||||||||||||
| **Invocations** (same CUDA source or header — each line is a **function-like macro call** that expands into a device template plus a `std::string` factory; no extra `#define` is required unless you want a named alias): | ||||||||||||
|
|
||||||||||||
| ```cpp | ||||||||||||
| EXAMPLE_UDF_DISTANCE(my_l2, { | ||||||||||||
| T diff = q - d; | ||||||||||||
| return diff * diff; | ||||||||||||
| }) | ||||||||||||
|
|
||||||||||||
| EXAMPLE_UDF_FILTER(my_pass, { | ||||||||||||
| (void)query_id; | ||||||||||||
| (void)node_id; | ||||||||||||
| (void)filter_data; | ||||||||||||
| return true; | ||||||||||||
| }) | ||||||||||||
| ``` | ||||||||||||
|
|
||||||||||||
| `#BODY` follows the usual preprocessor rules (avoid raw `"` inside the body unless you splice that section with raw-string literals around `#BODY`). | ||||||||||||
|
|
||||||||||||
| **3. Host — one NVRTC compile per UDF fragment** | ||||||||||||
|
|
||||||||||||
| Assemble the full CUDA program for each hook, compile it once, and register the fatbin through the planner’s UDF entry point (e.g. `add_metric_udf_fragment`). Do not merge unrelated UDFs into a single compile / fragment. | ||||||||||||
|
|
||||||||||||
| Step 7b’s toy kernel uses `compute_distance<float>` and `apply_filter<uint32_t>` (explicit lines come from `instantiate_compute_distance_udf` / `instantiate_apply_filter_udf`, not from hard-coded types inside the macros). A matching `type_name` only has to spell those concrete types; each return must be the **exact** token NVRTC will parse (same characters as in the generated CUDA). Strip cv/ref first, then add `if constexpr` branches as you support more index and element types. | ||||||||||||
|
|
||||||||||||
| ```cpp | ||||||||||||
| #include <type_traits> | ||||||||||||
|
|
||||||||||||
| template <typename U> | ||||||||||||
| constexpr const char* type_name() | ||||||||||||
| { | ||||||||||||
| using T = std::remove_cv_t<std::remove_reference_t<U>>; | ||||||||||||
| if constexpr (std::is_same_v<T, float>) { | ||||||||||||
| return "float"; | ||||||||||||
| } else if constexpr (std::is_same_v<T, uint32_t>) { | ||||||||||||
| return "uint32_t"; | ||||||||||||
| } else { | ||||||||||||
| static_assert(std::is_same_v<T, void>, "add a branch for each T / AccT / IdxT the entry uses"); | ||||||||||||
| return ""; | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
| ``` | ||||||||||||
|
|
||||||||||||
| Call `instantiate_compute_distance_udf` / `instantiate_apply_filter_udf` to append forwarding templates and explicit instantiations, using the same `type_name<...>()` tokens the entry TU will use. Concatenate that glue onto each `*_udf()` string, then compile and register. | ||||||||||||
|
|
||||||||||||
| **4. Extend `search_jit.cuh` (Step 7) for UDF vs static** | ||||||||||||
|
|
||||||||||||
| Step 7 only registered static matrix fragments. To add NVRTC UDFs without breaking the Euclidean / no-filter path, add `#include <string>`, replace the single-value `DistanceType` / `FilterType` enums and the `get_metric_tag` / `get_filter_tag` templates with the extended versions below, add the empty UDF tag structs, and keep the UDF glue (`my_l2_udf`, `instantiate_*`, `type_name`, `nvrtc_compiler`) in the same translation unit. Then **replace** the two unconditional `add_compute_distance_device_function` / `add_filter_device_function` lines with the `if constexpr` planner block (second snippet). | ||||||||||||
|
|
||||||||||||
| ```cpp | ||||||||||||
| enum class DistanceType { Euclidean, MetricUdf }; | ||||||||||||
| enum class FilterType { None, FilterUdf }; | ||||||||||||
|
Comment on lines
+879
to
+880
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Small comment that might help readability |
||||||||||||
|
|
||||||||||||
| struct tag_metric_custom_udf {}; | ||||||||||||
| struct tag_filter_custom_udf {}; | ||||||||||||
|
|
||||||||||||
| template <DistanceType Metric> | ||||||||||||
| constexpr auto get_metric_tag() { | ||||||||||||
| if constexpr (Metric == DistanceType::Euclidean) return tag_metric_euclidean{}; | ||||||||||||
| else if constexpr (Metric == DistanceType::MetricUdf) return tag_metric_custom_udf{}; | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| template <FilterType Filter> | ||||||||||||
| constexpr auto get_filter_tag() { | ||||||||||||
| if constexpr (Filter == FilterType::None) return tag_filter_none{}; | ||||||||||||
| else if constexpr (Filter == FilterType::FilterUdf) return tag_filter_custom_udf{}; | ||||||||||||
| } | ||||||||||||
|
Comment on lines
+885
to
+895
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is similar to a comment I had in the UDF infra PR, if a new enum value is added later but these functions aren't updated, the compiler will give a confusing "no return" error rather than a clear message. Adding an else { static_assert(...); } fallback would be a good idea. Since this is documentation/example code, arguably it's even more important here to model best practices. |
||||||||||||
| ``` | ||||||||||||
|
|
||||||||||||
| ```cpp | ||||||||||||
| SearchPlanner planner; | ||||||||||||
| planner.add_search_function<data_tag, out_tag, idx_tag, Optimized, Veclen>(); | ||||||||||||
|
|
||||||||||||
| if constexpr (std::is_same_v<metric_tag, tag_metric_custom_udf>) { | ||||||||||||
| std::string metric_udf_code = my_l2_udf(); | ||||||||||||
| metric_udf_code += instantiate_compute_distance_udf(type_name<T>()); | ||||||||||||
| planner.add_metric_udf_fragment(nvrtc_compiler().compile(metric_udf_code, metric_udf_code)); | ||||||||||||
| } else { | ||||||||||||
| planner.add_compute_distance_device_function<metric_tag, data_tag>(); | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| if constexpr (std::is_same_v<filter_tag, tag_filter_custom_udf>) { | ||||||||||||
| std::string filter_udf_code = my_pass_filter_udf(); | ||||||||||||
| filter_udf_code += instantiate_apply_filter_udf(type_name<IdxT>()); | ||||||||||||
| planner.add_filter_udf_fragment(nvrtc_compiler().compile(filter_udf_code, filter_udf_code)); | ||||||||||||
| } else { | ||||||||||||
| planner.add_filter_device_function<filter_tag, idx_tag>(); | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| auto launcher = planner.get_launcher(); | ||||||||||||
| ``` | ||||||||||||
|
|
||||||||||||
| Instantiate `search_jit` with `DistanceType::MetricUdf` and/or `FilterType::FilterUdf` only when you intend the NVRTC branches; `Euclidean` and `FilterType::None` keep the original static behavior. | ||||||||||||
|
|
||||||||||||
| ## Key Concepts | ||||||||||||
|
|
||||||||||||
| ### Fragment Tags | ||||||||||||
|
|
||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| UDF Usage | ||
| ========= | ||
|
|
||
| .. caution:: | ||
|
|
||
| Custom distance metrics for IVF-flat search are **experimental**. They live under the | ||
| ``cuvs::neighbors::ivf_flat::experimental::udf`` namespace and the associated ``CUVS_METRIC`` | ||
| macro. APIs and behavior may change without a major release. | ||
|
|
||
| What this feature does | ||
| ---------------------- | ||
|
|
||
| You can supply **your own CUDA device code** that defines how distance accumulates between a query | ||
| vector and database vectors **inside the IVF-flat interleaved scan** (the fine search over lists). | ||
| Technical background on compilation and linking is in :doc:`jit_lto_guide`. | ||
|
|
||
| Available via C++ APIs for the following algorithms | ||
| --------------------------------------------------- | ||
|
|
||
| * IVF-flat — :doc:`search <cpp_api/neighbors_ivf_flat>` (``search_params.metric_udf`` / ``CUVS_METRIC``). | ||
|
|
||
| Requirements and tips | ||
| ----------------------- | ||
|
|
||
| * Include ``<cuvs/neighbors/ivf_flat.hpp>`` and define a metric with ``CUVS_METRIC(MyName, { ... })``. | ||
| Set ``search_params.metric_udf`` to the string returned by ``MyName_udf()``. | ||
| * Prefer the helpers documented next to the macro (``squared_diff``, ``abs_diff``, ``dot_product``, | ||
| ``point`` element access, and so on) so the same definition works across ``float``, ``int8_t`` / | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "and so on" kinda leaves the user guessing, almost like "left as an exercise to the reader". A user writing a custom metric needs to know exactly what's available, a brief table listing all the helpers (name, signature, description) or at minimum a cross-reference to the header line numbers where they're defined would be helpful here. |
||
| ``uint8_t`` packed lanes, and related accumulator types. | ||
| * Custom UDF is **not supported for fp16** (``__half`` / ``half``) indices at this time; the headers | ||
| enforce this with a static assertion when applicable. | ||
| * The scan assumes **ascending** distance order for top-*k* selection; metrics that do not behave | ||
| like a distance in that sense need careful validation. | ||
| * The first search with a new metric string may pay a one-time compilation cost; reuse the same | ||
| string (and run a warmup) to benefit from the caches described in :doc:`advanced_topics`. | ||
|
|
||
| Example | ||
| ------- | ||
|
|
||
| .. code-block:: cpp | ||
|
|
||
| #include <cuvs/neighbors/ivf_flat.hpp> | ||
|
|
||
| namespace ivf = cuvs::neighbors::ivf_flat; | ||
|
|
||
| // L∞ (Chebyshev): per dimension, acc = max(acc, |x - y|); acc starts at 0 in the scan kernel. | ||
| CUVS_METRIC(my_chebyshev, { | ||
| auto d = abs_diff(x, y); | ||
| acc = (d > acc) ? d : acc; | ||
| }) | ||
|
Comment on lines
+47
to
+50
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Small thing, but this differs from the example in ivf_flat.hpp where we don't use the helper, would be good to make them match |
||
|
|
||
| void run_search(raft::resources const& res, | ||
| ivf::index<float, int64_t> const& index, | ||
| raft::device_matrix_view<const float, int64_t, raft::row_major> queries, | ||
| raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors, | ||
| raft::device_matrix_view<float, int64_t, raft::row_major> distances) | ||
| { | ||
| ivf::search_params params; | ||
| params.metric_udf = my_chebyshev_udf(); | ||
|
|
||
| ivf::search(res, params, index, queries, neighbors, distances); | ||
| } | ||
|
|
||
| For more examples (L2 via ``squared_diff``, raw string fragments, and so on), see | ||
| ``cpp/tests/neighbors/ann_ivf_flat/test_udf.cu`` in the cuVS repository. | ||
|
|
||
| Further reading | ||
| --------------- | ||
|
|
||
| * C++ API reference: :doc:`cpp_api/neighbors_ivf_flat` | ||
| * JIT LTO architecture and IVF-flat fragments: :doc:`jit_lto_guide` | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These do the same and is just semantic sugar, no? Would be good to add a comment to state so for readability and documentation.