Skip to content

Commit 6c93925

Browse files
committed
revise sanitize config and model card usage
1 parent 6cda963 commit 6c93925

2 files changed

Lines changed: 7 additions & 4 deletions

File tree

torch_molecule/utils/format.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@ def serialize_config(obj):
5858

5959
# Handle numpy arrays
6060
elif isinstance(obj, (np.ndarray, np.generic)):
61-
if obj.size < 1000:
61+
# If it's a single number wrapped in a numpy array, just return the number
62+
if obj.size == 1:
63+
return obj.item()
64+
elif obj.size < 1000:
6265
return {
6366
"_type": "numpy_array",
6467
"data": obj.tolist(),

torch_molecule/utils/hf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,11 @@ def create_model_card(
182182
repo="{repo_id}"
183183
)
184184
185-
# Make predictions
185+
# For predictor: Make predictions
186186
# predictions = model.predict(smiles_list)
187-
# Make generations
187+
# For generator: Make generations
188188
# generations = model.generate(n_samples)
189-
# Make encodings
189+
# For encoder: Make encodings
190190
# encodings = model.encode(smiles_list)
191191
```
192192

0 commit comments

Comments
 (0)