Skip to content

Commit 5390439

Browse files
committed
inference api: batch-wise image analysis for superanimal prepare_2d
1 parent a51661e commit 5390439

1 file changed

Lines changed: 21 additions & 21 deletions

File tree

fmpose3d/inference_api/fmpose3d.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -187,28 +187,28 @@ def predict(
187187
cv2.imwrite(p, frames[idx])
188188
paths.append(p)
189189

190-
# Run DeepLabCut on each frame individually (the API
191-
# expects a single image path).
190+
# Run DeepLabCut once for all frames.
191+
predictions = superanimal_analyze_images(
192+
superanimal_name=cfg.superanimal_name,
193+
model_name=cfg.sa_model_name,
194+
detector_name=cfg.detector_name,
195+
images=paths,
196+
max_individuals=cfg.max_individuals,
197+
out_folder=tmpdir,
198+
)
199+
# predictions: {image_path: {"bodyparts": (N_ind, K, 3), ...}}
200+
# Iterate in input order to keep frame alignment stable.
192201
for img_path in paths:
193-
predictions = superanimal_analyze_images(
194-
superanimal_name=cfg.superanimal_name,
195-
model_name=cfg.sa_model_name,
196-
detector_name=cfg.detector_name,
197-
images=img_path,
198-
max_individuals=cfg.max_individuals,
199-
out_folder=tmpdir,
200-
)
201-
# predictions: {image_path: {"bodyparts": (N_ind, K, 3), ...}}
202-
for _path, payload in predictions.items():
203-
bodyparts = payload.get("bodyparts")
204-
if bodyparts is None:
205-
# No detection -- fill with zeros.
206-
all_mapped.append(np.zeros((1, 26, 2), dtype="float32"))
207-
continue
208-
xy = bodyparts[..., :2] # (N_ind, K, 2)
209-
mapped = self._map_keypoints(xy)
210-
# Take only the first individual.
211-
all_mapped.append(mapped[:1])
202+
payload = predictions.get(img_path)
203+
bodyparts = None if payload is None else payload.get("bodyparts")
204+
if bodyparts is None:
205+
# No detection -- fill with zeros.
206+
all_mapped.append(np.zeros((1, 26, 2), dtype="float32"))
207+
continue
208+
xy = bodyparts[..., :2] # (N_ind, K, 2)
209+
mapped = self._map_keypoints(xy)
210+
# Take only the first individual.
211+
all_mapped.append(mapped[:1])
212212

213213
# Stack along frame axis → (1, N, 26, 2)
214214
kpts = np.stack(all_mapped, axis=1) # (1, N, 26, 2)

0 commit comments

Comments
 (0)