33"""
44import logging
55import os
6+ import warnings
67
78import numpy as np
89import torch
@@ -104,17 +105,25 @@ def __init__(self,
104105
105106 self .sampler = data_sampler
106107
108+ self .output_dir = output_dir
109+ os .makedirs (output_dir , exist_ok = True )
110+
107111 self .features = features
108112 self ._use_ixs = list (range (len (features )))
109113 if use_features_ord is not None :
110114 feature_ixs = {f : ix for (ix , f ) in enumerate (features )}
111115 self ._use_ixs = []
112- self ._use_features = use_features_ord
113- for f in use_features_ord :
114- self ._use_ixs .append (feature_ixs [f ])
116+ self .features = []
115117
116- self .output_dir = output_dir
117- os .makedirs (output_dir , exist_ok = True )
118+ for f in use_features_ord :
119+ if f in feature_ixs :
120+ self ._use_ixs .append (feature_ixs [f ])
121+ self .features .append (f )
122+ else :
123+ warnings .warn (("Feature {0} in `use_features_ord` "
124+ "does not match any features in the list "
125+ "`features` and will be skipped." ).format (f ))
126+ self ._write_features_ordered_to_file ()
118127
119128 initialize_logger (
120129 os .path .join (self .output_dir , "{0}.log" .format (
@@ -138,13 +147,23 @@ def __init__(self,
138147
139148 self ._test_data , self ._all_test_targets = \
140149 self .sampler .get_data_and_targets (self .batch_size , n_test_samples )
150+ # TODO: we should be able to do this on the sampler end, vs here...
151+ # this is a bad workaround, since self._test_data still has the full
152+ # featureset in it, and we select the subset during `evaluate`
153+ self ._all_test_targets = self ._all_test_targets [:, self ._use_ixs ]
141154
142155 if (hasattr (self .sampler , "reference_sequence" ) and
143- isinstance (self .sampler .reference_sequence , Genome ) and
144- _is_lua_trained_model (model )):
145- Genome .update_bases_order (['A' , 'G' , 'C' , 'T' ])
146- elif isinstance (self .sampler .reference_sequence , Genome ):
147- Genome .update_bases_order (['A' , 'C' , 'G' , 'T' ])
156+ isinstance (self .sampler .reference_sequence , Genome )):
157+ if _is_lua_trained_model (model ):
158+ Genome .update_bases_order (['A' , 'G' , 'C' , 'T' ])
159+ else :
160+ Genome .update_bases_order (['A' , 'C' , 'G' , 'T' ])
161+
162+ def _write_features_ordered_to_file (self ):
163+ fp = os .path .join (self .output_dir , 'use_features_ord.txt' )
164+ with open (fp , 'w+' ) as file_handle :
165+ for f in self .features :
166+ file_handle .write ('{0}\n ' .format (f ))
148167
149168 def _get_feature_from_index (self , index ):
150169 """
@@ -180,7 +199,7 @@ def evaluate(self):
180199 all_predictions = []
181200 for (inputs , targets ) in self ._test_data :
182201 inputs = torch .Tensor (inputs )
183- targets = torch .Tensor (targets )
202+ targets = torch .Tensor (targets [:, self . _use_ixs ] )
184203
185204 if self .use_cuda :
186205 inputs = inputs .cuda ()
@@ -192,7 +211,7 @@ def evaluate(self):
192211 predictions = None
193212 if _is_lua_trained_model (self .model ):
194213 predictions = self .model .forward (
195- inputs .transpose (1 , 2 ).unsqueeze_ (2 ))
214+ inputs .transpose (1 , 2 ).contiguous (). unsqueeze_ (2 ))
196215 else :
197216 predictions = self .model .forward (
198217 inputs .transpose (1 , 2 ))
0 commit comments