@@ -499,15 +499,7 @@ def get_predictions_for_fasta_file(self,
499499 len (self .reference_sequence .BASES_ARR )))
500500 batch_ids = []
501501 for i , fasta_record in enumerate (fasta_file ):
502- cur_sequence = str (fasta_record )
503-
504- if len (cur_sequence ) < self .sequence_length :
505- cur_sequence = _pad_sequence (cur_sequence ,
506- self .sequence_length ,
507- self .reference_sequence .UNK_BASE )
508- elif len (cur_sequence ) > self .sequence_length :
509- cur_sequence = _truncate_sequence (cur_sequence , self .sequence_length )
510-
502+ cur_sequence = self ._pad_or_truncate_sequence (str (fasta_record ))
511503 cur_sequence_encoding = self .reference_sequence .sequence_to_encoding (
512504 cur_sequence )
513505
@@ -531,19 +523,22 @@ def get_predictions_for_fasta_file(self,
531523
532524
533525 def get_predictions (self ,
534- input_path ,
535- output_dir ,
526+ input ,
527+ output_dir = None ,
536528 output_format = "tsv" ,
537529 strand_index = None ):
538530 """
539- Get model predictions for sequences specified in a FASTA or BED file.
531+ Get model predictions for sequences specified as a raw sequence,
532+ FASTA, or BED file.
540533
541534 Parameters
542535 ----------
543- input_path : str
544- Input path to the FASTA or BED file.
545- output_dir : str
546- Output directory to write the model predictions.
536+ input : str
537+ A single sequence, or a path to the FASTA or BED file input.
538+ output_dir : str, optional
539+ Default is None. Output directory to write the model predictions.
540+ If this is left blank a raw sequence input will be assumed, though
541+ an output directory is required for FASTA and BED inputs.
547542 output_format : {'tsv', 'hdf5'}, optional
548543 Default is 'tsv'. Choose whether to save TSV or HDF5 output files.
549544 TSV is easier to access (i.e. open with text editor/Excel) and
@@ -583,16 +578,23 @@ def get_predictions(self,
583578 or .tsv file will mark this sequence or region as `contains_unk = True`.
584579
585580 """
586- if input_path .endswith ('.fa' ) or input_path .endswith ('.fasta' ):
581+ if output_dir is None :
582+ sequence = self ._pad_or_truncate_sequence (input )
583+ seq_enc = self .reference_sequence .sequence_to_encoding (sequence )
584+ seq_enc = np .expand_dims (seq_enc , axis = 0 ) # add batch size of 1
585+ return predict (self .model , seq_enc )
586+ elif input .endswith ('.fa' ) or input .endswith ('.fasta' ):
587587 self .get_predictions_for_fasta_file (
588- input_path , output_dir , output_format = output_format )
588+ input , output_dir , output_format = output_format )
589589 else :
590590 self .get_predictions_for_bed_file (
591- input_path ,
591+ input ,
592592 output_dir ,
593593 output_format = output_format ,
594594 strand_index = strand_index )
595595
596+ return None
597+
596598 def in_silico_mutagenesis_predict (self ,
597599 sequence ,
598600 base_preds ,
@@ -898,14 +900,7 @@ def in_silico_mutagenesis_from_file(self,
898900
899901 fasta_file = pyfaidx .Fasta (input_path )
900902 for i , fasta_record in enumerate (fasta_file ):
901- cur_sequence = str .upper (str (fasta_record ))
902- if len (cur_sequence ) < self .sequence_length :
903- cur_sequence = _pad_sequence (cur_sequence ,
904- self .sequence_length ,
905- self .reference_sequence .UNK_BASE )
906- elif len (cur_sequence ) > self .sequence_length :
907- cur_sequence = _truncate_sequence (
908- cur_sequence , self .sequence_length )
903+ cur_sequence = self ._pad_or_truncate_sequence (str .upper (str (fasta_record )))
909904
910905 # Generate mut sequences and base preds.
911906 mutated_sequences = in_silico_mutagenesis_sequences (
@@ -1143,3 +1138,15 @@ def variant_effect_prediction(self,
11431138
11441139 for r in reporters :
11451140 r .write_to_file ()
1141+
1142+ def _pad_or_truncate_sequence (self , sequence ):
1143+ if len (sequence ) < self .sequence_length :
1144+ sequence = _pad_sequence (
1145+ sequence ,
1146+ self .sequence_length ,
1147+ self .reference_sequence .UNK_BASE ,
1148+ )
1149+ elif len (sequence ) > self .sequence_length :
1150+ sequence = _truncate_sequence (sequence , self .sequence_length )
1151+
1152+ return sequence
0 commit comments