diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index ac4d3fc97a..ddf60140f1 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -85,13 +85,9 @@ def swap_shape_xy(seq: List[int]) -> List[int]: return [seq[1], seq[0]] + list(seq[2:]) -def build_fourier_pos_embed( +def _build_fourier_pos_embed( feat_shape: List[int], - bands: Optional[torch.Tensor] = None, - num_bands: int = 64, - max_res: int = 224, - temperature: float = 10000., - linear_bands: bool = False, + bands: torch.Tensor, include_grid: bool = False, in_pixels: bool = True, ref_feat_shape: Optional[List[int]] = None, @@ -100,46 +96,10 @@ def build_fourier_pos_embed( device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32, ) -> List[torch.Tensor]: - """ - - Args: - feat_shape: Feature shape for embedding. - bands: Pre-calculated frequency bands. - num_bands: Number of frequency bands (determines output dim). - max_res: Maximum resolution for pixel based freq. - temperature: Temperature for non-pixel freq. - linear_bands: Linear band spacing for pixel based freq. - include_grid: Include the spatial grid in output. - in_pixels: Output in pixel freq. - ref_feat_shape: Reference feature shape for resize / fine-tune. - grid_offset: Constant offset to add to grid for non-pixel freq. - grid_indexing: Indexing mode for meshgrid ('ij' or 'xy') - dtype: Output dtype. - device: Output device. - - Returns: - - """ - if bands is None: - if in_pixels: - bands = pixel_freq_bands( - num_bands, - float(max_res), - linear_bands=linear_bands, - device=device, - ) - else: - bands = freq_bands( - num_bands, - temperature=temperature, - step=1, - device=device, - ) - else: - if device is None: - device = bands.device - if dtype is None: - dtype = bands.dtype + if device is None: + device = bands.device + if dtype is None: + dtype = bands.dtype if grid_indexing == 'xy': feat_shape = swap_shape_xy(feat_shape) @@ -170,6 +130,92 @@ def build_fourier_pos_embed( return out +def _compute_bands( + bands: Optional[torch.Tensor], + num_bands: int, + max_res: int, + temperature: float, + linear_bands: bool, + in_pixels: bool, + device: Optional[torch.device], + dtype: torch.dtype, +) -> torch.Tensor: + if bands is None: + if in_pixels: + bands = pixel_freq_bands( + num_bands, + float(max_res), + linear_bands=linear_bands, + device=device, + ) + else: + bands = freq_bands( + num_bands, + temperature=temperature, + step=1, + device=device, + ) + return bands + + +def build_fourier_pos_embed( + feat_shape: List[int], + bands: Optional[torch.Tensor] = None, + num_bands: int = 64, + max_res: int = 224, + temperature: float = 10000., + linear_bands: bool = False, + include_grid: bool = False, + in_pixels: bool = True, + ref_feat_shape: Optional[List[int]] = None, + grid_offset: float = 0., + grid_indexing: str = 'ij', + device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, +) -> List[torch.Tensor]: + """ + + Args: + feat_shape: Feature shape for embedding. + bands: Pre-calculated frequency bands. + num_bands: Number of frequency bands (determines output dim). + max_res: Maximum resolution for pixel based freq. + temperature: Temperature for non-pixel freq. + linear_bands: Linear band spacing for pixel based freq. + include_grid: Include the spatial grid in output. + in_pixels: Output in pixel freq. + ref_feat_shape: Reference feature shape for resize / fine-tune. + grid_offset: Constant offset to add to grid for non-pixel freq. + grid_indexing: Indexing mode for meshgrid ('ij' or 'xy') + dtype: Output dtype. + device: Output device. + + Returns: + + """ + bands = _compute_bands( + bands=bands, + num_bands=num_bands, + max_res=max_res, + temperature=temperature, + linear_bands=linear_bands, + in_pixels=in_pixels, + device=device, + dtype=dtype, + ) + return _build_fourier_pos_embed( + feat_shape, + bands, + include_grid, + in_pixels, + ref_feat_shape, + grid_offset, + grid_indexing, + device, + dtype, + ) + + class FourierEmbed(nn.Module): def __init__( @@ -206,7 +252,7 @@ def init_non_persistent_buffers(self) -> None: def forward(self, x): B, C = x.shape[:2] feat_shape = x.shape[2:] - emb = build_fourier_pos_embed( + emb = _build_fourier_pos_embed( feat_shape, self.bands, include_grid=self.concat_grid, @@ -336,6 +382,35 @@ def apply_keep_indices_nlc( return pos_embed.gather(-2, keep_indices) +def _build_rotary_pos_embed( + feat_shape: List[int], + bands: torch.Tensor, + in_pixels: bool = True, + ref_feat_shape: Optional[List[int]] = None, + grid_offset: float = 0., + grid_indexing: str = 'ij', + device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, +): + sin_emb, cos_emb = _build_fourier_pos_embed( + feat_shape, + bands=bands, + in_pixels=in_pixels, + ref_feat_shape=ref_feat_shape, + grid_offset=grid_offset, + grid_indexing=grid_indexing, + device=device, + dtype=dtype, + ) + num_spatial_dim = 1 + # this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks + for x in feat_shape: + num_spatial_dim *= x + sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1) + cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1) + return sin_emb, cos_emb + + def build_rotary_pos_embed( feat_shape: List[int], bands: Optional[torch.Tensor] = None, @@ -369,27 +444,26 @@ def build_rotary_pos_embed( Returns: """ - sin_emb, cos_emb = build_fourier_pos_embed( - feat_shape, + bands = _compute_bands( bands=bands, num_bands=dim // 4, max_res=max_res, temperature=temperature, linear_bands=linear_bands, in_pixels=in_pixels, - ref_feat_shape=ref_feat_shape, - grid_offset=grid_offset, - grid_indexing=grid_indexing, device=device, dtype=dtype, ) - num_spatial_dim = 1 - # this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks - for x in feat_shape: - num_spatial_dim *= x - sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1) - cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1) - return sin_emb, cos_emb + return _build_rotary_pos_embed( + feat_shape, + bands, + in_pixels, + ref_feat_shape, + grid_offset, + grid_indexing, + device, + dtype, + ) class RotaryEmbedding(nn.Module): @@ -480,12 +554,10 @@ def _compute_bands(self, device=None, dtype=None): return bands.to(device=device, dtype=dtype) def _get_pos_embed_values(self, feat_shape: List[int], device=None, dtype=torch.float32): - emb_sin, emb_cos = build_rotary_pos_embed( + bands = self._compute_bands(device, dtype) + emb_sin, emb_cos = _build_rotary_pos_embed( feat_shape=feat_shape, - dim=self.dim, - max_res=self.max_res, - temperature=self.temperature, - linear_bands=self.linear_bands, + bands=bands, in_pixels=self.in_pixels, ref_feat_shape=self.ref_feat_shape, grid_offset=self.grid_offset, @@ -514,7 +586,7 @@ def update_feat_shape(self, feat_shape: List[int]): def get_embed(self, shape: Optional[List[int]] = None): if shape is not None and self.bands is not None: # rebuild embeddings every call, use if target shape changes - return build_rotary_pos_embed( + return _build_rotary_pos_embed( shape, self.bands, in_pixels=self.in_pixels, @@ -614,12 +686,10 @@ def _compute_bands(self, device=None, dtype=None): return bands.to(device=device, dtype=dtype) def _get_pos_embed_values(self, feat_shape: List[int], device=None, dtype=torch.float32): - embeds = build_rotary_pos_embed( + bands = self._compute_bands(device, dtype) + embeds = _build_rotary_pos_embed( feat_shape=feat_shape, - dim=self.dim, - max_res=self.max_res, - temperature=self.temperature, - linear_bands=self.linear_bands, + bands=bands, in_pixels=self.in_pixels, ref_feat_shape=self.ref_feat_shape, grid_offset=self.grid_offset, @@ -647,7 +717,7 @@ def update_feat_shape(self, feat_shape: List[int]): def get_embed(self, shape: Optional[List[int]] = None): if shape is not None and self.bands is not None: # rebuild embeddings from cached bands every call, use if target shape changes - embeds = build_rotary_pos_embed( + embeds = _build_rotary_pos_embed( shape, self.bands, in_pixels=self.in_pixels, @@ -691,7 +761,7 @@ def get_batch_embeds( max_w = max(w for h, w in shapes) # Generate embeddings for max size ONCE - sin_emb, cos_emb = build_rotary_pos_embed( + sin_emb, cos_emb = _build_rotary_pos_embed( feat_shape=(max_h, max_w), bands=self.bands, in_pixels=self.in_pixels,