Skip to content

Commit 188b060

Browse files
committed
add missing cuda parameter to predict call
1 parent 87a8744 commit 188b060

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

selene_sdk/predict/model_predict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ def get_predictions(self,
582582
sequence = self._pad_or_truncate_sequence(input)
583583
seq_enc = self.reference_sequence.sequence_to_encoding(sequence)
584584
seq_enc = np.expand_dims(seq_enc, axis=0) # add batch size of 1
585-
return predict(self.model, seq_enc)
585+
return predict(self.model, seq_enc, use_cuda=self.use_cuda)
586586
elif input.endswith('.fa') or input.endswith('.fasta'):
587587
self.get_predictions_for_fasta_file(
588588
input, output_dir, output_format=output_format)

0 commit comments

Comments
 (0)