1+ import numpy as np
2+ import pandas as pd
3+ import os
4+ from torch_molecule import GraphMAEMolecularEncoder
5+
6+ def test_graphmae_encoder ():
7+ # Load molecules from CSV file
8+ data_path = "data/molecule100.csv"
9+ if not os .path .exists (data_path ):
10+ print (f"Data file not found: { data_path } " )
11+ # Use simple molecules as fallback
12+ molecules = [
13+ "CC(=O)O" , # Acetic acid
14+ "CCO" , # Ethanol
15+ "CCCC" , # Butane
16+ "c1ccccc1" , # Benzene
17+ "CCN" , # Ethylamine
18+ ]
19+ else :
20+ df = pd .read_csv (data_path )
21+ molecules = df ['smiles' ].tolist ()[:50 ] # Use first 50 molecules
22+ print (f"Loaded { len (molecules )} molecules from { data_path } " )
23+
24+ # Initialize GraphMAE model
25+ model = GraphMAEMolecularEncoder (
26+ num_layer = 3 ,
27+ hidden_size = 128 ,
28+ batch_size = 16 ,
29+ epochs = 30 , # Small number for testing
30+ mask_rate = 0.15 ,
31+ verbose = True ,
32+ # device="cpu"
33+ )
34+ print ("GraphMAE model initialized successfully" )
35+
36+ # Test fitting
37+ print ("\n === Testing GraphMAE model self-supervised fitting ===" )
38+ model .fit (molecules )
39+
40+ # Test encoding
41+ print ("\n === Testing molecule encoding ===" )
42+ encodings = model .encode (molecules [:5 ])
43+ print (f"Encoding shape: { encodings .shape } " )
44+
45+ # Test saving and loading
46+ print ("\n === Testing model saving and loading ===" )
47+ save_path = "graphmae_model.pt"
48+ model .save_to_local (save_path )
49+ print (f"Model saved to { save_path } " )
50+
51+ new_model = GraphMAEMolecularEncoder ()
52+ new_model .load_from_local (save_path )
53+ print ("Model loaded successfully" )
54+
55+ # Test encoding with loaded model
56+ new_encodings = new_model .encode (molecules [:5 ])
57+ print (f"New encoding shape: { new_encodings .shape } " )
58+
59+ # Verify encodings are the same (or very close)
60+ encoding_diff = (encodings - new_encodings ).abs ().max ().item ()
61+ print (f"Max difference between encodings: { encoding_diff } " )
62+
63+ # Clean up
64+ if os .path .exists (save_path ):
65+ os .remove (save_path )
66+ print (f"Cleaned up { save_path } " )
67+
68+ def test_graphmae_with_edge_masking ():
69+ # Load molecules from CSV file
70+ data_path = "data/molecule100.csv"
71+ if not os .path .exists (data_path ):
72+ print (f"Data file not found: { data_path } " )
73+ # Use simple molecules as fallback
74+ molecules = [
75+ "CC(=O)O" , # Acetic acid
76+ "CCO" , # Ethanol
77+ "CCCC" , # Butane
78+ "c1ccccc1" , # Benzene
79+ "CCN" , # Ethylamine
80+ ]
81+ else :
82+ df = pd .read_csv (data_path )
83+ molecules = df ['smiles' ].tolist ()[:50 ] # Use first 50 molecules
84+ print (f"Loaded { len (molecules )} molecules from { data_path } " )
85+
86+ # Initialize GraphMAE model with edge masking enabled
87+ model = GraphMAEMolecularEncoder (
88+ num_layer = 3 ,
89+ hidden_size = 128 ,
90+ batch_size = 16 ,
91+ epochs = 30 , # Small number for testing
92+ mask_rate = 0.15 ,
93+ mask_edge = True , # Enable edge masking
94+ verbose = True ,
95+ # device="cpu"
96+ )
97+ print ("GraphMAE model with edge masking initialized successfully" )
98+
99+ # Test fitting
100+ print ("\n === Testing GraphMAE model with edge masking ===" )
101+ model .fit (molecules )
102+
103+ # Test encoding
104+ print ("\n === Testing molecule encoding with edge masking model ===" )
105+ encodings = model .encode (molecules [:5 ])
106+ print (f"Encoding shape: { encodings .shape } " )
107+
108+ # Test saving and loading
109+ print ("\n === Testing edge masking model saving and loading ===" )
110+ save_path = "graphmae_edge_model.pt"
111+ model .save_to_local (save_path )
112+ print (f"Model saved to { save_path } " )
113+
114+ new_model = GraphMAEMolecularEncoder ()
115+ new_model .load_from_local (save_path )
116+ print ("Model loaded successfully" )
117+
118+ # Verify edge masking parameter was preserved
119+ print (f"Loaded model mask_edge parameter: { new_model .mask_edge } " )
120+
121+ # Clean up
122+ if os .path .exists (save_path ):
123+ os .remove (save_path )
124+ print (f"Cleaned up { save_path } " )
125+
126+ if __name__ == "__main__" :
127+ print ("=== Testing GraphMAE Encoder (Default Configuration) ===" )
128+ test_graphmae_encoder ()
129+
130+ print ("\n === Testing GraphMAE Encoder with Edge Masking ===" )
131+ test_graphmae_with_edge_masking ()
0 commit comments