From a8e7175f333859542aa3fa8938995e53943d240e Mon Sep 17 00:00:00 2001 From: Yang Kun Date: Mon, 19 Jan 2026 11:38:51 +0800 Subject: [PATCH 1/7] refine build_fourier_pos_embed() to use precomputed bands --- timm/layers/pos_embed_sincos.py | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index ac4d3fc97a..50ab6173e0 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -87,7 +87,7 @@ def swap_shape_xy(seq: List[int]) -> List[int]: def build_fourier_pos_embed( feat_shape: List[int], - bands: Optional[torch.Tensor] = None, + bands: torch.Tensor = None, num_bands: int = 64, max_res: int = 224, temperature: float = 10000., @@ -120,26 +120,8 @@ def build_fourier_pos_embed( Returns: """ - if bands is None: - if in_pixels: - bands = pixel_freq_bands( - num_bands, - float(max_res), - linear_bands=linear_bands, - device=device, - ) - else: - bands = freq_bands( - num_bands, - temperature=temperature, - step=1, - device=device, - ) - else: - if device is None: - device = bands.device - if dtype is None: - dtype = bands.dtype + device = device or bands.device + dtype = dtype or bands.dtype if grid_indexing == 'xy': feat_shape = swap_shape_xy(feat_shape) @@ -338,7 +320,7 @@ def apply_keep_indices_nlc( def build_rotary_pos_embed( feat_shape: List[int], - bands: Optional[torch.Tensor] = None, + bands: torch.Tensor = None, dim: int = 64, max_res: int = 224, temperature: float = 10000., @@ -480,8 +462,10 @@ def _compute_bands(self, device=None, dtype=None): return bands.to(device=device, dtype=dtype) def _get_pos_embed_values(self, feat_shape: List[int], device=None, dtype=torch.float32): + bands = self._compute_bands(device, dtype) emb_sin, emb_cos = build_rotary_pos_embed( feat_shape=feat_shape, + bands=bands, dim=self.dim, max_res=self.max_res, temperature=self.temperature, @@ -614,8 +598,10 @@ def _compute_bands(self, device=None, dtype=None): return bands.to(device=device, dtype=dtype) def _get_pos_embed_values(self, feat_shape: List[int], device=None, dtype=torch.float32): + bands = self._compute_bands(device, dtype) embeds = build_rotary_pos_embed( feat_shape=feat_shape, + bands=bands, dim=self.dim, max_res=self.max_res, temperature=self.temperature, From 0732ab62ef6872936867d6268850458ac0b2c082 Mon Sep 17 00:00:00 2001 From: Yang Kun Date: Mon, 19 Jan 2026 11:48:41 +0800 Subject: [PATCH 2/7] remove default parameter --- timm/layers/pos_embed_sincos.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index 50ab6173e0..77a50b442d 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -87,7 +87,7 @@ def swap_shape_xy(seq: List[int]) -> List[int]: def build_fourier_pos_embed( feat_shape: List[int], - bands: torch.Tensor = None, + bands: torch.Tensor, num_bands: int = 64, max_res: int = 224, temperature: float = 10000., @@ -320,7 +320,7 @@ def apply_keep_indices_nlc( def build_rotary_pos_embed( feat_shape: List[int], - bands: torch.Tensor = None, + bands: torch.Tensor, dim: int = 64, max_res: int = 224, temperature: float = 10000., From 81b0a5a020f3705bd8934650436e789a53b6b233 Mon Sep 17 00:00:00 2001 From: Yang Kun Date: Mon, 19 Jan 2026 12:06:03 +0800 Subject: [PATCH 3/7] Remove unneeded parameters --- timm/layers/pos_embed_sincos.py | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index 77a50b442d..58bc605ecd 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -88,10 +88,6 @@ def swap_shape_xy(seq: List[int]) -> List[int]: def build_fourier_pos_embed( feat_shape: List[int], bands: torch.Tensor, - num_bands: int = 64, - max_res: int = 224, - temperature: float = 10000., - linear_bands: bool = False, include_grid: bool = False, in_pixels: bool = True, ref_feat_shape: Optional[List[int]] = None, @@ -105,10 +101,6 @@ def build_fourier_pos_embed( Args: feat_shape: Feature shape for embedding. bands: Pre-calculated frequency bands. - num_bands: Number of frequency bands (determines output dim). - max_res: Maximum resolution for pixel based freq. - temperature: Temperature for non-pixel freq. - linear_bands: Linear band spacing for pixel based freq. include_grid: Include the spatial grid in output. in_pixels: Output in pixel freq. ref_feat_shape: Reference feature shape for resize / fine-tune. @@ -321,10 +313,6 @@ def apply_keep_indices_nlc( def build_rotary_pos_embed( feat_shape: List[int], bands: torch.Tensor, - dim: int = 64, - max_res: int = 224, - temperature: float = 10000., - linear_bands: bool = False, in_pixels: bool = True, ref_feat_shape: Optional[List[int]] = None, grid_offset: float = 0., @@ -337,10 +325,6 @@ def build_rotary_pos_embed( Args: feat_shape: Spatial shape of the target tensor for embedding. bands: Optional pre-generated frequency bands - dim: Output dimension of embedding tensor. - max_res: Maximum resolution for pixel mode. - temperature: Temperature (inv freq) for non-pixel mode - linear_bands: Linearly (instead of log) spaced bands for pixel mode in_pixels: Pixel vs language (inv freq) mode. ref_feat_shape: Reference feature shape for resize / fine-tune. grid_offset: Constant offset to add to grid for non-pixel freq. @@ -354,10 +338,6 @@ def build_rotary_pos_embed( sin_emb, cos_emb = build_fourier_pos_embed( feat_shape, bands=bands, - num_bands=dim // 4, - max_res=max_res, - temperature=temperature, - linear_bands=linear_bands, in_pixels=in_pixels, ref_feat_shape=ref_feat_shape, grid_offset=grid_offset, @@ -466,10 +446,6 @@ def _get_pos_embed_values(self, feat_shape: List[int], device=None, dtype=torch. emb_sin, emb_cos = build_rotary_pos_embed( feat_shape=feat_shape, bands=bands, - dim=self.dim, - max_res=self.max_res, - temperature=self.temperature, - linear_bands=self.linear_bands, in_pixels=self.in_pixels, ref_feat_shape=self.ref_feat_shape, grid_offset=self.grid_offset, @@ -602,10 +578,6 @@ def _get_pos_embed_values(self, feat_shape: List[int], device=None, dtype=torch. embeds = build_rotary_pos_embed( feat_shape=feat_shape, bands=bands, - dim=self.dim, - max_res=self.max_res, - temperature=self.temperature, - linear_bands=self.linear_bands, in_pixels=self.in_pixels, ref_feat_shape=self.ref_feat_shape, grid_offset=self.grid_offset, From eabb18af70e731659488f0d715d14f2294d9b43a Mon Sep 17 00:00:00 2001 From: Yang Kun Date: Tue, 20 Jan 2026 08:01:53 +0800 Subject: [PATCH 4/7] Fix a wrong --- timm/layers/pos_embed_sincos.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index 58bc605ecd..af63873a6b 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -112,8 +112,10 @@ def build_fourier_pos_embed( Returns: """ - device = device or bands.device - dtype = dtype or bands.dtype + if device is None: + device = bands.device + if dtype is None: + dtype = bands.dtype if grid_indexing == 'xy': feat_shape = swap_shape_xy(feat_shape) From 4e2b121d3d683fdcb3b8e1a80f68d142583d1d0a Mon Sep 17 00:00:00 2001 From: Yang Kun Date: Sat, 24 Jan 2026 15:42:28 +0800 Subject: [PATCH 5/7] Keep original function signature for backward compatibility --- timm/layers/pos_embed_sincos.py | 183 +++++++++++++++++++++++++------- 1 file changed, 144 insertions(+), 39 deletions(-) diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index af63873a6b..9e4a1a90e2 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -85,7 +85,7 @@ def swap_shape_xy(seq: List[int]) -> List[int]: return [seq[1], seq[0]] + list(seq[2:]) -def build_fourier_pos_embed( +def _build_fourier_pos_embed( feat_shape: List[int], bands: torch.Tensor, include_grid: bool = False, @@ -96,22 +96,6 @@ def build_fourier_pos_embed( device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32, ) -> List[torch.Tensor]: - """ - - Args: - feat_shape: Feature shape for embedding. - bands: Pre-calculated frequency bands. - include_grid: Include the spatial grid in output. - in_pixels: Output in pixel freq. - ref_feat_shape: Reference feature shape for resize / fine-tune. - grid_offset: Constant offset to add to grid for non-pixel freq. - grid_indexing: Indexing mode for meshgrid ('ij' or 'xy') - dtype: Output dtype. - device: Output device. - - Returns: - - """ if device is None: device = bands.device if dtype is None: @@ -146,6 +130,75 @@ def build_fourier_pos_embed( return out +def build_fourier_pos_embed( + feat_shape: List[int], + bands: Optional[torch.Tensor] = None, + num_bands: int = 64, + max_res: int = 224, + temperature: float = 10000., + linear_bands: bool = False, + include_grid: bool = False, + in_pixels: bool = True, + ref_feat_shape: Optional[List[int]] = None, + grid_offset: float = 0., + grid_indexing: str = 'ij', + device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, +) -> List[torch.Tensor]: + """ + + Args: + feat_shape: Feature shape for embedding. + bands: Pre-calculated frequency bands. + num_bands: Number of frequency bands (determines output dim). + max_res: Maximum resolution for pixel based freq. + temperature: Temperature for non-pixel freq. + linear_bands: Linear band spacing for pixel based freq. + include_grid: Include the spatial grid in output. + in_pixels: Output in pixel freq. + ref_feat_shape: Reference feature shape for resize / fine-tune. + grid_offset: Constant offset to add to grid for non-pixel freq. + grid_indexing: Indexing mode for meshgrid ('ij' or 'xy') + dtype: Output dtype. + device: Output device. + + Returns: + + """ + if bands is None: + if in_pixels: + bands = pixel_freq_bands( + num_bands, + float(max_res), + linear_bands=linear_bands, + device=device, + ) + else: + bands = freq_bands( + num_bands, + temperature=temperature, + step=1, + device=device, + ) + else: + if device is None: + device = bands.device + if dtype is None: + dtype = bands.dtype + + return _build_fourier_pos_embed( + feat_shape, + bands, + include_grid, + in_pixels, + ref_feat_shape, + grid_offset, + grid_indexing, + device, + dtype, + ) + + class FourierEmbed(nn.Module): def __init__( @@ -182,7 +235,7 @@ def init_non_persistent_buffers(self) -> None: def forward(self, x): B, C = x.shape[:2] feat_shape = x.shape[2:] - emb = build_fourier_pos_embed( + emb = _build_fourier_pos_embed( feat_shape, self.bands, include_grid=self.concat_grid, @@ -312,7 +365,7 @@ def apply_keep_indices_nlc( return pos_embed.gather(-2, keep_indices) -def build_rotary_pos_embed( +def _build_rotary_pos_embed( feat_shape: List[int], bands: torch.Tensor, in_pixels: bool = True, @@ -322,21 +375,6 @@ def build_rotary_pos_embed( device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32, ): - """ - - Args: - feat_shape: Spatial shape of the target tensor for embedding. - bands: Optional pre-generated frequency bands - in_pixels: Pixel vs language (inv freq) mode. - ref_feat_shape: Reference feature shape for resize / fine-tune. - grid_offset: Constant offset to add to grid for non-pixel freq. - grid_indexing: Indexing mode for meshgrid ('ij' or 'xy') - device: Output device. - dtype: Output dtype. - - Returns: - - """ sin_emb, cos_emb = build_fourier_pos_embed( feat_shape, bands=bands, @@ -356,6 +394,73 @@ def build_rotary_pos_embed( return sin_emb, cos_emb +def build_rotary_pos_embed( + feat_shape: List[int], + bands: Optional[torch.Tensor] = None, + dim: int = 64, + max_res: int = 224, + temperature: float = 10000., + linear_bands: bool = False, + in_pixels: bool = True, + ref_feat_shape: Optional[List[int]] = None, + grid_offset: float = 0., + grid_indexing: str = 'ij', + device: Optional[torch.device] = None, + dtype: torch.dtype = torch.float32, +): + """ + + Args: + feat_shape: Spatial shape of the target tensor for embedding. + bands: Optional pre-generated frequency bands + dim: Output dimension of embedding tensor. + max_res: Maximum resolution for pixel mode. + temperature: Temperature (inv freq) for non-pixel mode + linear_bands: Linearly (instead of log) spaced bands for pixel mode + in_pixels: Pixel vs language (inv freq) mode. + ref_feat_shape: Reference feature shape for resize / fine-tune. + grid_offset: Constant offset to add to grid for non-pixel freq. + grid_indexing: Indexing mode for meshgrid ('ij' or 'xy') + device: Output device. + dtype: Output dtype. + + Returns: + + """ + if bands is None: + if in_pixels: + bands = pixel_freq_bands( + num_bands, + float(max_res), + linear_bands=linear_bands, + device=device, + ) + else: + bands = freq_bands( + num_bands, + temperature=temperature, + step=1, + device=device, + ) + else: + if device is None: + device = bands.device + if dtype is None: + dtype = bands.dtype + + return _build_rotary_pos_embed( + feat_shape, + bands, + include_grid, + in_pixels, + ref_feat_shape, + grid_offset, + grid_indexing, + device, + dtype, + ) + + class RotaryEmbedding(nn.Module): """ Rotary position embedding @@ -445,7 +550,7 @@ def _compute_bands(self, device=None, dtype=None): def _get_pos_embed_values(self, feat_shape: List[int], device=None, dtype=torch.float32): bands = self._compute_bands(device, dtype) - emb_sin, emb_cos = build_rotary_pos_embed( + emb_sin, emb_cos = _build_rotary_pos_embed( feat_shape=feat_shape, bands=bands, in_pixels=self.in_pixels, @@ -476,7 +581,7 @@ def update_feat_shape(self, feat_shape: List[int]): def get_embed(self, shape: Optional[List[int]] = None): if shape is not None and self.bands is not None: # rebuild embeddings every call, use if target shape changes - return build_rotary_pos_embed( + return _build_rotary_pos_embed( shape, self.bands, in_pixels=self.in_pixels, @@ -577,7 +682,7 @@ def _compute_bands(self, device=None, dtype=None): def _get_pos_embed_values(self, feat_shape: List[int], device=None, dtype=torch.float32): bands = self._compute_bands(device, dtype) - embeds = build_rotary_pos_embed( + embeds = _build_rotary_pos_embed( feat_shape=feat_shape, bands=bands, in_pixels=self.in_pixels, @@ -607,7 +712,7 @@ def update_feat_shape(self, feat_shape: List[int]): def get_embed(self, shape: Optional[List[int]] = None): if shape is not None and self.bands is not None: # rebuild embeddings from cached bands every call, use if target shape changes - embeds = build_rotary_pos_embed( + embeds = _build_rotary_pos_embed( shape, self.bands, in_pixels=self.in_pixels, @@ -651,7 +756,7 @@ def get_batch_embeds( max_w = max(w for h, w in shapes) # Generate embeddings for max size ONCE - sin_emb, cos_emb = build_rotary_pos_embed( + sin_emb, cos_emb = _build_rotary_pos_embed( feat_shape=(max_h, max_w), bands=self.bands, in_pixels=self.in_pixels, From 2f6b316014fbc1287cf032374a0e347244bc1f00 Mon Sep 17 00:00:00 2001 From: Yang Kun Date: Sat, 24 Jan 2026 16:10:06 +0800 Subject: [PATCH 6/7] Fix wrongs --- timm/layers/pos_embed_sincos.py | 79 +++++++++++++++------------------ 1 file changed, 35 insertions(+), 44 deletions(-) diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index 9e4a1a90e2..d264d47fe9 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -130,6 +130,34 @@ def _build_fourier_pos_embed( return out +def _compute_bands( + bands: Optional[torch.Tensor], + num_bands: int, + max_res: int, + temperature: float, + linear_bands: bool, + in_pixels: bool, + device: Optional[torch.device], + dtype: torch.dtype, +) -> torch.Tensor: + if bands is None: + if in_pixels: + bands = pixel_freq_bands( + num_bands, + float(max_res), + linear_bands=linear_bands, + device=device, + ) + else: + bands = freq_bands( + num_bands, + temperature=temperature, + step=1, + device=device, + ) + return bands + + def build_fourier_pos_embed( feat_shape: List[int], bands: Optional[torch.Tensor] = None, @@ -165,27 +193,9 @@ def build_fourier_pos_embed( Returns: """ - if bands is None: - if in_pixels: - bands = pixel_freq_bands( - num_bands, - float(max_res), - linear_bands=linear_bands, - device=device, - ) - else: - bands = freq_bands( - num_bands, - temperature=temperature, - step=1, - device=device, - ) - else: - if device is None: - device = bands.device - if dtype is None: - dtype = bands.dtype - + bands = _compute_bands( + bands, num_bands, max_res, temperature, linear_bands, in_pixels, device, dtype + ) return _build_fourier_pos_embed( feat_shape, bands, @@ -375,7 +385,7 @@ def _build_rotary_pos_embed( device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32, ): - sin_emb, cos_emb = build_fourier_pos_embed( + sin_emb, cos_emb = _build_fourier_pos_embed( feat_shape, bands=bands, in_pixels=in_pixels, @@ -427,31 +437,12 @@ def build_rotary_pos_embed( Returns: """ - if bands is None: - if in_pixels: - bands = pixel_freq_bands( - num_bands, - float(max_res), - linear_bands=linear_bands, - device=device, - ) - else: - bands = freq_bands( - num_bands, - temperature=temperature, - step=1, - device=device, - ) - else: - if device is None: - device = bands.device - if dtype is None: - dtype = bands.dtype - + bands = _compute_bands( + bands, dim // 4, max_res, temperature, linear_bands, in_pixels, device, dtype + ) return _build_rotary_pos_embed( feat_shape, bands, - include_grid, in_pixels, ref_feat_shape, grid_offset, From ae2f4c552af79df8e005cf37cb31c23702e393c2 Mon Sep 17 00:00:00 2001 From: Yang Kun Date: Sat, 24 Jan 2026 16:23:55 +0800 Subject: [PATCH 7/7] Use keyword parameters --- timm/layers/pos_embed_sincos.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/timm/layers/pos_embed_sincos.py b/timm/layers/pos_embed_sincos.py index d264d47fe9..ddf60140f1 100644 --- a/timm/layers/pos_embed_sincos.py +++ b/timm/layers/pos_embed_sincos.py @@ -194,7 +194,14 @@ def build_fourier_pos_embed( """ bands = _compute_bands( - bands, num_bands, max_res, temperature, linear_bands, in_pixels, device, dtype + bands=bands, + num_bands=num_bands, + max_res=max_res, + temperature=temperature, + linear_bands=linear_bands, + in_pixels=in_pixels, + device=device, + dtype=dtype, ) return _build_fourier_pos_embed( feat_shape, @@ -438,7 +445,14 @@ def build_rotary_pos_embed( """ bands = _compute_bands( - bands, dim // 4, max_res, temperature, linear_bands, in_pixels, device, dtype + bands=bands, + num_bands=dim // 4, + max_res=max_res, + temperature=temperature, + linear_bands=linear_bands, + in_pixels=in_pixels, + device=device, + dtype=dtype, ) return _build_rotary_pos_embed( feat_shape,