Skip to content

Commit 9c64a3a

Browse files
committed
update encoder vis
1 parent a4f43d1 commit 9c64a3a

6 files changed

Lines changed: 249 additions & 2 deletions

File tree

README.md

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,89 @@
1-
# L2M
2-
Official implementation of our ICCV'25 paper "Learning Dense Feature Matching via Lifting Single 2D Image to 3D Space"
1+
2+
# Lift to Match (L2M): Learning Dense Feature Matching via Lifting Single 2D Image to 3D Space
3+
4+
*Accepted to ICCV 2025 Conference*
5+
6+
---
7+
8+
## 🧠 Overview
9+
10+
**Lift to Match (L2M)** is a novel two-stage framework for **dense feature matching** that lifts 2D images into 3D space to enhance feature generalization and robustness. Unlike traditional methods that depend on multi-view image pairs, L2M is trained on large-scale, diverse single-view image collections.
11+
12+
- **Stage 1:** Learn a **3D-aware ViT-based encoder** using multi-view image synthesis and 3D Gaussian feature representation.
13+
- **Stage 2:** Learn a **feature decoder** through novel-view rendering and synthetic data, enabling robust matching across diverse scenarios.
14+
15+
> 🚧 Code under construction.
16+
17+
---
18+
19+
## 🧪 Feature Visualization
20+
21+
We compare the 3D-aware ViT encoder from L2M (Stage 1) with other recent methods:
22+
23+
- **DINOv2**
24+
- **FIT3D**
25+
- **Ours: L2M Encoder**
26+
27+
Below are feature comparison results on the Sacré-Cœur dataset:
28+
29+
<div align="center">
30+
<img src="./assets/sacre_coeur_A_compare.png" width="90%">
31+
<br/>
32+
</div>
33+
34+
<div align="center">
35+
<img src="./assets/sacre_coeur_B_compare.png" width="90%">
36+
<br/>
37+
</div>
38+
39+
---
40+
41+
## 🏗️ Data Generation (WIP)
42+
43+
We synthesize multi-view images and 3D-aware Gaussian features from single-view inputs.
44+
Scripts for data generation will be released soon.
45+
46+
---
47+
48+
## 🏋️‍♀️ Model Training (Stage 1)
49+
50+
We provide pretrained weights for the 3D-aware ViT encoder.
51+
52+
> 🔗 **[Download pretrained encoder weights](#)** (Coming soon)
53+
54+
You can visualize features using:
55+
56+
```bash
57+
python vis_feats.py --input ./assets/sacre_coeur_A.jpg --model vit_encoder.pth
58+
```
59+
60+
---
61+
62+
## 🚀 Inference & Stage 2 (Coming Soon)
63+
64+
The second stage—feature decoding with novel-view rendering—is **under development**. Stay tuned!
65+
66+
---
67+
68+
## 📌 Citation
69+
70+
```bibtex
71+
@article{liang2025lift2match,
72+
title={Learning Dense Feature Matching via Lifting Single 2D Image to 3D Space},
73+
author={Liang, Yingping and Hu, Yutao and Shao, Wenqi and Fu, Ying},
74+
journal={ICCV},
75+
year={2025}
76+
}
77+
```
78+
79+
---
80+
81+
## 📋 License
82+
83+
This project is licensed under **CC BY 4.0**.
84+
85+
---
86+
87+
## 🙋‍♂️ Acknowledgements
88+
89+
We build upon recent advances in ROMA and FIT3D.

assets/sacre_coeur_A.jpg

115 KB
Loading

assets/sacre_coeur_A_compare.png

652 KB
Loading

assets/sacre_coeur_B.jpg

149 KB
Loading

assets/sacre_coeur_B_compare.png

659 KB
Loading

vis_feats.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import os
2+
import torch
3+
import torchvision.transforms as T
4+
import matplotlib.pyplot as plt
5+
import numpy as np
6+
from PIL import Image
7+
from sklearn.decomposition import PCA
8+
import argparse
9+
from romatch.models.transformer import vit_base
10+
11+
device = "cuda" if torch.cuda.is_available() else "cpu"
12+
13+
14+
# -------------------- 可视化函数 --------------------
15+
def vis_feat_map(features, patch_h, patch_w, resize_hw=(560, 560)):
16+
features = features.reshape(patch_h * patch_w, -1)
17+
pca = PCA(n_components=3)
18+
pca_feats = pca.fit_transform(features)
19+
pca_feats = (pca_feats - pca_feats.mean(0)) / (pca_feats.std(0) + 1e-5)
20+
pca_feats = np.clip(pca_feats * 0.5 + 0.5, 0, 1)
21+
img = pca_feats.reshape(patch_h, patch_w, 3)
22+
img = (img * 255).astype(np.uint8)
23+
img = Image.fromarray(img)
24+
return img.resize(resize_hw, Image.BICUBIC)
25+
26+
27+
28+
def save_all_visualizations(
29+
feat_dino, feat_fit3d, feat_L2M,
30+
patch_h, patch_w, base_name, save_dir, original_image=None
31+
):
32+
os.makedirs(save_dir, exist_ok=True)
33+
34+
# 单图保存
35+
img_dino = vis_feat_map(feat_dino, patch_h, patch_w)
36+
img_fit3d = vis_feat_map(feat_fit3d, patch_h, patch_w)
37+
img_L2M = vis_feat_map(feat_L2M, patch_h, patch_w)
38+
39+
img_dino.save(os.path.join(save_dir, f"{base_name}_dino.png"))
40+
img_fit3d.save(os.path.join(save_dir, f"{base_name}_fit3d_fit3d.png"))
41+
img_L2M.save(os.path.join(save_dir, f"{base_name}_L2M.png"))
42+
43+
# 拼图(含原图)
44+
fig, ax = plt.subplots(1, 4, figsize=(16, 4))
45+
for a, im, title in zip(
46+
ax,
47+
[original_image, img_dino, img_fit3d, img_L2M],
48+
["Original", "DINOv2", "Fit3D", "L2M (Ours)"]
49+
):
50+
a.imshow(im)
51+
a.set_title(title, fontsize=12)
52+
a.axis("off")
53+
plt.tight_layout()
54+
plt.savefig(os.path.join(save_dir, f"{base_name}_compare.png"))
55+
plt.close()
56+
57+
58+
# -------------------- 特征提取函数 --------------------
59+
def extract_features(model, image_tensor):
60+
with torch.no_grad():
61+
return model.forward_features(image_tensor)["x_norm_patchtokens"].squeeze(0).cpu().numpy()
62+
63+
64+
# -------------------- 主脚本 --------------------
65+
def main(args):
66+
os.makedirs(args.save_dir, exist_ok=True)
67+
68+
patch_h, patch_w = 37, 37
69+
img_size = patch_h * 14 # = 560
70+
feat_dim = 768
71+
72+
transform = T.Compose([
73+
T.GaussianBlur(9, sigma=(0.1, 2.0)),
74+
T.Resize((img_size, img_size)),
75+
T.CenterCrop((img_size, img_size)),
76+
T.ToTensor(),
77+
T.Normalize(mean=(0.485, 0.456, 0.406),
78+
std=(0.229, 0.224, 0.225)),
79+
])
80+
81+
# 初始化模型
82+
vit_kwargs = dict(
83+
img_size=img_size,
84+
patch_size=14,
85+
init_values=1.0,
86+
ffn_layer="mlp",
87+
block_chunks=0
88+
)
89+
90+
# DINOv2
91+
dino = vit_base(**vit_kwargs).eval().to(device)
92+
dino_ckpt_raw = torch.load(args.ckpt_dino, map_location="cpu")
93+
dino_ckpt = {k.replace("model.", ""): v for k, v in dino_ckpt_raw.items()}
94+
dino.load_state_dict(dino_ckpt, strict=False)
95+
96+
# Fit3D
97+
fit3d = vit_base(**vit_kwargs).eval().to(device)
98+
fit3d_ckpt_raw = torch.load(args.ckpt_fit3d, map_location="cpu")["model"]
99+
fit3d_ckpt = {k.replace("model.", ""): v for k, v in fit3d_ckpt_raw.items()}
100+
fit3d.load_state_dict(fit3d_ckpt, strict=False)
101+
102+
# L2M (Ours)
103+
L2M = vit_base(**vit_kwargs).eval().to(device)
104+
L2M_ckpt = torch.load(args.ckpt_L2M, map_location="cpu")
105+
L2M.load_state_dict(L2M_ckpt, strict=False)
106+
107+
for i, img_path in enumerate(args.img_paths):
108+
img = Image.open(img_path).convert("RGB")
109+
x = transform(img).unsqueeze(0).to(device)
110+
111+
# 提取特征
112+
feat_dino = extract_features(dino, x)
113+
feat_fit3d = extract_features(fit3d, x)
114+
feat_L2M = extract_features(L2M, x)
115+
116+
# 保存图
117+
base_name = os.path.splitext(os.path.basename(img_path))[0]
118+
save_all_visualizations(
119+
feat_dino, feat_fit3d, feat_L2M,
120+
patch_h, patch_w, base_name, args.save_dir,
121+
original_image=img
122+
)
123+
124+
125+
print(f"[{i+1}/{len(args.img_paths)}] Saved visualizations for {img_path}")
126+
127+
128+
if __name__ == "__main__":
129+
parser = argparse.ArgumentParser()
130+
parser.add_argument(
131+
"--img_paths",
132+
nargs="+",
133+
default=[
134+
"assets/sacre_coeur_A.jpg",
135+
"assets/sacre_coeur_B.jpg"
136+
],
137+
help="List of image paths"
138+
)
139+
parser.add_argument(
140+
"--ckpt_fit3d",
141+
default="ckpts/fit3d.pth",
142+
help="Original Fit3D checkpoint"
143+
)
144+
parser.add_argument(
145+
"--ckpt_L2M",
146+
default="ckpts/output_20250629/l2m_vit_base.pth",
147+
help="L2M Fit3D checkpoint"
148+
)
149+
parser.add_argument(
150+
"--ckpt_dino",
151+
default="ckpts/dinov2.pth",
152+
help="dino checkpoint"
153+
)
154+
parser.add_argument(
155+
"--save_dir",
156+
default="outputs_vis_feat",
157+
help="Directory to save visualizations"
158+
)
159+
args = parser.parse_args()
160+
main(args)

0 commit comments

Comments
 (0)