@@ -129,19 +129,24 @@ function normalize_graphdata(data; default_name::Symbol, kws...)
129129 normalize_graphdata (NamedTuple {(default_name,)} ((data,)); default_name, kws... )
130130end
131131
132- function normalize_graphdata (data:: NamedTuple ; default_name, n, duplicate_if_needed = false )
132+ function normalize_graphdata (data:: NamedTuple ; default_name:: Symbol , n:: Int ,
133+ duplicate_if_needed:: Bool = false , glob:: Bool = false )
133134 # This had to workaround two Zygote bugs with NamedTuples
134135 # https://github.com/FluxML/Zygote.jl/issues/1071
135- # https://github.com/FluxML/Zygote.jl/issues/1072
136+ # https://github.com/FluxML/Zygote.jl/issues/1072 # TODO fixed. Can we simplify something?
137+
136138
137139 if n > 1
138140 @assert all (x -> x isa AbstractArray, data) " Non-array features provided."
139141 end
140142
141- if n <= 1
142- # If last array dimension is not 1, add a new dimension.
143- # This is mostly useful to reshape global feature vectors
144- # of size D to Dx1 matrices.
143+ if n <= 1 && glob == true
144+ @assert n == 1
145+ n = - 1 # relax the case of a single graph, allowing to store arbitrary types
146+ # # # If last array dimension is not 1, add a new dimension.
147+ # # # This is mostly useful to reshape global feature vectors
148+ # # # of size D to Dx1 matrices.
149+ # TODO remove this and handle better the batching of global features
145150 unsqz_last (v:: AbstractArray ) = size (v)[end ] != 1 ? reshape (v, size (v)... , 1 ) : v
146151 unsqz_last (v) = v
147152
@@ -161,7 +166,7 @@ function normalize_graphdata(data::NamedTuple; default_name, n, duplicate_if_nee
161166
162167 for x in data
163168 if x isa AbstractArray
164- @assert size (x)[end ]== n " Wrong size in last dimension for feature array, expected $n but got $(size (x)[end ]) ."
169+ @assert size (x)[end ] == n " Wrong size in last dimension for feature array, expected $n but got $(size (x)[end ]) ."
165170 end
166171 end
167172 end
0 commit comments