File tree Expand file tree Collapse file tree
GNNGraphs/src/gnnheterograph Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -90,19 +90,30 @@ function graph_indicator(g::GNNHeteroGraph, node_t::Symbol)
9090 return gi
9191end
9292
93- edge_features (g:: GNNHeteroGraph ) = begin
93+ """
94+ edge_features(g::GNNHeteroGraph)
95+
96+ Return the edge features for a heterogeneous graph with a single edge type.
97+ If the graph has multiple edge types, this will error.
98+ If no edge features are present, returns `nothing`.
99+ """
100+ function edge_features (g:: GNNHeteroGraph )
94101 if isempty (g. edata)
95102 return nothing
96103 elseif length (g. edata) > 1
97- @error " Multiple edge feature arrays , access directly through `g.edata`"
104+ @error " Multiple edge types present , access edge features directly through `g.edata[edge_t] `"
98105 else
99- ds = only (values (g. edata))
100- if isempty (ds)
101- return nothing
102- elseif length (ds) > 1
103- @error " Multiple edge feature arrays, access directly through `g.edata`"
104- else
105- return first (values (ds))
106+ edata = only (values (g. edata))
107+ if edata isa AbstractArray
108+ return isempty (edata) ? nothing : edata
109+ else
110+ if isempty (edata)
111+ return nothing
112+ elseif length (edata) > 1
113+ @error " Multiple edge feature arrays present, access directly through `g.edata[edge_t]`"
114+ else
115+ return first (values (edata))
116+ end
106117 end
107118 end
108119end
You can’t perform that action at this time.
0 commit comments