Skip to content

Commit 42ba044

Browse files
committed
Fix: skip single atom molecule in diffusion model data preprocessing. Fixes #14
1 parent cd8e77a commit 42ba044

4 files changed

Lines changed: 38 additions & 5 deletions

File tree

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import os
2+
3+
from torch_molecule import DigressMolecularGenerator
4+
from torch_molecule.datasets import load_qm9
5+
6+
7+
def train_on_qm9() -> None:
8+
model = DigressMolecularGenerator(verbose=True, batch_size=1024, epochs=2)
9+
10+
smiles_list, _ = load_qm9(local_dir="torchmol_data")
11+
12+
original_count = len(smiles_list)
13+
smiles_list = [s for s in smiles_list if isinstance(s, str) and s]
14+
if original_count > len(smiles_list):
15+
print(f"Data cleaning: removed {original_count - len(smiles_list)} invalid entries from QM9 dataset.")
16+
17+
model.fit(smiles_list)
18+
19+
print("\n=== Generating 10 molecules from QM9-trained model ===")
20+
generated_smiles = model.generate(batch_size=10)
21+
print(f"Generated {len(generated_smiles)} molecules.")
22+
for i, smiles in enumerate(generated_smiles, start=1):
23+
print(f"{i}: {smiles}")
24+
25+
26+
if __name__ == "__main__":
27+
train_on_qm9()

torch_molecule/generator/digress/modeling_digress.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ def _convert_to_pytorch_data(self, X, y=None):
167167

168168
# No H, first heavy atom has type 0
169169
node_type = torch.from_numpy(graph['node_feat'][:, 0] - 1)
170+
if node_type.numel() <= 1:
171+
continue
170172

171173
# Filter out invalid node types (< 0)
172174
valid_mask = node_type >= 0
@@ -191,16 +193,16 @@ def _convert_to_pytorch_data(self, X, y=None):
191193
# Update node and edge data
192194
node_type = node_type[valid_mask]
193195
g.edge_index = remapped_edge_index
194-
g.edge_attr = valid_edge_attr.long().squeeze(-1)
196+
g.edge_attr = valid_edge_attr.long()
195197
else:
196198
# No invalid nodes, proceed normally
197199
g.edge_index = torch.from_numpy(graph["edge_index"])
198200
edge_attr = torch.from_numpy(graph["edge_feat"])[:, 0] + 1
199-
g.edge_attr = edge_attr.long().squeeze(-1)
200-
201+
g.edge_attr = edge_attr.long()
202+
201203
# * is encoded as "misc" which is 119 - 1 and should be 117
202204
node_type[node_type == 118] = 117
203-
g.x = node_type.long().squeeze(-1)
205+
g.x = node_type.long()
204206
# g.y = torch.from_numpy(graph["y"])
205207
g.y = torch.zeros(1, 1)
206208
del graph["node_feat"]
@@ -273,7 +275,7 @@ def _setup_optimizers(self) -> Tuple[torch.optim.Optimizer, Optional[Any]]:
273275

274276
def fit(self, X_train: List[str]) -> "DigressMolecularGenerator":
275277
num_task = 0 if self.input_dim_y is None else self.input_dim_y
276-
X_train, _ = self._validate_inputs(X_train, num_task=num_task)
278+
X_train, _ = self._validate_inputs(X_train, num_task=num_task, return_rdkit_mol=False)
277279
self._setup_diffusion_params(X_train)
278280
self._initialize_model(self.model_class)
279281
self.model.initialize_parameters()

torch_molecule/generator/gdss/modeling_gdss.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ def _convert_to_pytorch_data(self, X, y=None):
235235

236236
# No H, first heavy atom has type 0
237237
node_type = torch.from_numpy(graph['node_feat'][:, 0] - 1)
238+
if node_type.numel() <= 1:
239+
continue
238240

239241
# Filter out invalid node types (< 0)
240242
valid_mask = node_type >= 0

torch_molecule/generator/graph_dit/modeling_graph_dit.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ def _convert_to_pytorch_data(self, X, y=None):
190190

191191
# No H, first heavy atom has type 0
192192
node_type = torch.from_numpy(graph['node_feat'][:, 0] - 1)
193+
if node_type.numel() <= 1:
194+
continue
193195

194196
# Filter out invalid node types (< 0)
195197
valid_mask = node_type >= 0

0 commit comments

Comments
 (0)