|
| 1 | +import os |
| 2 | +from typing import Any, Dict, List, Optional |
| 3 | + |
| 4 | +from ray import serve |
| 5 | +from starlette.requests import Request |
| 6 | + |
| 7 | +from graphgen.bases.datatypes import Token |
| 8 | +from graphgen.models.tokenizer import Tokenizer |
| 9 | + |
| 10 | + |
| 11 | +@serve.deployment |
| 12 | +class LLMDeployment: |
| 13 | + def __init__(self, backend: str, config: Dict[str, Any]): |
| 14 | + self.backend = backend |
| 15 | + |
| 16 | + # Initialize tokenizer if needed |
| 17 | + tokenizer_model = os.environ.get("TOKENIZER_MODEL", "cl100k_base") |
| 18 | + if "tokenizer" not in config: |
| 19 | + tokenizer = Tokenizer(model_name=tokenizer_model) |
| 20 | + config["tokenizer"] = tokenizer |
| 21 | + |
| 22 | + if backend == "vllm": |
| 23 | + from graphgen.models.llm.local.vllm_wrapper import VLLMWrapper |
| 24 | + |
| 25 | + self.llm_instance = VLLMWrapper(**config) |
| 26 | + elif backend == "huggingface": |
| 27 | + from graphgen.models.llm.local.hf_wrapper import HuggingFaceWrapper |
| 28 | + |
| 29 | + self.llm_instance = HuggingFaceWrapper(**config) |
| 30 | + elif backend == "sglang": |
| 31 | + from graphgen.models.llm.local.sglang_wrapper import SGLangWrapper |
| 32 | + |
| 33 | + self.llm_instance = SGLangWrapper(**config) |
| 34 | + else: |
| 35 | + raise NotImplementedError( |
| 36 | + f"Backend {backend} is not implemented for Ray Serve yet." |
| 37 | + ) |
| 38 | + |
| 39 | + async def generate_answer( |
| 40 | + self, text: str, history: Optional[List[str]] = None, **extra: Any |
| 41 | + ) -> str: |
| 42 | + return await self.llm_instance.generate_answer(text, history, **extra) |
| 43 | + |
| 44 | + async def generate_topk_per_token( |
| 45 | + self, text: str, history: Optional[List[str]] = None, **extra: Any |
| 46 | + ) -> List[Token]: |
| 47 | + return await self.llm_instance.generate_topk_per_token(text, history, **extra) |
| 48 | + |
| 49 | + async def generate_inputs_prob( |
| 50 | + self, text: str, history: Optional[List[str]] = None, **extra: Any |
| 51 | + ) -> List[Token]: |
| 52 | + return await self.llm_instance.generate_inputs_prob(text, history, **extra) |
| 53 | + |
| 54 | + async def __call__(self, request: Request) -> Dict: |
| 55 | + try: |
| 56 | + data = await request.json() |
| 57 | + text = data.get("text") |
| 58 | + history = data.get("history") |
| 59 | + method = data.get("method", "generate_answer") |
| 60 | + kwargs = data.get("kwargs", {}) |
| 61 | + |
| 62 | + if method == "generate_answer": |
| 63 | + result = await self.generate_answer(text, history, **kwargs) |
| 64 | + elif method == "generate_topk_per_token": |
| 65 | + result = await self.generate_topk_per_token(text, history, **kwargs) |
| 66 | + elif method == "generate_inputs_prob": |
| 67 | + result = await self.generate_inputs_prob(text, history, **kwargs) |
| 68 | + else: |
| 69 | + return {"error": f"Method {method} not supported"} |
| 70 | + |
| 71 | + return {"result": result} |
| 72 | + except Exception as e: |
| 73 | + return {"error": str(e)} |
| 74 | + |
| 75 | + |
| 76 | +def app_builder(args: Dict[str, str]) -> Any: |
| 77 | + """ |
| 78 | + Builder function for 'serve run'. |
| 79 | + Usage: serve run graphgen.models.llm.local.ray_serve_deployment:app_builder backend=vllm model=... |
| 80 | + """ |
| 81 | + # args comes from the command line key=value pairs |
| 82 | + backend = args.pop("backend", "vllm") |
| 83 | + # remaining args are treated as config |
| 84 | + return LLMDeployment.bind(backend=backend, config=args) |
0 commit comments