forked from lightspeed-core/lightspeed-stack
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodels.py
More file actions
141 lines (111 loc) · 4.66 KB
/
models.py
File metadata and controls
141 lines (111 loc) · 4.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""Handler for REST API call to list available models."""
from typing import Annotated, Any
from fastapi import APIRouter, HTTPException, Request, Query
from fastapi.params import Depends
from llama_stack_client import APIConnectionError
from authentication import get_auth_dependency
from authentication.interface import AuthTuple
from authorization.middleware import authorize
from client import AsyncLlamaStackClientHolder
from configuration import configuration
from models.config import Action
from models.requests import ModelFilter
from models.responses import (
ForbiddenResponse,
InternalServerErrorResponse,
ModelsResponse,
ServiceUnavailableResponse,
UnauthorizedResponse,
)
from utils.endpoints import check_configuration_loaded
from log import get_logger
logger = get_logger(__name__)
router = APIRouter(tags=["models"])
def parse_llama_stack_model(model: Any) -> dict[str, Any]:
"""
Parse llama-stack model.
Converting the new llama-stack model format (0.4.x) with custom_metadata.
Args:
model: Model object from llama-stack (has id, custom_metadata, object fields)
Returns:
dict: Model in legacy format with identifier, provider_id, model_type, etc.
"""
custom_metadata = getattr(model, "custom_metadata", {}) or {}
model_type = str(custom_metadata.get("model_type", "unknown"))
metadata = {
k: v
for k, v in custom_metadata.items()
if k not in ("provider_id", "provider_resource_id", "model_type")
}
legacy_model = {
"identifier": getattr(model, "id", ""),
"metadata": metadata,
"api_model_type": model_type,
"provider_id": str(custom_metadata.get("provider_id", "")),
"type": getattr(model, "object", "model"),
"provider_resource_id": str(custom_metadata.get("provider_resource_id", "")),
"model_type": model_type,
}
return legacy_model
models_responses: dict[int | str, dict[str, Any]] = {
200: ModelsResponse.openapi_response(),
401: UnauthorizedResponse.openapi_response(
examples=["missing header", "missing token"]
),
403: ForbiddenResponse.openapi_response(examples=["endpoint"]),
500: InternalServerErrorResponse.openapi_response(examples=["configuration"]),
503: ServiceUnavailableResponse.openapi_response(),
}
@router.get("/models", responses=models_responses)
@authorize(Action.GET_MODELS)
async def models_endpoint_handler(
request: Request,
auth: Annotated[AuthTuple, Depends(get_auth_dependency())],
model_type: Annotated[ModelFilter, Query()],
) -> ModelsResponse:
"""
Handle requests to the /models endpoint.
Process GET requests to the /models endpoint, returning a list of available
models from the Llama Stack service. It is possible to specify "model_type"
query parameter that is used as a filter. For example, if model type is set
to "llm", only LLM models will be returned:
curl http://localhost:8080/v1/models?model_type=llm
The "model_type" query parameter is optional. When not specified, all models
will be returned.
## Parameters:
request: The incoming HTTP request.
auth: Authentication tuple from the auth dependency.
model_type: Optional filter to return only models matching this type.
## Raises:
HTTPException: If unable to connect to the Llama Stack server or if
model retrieval fails for any reason.
## Returns:
ModelsResponse: An object containing the list of available models.
"""
# Used only by the middleware
_ = auth
# Nothing interesting in the request
_ = request
check_configuration_loaded(configuration)
llama_stack_configuration = configuration.llama_stack_configuration
logger.info("Llama stack config: %s", llama_stack_configuration)
try:
# try to get Llama Stack client
client = AsyncLlamaStackClientHolder().get_client()
# retrieve models
models = await client.models.list()
# parse models to legacy format
parsed_models = [parse_llama_stack_model(model) for model in models]
# optional filtering by model type
if model_type.model_type is not None:
parsed_models = [
model
for model in parsed_models
if model["model_type"] == model_type.model_type
]
return ModelsResponse(models=parsed_models)
# Connection to Llama Stack server failed
except APIConnectionError as e:
logger.error("Unable to connect to Llama Stack: %s", e)
response = ServiceUnavailableResponse(backend_name="Llama Stack", cause=str(e))
raise HTTPException(**response.model_dump()) from e