|
35 | 35 |
|
36 | 36 | from nemoguardrails import LLMRails, RailsConfig, utils |
37 | 37 | from nemoguardrails.rails.llm.config import Model |
38 | | -from nemoguardrails.rails.llm.options import GenerationResponse |
| 38 | +from nemoguardrails.rails.llm.options import GenerationResponse, RailStatus |
39 | 39 | from nemoguardrails.server.datastore.datastore import DataStore |
40 | 40 | from nemoguardrails.server.schemas.openai import ( |
| 41 | + GuardrailCheckRequest, |
| 42 | + GuardrailCheckResponse, |
41 | 43 | GuardrailsChatCompletion, |
42 | 44 | GuardrailsChatCompletionRequest, |
43 | 45 | OpenAIModelsList, |
@@ -328,6 +330,20 @@ def _update_models_in_config(config: RailsConfig, main_model: Model) -> RailsCon |
328 | 330 | return config.model_copy(update={"models": models}) |
329 | 331 |
|
330 | 332 |
|
| 333 | +def _inject_model(config: RailsConfig, model_name: str) -> RailsConfig: |
| 334 | + """Inject the request's model into a RailsConfig using env-based engine/base_url.""" |
| 335 | + engine = os.environ.get("MAIN_MODEL_ENGINE") |
| 336 | + if not engine: |
| 337 | + engine = "openai" |
| 338 | + log.warning("MAIN_MODEL_ENGINE not set, defaulting to 'openai'. ") |
| 339 | + parameters = {} |
| 340 | + base_url = os.environ.get("MAIN_MODEL_BASE_URL") |
| 341 | + if base_url: |
| 342 | + parameters["base_url"] = base_url |
| 343 | + main_model = Model(model=model_name, type="main", engine=engine, parameters=parameters) |
| 344 | + return _update_models_in_config(config, main_model) |
| 345 | + |
| 346 | + |
331 | 347 | async def _get_rails(config_ids: List[str], model_name: Optional[str] = None) -> LLMRails: |
332 | 348 | """Returns the rails instance for the given config id and model. |
333 | 349 |
|
@@ -373,18 +389,7 @@ async def _get_rails(config_ids: List[str], model_name: Optional[str] = None) -> |
373 | 389 | raise ValueError("No valid rails configuration found.") |
374 | 390 |
|
375 | 391 | if model_name: |
376 | | - engine = os.environ.get("MAIN_MODEL_ENGINE") |
377 | | - if not engine: |
378 | | - engine = "openai" |
379 | | - log.warning("MAIN_MODEL_ENGINE not set, defaulting to 'openai'. ") |
380 | | - |
381 | | - parameters = {} |
382 | | - base_url = os.environ.get("MAIN_MODEL_BASE_URL") |
383 | | - if base_url: |
384 | | - parameters["base_url"] = base_url |
385 | | - |
386 | | - main_model = Model(model=model_name, type="main", engine=engine, parameters=parameters) |
387 | | - full_llm_rails_config = _update_models_in_config(full_llm_rails_config, main_model) |
| 392 | + full_llm_rails_config = _inject_model(full_llm_rails_config, model_name) |
388 | 393 |
|
389 | 394 | llm_rails = LLMRails(config=full_llm_rails_config, verbose=True) |
390 | 395 | llm_rails_instances[configs_cache_key] = llm_rails |
@@ -643,6 +648,79 @@ async def chat_completion(body: GuardrailsChatCompletionRequest, request: Reques |
643 | 648 | ) |
644 | 649 |
|
645 | 650 |
|
| 651 | +def _map_rail_status(status: RailStatus) -> str: |
| 652 | + """Map internal RailStatus to API status string.""" |
| 653 | + return status.value |
| 654 | + |
| 655 | + |
| 656 | +@app.post( |
| 657 | + "/v1/checks", |
| 658 | + response_model=GuardrailCheckResponse, |
| 659 | + response_model_exclude_none=True, |
| 660 | +) |
| 661 | +async def guardrail_check(body: GuardrailCheckRequest, request: Request): |
| 662 | + """Guardrail check request.""" |
| 663 | + api_request_headers.set(request.headers) |
| 664 | + |
| 665 | + if not body.messages: |
| 666 | + raise HTTPException(status_code=422, detail="messages must be non-empty") |
| 667 | + |
| 668 | + config_ids = None |
| 669 | + config = body.guardrails.config |
| 670 | + |
| 671 | + if isinstance(config, dict): |
| 672 | + try: |
| 673 | + rails_config = RailsConfig.from_content(config=config) |
| 674 | + if body.model: |
| 675 | + rails_config = _inject_model(rails_config, body.model) |
| 676 | + llm_rails = LLMRails(config=rails_config, verbose=True) |
| 677 | + except Exception as ex: |
| 678 | + log.exception(ex) |
| 679 | + raise HTTPException(status_code=422, detail=f"Invalid inline config: {ex}") |
| 680 | + else: |
| 681 | + if isinstance(config, str): |
| 682 | + config_ids = [config] |
| 683 | + elif body.guardrails.config_ids: |
| 684 | + config_ids = list(body.guardrails.config_ids) |
| 685 | + elif app.default_config_id: |
| 686 | + config_ids = [app.default_config_id] |
| 687 | + else: |
| 688 | + raise HTTPException( |
| 689 | + status_code=422, |
| 690 | + detail="No guardrails config_id provided and server has no default configuration", |
| 691 | + ) |
| 692 | + try: |
| 693 | + llm_rails = await _get_rails(config_ids, model_name=body.model) |
| 694 | + except ValueError as ex: |
| 695 | + log.exception(ex) |
| 696 | + raise HTTPException(status_code=422, detail=str(ex)) |
| 697 | + |
| 698 | + if llm_rails.config.colang_version != "1.0": |
| 699 | + raise HTTPException( |
| 700 | + status_code=422, |
| 701 | + detail="check_async does not support Colang 2.0 configurations.", |
| 702 | + ) |
| 703 | + |
| 704 | + try: |
| 705 | + messages = list(body.messages) |
| 706 | + if body.guardrails.context: |
| 707 | + messages.insert(0, {"role": "context", "content": body.guardrails.context}) |
| 708 | + |
| 709 | + result = await llm_rails.check_async(messages=messages) |
| 710 | + |
| 711 | + return GuardrailCheckResponse( |
| 712 | + status=_map_rail_status(result.status), |
| 713 | + content=result.content, |
| 714 | + rail=result.rail, |
| 715 | + ) |
| 716 | + |
| 717 | + except HTTPException: |
| 718 | + raise |
| 719 | + except Exception as ex: |
| 720 | + log.exception(ex) |
| 721 | + raise HTTPException(status_code=500, detail="Internal server error") |
| 722 | + |
| 723 | + |
646 | 724 | # By default, there are no challenges |
647 | 725 | challenges = [] |
648 | 726 |
|
|
0 commit comments