Skip to content

Commit 03736bd

Browse files
Fixes; prepare for deprecation of http_client and obo_token_getter params.
1 parent 8492a9d commit 03736bd

2 files changed

Lines changed: 64 additions & 54 deletions

File tree

singlestoredb/ai/chat.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -37,26 +37,31 @@ def SingleStoreChatFactory(
3737
streaming: bool = True,
3838
http_client: Optional[httpx.Client] = None,
3939
obo_token: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
40+
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
4041
base_url: Optional[str] = None,
4142
hosting_platform: Optional[str] = None,
43+
timeout: Optional[float] = None,
4244
**kwargs: Any,
4345
) -> Union[ChatOpenAI, ChatBedrockConverse]:
4446
"""Return a chat model instance (ChatOpenAI or ChatBedrockConverse).
4547
"""
4648
# Handle api_key and obo_token as callable functions
4749
if callable(api_key):
48-
api_key_getter = api_key
50+
api_key_getter_fn = api_key
4951
else:
50-
def api_key_getter() -> Optional[str]:
52+
def api_key_getter_fn() -> Optional[str]:
5153
if api_key is None:
5254
return os.environ.get('SINGLESTOREDB_USER_TOKEN')
5355
return api_key
5456

55-
if callable(obo_token):
56-
obo_token_getter = obo_token
57+
if obo_token_getter is not None:
58+
obo_token_getter_fn = obo_token_getter
5759
else:
58-
def obo_token_getter() -> Optional[str]:
59-
return obo_token
60+
if callable(obo_token):
61+
obo_token_getter_fn = obo_token
62+
else:
63+
def obo_token_getter_fn() -> Optional[str]:
64+
return obo_token
6065

6166
# handle model info
6267
if base_url is None:
@@ -99,6 +104,10 @@ def obo_token_getter() -> Optional[str]:
99104
elif isinstance(t, (int, float)):
100105
connect_timeout = float(t)
101106
read_timeout = float(t)
107+
if timeout is not None:
108+
connect_timeout = timeout
109+
read_timeout = timeout
110+
t = httpx.Timeout(timeout)
102111

103112
if info.hosting_platform == 'Amazon':
104113
# Instantiate Bedrock client
@@ -123,12 +132,12 @@ def obo_token_getter() -> Optional[str]:
123132

124133
def _inject_headers(request: Any, **_ignored: Any) -> None:
125134
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
126-
if api_key_getter is not None:
127-
token_val = api_key_getter()
135+
if api_key_getter_fn is not None:
136+
token_val = api_key_getter_fn()
128137
if token_val:
129138
request.headers['Authorization'] = f'Bearer {token_val}'
130-
if obo_token_getter is not None:
131-
obo_val = obo_token_getter()
139+
if obo_token_getter_fn is not None:
140+
obo_val = obo_token_getter_fn()
132141
if obo_val:
133142
request.headers['X-S2-OBO'] = obo_val
134143
request.headers.pop('X-Amz-Date', None)
@@ -167,30 +176,26 @@ class OpenAIAuth(httpx.Auth):
167176
def auth_flow(
168177
self, request: httpx.Request,
169178
) -> Generator[httpx.Request, None, None]:
170-
if api_key_getter is not None:
171-
token_val = api_key_getter()
179+
if api_key_getter_fn is not None:
180+
token_val = api_key_getter_fn()
172181
if token_val:
173182
request.headers['Authorization'] = f'Bearer {token_val}'
174-
if obo_token_getter is not None:
175-
obo_val = obo_token_getter()
183+
if obo_token_getter_fn is not None:
184+
obo_val = obo_token_getter_fn()
176185
if obo_val:
177186
request.headers['X-S2-OBO'] = obo_val
178187
yield request
179188

180-
# Build timeout configuration
181-
if connect_timeout is not None and read_timeout is not None:
182-
t = httpx.Timeout(connect=connect_timeout, read=read_timeout)
183-
elif connect_timeout is not None:
184-
t = httpx.Timeout(connect=connect_timeout)
185-
elif read_timeout is not None:
186-
t = httpx.Timeout(read=read_timeout)
189+
if t is not None:
190+
http_client = httpx.Client(
191+
timeout=t,
192+
auth=OpenAIAuth(),
193+
)
187194
else:
188-
t = 60.0 # default OpenAI client timeout
189-
190-
http_client = httpx.Client(
191-
timeout=t,
192-
auth=OpenAIAuth(),
193-
)
195+
http_client = httpx.Client(
196+
timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout
197+
auth=OpenAIAuth(),
198+
)
194199

