Skip to content

Commit 000f349

Browse files
committed
update oracle in docs
1 parent fd47601 commit 000f349

1 file changed

Lines changed: 8 additions & 7 deletions

File tree

torch_molecule/generator/graph_ga/modeling_graph_ga.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,18 +108,19 @@ def fit(
108108
Training data, which will be used as the initial population.
109109
y_train : Optional[Union[List, np.ndarray]]
110110
Training labels for conditional generation (num_task is not 0).
111-
oracles : Optional[List[Callable]]
112-
Oracles used to score the generated molecules, if not provided, default oracles based on
111+
oracle : Optional[Callable]
112+
Oracle used to score the generated molecules. If not provided, default oracles based on
113113
``sklearn.ensemble.RandomForestRegressor`` are trained on the X_train and y_train.
114114
115-
For the customized oracle, it should be a Callable object, i.e., ``oracle(X)``, and the number
116-
of oracles must equal to the number of tasks (``num_task``).
115+
For a customized oracle, it should be a Callable object, i.e., ``oracle(X, y)``.
116+
Please properly wrap your oracle to take two inputs:
117+
- a list of ``rdkit.Chem.rdchem.Mol`` objects and
118+
- a (1, num_task) numpy array of target values that all the molecules in the list target to achieve. Take care of NaN values if any.
117119
118-
Please properly wrap your oracle that takes a list of ``rdkit.Chem.rdchem.Mol`` and returns a list of scores for each molecule.
119-
For multi-conditional generation, scores for different tasks should be aggregated, i.e. mean or sum.
120+
Scores for different tasks should be aggregated, i.e., mean or sum. The return should be a list of scores (float).
120121
Smaller scores mean closer to the target goal.
121122
122-
We don't need oracles for unconditional generation.
123+
Oracles are not needed for unconditional generation.
123124
124125
Returns
125126
-------

0 commit comments

Comments
 (0)