Skip to content

Commit 26722cd

Browse files
committed
chore(utils): centralize context length helper
1 parent dd9846b commit 26722cd

3 files changed

Lines changed: 47 additions & 38 deletions

File tree

xinference/model/llm/mlx/core.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
QWEN_TOOL_CALL_FAMILY,
5757
ChatModelMixin,
5858
generate_completion_chunk,
59+
get_context_length_from_config,
5960
)
6061

6162
logger = logging.getLogger(__name__)
@@ -436,23 +437,6 @@ class PromptCache:
436437
tokens: List[int] = field(default_factory=list)
437438

438439

439-
def get_context_length(config: dict) -> int:
440-
"""Get the context length of a model from model config."""
441-
if config.get("max_sequence_length") is not None:
442-
max_sequence_length = config["max_sequence_length"]
443-
else:
444-
max_sequence_length = 2048
445-
if config.get("seq_length") is not None:
446-
seq_length = config["seq_length"]
447-
else:
448-
seq_length = 2048
449-
if config.get("max_position_embeddings") is not None:
450-
max_position_embeddings = config["max_position_embeddings"]
451-
else:
452-
max_position_embeddings = 2048
453-
return max(max_sequence_length, seq_length, max_position_embeddings)
454-
455-
456440
class MLXModel(LLM, ChatModelMixin):
457441
_rank_to_addresses: Optional[Dict[int, str]]
458442
allow_batch: bool = False
@@ -752,7 +736,7 @@ def wait_for_load(self):
752736
# get context length
753737
config = load_config(Path(self.model_path))
754738
config.update(self._model_config)
755-
self._context_length = get_context_length(config)
739+
self._context_length = get_context_length_from_config(config)
756740

757741
# Update allow_batch based on distributed inference
758742
# Only enable continuous batching for non-distributed inference (single worker)
@@ -1418,7 +1402,7 @@ def load(self):
14181402
# get context length
14191403
config = load_config(Path(self.model_path))
14201404
config.update(self._model_config)
1421-
self._context_length = get_context_length(config)
1405+
self._context_length = get_context_length_from_config(config)
14221406

14231407
def _generate_stream_inner(self, **kwargs):
14241408
import mlx.core as mx

xinference/model/llm/transformers/utils.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
max_tokens_field,
3838
)
3939
from ...scheduler.request import InferenceRequest
40+
from ..utils import get_context_length_from_config
4041

4142
if TYPE_CHECKING:
4243
from ...llm.transformers.core import PytorchModel
@@ -46,25 +47,7 @@
4647

4748
def get_context_length(config) -> int:
4849
"""Get the context length of a model from a huggingface model config."""
49-
if (
50-
hasattr(config, "max_sequence_length")
51-
and config.max_sequence_length is not None
52-
):
53-
max_sequence_length = config.max_sequence_length
54-
else:
55-
max_sequence_length = 2048
56-
if hasattr(config, "seq_length") and config.seq_length is not None:
57-
seq_length = config.seq_length
58-
else:
59-
seq_length = 2048
60-
if (
61-
hasattr(config, "max_position_embeddings")
62-
and config.max_position_embeddings is not None
63-
):
64-
max_position_embeddings = config.max_position_embeddings
65-
else:
66-
max_position_embeddings = 2048
67-
return max(max_sequence_length, seq_length, max_position_embeddings)
50+
return get_context_length_from_config(config)
6851

6952

7053
def prepare_logits_processor(

xinference/model/llm/utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,48 @@
5656
logger = logging.getLogger(__name__)
5757

5858

59+
_CONTEXT_LENGTH_KEYS: Tuple[str, ...] = (
60+
"max_sequence_length",
61+
"seq_length",
62+
"max_position_embeddings",
63+
"sliding_window",
64+
)
65+
66+
67+
def _get_config_value(config: Union[dict, Any], key: str) -> Any:
68+
if isinstance(config, dict):
69+
return config.get(key)
70+
return getattr(config, key, None)
71+
72+
73+
def _collect_context_length_candidates(
74+
config: Union[dict, Any], nested_attrs: Iterable[str]
75+
) -> List[int]:
76+
candidates: List[int] = []
77+
for key in _CONTEXT_LENGTH_KEYS:
78+
value = _get_config_value(config, key)
79+
if value is not None:
80+
candidates.append(value)
81+
for nested_attr in nested_attrs:
82+
nested = _get_config_value(config, nested_attr)
83+
if nested is not None:
84+
candidates.extend(_collect_context_length_candidates(nested, nested_attrs))
85+
return candidates
86+
87+
88+
def get_context_length_from_config(
89+
config: Union[dict, Any], nested_attrs: Iterable[str] = ("text_config",)
90+
) -> int:
91+
"""
92+
Determine a reasonable context length from model config dictionaries or
93+
HuggingFace config objects.
94+
"""
95+
candidates = _collect_context_length_candidates(config, nested_attrs)
96+
if not candidates:
97+
return 2048
98+
return max(candidates)
99+
100+
59101
QWEN_TOOL_CALL_FAMILY = [
60102
"qwen1.5-chat",
61103
"qwen1.5-moe-chat",

0 commit comments

Comments
 (0)