Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 91 additions & 13 deletions nemoguardrails/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@

from nemoguardrails import LLMRails, RailsConfig, utils
from nemoguardrails.rails.llm.config import Model
from nemoguardrails.rails.llm.options import GenerationResponse
from nemoguardrails.rails.llm.options import GenerationResponse, RailStatus
from nemoguardrails.server.datastore.datastore import DataStore
from nemoguardrails.server.schemas.openai import (
GuardrailCheckRequest,
GuardrailCheckResponse,
GuardrailsChatCompletion,
GuardrailsChatCompletionRequest,
OpenAIModelsList,
Expand Down Expand Up @@ -328,6 +330,20 @@ def _update_models_in_config(config: RailsConfig, main_model: Model) -> RailsCon
return config.model_copy(update={"models": models})


def _inject_model(config: RailsConfig, model_name: str) -> RailsConfig:
"""Inject the request's model into a RailsConfig using env-based engine/base_url."""
engine = os.environ.get("MAIN_MODEL_ENGINE")
if not engine:
engine = "openai"
log.warning("MAIN_MODEL_ENGINE not set, defaulting to 'openai'. ")
parameters = {}
base_url = os.environ.get("MAIN_MODEL_BASE_URL")
if base_url:
parameters["base_url"] = base_url
main_model = Model(model=model_name, type="main", engine=engine, parameters=parameters)
return _update_models_in_config(config, main_model)


Comment thread
Pouyanpi marked this conversation as resolved.
async def _get_rails(config_ids: List[str], model_name: Optional[str] = None) -> LLMRails:
"""Returns the rails instance for the given config id and model.

Expand Down Expand Up @@ -373,18 +389,7 @@ async def _get_rails(config_ids: List[str], model_name: Optional[str] = None) ->
raise ValueError("No valid rails configuration found.")

if model_name:
engine = os.environ.get("MAIN_MODEL_ENGINE")
if not engine:
engine = "openai"
log.warning("MAIN_MODEL_ENGINE not set, defaulting to 'openai'. ")

parameters = {}
base_url = os.environ.get("MAIN_MODEL_BASE_URL")
if base_url:
parameters["base_url"] = base_url

main_model = Model(model=model_name, type="main", engine=engine, parameters=parameters)
full_llm_rails_config = _update_models_in_config(full_llm_rails_config, main_model)
full_llm_rails_config = _inject_model(full_llm_rails_config, model_name)

llm_rails = LLMRails(config=full_llm_rails_config, verbose=True)
llm_rails_instances[configs_cache_key] = llm_rails
Expand Down Expand Up @@ -643,6 +648,79 @@ async def chat_completion(body: GuardrailsChatCompletionRequest, request: Reques
)


def _map_rail_status(status: RailStatus) -> str:
"""Map internal RailStatus to API status string."""
return status.value


@app.post(
"/v1/checks",
response_model=GuardrailCheckResponse,
response_model_exclude_none=True,
)
async def guardrail_check(body: GuardrailCheckRequest, request: Request):
"""Guardrail check request."""
api_request_headers.set(request.headers)

if not body.messages:
raise HTTPException(status_code=422, detail="messages must be non-empty")

config_ids = None
config = body.guardrails.config

if isinstance(config, dict):
try:
rails_config = RailsConfig.from_content(config=config)
if body.model:
rails_config = _inject_model(rails_config, body.model)
llm_rails = LLMRails(config=rails_config, verbose=True)
except Exception as ex:
log.exception(ex)
raise HTTPException(status_code=422, detail=f"Invalid inline config: {ex}")
Comment thread
m-misiura marked this conversation as resolved.
else:
if isinstance(config, str):
config_ids = [config]
elif body.guardrails.config_ids:
config_ids = list(body.guardrails.config_ids)
elif app.default_config_id:
config_ids = [app.default_config_id]
else:
raise HTTPException(
status_code=422,
detail="No guardrails config_id provided and server has no default configuration",
)
try:
llm_rails = await _get_rails(config_ids, model_name=body.model)
Comment thread
m-misiura marked this conversation as resolved.
except ValueError as ex:
log.exception(ex)
raise HTTPException(status_code=422, detail=str(ex))

if llm_rails.config.colang_version != "1.0":
raise HTTPException(
status_code=422,
detail="check_async does not support Colang 2.0 configurations.",
)

try:
messages = list(body.messages)
if body.guardrails.context:
messages.insert(0, {"role": "context", "content": body.guardrails.context})

result = await llm_rails.check_async(messages=messages)

return GuardrailCheckResponse(
status=_map_rail_status(result.status),
content=result.content,
rail=result.rail,
)

except HTTPException:
raise
except Exception as ex:
log.exception(ex)
raise HTTPException(status_code=500, detail="Internal server error")


# By default, there are no challenges
challenges = []

Expand Down
34 changes: 34 additions & 0 deletions nemoguardrails/server/schemas/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,37 @@ class OpenAIModelsList(BaseModel):
"""Standard OpenAI models list response."""

data: list[OpenAIModel] = Field(..., description="List of OpenAI model objects.")


class GuardrailCheckDataInput(GuardrailsDataInput):
"""Guardrails input options specific to the checks endpoint."""

config: Optional[Union[str, dict]] = Field(
default=None,
description="The id of the configuration or its dict representation to be used.",
Comment thread
Pouyanpi marked this conversation as resolved.
)

@model_validator(mode="before")
@classmethod
def validate_config_exclusivity(cls, data: Any) -> Any:
if isinstance(data, dict) and data.get("config") is not None:
if data.get("config_id") is not None or data.get("config_ids") is not None:
raise ValueError("config is mutually exclusive with config_id and config_ids")
Comment thread
Pouyanpi marked this conversation as resolved.
return data


class GuardrailCheckRequest(OpenAIChatCompletionRequest):
"""Request body for the /v1/checks endpoint."""

guardrails: GuardrailCheckDataInput = Field(
default_factory=GuardrailCheckDataInput,
description="Guardrails specific options for the request.",
)


class GuardrailCheckResponse(BaseModel):
"""Response from the /v1/checks endpoint."""

status: str = Field(..., description="Overall check result: passed, modified, or blocked.")
content: str = Field(..., description="Content after rails processing.")
rail: Optional[str] = Field(default=None, description="Name of the blocking rail, if any.")
Loading
Loading