@@ -167,6 +167,8 @@ def _convert_to_pytorch_data(self, X, y=None):
167167
168168 # No H, first heavy atom has type 0
169169 node_type = torch .from_numpy (graph ['node_feat' ][:, 0 ] - 1 )
170+ if node_type .numel () <= 1 :
171+ continue
170172
171173 # Filter out invalid node types (< 0)
172174 valid_mask = node_type >= 0
@@ -191,16 +193,16 @@ def _convert_to_pytorch_data(self, X, y=None):
191193 # Update node and edge data
192194 node_type = node_type [valid_mask ]
193195 g .edge_index = remapped_edge_index
194- g .edge_attr = valid_edge_attr .long (). squeeze ( - 1 )
196+ g .edge_attr = valid_edge_attr .long ()
195197 else :
196198 # No invalid nodes, proceed normally
197199 g .edge_index = torch .from_numpy (graph ["edge_index" ])
198200 edge_attr = torch .from_numpy (graph ["edge_feat" ])[:, 0 ] + 1
199- g .edge_attr = edge_attr .long (). squeeze ( - 1 )
200-
201+ g .edge_attr = edge_attr .long ()
202+
201203 # * is encoded as "misc" which is 119 - 1 and should be 117
202204 node_type [node_type == 118 ] = 117
203- g .x = node_type .long (). squeeze ( - 1 )
205+ g .x = node_type .long ()
204206 # g.y = torch.from_numpy(graph["y"])
205207 g .y = torch .zeros (1 , 1 )
206208 del graph ["node_feat" ]
@@ -273,7 +275,7 @@ def _setup_optimizers(self) -> Tuple[torch.optim.Optimizer, Optional[Any]]:
273275
274276 def fit (self , X_train : List [str ]) -> "DigressMolecularGenerator" :
275277 num_task = 0 if self .input_dim_y is None else self .input_dim_y
276- X_train , _ = self ._validate_inputs (X_train , num_task = num_task )
278+ X_train , _ = self ._validate_inputs (X_train , num_task = num_task , return_rdkit_mol = False )
277279 self ._setup_diffusion_params (X_train )
278280 self ._initialize_model (self .model_class )
279281 self .model .initialize_parameters ()
0 commit comments