Skip to content

Commit 76555e7

Browse files
committed
Fix secrets detector model dump counting
Signed-off-by: lucarlig <luca.carlig@ibm.com>
1 parent c9c07ae commit 76555e7

2 files changed

Lines changed: 226 additions & 0 deletions

File tree

plugins/rust/python-package/secrets_detection/src/scanner/python_scan.rs

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,11 @@ fn should_scan_serialized_state(
192192
}
193193

194194
if serialized_state.is_exact_instance_of::<PyDict>() {
195+
if let Some(rebuild_state) = rebuild_state
196+
&& serialized_dict_duplicates_rebuild_state(serialized_state, rebuild_state)?
197+
{
198+
return Ok(false);
199+
}
195200
return Ok(true);
196201
}
197202

@@ -221,6 +226,92 @@ fn should_scan_serialized_state(
221226
Ok(!serialized_rebuild_state.as_any().eq(rebuild_state)?)
222227
}
223228

229+
fn serialized_dict_duplicates_rebuild_state(
230+
serialized_state: &Bound<'_, PyAny>,
231+
rebuild_state: &Bound<'_, PyAny>,
232+
) -> PyResult<bool> {
233+
let serialized_dict = serialized_state.cast::<PyDict>()?;
234+
let Ok(rebuild_dict) = rebuild_state.cast::<PyDict>() else {
235+
return Ok(false);
236+
};
237+
238+
if !dict_has_only_exact_string_keys(serialized_dict)
239+
|| !dict_has_only_exact_string_keys(rebuild_dict)
240+
{
241+
return Ok(false);
242+
}
243+
244+
for (key, serialized_value) in serialized_dict.iter() {
245+
let Some(rebuild_value) = rebuild_dict.get_item(&key)? else {
246+
return Ok(false);
247+
};
248+
if !same_safe_value(&serialized_value, &rebuild_value)? {
249+
return Ok(false);
250+
}
251+
}
252+
253+
Ok(true)
254+
}
255+
256+
fn same_safe_value(left: &Bound<'_, PyAny>, right: &Bound<'_, PyAny>) -> PyResult<bool> {
257+
if left.is(right) {
258+
return Ok(true);
259+
}
260+
261+
if left.is_exact_instance_of::<PyString>() && right.is_exact_instance_of::<PyString>() {
262+
return Ok(left.extract::<String>()? == right.extract::<String>()?);
263+
}
264+
265+
if let (Ok(left_list), Ok(right_list)) = (left.cast::<PyList>(), right.cast::<PyList>()) {
266+
if left_list.len() != right_list.len() {
267+
return Ok(false);
268+
}
269+
for (left_item, right_item) in left_list.iter().zip(right_list.iter()) {
270+
if !same_safe_value(&left_item, &right_item)? {
271+
return Ok(false);
272+
}
273+
}
274+
return Ok(true);
275+
}
276+
277+
if let (Ok(left_tuple), Ok(right_tuple)) = (left.cast::<PyTuple>(), right.cast::<PyTuple>()) {
278+
if left_tuple.len() != right_tuple.len() {
279+
return Ok(false);
280+
}
281+
for (left_item, right_item) in left_tuple.iter().zip(right_tuple.iter()) {
282+
if !same_safe_value(&left_item, &right_item)? {
283+
return Ok(false);
284+
}
285+
}
286+
return Ok(true);
287+
}
288+
289+
if let (Ok(left_dict), Ok(right_dict)) = (left.cast::<PyDict>(), right.cast::<PyDict>()) {
290+
if left_dict.len() != right_dict.len()
291+
|| !dict_has_only_exact_string_keys(left_dict)
292+
|| !dict_has_only_exact_string_keys(right_dict)
293+
{
294+
return Ok(false);
295+
}
296+
for (key, left_value) in left_dict.iter() {
297+
let Some(right_value) = right_dict.get_item(&key)? else {
298+
return Ok(false);
299+
};
300+
if !same_safe_value(&left_value, &right_value)? {
301+
return Ok(false);
302+
}
303+
}
304+
return Ok(true);
305+
}
306+
307+
Ok(false)
308+
}
309+
310+
fn dict_has_only_exact_string_keys(dict: &Bound<'_, PyDict>) -> bool {
311+
dict.iter()
312+
.all(|(key, _)| key.is_exact_instance_of::<PyString>())
313+
}
314+
224315
fn serialized_result<'py>(
225316
py: Python<'py>,
226317
container: &Bound<'py, PyAny>,
@@ -373,4 +464,97 @@ dummy = object()
373464
})
374465
.unwrap();
375466
}
467+
468+
#[test]
469+
fn scan_container_does_not_double_count_matching_model_dump_dict() {
470+
Python::initialize();
471+
Python::attach(|py| -> PyResult<()> {
472+
let code = CString::new(
473+
r#"
474+
class Model:
475+
def __init__(self):
476+
self.text = "AWS_ACCESS_KEY_ID=AKIAFAKE12345EXAMPLE"
477+
478+
def model_dump(self):
479+
return {"text": self.text}
480+
"#,
481+
)
482+
.unwrap();
483+
let module =
484+
PyModule::from_code(py, code.as_c_str(), c"test_module.py", c"test_module")?;
485+
let instance = module.getattr("Model")?.call0()?;
486+
let config = SecretsDetectionConfig::default();
487+
488+
let (count, _, findings) = scan_container(py, &instance, &config)?;
489+
490+
assert_eq!(count, 1);
491+
assert_eq!(findings.len(), 1);
492+
493+
Ok(())
494+
})
495+
.unwrap();
496+
}
497+
498+
#[test]
499+
fn scan_container_does_not_double_count_matching_model_dump_list() {
500+
Python::initialize();
501+
Python::attach(|py| -> PyResult<()> {
502+
let code = CString::new(
503+
r#"
504+
class Model:
505+
def __init__(self):
506+
self.items = ["AWS_ACCESS_KEY_ID=AKIAFAKE12345EXAMPLE"]
507+
508+
def model_dump(self):
509+
return {"items": list(self.items)}
510+
"#,
511+
)
512+
.unwrap();
513+
let module =
514+
PyModule::from_code(py, code.as_c_str(), c"test_module.py", c"test_module")?;
515+
let instance = module.getattr("Model")?.call0()?;
516+
let config = SecretsDetectionConfig::default();
517+
518+
let (count, _, findings) = scan_container(py, &instance, &config)?;
519+
520+
assert_eq!(count, 1);
521+
assert_eq!(findings.len(), 1);
522+
523+
Ok(())
524+
})
525+
.unwrap();
526+
}
527+
528+
#[test]
529+
fn duplicate_gate_ignores_non_string_model_dump_keys_without_lookup() {
530+
Python::initialize();
531+
Python::attach(|py| -> PyResult<()> {
532+
let code = CString::new(
533+
r#"
534+
class BadKey:
535+
def __hash__(self):
536+
return hash("text")
537+
538+
def __eq__(self, other):
539+
raise RuntimeError("duplicate gate should not compare custom keys")
540+
"#,
541+
)
542+
.unwrap();
543+
let module =
544+
PyModule::from_code(py, code.as_c_str(), c"test_module.py", c"test_module")?;
545+
let bad_key = module.getattr("BadKey")?.call0()?;
546+
let serialized = PyDict::new(py);
547+
serialized.set_item(&bad_key, "AWS_ACCESS_KEY_ID=AKIAFAKE12345EXAMPLE")?;
548+
let rebuild = PyDict::new(py);
549+
rebuild.set_item("text", "AWS_ACCESS_KEY_ID=AKIAFAKE12345EXAMPLE")?;
550+
551+
let duplicates =
552+
serialized_dict_duplicates_rebuild_state(serialized.as_any(), rebuild.as_any())?;
553+
554+
assert!(!duplicates);
555+
556+
Ok(())
557+
})
558+
.unwrap();
559+
}
376560
}

