-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathoutput_handlers.py
More file actions
215 lines (165 loc) · 6.6 KB
/
output_handlers.py
File metadata and controls
215 lines (165 loc) · 6.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""Output handling for different file formats."""
import os
import logging
from typing import List, Tuple, Optional
import pandas as pd
import numpy as np
import h5py
import inference
from config import NUM_CLASSES, EMBEDDING_DIM, RESULT_CSV_FILE
logger = logging.getLogger(__name__)
def create_csv_columns(has_classifier: bool) -> List[str]:
"""Create column names for CSV output.
Args:
has_classifier: Whether classifier predictions are included
Returns:
List of column names
"""
columns = ["id"]
if has_classifier:
columns.extend([
"top_class_name",
"top_class",
"top_3_classes_names",
"top_3_classes",
])
# Probability columns
columns.extend([f"prob{i:02d}" for i in range(NUM_CLASSES)])
# Feature columns
columns.extend([f"feat{i:04d}" for i in range(EMBEDDING_DIM)])
return columns
def compute_top_predictions(probabilities: np.ndarray) -> Tuple[int, str, str, str]:
"""Compute top class and top-3 classes from probabilities.
Args:
probabilities: Array of class probabilities
Returns:
Tuple of (top_class_idx, top_class_name, top_3_names, top_3_indices)
"""
probs_list = probabilities.tolist()
# Top class
top_class = probs_list.index(max(probs_list))
top_class_name = inference.CLASS2NAME[top_class]
# Top 3 classes
top_3_indices = sorted(range(len(probs_list)), key=lambda i: probs_list[i], reverse=True)[:3]
top_3_names = ",".join([inference.CLASS2NAME[i] for i in top_3_indices])
top_3_str = ",".join(map(str, top_3_indices))
return top_class, top_class_name, top_3_names, top_3_str
def create_csv_row(
output_prefix: str,
embedding: np.ndarray,
probabilities: Optional[np.ndarray],
has_classifier: bool,
) -> List:
"""Create a single CSV row from results.
Args:
output_prefix: ID/prefix for this sample
embedding: Embedding vector
probabilities: Probability vector (None if embeddings_only)
has_classifier: Whether classifier is being used
Returns:
List representing CSV row
"""
row = [output_prefix]
if has_classifier and probabilities is not None:
top_class, top_class_name, top_3_names, top_3_str = compute_top_predictions(probabilities)
row.extend([top_class_name, top_class, top_3_names, top_3_str])
row.extend(probabilities.tolist())
row.extend(embedding.tolist())
return row
class CSVOutputHandler:
"""Handler for CSV output format."""
def __init__(self, has_classifier: bool):
"""Initialize CSV handler.
Args:
has_classifier: Whether classifier predictions are included
"""
self.has_classifier = has_classifier
self.columns = create_csv_columns(has_classifier)
self.rows = []
def add_batch(
self,
output_prefixes: List[str],
embeddings: List[np.ndarray],
probabilities_list: List[Optional[np.ndarray]],
) -> None:
"""Add a batch of results.
Args:
output_prefixes: List of sample IDs
embeddings: List of embedding vectors
probabilities_list: List of probability vectors (None entries if embeddings_only)
"""
for prefix, embedding, probs in zip(output_prefixes, embeddings, probabilities_list):
row = create_csv_row(prefix, embedding, probs, self.has_classifier)
self.rows.append(row)
def save(self, output_path: str = RESULT_CSV_FILE) -> None:
"""Save accumulated results to CSV.
Args:
output_path: Path to save CSV file
"""
if not self.rows:
logger.warning("No data to save to CSV")
return
df = pd.DataFrame(self.rows, columns=self.columns)
df.to_csv(output_path, index=False)
logger.info(f"CSV saved to {output_path} with {len(self.rows)} rows")
class H5ADOutputHandler:
"""Handler for H5AD output format."""
def __init__(self, output_dir: str):
"""Initialize H5AD handler.
Args:
output_dir: Directory to save H5AD file
"""
self.output_dir = output_dir
self.embeddings = []
self.probabilities = []
self.image_names = []
def add_batch(
self,
output_prefixes: List[str],
embeddings: List[np.ndarray],
probabilities_list: List[Optional[np.ndarray]],
) -> None:
"""Add a batch of results.
Args:
output_prefixes: List of sample IDs
embeddings: List of embedding vectors
probabilities_list: List of probability vectors (None entries if embeddings_only)
"""
for prefix, embedding, probs in zip(output_prefixes, embeddings, probabilities_list):
self.embeddings.append(embedding)
if probs is not None:
self.probabilities.append(probs)
self.image_names.append(prefix)
def save(self, embeddings_only: bool = False) -> str:
"""Save accumulated results to H5AD file.
Args:
embeddings_only: Whether only embeddings were computed
Returns:
Path to saved H5AD file
"""
if not self.embeddings:
logger.warning("No data to save to H5AD")
return ""
logger.info(f"Saving H5AD file with {len(self.embeddings)} embeddings...")
os.makedirs(self.output_dir, exist_ok=True)
h5ad_path = os.path.join(self.output_dir, "embeddings.h5ad")
embeddings_array = np.stack(self.embeddings)
with h5py.File(h5ad_path, 'w') as f:
# Save embeddings as the main data matrix (AnnData convention)
f.create_dataset('X', data=embeddings_array)
# Save observation names (image names)
obs_names = np.array(self.image_names, dtype='S')
f.create_dataset('obs/index', data=obs_names)
# Save probabilities if available
if self.probabilities:
probabilities_array = np.stack(self.probabilities)
f.create_dataset('obsm/probabilities', data=probabilities_array)
# Add metadata
f.attrs['n_obs'] = len(self.embeddings)
f.attrs['n_vars'] = embeddings_array.shape[1]
f.attrs['created_by'] = 'SubCellPortable'
f.attrs['embeddings_only'] = embeddings_only
logger.info(f"H5AD file saved: {h5ad_path}")
logger.info(f"Shape: {embeddings_array.shape}")
logger.info(f"Contains: embeddings, image_names" + (", probabilities" if self.probabilities else ""))
return h5ad_path