Skip to content

Commit e2fce5e

Browse files
committed
update situation's topk 3 to 1
1 parent 6d3d867 commit e2fce5e

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

services/predict_services.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,10 @@ async def classify_style_with_session(session: aiohttp.ClientSession, download_u
112112
labels = {}
113113
for name, (embeds, texts) in embed_dict.items():
114114
logits = logit_scale * img_feats @ embeds.T
115-
_, idx = logits.topk(top_k, dim=-1)
115+
if name == "situations":
116+
_, idx = logits.topk(1, dim=-1)
117+
else:
118+
_, idx = logits.topk(top_k, dim=-1)
116119
labels[name] = [texts[i] for i in idx.squeeze(0).tolist()]
117120

118121
return {

0 commit comments

Comments
 (0)