Skip to content

Commit a1915c7

Browse files
committed
Misc coding style changes
1 parent cdaa2af commit a1915c7

3 files changed

Lines changed: 80 additions & 79 deletions

File tree

include/cuco/detail/trie/trie.inl

Lines changed: 49 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,20 @@ namespace cuco {
2121
namespace experimental {
2222

2323
template <typename label_type>
24-
trie<label_type>::trie()
25-
: levels_{2},
26-
d_levels_ptr_{nullptr},
27-
num_levels_{2},
28-
n_keys_{0},
29-
n_nodes_{1},
24+
constexpr trie<label_type>::trie()
25+
: num_keys_{0},
26+
num_nodes_{1},
3027
last_key_{},
28+
num_levels_{2},
29+
levels_{2},
30+
d_levels_ptr_{nullptr},
3131
device_ptr_{nullptr}
3232
{
33-
levels_[0].louds.append(0);
34-
levels_[0].louds.append(1);
35-
levels_[1].louds.append(1);
36-
levels_[0].outs.append(0);
37-
levels_[0].labels.push_back(root_label_);
33+
levels_[0].louds_.append(0);
34+
levels_[0].louds_.append(1);
35+
levels_[1].louds_.append(1);
36+
levels_[0].outs_.append(0);
37+
levels_[0].labels_.push_back(root_label_);
3838
}
3939

4040
template <typename label_type>
@@ -45,15 +45,15 @@ trie<label_type>::~trie() noexcept(false)
4545
}
4646

4747
template <typename label_type>
48-
void trie<label_type>::insert(const std::vector<label_type>& key)
48+
void trie<label_type>::insert(const std::vector<label_type>& key) noexcept
4949
{
50-
if (key == last_key_) { return; } // Ignore duplicate keys
51-
assert(n_keys_ == 0 || key > last_key_); // Keys are expected to be inserted in sorted order
50+
if (key == last_key_) { return; } // Ignore duplicate keys
51+
assert(num_keys_ == 0 || key > last_key_); // Keys are expected to be inserted in sorted order
5252

5353
if (key.empty()) {
54-
levels_[0].outs.set(0, 1);
55-
++levels_[1].offset;
56-
++n_keys_;
54+
levels_[0].outs_.set(0, 1);
55+
++levels_[1].offset_;
56+
++num_keys_;
5757
return;
5858
}
5959

@@ -66,37 +66,37 @@ void trie<label_type>::insert(const std::vector<label_type>& key)
6666
auto& level = levels_[pos + 1];
6767
auto label = key[pos];
6868

69-
if ((pos == last_key_.size()) || (label != level.labels.back())) {
70-
level.louds.set_last(0);
71-
level.louds.append(1);
72-
level.outs.append(0);
73-
level.labels.push_back(label);
74-
++n_nodes_;
69+
if ((pos == last_key_.size()) || (label != level.labels_.back())) {
70+
level.louds_.set_last(0);
71+
level.louds_.append(1);
72+
level.outs_.append(0);
73+
level.labels_.push_back(label);
74+
++num_nodes_;
7575
break;
7676
}
7777
}
7878

7979
// Process remaining labels after divergence point from last_key
80-
// Each such label will create a new edge and node pair in trie
80+
// Each such label will create a new edge and node pair
8181
for (++pos; pos < key.size(); ++pos) {
8282
auto& level = levels_[pos + 1];
83-
level.louds.append(0);
84-
level.louds.append(1);
85-
level.outs.append(0);
86-
level.labels.push_back(key[pos]);
87-
++n_nodes_;
83+
level.louds_.append(0);
84+
level.louds_.append(1);
85+
level.outs_.append(0);
86+
level.labels_.push_back(key[pos]);
87+
++num_nodes_;
8888
}
8989

90-
levels_[key.size() + 1].louds.append(1); // Mark end of current key
91-
++levels_[key.size() + 1].offset;
92-
levels_[key.size()].outs.set_last(1); // Set terminal bit indicating valid path
90+
levels_[key.size() + 1].louds_.append(1); // Mark end of current key
91+
++levels_[key.size() + 1].offset_;
92+
levels_[key.size()].outs_.set_last(1); // Set terminal bit indicating valid path
9393

94-
++n_keys_;
94+
++num_keys_;
9595
last_key_ = key;
9696
}
9797

9898
// Helper to move vector from host to device
99-
// Host vector is clear to avoid duplication. Device pointer is returned
99+
// Host vector is cleared to avoid duplication. Device pointer is returned
100100
template <typename T>
101101
T* move_vector_to_device(std::vector<T>& host_vector, thrust::device_vector<T>& device_vector)
102102
{
@@ -106,7 +106,7 @@ T* move_vector_to_device(std::vector<T>& host_vector, thrust::device_vector<T>&
106106
}
107107

108108
template <typename label_type>
109-
void trie<label_type>::build()
109+
void trie<label_type>::build() noexcept(false)
110110
{
111111
// Perform build level-by-level for all levels, followed by a deep-copy from host to device
112112

@@ -115,17 +115,17 @@ void trie<label_type>::build()
115115
size_type offset = 0;
116116

117117
for (auto& level : levels_) {
118-
level.louds.build();
119-
louds_refs.push_back(level.louds.ref(bv_read));
118+
level.louds_.build();
119+
louds_refs.push_back(level.louds_.ref(bv_read));
120120

121-
level.outs.build();
122-
outs_refs.push_back(level.outs.ref(bv_read));
121+
level.outs_.build();
122+
outs_refs.push_back(level.outs_.ref(bv_read));
123123

124124
// Move labels to device
125-
level.d_labels_ptr = move_vector_to_device(level.labels, level.d_labels);
125+
level.d_labels_ptr_ = move_vector_to_device(level.labels_, level.d_labels_);
126126

127-
offset += level.offset;
128-
level.offset = offset;
127+
offset += level.offset_;
128+
level.offset_ = offset;
129129
}
130130

131131
// Move bitvector refs to device
@@ -150,7 +150,7 @@ void trie<label_type>::lookup(KeyIt keys_begin,
150150
OffsetIt offsets_begin,
151151
OffsetIt offsets_end,
152152
OutputIt outputs_begin,
153-
cuda_stream_ref stream) const
153+
cuda_stream_ref stream) const noexcept
154154
{
155155
auto num_keys = cuco::detail::distance(offsets_begin, offsets_end) - 1;
156156
if (num_keys == 0) { return; }
@@ -166,14 +166,14 @@ void trie<label_type>::lookup(KeyIt keys_begin,
166166

167167
template <typename TrieRef, typename KeyIt, typename OffsetIt, typename OutputIt>
168168
__global__ void trie_lookup_kernel(
169-
TrieRef ref, KeyIt keys, OffsetIt offsets, OutputIt outputs, uint64_t num_keys)
169+
TrieRef ref, KeyIt keys, OffsetIt offsets, OutputIt outputs, size_t num_keys)
170170
{
171-
size_t loop_stride = gridDim.x * blockDim.x;
172-
size_t key_id = blockDim.x * blockIdx.x + threadIdx.x;
171+
auto loop_stride = gridDim.x * blockDim.x;
172+
auto key_id = blockDim.x * blockIdx.x + threadIdx.x;
173173

174174
while (key_id < num_keys) {
175175
auto key_start_pos = keys + offsets[key_id];
176-
size_t key_length = offsets[key_id + 1] - offsets[key_id];
176+
auto key_length = offsets[key_id + 1] - offsets[key_id];
177177

178178
outputs[key_id] = ref.lookup_key(key_start_pos, key_length);
179179
key_id += loop_stride;
@@ -189,7 +189,8 @@ auto trie<label_type>::ref(Operators...) const noexcept
189189
}
190190

191191
template <typename label_type>
192-
trie<label_type>::level::level() : louds{}, outs{}, labels{}, d_labels_ptr{nullptr}, offset{0}
192+
trie<label_type>::level::level()
193+
: louds_{}, outs_{}, labels_{}, d_labels_{}, d_labels_ptr_{nullptr}, offset_{0}
193194
{
194195
}
195196

include/cuco/detail/trie/trie_ref.inl

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@ namespace experimental {
44
template <typename label_type, typename... Operators>
55
__host__ __device__ constexpr trie_ref<label_type, Operators...>::trie_ref(
66
const trie<label_type>* trie) noexcept
7-
: trie_(trie)
7+
: trie_{trie}
88
{
99
}
1010

1111
namespace detail {
1212

1313
template <typename label_type, typename... Operators>
1414
class operator_impl<op::trie_lookup_tag, trie_ref<label_type, Operators...>> {
15-
using ref_type = trie_ref<label_type, Operators...>;
15+
using ref_type = trie_ref<label_type, Operators...>;
16+
using size_type = size_t;
1617

1718
public:
1819
/**
@@ -24,22 +25,22 @@ class operator_impl<op::trie_lookup_tag, trie_ref<label_type, Operators...>> {
2425
* @return Index of key if it exists in trie, -1 otherwise
2526
*/
2627
template <typename KeyIt>
27-
[[nodiscard]] __device__ uint64_t lookup_key(KeyIt key, uint64_t length) const noexcept
28+
[[nodiscard]] __device__ size_type lookup_key(KeyIt key, size_type length) const noexcept
2829
{
2930
auto const& trie = static_cast<ref_type const&>(*this).trie_;
3031

3132
// Level-by-level search. node_id is updated at each level
32-
uint32_t node_id = 0;
33-
for (uint32_t cur_depth = 1; cur_depth <= length; cur_depth++) {
33+
size_type node_id = 0;
34+
for (size_type cur_depth = 1; cur_depth <= length; cur_depth++) {
3435
if (!search_label_in_children(key[cur_depth - 1], node_id, cur_depth)) { return -1lu; }
3536
}
3637

3738
// Check for terminal node bit that indicates a valid key
38-
uint64_t leaf_level_id = length;
39+
size_type leaf_level_id = length;
3940
if (!trie->d_outs_refs_ptr_[leaf_level_id].get(node_id)) { return -1lu; }
4041

4142
// Key exists in trie, generate the index
42-
auto offset = trie->d_levels_ptr_[leaf_level_id].offset;
43+
auto offset = trie->d_levels_ptr_[leaf_level_id].offset_;
4344
auto rank = trie->d_outs_refs_ptr_[leaf_level_id].rank(node_id);
4445

4546
return offset + rank;
@@ -55,16 +56,16 @@ class operator_impl<op::trie_lookup_tag, trie_ref<label_type, Operators...>> {
5556
* @return Position of last child
5657
*/
5758
template <typename BitVectorRef>
58-
[[nodiscard]] __device__ uint32_t get_last_child_position(BitVectorRef louds,
59-
uint32_t& node_id) const noexcept
59+
[[nodiscard]] __device__ size_type get_last_child_position(BitVectorRef louds,
60+
size_type& node_id) const noexcept
6061
{
61-
uint32_t node_pos = 0;
62+
size_type node_pos = 0;
6263
if (node_id != 0) {
6364
node_pos = louds.select(node_id - 1) + 1;
6465
node_id = node_pos - node_id;
6566
}
6667

67-
uint32_t pos_end = louds.find_next_set(node_pos);
68+
auto pos_end = louds.find_next_set(node_pos);
6869
return node_id + (pos_end - node_pos);
6970
}
7071

@@ -78,17 +79,17 @@ class operator_impl<op::trie_lookup_tag, trie_ref<label_type, Operators...>> {
7879
* @return Boolean indicating success of search process
7980
*/
8081
[[nodiscard]] __device__ bool search_label_in_children(label_type target,
81-
uint32_t& node_id,
82-
uint32_t level_id) const noexcept
82+
size_type& node_id,
83+
size_type level_id) const noexcept
8384
{
8485
auto const& trie = static_cast<ref_type const&>(*this).trie_;
8586
auto louds = trie->d_louds_refs_ptr_[level_id];
8687

87-
uint32_t end = get_last_child_position(louds, node_id); // Position of last child
88-
uint32_t begin = node_id; // Position of first child, initialized after find_last_child call
88+
auto end = get_last_child_position(louds, node_id); // Position of last child
89+
auto begin = node_id; // Position of first child, initialized after find_last_child call
8990

9091
auto& level = trie->d_levels_ptr_[level_id];
91-
auto labels = level.d_labels_ptr;
92+
auto labels = level.d_labels_ptr_;
9293

9394
// Binary search labels array of current level
9495
while (begin < end) {

include/cuco/trie.cuh

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,22 +29,22 @@ namespace experimental {
2929
template <typename label_type>
3030
class trie {
3131
public:
32-
trie();
32+
constexpr trie();
3333
~trie() noexcept(false);
3434

3535
/**
3636
* @brief Insert new key into trie
3737
*
3838
* @param key Key to insert
3939
*/
40-
void insert(const std::vector<label_type>& key);
40+
void insert(const std::vector<label_type>& key) noexcept;
4141

4242
/**
4343
* @brief Build level-by-level trie indexes after inserting all keys
4444
*
4545
* In addition, a snapshot of current trie state is copied to device
4646
*/
47-
void build();
47+
void build() noexcept(false);
4848

4949
/**
5050
* @brief Bulk lookup vector of keys
@@ -64,7 +64,7 @@ class trie {
6464
OffsetIt offsets_begin,
6565
OffsetIt offsets_end,
6666
OutputIt outputs_begin,
67-
cuda_stream_ref stream = {}) const;
67+
cuda_stream_ref stream = {}) const noexcept;
6868

6969
using size_type = std::size_t; ///< size type
7070

@@ -73,7 +73,7 @@ class trie {
7373
*
7474
* @return Number of keys
7575
*/
76-
size_type constexpr size() const { return n_keys_; }
76+
size_type constexpr size() const noexcept { return num_keys_; }
7777

7878
/**
7979
* @brief Get device ref with operators.
@@ -88,12 +88,11 @@ class trie {
8888
[[nodiscard]] auto ref(Operators... ops) const noexcept;
8989

9090
private:
91-
size_type n_keys_; ///< Number of keys inserted into trie
92-
size_type n_nodes_; ///< Number of nodes in trie
91+
size_type num_keys_; ///< Number of keys inserted into trie
92+
size_type num_nodes_; ///< Number of internal nodes
9393
std::vector<label_type> last_key_; ///< Last key inserted into trie
9494

95-
static constexpr label_type root_label_ =
96-
sizeof(label_type) == 1 ? ' ' : static_cast<label_type>(-1); ///< Sentinel value
95+
static constexpr label_type root_label_ = sizeof(label_type) == 1 ? ' ' : -1; ///< Sentinel value
9796

9897
struct level;
9998
size_type num_levels_; ///< Number of trie levels
@@ -124,14 +123,14 @@ class trie {
124123
level();
125124
level(level&&) = default; ///< Move constructor
126125

127-
bit_vector<> louds; ///< Indicates links to next and previous level
128-
bit_vector<> outs; ///< Indicates terminal nodes of valid keys
126+
bit_vector<> louds_; ///< Indicates links to next and previous level
127+
bit_vector<> outs_; ///< Indicates terminal nodes of valid keys
129128

130-
std::vector<label_type> labels; ///< Stores individual characters of keys
131-
thrust::device_vector<label_type> d_labels; ///< Device-side copy of `labels`
132-
label_type* d_labels_ptr; ///< Raw pointer to d_labels
129+
std::vector<label_type> labels_; ///< Stores individual characters of keys
130+
thrust::device_vector<label_type> d_labels_; ///< Device-side copy of `labels`
131+
label_type* d_labels_ptr_; ///< Raw pointer to d_labels
133132

134-
size_type offset; ///< Count of nodes in all parent levels
133+
size_type offset_; ///< Cumulative node count in parent levels
135134
};
136135
};
137136

0 commit comments

Comments
 (0)