Skip to content

Commit 53c58dc

Browse files
[Feat] Add Vit-Det as LW-DETR encoder (#2063)
1 parent fb4bb21 commit 53c58dc

16 files changed

Lines changed: 668 additions & 24 deletions

File tree

docs/source/modules/models.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ doctr.models.classification
4545

4646
.. autofunction:: doctr.models.classification.vip_base
4747

48+
.. autofunction:: doctr.models.classification.vit_det_s
49+
50+
.. autofunction:: doctr.models.classification.vit_det_m
51+
4852
.. autofunction:: doctr.models.classification.crop_orientation_predictor
4953

5054
.. autofunction:: doctr.models.classification.page_orientation_predictor

doctr/models/classification/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
from .vit import *
66
from .textnet import *
77
from .vip import *
8+
from .vit_det import *
89
from .zoo import *

doctr/models/classification/predictor/pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def forward(
4747

4848
processed_batches = self.pre_processor(inputs)
4949
_params = next(self.model.parameters())
50-
self.model, processed_batches = set_device_and_dtype(
50+
self.model, processed_batches = set_device_and_dtype( # type: ignore[assignment]
5151
self.model, processed_batches, _params.device, _params.dtype
5252
)
5353
predicted_batches = [self.model(batch) for batch in processed_batches]

doctr/models/classification/textnet/pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(
6363
) -> None:
6464
_layers: list[nn.Module] = [
6565
*conv_sequence_pt(
66-
in_channels=3, out_channels=64, relu=True, bn=True, kernel_size=3, stride=2, padding=(1, 1)
66+
in_channels=3, out_channels=64, act=True, bn=True, kernel_size=3, stride=2, padding=(1, 1)
6767
),
6868
*[
6969
nn.Sequential(*[

doctr/models/classification/vip/layers/pytorch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,11 @@ def __init__(self, in_channels: int = 3, embed_dim: int = 128) -> None:
6363
self.embed_dim = embed_dim
6464
self.proj = nn.Sequential(
6565
*conv_sequence_pt(
66-
in_channels, embed_dim // 2, kernel_size=3, stride=2, padding=1, bias=False, bn=True, relu=False
66+
in_channels, embed_dim // 2, kernel_size=3, stride=2, padding=1, bias=False, bn=True, act=False
6767
),
6868
nn.GELU(),
6969
*conv_sequence_pt(
70-
embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1, bias=False, bn=True, relu=False
70+
embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1, bias=False, bn=True, act=False
7171
),
7272
nn.GELU(),
7373
)
@@ -240,10 +240,10 @@ def __init__(
240240
groups=dim,
241241
bias=False,
242242
bn=True,
243-
relu=False,
243+
act=False,
244244
),
245245
nn.GELU(),
246-
*conv_sequence_pt(dim, dim, kernel_size=1, groups=dim, bias=False, bn=True, relu=False),
246+
*conv_sequence_pt(dim, dim, kernel_size=1, groups=dim, bias=False, bn=True, act=False),
247247
)
248248
else:
249249
self.sr = nn.Identity() # type: ignore[assignment]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .pytorch import *
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .pytorch import *
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# Copyright (C) 2021-2026, Mindee.
2+
3+
# This program is licensed under the Apache License 2.0.
4+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5+
6+
import torch
7+
import torch.nn as nn
8+
import torch.nn.functional as F
9+
10+
from doctr.models.modules import DropPath
11+
12+
__all__ = ["PatchEmbed", "MLP", "AttentionWithCAE", "WindowedCAETransformerBlock"]
13+
14+
15+
class PatchEmbed(nn.Module):
16+
"""Simple 2D convolutional patch embedding layer for ViT Det
17+
18+
Args:
19+
kernel_size: kernel size of the projection layer.
20+
stride: stride of the projection layer.
21+
padding: padding size of the projection layer.
22+
in_chans: Number of input image channels.
23+
embed_dim: embed_dim (int): Patch embedding dimension.
24+
"""
25+
26+
def __init__(
27+
self,
28+
kernel_size: tuple[int, int] = (16, 16),
29+
stride: tuple[int, int] = (16, 16),
30+
padding: tuple[int, int] = (0, 0),
31+
in_chans: int = 3,
32+
embed_dim: int = 768,
33+
):
34+
super().__init__()
35+
36+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
37+
38+
def forward(self, x: torch.Tensor) -> torch.Tensor:
39+
# B C H W -> B H W C
40+
return self.proj(x).permute(0, 2, 3, 1)
41+
42+
43+
class MLP(nn.Module):
44+
"""Simple Multilayer Perceptron (MLP)
45+
46+
Args:
47+
in_features: number of input features
48+
hidden_features: number of hidden features (default: in_features)
49+
out_features: number of output features (default: in_features)
50+
act_layer: activation layer (default: nn.GELU)
51+
"""
52+
53+
def __init__(
54+
self,
55+
in_features: int,
56+
hidden_features: int | None = None,
57+
out_features: int | None = None,
58+
act_layer=nn.GELU,
59+
):
60+
super().__init__()
61+
62+
hidden_features = hidden_features or in_features
63+
out_features = out_features or in_features
64+
65+
self.net = nn.Sequential(
66+
nn.Linear(in_features, hidden_features),
67+
act_layer(),
68+
nn.Linear(hidden_features, out_features),
69+
)
70+
71+
def forward(self, x: torch.Tensor) -> torch.Tensor:
72+
return self.net(x)
73+
74+
75+
class AttentionWithCAE(nn.Module):
76+
"""Multi-head Attention block with CAE bias construction.
77+
78+
Args:
79+
dim: Number of input channels.
80+
num_heads: Number of attention heads.
81+
qkv_bias: If True, add a learnable bias to query, key, value.
82+
use_cae: If True, use CAE bias construction (separate q and v bias).
83+
"""
84+
85+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = True, use_cae: bool = False):
86+
super().__init__()
87+
88+
self.num_heads = num_heads
89+
self.head_dim = dim // num_heads
90+
self.scale = self.head_dim**-0.5
91+
self.use_cae = use_cae
92+
93+
self.qkv = nn.Linear(dim, dim * 3, bias=(qkv_bias and not use_cae))
94+
95+
# CAE bias
96+
if use_cae:
97+
self.q_bias = nn.Parameter(torch.zeros(dim))
98+
self.v_bias = nn.Parameter(torch.zeros(dim))
99+
100+
self.proj = nn.Linear(dim, dim)
101+
102+
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
103+
B, N, C = x.shape
104+
105+
# QKV projection
106+
if self.use_cae:
107+
zeros = torch.zeros_like(self.v_bias, requires_grad=False)
108+
qkv_bias = torch.cat([self.q_bias, zeros, self.v_bias])
109+
qkv = F.linear(x, weight=self.qkv.weight, bias=qkv_bias)
110+
else: # pragma: no cover
111+
qkv = self.qkv(x)
112+
113+
# Reshape to multi-head
114+
qkv = qkv.view(B, N, 3, self.num_heads, self.head_dim)
115+
q, k, v = qkv.permute(2, 0, 3, 1, 4)
116+
117+
# Attention
118+
attn = (q * self.scale) @ k.transpose(-2, -1)
119+
120+
if mask is not None:
121+
attn = attn.masked_fill(mask.view(B, 1, 1, N).expand_as(attn), float("-inf"))
122+
123+
attn = attn.softmax(dim=-1)
124+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
125+
return self.proj(x)
126+
127+
128+
class WindowedCAETransformerBlock(nn.Module):
129+
"""Transformer blocks with support of window attention and residual propagation blocks
130+
131+
Args:
132+
dim (int): Number of input channels.
133+
num_heads (int): Number of attention heads in each ViT block.
134+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
135+
qkv_bias (bool): If True, add a learnable bias to query, key, value.
136+
drop_prob (float): Stochastic depth rate.
137+
norm_layer (nn.Module): Normalization layer.
138+
act_layer (nn.Module): Activation layer.
139+
window (bool): If True, use window attention. Otherwise, use global attention.
140+
use_cae (bool): If True, use CAE bias construction (separate q and v bias).
141+
"""
142+
143+
def __init__(
144+
self,
145+
dim,
146+
num_heads,
147+
mlp_ratio=4.0,
148+
qkv_bias=True,
149+
drop_prob=0.0,
150+
norm_layer=nn.LayerNorm,
151+
act_layer=nn.GELU,
152+
window=False,
153+
use_cae=False,
154+
):
155+
super().__init__()
156+
157+
self.norm1 = norm_layer(dim)
158+
self.norm2 = norm_layer(dim)
159+
160+
self.attn = AttentionWithCAE(
161+
dim,
162+
num_heads=num_heads,
163+
qkv_bias=qkv_bias,
164+
use_cae=use_cae,
165+
)
166+
self.mlp = MLP(
167+
in_features=dim,
168+
hidden_features=int(dim * mlp_ratio),
169+
act_layer=act_layer,
170+
)
171+
self.drop_path = DropPath(drop_prob) if drop_prob > 0.0 else nn.Identity()
172+
173+
self.window = window
174+
self.use_cae = use_cae
175+
176+
if use_cae:
177+
self.gamma_1 = nn.Parameter(0.1 * torch.ones(dim), requires_grad=True)
178+
self.gamma_2 = nn.Parameter(0.1 * torch.ones(dim), requires_grad=True)
179+
180+
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
181+
B, HW, C = x.shape
182+
shortcut = x
183+
184+
x = self.norm1(x)
185+
mask_r = mask
186+
187+
# Window partitioning logic
188+
if not self.window:
189+
x = x.reshape(B // 16, 16 * HW, C)
190+
shortcut_r = shortcut.reshape(B // 16, 16 * HW, C)
191+
192+
if mask is not None: # pragma: no cover
193+
mask_r = mask.reshape(B // 16, 16 * HW)
194+
else:
195+
mask_r = None
196+
else:
197+
shortcut_r = shortcut
198+
199+
# Attention
200+
attn_out = self.attn(x, mask_r)
201+
202+
if self.use_cae:
203+
attn_out = self.gamma_1 * attn_out
204+
205+
x = shortcut_r + self.drop_path(attn_out)
206+
207+
# Reshape back if needed
208+
if not self.window:
209+
x = x.reshape(B, HW, C)
210+
if mask is not None: # pragma: no cover
211+
mask = mask.reshape(B, HW)
212+
213+
x = x + self.drop_path((self.gamma_2 * self.mlp(self.norm2(x))) if self.use_cae else self.mlp(self.norm2(x)))
214+
215+
return x

0 commit comments

Comments
 (0)