plugins/tests/secrets_detection/test_integration.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,6 +1314,48 @@ def serialize_model(self):
13141314
assert result.violation is not None
13151315
assert result.violation.code == "SECRETS_DETECTED"
13161316

1317+
async def test_tool_post_invoke_does_not_double_count_model_dump_fields(self):
1318+
class SecretModel(BaseModel):
1319+
text: str
1320+
1321+
plugin = SecretsDetectionPlugin(
1322+
_make_config(block_on_detection=True, redact=False, min_findings_to_block=2)
1323+
)
1324+
payload = ToolPostInvokePayload(
1325+
name="writer",
1326+
result=SecretModel(text="AWS_ACCESS_KEY_ID=AKIAFAKE12345EXAMPLE"),
1327+
)
1328+
1329+
result = await plugin.tool_post_invoke(payload, _make_context())
1330+
1331+
assert result.continue_processing is True
1332+
assert result.violation is None
1333+
assert result.metadata == {
1334+
"count": 1,
1335+
"secrets_findings": [{"type": "aws_access_key_id"}],
1336+
}
1337+
1338+
async def test_tool_post_invoke_does_not_double_count_model_dump_list_fields(self):
1339+
class SecretListModel(BaseModel):
1340+
items: list[str]
1341+
1342+
plugin = SecretsDetectionPlugin(
1343+
_make_config(block_on_detection=True, redact=False, min_findings_to_block=2)
1344+
)
1345+
payload = ToolPostInvokePayload(
1346+
name="writer",
1347+
result=SecretListModel(items=["AWS_ACCESS_KEY_ID=AKIAFAKE12345EXAMPLE"]),
1348+
)
1349+
1350+
result = await plugin.tool_post_invoke(payload, _make_context())
1351+
1352+
assert result.continue_processing is True
1353+
assert result.violation is None
1354+
assert result.metadata == {
1355+
"count": 1,
1356+
"secrets_findings": [{"type": "aws_access_key_id"}],
1357+
}
1358+
13171359
async def test_tool_post_invoke_redacts_secret_exposed_only_by_model_dump(self, plugin):
13181360
class SplitSecretModel(BaseModel):
13191361
prefix: str

0 commit comments

Comments
 (0)