Skip to content

Commit 34056e9

Browse files
committed
Fix PII result copy regression
Signed-off-by: lucarlig <luca.carlig@ibm.com>
1 parent cfe7087 commit 34056e9

5 files changed

Lines changed: 160 additions & 12 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

plugins/rust/python-package/pii_filter/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pii_filter"
3-
version = "0.3.1"
3+
version = "0.3.2"
44
edition.workspace = true
55
authors.workspace = true
66
license.workspace = true

plugins/rust/python-package/pii_filter/cpex_pii_filter/plugin-manifest.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
description: "Rust-backed PII detection and masking for prompt arguments, tool inputs, and tool outputs"
22
author: "ContextForge Contributors"
3-
version: "0.3.1"
3+
version: "0.3.2"
44
kind: "cpex_pii_filter.pii_filter.PIIFilterPlugin"
55
available_hooks:
66
- "prompt_pre_fetch"

plugins/rust/python-package/pii_filter/src/detector.rs

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
use log::{debug, warn};
77
use pyo3::prelude::*;
8-
use pyo3::types::{PyAny, PyDict, PyList, PySet, PyString, PyTuple};
8+
use pyo3::types::{PyAny, PyDict, PyList, PyMapping, PySet, PyString, PyTuple};
99
use pyo3_stub_gen::derive::*;
1010
use std::collections::HashMap;
1111

@@ -357,24 +357,28 @@ impl PIIDetectorRust {
357357
}
358358
}
359359

360-
// Handle dictionaries
361-
if let Ok(dict) = data.cast::<PyDict>() {
362-
let mut entries: Vec<(Py<PyAny>, Py<PyAny>)> = Vec::with_capacity(dict.len());
360+
// Handle mappings through the Python protocol. CPEX isolation wraps
361+
// dicts in copy-on-write dict subclasses whose visible entries are not
362+
// stored in the underlying PyDict table.
363+
if let Ok(mapping) = data.cast::<PyMapping>() {
364+
let mapping_len = mapping.len()?;
365+
let mut entries: Vec<(Py<PyAny>, Py<PyAny>)> = Vec::with_capacity(mapping_len);
363366
let mut all_detections = HashMap::new();
364-
if dict.len() > self.config.max_collection_items {
367+
if mapping_len > self.config.max_collection_items {
365368
warn!(
366369
"Rejected nested mapping at path '{}' because size {} exceeds max {}",
367-
path,
368-
dict.len(),
369-
self.config.max_collection_items
370+
path, mapping_len, self.config.max_collection_items
370371
);
371372
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
372373
"Nested mapping exceeds maximum size of {} items",
373374
self.config.max_collection_items
374375
)));
375376
}
376377

377-
for (key, value) in dict.iter() {
378+
for item in mapping.items()?.iter() {
379+
let item = item.cast::<PyTuple>()?;
380+
let key = item.get_item(0)?;
381+
let value = item.get_item(1)?;
378382
let key_str = key.str()?.to_string_lossy().into_owned();
379383
let new_path = if path.is_empty() {
380384
key_str.clone()
@@ -1659,6 +1663,57 @@ class ConfigModel:
16591663
});
16601664
}
16611665

