Skip to content

Commit 60999ac

Browse files
committed
Address guardrails remote review feedback
Signed-off-by: Alex Fournier <afournier@nvidia.com>
1 parent 3c2276c commit 60999ac

3 files changed

Lines changed: 66 additions & 40 deletions

File tree

crates/core/src/plugins/nemo_guardrails/component.rs

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -986,18 +986,20 @@ fn validate_request_defaults(
986986
.state
987987
.as_ref()
988988
.and_then(|value| value.as_object())
989-
&& !state.is_empty()
990-
&& !state.contains_key("events")
991-
&& !state.contains_key("state")
992989
{
993-
push_policy_diag(
994-
diagnostics,
995-
policy.unsupported_value,
996-
"nemo_guardrails.unsupported_value",
997-
Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()),
998-
Some("request_defaults.state".to_string()),
999-
"request_defaults.state must be empty or contain 'events' or 'state'".to_string(),
1000-
);
990+
let contains_supported_key = state.contains_key("events") || state.contains_key("state");
991+
let contains_unsupported_key = state.keys().any(|key| key != "events" && key != "state");
992+
if (!state.is_empty() && !contains_supported_key) || contains_unsupported_key {
993+
push_policy_diag(
994+
diagnostics,
995+
policy.unsupported_value,
996+
"nemo_guardrails.unsupported_value",
997+
Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()),
998+
Some("request_defaults.state".to_string()),
999+
"request_defaults.state must be empty or contain only 'events' or 'state'"
1000+
.to_string(),
1001+
);
1002+
}
10011003
}
10021004
validate_json_object_field(
10031005
diagnostics,

crates/core/src/plugins/nemo_guardrails/remote.rs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ impl RemoteBackendRuntime {
178178
&self.config_id,
179179
&self.config_ids,
180180
Some(status.as_u16()),
181-
Some(payload.clone()),
181+
Some(redact_remote_error_payload(status.as_u16(), &payload)),
182182
),
183183
);
184184
return Err(FlowError::Internal(format!(
@@ -302,7 +302,7 @@ impl RemoteBackendRuntime {
302302
&self.config_id,
303303
&self.config_ids,
304304
Some(status.as_u16()),
305-
Some(payload.clone()),
305+
Some(redact_remote_error_payload(status.as_u16(), &payload)),
306306
),
307307
);
308308
return Err(FlowError::Internal(format!(
@@ -615,7 +615,7 @@ impl RemoteBackendRuntime {
615615
&self.config_id,
616616
&self.config_ids,
617617
Some(status.as_u16()),
618-
Some(payload.clone()),
618+
Some(redact_remote_error_payload(status.as_u16(), &payload)),
619619
),
620620
);
621621
return Err(FlowError::Internal(format!(
@@ -679,20 +679,20 @@ impl RemoteBackendRuntime {
679679
let mut options = Map::new();
680680
let rails = match kind {
681681
RemoteCheckKind::Input => json!({
682-
"input": true,
682+
"input": false,
683683
"output": false,
684684
"dialog": false,
685685
"retrieval": false,
686-
"tool_input": false,
686+
"tool_input": true,
687687
"tool_output": false,
688688
}),
689689
RemoteCheckKind::Output => json!({
690690
"input": false,
691-
"output": true,
691+
"output": false,
692692
"dialog": false,
693693
"retrieval": false,
694694
"tool_input": false,
695-
"tool_output": false,
695+
"tool_output": true,
696696
}),
697697
};
698698
options.insert("rails".to_string(), rails);
@@ -730,6 +730,14 @@ fn tool_input_content(tool_name: &str, args: &Json) -> String {
730730
.expect("tool input payload should serialize to JSON")
731731
}
732732

733+
#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))]
734+
fn redact_remote_error_payload(status: u16, payload: &str) -> String {
735+
format!(
736+
"remote request failed with status {status}; error body omitted from marks ({} bytes)",
737+
payload.len()
738+
)
739+
}
740+
733741
#[cfg(all(not(target_arch = "wasm32"), feature = "guardrails-remote"))]
734742
fn tool_output_content(tool_name: &str, args: &Json, result: &Json) -> String {
735743
serde_json::to_string(&json!({

crates/core/tests/unit/plugins/nemo_guardrails/component_tests.rs

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use std::net::TcpListener;
1111
use std::sync::atomic::{AtomicBool, Ordering};
1212
use std::sync::{Arc, Mutex, mpsc};
1313
use std::thread;
14+
use std::time::Duration;
1415

1516
use crate::api::event::Event;
1617
use crate::api::llm::{
@@ -37,6 +38,8 @@ use crate::plugin::{
3738
use futures::StreamExt;
3839
use serde_json::json;
3940

41+
const TEST_TIMEOUT: Duration = Duration::from_secs(5);
42+
4043
fn reset_runtime() {
4144
let _ = clear_plugin_configuration();
4245
crate::shared_runtime::reset_runtime_owner_for_tests();
@@ -191,6 +194,12 @@ fn header_value<'a>(headers_text: &'a str, header_name: &str) -> Option<&'a str>
191194
})
192195
}
193196

197+
fn recv_captured_request(request_rx: &mpsc::Receiver<CapturedHttpRequest>) -> CapturedHttpRequest {
198+
request_rx
199+
.recv_timeout(TEST_TIMEOUT)
200+
.expect("timed out waiting for captured HTTP request")
201+
}
202+
194203
fn make_chat_request(stream: bool) -> LlmRequest {
195204
LlmRequest {
196205
headers: serde_json::Map::new(),
@@ -749,7 +758,7 @@ fn invalid_shapes_and_values_are_reported() {
749758
);
750759
assert!(invalid_request_defaults.diagnostics.iter().any(|diag| {
751760
diag.message
752-
.contains("request_defaults.state must be empty or contain 'events' or 'state'")
761+
.contains("request_defaults.state must be empty or contain only 'events' or 'state'")
753762
}));
754763
assert!(
755764
invalid_request_defaults
@@ -943,7 +952,7 @@ async fn remote_initialization_installs_non_streaming_execution_intercept() {
943952
json!("server-state")
944953
);
945954

946-
let captured = request_rx.recv().unwrap();
955+
let captured = recv_captured_request(&request_rx);
947956
assert_eq!(captured.path, "/v1/chat/completions");
948957
assert!(captured.content_type.starts_with("application/json"));
949958

@@ -1061,7 +1070,7 @@ async fn remote_request_uses_config_ids_when_config_id_is_not_set() {
10611070
.await
10621071
.unwrap();
10631072

1064-
let captured = request_rx.recv().unwrap();
1073+
let captured = recv_captured_request(&request_rx);
10651074
let request_json: Json = serde_json::from_slice(&captured.body).unwrap();
10661075
assert_eq!(
10671076
request_json["guardrails"]["config_ids"],
@@ -1136,7 +1145,10 @@ async fn remote_initialization_installs_stream_execution_intercept() {
11361145
.unwrap();
11371146

11381147
let mut chunks = Vec::new();
1139-
while let Some(chunk) = stream.next().await {
1148+
while let Some(chunk) = tokio::time::timeout(TEST_TIMEOUT, stream.next())
1149+
.await
1150+
.expect("timed out waiting for remote stream chunk")
1151+
{
11401152
chunks.push(chunk.unwrap());
11411153
}
11421154

@@ -1145,7 +1157,7 @@ async fn remote_initialization_installs_stream_execution_intercept() {
11451157
assert_eq!(chunks[0]["choices"][0]["delta"]["content"], json!("guard"));
11461158
assert_eq!(chunks[1]["choices"][0]["delta"]["content"], json!("ed"));
11471159

1148-
let captured = request_rx.recv().unwrap();
1160+
let captured = recv_captured_request(&request_rx);
11491161
let request_json: Json = serde_json::from_slice(&captured.body).unwrap();
11501162
assert_eq!(request_json["stream"], json!(true));
11511163
assert_eq!(
@@ -1246,7 +1258,7 @@ async fn remote_non_streaming_http_errors_are_reported_and_marked() {
12461258
error_mark.data().unwrap()["error"]
12471259
.as_str()
12481260
.unwrap()
1249-
.contains("backend unavailable")
1261+
.contains("error body omitted from marks")
12501262
);
12511263

12521264
deregister_subscriber("nemo-guardrails-remote-error-events").unwrap();
@@ -1341,7 +1353,7 @@ async fn remote_streaming_http_errors_are_reported_and_marked() {
13411353
error_mark.data().unwrap()["error"]
13421354
.as_str()
13431355
.unwrap()
1344-
.contains("stream backend unavailable")
1356+
.contains("error body omitted from marks")
13451357
);
13461358

13471359
deregister_subscriber("nemo-guardrails-remote-stream-error-events").unwrap();
@@ -1470,7 +1482,11 @@ async fn remote_streaming_malformed_chunk_is_reported_and_marked() {
14701482
.await
14711483
.unwrap();
14721484

1473-
let error = stream.next().await.unwrap().unwrap_err();
1485+
let error = tokio::time::timeout(TEST_TIMEOUT, stream.next())
1486+
.await
1487+
.expect("timed out waiting for remote stream error")
1488+
.unwrap()
1489+
.unwrap_err();
14741490
match error {
14751491
crate::error::FlowError::Internal(message) => {
14761492
assert!(!message.is_empty());
@@ -1759,14 +1775,14 @@ async fn remote_tool_input_block_rejects_before_tool_execution() {
17591775
other => panic!("unexpected error: {other}"),
17601776
}
17611777

1762-
let captured = request_rx.recv().unwrap();
1778+
let captured = recv_captured_request(&request_rx);
17631779
let request_json: Json = serde_json::from_slice(&captured.body).unwrap();
17641780
assert_eq!(
1765-
request_json["guardrails"]["options"]["rails"]["input"],
1781+
request_json["guardrails"]["options"]["rails"]["tool_input"],
17661782
json!(true)
17671783
);
17681784
assert_eq!(
1769-
request_json["guardrails"]["options"]["rails"]["output"],
1785+
request_json["guardrails"]["options"]["rails"]["tool_output"],
17701786
json!(false)
17711787
);
17721788

@@ -1860,7 +1876,7 @@ async fn remote_tool_input_can_rewrite_tool_arguments() {
18601876
assert_eq!(result, json!({"forecast": "sunny"}));
18611877
assert_eq!(*seen_args.lock().unwrap(), Some(json!({"city": "Boston"})));
18621878

1863-
let captured = request_rx.recv().unwrap();
1879+
let captured = recv_captured_request(&request_rx);
18641880
let request_json: Json = serde_json::from_slice(&captured.body).unwrap();
18651881
assert_eq!(request_json["messages"][0]["role"], json!("user"));
18661882
}
@@ -1932,14 +1948,14 @@ async fn remote_tool_output_can_rewrite_tool_result() {
19321948

19331949
assert_eq!(result, json!({"forecast": "cloudy"}));
19341950

1935-
let captured = request_rx.recv().unwrap();
1951+
let captured = recv_captured_request(&request_rx);
19361952
let request_json: Json = serde_json::from_slice(&captured.body).unwrap();
19371953
assert_eq!(
1938-
request_json["guardrails"]["options"]["rails"]["input"],
1954+
request_json["guardrails"]["options"]["rails"]["tool_input"],
19391955
json!(false)
19401956
);
19411957
assert_eq!(
1942-
request_json["guardrails"]["options"]["rails"]["output"],
1958+
request_json["guardrails"]["options"]["rails"]["tool_output"],
19431959
json!(true)
19441960
);
19451961
}
@@ -2307,31 +2323,31 @@ async fn remote_tool_input_and_output_run_in_order() {
23072323
assert_eq!(*seen_args.lock().unwrap(), Some(json!({"city": "Boston"})));
23082324
assert_eq!(result, json!({"forecast": "cloudy"}));
23092325

2310-
let first_request = request_rx.recv().unwrap();
2326+
let first_request = recv_captured_request(&request_rx);
23112327
let first_request_json: Json = serde_json::from_slice(&first_request.body).unwrap();
23122328
assert_eq!(first_request_json["messages"][0]["role"], json!("user"));
23132329
assert_eq!(
2314-
first_request_json["guardrails"]["options"]["rails"]["input"],
2330+
first_request_json["guardrails"]["options"]["rails"]["tool_input"],
23152331
json!(true)
23162332
);
23172333
assert_eq!(
2318-
first_request_json["guardrails"]["options"]["rails"]["output"],
2334+
first_request_json["guardrails"]["options"]["rails"]["tool_output"],
23192335
json!(false)
23202336
);
23212337

2322-
let second_request = request_rx.recv().unwrap();
2338+
let second_request = recv_captured_request(&request_rx);
23232339
let second_request_json: Json = serde_json::from_slice(&second_request.body).unwrap();
23242340
assert_eq!(second_request_json["messages"][0]["role"], json!("user"));
23252341
assert_eq!(
23262342
second_request_json["messages"][1]["role"],
23272343
json!("assistant")
23282344
);
23292345
assert_eq!(
2330-
second_request_json["guardrails"]["options"]["rails"]["input"],
2346+
second_request_json["guardrails"]["options"]["rails"]["tool_input"],
23312347
json!(false)
23322348
);
23332349
assert_eq!(
2334-
second_request_json["guardrails"]["options"]["rails"]["output"],
2350+
second_request_json["guardrails"]["options"]["rails"]["tool_output"],
23352351
json!(true)
23362352
);
23372353
}
@@ -2408,7 +2424,7 @@ async fn remote_tool_checks_forward_context_state_and_thread_id() {
24082424

24092425
assert_eq!(result, json!({"forecast": "sunny"}));
24102426

2411-
let captured = request_rx.recv().unwrap();
2427+
let captured = recv_captured_request(&request_rx);
24122428
let request_json: Json = serde_json::from_slice(&captured.body).unwrap();
24132429
assert_eq!(
24142430
request_json["guardrails"]["context"],

0 commit comments

Comments
 (0)