|
| 1 | +import torch |
| 2 | +import torch.nn.functional as F |
| 3 | +from loma.descriptor.dedode import DeDoDeDescriptor |
| 4 | +from loma.detector.dad import DaD |
| 5 | +from loma.device import device |
| 6 | + |
| 7 | +from ..utils.base_model import BaseModel |
| 8 | + |
| 9 | + |
| 10 | +class LoMaExtractor(BaseModel): |
| 11 | + default_conf = { |
| 12 | + "max_keypoints": 2048, |
| 13 | + "compile": False, |
| 14 | + } |
| 15 | + required_inputs = ["image"] |
| 16 | + |
| 17 | + def _init(self, conf): |
| 18 | + # DaD weights loaded by default |
| 19 | + self.detector = DaD(DaD.Cfg(compile=conf["compile"])).eval() |
| 20 | + |
| 21 | + # Descriptor weights need to be manually loaded |
| 22 | + self.descriptor = DeDoDeDescriptor( |
| 23 | + DeDoDeDescriptor.Cfg(compile=conf["compile"], arch="dedode_g") |
| 24 | + ).eval() |
| 25 | + weights = torch.hub.load_state_dict_from_url( |
| 26 | + "https://github.com/davnords/storage/releases/download/loma/loma_B.pt", |
| 27 | + map_location=device, |
| 28 | + ) |
| 29 | + weights = {k: v for k, v in weights.items() if k.startswith("_descriptor.")} |
| 30 | + weights = {k[len("_descriptor.") :]: v for k, v in weights.items()} |
| 31 | + self.descriptor.load_state_dict(weights, strict=True) |
| 32 | + |
| 33 | + def preprocess_image(self, image, H=784, W=784): |
| 34 | + image = F.interpolate( |
| 35 | + image, |
| 36 | + size=(H, W), |
| 37 | + mode="bilinear", |
| 38 | + align_corners=False, |
| 39 | + )[0] |
| 40 | + return image[None].to(device) |
| 41 | + |
| 42 | + def detect_and_describe(self, batch: dict[str, torch.Tensor]): |
| 43 | + H, W = batch["image"].shape[2:] |
| 44 | + |
| 45 | + detections = self.detector.detect( |
| 46 | + batch, num_keypoints=self.conf["max_keypoints"] |
| 47 | + ) |
| 48 | + keypoints = detections["keypoints"] |
| 49 | + |
| 50 | + description = self.descriptor.describe_keypoints( |
| 51 | + self.preprocess_image(batch["image"]), |
| 52 | + keypoints, |
| 53 | + ) |
| 54 | + keypoints = self.detector.to_pixel_coords(keypoints, H, W) |
| 55 | + keypoints = keypoints - 0.5 # be consistent with hloc |
| 56 | + keypoints[..., 0] = keypoints[..., 0].clamp(0.5, W - 1.5) |
| 57 | + keypoints[..., 1] = keypoints[..., 1].clamp(0.5, H - 1.5) |
| 58 | + return { |
| 59 | + "keypoints": [keypoints[0]], |
| 60 | + "descriptors": [description["descriptions"].transpose(-1, -2)[0]], |
| 61 | + "scores": [detections["keypoint_probs"][0]], |
| 62 | + } |
| 63 | + |
| 64 | + def _forward(self, data): |
| 65 | + return self.detect_and_describe(data) |
0 commit comments