Skip to content

Fix MLX implementation of LocalDotProductSelfAttention that previously overwrote block_size#16

Open
lancelotblanchard wants to merge 1 commit into
google:mainfrom
lancelotblanchard:main
Open

Fix MLX implementation of LocalDotProductSelfAttention that previously overwrote block_size#16
lancelotblanchard wants to merge 1 commit into
google:mainfrom
lancelotblanchard:main

Conversation

@lancelotblanchard

Copy link
Copy Markdown

The current MLX implementation of LocalDotProductSelfAttention does:

@property
@override
def block_size(self):
  return self._block_size_config

, which overwrites the block_size property of the class and breaks inference using step, which needs block_size=1. Instead, the JAX implementation does not overwrite the property. This PR fixes the MLX implementation to match the JAX one.

Additionally, this PR adds corresponding attention tests (self.assertEqual(layer.block_size, 1)) for the step function, and modifies DefaultSteppable to make sure properties are read from base class backend_sl.types.Steppable and not DefaultTestLayer.

@rryan rryan requested a review from JulianSlzr June 10, 2026 20:57
@rryan

rryan commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator

Great catch -- this is an unfortunate name clash. block_size in LocalDotProductSelfAttention is unrelated to the block_size concept in the SequenceLayer API. I imagine this got mixed up in the initial MLX port.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants