Skip to content

Commit 359e989

Browse files
committed
✨ implement /v1/guardrail/checks endpoint
1 parent 7285f2c commit 359e989

5 files changed

Lines changed: 504 additions & 17 deletions

File tree

nemoguardrails/rails/llm/llmrails.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,11 +1680,11 @@ async def check_async(
16801680
result_content = _get_last_response_content(response)
16811681

16821682
if blocking_rail:
1683-
return RailsResult(status=RailStatus.BLOCKED, content=result_content, rail=blocking_rail)
1683+
return RailsResult(status=RailStatus.BLOCKED, content=result_content, rail=blocking_rail, log=response.log)
16841684

16851685
if result_content != original_content:
1686-
return RailsResult(status=RailStatus.MODIFIED, content=result_content)
1687-
return RailsResult(status=RailStatus.PASSED, content=result_content)
1686+
return RailsResult(status=RailStatus.MODIFIED, content=result_content, log=response.log)
1687+
return RailsResult(status=RailStatus.PASSED, content=result_content, log=response.log)
16881688

16891689
def check(
16901690
self,

nemoguardrails/rails/llm/options.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ class RailsResult(BaseModel):
100100
status: RailStatus = Field(description="Status of the rails check: passed, modified, or blocked.")
101101
content: str = Field(description="The content after rails processing.")
102102
rail: Optional[str] = Field(default=None, description="Name of the rail that blocked the content.")
103+
log: Optional["GenerationLog"] = Field(default=None, description="Generation log from the rails check.")
103104

104105

105106
class GenerationLogOptions(BaseModel):

nemoguardrails/server/api.py

Lines changed: 128 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,16 @@
3535

3636
from nemoguardrails import LLMRails, RailsConfig, utils
3737
from nemoguardrails.rails.llm.config import Model
38-
from nemoguardrails.rails.llm.options import GenerationResponse
38+
from nemoguardrails.rails.llm.options import GenerationResponse, RailStatus
3939
from nemoguardrails.server.datastore.datastore import DataStore
4040
from nemoguardrails.server.schemas.openai import (
41+
GuardrailCheckDataOutput,
42+
GuardrailCheckRequest,
43+
GuardrailCheckResponse,
4144
GuardrailsChatCompletion,
4245
GuardrailsChatCompletionRequest,
4346
OpenAIModelsList,
47+
RailStatusEntry,
4448
)
4549
from nemoguardrails.server.schemas.utils import (
4650
create_error_chat_completion,
@@ -328,6 +332,20 @@ def _update_models_in_config(config: RailsConfig, main_model: Model) -> RailsCon
328332
return config.model_copy(update={"models": models})
329333

330334

335+
def _inject_model(config: RailsConfig, model_name: str) -> RailsConfig:
336+
"""Inject the request's model into a RailsConfig using env-based engine/base_url."""
337+
engine = os.environ.get("MAIN_MODEL_ENGINE")
338+
if not engine:
339+
engine = "openai"
340+
log.warning("MAIN_MODEL_ENGINE not set, defaulting to 'openai'. ")
341+
parameters = {}
342+
base_url = os.environ.get("MAIN_MODEL_BASE_URL")
343+
if base_url:
344+
parameters["base_url"] = base_url
345+
main_model = Model(model=model_name, type="main", engine=engine, parameters=parameters)
346+
return _update_models_in_config(config, main_model)
347+
348+
331349
async def _get_rails(config_ids: List[str], model_name: Optional[str] = None) -> LLMRails:
332350
"""Returns the rails instance for the given config id and model.
333351
@@ -373,18 +391,7 @@ async def _get_rails(config_ids: List[str], model_name: Optional[str] = None) ->
373391
raise ValueError("No valid rails configuration found.")
374392

375393
if model_name:
376-
engine = os.environ.get("MAIN_MODEL_ENGINE")
377-
if not engine:
378-
engine = "openai"
379-
log.warning("MAIN_MODEL_ENGINE not set, defaulting to 'openai'. ")
380-
381-
parameters = {}
382-
base_url = os.environ.get("MAIN_MODEL_BASE_URL")
383-
if base_url:
384-
parameters["base_url"] = base_url
385-
386-
main_model = Model(model=model_name, type="main", engine=engine, parameters=parameters)
387-
full_llm_rails_config = _update_models_in_config(full_llm_rails_config, main_model)
394+
full_llm_rails_config = _inject_model(full_llm_rails_config, model_name)
388395

389396
llm_rails = LLMRails(config=full_llm_rails_config, verbose=True)
390397
llm_rails_instances[configs_cache_key] = llm_rails
@@ -643,6 +650,114 @@ async def chat_completion(body: GuardrailsChatCompletionRequest, request: Reques
643650
)
644651

645652

653+
def _filter_log(log_dict: dict, log_options) -> dict:
654+
"""Filter log output based on caller's log preferences.
655+
656+
check_async always enables activated_rails internally (needed for
657+
rails_status), but the response log should only include fields the
658+
caller requested.
659+
"""
660+
filtered = {}
661+
if log_options.activated_rails:
662+
filtered["activated_rails"] = log_dict.get("activated_rails", [])
663+
else:
664+
filtered["activated_rails"] = []
665+
if log_options.llm_calls and "llm_calls" in log_dict:
666+
filtered["llm_calls"] = log_dict["llm_calls"]
667+
if log_options.internal_events and "internal_events" in log_dict:
668+
filtered["internal_events"] = log_dict["internal_events"]
669+
if log_options.colang_history and "colang_history" in log_dict:
670+
filtered["colang_history"] = log_dict["colang_history"]
671+
if "stats" in log_dict:
672+
filtered["stats"] = log_dict["stats"]
673+
return filtered
674+
675+
676+
def _map_rail_status(status: RailStatus) -> str:
677+
"""Map internal RailStatus to upstream StatusEnum values."""
678+
if status == RailStatus.BLOCKED:
679+
return "blocked"
680+
return "success"
681+
682+
683+
def _build_rails_status(result) -> dict:
684+
"""Build rails_status dict from activated rails in the generation log."""
685+
rails_status = {}
686+
if result.log and result.log.activated_rails:
687+
for rail in result.log.activated_rails:
688+
rail_status = "blocked" if rail.stop else "success"
689+
rails_status[rail.name] = RailStatusEntry(status=rail_status)
690+
return rails_status
691+
692+
693+
@app.post(
694+
"/v1/guardrail/checks",
695+
response_model=GuardrailCheckResponse,
696+
response_model_exclude_none=True,
697+
)
698+
async def guardrail_check(body: GuardrailCheckRequest, request: Request):
699+
"""Guardrail check request."""
700+
api_request_headers.set(request.headers)
701+
702+
if not body.messages:
703+
raise HTTPException(status_code=422, detail="messages must be non-empty")
704+
705+
config_ids = None
706+
config = body.guardrails.config
707+
708+
if isinstance(config, dict):
709+
try:
710+
rails_config = RailsConfig.from_content(config=config)
711+
if body.model:
712+
rails_config = _inject_model(rails_config, body.model)
713+
llm_rails = LLMRails(config=rails_config, verbose=True)
714+
except Exception as ex:
715+
log.exception(ex)
716+
raise HTTPException(status_code=422, detail=f"Invalid inline config: {ex}")
717+
else:
718+
if isinstance(config, str):
719+
config_ids = [config]
720+
elif body.guardrails.config_ids:
721+
config_ids = list(body.guardrails.config_ids)
722+
elif app.default_config_id:
723+
config_ids = [app.default_config_id]
724+
else:
725+
raise HTTPException(
726+
status_code=422,
727+
detail="No guardrails config_id provided and server has no default configuration",
728+
)
729+
try:
730+
llm_rails = await _get_rails(config_ids, model_name=body.model)
731+
except ValueError as ex:
732+
log.exception(ex)
733+
raise HTTPException(status_code=422, detail=str(ex))
734+
735+
try:
736+
messages = list(body.messages)
737+
if body.guardrails.context:
738+
messages.insert(0, {"role": "context", "content": body.guardrails.context})
739+
740+
result = await llm_rails.check_async(messages=messages)
741+
742+
log_dict = _filter_log(result.log.model_dump(), body.guardrails.options.log) if result.log else None
743+
guardrails_data = GuardrailCheckDataOutput(
744+
config_ids=config_ids,
745+
log=log_dict,
746+
)
747+
748+
return GuardrailCheckResponse(
749+
status=_map_rail_status(result.status),
750+
rails_status=_build_rails_status(result),
751+
guardrails_data=guardrails_data,
752+
)
753+
754+
except HTTPException:
755+
raise
756+
except Exception as ex:
757+
log.exception(ex)
758+
raise HTTPException(status_code=500, detail="Internal server error")
759+
760+
646761
# By default, there are no challenges
647762
challenges = []
648763

nemoguardrails/server/schemas/openai.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"""OpenAI API schema definitions for the NeMo Guardrails server."""
1717

1818
import os
19-
from typing import Any, List, Literal, Optional, Union
19+
from typing import Any, Dict, List, Literal, Optional, Union
2020

2121
from openai.types.chat.chat_completion import ChatCompletion
2222
from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator
@@ -165,3 +165,60 @@ class OpenAIModelsList(BaseModel):
165165
"""Standard OpenAI models list response."""
166166

167167
data: list[OpenAIModel] = Field(..., description="List of OpenAI model objects.")
168+
169+
170+
class RailStatusEntry(BaseModel):
171+
"""Status of an individual rail."""
172+
173+
status: str = Field(..., description="Status of the individual rail.")
174+
175+
176+
class GuardrailCheckDataInput(GuardrailsDataInput):
177+
"""Guardrails input options specific to the checks endpoint."""
178+
179+
config: Optional[Union[str, dict]] = Field(
180+
default=None,
181+
description="The id of the configuration or its dict representation to be used.",
182+
)
183+
184+
@model_validator(mode="before")
185+
@classmethod
186+
def validate_config_exclusivity(cls, data: Any) -> Any:
187+
if isinstance(data, dict) and data.get("config") is not None:
188+
if data.get("config_id") is not None or data.get("config_ids") is not None:
189+
raise ValueError("config is mutually exclusive with config_id and config_ids")
190+
return data
191+
192+
193+
class GuardrailCheckRequest(OpenAIChatCompletionRequest):
194+
"""Request body for the /v1/guardrail/checks endpoint."""
195+
196+
guardrails: GuardrailCheckDataInput = Field(
197+
default_factory=GuardrailCheckDataInput,
198+
description="Guardrails specific options for the request.",
199+
)
200+
201+
202+
class GuardrailCheckDataOutput(BaseModel):
203+
"""Guardrails-specific output data for the checks endpoint (upstream-aligned)."""
204+
205+
llm_output: Optional[dict] = Field(default=None, description="Contains any additional output coming from the LLM.")
206+
config_ids: Optional[List[str]] = Field(
207+
default=None,
208+
description="The list of configuration ids that were used.",
209+
)
210+
output_data: Optional[dict] = Field(
211+
default=None,
212+
description="The output data, i.e. a dict with the values corresponding to the output_vars.",
213+
)
214+
log: Optional[dict] = Field(default=None, description="Additional logging information.")
215+
216+
217+
class GuardrailCheckResponse(BaseModel):
218+
"""Response from the /v1/guardrail/checks endpoint."""
219+
220+
status: str = Field(..., description="Overall status indicating if all rails passed or if any failed.")
221+
rails_status: Dict[str, RailStatusEntry] = Field(..., description="Dictionary mapping each rail to its status.")
222+
guardrails_data: Optional[GuardrailCheckDataOutput] = Field(
223+
default=None, description="Additional data related to guardrails."
224+
)

0 commit comments

Comments
 (0)