@@ -31,6 +31,7 @@ class InteractionNetworkBlock(MessagePassing):
3131 In Advances in Neural Information Processing Systems (NeurIPS 2016).
3232 DOI: `<https://doi.org/10.48550/arXiv.1612.00222>_`.
3333 """
34+
3435 def __init__ (
3536 self ,
3637 node_feature_dim ,
@@ -84,19 +85,18 @@ def __init__(
8485 "`node_feature_dim` must be a positive integer,"
8586 f" got { node_feature_dim } ."
8687 )
87-
88+
8889 if hidden_dim <= 0 :
8990 raise ValueError (
90- "`hidden_dim` must be a positive integer,"
91- f" got { hidden_dim } ."
91+ "`hidden_dim` must be a positive integer," f" got { hidden_dim } ."
9292 )
93-
93+
9494 if n_message_layers <= 0 :
9595 raise ValueError (
9696 "`n_message_layers` must be a positive integer,"
9797 f" got { n_message_layers } ."
9898 )
99-
99+
100100 if n_update_layers <= 0 :
101101 raise ValueError (
102102 "`n_update_layers` must be a positive integer,"
@@ -110,7 +110,7 @@ def __init__(
110110
111111 # Message network
112112 self .message_net = FeedForward (
113- input_dimensions = 2 * self .node_feature_dim ,
113+ input_dimensions = 2 * self .node_feature_dim ,
114114 output_dimensions = self .hidden_dim ,
115115 inner_size = self .hidden_dim ,
116116 n_layers = n_message_layers ,
@@ -165,4 +165,4 @@ def update(self, message, x):
165165 :return: The updated node features.
166166 :rtype: torch.Tensor
167167 """
168- return self .update_net (torch .cat ((x , message ), dim = - 1 ))
168+ return self .update_net (torch .cat ((x , message ), dim = - 1 ))
0 commit comments