Skip to content

Commit a3d45e2

Browse files
Noel2521ankaisen
andauthored
feat(http): add HTTP proxy support for LLM & embedding clients (#310)
Co-authored-by: ankaisen <51148505+ankaisen@users.noreply.github.com>
1 parent 9e31ef2 commit a3d45e2

2 files changed

Lines changed: 21 additions & 6 deletions

File tree

src/memu/embedding/http_client.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4+
import os
45
from collections.abc import Callable
56
from typing import Literal
67

@@ -10,6 +11,11 @@
1011
from memu.embedding.backends.doubao import DoubaoEmbeddingBackend, DoubaoMultimodalEmbeddingInput
1112
from memu.embedding.backends.openai import OpenAIEmbeddingBackend
1213

14+
15+
def _load_proxy() -> str | None:
16+
return os.getenv("MEMU_HTTP_PROXY") or os.getenv("HTTP_PROXY") or os.getenv("HTTPS_PROXY") or None
17+
18+
1319
logger = logging.getLogger(__name__)
1420

1521
EMBEDDING_BACKENDS: dict[str, Callable[[], EmbeddingBackend]] = {
@@ -49,6 +55,7 @@ def __init__(
4955
# Strip leading "/" so httpx resolves relative to base_url
5056
self.embedding_endpoint = raw_embedding_ep.lstrip("/")
5157
self.timeout = timeout
58+
self.proxy = _load_proxy()
5259

5360
async def embed(self, inputs: list[str]) -> list[list[float]]:
5461
"""
@@ -61,7 +68,7 @@ async def embed(self, inputs: list[str]) -> list[list[float]]:
6168
List of embedding vectors
6269
"""
6370
payload = self.backend.build_embedding_payload(inputs=inputs, embed_model=self.embed_model)
64-
async with httpx.AsyncClient(base_url=self.base_url, timeout=self.timeout) as client:
71+
async with httpx.AsyncClient(base_url=self.base_url, timeout=self.timeout, proxy=self.proxy) as client:
6572
resp = await client.post(self.embedding_endpoint, json=payload, headers=self._headers())
6673
resp.raise_for_status()
6774
data = resp.json()
@@ -123,7 +130,7 @@ async def embed_multimodal(
123130
)
124131

125132
endpoint = self.backend.multimodal_embedding_endpoint.lstrip("/")
126-
async with httpx.AsyncClient(base_url=self.base_url, timeout=self.timeout) as client:
133+
async with httpx.AsyncClient(base_url=self.base_url, timeout=self.timeout, proxy=self.proxy) as client:
127134
resp = await client.post(endpoint, json=payload, headers=self._headers())
128135
resp.raise_for_status()
129136
data = resp.json()

src/memu/llm/http_client.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import base64
44
import logging
5+
import os
56
from collections.abc import Callable
67
from pathlib import Path
78
from typing import Any, cast
@@ -15,6 +16,10 @@
1516
from memu.llm.backends.openrouter import OpenRouterLLMBackend
1617

1718

19+
def _load_proxy() -> str | None:
20+
return os.getenv("MEMU_HTTP_PROXY") or os.getenv("HTTP_PROXY") or os.getenv("HTTPS_PROXY") or None
21+
22+
1823
# Minimal embedding backend support (moved from embedding module)
1924
class _EmbeddingBackend:
2025
name: str
@@ -109,6 +114,7 @@ def __init__(
109114
self.embedding_endpoint = raw_embedding_ep.lstrip("/")
110115
self.timeout = timeout
111116
self.embed_model = embed_model or chat_model
117+
self.proxy = _load_proxy()
112118

113119
async def chat(
114120
self,
@@ -145,7 +151,7 @@ async def summarize(
145151
payload = self.backend.build_summary_payload(
146152
text=text, system_prompt=system_prompt, chat_model=self.chat_model, max_tokens=max_tokens
147153
)
148-
async with httpx.AsyncClient(base_url=self.base_url, timeout=self.timeout) as client:
154+
async with httpx.AsyncClient(base_url=self.base_url, timeout=self.timeout, proxy=self.proxy) as client:
149155
resp = await client.post(self.summary_endpoint, json=payload, headers=self._headers())
150156
resp.raise_for_status()
151157
data = resp.json()
@@ -195,7 +201,7 @@ async def vision(
195201
max_tokens=max_tokens,
196202
)
197203

198-
async with httpx.AsyncClient(base_url=self.base_url, timeout=self.timeout) as client:
204+
async with httpx.AsyncClient(base_url=self.base_url, timeout=self.timeout, proxy=self.proxy) as client:
199205
resp = await client.post(self.summary_endpoint, json=payload, headers=self._headers())
200206
resp.raise_for_status()
201207
data = resp.json()
@@ -205,7 +211,7 @@ async def vision(
205211
async def embed(self, inputs: list[str]) -> tuple[list[list[float]], dict[str, Any]]:
206212
"""Create text embeddings using the provider-specific embedding API."""
207213
payload = self.embedding_backend.build_embedding_payload(inputs=inputs, embed_model=self.embed_model)
208-
async with httpx.AsyncClient(base_url=self.base_url, timeout=self.timeout) as client:
214+
async with httpx.AsyncClient(base_url=self.base_url, timeout=self.timeout, proxy=self.proxy) as client:
209215
resp = await client.post(self.embedding_endpoint, json=payload, headers=self._headers())
210216
resp.raise_for_status()
211217
data = resp.json()
@@ -246,7 +252,9 @@ async def transcribe(
246252
if language:
247253
data["language"] = language
248254

249-
async with httpx.AsyncClient(base_url=self.base_url, timeout=self.timeout * 3) as client:
255+
async with httpx.AsyncClient(
256+
base_url=self.base_url, timeout=self.timeout * 3, proxy=self.proxy
257+
) as client:
250258
resp = await client.post(
251259
"/v1/audio/transcriptions",
252260
files=files,

0 commit comments

Comments
 (0)