Skip to content

Commit c2bc691

Browse files
feat: add ray_serve as llm provider
1 parent c35c4f8 commit c2bc691

3 files changed

Lines changed: 176 additions & 0 deletions

File tree

graphgen/common/init_llm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ def __init__(self, backend: str, config: Dict[str, Any]):
4646
from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper
4747

4848
self.llm_instance = VLLMWrapper(**config)
49+
elif backend == "ray_serve":
50+
from graphgen.models.llm.api.ray_serve_client import RayServeClient
51+
52+
self.llm_instance = RayServeClient(**config)
4953
else:
5054
raise NotImplementedError(f"Backend {backend} is not implemented yet.")
5155

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from typing import Any, List, Optional
2+
3+
from graphgen.bases.base_llm_wrapper import BaseLLMWrapper
4+
from graphgen.bases.datatypes import Token
5+
6+
7+
class RayServeClient(BaseLLMWrapper):
8+
"""
9+
A client to interact with a Ray Serve deployment.
10+
"""
11+
12+
def __init__(
13+
self,
14+
*,
15+
app_name: Optional[str] = None,
16+
deployment_name: Optional[str] = None,
17+
serve_backend: Optional[str] = None,
18+
**kwargs: Any,
19+
):
20+
try:
21+
from ray import serve
22+
except ImportError as e:
23+
raise ImportError(
24+
"Ray is not installed. Please install it with `pip install ray[serve]`."
25+
) from e
26+
27+
super().__init__(**kwargs)
28+
29+
# Try to get existing handle first
30+
self.handle = None
31+
if app_name:
32+
try:
33+
self.handle = serve.get_app_handle(app_name)
34+
except Exception:
35+
pass
36+
elif deployment_name:
37+
try:
38+
self.handle = serve.get_deployment(deployment_name).get_handle()
39+
except Exception:
40+
pass
41+
42+
# If no handle found, try to deploy if serve_backend is provided
43+
if self.handle is None:
44+
if serve_backend:
45+
if not app_name:
46+
import uuid
47+
48+
app_name = f"llm_app_{serve_backend}_{uuid.uuid4().hex[:8]}"
49+
50+
print(
51+
f"Deploying Ray Serve app '{app_name}' with backend '{serve_backend}'..."
52+
)
53+
from graphgen.models.llm.local.ray_serve_deployment import LLMDeployment
54+
55+
# Filter kwargs to avoid passing unrelated args if necessary,
56+
# but LLMDeployment config accepts everything for now.
57+
# Note: We need to pass kwargs as the config dict.
58+
deployment = LLMDeployment.bind(backend=serve_backend, config=kwargs)
59+
serve.run(deployment, name=app_name, route_prefix=f"/{app_name}")
60+
self.handle = serve.get_app_handle(app_name)
61+
elif app_name or deployment_name:
62+
raise ValueError(
63+
f"Ray Serve app/deployment '{app_name or deployment_name}' "
64+
"not found and 'serve_backend' not provided to deploy it."
65+
)
66+
else:
67+
raise ValueError(
68+
"Either 'app_name', 'deployment_name' or 'serve_backend' "
69+
"must be provided for RayServeClient."
70+
)
71+
72+
async def generate_answer(
73+
self, text: str, history: Optional[List[str]] = None, **extra: Any
74+
) -> str:
75+
"""Generate answer from the model."""
76+
return await self.handle.generate_answer.remote(text, history, **extra)
77+
78+
async def generate_topk_per_token(
79+
self, text: str, history: Optional[List[str]] = None, **extra: Any
80+
) -> List[Token]:
81+
"""Generate top-k tokens for the next token prediction."""
82+
return await self.handle.generate_topk_per_token.remote(text, history, **extra)
83+
84+
async def generate_inputs_prob(
85+
self, text: str, history: Optional[List[str]] = None, **extra: Any
86+
) -> List[Token]:
87+
"""Generate probabilities for each token in the input."""
88+
return await self.handle.generate_inputs_prob.remote(text, history, **extra)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import os
2+
from typing import Any, Dict, List, Optional
3+
4+
from ray import serve
5+
from starlette.requests import Request
6+
7+
from graphgen.bases.datatypes import Token
8+
from graphgen.models.tokenizer import Tokenizer
9+
10+
11+
@serve.deployment
12+
class LLMDeployment:
13+
def __init__(self, backend: str, config: Dict[str, Any]):
14+
self.backend = backend
15+
16+
# Initialize tokenizer if needed
17+
tokenizer_model = os.environ.get("TOKENIZER_MODEL", "cl100k_base")
18+
if "tokenizer" not in config:
19+
tokenizer = Tokenizer(model_name=tokenizer_model)
20+
config["tokenizer"] = tokenizer
21+
22+
if backend == "vllm":
23+
from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper
24+
25+
self.llm_instance = VLLMWrapper(**config)
26+
elif backend == "huggingface":
27+
from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper
28+
29+
self.llm_instance = HuggingFaceWrapper(**config)
30+
elif backend == "sglang":
31+
from graphgen.models.llm.local.sglang_wrapper import SGLangWrapper
32+
33+
self.llm_instance = SGLangWrapper(**config)
34+
else:
35+
raise NotImplementedError(
36+
f"Backend {backend} is not implemented for Ray Serve yet."
37+
)
38+
39+
async def generate_answer(
40+
self, text: str, history: Optional[List[str]] = None, **extra: Any
41+
) -> str:
42+
return await self.llm_instance.generate_answer(text, history, **extra)
43+
44+
async def generate_topk_per_token(
45+
self, text: str, history: Optional[List[str]] = None, **extra: Any
46+
) -> List[Token]:
47+
return await self.llm_instance.generate_topk_per_token(text, history, **extra)
48+
49+
async def generate_inputs_prob(
50+
self, text: str, history: Optional[List[str]] = None, **extra: Any
51+
) -> List[Token]:
52+
return await self.llm_instance.generate_inputs_prob(text, history, **extra)
53+
54+
async def __call__(self, request: Request) -> Dict:
55+
try:
56+
data = await request.json()
57+
text = data.get("text")
58+
history = data.get("history")
59+
method = data.get("method", "generate_answer")
60+
kwargs = data.get("kwargs", {})
61+
62+
if method == "generate_answer":
63+
result = await self.generate_answer(text, history, **kwargs)
64+
elif method == "generate_topk_per_token":
65+
result = await self.generate_topk_per_token(text, history, **kwargs)
66+
elif method == "generate_inputs_prob":
67+
result = await self.generate_inputs_prob(text, history, **kwargs)
68+
else:
69+
return {"error": f"Method {method} not supported"}
70+
71+
return {"result": result}
72+
except Exception as e:
73+
return {"error": str(e)}
74+
75+
76+
def app_builder(args: Dict[str, str]) -> Any:
77+
"""
78+
Builder function for 'serve run'.
79+
Usage: serve run graphgen.models.llm.local.ray_serve_deployment:app_builder backend=vllm model=...
80+
"""
81+
# args comes from the command line key=value pairs
82+
backend = args.pop("backend", "vllm")
83+
# remaining args are treated as config
84+
return LLMDeployment.bind(backend=backend, config=args)

0 commit comments

Comments
 (0)