Skip to content

Commit 38aa2e3

Browse files
committed
fix: honor configured remote guardrails rails
Signed-off-by: Alex Fournier <afournier@nvidia.com>
1 parent 46f8146 commit 38aa2e3

2 files changed

Lines changed: 406 additions & 29 deletions

File tree

crates/core/src/plugins/nemo_guardrails/remote.rs

Lines changed: 78 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use crate::plugin::{PluginError, PluginRegistrationContext, Result as PluginResu
2121
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
2222
use rustls::crypto::ring;
2323

24-
use super::{NeMoGuardrailsConfig, RequestDefaultsConfig};
24+
use super::{NeMoGuardrailsConfig, RequestDefaultsConfig, RequestRailsConfig};
2525

2626
#[derive(Clone)]
2727
struct RemoteBackendRuntime {
@@ -88,6 +88,8 @@ impl RemoteBackendRuntime {
8888
&remote.config_id,
8989
&remote.config_ids,
9090
request_defaults,
91+
config.input,
92+
config.output,
9193
),
9294
tool_input_guardrails: build_tool_check_guardrails_config(
9395
RemoteCheckKind::Input,
@@ -862,16 +864,35 @@ fn build_llm_guardrails_config(
862864
config_id: &Option<String>,
863865
config_ids: &[String],
864866
request_defaults: Option<&RequestDefaultsConfig>,
867+
input_enabled: bool,
868+
output_enabled: bool,
865869
) -> Option<Map<String, Json>> {
866870
let mut guardrails = build_base_guardrails_config(config_id, config_ids, request_defaults);
867-
if let Some(request_defaults) = request_defaults {
868-
let mut options = Map::new();
869-
if let Some(rails) = &request_defaults.rails {
870-
options.insert(
871-
"rails".to_string(),
872-
serde_json::to_value(rails).expect("request rails config should serialize to JSON"),
873-
);
871+
let mut options = Map::new();
872+
873+
if let Some(mut rails) = request_defaults
874+
.and_then(|defaults| defaults.rails.as_ref())
875+
.map(serialize_request_rails)
876+
{
877+
if !input_enabled {
878+
rails.insert("input".to_string(), Json::Bool(false));
879+
}
880+
if !output_enabled {
881+
rails.insert("output".to_string(), Json::Bool(false));
882+
}
883+
options.insert("rails".to_string(), Json::Object(rails));
884+
} else if !input_enabled || !output_enabled {
885+
let mut rails = Map::new();
886+
if !input_enabled {
887+
rails.insert("input".to_string(), Json::Bool(false));
874888
}
889+
if !output_enabled {
890+
rails.insert("output".to_string(), Json::Bool(false));
891+
}
892+
options.insert("rails".to_string(), Json::Object(rails));
893+
}
894+
895+
if let Some(request_defaults) = request_defaults {
875896
if let Some(llm_params) = &request_defaults.llm_params {
876897
options.insert("llm_params".to_string(), llm_params.clone());
877898
}
@@ -884,9 +905,9 @@ fn build_llm_guardrails_config(
884905
if let Some(log) = &request_defaults.log {
885906
options.insert("log".to_string(), log.clone());
886907
}
887-
if !options.is_empty() {
888-
guardrails.insert("options".to_string(), Json::Object(options));
889-
}
908+
}
909+
if !options.is_empty() {
910+
guardrails.insert("options".to_string(), Json::Object(options));
890911
}
891912
(!guardrails.is_empty()).then_some(guardrails)
892913
}
@@ -899,25 +920,31 @@ fn build_tool_check_guardrails_config(
899920
) -> Map<String, Json> {
900921
let mut guardrails = build_base_guardrails_config(config_id, config_ids, request_defaults);
901922
let mut options = Map::new();
902-
let rails = match kind {
903-
RemoteCheckKind::Input => json!({
904-
"input": false,
905-
"output": false,
906-
"dialog": false,
907-
"retrieval": false,
908-
"tool_input": true,
909-
"tool_output": false,
910-
}),
911-
RemoteCheckKind::Output => json!({
912-
"input": false,
913-
"output": false,
914-
"dialog": false,
915-
"retrieval": false,
916-
"tool_input": false,
917-
"tool_output": true,
918-
}),
923+
let mut rails = Map::from_iter([
924+
("input".to_string(), Json::Bool(false)),
925+
("output".to_string(), Json::Bool(false)),
926+
("dialog".to_string(), Json::Bool(false)),
927+
("retrieval".to_string(), Json::Bool(false)),
928+
]);
929+
match kind {
930+
RemoteCheckKind::Input => {
931+
rails.insert(
932+
"tool_input".to_string(),
933+
configured_tool_selector(request_defaults, RemoteCheckKind::Input)
934+
.unwrap_or(Json::Bool(true)),
935+
);
936+
rails.insert("tool_output".to_string(), Json::Bool(false));
937+
}
938+
RemoteCheckKind::Output => {
939+
rails.insert("tool_input".to_string(), Json::Bool(false));
940+
rails.insert(
941+
"tool_output".to_string(),
942+
configured_tool_selector(request_defaults, RemoteCheckKind::Output)
943+
.unwrap_or(Json::Bool(true)),
944+
);
945+
}
919946
};
920-
options.insert("rails".to_string(), rails);
947+
options.insert("rails".to_string(), Json::Object(rails));
921948
let mut log = request_defaults
922949
.and_then(|defaults| defaults.log.as_ref())
923950
.and_then(Json::as_object)
@@ -929,6 +956,28 @@ fn build_tool_check_guardrails_config(
929956
guardrails
930957
}
931958

959+
fn serialize_request_rails(rails: &RequestRailsConfig) -> Map<String, Json> {
960+
serde_json::to_value(rails)
961+
.expect("request rails config should serialize to JSON")
962+
.as_object()
963+
.cloned()
964+
.expect("request rails config should serialize to a JSON object")
965+
}
966+
967+
fn configured_tool_selector(
968+
request_defaults: Option<&RequestDefaultsConfig>,
969+
kind: RemoteCheckKind,
970+
) -> Option<Json> {
971+
let rails = request_defaults.and_then(|defaults| defaults.rails.as_ref())?;
972+
match kind {
973+
RemoteCheckKind::Input => rails.tool_input.as_ref(),
974+
RemoteCheckKind::Output => rails.tool_output.as_ref(),
975+
}
976+
.map(|selector| {
977+
serde_json::to_value(selector).expect("tool rail selector should serialize to JSON")
978+
})
979+
}
980+
932981
#[cfg(any(target_arch = "wasm32", not(feature = "guardrails-remote")))]
933982
pub(super) fn register_remote_backend(
934983
_config: NeMoGuardrailsConfig,

0 commit comments

Comments
 (0)