Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/advanced_topics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ Advanced Topics
===============

- `Just-in-Time Compilation`_
- :doc:`UDF Usage <udf_usage>`

Just-in-Time Compilation
------------------------
Expand All @@ -16,7 +17,13 @@ Thus, the JIT compilation is a one-time cost and you can expect no loss in real
Currently, the following capabilities will trigger a JIT compilation:
- IVF Flat search APIs: :doc:`cuvs::neighbors::ivf_flat::search() <cpp_api/neighbors_ivf_flat>`

UDFs are available in the following APIs:
-----------------------------------------
- IVF Flat search (C++ only): experimental custom distance via ``search_params.metric_udf``; see
:doc:`udf_usage`.

.. toctree::
:maxdepth: 2

jit_lto_guide
udf_usage
221 changes: 220 additions & 1 deletion docs/source/jit_lto_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -573,8 +573,10 @@ The planner is responsible for:
#pragma once

#include <cuvs/detail/jit_lto/AlgorithmPlanner.hpp>
#include <cuvs/detail/jit_lto/FragmentEntry.hpp>
#include <cuvs/detail/jit_lto/MakeFragmentKey.hpp>
#include <cuvs/detail/jit_lto/registration_tags.hpp>
#include <memory>
#include <string>

struct SearchPlanner : AlgorithmPlanner {
Expand Down Expand Up @@ -602,6 +604,16 @@ struct SearchPlanner : AlgorithmPlanner {
{
add_static_fragment<fragment_tag_filter<FilterTag, IndexTag>>();
}

void add_metric_udf_fragment(std::unique_ptr<UDFFatbinFragment> fragment)
{
add_fragment(std::move(fragment));
}

void add_filter_udf_fragment(std::unique_ptr<UDFFatbinFragment> fragment)
{
add_fragment(std::move(fragment));
}
Comment on lines +608 to +616
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These do the same and is just semantic sugar, no? Would be good to add a comment to state so for readability and documentation.

};
```

Expand All @@ -617,9 +629,13 @@ Now we integrate the planner into the actual search function:
#include "search_planner.hpp"
#include <cuvs/detail/jit_lto/registration_tags.hpp>
#include <raft/core/device_resources.hpp>
#include <type_traits>

namespace example::detail {

enum class DistanceType { Euclidean };
enum class FilterType { None };

// Type tag helpers
template <typename T>
constexpr auto get_data_type_tag() {
Expand Down Expand Up @@ -671,7 +687,6 @@ void search_jit(
// cannot handle non-type template parameters
SearchPlanner planner;

// Add required device function fragments
planner.add_search_function<data_tag, out_tag, idx_tag, Optimized, Veclen>();
planner.add_compute_distance_device_function<metric_tag, data_tag>();
planner.add_filter_device_function<filter_tag, idx_tag>();
Expand Down Expand Up @@ -701,6 +716,210 @@ void search_jit(
} // namespace example::detail
```

### Step 7b: Example — NVRTC UDFs for `compute_distance` and `apply_filter`
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is clever, however, the explanatory text leading into the code snippets is a really large wall of text that mixes:

  • Architectural rationale (why NVRTC and not static fatbins)
  • Implementation constraints (explicit instantiations, forwarding definitions)
  • Usage instructions (which planner API to call)

Which makes it really hard to read and follow, particularly for someone that has never used NVRTC or familiar with the UDF infra at all.

Consider restructuring into something like:

