Skip to content

Commit cdbabef

Browse files
committed
Add allocator template parameter
1 parent 7f3e3ac commit cdbabef

4 files changed

Lines changed: 44 additions & 33 deletions

File tree

include/cuco/detail/trie/trie.inl

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
namespace cuco {
2121
namespace experimental {
2222

23-
template <typename LabelType>
24-
constexpr trie<LabelType>::trie()
25-
: num_keys_{0},
23+
template <typename LabelType, class Allocator>
24+
constexpr trie<LabelType, Allocator>::trie(Allocator const& allocator)
25+
: allocator_{allocator},
26+
num_keys_{0},
2627
num_nodes_{1},
2728
last_key_{},
2829
num_levels_{2},
@@ -37,15 +38,15 @@ constexpr trie<LabelType>::trie()
3738
levels_[0].labels_.push_back(root_label_);
3839
}
3940

40-
template <typename LabelType>
41-
trie<LabelType>::~trie() noexcept(false)
41+
template <typename LabelType, class Allocator>
42+
trie<LabelType, Allocator>::~trie() noexcept(false)
4243
{
4344
if (d_levels_ptr_) { CUCO_CUDA_TRY(cudaFree(d_levels_ptr_)); }
4445
if (device_ptr_) { CUCO_CUDA_TRY(cudaFree(device_ptr_)); }
4546
}
4647

47-
template <typename LabelType>
48-
void trie<LabelType>::insert(const std::vector<LabelType>& key) noexcept
48+
template <typename LabelType, class Allocator>
49+
void trie<LabelType, Allocator>::insert(const std::vector<LabelType>& key) noexcept
4950
{
5051
if (key == last_key_) { return; } // Ignore duplicate keys
5152
assert(num_keys_ == 0 || key > last_key_); // Keys are expected to be inserted in sorted order
@@ -95,8 +96,8 @@ void trie<LabelType>::insert(const std::vector<LabelType>& key) noexcept
9596
last_key_ = key;
9697
}
9798

98-
template <typename LabelType>
99-
void trie<LabelType>::build() noexcept(false)
99+
template <typename LabelType, class Allocator>
100+
void trie<LabelType, Allocator>::build() noexcept(false)
100101
{
101102
// Perform build level-by-level for all levels, followed by a deep-copy from host to device
102103
size_type offset = 0;
@@ -125,13 +126,13 @@ void trie<LabelType>::build() noexcept(false)
125126
CUCO_CUDA_TRY(cudaMemcpy(device_ptr_, this, sizeof(trie<LabelType>), cudaMemcpyHostToDevice));
126127
}
127128

128-
template <typename LabelType>
129+
template <typename LabelType, class Allocator>
129130
template <typename KeyIt, typename OffsetIt, typename OutputIt>
130-
void trie<LabelType>::lookup(KeyIt keys_begin,
131-
OffsetIt offsets_begin,
132-
OffsetIt offsets_end,
133-
OutputIt outputs_begin,
134-
cuda_stream_ref stream) const noexcept
131+
void trie<LabelType, Allocator>::lookup(KeyIt keys_begin,
132+
OffsetIt offsets_begin,
133+
OffsetIt offsets_end,
134+
OutputIt outputs_begin,
135+
cuda_stream_ref stream) const noexcept
135136
{
136137
auto num_keys = cuco::detail::distance(offsets_begin, offsets_end) - 1;
137138
if (num_keys == 0) { return; }
@@ -159,16 +160,17 @@ __global__ void trie_lookup_kernel(
159160
}
160161
}
161162

162-
template <typename LabelType>
163+
template <typename LabelType, class Allocator>
163164
template <typename... Operators>
164-
auto trie<LabelType>::ref(Operators...) const noexcept
165+
auto trie<LabelType, Allocator>::ref(Operators...) const noexcept
165166
{
166167
static_assert(sizeof...(Operators), "No operators specified");
167168
return ref_type<Operators...>{device_ptr_};
168169
}
169170

170-
template <typename LabelType>
171-
trie<LabelType>::level::level() : louds_{}, outs_{}, labels_{}, labels_ptr_{nullptr}, offset_{0}
171+
template <typename LabelType, class Allocator>
172+
trie<LabelType, Allocator>::level::level()
173+
: louds_{}, outs_{}, labels_{}, labels_ptr_{nullptr}, offset_{0}
172174
{
173175
}
174176

include/cuco/detail/trie/trie_ref.inl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
namespace cuco {
22
namespace experimental {
33

4-
template <typename LabelType, typename... Operators>
5-
__host__ __device__ constexpr trie_ref<LabelType, Operators...>::trie_ref(
6-
const trie<LabelType>* trie) noexcept
4+
template <typename LabelType, class Allocator, typename... Operators>
5+
__host__ __device__ constexpr trie_ref<LabelType, Allocator, Operators...>::trie_ref(
6+
const trie<LabelType, Allocator>* trie) noexcept
77
: trie_{trie}
88
{
99
}
1010

1111
namespace detail {
1212

13-
template <typename LabelType, typename... Operators>
14-
class operator_impl<op::trie_lookup_tag, trie_ref<LabelType, Operators...>> {
15-
using ref_type = trie_ref<LabelType, Operators...>;
13+
template <typename LabelType, class Allocator, typename... Operators>
14+
class operator_impl<op::trie_lookup_tag, trie_ref<LabelType, Allocator, Operators...>> {
15+
using ref_type = trie_ref<LabelType, Allocator, Operators...>;
1616
using size_type = size_t;
1717

1818
public:

include/cuco/trie.cuh

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,17 @@ namespace experimental {
2525
* @brief Trie class
2626
*
2727
* @tparam label_type type of individual characters of vector keys (eg. char or int)
28+
* @tparam Allocator Type of allocator used for device storage
2829
*/
29-
template <typename LabelType>
30+
template <typename LabelType, class Allocator = thrust::device_malloc_allocator<std::byte>>
3031
class trie {
3132
public:
32-
constexpr trie();
33+
/**
34+
* @brief Constructs an empty trie
35+
*
36+
* @param allocator Allocator used for allocating device storage
37+
*/
38+
constexpr trie(Allocator const& allocator = Allocator{});
3339
~trie() noexcept(false);
3440

3541
/**
@@ -88,6 +94,7 @@ class trie {
8894
[[nodiscard]] auto ref(Operators... ops) const noexcept;
8995

9096
private:
97+
Allocator allocator_; ///< Allocator
9198
size_type num_keys_; ///< Number of keys inserted into trie
9299
size_type num_nodes_; ///< Number of internal nodes
93100
std::vector<LabelType> last_key_; ///< Last key inserted into trie
@@ -110,7 +117,8 @@ class trie {
110117

111118
template <typename... Operators>
112119
using ref_type =
113-
cuco::experimental::trie_ref<LabelType, Operators...>; ///< Non-owning container ref type
120+
cuco::experimental::trie_ref<LabelType, Allocator, Operators...>; ///< Non-owning container ref
121+
///< type
114122

115123
// Mixins need to be friends with this class in order to access private members
116124
template <typename Op, typename Ref>

include/cuco/trie_ref.cuh

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
namespace cuco {
66
namespace experimental {
77

8-
template <typename LabelType>
8+
template <typename LabelType, class Allocator>
99
class trie;
1010

1111
/**
@@ -15,18 +15,19 @@ class trie;
1515
* @tparam LabelType Trie label type
1616
* @tparam Operators Device operator options defined in `include/cuco/operator.hpp`
1717
*/
18-
template <typename LabelType, typename... Operators>
19-
class trie_ref : public detail::operator_impl<Operators, trie_ref<LabelType, Operators...>>... {
18+
template <typename LabelType, class Allocator, typename... Operators>
19+
class trie_ref
20+
: public detail::operator_impl<Operators, trie_ref<LabelType, Allocator, Operators...>>... {
2021
public:
2122
/**
2223
* @brief Constructs trie_ref.
2324
*
2425
* @param trie Non-owning ref of trie
2526
*/
26-
__host__ __device__ explicit constexpr trie_ref(const trie<LabelType>* trie) noexcept;
27+
__host__ __device__ explicit constexpr trie_ref(const trie<LabelType, Allocator>* trie) noexcept;
2728

2829
private:
29-
const trie<LabelType>* trie_;
30+
const trie<LabelType, Allocator>* trie_;
3031

3132
// Mixins need to be friends with this class in order to access private members
3233
template <typename Op, typename Ref>

0 commit comments

Comments
 (0)