|
56 | 56 | QWEN_TOOL_CALL_FAMILY, |
57 | 57 | ChatModelMixin, |
58 | 58 | generate_completion_chunk, |
| 59 | + get_context_length_from_config, |
59 | 60 | ) |
60 | 61 |
|
61 | 62 | logger = logging.getLogger(__name__) |
@@ -436,23 +437,6 @@ class PromptCache: |
436 | 437 | tokens: List[int] = field(default_factory=list) |
437 | 438 |
|
438 | 439 |
|
439 | | -def get_context_length(config: dict) -> int: |
440 | | - """Get the context length of a model from model config.""" |
441 | | - if config.get("max_sequence_length") is not None: |
442 | | - max_sequence_length = config["max_sequence_length"] |
443 | | - else: |
444 | | - max_sequence_length = 2048 |
445 | | - if config.get("seq_length") is not None: |
446 | | - seq_length = config["seq_length"] |
447 | | - else: |
448 | | - seq_length = 2048 |
449 | | - if config.get("max_position_embeddings") is not None: |
450 | | - max_position_embeddings = config["max_position_embeddings"] |
451 | | - else: |
452 | | - max_position_embeddings = 2048 |
453 | | - return max(max_sequence_length, seq_length, max_position_embeddings) |
454 | | - |
455 | | - |
456 | 440 | class MLXModel(LLM, ChatModelMixin): |
457 | 441 | _rank_to_addresses: Optional[Dict[int, str]] |
458 | 442 | allow_batch: bool = False |
@@ -752,7 +736,7 @@ def wait_for_load(self): |
752 | 736 | # get context length |
753 | 737 | config = load_config(Path(self.model_path)) |
754 | 738 | config.update(self._model_config) |
755 | | - self._context_length = get_context_length(config) |
| 739 | + self._context_length = get_context_length_from_config(config) |
756 | 740 |
|
757 | 741 | # Update allow_batch based on distributed inference |
758 | 742 | # Only enable continuous batching for non-distributed inference (single worker) |
@@ -1418,7 +1402,7 @@ def load(self): |
1418 | 1402 | # get context length |
1419 | 1403 | config = load_config(Path(self.model_path)) |
1420 | 1404 | config.update(self._model_config) |
1421 | | - self._context_length = get_context_length(config) |
| 1405 | + self._context_length = get_context_length_from_config(config) |
1422 | 1406 |
|
1423 | 1407 | def _generate_stream_inner(self, **kwargs): |
1424 | 1408 | import mlx.core as mx |
|
0 commit comments