forked from JuliaGraphs/GraphNeuralNetworks.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathGNNGraphsCUDAExt.jl
More file actions
39 lines (25 loc) · 910 Bytes
/
GNNGraphsCUDAExt.jl
File metadata and controls
39 lines (25 loc) · 910 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
module GNNGraphsCUDAExt
using CUDA
using Random, Statistics, LinearAlgebra
using GNNGraphs
using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T
using SparseArrays
const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix}
# Query
GNNGraphs._rand_dense_vector(A::CUMAT_T) = CUDA.randn(size(A, 1))
# Transform
GNNGraphs.dense_zeros_like(a::CUMAT_T, T::Type, sz = size(a)) = CUDA.zeros(T, sz)
# Utils
GNNGraphs.iscuarray(x::AnyCuArray) = true
function GNNGraphs.binarize(Mat::CUSPARSE.CuSparseMatrixCSC, T::DataType = Bool)
bin_vals = fill!(similar(nonzeros(Mat)), one(T))
return CUSPARSE.CuSparseMatrixCSC(Mat.colPtr, rowvals(Mat), bin_vals, size(Mat))
end
function sort_edge_index(u::AnyCuArray, v::AnyCuArray)
dev = get_device(u)
cdev = cpu_device()
u, v = u |> cdev, v |> cdev
#TODO proper cuda friendly implementation
sort_edge_index(u, v) |> dev
end
end #module