-
Notifications
You must be signed in to change notification settings - Fork 54
add torchvision detector functionality #164
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
01a2f2e
333f714
3a791aa
15b265c
9ee9768
0639363
3ca1d4e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,9 +14,12 @@ | |
| import torch | ||
| import torchvision.models.detection as detection | ||
|
|
||
| from dlclive.pose_estimation_pytorch.models.detectors.base import BaseDetector | ||
| from dlclive.pose_estimation_pytorch.models.detectors.base import DETECTORS, BaseDetector | ||
|
|
||
| SUPPORTED_TORCHVISION_DETECTORS = ["fasterrcnn_mobilenet_v3_large_fpn"] | ||
|
|
||
|
|
||
| @DETECTORS.register_module | ||
| class TorchvisionDetectorAdaptor(BaseDetector): | ||
|
||
| """An adaptor for torchvision detectors | ||
|
|
||
|
|
@@ -26,8 +29,8 @@ class TorchvisionDetectorAdaptor(BaseDetector): | |
| - fasterrcnn_mobilenet_v3_large_fpn | ||
| - fasterrcnn_resnet50_fpn_v2 | ||
|
Comment on lines
33
to
34
|
||
|
|
||
| This class should not be used out-of-the-box. Subclasses (such as FasterRCNN or | ||
| SSDLite) should be used instead. | ||
| This class can be used directly (e.g. with pre-trained COCO weights) or through its | ||
| subclasses (FasterRCNN or SSDLite) which adapt the model for DLC's 2-class detection. | ||
|
Comment on lines
+36
to
+37
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure why! usually base classes that shouldn't be used have some abstract methods and only serve as the external public API contract, so it doesn't seem like this was serving as an ABC before, but usually mixing levels like this is a bad call, having a "pre trained" or whatever subclass might be a good idea to discourage modifying the base class to accommodate any specific needs for this use that doesn't apply to the other subclasses and warps the contract made by the ABC. |
||
|
|
||
| The torchvision implementation does not allow to get both predictions and losses | ||
| with a single forward pass. Therefore, during evaluation only bounding box metrics | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -268,10 +268,24 @@ def load_model(self) -> None: | |
| self.model = self.model.half() | ||
|
|
||
| self.detector = None | ||
| if self.dynamic is None and raw_data.get("detector") is not None: | ||
| detector_cfg = self.cfg.get("detector") | ||
| has_detector_weights = raw_data.get("detector") is not None | ||
| if detector_cfg is not None: | ||
| detector_model_cfg = detector_cfg["model"] | ||
| uses_pretrained = ( | ||
| detector_model_cfg.get("pretrained", False) | ||
| or detector_model_cfg.get("weights") is not None | ||
| ) | ||
|
Comment on lines
+271
to
+278
|
||
| else: | ||
| uses_pretrained = False | ||
|
|
||
| if self.dynamic is None and (has_detector_weights or uses_pretrained): | ||
| self.detector = models.DETECTORS.build(self.cfg["detector"]["model"]) | ||
| self.detector.to(self.device) | ||
| self.detector.load_state_dict(raw_data["detector"]) | ||
|
|
||
| if has_detector_weights: | ||
| self.detector.load_state_dict(raw_data["detector"]) | ||
|
Comment on lines
+282
to
+287
|
||
|
|
||
| self.detector.eval() | ||
| if self.precision == "FP16": | ||
| self.detector = self.detector.half() | ||
|
|
@@ -281,7 +295,8 @@ def load_model(self) -> None: | |
| self.top_down_config.read_config(self.cfg) | ||
|
|
||
| detector_transforms = [v2.ToDtype(torch.float32, scale=True)] | ||
| if self.cfg["detector"]["data"]["inference"].get("normalize_images", False): | ||
| detector_data_cfg = detector_cfg.get("data", {}).get("inference", {}) | ||
| if detector_data_cfg.get("normalize_images", False): | ||
| detector_transforms.append(v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) | ||
| self.detector_transform = v2.Compose(detector_transforms) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The class docstring lists
fasterrcnn_resnet50_fpn_v2as supported, butSUPPORTED_TORCHVISION_DETECTORSonly includesfasterrcnn_mobilenet_v3_large_fpn. This inconsistency will confuse users (and currently the modelzoo validation rejects the resnet50 variant). Either update the docstring to match the allowlist, or expandSUPPORTED_TORCHVISION_DETECTORSand ensure the humanbody path supports that detector.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or, better suggestion, don't list the models in the docstring and just say "one of {variable name}"