We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 9c14ecd commit 9b68fe2Copy full SHA for 9b68fe2
1 file changed
deepmd/pt/model/network/utils.py
@@ -13,7 +13,6 @@
13
has_torch_scatter = False
14
15
16
-@torch.jit.script
17
def aggregate(
18
data: torch.Tensor,
19
owners: torch.Tensor,
@@ -37,6 +36,7 @@ def aggregate(
37
36
-------
38
output: [num_owner, feature_dim]
39
"""
+ # faster and recommended
40
if has_torch_scatter:
41
output = torch_scatter.segment_coo(
42
src=data,
@@ -46,7 +46,7 @@ def aggregate(
46
)
47
return output
48
49
- # if torch_scatter is not available, use index_add_
+ # if torch_scatter is not available, use native index_add
50
if num_owner is None or average:
51
# requires bincount
52
bin_count = torch.bincount(owners)
0 commit comments