File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -30,15 +30,17 @@ def aggregate(
3030 -------
3131 output: [num_owner, feature_dim]
3232 """
33- bin_count = torch .bincount (owners )
34- bin_count = bin_count .where (bin_count != 0 , bin_count .new_ones (1 ))
35-
36- if (num_owner is not None ) and (bin_count .shape [0 ] != num_owner ):
37- difference = num_owner - bin_count .shape [0 ]
38- bin_count = torch .cat ([bin_count , bin_count .new_ones (difference )])
39-
40- # make sure this operation is done on the same device of data and owners
41- output = data .new_zeros ([bin_count .shape [0 ], data .shape [1 ]])
33+ if num_owner is None or average :
34+ # requires bincount
35+ bin_count = torch .bincount (owners )
36+ bin_count = bin_count .where (bin_count != 0 , bin_count .new_ones (1 ))
37+ if (num_owner is not None ) and (bin_count .shape [0 ] != num_owner ):
38+ difference = num_owner - bin_count .shape [0 ]
39+ bin_count = torch .cat ([bin_count , bin_count .new_ones (difference )])
40+ else :
41+ num_owner = bin_count .shape [0 ]
42+
43+ output = data .new_zeros ([num_owner , data .shape [1 ]])
4244 output = output .index_add_ (0 , owners , data )
4345 if average :
4446 output = (output .T / bin_count ).T
You can’t perform that action at this time.
0 commit comments