Skip to content

Commit 737ec9f

Browse files
authored
Update Zero-shot Classification Task (#27)
1 parent a247b4e commit 737ec9f

3 files changed

Lines changed: 44 additions & 19 deletions

File tree

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,17 @@ Evaluates the quality of the learned representations in retrieving the <i>k</i>
156156
using recall@k metric. This is applicable to any number of pairs of modalities at once, depending on memory constraints.
157157
</td>
158158
</tr>
159+
<tr>
160+
<td>
161+
162+
Zero-shot Classification
163+
</td>
164+
<td>
165+
Evaluates the ability of a pre-trained encoder-based multimodal model to predict classes that were not explicitly seen
166+
during training. The new classes are given as text prompts, and the query modality can be any of the supported modalities.
167+
Binary and multi-class classification tasks are supported.
168+
</td>
169+
</tr>
159170
</table>
160171

161172
## Components

mmlearn/tasks/zero_shot_classification.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,13 @@ def evaluation_step(
195195
query_embeddings /= query_embeddings.norm(p=2, dim=-1, keepdim=True)
196196
query_embeddings = query_embeddings[matching_indices]
197197

198-
logits = 100.0 * _safe_matmul(query_embeddings, class_embeddings)
198+
if self.all_dataset_info[dataset_index]["num_classes"] == 2:
199+
softmax_output = _safe_matmul(
200+
query_embeddings, class_embeddings
201+
).softmax(dim=-1)
202+
logits = softmax_output[:, 1] - softmax_output[:, 0]
203+
else:
204+
logits = 100.0 * _safe_matmul(query_embeddings, class_embeddings)
199205
targets = batch[Modalities.get_modality(query_modality).target][
200206
matching_indices
201207
]
@@ -233,27 +239,36 @@ def _create_metrics(
233239
num_classes: int, top_k: List[int], prefix: str, postfix: str
234240
) -> MetricCollection:
235241
"""Create a collection of classification metrics."""
242+
task_type = "binary" if num_classes == 2 else "multiclass"
243+
acc_metrics = (
244+
{
245+
f"top{k}_accuracy": Accuracy(
246+
task=task_type, num_classes=num_classes, top_k=k, average="micro"
247+
)
248+
for k in top_k
249+
}
250+
if num_classes > 2
251+
else {"accuracy": Accuracy(task=task_type, num_classes=num_classes)}
252+
)
236253
return MetricCollection(
237254
{
238255
"precision": Precision(
239-
task="multiclass", num_classes=num_classes, average="macro"
256+
task=task_type,
257+
num_classes=num_classes,
258+
average="macro" if num_classes > 2 else "micro",
240259
),
241260
"recall": Recall(
242-
task="multiclass", num_classes=num_classes, average="macro"
261+
task=task_type,
262+
num_classes=num_classes,
263+
average="macro" if num_classes > 2 else "micro",
243264
),
244265
"f1_score_macro": F1Score(
245-
task="multiclass", num_classes=num_classes, average="macro"
266+
task=task_type,
267+
num_classes=num_classes,
268+
average="macro" if num_classes > 2 else "micro",
246269
),
247-
"aucroc": AUROC(task="multiclass", num_classes=num_classes),
248-
**{
249-
f"top{k}_accuracy": Accuracy(
250-
task="multiclass",
251-
num_classes=num_classes,
252-
top_k=k,
253-
average="micro",
254-
)
255-
for k in top_k
256-
},
270+
"aucroc": AUROC(task=task_type, num_classes=num_classes),
271+
**acc_metrics,
257272
},
258273
prefix=prefix,
259274
postfix=postfix,

projects/med_benchmarking/configs/experiment/zeroshot_classification_eval.yaml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ datasets:
143143

144144
dataloader:
145145
test:
146-
batch_size: 64
146+
batch_size: 128
147147
num_workers: 4
148148

149149
task:
@@ -153,15 +153,14 @@ task:
153153
task_specs:
154154
- top_k: [1]
155155
query_modality: rgb
156-
run_on_validation: false
157-
run_on_test: true
156+
run_on_validation: False
157+
run_on_test: True
158158
compute_validation_loss: False
159159
compute_test_loss: False
160160

161161
trainer:
162162
precision: 16-mixed
163-
deterministic: False
164-
benchmark: True
163+
deterministic: True
165164
sync_batchnorm: False # set to True if using DDP with batchnorm
166165
log_every_n_steps: 100
167166

0 commit comments

Comments
 (0)