@@ -21,7 +21,7 @@ use crate::plugin::{PluginError, PluginRegistrationContext, Result as PluginResu
2121use reqwest:: header:: { HeaderMap , HeaderName , HeaderValue } ;
2222use rustls:: crypto:: ring;
2323
24- use super :: { NeMoGuardrailsConfig , RequestDefaultsConfig } ;
24+ use super :: { NeMoGuardrailsConfig , RequestDefaultsConfig , RequestRailsConfig } ;
2525
2626#[ derive( Clone ) ]
2727struct RemoteBackendRuntime {
@@ -88,6 +88,8 @@ impl RemoteBackendRuntime {
8888 & remote. config_id ,
8989 & remote. config_ids ,
9090 request_defaults,
91+ config. input ,
92+ config. output ,
9193 ) ,
9294 tool_input_guardrails : build_tool_check_guardrails_config (
9395 RemoteCheckKind :: Input ,
@@ -862,16 +864,35 @@ fn build_llm_guardrails_config(
862864 config_id : & Option < String > ,
863865 config_ids : & [ String ] ,
864866 request_defaults : Option < & RequestDefaultsConfig > ,
867+ input_enabled : bool ,
868+ output_enabled : bool ,
865869) -> Option < Map < String , Json > > {
866870 let mut guardrails = build_base_guardrails_config ( config_id, config_ids, request_defaults) ;
867- if let Some ( request_defaults) = request_defaults {
868- let mut options = Map :: new ( ) ;
869- if let Some ( rails) = & request_defaults. rails {
870- options. insert (
871- "rails" . to_string ( ) ,
872- serde_json:: to_value ( rails) . expect ( "request rails config should serialize to JSON" ) ,
873- ) ;
871+ let mut options = Map :: new ( ) ;
872+
873+ if let Some ( mut rails) = request_defaults
874+ . and_then ( |defaults| defaults. rails . as_ref ( ) )
875+ . map ( serialize_request_rails)
876+ {
877+ if !input_enabled {
878+ rails. insert ( "input" . to_string ( ) , Json :: Bool ( false ) ) ;
879+ }
880+ if !output_enabled {
881+ rails. insert ( "output" . to_string ( ) , Json :: Bool ( false ) ) ;
882+ }
883+ options. insert ( "rails" . to_string ( ) , Json :: Object ( rails) ) ;
884+ } else if !input_enabled || !output_enabled {
885+ let mut rails = Map :: new ( ) ;
886+ if !input_enabled {
887+ rails. insert ( "input" . to_string ( ) , Json :: Bool ( false ) ) ;
874888 }
889+ if !output_enabled {
890+ rails. insert ( "output" . to_string ( ) , Json :: Bool ( false ) ) ;
891+ }
892+ options. insert ( "rails" . to_string ( ) , Json :: Object ( rails) ) ;
893+ }
894+
895+ if let Some ( request_defaults) = request_defaults {
875896 if let Some ( llm_params) = & request_defaults. llm_params {
876897 options. insert ( "llm_params" . to_string ( ) , llm_params. clone ( ) ) ;
877898 }
@@ -884,9 +905,9 @@ fn build_llm_guardrails_config(
884905 if let Some ( log) = & request_defaults. log {
885906 options. insert ( "log" . to_string ( ) , log. clone ( ) ) ;
886907 }
887- if !options . is_empty ( ) {
888- guardrails . insert ( " options" . to_string ( ) , Json :: Object ( options ) ) ;
889- }
908+ }
909+ if ! options. is_empty ( ) {
910+ guardrails . insert ( "options" . to_string ( ) , Json :: Object ( options ) ) ;
890911 }
891912 ( !guardrails. is_empty ( ) ) . then_some ( guardrails)
892913}
@@ -899,25 +920,31 @@ fn build_tool_check_guardrails_config(
899920) -> Map < String , Json > {
900921 let mut guardrails = build_base_guardrails_config ( config_id, config_ids, request_defaults) ;
901922 let mut options = Map :: new ( ) ;
902- let rails = match kind {
903- RemoteCheckKind :: Input => json ! ( {
904- "input" : false ,
905- "output" : false ,
906- "dialog" : false ,
907- "retrieval" : false ,
908- "tool_input" : true ,
909- "tool_output" : false ,
910- } ) ,
911- RemoteCheckKind :: Output => json ! ( {
912- "input" : false ,
913- "output" : false ,
914- "dialog" : false ,
915- "retrieval" : false ,
916- "tool_input" : false ,
917- "tool_output" : true ,
918- } ) ,
923+ let mut rails = Map :: from_iter ( [
924+ ( "input" . to_string ( ) , Json :: Bool ( false ) ) ,
925+ ( "output" . to_string ( ) , Json :: Bool ( false ) ) ,
926+ ( "dialog" . to_string ( ) , Json :: Bool ( false ) ) ,
927+ ( "retrieval" . to_string ( ) , Json :: Bool ( false ) ) ,
928+ ] ) ;
929+ match kind {
930+ RemoteCheckKind :: Input => {
931+ rails. insert (
932+ "tool_input" . to_string ( ) ,
933+ configured_tool_selector ( request_defaults, RemoteCheckKind :: Input )
934+ . unwrap_or ( Json :: Bool ( true ) ) ,
935+ ) ;
936+ rails. insert ( "tool_output" . to_string ( ) , Json :: Bool ( false ) ) ;
937+ }
938+ RemoteCheckKind :: Output => {
939+ rails. insert ( "tool_input" . to_string ( ) , Json :: Bool ( false ) ) ;
940+ rails. insert (
941+ "tool_output" . to_string ( ) ,
942+ configured_tool_selector ( request_defaults, RemoteCheckKind :: Output )
943+ . unwrap_or ( Json :: Bool ( true ) ) ,
944+ ) ;
945+ }
919946 } ;
920- options. insert ( "rails" . to_string ( ) , rails) ;
947+ options. insert ( "rails" . to_string ( ) , Json :: Object ( rails) ) ;
921948 let mut log = request_defaults
922949 . and_then ( |defaults| defaults. log . as_ref ( ) )
923950 . and_then ( Json :: as_object)
@@ -929,6 +956,28 @@ fn build_tool_check_guardrails_config(
929956 guardrails
930957}
931958
959+ fn serialize_request_rails ( rails : & RequestRailsConfig ) -> Map < String , Json > {
960+ serde_json:: to_value ( rails)
961+ . expect ( "request rails config should serialize to JSON" )
962+ . as_object ( )
963+ . cloned ( )
964+ . expect ( "request rails config should serialize to a JSON object" )
965+ }
966+
967+ fn configured_tool_selector (
968+ request_defaults : Option < & RequestDefaultsConfig > ,
969+ kind : RemoteCheckKind ,
970+ ) -> Option < Json > {
971+ let rails = request_defaults. and_then ( |defaults| defaults. rails . as_ref ( ) ) ?;
972+ match kind {
973+ RemoteCheckKind :: Input => rails. tool_input . as_ref ( ) ,
974+ RemoteCheckKind :: Output => rails. tool_output . as_ref ( ) ,
975+ }
976+ . map ( |selector| {
977+ serde_json:: to_value ( selector) . expect ( "tool rail selector should serialize to JSON" )
978+ } )
979+ }
980+
932981#[ cfg( any( target_arch = "wasm32" , not( feature = "guardrails-remote" ) ) ) ]
933982pub ( super ) fn register_remote_backend (
934983 _config : NeMoGuardrailsConfig ,
0 commit comments