forked from google/adk-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathregistry.py
More file actions
130 lines (101 loc) · 3.68 KB
/
registry.py
File metadata and controls
130 lines (101 loc) · 3.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The registry class for model."""
from __future__ import annotations
from functools import lru_cache
import logging
import re
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .base_llm import BaseLlm
logger = logging.getLogger('google_adk.' + __name__)
_llm_registry_dict: dict[str, type[BaseLlm]] = {}
"""Registry for LLMs.
Key is the regex that matches the model name.
Value is the class that implements the model.
"""
_compiled_regex_cache: dict[str, re.Pattern] = {}
"""Cache for compiled regex patterns, keyed by the raw regex string."""
class LLMRegistry:
"""Registry for LLMs."""
@staticmethod
def new_llm(model: str) -> BaseLlm:
"""Creates a new LLM instance.
Args:
model: The model name.
Returns:
The LLM instance.
"""
return LLMRegistry.resolve(model)(model=model)
@staticmethod
def _register(model_name_regex: str, llm_cls: type[BaseLlm]):
"""Registers a new LLM class.
Args:
model_name_regex: The regex that matches the model name.
llm_cls: The class that implements the model.
"""
if model_name_regex in _llm_registry_dict:
logger.info(
'Updating LLM class for %s from %s to %s',
model_name_regex,
_llm_registry_dict[model_name_regex],
llm_cls,
)
_llm_registry_dict[model_name_regex] = llm_cls
if model_name_regex not in _compiled_regex_cache:
_compiled_regex_cache[model_name_regex] = re.compile(model_name_regex)
@staticmethod
def register(llm_cls: type[BaseLlm]):
"""Registers a new LLM class.
Args:
llm_cls: The class that implements the model.
"""
for regex in llm_cls.supported_models():
LLMRegistry._register(regex, llm_cls)
@staticmethod
@lru_cache(maxsize=32)
def resolve(model: str) -> type[BaseLlm]:
"""Resolves the model to a BaseLlm subclass.
Args:
model: The model name.
Returns:
The BaseLlm subclass.
Raises:
ValueError: If the model is not found.
"""
for regex, llm_class in _llm_registry_dict.items():
compiled = _compiled_regex_cache.get(regex) or re.compile(regex)
if compiled.fullmatch(model):
return llm_class
# Provide helpful error messages for known patterns
error_msg = f'Model {model} not found.'
# Check if it matches known patterns that require optional dependencies
if re.match(r'^claude-', model):
error_msg += (
'\n\nClaude models require the anthropic package.'
'\nInstall it with: pip install google-adk[extensions]'
'\nOr: pip install anthropic>=0.43.0'
)
elif '/' in model:
# Any model with provider/model format likely needs LiteLLM
error_msg += (
'\n\nProvider-style models (e.g., "provider/model-name") require'
' the litellm package.'
'\nInstall it with: pip install google-adk[extensions]'
'\nOr: pip install litellm>=1.75.5'
'\n\nSupported providers include: openai, groq, anthropic, and 100+'
' others.'
'\nSee https://docs.litellm.ai/docs/providers for a full list.'
)
raise ValueError(error_msg)