Skip to content

Commit ac2bda5

Browse files
committed
Add kernel export API via cuda.tile.compilation module
Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent 2b8dd4e commit ac2bda5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+2814
-1186
lines changed

cext/cuda_loader.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
#define FOREACH_CUDA_FUNCTION_TO_LOAD(X) \
1313
X(cuInit, 2000) \
14-
X(cuLibraryLoadFromFile, 12000) \
14+
X(cuLibraryLoadData, 12000) \
1515
X(cuLibraryUnload, 12000) \
1616
X(cuLibraryGetKernel, 12000) \
1717
X(cuGetErrorString, 6000) \

cext/hash_map.h

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,17 @@ class HashMap {
3030
static constexpr size_t kInitialBuckets = 4;
3131

3232
public:
33+
struct Item {
34+
const K key;
35+
V value;
36+
37+
Item(K&& k, V&& v) : key(std::move(k)), value(std::move(v)) {}
38+
Item(Item&&) = default;
39+
40+
Item(const Item&) = delete;
41+
void operator=(const Item&) = delete;
42+
};
43+
3344
HashMap()
3445
: nbuckets_(kInitialBuckets),
3546
size_(0),
@@ -50,17 +61,17 @@ class HashMap {
5061
}
5162

5263
template <typename Q>
53-
V* find(Q&& key) {
64+
Item* find(Q&& key) {
5465
uint64_t needle = compute_hash(key) | kOccupiedBit;
5566
auto [found, pos] = lookup(key, needle);
56-
return found ? &items_[pos].value : nullptr;
67+
return found ? &items_[pos] : nullptr;
5768
}
5869

59-
V* insert(K key, V value) {
70+
Item* insert(K key, V value) {
6071
uint64_t needle = compute_hash(key) | kOccupiedBit;
6172
auto [found, pos] = lookup(key, needle);
6273
if (found) {
63-
return &items_[pos].value;
74+
return &items_[pos];
6475
} else {
6576
if (size_ >= nbuckets_ / 2) {
6677
rehash();
@@ -70,18 +81,11 @@ class HashMap {
7081
}
7182
++size_;
7283
hashes_[pos] = needle;
73-
Item* new_item = new (&items_[pos]) Item(std::move(key), std::move(value));
74-
return &new_item->value;
84+
return new (&items_[pos]) Item(std::move(key), std::move(value));
7585
}
7686
}
7787

7888
private:
79-
struct Item {
80-
K key;
81-
V value;
82-
83-
Item(K&& k, V&& v) : key(std::move(k)), value(std::move(v)) {}
84-
};
8589

8690
size_t nbuckets_; // power of two
8791
size_t size_;

cext/py.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,28 @@ T pylong_as(PyObject* obj) {
4141
return PyLong_AsLong(obj);
4242
} else if constexpr (std::is_same_v<T, long long>) {
4343
return PyLong_AsLongLong(obj);
44+
} else if constexpr (std::is_same_v<T, unsigned long>) {
45+
return PyLong_AsUnsignedLong(obj);
46+
} else if constexpr (std::is_same_v<T, unsigned long long>) {
47+
return PyLong_AsUnsignedLongLong(obj);
4448
} else {
4549
static_assert(!sizeof(T*), "pylong_as<T> not implemented for given T");
4650
}
4751
}
4852

53+
template <typename T>
54+
T pylong_as_overflow_and(PyObject* obj, int* overflow) {
55+
if constexpr (std::is_same_v<T, int>) {
56+
return pylong_as_int(obj);
57+
} else if constexpr (std::is_same_v<T, long>) {
58+
return PyLong_AsLongAndOverflow(obj, overflow);
59+
} else if constexpr (std::is_same_v<T, long long>) {
60+
return PyLong_AsLongLongAndOverflow(obj, overflow);
61+
} else {
62+
static_assert(!sizeof(T*), "pylong_as_overflow_and<T> not implemented for given T");
63+
}
64+
}
65+
4966
template <typename T>
5067
struct PythonWrapper {
5168
PyObject_HEAD
@@ -75,6 +92,21 @@ void pywrapper_dealloc(PyObject* self) {
7592
Py_TYPE(self)->tp_free(self);
7693
}
7794

95+
template <typename T>
96+
PyObject* pywrapper_richcompare_via_operator_equals(PyObject* self, PyObject* other, int op) {
97+
if (!PyObject_TypeCheck(self, &T::pytype) || !PyObject_TypeCheck(other, &T::pytype))
98+
return Py_NewRef(Py_NotImplemented);
99+
100+
T& a = py_unwrap<T>(self);
101+
T& b = py_unwrap<T>(other);
102+
103+
switch (op) {
104+
case Py_EQ: return Py_NewRef(a == b ? Py_True : Py_False);
105+
case Py_NE: return Py_NewRef(a == b ? Py_False : Py_True);
106+
default: return Py_NewRef(Py_NotImplemented);
107+
}
108+
}
109+
78110
struct OK_t{};
79111
struct ErrorRaised_t{};
80112

cext/test/test_hash_map.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,18 @@ int main() {
1313

1414
hm.insert(0, 0);
1515
CHECK(hm.find(0));
16-
CHECK(*hm.find(0) == 0);
16+
CHECK(hm.find(0)->value == 0);
1717

1818
// Insert doesn't overwrite existing values
1919
hm.insert(0, 20);
20-
CHECK(*hm.find(0) == 0);
20+
CHECK(hm.find(0)->value == 0);
2121

2222
for (int i = 1; i < 1000; ++i) {
2323
hm.insert(i * 16, i * 10);
2424
for (int j = 0; j <= i; ++j) {
25-
int* v = hm.find(j * 16);
26-
CHECK(v);
27-
CHECK(*v == j * 10);
25+
auto* item = hm.find(j * 16);
26+
CHECK(item);
27+
CHECK(item->value == j * 10);
2828
}
2929
}
3030

0 commit comments

Comments
 (0)