-
Notifications
You must be signed in to change notification settings - Fork 1k
gemma4_31b: add OpenAI serving entrypoint #20473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,7 @@ | |
| """ | ||
|
|
||
| import argparse | ||
| import json | ||
| import os | ||
|
|
||
| import torch | ||
|
|
@@ -135,6 +136,11 @@ def _pack_for_backend(model: nn.Module, path: str, backend: str) -> None: | |
| # Export + lower | ||
|
|
||
|
|
||
| def _mutable_buffer_metadata(model: nn.Module) -> str: | ||
| mutable = [name for name, _ in model.named_buffers() if ".kv_cache." in name] | ||
| return json.dumps({"version": 1, "mutable_buffers": mutable}) | ||
|
|
||
|
|
||
| def export_and_lower( | ||
| model: Gemma4_31B, | ||
| config: Gemma4_31BConfig, | ||
|
|
@@ -181,6 +187,7 @@ def _export_cuda( | |
| import executorch.backends.cuda.quantize_op_dispatch # noqa: F401 | ||
|
|
||
| materialize_runtime_buffers(model, dtype=torch.bfloat16) | ||
| mutable_buffer_metadata = _mutable_buffer_metadata(model) | ||
|
|
||
| if use_turboquant: | ||
| from executorch.examples.models.gemma4_31b.cuda_source_transformations import ( | ||
|
|
@@ -255,6 +262,8 @@ def _export_cuda( | |
| "get_vocab_size": config.vocab_size, | ||
| "get_n_layers": config.num_hidden_layers, | ||
| "get_max_prefill_chunk": max_prefill, | ||
| "get_min_prefill_chunk": 5, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you pad up to this chunk size? Should this be backend specific? MLX can handle seqlen=1.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. get_min_prefill_chunk=5 is CUDA-only; MLX uses the default path |
||
| "get_mutable_buffer_metadata": mutable_buffer_metadata, | ||
| "use_kv_cache": True, | ||
| "use_sdpa_with_kv_cache": False, | ||
| "enable_dynamic_shape": True, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this repeated twice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
runner and worker both need the MLX metallib copy