Skip to content

Commit af88d62

Browse files
authored
[Fix] Avoid scope switching when using mmdet inference interface (#2039)
1 parent 13acbc8 commit af88d62

4 files changed

Lines changed: 33 additions & 5 deletions

File tree

demo/topdown_demo_with_mmdet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
import mmcv
99
import mmengine
1010
import numpy as np
11-
from mmengine.registry import init_default_scope
1211

1312
from mmpose.apis import inference_topdown
1413
from mmpose.apis import init_model as init_pose_estimator
1514
from mmpose.evaluation.functional import nms
1615
from mmpose.registry import VISUALIZERS
1716
from mmpose.structures import merge_data_samples, split_instances
17+
from mmpose.utils import adapt_mmdet_pipeline
1818

1919
try:
2020
from mmdet.apis import inference_detector, init_detector
@@ -28,7 +28,6 @@ def process_one_image(args, img_path, detector, pose_estimator, visualizer,
2828
"""Visualize predicted keypoints (and heatmaps) of one image."""
2929

3030
# predict bbox
31-
init_default_scope(detector.cfg.get('default_scope', 'mmdet'))
3231
det_result = inference_detector(detector, img_path)
3332
pred_instance = det_result.pred_instances.cpu().numpy()
3433
bboxes = np.concatenate(
@@ -147,6 +146,7 @@ def main():
147146
# build detector
148147
detector = init_detector(
149148
args.det_config, args.det_checkpoint, device=args.device)
149+
detector.cfg = adapt_mmdet_pipeline(detector.cfg)
150150

151151
# build pose estimator
152152
pose_estimator = init_pose_estimator(

mmpose/apis/webcam/nodes/model_nodes/detector_node.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from typing import Dict, List, Optional, Union
33

44
import numpy as np
5-
from mmengine.registry import init_default_scope
65

6+
from mmpose.utils import adapt_mmdet_pipeline
77
from ...utils import get_config_path
88
from ..node import Node
99
from ..registry import NODES
@@ -92,6 +92,7 @@ def __init__(self,
9292
# Init model
9393
self.model = init_detector(
9494
self.model_config, self.model_checkpoint, device=self.device)
95+
self.model.cfg = adapt_mmdet_pipeline(self.model.cfg)
9596

9697
# Register buffers
9798
self.register_input_buffer(input_buffer, 'input', trigger=True)
@@ -109,7 +110,6 @@ def process(self, input_msgs):
109110

110111
img = input_msg.get_image()
111112

112-
init_default_scope(self.model.cfg.get('default_scope', 'mmdet'))
113113
preds = inference_detector(self.model, img)
114114
objects = self._post_process(preds)
115115
input_msg.update_objects(objects)

mmpose/utils/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .camera import SimpleCamera, SimpleCameraTorch
33
from .collect_env import collect_env
4+
from .config_utils import adapt_mmdet_pipeline
45
from .logger import get_root_logger
56
from .setup_env import register_all_modules, setup_multi_processes
67
from .timer import StopWatch
78

89
__all__ = [
910
'get_root_logger', 'collect_env', 'StopWatch', 'setup_multi_processes',
10-
'register_all_modules', 'SimpleCamera', 'SimpleCameraTorch'
11+
'register_all_modules', 'SimpleCamera', 'SimpleCameraTorch',
12+
'adapt_mmdet_pipeline'
1113
]

mmpose/utils/config_utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from mmpose.utils.typing import ConfigDict
3+
4+
5+
def adapt_mmdet_pipeline(cfg: ConfigDict) -> ConfigDict:
6+
"""Converts pipeline types in MMDetection's test dataloader to use the
7+
'mmdet' namespace.
8+
9+
Args:
10+
cfg (ConfigDict): Configuration dictionary for MMDetection.
11+
12+
Returns:
13+
ConfigDict: Configuration dictionary with updated pipeline types.
14+
"""
15+
# use lazy import to avoid hard dependence on mmdet
16+
from mmdet.datasets import transforms
17+
18+
if 'test_dataloader' not in cfg:
19+
return cfg
20+
21+
pipeline = cfg.test_dataloader.dataset.pipeline
22+
for trans in pipeline:
23+
if trans['type'] in dir(transforms):
24+
trans['type'] = 'mmdet.' + trans['type']
25+
26+
return cfg

0 commit comments

Comments
 (0)