1666+
#[test]
1667+
fn test_process_nested_mapping_allows_collection_limit_boundary() {
1668+
Python::initialize();
1669+
Python::attach(|py| {
1670+
let config = PyDict::new(py);
1671+
config.set_item("detect_email", true).unwrap();
1672+
config.set_item("max_collection_items", 1).unwrap();
1673+
1674+
let detector = PIIDetectorRust::new(&config.into_any()).unwrap();
1675+
let data = PyDict::new(py);
1676+
data.set_item("email", "alice@example.com").unwrap();
1677+
1678+
let (modified, new_data, _) =
1679+
detector.process_nested(py, &data.into_any(), "").unwrap();
1680+
1681+
assert!(modified);
1682+
assert_eq!(
1683+
new_data
1684+
.bind(py)
1685+
.cast::<PyDict>()
1686+
.unwrap()
1687+
.get_item("email")
1688+
.unwrap()
1689+
.unwrap()
1690+
.extract::<String>()
1691+
.unwrap(),
1692+
"[REDACTED]"
1693+
);
1694+
});
1695+
}
1696+
1697+
#[test]
1698+
fn test_process_nested_mapping_rejects_over_collection_limit() {
1699+
Python::initialize();
1700+
Python::attach(|py| {
1701+
let config = PyDict::new(py);
1702+
config.set_item("detect_email", true).unwrap();
1703+
config.set_item("max_collection_items", 1).unwrap();
1704+
1705+
let detector = PIIDetectorRust::new(&config.into_any()).unwrap();
1706+
let data = PyDict::new(py);
1707+
data.set_item("first", "alice@example.com").unwrap();
1708+
data.set_item("second", "bob@example.com").unwrap();
1709+
1710+
let err = detector
1711+
.process_nested(py, &data.into_any(), "")
1712+
.unwrap_err();
1713+
assert!(err.is_instance_of::<pyo3::exceptions::PyValueError>(py));
1714+
});
1715+
}
1716+
16621717
#[test]
16631718
fn test_detects_plus_prefixed_international_phone_number() {
16641719
let config = PIIConfig {

plugins/tests/pii_filter/test_integration.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from dataclasses import dataclass
22
import logging
3+
import os
34
from pathlib import Path
5+
import subprocess
6+
import sys
7+
import textwrap
48

59
import pytest
610

@@ -314,6 +318,95 @@ async def test_tool_post_invoke_returns_copied_payload_for_frozen_models():
314318
assert result.modified_payload.result["contact"] == "[REDACTED]"
315319

316320

321+
@pytest.mark.asyncio
322+
async def test_tool_post_invoke_returns_new_nested_result_for_mcp_content():
323+
plugin = PIIFilterPlugin(_make_config())
324+
payload = ToolPostInvokePayload(
325+
name="search",
326+
result={
327+
"content": [
328+
{
329+
"type": "text",
330+
"text": "Contact alice@example.com",
331+
}
332+
],
333+
"isError": False,
334+
},
335+
)
336+
337+
result = await plugin.tool_post_invoke(payload, _make_context())
338+
339+
assert result.modified_payload is not None
340+
assert result.modified_payload is not payload
341+
assert result.modified_payload.result is not payload.result
342+
assert result.modified_payload.result["content"] is not payload.result["content"]
343+
assert result.modified_payload.result["content"][0] is not payload.result["content"][0]
344+
assert payload.result["content"][0]["text"] == "Contact alice@example.com"
345+
assert result.modified_payload.result["content"][0]["text"] == "Contact [REDACTED]"
346+
347+
348+
def test_tool_post_invoke_survives_real_cpex_policy_with_isolated_payload():
349+
plugin_root = (
350+
Path(__file__).resolve().parents[3]
351+
/ "plugins"
352+
/ "rust"
353+
/ "python-package"
354+
/ "pii_filter"
355+
)
356+
script = """
357+
import asyncio
358+
from cpex.framework import PluginConfig, PluginContext
359+
from cpex.framework.hooks.policies import HookPayloadPolicy, apply_policy
360+
from cpex.framework.hooks.tools import ToolPostInvokePayload
361+
from cpex.framework.memory import wrap_payload_for_isolation
362+
from cpex.framework.models import GlobalContext
363+
from cpex_pii_filter.pii_filter import PIIFilterPlugin
364+
365+
async def main():
366+
plugin = PIIFilterPlugin(PluginConfig(
367+
name="pii_filter",
368+
kind="cpex_pii_filter.pii_filter.PIIFilterPlugin",
369+
config={"detect_email": True, "detect_ssn": True, "block_on_detection": False},
370+
))
371+
payload = ToolPostInvokePayload(
372+
name="search",
373+
result={
374+
"content": [{"type": "text", "text": "Contact alice@example.com"}],
375+
"isError": False,
376+
},
377+
)
378+
plugin_input = wrap_payload_for_isolation(payload)
379+
context = PluginContext(global_context=GlobalContext(request_id="req-pii"))
380+
381+
result = await plugin.tool_post_invoke(plugin_input, context)
382+
assert result.modified_payload is not None
383+
filtered = apply_policy(
384+
plugin_input,
385+
result.modified_payload,
386+
HookPayloadPolicy(writable_fields=frozenset({"result"})),
387+
apply_to=payload,
388+
)
389+
assert filtered is not None
390+
assert payload.result["content"][0]["text"] == "Contact alice@example.com"
391+
assert filtered.result["content"][0]["text"] == "Contact [REDACTED]"
392+
393+
asyncio.run(main())
394+
print("ok")
395+
"""
396+
env = os.environ.copy()
397+
env.pop("PYTHONPATH", None)
398+
result = subprocess.run(
399+
[sys.executable, "-c", textwrap.dedent(script)],
400+
cwd=plugin_root,
401+
env=env,
402+
text=True,
403+
capture_output=True,
404+
check=False,
405+
)
406+
assert result.returncode == 0, result.stderr
407+
assert result.stdout.strip() == "ok"
408+
409+
317410
@pytest.mark.asyncio
318411
async def test_tool_post_invoke_blocks_when_configured():
319412
plugin = PIIFilterPlugin(_make_config(block_on_detection=True))

0 commit comments

Comments
 (0)