Skip to content

Commit 97080e5

Browse files
authored
Merge pull request #139 from FunctionLab/add-raw-sequence-predict
Add raw sequence predict
2 parents 6a9f2a0 + f24a0b1 commit 97080e5

2 files changed

Lines changed: 36 additions & 27 deletions

File tree

selene_sdk/predict/model_predict.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -499,15 +499,7 @@ def get_predictions_for_fasta_file(self,
499499
len(self.reference_sequence.BASES_ARR)))
500500
batch_ids = []
501501
for i, fasta_record in enumerate(fasta_file):
502-
cur_sequence = str(fasta_record)
503-
504-
if len(cur_sequence) < self.sequence_length:
505-
cur_sequence = _pad_sequence(cur_sequence,
506-
self.sequence_length,
507-
self.reference_sequence.UNK_BASE)
508-
elif len(cur_sequence) > self.sequence_length:
509-
cur_sequence = _truncate_sequence(cur_sequence, self.sequence_length)
510-
502+
cur_sequence = self._pad_or_truncate_sequence(str(fasta_record))
511503
cur_sequence_encoding = self.reference_sequence.sequence_to_encoding(
512504
cur_sequence)
513505

@@ -531,19 +523,22 @@ def get_predictions_for_fasta_file(self,
531523

532524

533525
def get_predictions(self,
534-
input_path,
535-
output_dir,
526+
input,
527+
output_dir=None,
536528
output_format="tsv",
537529
strand_index=None):
538530
"""
539-
Get model predictions for sequences specified in a FASTA or BED file.
531+
Get model predictions for sequences specified as a raw sequence,
532+
FASTA, or BED file.
540533
541534
Parameters
542535
----------
543-
input_path : str
544-
Input path to the FASTA or BED file.
545-
output_dir : str
546-
Output directory to write the model predictions.
536+
input : str
537+
A single sequence, or a path to the FASTA or BED file input.
538+
output_dir : str, optional
539+
Default is None. Output directory to write the model predictions.
540+
If this is left blank a raw sequence input will be assumed, though
541+
an output directory is required for FASTA and BED inputs.
547542
output_format : {'tsv', 'hdf5'}, optional
548543
Default is 'tsv'. Choose whether to save TSV or HDF5 output files.
549544
TSV is easier to access (i.e. open with text editor/Excel) and
@@ -583,16 +578,23 @@ def get_predictions(self,
583578
or .tsv file will mark this sequence or region as `contains_unk = True`.
584579
585580
"""
586-
if input_path.endswith('.fa') or input_path.endswith('.fasta'):
581+
if output_dir is None:
582+
sequence = self._pad_or_truncate_sequence(input)
583+
seq_enc = self.reference_sequence.sequence_to_encoding(sequence)
584+
seq_enc = np.expand_dims(seq_enc, axis=0) # add batch size of 1
585+
return predict(self.model, seq_enc)
586+
elif input.endswith('.fa') or input.endswith('.fasta'):
587587
self.get_predictions_for_fasta_file(
588-
input_path, output_dir, output_format=output_format)
588+
input, output_dir, output_format=output_format)
589589
else:
590590
self.get_predictions_for_bed_file(
591-
input_path,
591+
input,
592592
output_dir,
593593
output_format=output_format,
594594
strand_index=strand_index)
595595

596+
return None
597+
596598
def in_silico_mutagenesis_predict(self,
597599
sequence,
598600
base_preds,
@@ -898,14 +900,7 @@ def in_silico_mutagenesis_from_file(self,
898900

899901
fasta_file = pyfaidx.Fasta(input_path)
900902
for i, fasta_record in enumerate(fasta_file):
901-
cur_sequence = str.upper(str(fasta_record))
902-
if len(cur_sequence) < self.sequence_length:
903-
cur_sequence = _pad_sequence(cur_sequence,
904-
self.sequence_length,
905-
self.reference_sequence.UNK_BASE)
906-
elif len(cur_sequence) > self.sequence_length:
907-
cur_sequence = _truncate_sequence(
908-
cur_sequence, self.sequence_length)
903+
cur_sequence = self._pad_or_truncate_sequence(str.upper(str(fasta_record)))
909904

910905
# Generate mut sequences and base preds.
911906
mutated_sequences = in_silico_mutagenesis_sequences(
@@ -1143,3 +1138,15 @@ def variant_effect_prediction(self,
11431138

11441139
for r in reporters:
11451140
r.write_to_file()
1141+
1142+
def _pad_or_truncate_sequence(self, sequence):
1143+
if len(sequence) < self.sequence_length:
1144+
sequence = _pad_sequence(
1145+
sequence,
1146+
self.sequence_length,
1147+
self.reference_sequence.UNK_BASE,
1148+
)
1149+
elif len(sequence) > self.sequence_length:
1150+
sequence = _truncate_sequence(sequence, self.sequence_length)
1151+
1152+
return sequence

tutorials/analyzing_mutations_with_trained_models/analyzing_mutations_with_trained_models.ipynb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
"\n",
1818
"Download the compressed data from [here](https://zenodo.org/record/1319784): \n",
1919
"\n",
20+
"**Note: The tutorials and manuscript examples have been run on Selene versions 0.1.3 through 0.2.0, and PyTorch version 0.4.1. Models associated with the manuscript can only be run with PyTorch 0.4.1, as PyTorch models are not forward-compatible.**\n",
21+
"\n",
2022
"```sh\n",
2123
"wget https://zenodo.org/record/2206957/files/selene_analyzing_mutations_tutorial.tar.gz\n",
2224
"```\n",

0 commit comments

Comments
 (0)