-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Fix #8462: embed patch sizes in einops pattern for einops >= 0.8 compatibility #8834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from 2 commits
78756f8
093f0b0
f1cffe8
d401c76
1527dd4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,6 +29,32 @@ | |
| SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos", "fourier"} | ||
|
|
||
|
|
||
| class _PatchRearrange(nn.Module): | ||
| """Fallback patch rearrangement using pure PyTorch, for einops compatibility.""" | ||
|
|
||
| def __init__(self, spatial_dims: int, patch_size: tuple) -> None: | ||
| super().__init__() | ||
| self.spatial_dims = spatial_dims | ||
| self.patch_size = patch_size | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| B, C = x.shape[0], x.shape[1] | ||
| sp = x.shape[2:] | ||
| g = tuple(s // p for s, p in zip(sp, self.patch_size)) | ||
| v: list[int] = [B, C] | ||
| for gi, pi in zip(g, self.patch_size): | ||
| v += [gi, pi] | ||
| x = x.view(*v) | ||
| n = self.spatial_dims | ||
| gdims = list(range(2, 2 + 2 * n, 2)) | ||
| pdims = list(range(3, 3 + 2 * n, 2)) | ||
| x = x.permute(0, *gdims, *pdims, 1).contiguous() | ||
| n_patches = 1 | ||
| for gi in g: | ||
| n_patches *= gi | ||
| return x.reshape(B, n_patches, -1) | ||
|
|
||
|
|
||
| class PatchEmbeddingBlock(nn.Module): | ||
| """ | ||
| A patch embedding block, based on: "Dosovitskiy et al., | ||
|
|
@@ -97,14 +123,16 @@ def __init__( | |
| in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size | ||
| ) | ||
| elif self.proj_type == "perceptron": | ||
| # for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)" | ||
| # for 3d: "b c (h p1) (w p2) (d p3) -> b (h w d) (p1 p2 p3 c)" | ||
| chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims] | ||
| from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars) | ||
| to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)" | ||
| axes_len = {f"p{i + 1}": p for i, p in enumerate(patch_size)} | ||
| self.patch_embeddings = nn.Sequential( | ||
| Rearrange(f"{from_chars} -> {to_chars}", **axes_len), nn.Linear(self.patch_dim, hidden_size) | ||
| ) | ||
| try: | ||
| rearrange_layer: nn.Module = Rearrange(f"{from_chars} -> {to_chars}", **axes_len) | ||
| except TypeError: | ||
| rearrange_layer = _PatchRearrange(spatial_dims, tuple(int(p) for p in patch_size)) | ||
| self.patch_embeddings = nn.Sequential(rearrange_layer, nn.Linear(self.patch_dim, hidden_size)) | ||
|
Comment on lines
+115
to
+124
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🌐 Web query:
💡 Result: No, einops Rearrange patterns do not support standalone integer literals (e.g., 'b 2 h w') as axis specifications. Patterns use symbolic axis names (letters like 'b', 'h'), anonymous axes (numbers in parentheses like '(b1 b2)'), ellipsis (...), or underscores (_ for skipping). Integer values are provided via the axes_lengths keyword argument (e.g., rearrange(x, 'b (h1 h) w c -> ...', h1=2)) or parsed from tensor shapes via parse_shape. This syntax has been consistent across versions, including 0.8.0+ (released 2024-04-28), with no changes introducing integer literal support in patterns per release notes, changelogs, and docs. Examples in official documentation and tutorials confirm named axes only. Citations:
🏁 Script executed: # Get the file and examine _PatchRearrange class and test coverage
head -150 monai/networks/blocks/patchembedding.py | tail -100Repository: Project-MONAI/MONAI Length of output: 4640 🏁 Script executed: # Check for existing tests covering the fallback path
find . -name "*.py" -path "*/test*" -type f | xargs grep -l "patchembedding" | head -5Repository: Project-MONAI/MONAI Length of output: 130 🏁 Script executed: # Get the _PatchRearrange class definition (around lines 32-55)
sed -n '32,56p' monai/networks/blocks/patchembedding.pyRepository: Project-MONAI/MONAI Length of output: 958 🏁 Script executed: # Check the test file for coverage of _PatchRearrange and the fallback path
cat monai/tests/networks/blocks/test_patchembedding.pyRepository: Project-MONAI/MONAI Length of output: 145 🏁 Script executed: # Find the test file with correct path
find . -name "test_patchembedding.py" -type f 2>/dev/nullRepository: Project-MONAI/MONAI Length of output: 109 🏁 Script executed: # Also check for any tests that import or test _PatchRearrange
rg "_PatchRearrange" --type pyRepository: Project-MONAI/MONAI Length of output: 275 🏁 Script executed: # Read the test file
cat ./tests/networks/blocks/test_patchembedding.pyRepository: Project-MONAI/MONAI Length of output: 7510 🏁 Script executed: # Check einops version constraints and any comments about the fallback logic
head -30 monai/networks/blocks/patchembedding.pyRepository: Project-MONAI/MONAI Length of output: 1308 🏁 Script executed: # Verify if x.view() in _PatchRearrange could fail on non-contiguous tensors
rg "x\.view\(" monai/networks/blocks/patchembedding.py -A 2 -B 2Repository: Project-MONAI/MONAI Length of output: 233 Retract the simplification suggestion; the current try/except approach is necessary. einops Rearrange does not support integer literals in patterns—numeric axis values must be passed via However, address these remaining issues in
🤖 Prompt for AI Agents
Comment on lines
+120
to
+124
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fallback path isn't deterministically covered by tests.
As per coding guidelines: "Ensure new or modified definitions will be covered by existing or new unit tests." 🤖 Prompt for AI Agents |
||
| self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) | ||
| self.dropout = nn.Dropout(dropout_rate) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.