|
35 | 35 |
|
36 | 36 | from nemoguardrails import LLMRails, RailsConfig, utils |
37 | 37 | from nemoguardrails.rails.llm.config import Model |
38 | | -from nemoguardrails.rails.llm.options import GenerationResponse |
| 38 | +from nemoguardrails.rails.llm.options import GenerationResponse, RailStatus |
39 | 39 | from nemoguardrails.server.datastore.datastore import DataStore |
40 | 40 | from nemoguardrails.server.schemas.openai import ( |
| 41 | + GuardrailCheckDataOutput, |
| 42 | + GuardrailCheckRequest, |
| 43 | + GuardrailCheckResponse, |
41 | 44 | GuardrailsChatCompletion, |
42 | 45 | GuardrailsChatCompletionRequest, |
43 | 46 | OpenAIModelsList, |
| 47 | + RailStatusEntry, |
44 | 48 | ) |
45 | 49 | from nemoguardrails.server.schemas.utils import ( |
46 | 50 | create_error_chat_completion, |
@@ -328,6 +332,20 @@ def _update_models_in_config(config: RailsConfig, main_model: Model) -> RailsCon |
328 | 332 | return config.model_copy(update={"models": models}) |
329 | 333 |
|
330 | 334 |
|
| 335 | +def _inject_model(config: RailsConfig, model_name: str) -> RailsConfig: |
| 336 | + """Inject the request's model into a RailsConfig using env-based engine/base_url.""" |
| 337 | + engine = os.environ.get("MAIN_MODEL_ENGINE") |
| 338 | + if not engine: |
| 339 | + engine = "openai" |
| 340 | + log.warning("MAIN_MODEL_ENGINE not set, defaulting to 'openai'. ") |
| 341 | + parameters = {} |
| 342 | + base_url = os.environ.get("MAIN_MODEL_BASE_URL") |
| 343 | + if base_url: |
| 344 | + parameters["base_url"] = base_url |
| 345 | + main_model = Model(model=model_name, type="main", engine=engine, parameters=parameters) |
| 346 | + return _update_models_in_config(config, main_model) |
| 347 | + |
| 348 | + |
331 | 349 | async def _get_rails(config_ids: List[str], model_name: Optional[str] = None) -> LLMRails: |
332 | 350 | """Returns the rails instance for the given config id and model. |
333 | 351 |
|
@@ -373,18 +391,7 @@ async def _get_rails(config_ids: List[str], model_name: Optional[str] = None) -> |
373 | 391 | raise ValueError("No valid rails configuration found.") |
374 | 392 |
|
375 | 393 | if model_name: |
376 | | - engine = os.environ.get("MAIN_MODEL_ENGINE") |
377 | | - if not engine: |
378 | | - engine = "openai" |
379 | | - log.warning("MAIN_MODEL_ENGINE not set, defaulting to 'openai'. ") |
380 | | - |
381 | | - parameters = {} |
382 | | - base_url = os.environ.get("MAIN_MODEL_BASE_URL") |
383 | | - if base_url: |
384 | | - parameters["base_url"] = base_url |
385 | | - |
386 | | - main_model = Model(model=model_name, type="main", engine=engine, parameters=parameters) |
387 | | - full_llm_rails_config = _update_models_in_config(full_llm_rails_config, main_model) |
| 394 | + full_llm_rails_config = _inject_model(full_llm_rails_config, model_name) |
388 | 395 |
|
389 | 396 | llm_rails = LLMRails(config=full_llm_rails_config, verbose=True) |
390 | 397 | llm_rails_instances[configs_cache_key] = llm_rails |
@@ -643,6 +650,114 @@ async def chat_completion(body: GuardrailsChatCompletionRequest, request: Reques |
643 | 650 | ) |
644 | 651 |
|
645 | 652 |
|
| 653 | +def _filter_log(log_dict: dict, log_options) -> dict: |
| 654 | + """Filter log output based on caller's log preferences. |
| 655 | +
|
| 656 | + check_async always enables activated_rails internally (needed for |
| 657 | + rails_status), but the response log should only include fields the |
| 658 | + caller requested. |
| 659 | + """ |
| 660 | + filtered = {} |
| 661 | + if log_options.activated_rails: |
| 662 | + filtered["activated_rails"] = log_dict.get("activated_rails", []) |
| 663 | + else: |
| 664 | + filtered["activated_rails"] = [] |
| 665 | + if log_options.llm_calls and "llm_calls" in log_dict: |
| 666 | + filtered["llm_calls"] = log_dict["llm_calls"] |
| 667 | + if log_options.internal_events and "internal_events" in log_dict: |
| 668 | + filtered["internal_events"] = log_dict["internal_events"] |
| 669 | + if log_options.colang_history and "colang_history" in log_dict: |
| 670 | + filtered["colang_history"] = log_dict["colang_history"] |
| 671 | + if "stats" in log_dict: |
| 672 | + filtered["stats"] = log_dict["stats"] |
| 673 | + return filtered |
| 674 | + |
| 675 | + |
| 676 | +def _map_rail_status(status: RailStatus) -> str: |
| 677 | + """Map internal RailStatus to upstream StatusEnum values.""" |
| 678 | + if status == RailStatus.BLOCKED: |
| 679 | + return "blocked" |
| 680 | + return "success" |
| 681 | + |
| 682 | + |
| 683 | +def _build_rails_status(result) -> dict: |
| 684 | + """Build rails_status dict from activated rails in the generation log.""" |
| 685 | + rails_status = {} |
| 686 | + if result.log and result.log.activated_rails: |
| 687 | + for rail in result.log.activated_rails: |
| 688 | + rail_status = "blocked" if rail.stop else "success" |
| 689 | + rails_status[rail.name] = RailStatusEntry(status=rail_status) |
| 690 | + return rails_status |
| 691 | + |
| 692 | + |
| 693 | +@app.post( |
| 694 | + "/v1/guardrail/checks", |
| 695 | + response_model=GuardrailCheckResponse, |
| 696 | + response_model_exclude_none=True, |
| 697 | +) |
| 698 | +async def guardrail_check(body: GuardrailCheckRequest, request: Request): |
| 699 | + """Guardrail check request.""" |
| 700 | + api_request_headers.set(request.headers) |
| 701 | + |
| 702 | + if not body.messages: |
| 703 | + raise HTTPException(status_code=422, detail="messages must be non-empty") |
| 704 | + |
| 705 | + config_ids = None |
| 706 | + config = body.guardrails.config |
| 707 | + |
| 708 | + if isinstance(config, dict): |
| 709 | + try: |
| 710 | + rails_config = RailsConfig.from_content(config=config) |
| 711 | + if body.model: |
| 712 | + rails_config = _inject_model(rails_config, body.model) |
| 713 | + llm_rails = LLMRails(config=rails_config, verbose=True) |
| 714 | + except Exception as ex: |
| 715 | + log.exception(ex) |
| 716 | + raise HTTPException(status_code=422, detail=f"Invalid inline config: {ex}") |
| 717 | + else: |
| 718 | + if isinstance(config, str): |
| 719 | + config_ids = [config] |
| 720 | + elif body.guardrails.config_ids: |
| 721 | + config_ids = list(body.guardrails.config_ids) |
| 722 | + elif app.default_config_id: |
| 723 | + config_ids = [app.default_config_id] |
| 724 | + else: |
| 725 | + raise HTTPException( |
| 726 | + status_code=422, |
| 727 | + detail="No guardrails config_id provided and server has no default configuration", |
| 728 | + ) |
| 729 | + try: |
| 730 | + llm_rails = await _get_rails(config_ids, model_name=body.model) |
| 731 | + except ValueError as ex: |
| 732 | + log.exception(ex) |
| 733 | + raise HTTPException(status_code=422, detail=str(ex)) |
| 734 | + |
| 735 | + try: |
| 736 | + messages = list(body.messages) |
| 737 | + if body.guardrails.context: |
| 738 | + messages.insert(0, {"role": "context", "content": body.guardrails.context}) |
| 739 | + |
| 740 | + result = await llm_rails.check_async(messages=messages) |
| 741 | + |
| 742 | + log_dict = _filter_log(result.log.model_dump(), body.guardrails.options.log) if result.log else None |
| 743 | + guardrails_data = GuardrailCheckDataOutput( |
| 744 | + config_ids=config_ids, |
| 745 | + log=log_dict, |
| 746 | + ) |
| 747 | + |
| 748 | + return GuardrailCheckResponse( |
| 749 | + status=_map_rail_status(result.status), |
| 750 | + rails_status=_build_rails_status(result), |
| 751 | + guardrails_data=guardrails_data, |
| 752 | + ) |
| 753 | + |
| 754 | + except HTTPException: |
| 755 | + raise |
| 756 | + except Exception as ex: |
| 757 | + log.exception(ex) |
| 758 | + raise HTTPException(status_code=500, detail="Internal server error") |
| 759 | + |
| 760 | + |
646 | 761 | # By default, there are no challenges |
647 | 762 | challenges = [] |
648 | 763 |
|
|
0 commit comments