-
Notifications
You must be signed in to change notification settings - Fork 310
Expand file tree
/
Copy pathmodel.py
More file actions
153 lines (120 loc) · 5.74 KB
/
model.py
File metadata and controls
153 lines (120 loc) · 5.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from __future__ import annotations
import logging
from typing import Callable, List, Optional
from ..imodel import IModel
from ..openai.chat_client import ChatClient
from ..openai.audio_client import AudioClient
from ..openai.embedding_client import EmbeddingClient
from ..openai.responses_client import ResponsesClient
from .model_variant import ModelVariant
from ..exception import FoundryLocalException
from .core_interop import CoreInterop
from .model_data_types import ModelInfo
logger = logging.getLogger(__name__)
class Model(IModel):
"""A model identified by an alias that groups one or more variants.
Operations are delegated to the currently selected variant.
"""
def __init__(self, model_variant: ModelVariant, core_interop: CoreInterop):
self._alias = model_variant.alias
self._variants: List[ModelVariant] = [model_variant]
# Variants are sorted by Core, so the first one added is the default
self._selected_variant = model_variant
self._core_interop = core_interop
def _add_variant(self, variant: ModelVariant) -> None:
if variant.alias != self._alias:
raise FoundryLocalException(
f"Variant alias {variant.alias} does not match model alias {self._alias}"
)
self._variants.append(variant)
# Prefer the highest priority locally cached variant
if variant.info.cached and not self._selected_variant.info.cached:
self._selected_variant = variant
def select_variant(self, variant: IModel) -> None:
"""
Select a specific model variant to use for IModel operations.
An IModel from ``variants`` can also be used directly.
:param variant: IModel to select. Must be one of the variants in ``variants``.
:raises FoundryLocalException: If variant is not valid for this model
"""
matching = next((v for v in self._variants if v.id == variant.id), None)
if matching is None:
raise FoundryLocalException(
"Input variant was not found in Variants."
)
self._selected_variant = matching
@property
def variants(self) -> List[IModel]:
"""List of all variants for this model."""
return list(self._variants) # Return a copy to prevent external modification
@property
def id(self) -> str:
"""Model Id of the currently selected variant."""
return self._selected_variant.id
@property
def alias(self) -> str:
"""Alias of this model."""
return self._alias
@property
def info(self) -> ModelInfo:
"""ModelInfo of the currently selected variant."""
return self._selected_variant.info
@property
def context_length(self) -> Optional[int]:
"""Maximum context length (in tokens) of the currently selected variant."""
return self._selected_variant.context_length
@property
def input_modalities(self) -> Optional[str]:
"""Comma-separated input modalities of the currently selected variant."""
return self._selected_variant.input_modalities
@property
def output_modalities(self) -> Optional[str]:
"""Comma-separated output modalities of the currently selected variant."""
return self._selected_variant.output_modalities
@property
def capabilities(self) -> Optional[str]:
"""Comma-separated capability tags of the currently selected variant."""
return self._selected_variant.capabilities
@property
def supports_tool_calling(self) -> Optional[bool]:
"""Whether the currently selected variant supports tool/function calling."""
return self._selected_variant.supports_tool_calling
@property
def is_cached(self) -> bool:
"""Is the currently selected variant cached locally?"""
return self._selected_variant.is_cached
@property
def is_loaded(self) -> bool:
"""Is the currently selected variant loaded in memory?"""
return self._selected_variant.is_loaded
def download(self, progress_callback: Optional[Callable[[float], None]] = None) -> None:
"""Download the currently selected variant."""
self._selected_variant.download(progress_callback)
def get_path(self) -> str:
"""Get the path to the currently selected variant."""
return self._selected_variant.get_path()
def load(self) -> None:
"""Load the currently selected variant into memory."""
self._selected_variant.load()
def unload(self) -> None:
"""Unload the currently selected variant from memory."""
self._selected_variant.unload()
def remove_from_cache(self) -> None:
"""Remove the currently selected variant from the local cache."""
self._selected_variant.remove_from_cache()
def get_chat_client(self) -> ChatClient:
"""Get a chat client for the currently selected variant."""
return self._selected_variant.get_chat_client()
def get_audio_client(self) -> AudioClient:
"""Get an audio client for the currently selected variant."""
return self._selected_variant.get_audio_client()
def get_embedding_client(self) -> EmbeddingClient:
"""Get an embedding client for the currently selected variant."""
return self._selected_variant.get_embedding_client()
def create_responses_client(self, base_url: str) -> "ResponsesClient":
"""Create a Responses API client for the currently selected variant."""
return self._selected_variant.create_responses_client(base_url)