  • What you're building (1-2 sentences)
  • Architecture diagram or flow (entry kernel → forward declarations → NVRTC TU → explicit instantiations → planner registration)
  • Code with inline comments
  • Pitfalls/constraints as a separate callout


Same entry kernel as Steps 1–7, but `compute_distance` / `apply_filter` are **not** linked from static matrix fatbins: one NVRTC TU per hook (or compile twice) and register each through the planner’s **UDF-specific** APIs (see below—not the same calls as static matrix fragments). Both are **templates** in this example, so each TU must include a **forwarding definition** of the hook plus an **explicit instantiation** for every concrete specialization the entry fatbin calls (e.g. `compute_distance<float>` and `apply_filter<uint32_t>`).

**1. Entry / shared header — declarations only**

```cpp
namespace example::detail {

template <typename T>
__device__ float compute_distance(T q, T d);

template <typename IdxT>
__device__ bool apply_filter(uint32_t query_id, IdxT node_id, void* filter_data);

} // namespace example::detail
```

Register each NVRTC fatbin through your planner’s **UDF** hooks (the snippet below uses `add_metric_udf_fragment` / `add_filter_udf_fragment`), not through the static-matrix helpers from Step 6 (`add_compute_distance_function` / `add_filter_function` on `SearchPlanner`). Those paths select different embedded fatbins; using both for the same hook would duplicate device definitions.

If you only declare `template ... compute_distance` / `apply_filter` in the entry TU, the NVRTC program must define them **and** emit matching `template __device__ ...` **explicit instantiations** for the concrete types the entry calls. For `IdxT`-dependent hooks, do not bake `IdxT` into a macro string: append those lines from a small host helper (`instantiate_apply_filter_udf` below) that takes the same NVRTC spelling `type_name<IdxT>()` returns, analogous to `instantiate_compute_distance_udf` for `T`.

**2. Write the body first, then a `#define` turns it into NVRTC source**

The hooks are **function-like macros** (`#define EXAMPLE_UDF_DISTANCE(NAME, BODY) ...`). You only edit the braced **body** argument. The preprocessor emits (1) a real `__device__` function template so NVCC checks `T`, `q`, `d`, etc., and (2) a `NAME_udf()` / `NAME_filter_udf()` factory that concatenates boilerplate with `std::string(#BODY)` so NVRTC compiles the **same** tokens NVCC already parsed. If the body needs raw `"` characters, splice that part with raw-string concatenation instead of relying on `#BODY` alone. For distance here, the NVRTC text also defines `compute_distance_udf_impl`, which forwards to `NAME_distance`.

**Macro definitions** (typically a shared header included before the invocations):

```cpp
#include <sstream>
#include <string>

#define EXAMPLE_PP_CAT_(a, b) a##b
#define EXAMPLE_PP_CAT(a, b) EXAMPLE_PP_CAT_(a, b)
#define EXAMPLE_PP_STR_(x) #x
#define EXAMPLE_PP_STR(x) EXAMPLE_PP_STR_(x)

// NAME_udf(): NVRTC program defines NAME_distance and compute_distance_udf_impl; host appends
// instantiate_compute_distance_udf (forwarding compute_distance + explicit inst only).
#define EXAMPLE_UDF_DISTANCE(NAME, BODY) \
template <typename T> \
__device__ float EXAMPLE_PP_CAT(NAME, _distance)(T q, T d) BODY \
\
inline std::string EXAMPLE_PP_CAT(NAME, _udf)() \
{ \
return std::string("#include <cuda_runtime.h>\n" \
"namespace example::detail {\n" \
"template <typename T>\n" \
"__device__ float " EXAMPLE_PP_STR(EXAMPLE_PP_CAT(NAME, _distance)) \
"(T q, T d) ") \
+ std::string(#BODY) + \
std::string("\n" \
"template <typename T>\n" \
"__device__ float compute_distance_udf_impl(T q, T d) {\n" \
" return ") + \
std::string(EXAMPLE_PP_STR(EXAMPLE_PP_CAT(NAME, _distance))) + \
std::string("(q, d);\n" \
"}\n}\n"); \
}

// Forwarding compute_distance + explicit inst only (user metric stays inside NAME_udf()).
inline std::string instantiate_compute_distance_udf(char const* t_type)
{
std::ostringstream oss;
oss << "\nnamespace example::detail {\n"
<< "template <typename T>\n"
<< "__device__ float compute_distance(T q, T d) {\n"
<< " return compute_distance_udf_impl(q, d);\n"
<< "}\n"
<< "template __device__ float compute_distance<" << t_type << ">(" << t_type << ", " << t_type
<< ");\n"
<< "}\n";
return oss.str();
}

// Device NAME_filter + filter_udf(); append instantiate_apply_filter_udf for apply_filter<IdxT>.
#define EXAMPLE_UDF_FILTER(NAME, BODY) \
template <typename IdxT> \
__device__ bool EXAMPLE_PP_CAT(NAME, _filter)(uint32_t query_id, IdxT node_id, void* filter_data) \
BODY \
\
inline std::string EXAMPLE_PP_CAT(NAME, _filter_udf)() \
{ \
return std::string("#include <cuda_runtime.h>\n" \
"namespace example::detail {\n" \
"template <typename IdxT>\n" \
"__device__ bool " EXAMPLE_PP_STR(EXAMPLE_PP_CAT(NAME, _filter)) \
"(uint32_t query_id, IdxT node_id, void* filter_data) ") \
+ std::string(#BODY) + \
std::string("\n" \
"template <typename IdxT>\n" \
"__device__ bool apply_filter(uint32_t query_id, IdxT node_id, " \
"void* filter_data) {\n" \
" return " EXAMPLE_PP_STR(EXAMPLE_PP_CAT(NAME, _filter)) \
"(query_id, node_id, filter_data);\n" \
"}\n" \
"}\n"); \
}

// Call after NAME_filter_udf() string is concatenated, before compile.
inline std::string instantiate_apply_filter_udf(char const* idx_type)
{
std::ostringstream oss;
oss << "\nnamespace example::detail {\n"
<< "template __device__ bool apply_filter<" << idx_type << ">(uint32_t, " << idx_type
<< ", void*);\n"
<< "}\n";
return oss.str();
}
```

**Invocations** (same CUDA source or header — each line is a **function-like macro call** that expands into a device template plus a `std::string` factory; no extra `#define` is required unless you want a named alias):

```cpp
EXAMPLE_UDF_DISTANCE(my_l2, {
T diff = q - d;
return diff * diff;
})

EXAMPLE_UDF_FILTER(my_pass, {
(void)query_id;
(void)node_id;
(void)filter_data;
return true;
})
```

`#BODY` follows the usual preprocessor rules (avoid raw `"` inside the body unless you splice that section with raw-string literals around `#BODY`).

**3. Host — one NVRTC compile per UDF fragment**

Assemble the full CUDA program for each hook, compile it once, and register the fatbin through the planner’s UDF entry point (e.g. `add_metric_udf_fragment`). Do not merge unrelated UDFs into a single compile / fragment.

Step 7b’s toy kernel uses `compute_distance<float>` and `apply_filter<uint32_t>` (explicit lines come from `instantiate_compute_distance_udf` / `instantiate_apply_filter_udf`, not from hard-coded types inside the macros). A matching `type_name` only has to spell those concrete types; each return must be the **exact** token NVRTC will parse (same characters as in the generated CUDA). Strip cv/ref first, then add `if constexpr` branches as you support more index and element types.

```cpp
#include <type_traits>

template <typename U>
constexpr const char* type_name()
{
using T = std::remove_cv_t<std::remove_reference_t<U>>;
if constexpr (std::is_same_v<T, float>) {
return "float";
} else if constexpr (std::is_same_v<T, uint32_t>) {
return "uint32_t";
} else {
static_assert(std::is_same_v<T, void>, "add a branch for each T / AccT / IdxT the entry uses");
return "";
}
}
```

Call `instantiate_compute_distance_udf` / `instantiate_apply_filter_udf` to append forwarding templates and explicit instantiations, using the same `type_name<...>()` tokens the entry TU will use. Concatenate that glue onto each `*_udf()` string, then compile and register.

**4. Extend `search_jit.cuh` (Step 7) for UDF vs static**

Step 7 only registered static matrix fragments. To add NVRTC UDFs without breaking the Euclidean / no-filter path, add `#include <string>`, replace the single-value `DistanceType` / `FilterType` enums and the `get_metric_tag` / `get_filter_tag` templates with the extended versions below, add the empty UDF tag structs, and keep the UDF glue (`my_l2_udf`, `instantiate_*`, `type_name`, `nvrtc_compiler`) in the same translation unit. Then **replace** the two unconditional `add_compute_distance_device_function` / `add_filter_device_function` lines with the `if constexpr` planner block (second snippet).

```cpp
enum class DistanceType { Euclidean, MetricUdf };
enum class FilterType { None, FilterUdf };
Comment on lines +879 to +880
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
enum class DistanceType { Euclidean, MetricUdf };
enum class FilterType { None, FilterUdf };
// Extend the DistanceType / FilterType enums from Step 7:
enum class DistanceType { Euclidean, MetricUdf };
enum class FilterType { None, FilterUdf };

Small comment that might help readability


struct tag_metric_custom_udf {};
struct tag_filter_custom_udf {};

template <DistanceType Metric>
constexpr auto get_metric_tag() {
if constexpr (Metric == DistanceType::Euclidean) return tag_metric_euclidean{};
else if constexpr (Metric == DistanceType::MetricUdf) return tag_metric_custom_udf{};
}

template <FilterType Filter>
constexpr auto get_filter_tag() {
if constexpr (Filter == FilterType::None) return tag_filter_none{};
else if constexpr (Filter == FilterType::FilterUdf) return tag_filter_custom_udf{};
}
Comment on lines +885 to +895
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is similar to a comment I had in the UDF infra PR, if a new enum value is added later but these functions aren't updated, the compiler will give a confusing "no return" error rather than a clear message. Adding an else { static_assert(...); } fallback would be a good idea. Since this is documentation/example code, arguably it's even more important here to model best practices.

```

```cpp
SearchPlanner planner;
planner.add_search_function<data_tag, out_tag, idx_tag, Optimized, Veclen>();

if constexpr (std::is_same_v<metric_tag, tag_metric_custom_udf>) {
std::string metric_udf_code = my_l2_udf();
metric_udf_code += instantiate_compute_distance_udf(type_name<T>());
planner.add_metric_udf_fragment(nvrtc_compiler().compile(metric_udf_code, metric_udf_code));
} else {
planner.add_compute_distance_device_function<metric_tag, data_tag>();
}

if constexpr (std::is_same_v<filter_tag, tag_filter_custom_udf>) {
std::string filter_udf_code = my_pass_filter_udf();
filter_udf_code += instantiate_apply_filter_udf(type_name<IdxT>());
planner.add_filter_udf_fragment(nvrtc_compiler().compile(filter_udf_code, filter_udf_code));
} else {
planner.add_filter_device_function<filter_tag, idx_tag>();
}

auto launcher = planner.get_launcher();
```

Instantiate `search_jit` with `DistanceType::MetricUdf` and/or `FilterType::FilterUdf` only when you intend the NVRTC branches; `Euclidean` and `FilterType::None` keep the original static behavior.

## Key Concepts

### Fragment Tags
Expand Down
71 changes: 71 additions & 0 deletions docs/source/udf_usage.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
UDF Usage
=========

.. caution::

Custom distance metrics for IVF-flat search are **experimental**. They live under the
``cuvs::neighbors::ivf_flat::experimental::udf`` namespace and the associated ``CUVS_METRIC``
macro. APIs and behavior may change without a major release.

What this feature does
----------------------

You can supply **your own CUDA device code** that defines how distance accumulates between a query
vector and database vectors **inside the IVF-flat interleaved scan** (the fine search over lists).
Technical background on compilation and linking is in :doc:`jit_lto_guide`.

Available via C++ APIs for the following algorithms
---------------------------------------------------

* IVF-flat — :doc:`search <cpp_api/neighbors_ivf_flat>` (``search_params.metric_udf`` / ``CUVS_METRIC``).

Requirements and tips
-----------------------

* Include ``<cuvs/neighbors/ivf_flat.hpp>`` and define a metric with ``CUVS_METRIC(MyName, { ... })``.
Set ``search_params.metric_udf`` to the string returned by ``MyName_udf()``.
* Prefer the helpers documented next to the macro (``squared_diff``, ``abs_diff``, ``dot_product``,
``point`` element access, and so on) so the same definition works across ``float``, ``int8_t`` /
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"and so on" kinda leaves the user guessing, almost like "left as an exercise to the reader". A user writing a custom metric needs to know exactly what's available, a brief table listing all the helpers (name, signature, description) or at minimum a cross-reference to the header line numbers where they're defined would be helpful here.

``uint8_t`` packed lanes, and related accumulator types.
* Custom UDF is **not supported for fp16** (``__half`` / ``half``) indices at this time; the headers
enforce this with a static assertion when applicable.
* The scan assumes **ascending** distance order for top-*k* selection; metrics that do not behave
like a distance in that sense need careful validation.
* The first search with a new metric string may pay a one-time compilation cost; reuse the same
string (and run a warmup) to benefit from the caches described in :doc:`advanced_topics`.

Example
-------

.. code-block:: cpp

#include <cuvs/neighbors/ivf_flat.hpp>

namespace ivf = cuvs::neighbors::ivf_flat;

// L∞ (Chebyshev): per dimension, acc = max(acc, |x - y|); acc starts at 0 in the scan kernel.
CUVS_METRIC(my_chebyshev, {
auto d = abs_diff(x, y);
acc = (d > acc) ? d : acc;
})
Comment on lines +47 to +50
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small thing, but this differs from the example in ivf_flat.hpp where we don't use the helper, would be good to make them match


void run_search(raft::resources const& res,
ivf::index<float, int64_t> const& index,
raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances)
{
ivf::search_params params;
params.metric_udf = my_chebyshev_udf();

ivf::search(res, params, index, queries, neighbors, distances);
}

For more examples (L2 via ``squared_diff``, raw string fragments, and so on), see
``cpp/tests/neighbors/ann_ivf_flat/test_udf.cu`` in the cuVS repository.

Further reading
---------------

* C++ API reference: :doc:`cpp_api/neighbors_ivf_flat`
* JIT LTO architecture and IVF-flat fragments: :doc:`jit_lto_guide`
Loading