Skip to content

Commit af9b980

Browse files
committed
perf: skip bincount if unnecessary
1 parent 75b175b commit af9b980

1 file changed

Lines changed: 11 additions & 9 deletions

File tree

deepmd/pt/model/network/utils.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,17 @@ def aggregate(
3030
-------
3131
output: [num_owner, feature_dim]
3232
"""
33-
bin_count = torch.bincount(owners)
34-
bin_count = bin_count.where(bin_count != 0, bin_count.new_ones(1))
35-
36-
if (num_owner is not None) and (bin_count.shape[0] != num_owner):
37-
difference = num_owner - bin_count.shape[0]
38-
bin_count = torch.cat([bin_count, bin_count.new_ones(difference)])
39-
40-
# make sure this operation is done on the same device of data and owners
41-
output = data.new_zeros([bin_count.shape[0], data.shape[1]])
33+
if num_owner is None or average:
34+
# requires bincount
35+
bin_count = torch.bincount(owners)
36+
bin_count = bin_count.where(bin_count != 0, bin_count.new_ones(1))
37+
if (num_owner is not None) and (bin_count.shape[0] != num_owner):
38+
difference = num_owner - bin_count.shape[0]
39+
bin_count = torch.cat([bin_count, bin_count.new_ones(difference)])
40+
else:
41+
num_owner = bin_count.shape[0]
42+
43+
output = data.new_zeros([num_owner, data.shape[1]])
4244
output = output.index_add_(0, owners, data)
4345
if average:
4446
output = (output.T / bin_count).T

0 commit comments

Comments
 (0)