Skip to content

Commit 01d050f

Browse files
committed
comment on todos/workarounds
1 parent 66468ae commit 01d050f

2 files changed

Lines changed: 37 additions & 18 deletions

File tree

selene_sdk/evaluate_model.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44
import logging
55
import os
6+
import warnings
67

78
import numpy as np
89
import torch
@@ -104,17 +105,25 @@ def __init__(self,
104105

105106
self.sampler = data_sampler
106107

108+
self.output_dir = output_dir
109+
os.makedirs(output_dir, exist_ok=True)
110+
107111
self.features = features
108112
self._use_ixs = list(range(len(features)))
109113
if use_features_ord is not None:
110114
feature_ixs = {f: ix for (ix, f) in enumerate(features)}
111115
self._use_ixs = []
112-
self._use_features = use_features_ord
113-
for f in use_features_ord:
114-
self._use_ixs.append(feature_ixs[f])
116+
self.features = []
115117

116-
self.output_dir = output_dir
117-
os.makedirs(output_dir, exist_ok=True)
118+
for f in use_features_ord:
119+
if f in feature_ixs:
120+
self._use_ixs.append(feature_ixs[f])
121+
self.features.append(f)
122+
else:
123+
warnings.warn(("Feature {0} in `use_features_ord` "
124+
"does not match any features in the list "
125+
"`features` and will be skipped.").format(f))
126+
self._write_features_ordered_to_file()
118127

119128
initialize_logger(
120129
os.path.join(self.output_dir, "{0}.log".format(
@@ -138,13 +147,23 @@ def __init__(self,
138147

139148
self._test_data, self._all_test_targets = \
140149
self.sampler.get_data_and_targets(self.batch_size, n_test_samples)
150+
# TODO: we should be able to do this on the sampler end, vs here...
151+
# this is a bad workaround, since self._test_data still has the full
152+
# featureset in it, and we select the subset during `evaluate`
153+
self._all_test_targets = self._all_test_targets[:, self._use_ixs]
141154

142155
if (hasattr(self.sampler, "reference_sequence") and
143-
isinstance(self.sampler.reference_sequence, Genome) and
144-
_is_lua_trained_model(model)):
145-
Genome.update_bases_order(['A', 'G', 'C', 'T'])
146-
elif isinstance(self.sampler.reference_sequence, Genome):
147-
Genome.update_bases_order(['A', 'C', 'G', 'T'])
156+
isinstance(self.sampler.reference_sequence, Genome)):
157+
if _is_lua_trained_model(model):
158+
Genome.update_bases_order(['A', 'G', 'C', 'T'])
159+
else:
160+
Genome.update_bases_order(['A', 'C', 'G', 'T'])
161+
162+
def _write_features_ordered_to_file(self):
163+
fp = os.path.join(self.output_dir, 'use_features_ord.txt')
164+
with open(fp, 'w+') as file_handle:
165+
for f in self.features:
166+
file_handle.write('{0}\n'.format(f))
148167

149168
def _get_feature_from_index(self, index):
150169
"""
@@ -180,7 +199,7 @@ def evaluate(self):
180199
all_predictions = []
181200
for (inputs, targets) in self._test_data:
182201
inputs = torch.Tensor(inputs)
183-
targets = torch.Tensor(targets)
202+
targets = torch.Tensor(targets[:, self._use_ixs])
184203

185204
if self.use_cuda:
186205
inputs = inputs.cuda()
@@ -192,7 +211,7 @@ def evaluate(self):
192211
predictions = None
193212
if _is_lua_trained_model(self.model):
194213
predictions = self.model.forward(
195-
inputs.transpose(1, 2).unsqueeze_(2))
214+
inputs.transpose(1, 2).contiguous().unsqueeze_(2))
196215
else:
197216
predictions = self.model.forward(
198217
inputs.transpose(1, 2))

selene_sdk/predict/model_predict.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def __init__(self,
176176
if type(self.reference_sequence) == Genome and \
177177
_is_lua_trained_model(model):
178178
Genome.update_bases_order(['A', 'G', 'C', 'T'])
179-
else:
179+
else: # even if not using Genome, I guess we can update?
180180
Genome.update_bases_order(['A', 'C', 'G', 'T'])
181181
self._write_mem_limit = write_mem_limit
182182

@@ -421,11 +421,11 @@ def get_predictions_for_bed_file(self,
421421
batch_ids.append(label+(contains_unk,))
422422
sequences[ i % self.batch_size, :, :] = encoding
423423
if contains_unk:
424-
warnings.warn("For region {0}, "
425-
"reference sequence contains unknown base(s). "
426-
"--will be marked `True` in the `contains_unk` column "
427-
"of the .tsv or the row_labels .txt file.".format(
428-
label))
424+
warnings.warn(("For region {0}, "
425+
"reference sequence contains unknown "
426+
"base(s). --will be marked `True` in the "
427+
"`contains_unk` column of the .tsv or "
428+
"row_labels .txt file.").format(label))
429429

430430
if (batch_ids and i == 0) or i % self.batch_size != 0:
431431
sequences = sequences[:i % self.batch_size + 1, :, :]

0 commit comments

Comments
 (0)