|
1 | | -import os |
| 1 | +""" |
| 2 | +Import as: |
| 3 | +
|
| 4 | +import config.config as cconf |
| 5 | +""" |
| 6 | + |
2 | 7 | import dataclasses |
3 | 8 | import functools |
4 | | -import pydantic |
| 9 | +import os |
5 | 10 |
|
6 | 11 | import dotenv |
| 12 | +import langchain_anthropic |
| 13 | +import langchain_google_genai |
7 | 14 | import langchain_openai |
8 | | -import langchain_anthropic # ChatAnthropic |
9 | | -import langchain_google_genai # ChatGoogleGenerativeAI |
10 | | -# import langchain_groq # ChatGroq |
11 | | -# import langchain_mistralai # ChatMistralAI |
12 | | -# import langchain_ollama # ChatOllama |
13 | | - |
| 15 | +import pydantic |
14 | 16 |
|
15 | 17 | dataclass = dataclasses.dataclass |
16 | 18 | lru_cache = functools.lru_cache |
17 | 19 | ChatOpenAI = langchain_openai.ChatOpenAI |
18 | 20 | ChatAnthropic = langchain_anthropic.ChatAnthropic |
19 | 21 | ChatGoogleGenerativeAI = langchain_google_genai.ChatGoogleGenerativeAI |
20 | | -# ChatGroq = langchain_groq.ChatGroq |
21 | | -# ChatMistralAI = langchain_mistralai.ChatMistralAI |
22 | | -# ChatOllama = langchain_ollama.ChatOllama |
23 | 22 | SecretStr = pydantic.SecretStr |
24 | 23 |
|
25 | | -# Load Variables |
26 | 24 | dotenv.load_dotenv() |
27 | 25 |
|
28 | 26 |
|
29 | | -# Immutable data class |
30 | 27 | @dataclass(frozen=True) |
31 | 28 | class Settings: |
| 29 | + """ |
| 30 | + Store model provider settings. |
| 31 | + """ |
| 32 | + |
32 | 33 | provider: str |
33 | 34 | model: str |
34 | 35 | temperature: float |
35 | 36 | timeout: float |
36 | 37 | max_retries: int |
37 | 38 |
|
38 | | -def _need(name:str) -> str: |
39 | | - v = os.getenv(name) |
40 | | - if v is None or v == "": |
| 39 | + |
| 40 | +def _need(name: str) -> str: |
| 41 | + """ |
| 42 | + Read a required environment variable. |
| 43 | +
|
| 44 | + :param name: environment variable name |
| 45 | + :return: environment variable value |
| 46 | + """ |
| 47 | + value = os.getenv(name) |
| 48 | + if value is None or value == "": |
41 | 49 | raise RuntimeError(f"Missing required environment variable: {name}") |
42 | | - return v |
| 50 | + return value |
| 51 | + |
43 | 52 |
|
44 | 53 | @lru_cache(maxsize=1) |
45 | 54 | def get_settings() -> Settings: |
46 | | - return Settings( |
| 55 | + """ |
| 56 | + Build settings from environment variables. |
| 57 | +
|
| 58 | + :return: configured settings |
| 59 | + """ |
| 60 | + settings = Settings( |
47 | 61 | provider=os.getenv("LLM_PROVIDER", "openai"), |
48 | 62 | model=os.getenv("LLM_MODEL", "gpt-5-nano"), |
49 | 63 | temperature=float(os.getenv("LLM_TEMP", 0.2)), |
50 | 64 | timeout=float(os.getenv("LLM_TIMEOUT", 60)), |
51 | 65 | max_retries=int(os.getenv("LLM_MAX_RETRIES", 2)), |
52 | | - |
53 | 66 | ) |
| 67 | + return settings |
54 | 68 |
|
55 | | -@lru_cache(maxsize=1) |
56 | | -def get_chat_model(model=get_settings().model): |
57 | | - s = get_settings() |
58 | | - |
59 | | - # OpenAI-adjacent |
60 | 69 |
|
61 | | - if s.provider == "openai": |
62 | | - |
63 | | - # READ API KEY. |
| 70 | +@lru_cache(maxsize=1) |
| 71 | +def get_chat_model(*, model: str | None = None) -> object: |
| 72 | + """ |
| 73 | + Build the configured chat model client. |
| 74 | +
|
| 75 | + :param model: optional model override |
| 76 | + :return: langchain chat model client |
| 77 | + """ |
| 78 | + settings = get_settings() |
| 79 | + model_name = settings.model if model is None else model |
| 80 | + provider = settings.provider |
| 81 | + if provider == "openai": |
64 | 82 | _need("OPENAI_API_KEY") |
65 | | - |
66 | | - # Return the chatmodel |
67 | | - |
68 | | - return ChatOpenAI( |
69 | | - model=s.model, |
70 | | - temperature=s.temperature, |
71 | | - timeout=s.timeout, |
72 | | - max_retries=s.max_retries, |
| 83 | + chat_model = ChatOpenAI( |
| 84 | + model=model_name, |
| 85 | + temperature=settings.temperature, |
| 86 | + timeout=settings.timeout, |
| 87 | + max_retries=settings.max_retries, |
73 | 88 | ) |
74 | | - |
75 | | - if s.provider == "openai_compatible": |
76 | | - |
77 | | - # Secrets. |
| 89 | + elif provider == "openai_compatible": |
78 | 90 | base_url = _need("OPENAI_COMPAT_BASE_URL") |
79 | 91 | api_key = _need("OPENAI_COMPAT_API_KEY") |
80 | | - return ChatOpenAI( |
81 | | - model=model, |
| 92 | + chat_model = ChatOpenAI( |
| 93 | + model=model_name, |
82 | 94 | base_url=base_url, |
83 | 95 | api_key=SecretStr(api_key), |
84 | | - temperature=s.temperature, |
85 | | - timeout=s.timeout, |
86 | | - max_retries=s.max_retries, |
87 | | - |
| 96 | + temperature=settings.temperature, |
| 97 | + timeout=settings.timeout, |
| 98 | + max_retries=settings.max_retries, |
88 | 99 | ) |
89 | | - |
90 | | - if s.provider == "azure_openai_v1": |
91 | | - |
92 | | - # Secrets. |
| 100 | + elif provider == "azure_openai_v1": |
93 | 101 | azure_base = _need("AZURE_OPENAI_BASE_URL") |
94 | 102 | azure_key = SecretStr(_need("AZURE_OPENAI_API_KEY")) |
95 | | - |
96 | | - return ChatOpenAI( |
97 | | - model=s.model, |
| 103 | + chat_model = ChatOpenAI( |
| 104 | + model=model_name, |
98 | 105 | base_url=azure_base, |
99 | 106 | api_key=azure_key, |
100 | | - temperature=s.temperature, |
101 | | - timeout=s.timeout, |
102 | | - max_retries=s.max_retries, |
103 | | - |
| 107 | + temperature=settings.temperature, |
| 108 | + timeout=settings.timeout, |
| 109 | + max_retries=settings.max_retries, |
104 | 110 | ) |
105 | | - |
106 | | - # Anthropic |
107 | | - |
108 | | - if s.provider == "anthropic": |
109 | | - |
110 | | - # Secrets. |
111 | | - _need("ANTHROPIC_API_KEY") |
112 | | - return ChatAnthropic( |
113 | | - model_name=s.model, |
114 | | - temperature=s.temperature, |
115 | | - timeout=s.timeout, |
116 | | - max_retries=s.max_retries, |
117 | | - stop=None |
118 | | - ) |
119 | | - |
120 | | - # Google |
121 | | - if s.provider in ("google", "gemini", "google_genai"): |
122 | | - # Secrets. |
| 111 | + elif provider == "anthropic": |
| 112 | + _need("ANTHROPIC_API_KEY") |
| 113 | + chat_model = ChatAnthropic( |
| 114 | + model_name=model_name, |
| 115 | + temperature=settings.temperature, |
| 116 | + timeout=settings.timeout, |
| 117 | + max_retries=settings.max_retries, |
| 118 | + stop=None, |
| 119 | + ) |
| 120 | + elif provider in ("google", "gemini", "google_genai"): |
123 | 121 | _need("GOOGLE_API_KEY") |
124 | | - return ChatGoogleGenerativeAI( |
125 | | - model=s.model, |
126 | | - temperature=s.temperature, |
| 122 | + chat_model = ChatGoogleGenerativeAI( |
| 123 | + model=model_name, |
| 124 | + temperature=settings.temperature, |
127 | 125 | ) |
128 | | - |
129 | | - |
130 | | - |
131 | | - |
132 | | - |
133 | | - raise ValueError("TODO(*): expand support!") |
| 126 | + else: |
| 127 | + raise ValueError(f"Unsupported provider='{provider}'") |
| 128 | + return chat_model |
0 commit comments