Skip to content

Commit 032a220

Browse files
AmoghAmogh
authored andcommitted
Aligning with changes in GMMConv
1 parent 0d09d2a commit 032a220

1 file changed

Lines changed: 20 additions & 9 deletions

File tree

GNNGraphs/src/gnnheterograph/query.jl

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,19 +90,30 @@ function graph_indicator(g::GNNHeteroGraph, node_t::Symbol)
9090
return gi
9191
end
9292

93-
edge_features(g::GNNHeteroGraph) = begin
93+
"""
94+
edge_features(g::GNNHeteroGraph)
95+
96+
Return the edge features for a heterogeneous graph with a single edge type.
97+
If the graph has multiple edge types, this will error.
98+
If no edge features are present, returns `nothing`.
99+
"""
100+
function edge_features(g::GNNHeteroGraph)
94101
if isempty(g.edata)
95102
return nothing
96103
elseif length(g.edata) > 1
97-
@error "Multiple edge feature arrays, access directly through `g.edata`"
104+
@error "Multiple edge types present, access edge features directly through `g.edata[edge_t]`"
98105
else
99-
ds = only(values(g.edata))
100-
if isempty(ds)
101-
return nothing
102-
elseif length(ds) > 1
103-
@error "Multiple edge feature arrays, access directly through `g.edata`"
104-
else
105-
return first(values(ds))
106+
edata = only(values(g.edata))
107+
if edata isa AbstractArray
108+
return isempty(edata) ? nothing : edata
109+
else
110+
if isempty(edata)
111+
return nothing
112+
elseif length(edata) > 1
113+
@error "Multiple edge feature arrays present, access directly through `g.edata[edge_t]`"
114+
else
115+
return first(values(edata))
116+
end
106117
end
107118
end
108119
end

0 commit comments

Comments
 (0)