@@ -59,7 +59,12 @@ class EvaluateModel(object):
5959 Default is None. Specify an ordered list of features for which to
6060 run the evaluation. The features in this list must be identical to or
6161 a subset of `features`, and in the order you want the resulting
62- `test_targets.npz` and `test_predictions.npz` to be saved.
62+ `test_targets.npz` and `test_predictions.npz` to be saved. If using
63+ a FileSampler or H5DataLoader for the evaluation, you can pass in
64+ a dataset with the targets matrix only containing these features, but
65+ note that this subsetted targets matrix MUST be ordered the same
66+ way as `features`, and the predictions and targets .npz output
67+ will be reordered according to `use_features_ord`.
6368
6469 Attributes
6570 ----------
@@ -117,17 +122,14 @@ def __init__(self,
117122 self .output_dir = output_dir
118123 os .makedirs (output_dir , exist_ok = True )
119124
120- self .features = features
125+ self .features = np . array ( features )
121126 self ._use_ixs = list (range (len (features )))
122127 if use_features_ord is not None :
123128 feature_ixs = {f : ix for (ix , f ) in enumerate (features )}
124129 self ._use_ixs = []
125- self .features = []
126-
127130 for f in use_features_ord :
128131 if f in feature_ixs :
129132 self ._use_ixs .append (feature_ixs [f ])
130- self .features .append (f )
131133 else :
132134 warnings .warn (("Feature {0} in `use_features_ord` "
133135 "does not match any features in the list "
@@ -157,11 +159,23 @@ def __init__(self,
157159
158160 self ._test_data , self ._all_test_targets = \
159161 self .sampler .get_data_and_targets (self .batch_size , n_test_samples )
160- # TODO: we should be able to do this on the sampler end instead of
161- # here. the current workaround is problematic, since
162- # self._test_data still has the full featureset in it, and we
163- # select the subset during `evaluate`
164- self ._all_test_targets = self ._all_test_targets [:, self ._use_ixs ]
162+
163+ self ._use_testmat_ixs = self ._use_ixs [:]
164+ # if the targets shape is the same as the subsetted features,
165+ # reindex based on the subsetted list
166+ if self ._all_test_targets .shape [1 ] == len (self ._use_ixs ):
167+ subset_features = {self .features [ix ]: i for (i , ix ) in
168+ enumerate (sorted (self ._use_ixs ))}
169+ self ._use_testmat_ixs = [
170+ subset_features [f ] for f in self .features [self ._use_ixs ]]
171+
172+ self ._all_test_targets = self ._all_test_targets [
173+ :, self ._use_testmat_ixs ]
174+
175+ # save the targets dataset now
176+ np .savez_compressed (
177+ os .path .join (self .output_dir , "test_targets.npz" ),
178+ data = self ._all_test_targets )
165179
166180 # reset Genome base ordering when applicable.
167181 if (hasattr (self .sampler , "reference_sequence" ) and
@@ -179,7 +193,7 @@ def _write_features_ordered_to_file(self):
179193 """
180194 fp = os .path .join (self .output_dir , 'use_features_ord.txt' )
181195 with open (fp , 'w+' ) as file_handle :
182- for f in self .features :
196+ for f in self .features [ self . _use_ixs ] :
183197 file_handle .write ('{0}\n ' .format (f ))
184198
185199 def _get_feature_from_index (self , index ):
@@ -196,7 +210,7 @@ def _get_feature_from_index(self, index):
196210 The name of the feature/target at the specified index.
197211
198212 """
199- return self .features [index ]
213+ return self .features [self . _use_ixs ][ index ]
200214
201215 def evaluate (self ):
202216 """
@@ -216,7 +230,7 @@ def evaluate(self):
216230 all_predictions = []
217231 for (inputs , targets ) in self ._test_data :
218232 inputs = torch .Tensor (inputs )
219- targets = torch .Tensor (targets [:, self ._use_ixs ])
233+ targets = torch .Tensor (targets [:, self ._use_testmat_ixs ])
220234
221235 if self .use_cuda :
222236 inputs = inputs .cuda ()
@@ -246,10 +260,6 @@ def evaluate(self):
246260 os .path .join (self .output_dir , "test_predictions.npz" ),
247261 data = all_predictions )
248262
249- np .savez_compressed (
250- os .path .join (self .output_dir , "test_targets.npz" ),
251- data = self ._all_test_targets )
252-
253263 loss = np .average (batch_losses )
254264 logger .info ("test loss: {0}" .format (loss ))
255265 for name , score in average_scores .items ():
0 commit comments