From 3d98616d323f116e0347e1a7b4e93068f65fe0d9 Mon Sep 17 00:00:00 2001 From: 0xyangl Date: Tue, 10 Feb 2026 10:59:30 +0800 Subject: [PATCH 1/6] Add FLOPs analysis tool for groundingdino ## Motivation The existing [tools/analysis_tools/get_flops.py](cci:7://file:///home/david/mmdetection/tools/analysis_tools/get_flops.py:0:0-0:0) does not support grounding / vision-language detection models (e.g., GroundingDINO, GroundingCLIP) because these models require text inputs and have multi-modal architectures that cannot be traced end-to-end with `mmengine.analysis.get_model_complexity_info`. This PR adds a dedicated FLOPs analysis tool that handles the unique architecture of grounding detection models, providing per-component FLOPs and parameter breakdowns. ## Modification **New file: [tools/analysis_tools/get_flops_grounding.py](cci:7://file:///home/david/mmdetection/tools/analysis_tools/get_flops_grounding.py:0:0-0:0)** A script that computes per-component FLOPs and parameter counts for grounding detection models: - **Vision Backbone**: Accurate FLOPs via `fvcore.nn.FlopCountAnalysis` - **Text Encoder**: Estimated FLOPs based on model type (CLIP, BERT, etc.) - **Neck (ChannelMapper)**: Estimated from config-driven channel/stride info - **Transformer Encoder/Decoder**: Estimated from config-driven architecture params - **Detection Head**: Parameter count Key design choices: - Automatically disables `with_cp` (gradient checkpointing) which is incompatible with JIT tracing, without modifying the original config - Reads architecture parameters (channels, layers, embed_dim, etc.) dynamically from the model config instead of hardcoding - Uses `MMLogger` consistent with existing mmdet tools **New file: [tests/test_tools/test_get_flops_grounding.py](cci:7://file:///home/david/mmdetection/tests/test_tools/test_get_flops_grounding.py:0:0-0:0)** 41 unit tests covering all helper functions and config readers. ## BC-breaking No. This PR only adds new files and does not modify any existing code. ## Use cases ```bash # Basic usage python tools/analysis_tools/get_flops_grounding.py \ configs/mm_grounding_dino/grounding_dino_swin-t_finetune_8xb4_20e_cat.py # Custom input shape python tools/analysis_tools/get_flops_grounding.py --shape 640 640 # Specify number of classes for text encoder FLOPs estimation python tools/analysis_tools/get_flops_grounding.py --num-classes 80 --- get_flops_grounding.py | 525 ++++++++++++++++++++++++++++++++++++ test_get_flops_grounding.py | 326 ++++++++++++++++++++++ 2 files changed, 851 insertions(+) create mode 100644 get_flops_grounding.py create mode 100644 test_get_flops_grounding.py diff --git a/get_flops_grounding.py b/get_flops_grounding.py new file mode 100644 index 00000000000..8a5aa2c47c3 --- /dev/null +++ b/get_flops_grounding.py @@ -0,0 +1,525 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Get FLOPs for GroundingDINO models. + +This script is specifically designed for models that require text inputs, +such as GroundingDINO and other vision-language models. + +Gradient Checkpointing (with_cp) Compatibility: + This script uses fvcore's FlopCountAnalysis for accurate backbone + FLOPs, which internally uses PyTorch JIT tracing. However, gradient + checkpointing (with_cp=True) is INCOMPATIBLE with JIT tracing. + This script automatically disables gradient checkpointing when + building the model for analysis. This does NOT affect model accuracy + or your original config file. + +Usage: + python tools/analysis_tools/get_flops_grounding.py + +Example: + python tools/analysis_tools/get_flops_grounding.py \\ + configs/mm_grounding_dino/grounding_dino_swin-t_finetune_8xb4_20e_cat.py +""" +import argparse +import tempfile +from pathlib import Path + +import torch +from mmengine.config import Config, DictAction +from mmengine.logging import MMLogger +from mmengine.model import revert_sync_batchnorm +from mmengine.registry import init_default_scope + +from mmdet.registry import MODELS + +try: + from fvcore.nn import FlopCountAnalysis +except ImportError: + FlopCountAnalysis = None + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description='Get FLOPs for grounding detection models') + parser.add_argument('config', help='train config file path') + parser.add_argument( + '--shape', + type=int, + nargs='+', + default=[800, 1333], + help='input image size (height width)') + parser.add_argument( + '--text', + type=str, + default='car', + help='text prompt for computing text encoder FLOPs') + parser.add_argument( + '--num-classes', + type=int, + default=None, + help='number of classes (for estimating text encoder FLOPs)') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config') + args = parser.parse_args() + return args + + +def format_flops(flops): + """Format FLOPs to human readable string. + + Args: + flops (int or float): Number of FLOPs. + + Returns: + str: Human readable FLOPs string. + """ + if flops >= 1e12: + return f'{flops / 1e12:.2f} T' + elif flops >= 1e9: + return f'{flops / 1e9:.2f} G' + elif flops >= 1e6: + return f'{flops / 1e6:.2f} M' + else: + return f'{flops:.2f}' + + +def format_params(params): + """Format parameters to human readable string. + + Args: + params (int or float): Number of parameters. + + Returns: + str: Human readable parameter count string. + """ + if params >= 1e9: + return f'{params / 1e9:.2f} B' + elif params >= 1e6: + return f'{params / 1e6:.2f} M' + elif params >= 1e3: + return f'{params / 1e3:.2f} K' + else: + return f'{params:.0f}' + + +def _get_backbone_out_channels(cfg): + """Extract backbone output channels from config. + + Tries to read from the neck's in_channels, then falls back to + common defaults based on backbone type. + + Args: + cfg (Config): Model config. + + Returns: + list[int]: Output channels per feature level. + """ + # Try reading from neck config (most reliable) + if hasattr(cfg.model, 'neck'): + neck_cfg = cfg.model.neck + if hasattr(neck_cfg, 'in_channels'): + return list(neck_cfg.in_channels) + + # Fallback: infer from backbone type + backbone_type = cfg.model.backbone.get('type', '') + if 'Swin' in backbone_type: + embed_dims = cfg.model.backbone.get('embed_dims', 96) + depths = cfg.model.backbone.get('depths', [2, 2, 6, 2]) + num_levels = len(depths) + return [embed_dims * (2**i) for i in range(num_levels)] + + # Generic default + return [256, 512, 1024, 2048] + + +def _get_feature_strides(cfg): + """Get feature map strides from config. + + Args: + cfg (Config): Model config. + + Returns: + list[int]: Strides for each feature level. + """ + if hasattr(cfg.model, 'neck'): + neck_cfg = cfg.model.neck + if hasattr(neck_cfg, 'in_channels'): + num_levels = len(neck_cfg.in_channels) + # Common strides: start at 8, double each level + return [8 * (2**i) for i in range(num_levels)] + return [8, 16, 32] + + +def _get_neck_out_channels(cfg): + """Get neck output channels from config. + + Args: + cfg (Config): Model config. + + Returns: + int: Output channels of the neck. + """ + if hasattr(cfg.model, 'neck'): + neck_cfg = cfg.model.neck + if hasattr(neck_cfg, 'out_channels'): + return neck_cfg.out_channels + return 256 + + +def _get_transformer_config(cfg): + """Extract transformer encoder/decoder config. + + Reads layer counts, embed_dim, ffn_dim, and num_queries from the + model config. Falls back to GroundingDINO defaults. + + Args: + cfg (Config): Model config. + + Returns: + dict: Transformer configuration with keys: embed_dim, + num_encoder_layers, num_decoder_layers, ffn_dim, + num_queries. + """ + result = dict( + embed_dim=256, + num_encoder_layers=6, + num_decoder_layers=6, + ffn_dim=2048, + num_queries=900) + + # Try to read from encoder config + if hasattr(cfg.model, 'encoder'): + enc = cfg.model.encoder + if hasattr(enc, 'num_layers'): + result['num_encoder_layers'] = enc.num_layers + if hasattr(enc, 'layer_cfg'): + layer = enc.layer_cfg + if hasattr(layer, 'self_attn_cfg'): + attn = layer.self_attn_cfg + if hasattr(attn, 'embed_dims'): + result['embed_dim'] = attn.embed_dims + if hasattr(layer, 'ffn_cfg'): + ffn = layer.ffn_cfg + if hasattr(ffn, 'feedforward_channels'): + result['ffn_dim'] = ffn.feedforward_channels + + # Try to read from decoder config + if hasattr(cfg.model, 'decoder'): + dec = cfg.model.decoder + if hasattr(dec, 'num_layers'): + result['num_decoder_layers'] = dec.num_layers + + # Try to read num_queries + if hasattr(cfg.model, 'num_queries'): + result['num_queries'] = cfg.model.num_queries + elif hasattr(cfg.model, 'bbox_head'): + head = cfg.model.bbox_head + if hasattr(head, 'num_queries'): + result['num_queries'] = head.num_queries + + return result + + +def count_backbone_flops(model, input_shape, device, logger): + """Count FLOPs for the vision backbone using fvcore. + + Args: + model (nn.Module): The detection model. + input_shape (tuple[int]): Input image (H, W). + device (torch.device): Device to run on. + logger (MMLogger): Logger instance. + + Returns: + tuple: (flops, params) or (None, params) if fvcore + is not available. + """ + if not hasattr(model, 'backbone'): + return None, None + + backbone = model.backbone + backbone.eval() + + h, w = input_shape + x = torch.randn(1, 3, h, w).to(device) + + if FlopCountAnalysis is not None: + flops_analyzer = FlopCountAnalysis(backbone, x) + flops_analyzer.unsupported_ops_warnings(False) + flops_analyzer.uncalled_modules_warnings(False) + flops = flops_analyzer.total() + params = sum(p.numel() for p in backbone.parameters()) + return flops, params + else: + logger.warning('fvcore is not installed, backbone FLOPs cannot be ' + 'computed accurately. Install with: pip install fvcore') + params = sum(p.numel() for p in backbone.parameters()) + return None, params + + +def count_text_encoder_flops(model, num_classes): + """Estimate FLOPs for the text encoder. + + Uses rough per-forward-pass estimates based on model type since + text encoders often have complex control flow that prevents + accurate tracing. + + Args: + model (nn.Module): The detection model. + num_classes (int): Number of classes. + + Returns: + tuple: (flops, params) or (None, None) if no language + model is found. + """ + if not hasattr(model, 'language_model'): + return None, None + + lang_model = model.language_model + params = sum(p.numel() for p in lang_model.parameters()) + + # Estimate FLOPs based on model type + model_name = getattr(lang_model, 'name', '') + if 'clip' in model_name.lower(): + if 'large' in model_name.lower(): + flops_per_forward = 10e9 + elif 'base' in model_name.lower(): + flops_per_forward = 4e9 + else: + flops_per_forward = 4e9 + elif 'bert' in model_name.lower(): + flops_per_forward = 11e9 + else: + flops_per_forward = 5e9 + + if num_classes: + flops = flops_per_forward * num_classes + else: + flops = flops_per_forward + + return flops, params + + +def count_neck_flops(model, input_shape, cfg): + """Estimate FLOPs for the neck (e.g. ChannelMapper). + + Reads in_channels and out_channels from the config to avoid + hardcoded assumptions. + + Args: + model (nn.Module): The detection model. + input_shape (tuple[int]): Input image (H, W). + cfg (Config): Model config. + + Returns: + tuple: (flops, params). + """ + if not hasattr(model, 'neck') or model.neck is None: + return 0, 0 + + neck = model.neck + params = sum(p.numel() for p in neck.parameters()) + + h, w = input_shape + in_channels = _get_backbone_out_channels(cfg) + strides = _get_feature_strides(cfg) + out_channels = _get_neck_out_channels(cfg) + + flops = 0 + for in_c, stride in zip(in_channels, strides): + fh, fw = h // stride, w // stride + # 1x1 conv: in_c * out_c * H * W * 2 (multiply-add) + flops += in_c * out_channels * fh * fw * 2 + + return flops, params + + +def count_transformer_flops(model, input_shape, cfg): + """Estimate FLOPs for the transformer encoder and decoder. + + Reads architecture parameters from the config dynamically + instead of hardcoding values. + + Args: + model (nn.Module): The detection model. + input_shape (tuple[int]): Input image (H, W). + cfg (Config): Model config. + + Returns: + tuple: (encoder_flops, decoder_flops, + encoder_params, decoder_params). + """ + h, w = input_shape + trans_cfg = _get_transformer_config(cfg) + embed_dim = trans_cfg['embed_dim'] + num_enc = trans_cfg['num_encoder_layers'] + num_dec = trans_cfg['num_decoder_layers'] + ffn_dim = trans_cfg['ffn_dim'] + num_queries = trans_cfg['num_queries'] + + strides = _get_feature_strides(cfg) + feat_sizes = [(h // s, w // s) for s in strides] + total_tokens = sum(fh * fw for fh, fw in feat_sizes) + + # Encoder self-attention: 4 * n^2 * d (Q,K,V proj + output) + enc_attn = (4 * total_tokens * total_tokens * embed_dim * num_enc) + # Encoder FFN: 2 * n * d * ffn_dim + enc_ffn = 2 * total_tokens * embed_dim * ffn_dim * num_enc + encoder_flops = enc_attn + enc_ffn + + # Decoder self-attention on queries + dec_self = (4 * num_queries * num_queries * embed_dim * num_dec) + # Decoder cross-attention to image features + dec_cross = (4 * num_queries * total_tokens * embed_dim * num_dec) + # Decoder FFN + dec_ffn = 2 * num_queries * embed_dim * ffn_dim * num_dec + decoder_flops = dec_self + dec_cross + dec_ffn + + encoder_params = 0 + if hasattr(model, 'encoder'): + encoder_params = sum(p.numel() for p in model.encoder.parameters()) + + decoder_params = 0 + if hasattr(model, 'decoder'): + decoder_params = sum(p.numel() for p in model.decoder.parameters()) + + return encoder_flops, decoder_flops, encoder_params, decoder_params + + +def main(): + args = parse_args() + logger = MMLogger.get_instance(name='MMLogger') + + config_path = Path(args.config) + if not config_path.exists(): + logger.error(f'{config_path} not found.') + return + + cfg = Config.fromfile(args.config) + cfg.work_dir = tempfile.TemporaryDirectory().name + + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # Disable gradient checkpointing for FLOPs analysis. + # with_cp=True is incompatible with JIT tracing used by fvcore. + if hasattr(cfg.model, 'backbone'): + if cfg.model.backbone.get('with_cp', False): + cfg.model.backbone.with_cp = False + logger.warning('Auto-disabled gradient checkpointing ' + '(with_cp=False) for FLOPs analysis. ' + 'This does NOT affect your config file.') + + init_default_scope(cfg.get('default_scope', 'mmdet')) + + # Build model + model = MODELS.build(cfg.model) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = model.to(device) + model = revert_sync_batchnorm(model) + model.eval() + + # Get input shape + if len(args.shape) == 1: + h = w = args.shape[0] + else: + h, w = args.shape[:2] + + # Get number of classes + num_classes = args.num_classes + if num_classes is None: + if (hasattr(cfg.model, 'bbox_head') + and hasattr(cfg.model.bbox_head, 'num_classes')): + num_classes = cfg.model.bbox_head.num_classes + else: + num_classes = 80 # Default COCO + + split_line = '=' * 60 + logger.info(split_line) + logger.info('GroundingDINO FLOPs Analysis') + logger.info(split_line) + logger.info(f'Config: {args.config}') + logger.info(f'Input shape: ({h}, {w})') + logger.info(f'Number of classes: {num_classes}') + logger.info(split_line) + + total_flops = 0 + total_params = 0 + + # 1. Vision Backbone + backbone_flops, backbone_params = count_backbone_flops( + model, (h, w), device, logger) + if backbone_flops is not None: + logger.info('\n[Vision Backbone]') + logger.info(f' FLOPs: {format_flops(backbone_flops)}') + logger.info(f' Params: {format_params(backbone_params)}') + total_flops += backbone_flops + total_params += backbone_params + else: + logger.info('\n[Vision Backbone]') + logger.info(' FLOPs: (fvcore not installed, cannot compute)') + if backbone_params: + logger.info(f' Params: {format_params(backbone_params)}') + total_params += backbone_params + + # 2. Text Encoder + text_flops, text_params = count_text_encoder_flops(model, num_classes) + if text_flops is not None: + logger.info('\n[Text Encoder]') + logger.info(f' FLOPs: {format_flops(text_flops)} ' + f'(estimated, {num_classes} classes)') + logger.info(f' Params: {format_params(text_params)}') + total_flops += text_flops + total_params += text_params + + # 3. Neck + neck_flops, neck_params = count_neck_flops(model, (h, w), cfg) + if neck_params > 0: + logger.info('\n[Neck (ChannelMapper)]') + logger.info(f' FLOPs: {format_flops(neck_flops)} (estimated)') + logger.info(f' Params: {format_params(neck_params)}') + total_flops += neck_flops + total_params += neck_params + + # 4. Transformer Encoder/Decoder + enc_flops, dec_flops, enc_params, dec_params = \ + count_transformer_flops(model, (h, w), cfg) + logger.info('\n[Transformer Encoder]') + logger.info(f' FLOPs: {format_flops(enc_flops)} (estimated)') + logger.info(f' Params: {format_params(enc_params)}') + total_flops += enc_flops + total_params += enc_params + + logger.info('\n[Transformer Decoder]') + logger.info(f' FLOPs: {format_flops(dec_flops)} (estimated)') + logger.info(f' Params: {format_params(dec_params)}') + total_flops += dec_flops + total_params += dec_params + + # 5. Bbox Head + if hasattr(model, 'bbox_head'): + head_params = sum(p.numel() for p in model.bbox_head.parameters()) + logger.info('\n[Detection Head]') + logger.info(f' Params: {format_params(head_params)}') + total_params += head_params + + # Total + logger.info('\n' + split_line) + logger.info(f'TOTAL FLOPs: {format_flops(total_flops)}') + logger.info(f'TOTAL Parameters: {format_params(total_params)}') + logger.info(split_line) + + logger.warning('Note: Some FLOPs are estimated based on model ' + 'architecture. Backbone FLOPs are accurate if fvcore is ' + 'installed. Text encoder and transformer FLOPs are ' + 'theoretical estimates.') + + if FlopCountAnalysis is None: + logger.info('Tip: Install fvcore for accurate backbone FLOPs: ' + 'pip install fvcore') + + +if __name__ == '__main__': + main() diff --git a/test_get_flops_grounding.py b/test_get_flops_grounding.py new file mode 100644 index 00000000000..8e26a5f462c --- /dev/null +++ b/test_get_flops_grounding.py @@ -0,0 +1,326 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import sys +import unittest +from unittest.mock import MagicMock + +import torch +import torch.nn as nn +from mmengine.config import Config + +# Add tools path so we can import the module directly +sys.path.insert(0, 'tools/analysis_tools') # pragma: no cover +from get_flops_grounding import ( # noqa: E402 + _get_backbone_out_channels, _get_feature_strides, _get_neck_out_channels, + _get_transformer_config, count_neck_flops, count_text_encoder_flops, + count_transformer_flops, format_flops, format_params, +) + + +class TestFormatFlops(unittest.TestCase): + """Test format_flops helper.""" + + def test_tera(self): + self.assertEqual(format_flops(1.5e12), '1.50 T') + + def test_giga(self): + self.assertEqual(format_flops(2.34e9), '2.34 G') + + def test_mega(self): + self.assertEqual(format_flops(5.67e6), '5.67 M') + + def test_small(self): + self.assertEqual(format_flops(1234.0), '1234.00') + + def test_exact_boundary_tera(self): + self.assertEqual(format_flops(1e12), '1.00 T') + + def test_exact_boundary_giga(self): + self.assertEqual(format_flops(1e9), '1.00 G') + + def test_exact_boundary_mega(self): + self.assertEqual(format_flops(1e6), '1.00 M') + + +class TestFormatParams(unittest.TestCase): + """Test format_params helper.""" + + def test_billion(self): + self.assertEqual(format_params(1.2e9), '1.20 B') + + def test_million(self): + self.assertEqual(format_params(27.52e6), '27.52 M') + + def test_thousand(self): + self.assertEqual(format_params(3.5e3), '3.50 K') + + def test_small(self): + self.assertEqual(format_params(42), '42') + + def test_exact_boundary_billion(self): + self.assertEqual(format_params(1e9), '1.00 B') + + def test_exact_boundary_million(self): + self.assertEqual(format_params(1e6), '1.00 M') + + def test_exact_boundary_thousand(self): + self.assertEqual(format_params(1e3), '1.00 K') + + +class TestGetBackboneOutChannels(unittest.TestCase): + """Test _get_backbone_out_channels config reader.""" + + def test_from_neck_in_channels(self): + cfg = Config( + dict( + model=dict( + neck=dict(in_channels=[192, 384, 768]), + backbone=dict(type='SwinTransformer')))) + result = _get_backbone_out_channels(cfg) + self.assertEqual(result, [192, 384, 768]) + + def test_fallback_swin(self): + cfg = Config( + dict( + model=dict( + backbone=dict( + type='SwinTransformer', + embed_dims=96, + depths=[2, 2, 6, 2])))) + result = _get_backbone_out_channels(cfg) + self.assertEqual(result, [96, 192, 384, 768]) + + def test_fallback_swin_custom_dims(self): + cfg = Config( + dict( + model=dict( + backbone=dict( + type='SwinTransformer', + embed_dims=128, + depths=[2, 2, 18, 2])))) + result = _get_backbone_out_channels(cfg) + self.assertEqual(result, [128, 256, 512, 1024]) + + def test_fallback_generic(self): + cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) + result = _get_backbone_out_channels(cfg) + self.assertEqual(result, [256, 512, 1024, 2048]) + + +class TestGetFeatureStrides(unittest.TestCase): + """Test _get_feature_strides config reader.""" + + def test_from_neck(self): + cfg = Config(dict(model=dict(neck=dict(in_channels=[192, 384, 768])))) + result = _get_feature_strides(cfg) + self.assertEqual(result, [8, 16, 32]) + + def test_four_levels(self): + cfg = Config( + dict(model=dict(neck=dict(in_channels=[96, 192, 384, 768])))) + result = _get_feature_strides(cfg) + self.assertEqual(result, [8, 16, 32, 64]) + + def test_fallback(self): + cfg = Config(dict(model=dict())) + result = _get_feature_strides(cfg) + self.assertEqual(result, [8, 16, 32]) + + +class TestGetNeckOutChannels(unittest.TestCase): + """Test _get_neck_out_channels config reader.""" + + def test_from_config(self): + cfg = Config(dict(model=dict(neck=dict(out_channels=256)))) + self.assertEqual(_get_neck_out_channels(cfg), 256) + + def test_custom_channels(self): + cfg = Config(dict(model=dict(neck=dict(out_channels=512)))) + self.assertEqual(_get_neck_out_channels(cfg), 512) + + def test_fallback(self): + cfg = Config(dict(model=dict())) + self.assertEqual(_get_neck_out_channels(cfg), 256) + + +class TestGetTransformerConfig(unittest.TestCase): + """Test _get_transformer_config config reader.""" + + def test_defaults(self): + cfg = Config(dict(model=dict())) + result = _get_transformer_config(cfg) + self.assertEqual(result['embed_dim'], 256) + self.assertEqual(result['num_encoder_layers'], 6) + self.assertEqual(result['num_decoder_layers'], 6) + self.assertEqual(result['ffn_dim'], 2048) + self.assertEqual(result['num_queries'], 900) + + def test_from_encoder_config(self): + cfg = Config( + dict( + model=dict( + encoder=dict( + num_layers=3, + layer_cfg=dict( + self_attn_cfg=dict(embed_dims=512), + ffn_cfg=dict(feedforward_channels=1024)))))) + result = _get_transformer_config(cfg) + self.assertEqual(result['num_encoder_layers'], 3) + self.assertEqual(result['embed_dim'], 512) + self.assertEqual(result['ffn_dim'], 1024) + + def test_from_decoder_config(self): + cfg = Config(dict(model=dict(decoder=dict(num_layers=4)))) + result = _get_transformer_config(cfg) + self.assertEqual(result['num_decoder_layers'], 4) + + def test_num_queries_from_model(self): + cfg = Config(dict(model=dict(num_queries=300))) + result = _get_transformer_config(cfg) + self.assertEqual(result['num_queries'], 300) + + def test_num_queries_from_bbox_head(self): + cfg = Config(dict(model=dict(bbox_head=dict(num_queries=100)))) + result = _get_transformer_config(cfg) + self.assertEqual(result['num_queries'], 100) + + +class TestCountTextEncoderFlops(unittest.TestCase): + """Test count_text_encoder_flops.""" + + def _make_model_with_lang(self, name): + """Create a mock model with a language_model attribute.""" + model = MagicMock() + lang = MagicMock() + lang.name = name + # Create a small parameter for counting + param = nn.Parameter(torch.randn(10, 10)) + lang.parameters.return_value = [param] + model.language_model = lang + return model + + def test_no_language_model(self): + model = MagicMock(spec=[]) # no language_model attr + flops, params = count_text_encoder_flops(model, 80) + self.assertIsNone(flops) + self.assertIsNone(params) + + def test_clip_base(self): + model = self._make_model_with_lang('clip-base') + flops, params = count_text_encoder_flops(model, 80) + self.assertEqual(flops, 4e9 * 80) + self.assertEqual(params, 100) + + def test_clip_large(self): + model = self._make_model_with_lang('clip-large') + flops, params = count_text_encoder_flops(model, 1) + self.assertEqual(flops, 10e9) + + def test_bert(self): + model = self._make_model_with_lang('bert-base') + flops, params = count_text_encoder_flops(model, 1) + self.assertEqual(flops, 11e9) + + def test_unknown_defaults(self): + model = self._make_model_with_lang('some-model') + flops, params = count_text_encoder_flops(model, 1) + self.assertEqual(flops, 5e9) + + def test_num_classes_none(self): + model = self._make_model_with_lang('clip-base') + flops, _ = count_text_encoder_flops(model, None) + self.assertEqual(flops, 4e9) + + +class TestCountNeckFlops(unittest.TestCase): + """Test count_neck_flops.""" + + def test_no_neck(self): + model = MagicMock(spec=[]) + cfg = Config(dict(model=dict())) + flops, params = count_neck_flops(model, (800, 1333), cfg) + self.assertEqual(flops, 0) + self.assertEqual(params, 0) + + def test_neck_none(self): + model = MagicMock() + model.neck = None + cfg = Config(dict(model=dict())) + flops, params = count_neck_flops(model, (800, 1333), cfg) + self.assertEqual(flops, 0) + self.assertEqual(params, 0) + + def test_with_neck(self): + model = MagicMock() + neck = nn.Conv2d(3, 3, 1) + model.neck = neck + cfg = Config( + dict( + model=dict( + neck=dict(in_channels=[192, 384, 768], out_channels=256), + backbone=dict(type='SwinTransformer')))) + flops, params = count_neck_flops(model, (800, 1333), cfg) + self.assertGreater(flops, 0) + self.assertGreater(params, 0) + + +class TestCountTransformerFlops(unittest.TestCase): + """Test count_transformer_flops.""" + + def test_basic(self): + model = MagicMock(spec=[]) + cfg = Config(dict(model=dict(neck=dict(in_channels=[192, 384, 768])))) + enc_f, dec_f, enc_p, dec_p = count_transformer_flops( + model, (800, 1333), cfg) + self.assertGreater(enc_f, 0) + self.assertGreater(dec_f, 0) + # No encoder/decoder attrs on mock, so params = 0 + self.assertEqual(enc_p, 0) + self.assertEqual(dec_p, 0) + + def test_with_encoder_decoder(self): + model = MagicMock() + enc_param = nn.Parameter(torch.randn(10, 10)) + model.encoder.parameters.return_value = [enc_param] + dec_param = nn.Parameter(torch.randn(5, 5)) + model.decoder.parameters.return_value = [dec_param] + + cfg = Config( + dict( + model=dict( + encoder=dict(num_layers=6), + decoder=dict(num_layers=6), + neck=dict(in_channels=[192, 384, 768])))) + enc_f, dec_f, enc_p, dec_p = count_transformer_flops( + model, (800, 1333), cfg) + self.assertGreater(enc_f, 0) + self.assertGreater(dec_f, 0) + self.assertEqual(enc_p, 100) # 10*10 + self.assertEqual(dec_p, 25) # 5*5 + + def test_custom_config(self): + """Verify different configs produce different FLOPs.""" + model = MagicMock(spec=[]) + cfg_small = Config( + dict( + model=dict( + encoder=dict(num_layers=3), + decoder=dict(num_layers=3), + neck=dict(in_channels=[192, 384, 768])))) + cfg_large = Config( + dict( + model=dict( + encoder=dict(num_layers=6), + decoder=dict(num_layers=6), + neck=dict(in_channels=[192, 384, 768])))) + + enc_small, dec_small, _, _ = count_transformer_flops( + model, (800, 1333), cfg_small) + enc_large, dec_large, _, _ = count_transformer_flops( + model, (800, 1333), cfg_large) + + self.assertGreater(enc_large, enc_small) + self.assertGreater(dec_large, dec_small) + + +if __name__ == '__main__': + unittest.main() From 60764e01194770eddfbf2c9d21c212bfc5d9a137 Mon Sep 17 00:00:00 2001 From: 0xyangl Date: Tue, 10 Feb 2026 11:37:16 +0800 Subject: [PATCH 2/6] fix: pre-commit lint formatting newest --- tests/test_tools/test_get_flops_grounding.py | 331 ++++++++++++ tools/analysis_tools/get_flops_grounding.py | 525 +++++++++++++++++++ 2 files changed, 856 insertions(+) create mode 100644 tests/test_tools/test_get_flops_grounding.py create mode 100644 tools/analysis_tools/get_flops_grounding.py diff --git a/tests/test_tools/test_get_flops_grounding.py b/tests/test_tools/test_get_flops_grounding.py new file mode 100644 index 00000000000..c21ec0120db --- /dev/null +++ b/tests/test_tools/test_get_flops_grounding.py @@ -0,0 +1,331 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import sys +import unittest +from unittest.mock import MagicMock + +import torch +import torch.nn as nn +from mmengine.config import Config + +# Add tools path so we can import the module directly +sys.path.insert(0, 'tools/analysis_tools') +_mod = __import__('get_flops_grounding') +_get_backbone_out_channels = _mod._get_backbone_out_channels +_get_feature_strides = _mod._get_feature_strides +_get_neck_out_channels = _mod._get_neck_out_channels +_get_transformer_config = _mod._get_transformer_config +count_neck_flops = _mod.count_neck_flops +count_text_encoder_flops = _mod.count_text_encoder_flops +count_transformer_flops = _mod.count_transformer_flops +format_flops = _mod.format_flops +format_params = _mod.format_params + + +class TestFormatFlops(unittest.TestCase): + """Test format_flops helper.""" + + def test_tera(self): + self.assertEqual(format_flops(1.5e12), '1.50 T') + + def test_giga(self): + self.assertEqual(format_flops(2.34e9), '2.34 G') + + def test_mega(self): + self.assertEqual(format_flops(5.67e6), '5.67 M') + + def test_small(self): + self.assertEqual(format_flops(1234.0), '1234.00') + + def test_exact_boundary_tera(self): + self.assertEqual(format_flops(1e12), '1.00 T') + + def test_exact_boundary_giga(self): + self.assertEqual(format_flops(1e9), '1.00 G') + + def test_exact_boundary_mega(self): + self.assertEqual(format_flops(1e6), '1.00 M') + + +class TestFormatParams(unittest.TestCase): + """Test format_params helper.""" + + def test_billion(self): + self.assertEqual(format_params(1.2e9), '1.20 B') + + def test_million(self): + self.assertEqual(format_params(27.52e6), '27.52 M') + + def test_thousand(self): + self.assertEqual(format_params(3.5e3), '3.50 K') + + def test_small(self): + self.assertEqual(format_params(42), '42') + + def test_exact_boundary_billion(self): + self.assertEqual(format_params(1e9), '1.00 B') + + def test_exact_boundary_million(self): + self.assertEqual(format_params(1e6), '1.00 M') + + def test_exact_boundary_thousand(self): + self.assertEqual(format_params(1e3), '1.00 K') + + +class TestGetBackboneOutChannels(unittest.TestCase): + """Test _get_backbone_out_channels config reader.""" + + def test_from_neck_in_channels(self): + cfg = Config( + dict( + model=dict( + neck=dict(in_channels=[192, 384, 768]), + backbone=dict(type='SwinTransformer')))) + result = _get_backbone_out_channels(cfg) + self.assertEqual(result, [192, 384, 768]) + + def test_fallback_swin(self): + cfg = Config( + dict( + model=dict( + backbone=dict( + type='SwinTransformer', + embed_dims=96, + depths=[2, 2, 6, 2])))) + result = _get_backbone_out_channels(cfg) + self.assertEqual(result, [96, 192, 384, 768]) + + def test_fallback_swin_custom_dims(self): + cfg = Config( + dict( + model=dict( + backbone=dict( + type='SwinTransformer', + embed_dims=128, + depths=[2, 2, 18, 2])))) + result = _get_backbone_out_channels(cfg) + self.assertEqual(result, [128, 256, 512, 1024]) + + def test_fallback_generic(self): + cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) + result = _get_backbone_out_channels(cfg) + self.assertEqual(result, [256, 512, 1024, 2048]) + + +class TestGetFeatureStrides(unittest.TestCase): + """Test _get_feature_strides config reader.""" + + def test_from_neck(self): + cfg = Config(dict(model=dict(neck=dict(in_channels=[192, 384, 768])))) + result = _get_feature_strides(cfg) + self.assertEqual(result, [8, 16, 32]) + + def test_four_levels(self): + cfg = Config( + dict(model=dict(neck=dict(in_channels=[96, 192, 384, 768])))) + result = _get_feature_strides(cfg) + self.assertEqual(result, [8, 16, 32, 64]) + + def test_fallback(self): + cfg = Config(dict(model=dict())) + result = _get_feature_strides(cfg) + self.assertEqual(result, [8, 16, 32]) + + +class TestGetNeckOutChannels(unittest.TestCase): + """Test _get_neck_out_channels config reader.""" + + def test_from_config(self): + cfg = Config(dict(model=dict(neck=dict(out_channels=256)))) + self.assertEqual(_get_neck_out_channels(cfg), 256) + + def test_custom_channels(self): + cfg = Config(dict(model=dict(neck=dict(out_channels=512)))) + self.assertEqual(_get_neck_out_channels(cfg), 512) + + def test_fallback(self): + cfg = Config(dict(model=dict())) + self.assertEqual(_get_neck_out_channels(cfg), 256) + + +class TestGetTransformerConfig(unittest.TestCase): + """Test _get_transformer_config config reader.""" + + def test_defaults(self): + cfg = Config(dict(model=dict())) + result = _get_transformer_config(cfg) + self.assertEqual(result['embed_dim'], 256) + self.assertEqual(result['num_encoder_layers'], 6) + self.assertEqual(result['num_decoder_layers'], 6) + self.assertEqual(result['ffn_dim'], 2048) + self.assertEqual(result['num_queries'], 900) + + def test_from_encoder_config(self): + cfg = Config( + dict( + model=dict( + encoder=dict( + num_layers=3, + layer_cfg=dict( + self_attn_cfg=dict(embed_dims=512), + ffn_cfg=dict(feedforward_channels=1024)))))) + result = _get_transformer_config(cfg) + self.assertEqual(result['num_encoder_layers'], 3) + self.assertEqual(result['embed_dim'], 512) + self.assertEqual(result['ffn_dim'], 1024) + + def test_from_decoder_config(self): + cfg = Config(dict(model=dict(decoder=dict(num_layers=4)))) + result = _get_transformer_config(cfg) + self.assertEqual(result['num_decoder_layers'], 4) + + def test_num_queries_from_model(self): + cfg = Config(dict(model=dict(num_queries=300))) + result = _get_transformer_config(cfg) + self.assertEqual(result['num_queries'], 300) + + def test_num_queries_from_bbox_head(self): + cfg = Config(dict(model=dict(bbox_head=dict(num_queries=100)))) + result = _get_transformer_config(cfg) + self.assertEqual(result['num_queries'], 100) + + +class TestCountTextEncoderFlops(unittest.TestCase): + """Test count_text_encoder_flops.""" + + def _make_model_with_lang(self, name): + """Create a mock model with a language_model attribute.""" + model = MagicMock() + lang = MagicMock() + lang.name = name + # Create a small parameter for counting + param = nn.Parameter(torch.randn(10, 10)) + lang.parameters.return_value = [param] + model.language_model = lang + return model + + def test_no_language_model(self): + model = MagicMock(spec=[]) # no language_model attr + flops, params = count_text_encoder_flops(model, 80) + self.assertIsNone(flops) + self.assertIsNone(params) + + def test_clip_base(self): + model = self._make_model_with_lang('clip-base') + flops, params = count_text_encoder_flops(model, 80) + self.assertEqual(flops, 4e9 * 80) + self.assertEqual(params, 100) + + def test_clip_large(self): + model = self._make_model_with_lang('clip-large') + flops, params = count_text_encoder_flops(model, 1) + self.assertEqual(flops, 10e9) + + def test_bert(self): + model = self._make_model_with_lang('bert-base') + flops, params = count_text_encoder_flops(model, 1) + self.assertEqual(flops, 11e9) + + def test_unknown_defaults(self): + model = self._make_model_with_lang('some-model') + flops, params = count_text_encoder_flops(model, 1) + self.assertEqual(flops, 5e9) + + def test_num_classes_none(self): + model = self._make_model_with_lang('clip-base') + flops, _ = count_text_encoder_flops(model, None) + self.assertEqual(flops, 4e9) + + +class TestCountNeckFlops(unittest.TestCase): + """Test count_neck_flops.""" + + def test_no_neck(self): + model = MagicMock(spec=[]) + cfg = Config(dict(model=dict())) + flops, params = count_neck_flops(model, (800, 1333), cfg) + self.assertEqual(flops, 0) + self.assertEqual(params, 0) + + def test_neck_none(self): + model = MagicMock() + model.neck = None + cfg = Config(dict(model=dict())) + flops, params = count_neck_flops(model, (800, 1333), cfg) + self.assertEqual(flops, 0) + self.assertEqual(params, 0) + + def test_with_neck(self): + model = MagicMock() + neck = nn.Conv2d(3, 3, 1) + model.neck = neck + cfg = Config( + dict( + model=dict( + neck=dict(in_channels=[192, 384, 768], out_channels=256), + backbone=dict(type='SwinTransformer')))) + flops, params = count_neck_flops(model, (800, 1333), cfg) + self.assertGreater(flops, 0) + self.assertGreater(params, 0) + + +class TestCountTransformerFlops(unittest.TestCase): + """Test count_transformer_flops.""" + + def test_basic(self): + model = MagicMock(spec=[]) + cfg = Config(dict(model=dict(neck=dict(in_channels=[192, 384, 768])))) + enc_f, dec_f, enc_p, dec_p = count_transformer_flops( + model, (800, 1333), cfg) + self.assertGreater(enc_f, 0) + self.assertGreater(dec_f, 0) + # No encoder/decoder attrs on mock, so params = 0 + self.assertEqual(enc_p, 0) + self.assertEqual(dec_p, 0) + + def test_with_encoder_decoder(self): + model = MagicMock() + enc_param = nn.Parameter(torch.randn(10, 10)) + model.encoder.parameters.return_value = [enc_param] + dec_param = nn.Parameter(torch.randn(5, 5)) + model.decoder.parameters.return_value = [dec_param] + + cfg = Config( + dict( + model=dict( + encoder=dict(num_layers=6), + decoder=dict(num_layers=6), + neck=dict(in_channels=[192, 384, 768])))) + enc_f, dec_f, enc_p, dec_p = count_transformer_flops( + model, (800, 1333), cfg) + self.assertGreater(enc_f, 0) + self.assertGreater(dec_f, 0) + self.assertEqual(enc_p, 100) # 10*10 + self.assertEqual(dec_p, 25) # 5*5 + + def test_custom_config(self): + """Verify different configs produce different FLOPs.""" + model = MagicMock(spec=[]) + cfg_small = Config( + dict( + model=dict( + encoder=dict(num_layers=3), + decoder=dict(num_layers=3), + neck=dict(in_channels=[192, 384, 768])))) + cfg_large = Config( + dict( + model=dict( + encoder=dict(num_layers=6), + decoder=dict(num_layers=6), + neck=dict(in_channels=[192, 384, 768])))) + + enc_small, dec_small, _, _ = count_transformer_flops( + model, (800, 1333), cfg_small) + enc_large, dec_large, _, _ = count_transformer_flops( + model, (800, 1333), cfg_large) + + self.assertGreater(enc_large, enc_small) + self.assertGreater(dec_large, dec_small) + + +if __name__ == '__main__': + unittest.main() diff --git a/tools/analysis_tools/get_flops_grounding.py b/tools/analysis_tools/get_flops_grounding.py new file mode 100644 index 00000000000..bcb84bd3049 --- /dev/null +++ b/tools/analysis_tools/get_flops_grounding.py @@ -0,0 +1,525 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Get FLOPs for GroundingDINO/GroundingCLIP models. + +This script is specifically designed for models that require text inputs, +such as GroundingDINO, GroundingCLIP, and other vision-language models. + +Gradient Checkpointing (with_cp) Compatibility: + This script uses fvcore's FlopCountAnalysis for accurate backbone + FLOPs, which internally uses PyTorch JIT tracing. However, gradient + checkpointing (with_cp=True) is INCOMPATIBLE with JIT tracing. + This script automatically disables gradient checkpointing when + building the model for analysis. This does NOT affect model accuracy + or your original config file. + +Usage: + python tools/analysis_tools/get_flops_grounding.py + +Example: + python tools/analysis_tools/get_flops_grounding.py \\ + configs/mm_grounding_dino/grounding_dino_swin-t_finetune_8xb4_20e_cat.py +""" +import argparse +import tempfile +from pathlib import Path + +import torch +from mmengine.config import Config, DictAction +from mmengine.logging import MMLogger +from mmengine.model import revert_sync_batchnorm +from mmengine.registry import init_default_scope + +from mmdet.registry import MODELS + +try: + from fvcore.nn import FlopCountAnalysis +except ImportError: + FlopCountAnalysis = None + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description='Get FLOPs for grounding detection models') + parser.add_argument('config', help='train config file path') + parser.add_argument( + '--shape', + type=int, + nargs='+', + default=[800, 1333], + help='input image size (height width)') + parser.add_argument( + '--text', + type=str, + default='car', + help='text prompt for computing text encoder FLOPs') + parser.add_argument( + '--num-classes', + type=int, + default=None, + help='number of classes (for estimating text encoder FLOPs)') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config') + args = parser.parse_args() + return args + + +def format_flops(flops): + """Format FLOPs to human readable string. + + Args: + flops (int or float): Number of FLOPs. + + Returns: + str: Human readable FLOPs string. + """ + if flops >= 1e12: + return f'{flops / 1e12:.2f} T' + elif flops >= 1e9: + return f'{flops / 1e9:.2f} G' + elif flops >= 1e6: + return f'{flops / 1e6:.2f} M' + else: + return f'{flops:.2f}' + + +def format_params(params): + """Format parameters to human readable string. + + Args: + params (int or float): Number of parameters. + + Returns: + str: Human readable parameter count string. + """ + if params >= 1e9: + return f'{params / 1e9:.2f} B' + elif params >= 1e6: + return f'{params / 1e6:.2f} M' + elif params >= 1e3: + return f'{params / 1e3:.2f} K' + else: + return f'{params:.0f}' + + +def _get_backbone_out_channels(cfg): + """Extract backbone output channels from config. + + Tries to read from the neck's in_channels, then falls back to + common defaults based on backbone type. + + Args: + cfg (Config): Model config. + + Returns: + list[int]: Output channels per feature level. + """ + # Try reading from neck config (most reliable) + if hasattr(cfg.model, 'neck'): + neck_cfg = cfg.model.neck + if hasattr(neck_cfg, 'in_channels'): + return list(neck_cfg.in_channels) + + # Fallback: infer from backbone type + backbone_type = cfg.model.backbone.get('type', '') + if 'Swin' in backbone_type: + embed_dims = cfg.model.backbone.get('embed_dims', 96) + depths = cfg.model.backbone.get('depths', [2, 2, 6, 2]) + num_levels = len(depths) + return [embed_dims * (2**i) for i in range(num_levels)] + + # Generic default + return [256, 512, 1024, 2048] + + +def _get_feature_strides(cfg): + """Get feature map strides from config. + + Args: + cfg (Config): Model config. + + Returns: + list[int]: Strides for each feature level. + """ + if hasattr(cfg.model, 'neck'): + neck_cfg = cfg.model.neck + if hasattr(neck_cfg, 'in_channels'): + num_levels = len(neck_cfg.in_channels) + # Common strides: start at 8, double each level + return [8 * (2**i) for i in range(num_levels)] + return [8, 16, 32] + + +def _get_neck_out_channels(cfg): + """Get neck output channels from config. + + Args: + cfg (Config): Model config. + + Returns: + int: Output channels of the neck. + """ + if hasattr(cfg.model, 'neck'): + neck_cfg = cfg.model.neck + if hasattr(neck_cfg, 'out_channels'): + return neck_cfg.out_channels + return 256 + + +def _get_transformer_config(cfg): + """Extract transformer encoder/decoder config. + + Reads layer counts, embed_dim, ffn_dim, and num_queries from the + model config. Falls back to GroundingDINO defaults. + + Args: + cfg (Config): Model config. + + Returns: + dict: Transformer configuration with keys: embed_dim, + num_encoder_layers, num_decoder_layers, ffn_dim, + num_queries. + """ + result = dict( + embed_dim=256, + num_encoder_layers=6, + num_decoder_layers=6, + ffn_dim=2048, + num_queries=900) + + # Try to read from encoder config + if hasattr(cfg.model, 'encoder'): + enc = cfg.model.encoder + if hasattr(enc, 'num_layers'): + result['num_encoder_layers'] = enc.num_layers + if hasattr(enc, 'layer_cfg'): + layer = enc.layer_cfg + if hasattr(layer, 'self_attn_cfg'): + attn = layer.self_attn_cfg + if hasattr(attn, 'embed_dims'): + result['embed_dim'] = attn.embed_dims + if hasattr(layer, 'ffn_cfg'): + ffn = layer.ffn_cfg + if hasattr(ffn, 'feedforward_channels'): + result['ffn_dim'] = ffn.feedforward_channels + + # Try to read from decoder config + if hasattr(cfg.model, 'decoder'): + dec = cfg.model.decoder + if hasattr(dec, 'num_layers'): + result['num_decoder_layers'] = dec.num_layers + + # Try to read num_queries + if hasattr(cfg.model, 'num_queries'): + result['num_queries'] = cfg.model.num_queries + elif hasattr(cfg.model, 'bbox_head'): + head = cfg.model.bbox_head + if hasattr(head, 'num_queries'): + result['num_queries'] = head.num_queries + + return result + + +def count_backbone_flops(model, input_shape, device, logger): + """Count FLOPs for the vision backbone using fvcore. + + Args: + model (nn.Module): The detection model. + input_shape (tuple[int]): Input image (H, W). + device (torch.device): Device to run on. + logger (MMLogger): Logger instance. + + Returns: + tuple: (flops, params) or (None, params) if fvcore + is not available. + """ + if not hasattr(model, 'backbone'): + return None, None + + backbone = model.backbone + backbone.eval() + + h, w = input_shape + x = torch.randn(1, 3, h, w).to(device) + + if FlopCountAnalysis is not None: + flops_analyzer = FlopCountAnalysis(backbone, x) + flops_analyzer.unsupported_ops_warnings(False) + flops_analyzer.uncalled_modules_warnings(False) + flops = flops_analyzer.total() + params = sum(p.numel() for p in backbone.parameters()) + return flops, params + else: + logger.warning('fvcore is not installed, backbone FLOPs cannot be ' + 'computed accurately. Install with: pip install fvcore') + params = sum(p.numel() for p in backbone.parameters()) + return None, params + + +def count_text_encoder_flops(model, num_classes): + """Estimate FLOPs for the text encoder. + + Uses rough per-forward-pass estimates based on model type since + text encoders often have complex control flow that prevents + accurate tracing. + + Args: + model (nn.Module): The detection model. + num_classes (int): Number of classes. + + Returns: + tuple: (flops, params) or (None, None) if no language + model is found. + """ + if not hasattr(model, 'language_model'): + return None, None + + lang_model = model.language_model + params = sum(p.numel() for p in lang_model.parameters()) + + # Estimate FLOPs based on model type + model_name = getattr(lang_model, 'name', '') + if 'clip' in model_name.lower(): + if 'large' in model_name.lower(): + flops_per_forward = 10e9 + elif 'base' in model_name.lower(): + flops_per_forward = 4e9 + else: + flops_per_forward = 4e9 + elif 'bert' in model_name.lower(): + flops_per_forward = 11e9 + else: + flops_per_forward = 5e9 + + if num_classes: + flops = flops_per_forward * num_classes + else: + flops = flops_per_forward + + return flops, params + + +def count_neck_flops(model, input_shape, cfg): + """Estimate FLOPs for the neck (e.g. ChannelMapper). + + Reads in_channels and out_channels from the config to avoid + hardcoded assumptions. + + Args: + model (nn.Module): The detection model. + input_shape (tuple[int]): Input image (H, W). + cfg (Config): Model config. + + Returns: + tuple: (flops, params). + """ + if not hasattr(model, 'neck') or model.neck is None: + return 0, 0 + + neck = model.neck + params = sum(p.numel() for p in neck.parameters()) + + h, w = input_shape + in_channels = _get_backbone_out_channels(cfg) + strides = _get_feature_strides(cfg) + out_channels = _get_neck_out_channels(cfg) + + flops = 0 + for in_c, stride in zip(in_channels, strides): + fh, fw = h // stride, w // stride + # 1x1 conv: in_c * out_c * H * W * 2 (multiply-add) + flops += in_c * out_channels * fh * fw * 2 + + return flops, params + + +def count_transformer_flops(model, input_shape, cfg): + """Estimate FLOPs for the transformer encoder and decoder. + + Reads architecture parameters from the config dynamically + instead of hardcoding values. + + Args: + model (nn.Module): The detection model. + input_shape (tuple[int]): Input image (H, W). + cfg (Config): Model config. + + Returns: + tuple: (encoder_flops, decoder_flops, + encoder_params, decoder_params). + """ + h, w = input_shape + trans_cfg = _get_transformer_config(cfg) + embed_dim = trans_cfg['embed_dim'] + num_enc = trans_cfg['num_encoder_layers'] + num_dec = trans_cfg['num_decoder_layers'] + ffn_dim = trans_cfg['ffn_dim'] + num_queries = trans_cfg['num_queries'] + + strides = _get_feature_strides(cfg) + feat_sizes = [(h // s, w // s) for s in strides] + total_tokens = sum(fh * fw for fh, fw in feat_sizes) + + # Encoder self-attention: 4 * n^2 * d (Q,K,V proj + output) + enc_attn = (4 * total_tokens * total_tokens * embed_dim * num_enc) + # Encoder FFN: 2 * n * d * ffn_dim + enc_ffn = 2 * total_tokens * embed_dim * ffn_dim * num_enc + encoder_flops = enc_attn + enc_ffn + + # Decoder self-attention on queries + dec_self = (4 * num_queries * num_queries * embed_dim * num_dec) + # Decoder cross-attention to image features + dec_cross = (4 * num_queries * total_tokens * embed_dim * num_dec) + # Decoder FFN + dec_ffn = 2 * num_queries * embed_dim * ffn_dim * num_dec + decoder_flops = dec_self + dec_cross + dec_ffn + + encoder_params = 0 + if hasattr(model, 'encoder'): + encoder_params = sum(p.numel() for p in model.encoder.parameters()) + + decoder_params = 0 + if hasattr(model, 'decoder'): + decoder_params = sum(p.numel() for p in model.decoder.parameters()) + + return encoder_flops, decoder_flops, encoder_params, decoder_params + + +def main(): + args = parse_args() + logger = MMLogger.get_instance(name='MMLogger') + + config_path = Path(args.config) + if not config_path.exists(): + logger.error(f'{config_path} not found.') + return + + cfg = Config.fromfile(args.config) + cfg.work_dir = tempfile.TemporaryDirectory().name + + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + # Disable gradient checkpointing for FLOPs analysis. + # with_cp=True is incompatible with JIT tracing used by fvcore. + if hasattr(cfg.model, 'backbone'): + if cfg.model.backbone.get('with_cp', False): + cfg.model.backbone.with_cp = False + logger.warning('Auto-disabled gradient checkpointing ' + '(with_cp=False) for FLOPs analysis. ' + 'This does NOT affect your config file.') + + init_default_scope(cfg.get('default_scope', 'mmdet')) + + # Build model + model = MODELS.build(cfg.model) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + model = model.to(device) + model = revert_sync_batchnorm(model) + model.eval() + + # Get input shape + if len(args.shape) == 1: + h = w = args.shape[0] + else: + h, w = args.shape[:2] + + # Get number of classes + num_classes = args.num_classes + if num_classes is None: + if (hasattr(cfg.model, 'bbox_head') + and hasattr(cfg.model.bbox_head, 'num_classes')): + num_classes = cfg.model.bbox_head.num_classes + else: + num_classes = 80 # Default COCO + + split_line = '=' * 60 + logger.info(split_line) + logger.info('GroundingDINO/GroundingCLIP FLOPs Analysis') + logger.info(split_line) + logger.info(f'Config: {args.config}') + logger.info(f'Input shape: ({h}, {w})') + logger.info(f'Number of classes: {num_classes}') + logger.info(split_line) + + total_flops = 0 + total_params = 0 + + # 1. Vision Backbone + backbone_flops, backbone_params = count_backbone_flops( + model, (h, w), device, logger) + if backbone_flops is not None: + logger.info('\n[Vision Backbone]') + logger.info(f' FLOPs: {format_flops(backbone_flops)}') + logger.info(f' Params: {format_params(backbone_params)}') + total_flops += backbone_flops + total_params += backbone_params + else: + logger.info('\n[Vision Backbone]') + logger.info(' FLOPs: (fvcore not installed, cannot compute)') + if backbone_params: + logger.info(f' Params: {format_params(backbone_params)}') + total_params += backbone_params + + # 2. Text Encoder + text_flops, text_params = count_text_encoder_flops(model, num_classes) + if text_flops is not None: + logger.info('\n[Text Encoder]') + logger.info(f' FLOPs: {format_flops(text_flops)} ' + f'(estimated, {num_classes} classes)') + logger.info(f' Params: {format_params(text_params)}') + total_flops += text_flops + total_params += text_params + + # 3. Neck + neck_flops, neck_params = count_neck_flops(model, (h, w), cfg) + if neck_params > 0: + logger.info('\n[Neck (ChannelMapper)]') + logger.info(f' FLOPs: {format_flops(neck_flops)} (estimated)') + logger.info(f' Params: {format_params(neck_params)}') + total_flops += neck_flops + total_params += neck_params + + # 4. Transformer Encoder/Decoder + enc_flops, dec_flops, enc_params, dec_params = \ + count_transformer_flops(model, (h, w), cfg) + logger.info('\n[Transformer Encoder]') + logger.info(f' FLOPs: {format_flops(enc_flops)} (estimated)') + logger.info(f' Params: {format_params(enc_params)}') + total_flops += enc_flops + total_params += enc_params + + logger.info('\n[Transformer Decoder]') + logger.info(f' FLOPs: {format_flops(dec_flops)} (estimated)') + logger.info(f' Params: {format_params(dec_params)}') + total_flops += dec_flops + total_params += dec_params + + # 5. Bbox Head + if hasattr(model, 'bbox_head'): + head_params = sum(p.numel() for p in model.bbox_head.parameters()) + logger.info('\n[Detection Head]') + logger.info(f' Params: {format_params(head_params)}') + total_params += head_params + + # Total + logger.info('\n' + split_line) + logger.info(f'TOTAL FLOPs: {format_flops(total_flops)}') + logger.info(f'TOTAL Parameters: {format_params(total_params)}') + logger.info(split_line) + + logger.warning('Note: Some FLOPs are estimated based on model ' + 'architecture. Backbone FLOPs are accurate if fvcore is ' + 'installed. Text encoder and transformer FLOPs are ' + 'theoretical estimates.') + + if FlopCountAnalysis is None: + logger.info('Tip: Install fvcore for accurate backbone FLOPs: ' + 'pip install fvcore') + + +if __name__ == '__main__': + main() From c78212b01097c58909de4f6fda881f940a626c37 Mon Sep 17 00:00:00 2001 From: 0xyangl Date: Tue, 10 Feb 2026 11:49:26 +0800 Subject: [PATCH 3/6] fix: remove duplicate root-level test file --- test_get_flops_grounding.py | 326 ------------------------------------ 1 file changed, 326 deletions(-) delete mode 100644 test_get_flops_grounding.py diff --git a/test_get_flops_grounding.py b/test_get_flops_grounding.py deleted file mode 100644 index 8e26a5f462c..00000000000 --- a/test_get_flops_grounding.py +++ /dev/null @@ -1,326 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import sys -import unittest -from unittest.mock import MagicMock - -import torch -import torch.nn as nn -from mmengine.config import Config - -# Add tools path so we can import the module directly -sys.path.insert(0, 'tools/analysis_tools') # pragma: no cover -from get_flops_grounding import ( # noqa: E402 - _get_backbone_out_channels, _get_feature_strides, _get_neck_out_channels, - _get_transformer_config, count_neck_flops, count_text_encoder_flops, - count_transformer_flops, format_flops, format_params, -) - - -class TestFormatFlops(unittest.TestCase): - """Test format_flops helper.""" - - def test_tera(self): - self.assertEqual(format_flops(1.5e12), '1.50 T') - - def test_giga(self): - self.assertEqual(format_flops(2.34e9), '2.34 G') - - def test_mega(self): - self.assertEqual(format_flops(5.67e6), '5.67 M') - - def test_small(self): - self.assertEqual(format_flops(1234.0), '1234.00') - - def test_exact_boundary_tera(self): - self.assertEqual(format_flops(1e12), '1.00 T') - - def test_exact_boundary_giga(self): - self.assertEqual(format_flops(1e9), '1.00 G') - - def test_exact_boundary_mega(self): - self.assertEqual(format_flops(1e6), '1.00 M') - - -class TestFormatParams(unittest.TestCase): - """Test format_params helper.""" - - def test_billion(self): - self.assertEqual(format_params(1.2e9), '1.20 B') - - def test_million(self): - self.assertEqual(format_params(27.52e6), '27.52 M') - - def test_thousand(self): - self.assertEqual(format_params(3.5e3), '3.50 K') - - def test_small(self): - self.assertEqual(format_params(42), '42') - - def test_exact_boundary_billion(self): - self.assertEqual(format_params(1e9), '1.00 B') - - def test_exact_boundary_million(self): - self.assertEqual(format_params(1e6), '1.00 M') - - def test_exact_boundary_thousand(self): - self.assertEqual(format_params(1e3), '1.00 K') - - -class TestGetBackboneOutChannels(unittest.TestCase): - """Test _get_backbone_out_channels config reader.""" - - def test_from_neck_in_channels(self): - cfg = Config( - dict( - model=dict( - neck=dict(in_channels=[192, 384, 768]), - backbone=dict(type='SwinTransformer')))) - result = _get_backbone_out_channels(cfg) - self.assertEqual(result, [192, 384, 768]) - - def test_fallback_swin(self): - cfg = Config( - dict( - model=dict( - backbone=dict( - type='SwinTransformer', - embed_dims=96, - depths=[2, 2, 6, 2])))) - result = _get_backbone_out_channels(cfg) - self.assertEqual(result, [96, 192, 384, 768]) - - def test_fallback_swin_custom_dims(self): - cfg = Config( - dict( - model=dict( - backbone=dict( - type='SwinTransformer', - embed_dims=128, - depths=[2, 2, 18, 2])))) - result = _get_backbone_out_channels(cfg) - self.assertEqual(result, [128, 256, 512, 1024]) - - def test_fallback_generic(self): - cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) - result = _get_backbone_out_channels(cfg) - self.assertEqual(result, [256, 512, 1024, 2048]) - - -class TestGetFeatureStrides(unittest.TestCase): - """Test _get_feature_strides config reader.""" - - def test_from_neck(self): - cfg = Config(dict(model=dict(neck=dict(in_channels=[192, 384, 768])))) - result = _get_feature_strides(cfg) - self.assertEqual(result, [8, 16, 32]) - - def test_four_levels(self): - cfg = Config( - dict(model=dict(neck=dict(in_channels=[96, 192, 384, 768])))) - result = _get_feature_strides(cfg) - self.assertEqual(result, [8, 16, 32, 64]) - - def test_fallback(self): - cfg = Config(dict(model=dict())) - result = _get_feature_strides(cfg) - self.assertEqual(result, [8, 16, 32]) - - -class TestGetNeckOutChannels(unittest.TestCase): - """Test _get_neck_out_channels config reader.""" - - def test_from_config(self): - cfg = Config(dict(model=dict(neck=dict(out_channels=256)))) - self.assertEqual(_get_neck_out_channels(cfg), 256) - - def test_custom_channels(self): - cfg = Config(dict(model=dict(neck=dict(out_channels=512)))) - self.assertEqual(_get_neck_out_channels(cfg), 512) - - def test_fallback(self): - cfg = Config(dict(model=dict())) - self.assertEqual(_get_neck_out_channels(cfg), 256) - - -class TestGetTransformerConfig(unittest.TestCase): - """Test _get_transformer_config config reader.""" - - def test_defaults(self): - cfg = Config(dict(model=dict())) - result = _get_transformer_config(cfg) - self.assertEqual(result['embed_dim'], 256) - self.assertEqual(result['num_encoder_layers'], 6) - self.assertEqual(result['num_decoder_layers'], 6) - self.assertEqual(result['ffn_dim'], 2048) - self.assertEqual(result['num_queries'], 900) - - def test_from_encoder_config(self): - cfg = Config( - dict( - model=dict( - encoder=dict( - num_layers=3, - layer_cfg=dict( - self_attn_cfg=dict(embed_dims=512), - ffn_cfg=dict(feedforward_channels=1024)))))) - result = _get_transformer_config(cfg) - self.assertEqual(result['num_encoder_layers'], 3) - self.assertEqual(result['embed_dim'], 512) - self.assertEqual(result['ffn_dim'], 1024) - - def test_from_decoder_config(self): - cfg = Config(dict(model=dict(decoder=dict(num_layers=4)))) - result = _get_transformer_config(cfg) - self.assertEqual(result['num_decoder_layers'], 4) - - def test_num_queries_from_model(self): - cfg = Config(dict(model=dict(num_queries=300))) - result = _get_transformer_config(cfg) - self.assertEqual(result['num_queries'], 300) - - def test_num_queries_from_bbox_head(self): - cfg = Config(dict(model=dict(bbox_head=dict(num_queries=100)))) - result = _get_transformer_config(cfg) - self.assertEqual(result['num_queries'], 100) - - -class TestCountTextEncoderFlops(unittest.TestCase): - """Test count_text_encoder_flops.""" - - def _make_model_with_lang(self, name): - """Create a mock model with a language_model attribute.""" - model = MagicMock() - lang = MagicMock() - lang.name = name - # Create a small parameter for counting - param = nn.Parameter(torch.randn(10, 10)) - lang.parameters.return_value = [param] - model.language_model = lang - return model - - def test_no_language_model(self): - model = MagicMock(spec=[]) # no language_model attr - flops, params = count_text_encoder_flops(model, 80) - self.assertIsNone(flops) - self.assertIsNone(params) - - def test_clip_base(self): - model = self._make_model_with_lang('clip-base') - flops, params = count_text_encoder_flops(model, 80) - self.assertEqual(flops, 4e9 * 80) - self.assertEqual(params, 100) - - def test_clip_large(self): - model = self._make_model_with_lang('clip-large') - flops, params = count_text_encoder_flops(model, 1) - self.assertEqual(flops, 10e9) - - def test_bert(self): - model = self._make_model_with_lang('bert-base') - flops, params = count_text_encoder_flops(model, 1) - self.assertEqual(flops, 11e9) - - def test_unknown_defaults(self): - model = self._make_model_with_lang('some-model') - flops, params = count_text_encoder_flops(model, 1) - self.assertEqual(flops, 5e9) - - def test_num_classes_none(self): - model = self._make_model_with_lang('clip-base') - flops, _ = count_text_encoder_flops(model, None) - self.assertEqual(flops, 4e9) - - -class TestCountNeckFlops(unittest.TestCase): - """Test count_neck_flops.""" - - def test_no_neck(self): - model = MagicMock(spec=[]) - cfg = Config(dict(model=dict())) - flops, params = count_neck_flops(model, (800, 1333), cfg) - self.assertEqual(flops, 0) - self.assertEqual(params, 0) - - def test_neck_none(self): - model = MagicMock() - model.neck = None - cfg = Config(dict(model=dict())) - flops, params = count_neck_flops(model, (800, 1333), cfg) - self.assertEqual(flops, 0) - self.assertEqual(params, 0) - - def test_with_neck(self): - model = MagicMock() - neck = nn.Conv2d(3, 3, 1) - model.neck = neck - cfg = Config( - dict( - model=dict( - neck=dict(in_channels=[192, 384, 768], out_channels=256), - backbone=dict(type='SwinTransformer')))) - flops, params = count_neck_flops(model, (800, 1333), cfg) - self.assertGreater(flops, 0) - self.assertGreater(params, 0) - - -class TestCountTransformerFlops(unittest.TestCase): - """Test count_transformer_flops.""" - - def test_basic(self): - model = MagicMock(spec=[]) - cfg = Config(dict(model=dict(neck=dict(in_channels=[192, 384, 768])))) - enc_f, dec_f, enc_p, dec_p = count_transformer_flops( - model, (800, 1333), cfg) - self.assertGreater(enc_f, 0) - self.assertGreater(dec_f, 0) - # No encoder/decoder attrs on mock, so params = 0 - self.assertEqual(enc_p, 0) - self.assertEqual(dec_p, 0) - - def test_with_encoder_decoder(self): - model = MagicMock() - enc_param = nn.Parameter(torch.randn(10, 10)) - model.encoder.parameters.return_value = [enc_param] - dec_param = nn.Parameter(torch.randn(5, 5)) - model.decoder.parameters.return_value = [dec_param] - - cfg = Config( - dict( - model=dict( - encoder=dict(num_layers=6), - decoder=dict(num_layers=6), - neck=dict(in_channels=[192, 384, 768])))) - enc_f, dec_f, enc_p, dec_p = count_transformer_flops( - model, (800, 1333), cfg) - self.assertGreater(enc_f, 0) - self.assertGreater(dec_f, 0) - self.assertEqual(enc_p, 100) # 10*10 - self.assertEqual(dec_p, 25) # 5*5 - - def test_custom_config(self): - """Verify different configs produce different FLOPs.""" - model = MagicMock(spec=[]) - cfg_small = Config( - dict( - model=dict( - encoder=dict(num_layers=3), - decoder=dict(num_layers=3), - neck=dict(in_channels=[192, 384, 768])))) - cfg_large = Config( - dict( - model=dict( - encoder=dict(num_layers=6), - decoder=dict(num_layers=6), - neck=dict(in_channels=[192, 384, 768])))) - - enc_small, dec_small, _, _ = count_transformer_flops( - model, (800, 1333), cfg_small) - enc_large, dec_large, _, _ = count_transformer_flops( - model, (800, 1333), cfg_large) - - self.assertGreater(enc_large, enc_small) - self.assertGreater(dec_large, dec_small) - - -if __name__ == '__main__': - unittest.main() From 750d05ddcd4f4b0f681ee344f0b29b429968d9ad Mon Sep 17 00:00:00 2001 From: 0xyangl Date: Tue, 10 Feb 2026 11:50:54 +0800 Subject: [PATCH 4/6] fix: remove duplicate root-level script file --- get_flops_grounding.py | 525 ----------------------------------------- 1 file changed, 525 deletions(-) delete mode 100644 get_flops_grounding.py diff --git a/get_flops_grounding.py b/get_flops_grounding.py deleted file mode 100644 index 8a5aa2c47c3..00000000000 --- a/get_flops_grounding.py +++ /dev/null @@ -1,525 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -"""Get FLOPs for GroundingDINO models. - -This script is specifically designed for models that require text inputs, -such as GroundingDINO and other vision-language models. - -Gradient Checkpointing (with_cp) Compatibility: - This script uses fvcore's FlopCountAnalysis for accurate backbone - FLOPs, which internally uses PyTorch JIT tracing. However, gradient - checkpointing (with_cp=True) is INCOMPATIBLE with JIT tracing. - This script automatically disables gradient checkpointing when - building the model for analysis. This does NOT affect model accuracy - or your original config file. - -Usage: - python tools/analysis_tools/get_flops_grounding.py - -Example: - python tools/analysis_tools/get_flops_grounding.py \\ - configs/mm_grounding_dino/grounding_dino_swin-t_finetune_8xb4_20e_cat.py -""" -import argparse -import tempfile -from pathlib import Path - -import torch -from mmengine.config import Config, DictAction -from mmengine.logging import MMLogger -from mmengine.model import revert_sync_batchnorm -from mmengine.registry import init_default_scope - -from mmdet.registry import MODELS - -try: - from fvcore.nn import FlopCountAnalysis -except ImportError: - FlopCountAnalysis = None - - -def parse_args(): - """Parse command line arguments.""" - parser = argparse.ArgumentParser( - description='Get FLOPs for grounding detection models') - parser.add_argument('config', help='train config file path') - parser.add_argument( - '--shape', - type=int, - nargs='+', - default=[800, 1333], - help='input image size (height width)') - parser.add_argument( - '--text', - type=str, - default='car', - help='text prompt for computing text encoder FLOPs') - parser.add_argument( - '--num-classes', - type=int, - default=None, - help='number of classes (for estimating text encoder FLOPs)') - parser.add_argument( - '--cfg-options', - nargs='+', - action=DictAction, - help='override some settings in the used config') - args = parser.parse_args() - return args - - -def format_flops(flops): - """Format FLOPs to human readable string. - - Args: - flops (int or float): Number of FLOPs. - - Returns: - str: Human readable FLOPs string. - """ - if flops >= 1e12: - return f'{flops / 1e12:.2f} T' - elif flops >= 1e9: - return f'{flops / 1e9:.2f} G' - elif flops >= 1e6: - return f'{flops / 1e6:.2f} M' - else: - return f'{flops:.2f}' - - -def format_params(params): - """Format parameters to human readable string. - - Args: - params (int or float): Number of parameters. - - Returns: - str: Human readable parameter count string. - """ - if params >= 1e9: - return f'{params / 1e9:.2f} B' - elif params >= 1e6: - return f'{params / 1e6:.2f} M' - elif params >= 1e3: - return f'{params / 1e3:.2f} K' - else: - return f'{params:.0f}' - - -def _get_backbone_out_channels(cfg): - """Extract backbone output channels from config. - - Tries to read from the neck's in_channels, then falls back to - common defaults based on backbone type. - - Args: - cfg (Config): Model config. - - Returns: - list[int]: Output channels per feature level. - """ - # Try reading from neck config (most reliable) - if hasattr(cfg.model, 'neck'): - neck_cfg = cfg.model.neck - if hasattr(neck_cfg, 'in_channels'): - return list(neck_cfg.in_channels) - - # Fallback: infer from backbone type - backbone_type = cfg.model.backbone.get('type', '') - if 'Swin' in backbone_type: - embed_dims = cfg.model.backbone.get('embed_dims', 96) - depths = cfg.model.backbone.get('depths', [2, 2, 6, 2]) - num_levels = len(depths) - return [embed_dims * (2**i) for i in range(num_levels)] - - # Generic default - return [256, 512, 1024, 2048] - - -def _get_feature_strides(cfg): - """Get feature map strides from config. - - Args: - cfg (Config): Model config. - - Returns: - list[int]: Strides for each feature level. - """ - if hasattr(cfg.model, 'neck'): - neck_cfg = cfg.model.neck - if hasattr(neck_cfg, 'in_channels'): - num_levels = len(neck_cfg.in_channels) - # Common strides: start at 8, double each level - return [8 * (2**i) for i in range(num_levels)] - return [8, 16, 32] - - -def _get_neck_out_channels(cfg): - """Get neck output channels from config. - - Args: - cfg (Config): Model config. - - Returns: - int: Output channels of the neck. - """ - if hasattr(cfg.model, 'neck'): - neck_cfg = cfg.model.neck - if hasattr(neck_cfg, 'out_channels'): - return neck_cfg.out_channels - return 256 - - -def _get_transformer_config(cfg): - """Extract transformer encoder/decoder config. - - Reads layer counts, embed_dim, ffn_dim, and num_queries from the - model config. Falls back to GroundingDINO defaults. - - Args: - cfg (Config): Model config. - - Returns: - dict: Transformer configuration with keys: embed_dim, - num_encoder_layers, num_decoder_layers, ffn_dim, - num_queries. - """ - result = dict( - embed_dim=256, - num_encoder_layers=6, - num_decoder_layers=6, - ffn_dim=2048, - num_queries=900) - - # Try to read from encoder config - if hasattr(cfg.model, 'encoder'): - enc = cfg.model.encoder - if hasattr(enc, 'num_layers'): - result['num_encoder_layers'] = enc.num_layers - if hasattr(enc, 'layer_cfg'): - layer = enc.layer_cfg - if hasattr(layer, 'self_attn_cfg'): - attn = layer.self_attn_cfg - if hasattr(attn, 'embed_dims'): - result['embed_dim'] = attn.embed_dims - if hasattr(layer, 'ffn_cfg'): - ffn = layer.ffn_cfg - if hasattr(ffn, 'feedforward_channels'): - result['ffn_dim'] = ffn.feedforward_channels - - # Try to read from decoder config - if hasattr(cfg.model, 'decoder'): - dec = cfg.model.decoder - if hasattr(dec, 'num_layers'): - result['num_decoder_layers'] = dec.num_layers - - # Try to read num_queries - if hasattr(cfg.model, 'num_queries'): - result['num_queries'] = cfg.model.num_queries - elif hasattr(cfg.model, 'bbox_head'): - head = cfg.model.bbox_head - if hasattr(head, 'num_queries'): - result['num_queries'] = head.num_queries - - return result - - -def count_backbone_flops(model, input_shape, device, logger): - """Count FLOPs for the vision backbone using fvcore. - - Args: - model (nn.Module): The detection model. - input_shape (tuple[int]): Input image (H, W). - device (torch.device): Device to run on. - logger (MMLogger): Logger instance. - - Returns: - tuple: (flops, params) or (None, params) if fvcore - is not available. - """ - if not hasattr(model, 'backbone'): - return None, None - - backbone = model.backbone - backbone.eval() - - h, w = input_shape - x = torch.randn(1, 3, h, w).to(device) - - if FlopCountAnalysis is not None: - flops_analyzer = FlopCountAnalysis(backbone, x) - flops_analyzer.unsupported_ops_warnings(False) - flops_analyzer.uncalled_modules_warnings(False) - flops = flops_analyzer.total() - params = sum(p.numel() for p in backbone.parameters()) - return flops, params - else: - logger.warning('fvcore is not installed, backbone FLOPs cannot be ' - 'computed accurately. Install with: pip install fvcore') - params = sum(p.numel() for p in backbone.parameters()) - return None, params - - -def count_text_encoder_flops(model, num_classes): - """Estimate FLOPs for the text encoder. - - Uses rough per-forward-pass estimates based on model type since - text encoders often have complex control flow that prevents - accurate tracing. - - Args: - model (nn.Module): The detection model. - num_classes (int): Number of classes. - - Returns: - tuple: (flops, params) or (None, None) if no language - model is found. - """ - if not hasattr(model, 'language_model'): - return None, None - - lang_model = model.language_model - params = sum(p.numel() for p in lang_model.parameters()) - - # Estimate FLOPs based on model type - model_name = getattr(lang_model, 'name', '') - if 'clip' in model_name.lower(): - if 'large' in model_name.lower(): - flops_per_forward = 10e9 - elif 'base' in model_name.lower(): - flops_per_forward = 4e9 - else: - flops_per_forward = 4e9 - elif 'bert' in model_name.lower(): - flops_per_forward = 11e9 - else: - flops_per_forward = 5e9 - - if num_classes: - flops = flops_per_forward * num_classes - else: - flops = flops_per_forward - - return flops, params - - -def count_neck_flops(model, input_shape, cfg): - """Estimate FLOPs for the neck (e.g. ChannelMapper). - - Reads in_channels and out_channels from the config to avoid - hardcoded assumptions. - - Args: - model (nn.Module): The detection model. - input_shape (tuple[int]): Input image (H, W). - cfg (Config): Model config. - - Returns: - tuple: (flops, params). - """ - if not hasattr(model, 'neck') or model.neck is None: - return 0, 0 - - neck = model.neck - params = sum(p.numel() for p in neck.parameters()) - - h, w = input_shape - in_channels = _get_backbone_out_channels(cfg) - strides = _get_feature_strides(cfg) - out_channels = _get_neck_out_channels(cfg) - - flops = 0 - for in_c, stride in zip(in_channels, strides): - fh, fw = h // stride, w // stride - # 1x1 conv: in_c * out_c * H * W * 2 (multiply-add) - flops += in_c * out_channels * fh * fw * 2 - - return flops, params - - -def count_transformer_flops(model, input_shape, cfg): - """Estimate FLOPs for the transformer encoder and decoder. - - Reads architecture parameters from the config dynamically - instead of hardcoding values. - - Args: - model (nn.Module): The detection model. - input_shape (tuple[int]): Input image (H, W). - cfg (Config): Model config. - - Returns: - tuple: (encoder_flops, decoder_flops, - encoder_params, decoder_params). - """ - h, w = input_shape - trans_cfg = _get_transformer_config(cfg) - embed_dim = trans_cfg['embed_dim'] - num_enc = trans_cfg['num_encoder_layers'] - num_dec = trans_cfg['num_decoder_layers'] - ffn_dim = trans_cfg['ffn_dim'] - num_queries = trans_cfg['num_queries'] - - strides = _get_feature_strides(cfg) - feat_sizes = [(h // s, w // s) for s in strides] - total_tokens = sum(fh * fw for fh, fw in feat_sizes) - - # Encoder self-attention: 4 * n^2 * d (Q,K,V proj + output) - enc_attn = (4 * total_tokens * total_tokens * embed_dim * num_enc) - # Encoder FFN: 2 * n * d * ffn_dim - enc_ffn = 2 * total_tokens * embed_dim * ffn_dim * num_enc - encoder_flops = enc_attn + enc_ffn - - # Decoder self-attention on queries - dec_self = (4 * num_queries * num_queries * embed_dim * num_dec) - # Decoder cross-attention to image features - dec_cross = (4 * num_queries * total_tokens * embed_dim * num_dec) - # Decoder FFN - dec_ffn = 2 * num_queries * embed_dim * ffn_dim * num_dec - decoder_flops = dec_self + dec_cross + dec_ffn - - encoder_params = 0 - if hasattr(model, 'encoder'): - encoder_params = sum(p.numel() for p in model.encoder.parameters()) - - decoder_params = 0 - if hasattr(model, 'decoder'): - decoder_params = sum(p.numel() for p in model.decoder.parameters()) - - return encoder_flops, decoder_flops, encoder_params, decoder_params - - -def main(): - args = parse_args() - logger = MMLogger.get_instance(name='MMLogger') - - config_path = Path(args.config) - if not config_path.exists(): - logger.error(f'{config_path} not found.') - return - - cfg = Config.fromfile(args.config) - cfg.work_dir = tempfile.TemporaryDirectory().name - - if args.cfg_options is not None: - cfg.merge_from_dict(args.cfg_options) - - # Disable gradient checkpointing for FLOPs analysis. - # with_cp=True is incompatible with JIT tracing used by fvcore. - if hasattr(cfg.model, 'backbone'): - if cfg.model.backbone.get('with_cp', False): - cfg.model.backbone.with_cp = False - logger.warning('Auto-disabled gradient checkpointing ' - '(with_cp=False) for FLOPs analysis. ' - 'This does NOT affect your config file.') - - init_default_scope(cfg.get('default_scope', 'mmdet')) - - # Build model - model = MODELS.build(cfg.model) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - model = model.to(device) - model = revert_sync_batchnorm(model) - model.eval() - - # Get input shape - if len(args.shape) == 1: - h = w = args.shape[0] - else: - h, w = args.shape[:2] - - # Get number of classes - num_classes = args.num_classes - if num_classes is None: - if (hasattr(cfg.model, 'bbox_head') - and hasattr(cfg.model.bbox_head, 'num_classes')): - num_classes = cfg.model.bbox_head.num_classes - else: - num_classes = 80 # Default COCO - - split_line = '=' * 60 - logger.info(split_line) - logger.info('GroundingDINO FLOPs Analysis') - logger.info(split_line) - logger.info(f'Config: {args.config}') - logger.info(f'Input shape: ({h}, {w})') - logger.info(f'Number of classes: {num_classes}') - logger.info(split_line) - - total_flops = 0 - total_params = 0 - - # 1. Vision Backbone - backbone_flops, backbone_params = count_backbone_flops( - model, (h, w), device, logger) - if backbone_flops is not None: - logger.info('\n[Vision Backbone]') - logger.info(f' FLOPs: {format_flops(backbone_flops)}') - logger.info(f' Params: {format_params(backbone_params)}') - total_flops += backbone_flops - total_params += backbone_params - else: - logger.info('\n[Vision Backbone]') - logger.info(' FLOPs: (fvcore not installed, cannot compute)') - if backbone_params: - logger.info(f' Params: {format_params(backbone_params)}') - total_params += backbone_params - - # 2. Text Encoder - text_flops, text_params = count_text_encoder_flops(model, num_classes) - if text_flops is not None: - logger.info('\n[Text Encoder]') - logger.info(f' FLOPs: {format_flops(text_flops)} ' - f'(estimated, {num_classes} classes)') - logger.info(f' Params: {format_params(text_params)}') - total_flops += text_flops - total_params += text_params - - # 3. Neck - neck_flops, neck_params = count_neck_flops(model, (h, w), cfg) - if neck_params > 0: - logger.info('\n[Neck (ChannelMapper)]') - logger.info(f' FLOPs: {format_flops(neck_flops)} (estimated)') - logger.info(f' Params: {format_params(neck_params)}') - total_flops += neck_flops - total_params += neck_params - - # 4. Transformer Encoder/Decoder - enc_flops, dec_flops, enc_params, dec_params = \ - count_transformer_flops(model, (h, w), cfg) - logger.info('\n[Transformer Encoder]') - logger.info(f' FLOPs: {format_flops(enc_flops)} (estimated)') - logger.info(f' Params: {format_params(enc_params)}') - total_flops += enc_flops - total_params += enc_params - - logger.info('\n[Transformer Decoder]') - logger.info(f' FLOPs: {format_flops(dec_flops)} (estimated)') - logger.info(f' Params: {format_params(dec_params)}') - total_flops += dec_flops - total_params += dec_params - - # 5. Bbox Head - if hasattr(model, 'bbox_head'): - head_params = sum(p.numel() for p in model.bbox_head.parameters()) - logger.info('\n[Detection Head]') - logger.info(f' Params: {format_params(head_params)}') - total_params += head_params - - # Total - logger.info('\n' + split_line) - logger.info(f'TOTAL FLOPs: {format_flops(total_flops)}') - logger.info(f'TOTAL Parameters: {format_params(total_params)}') - logger.info(split_line) - - logger.warning('Note: Some FLOPs are estimated based on model ' - 'architecture. Backbone FLOPs are accurate if fvcore is ' - 'installed. Text encoder and transformer FLOPs are ' - 'theoretical estimates.') - - if FlopCountAnalysis is None: - logger.info('Tip: Install fvcore for accurate backbone FLOPs: ' - 'pip install fvcore') - - -if __name__ == '__main__': - main() From 8044e1f945c126dcf372b4a170f85a04057806e0 Mon Sep 17 00:00:00 2001 From: lauriebax <47383333+lauriebax@users.noreply.github.com> Date: Tue, 10 Feb 2026 09:13:15 +0100 Subject: [PATCH 5/6] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- tools/analysis_tools/get_flops_grounding.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tools/analysis_tools/get_flops_grounding.py b/tools/analysis_tools/get_flops_grounding.py index bcb84bd3049..86d8f30cb12 100644 --- a/tools/analysis_tools/get_flops_grounding.py +++ b/tools/analysis_tools/get_flops_grounding.py @@ -246,11 +246,12 @@ def count_backbone_flops(model, input_shape, device, logger): x = torch.randn(1, 3, h, w).to(device) if FlopCountAnalysis is not None: - flops_analyzer = FlopCountAnalysis(backbone, x) - flops_analyzer.unsupported_ops_warnings(False) - flops_analyzer.uncalled_modules_warnings(False) - flops = flops_analyzer.total() - params = sum(p.numel() for p in backbone.parameters()) + with torch.no_grad(): + flops_analyzer = FlopCountAnalysis(backbone, x) + flops_analyzer.unsupported_ops_warnings(False) + flops_analyzer.uncalled_modules_warnings(False) + flops = flops_analyzer.total() + params = sum(p.numel() for p in backbone.parameters()) return flops, params else: logger.warning('fvcore is not installed, backbone FLOPs cannot be ' @@ -398,7 +399,7 @@ def main(): return cfg = Config.fromfile(args.config) - cfg.work_dir = tempfile.TemporaryDirectory().name + cfg.work_dir = tempfile.mkdtemp() if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) From 01937b35ad3b4dc59bf267a03f161ea0c6c0bd3a Mon Sep 17 00:00:00 2001 From: lauriebax <47383333+lauriebax@users.noreply.github.com> Date: Wed, 11 Feb 2026 17:11:06 +0100 Subject: [PATCH 6/6] remove --text argument for grounding flops tool Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- tools/analysis_tools/get_flops_grounding.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tools/analysis_tools/get_flops_grounding.py b/tools/analysis_tools/get_flops_grounding.py index 86d8f30cb12..bd9ad614f81 100644 --- a/tools/analysis_tools/get_flops_grounding.py +++ b/tools/analysis_tools/get_flops_grounding.py @@ -48,11 +48,6 @@ def parse_args(): nargs='+', default=[800, 1333], help='input image size (height width)') - parser.add_argument( - '--text', - type=str, - default='car', - help='text prompt for computing text encoder FLOPs') parser.add_argument( '--num-classes', type=int,