Skip to content

Add mixed-attention Core ML mask support for stateful generation#331

Open
Skyline-23 wants to merge 3 commits into
huggingface:mainfrom
Skyline-23:feat/mixed-attention-coreml-masks
Open

Add mixed-attention Core ML mask support for stateful generation#331
Skyline-23 wants to merge 3 commits into
huggingface:mainfrom
Skyline-23:feat/mixed-attention-coreml-masks

Conversation

@Skyline-23

@Skyline-23 Skyline-23 commented Mar 8, 2026

Copy link
Copy Markdown
Contributor

What

Add support for stateful Core ML language models that require multiple attention masks during generation.

Why

The current runtime only handles attentionMask / causalMask, which is not sufficient for mixed-attention Core ML exports that need separate masks for different layer types.

This change allows the stateful generation path to populate:

  • fullAttentionMask
  • slidingAttentionMask

when those inputs are present in the Core ML model description.

Implementation

  • add fullAttentionMask and slidingAttentionMask keys to LanguageModel.Keys
  • detect those inputs from modelDescription
  • build additive full-attention and sliding-window masks in the stateful generation path
  • resolve the sliding window size from Core ML metadata first, and fall back to Hugging Face config when needed
  • factor stateful generation input assembly into a reusable helper for test coverage
  • keep existing single-mask models working unchanged

Tests

  • add regression tests for additive full-attention mask construction
  • add regression tests for sliding-window mask construction
  • add an integration-style test that verifies the full input dictionary for a mixed-attention model contract
  • verify with:
    swift test --filter LanguageModelCoreMLMaskTests

Scope clarification

This PR is intended to support explicit multi-mask Core ML generation contracts in the runtime.

It does not attempt to fix exporter-side approaches that reconstruct multiple masks inside a Core ML graph from a single causalMask input.

Additional context

Closes #330

Example converted Core ML repo using the explicit multi-mask contract:
https://huggingface.co/Skyline23/translategemma-4b-it-coreml

- add support for fullAttentionMask and slidingAttentionMask model inputs in the stateful generation path
- derive sliding window masks from model metadata or config when needed
- add regression tests for additive full and sliding attention mask construction
- add fullAttentionMask and slidingAttentionMask handling to the stateful generation path
- resolve sliding window size from model metadata or config for mixed-attention models
- add regression tests for additive full and sliding attention mask construction
- factor stateful generation input assembly into a reusable helper
- verify full and sliding attention mask keys, shapes, and additive values
- keep single-mask generation behavior unchanged while covering mixed-attention inputs

@pcuenca pcuenca left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very interesting and cool PR @Skyline-23! I won't be able to properly test and review it until the end of the week. Meanwhile, a couple of questions:

  • The converted example model seems to be using float32 instead of float16 (because of this line, and because the repo takes ~16 GB). Did you try to convert to float16? Did you try any quantization options?
  • Are you using or planning to use this Core ML model in a downstream app?

Thanks a lot for the contribution!

@Skyline-23

Copy link
Copy Markdown
Contributor Author

@pcuenca Sorry for late reply! It's fine. Please review slowly
The convert script was not optimized for the CoreML, I fixed it and add options with 4-bit and 8-bit quantized models.
It was problem with float 32 attention.
I will planning to use with on device translate app, In the test bed.
It could be canceled, but In my opinion this PR going to move repo good way for the various model support
Thanks!

@john-rocky

Copy link
Copy Markdown
Contributor

I gave this a local try on macOS 26 / M-series. PR builds cleanly against current main after a trivial rebase, the 4 new mask tests pass, and the full suite (120 tests) passes too.

I also pulled Skyline23/translategemma-4b-it-coreml (Int4 per-channel) to see the runtime path end-to-end. PR #331's detection works as expected on the loaded model:

isRequiringFullAttentionMask    = true
isRequiringSlidingAttentionMask = true
slidingWindowSize               = 1024

A greedy decode of 8 tokens after a 12-token prefill produces coherent output (≈ 2.3 tok/s, .cpuAndGPU). One thing I ran into: the MLTensor-based model.prediction(from:using:) overload that swift-transformers uses errored with "The output feature named logits's shape is data dependent and doesn't allow user specified output backing object." on this model, so I ended up driving inference via MLDictionaryFeatureProvider directly. That's likely orthogonal to this PR, but flagging in case it's familiar.

Two small observations from testing:

1. The convert script doesn't write co.huggingface.exporters.sliding_window into the user-defined metadata. The fallback chain in slidingWindowSize (modelConfig?.textConfig.slidingWindow) recovers 1024 from Hub config.json for google/translategemma-4b-it, so it works on a developer machine with network. For fully offline loads or for users whose co.huggingface.exporters.name doesn't point to a public repo, the metadata key would be the more stable source. A one-line addition to _apply_coreml_metadata would close the gap:

"co.huggingface.exporters.sliding_window": str(text_config.sliding_window),

2. Inside statefulGenerationInputs:

if includeSlidingAttentionMask {
    guard let slidingWindow else {
        fatalError(...)
    }

combined with try? await slidingWindowSize in predictNextTokenScores, a Hub-fetch failure (offline, rate-limited, gated repo) silently becomes nil and lands at this fatalError. Resolving the value once at init for stateful models that need it would surface the missing-config case earlier and more recoverably.

If a real-model integration fixture would be useful later, I'd be happy to convert and publish a small mixed-attention variant under my own namespace as a possible CI fixture — no need to block this PR on it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Stateful Core ML generation does not support mixed-attention models that require multiple mask inputs

3 participants