66from ...utils import init_weights
77
88from .utils import get_mask_indices , get_fingerprint_loss
9+ from ...utils .graph .features import allowable_features
910
1011class_criterion = torch .nn .CrossEntropyLoss ()
1112
@@ -23,7 +24,9 @@ def __init__(
2324 ):
2425 super (GNN , self ).__init__ ()
2526 gnn_name = encoder_type .split ("-" )[0 ]
26- self .num_atom_type = 119
27+ decoding_size = len (allowable_features ['possible_atomic_num_list' ])
28+
29+ self .mask_atom_id = 119
2730 self .hidden_size = hidden_size
2831 self .mask_rate = mask_rate
2932 self .lw_rec = lw_rec
@@ -50,7 +53,7 @@ def __init__(
5053 if self .pool is None :
5154 raise ValueError (f"Invalid graph pooling type { readout } ." )
5255
53- self .predictor = GNN_Decoder (hidden_size , self . num_atom_type )
56+ self .predictor = GNN_Decoder (hidden_size , decoding_size )
5457
5558 def initialize_parameters (self , seed = None ):
5659 """
@@ -82,13 +85,16 @@ def compute_loss(self, batched_data):
8285
8386 # mask nodes' features
8487 for node_idx in masked_node_indices :
85- batched_data .x [node_idx ] = torch .tensor ([self .num_atom_type - 1 ] + [0 ] * (batched_data .x .shape [1 ] - 1 ))
88+ batched_data .x [node_idx ] = torch .tensor ([self .mask_atom_id - 1 ] + [0 ] * (batched_data .x .shape [1 ] - 1 ))
8689
8790 # generate predictions
8891 h_node , _ = self .graph_encoder (batched_data )
8992 h_rep = self .pool (h_node , batched_data .batch )
9093 batched_data .x = h_node
9194 prediction_class = self .predictor (batched_data )[masked_node_indices ]
95+ print ('prediction_class' , prediction_class .max (), prediction_class .min ())
96+ print ('batched_data.y' , batched_data .y .max (), batched_data .y .min ())
97+
9298
9399 # target_class = batched_data.y.to(torch.float32)
94100 loss_class = class_criterion (prediction_class .to (torch .float32 ), batched_data .y .long ())
0 commit comments