195200
# OpenAI / Azure OpenAI path
196201
openai_kwargs = dict(

singlestoredb/ai/embeddings.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -36,26 +36,31 @@ def SingleStoreEmbeddingsFactory(
3636
api_key: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
3737
http_client: Optional[httpx.Client] = None,
3838
obo_token: Optional[Union[Optional[str], Callable[[], Optional[str]]]] = None,
39+
obo_token_getter: Optional[Callable[[], Optional[str]]] = None,
3940
base_url: Optional[str] = None,
4041
hosting_platform: Optional[str] = None,
42+
timeout: Optional[float] = None,
4143
**kwargs: Any,
4244
) -> Union[OpenAIEmbeddings, BedrockEmbeddings]:
4345
"""Return an embeddings model instance (OpenAIEmbeddings or BedrockEmbeddings).
4446
"""
4547
# Handle api_key and obo_token as callable functions
4648
if callable(api_key):
47-
api_key_getter = api_key
49+
api_key_getter_fn = api_key
4850
else:
49-
def api_key_getter() -> Optional[str]:
51+
def api_key_getter_fn() -> Optional[str]:
5052
if api_key is None:
5153
return os.environ.get('SINGLESTOREDB_USER_TOKEN')
5254
return api_key
5355

54-
if callable(obo_token):
55-
obo_token_getter = obo_token
56+
if obo_token_getter is not None:
57+
obo_token_getter_fn = obo_token_getter
5658
else:
57-
def obo_token_getter() -> Optional[str]:
58-
return obo_token
59+
if callable(obo_token):
60+
obo_token_getter_fn = obo_token
61+
else:
62+
def obo_token_getter_fn() -> Optional[str]:
63+
return obo_token
5964

6065
# handle model info
6166
if base_url is None:
@@ -98,6 +103,10 @@ def obo_token_getter() -> Optional[str]:
98103
elif isinstance(t, (int, float)):
99104
connect_timeout = float(t)
100105
read_timeout = float(t)
106+
if timeout is not None:
107+
connect_timeout = timeout
108+
read_timeout = timeout
109+
t = httpx.Timeout(timeout)
101110

102111
if info.hosting_platform == 'Amazon':
103112
# Instantiate Bedrock client
@@ -122,12 +131,12 @@ def obo_token_getter() -> Optional[str]:
122131

123132
def _inject_headers(request: Any, **_ignored: Any) -> None:
124133
"""Inject dynamic auth/OBO headers prior to Bedrock sending."""
125-
if api_key_getter is not None:
126-
token_val = api_key_getter()
134+
if api_key_getter_fn is not None:
135+
token_val = api_key_getter_fn()
127136
if token_val:
128137
request.headers['Authorization'] = f'Bearer {token_val}'
129-
if obo_token_getter is not None:
130-
obo_val = obo_token_getter()
138+
if obo_token_getter_fn is not None:
139+
obo_val = obo_token_getter_fn()
131140
if obo_val:
132141
request.headers['X-S2-OBO'] = obo_val
133142
request.headers.pop('X-Amz-Date', None)
@@ -157,30 +166,26 @@ class OpenAIAuth(httpx.Auth):
157166
def auth_flow(
158167
self, request: httpx.Request,
159168
) -> Generator[httpx.Request, None, None]:
160-
if api_key_getter is not None:
161-
token_val = api_key_getter()
169+
if api_key_getter_fn is not None:
170+
token_val = api_key_getter_fn()
162171
if token_val:
163172
request.headers['Authorization'] = f'Bearer {token_val}'
164-
if obo_token_getter is not None:
165-
obo_val = obo_token_getter()
173+
if obo_token_getter_fn is not None:
174+
obo_val = obo_token_getter_fn()
166175
if obo_val:
167176
request.headers['X-S2-OBO'] = obo_val
168177
yield request
169178

170-
# Build timeout configuration
171-
if connect_timeout is not None and read_timeout is not None:
172-
t = httpx.Timeout(connect=connect_timeout, read=read_timeout)
173-
elif connect_timeout is not None:
174-
t = httpx.Timeout(connect=connect_timeout)
175-
elif read_timeout is not None:
176-
t = httpx.Timeout(read=read_timeout)
179+
if t is not None:
180+
http_client = httpx.Client(
181+
timeout=t,
182+
auth=OpenAIAuth(),
183+
)
177184
else:
178-
t = 60.0 # default OpenAI client timeout
179-
180-
http_client = httpx.Client(
181-
timeout=t,
182-
auth=OpenAIAuth(),
183-
)
185+
http_client = httpx.Client(
186+
timeout=httpx.Timeout(timeout=600, connect=5.0), # default OpenAI timeout
187+
auth=OpenAIAuth(),
188+
)
184189

185190
# OpenAI / Azure OpenAI path
186191
openai_kwargs = dict(

0 commit comments

Comments
 (0)