2020namespace cuco {
2121namespace experimental {
2222
23- template <typename LabelType>
24- constexpr trie<LabelType>::trie()
25- : num_keys_{0 },
23+ template <typename LabelType, class Allocator >
24+ constexpr trie<LabelType, Allocator>::trie(Allocator const & allocator)
25+ : allocator_{allocator},
26+ num_keys_{0 },
2627 num_nodes_{1 },
2728 last_key_{},
2829 num_levels_{2 },
@@ -37,15 +38,15 @@ constexpr trie<LabelType>::trie()
3738 levels_[0 ].labels_ .push_back (root_label_);
3839}
3940
40- template <typename LabelType>
41- trie<LabelType>::~trie () noexcept (false )
41+ template <typename LabelType, class Allocator >
42+ trie<LabelType, Allocator >::~trie () noexcept (false )
4243{
4344 if (d_levels_ptr_) { CUCO_CUDA_TRY (cudaFree (d_levels_ptr_)); }
4445 if (device_ptr_) { CUCO_CUDA_TRY (cudaFree (device_ptr_)); }
4546}
4647
47- template <typename LabelType>
48- void trie<LabelType>::insert(const std::vector<LabelType>& key) noexcept
48+ template <typename LabelType, class Allocator >
49+ void trie<LabelType, Allocator >::insert(const std::vector<LabelType>& key) noexcept
4950{
5051 if (key == last_key_) { return ; } // Ignore duplicate keys
5152 assert (num_keys_ == 0 || key > last_key_); // Keys are expected to be inserted in sorted order
@@ -95,8 +96,8 @@ void trie<LabelType>::insert(const std::vector<LabelType>& key) noexcept
9596 last_key_ = key;
9697}
9798
98- template <typename LabelType>
99- void trie<LabelType>::build() noexcept (false )
99+ template <typename LabelType, class Allocator >
100+ void trie<LabelType, Allocator >::build() noexcept (false )
100101{
101102 // Perform build level-by-level for all levels, followed by a deep-copy from host to device
102103 size_type offset = 0 ;
@@ -125,13 +126,13 @@ void trie<LabelType>::build() noexcept(false)
125126 CUCO_CUDA_TRY (cudaMemcpy (device_ptr_, this , sizeof (trie<LabelType>), cudaMemcpyHostToDevice));
126127}
127128
128- template <typename LabelType>
129+ template <typename LabelType, class Allocator >
129130template <typename KeyIt, typename OffsetIt, typename OutputIt>
130- void trie<LabelType>::lookup(KeyIt keys_begin,
131- OffsetIt offsets_begin,
132- OffsetIt offsets_end,
133- OutputIt outputs_begin,
134- cuda_stream_ref stream) const noexcept
131+ void trie<LabelType, Allocator >::lookup(KeyIt keys_begin,
132+ OffsetIt offsets_begin,
133+ OffsetIt offsets_end,
134+ OutputIt outputs_begin,
135+ cuda_stream_ref stream) const noexcept
135136{
136137 auto num_keys = cuco::detail::distance (offsets_begin, offsets_end) - 1 ;
137138 if (num_keys == 0 ) { return ; }
@@ -159,16 +160,17 @@ __global__ void trie_lookup_kernel(
159160 }
160161}
161162
162- template <typename LabelType>
163+ template <typename LabelType, class Allocator >
163164template <typename ... Operators>
164- auto trie<LabelType>::ref(Operators...) const noexcept
165+ auto trie<LabelType, Allocator >::ref(Operators...) const noexcept
165166{
166167 static_assert (sizeof ...(Operators), " No operators specified" );
167168 return ref_type<Operators...>{device_ptr_};
168169}
169170
170- template <typename LabelType>
171- trie<LabelType>::level::level() : louds_{}, outs_{}, labels_{}, labels_ptr_{nullptr }, offset_{0 }
171+ template <typename LabelType, class Allocator >
172+ trie<LabelType, Allocator>::level::level()
173+ : louds_{}, outs_{}, labels_{}, labels_ptr_{nullptr }, offset_{0 }
172174{
173175}
174176
0 commit comments