Skip to content

Commit b294d1d

Browse files
committed
fix: address regex filter review findings
Signed-off-by: lucarlig <luca.carlig@ibm.com>
1 parent 97c4675 commit b294d1d

4 files changed

Lines changed: 259 additions & 86 deletions

File tree

plugins/rust/python-package/regex_filter/cpex_regex_filter/regex_filter_rust/__init__.pyi

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# This file is automatically generated by pyo3_stub_gen
22
# ruff: noqa: E501, F401, F403, F405
33

4+
import builtins
45
import typing
56
__all__ = [
67
"RegexFilterPluginCore",
8+
"SearchReplacePluginRust",
79
]
810

911
@typing.final
@@ -13,3 +15,9 @@ class RegexFilterPluginCore:
1315
def prompt_post_fetch(self, payload: typing.Any, _context: typing.Any) -> typing.Any: ...
1416
def tool_pre_invoke(self, payload: typing.Any, _context: typing.Any) -> typing.Any: ...
1517
def tool_post_invoke(self, payload: typing.Any, _context: typing.Any) -> typing.Any: ...
18+
19+
@typing.final
20+
class SearchReplacePluginRust:
21+
def __new__(cls, config_dict: dict) -> SearchReplacePluginRust: ...
22+
def apply_patterns(self, text: builtins.str) -> builtins.str: ...
23+
def process_nested(self, data: typing.Any) -> tuple[builtins.bool, typing.Any]: ...

plugins/rust/python-package/regex_filter/src/bin/stub_gen.rs

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,21 @@ use regex_filter_rust::stub_info;
1010

1111
const EXTENSION_STUB_PATH: &str = "cpex_regex_filter/regex_filter_rust/__init__.pyi";
1212
const ORPHAN_EXTENSION_STUB_PATH: &str = "python/regex_filter_rust/__init__.pyi";
13-
const GENERATED_ALL_MARKER: &str = "\"SearchReplacePluginRust\",\n]";
14-
const CURATED_ALL_ENTRY: &str = "\"RegexFilterPluginCore\",\n \"SearchReplacePluginRust\",\n]";
13+
const GENERATED_ALL_MARKER: &str = "__all__ = [\n";
1514
const PLUGIN_CORE_CLASS_MARKER: &str = "class RegexFilterPluginCore:";
16-
const PLUGIN_CORE_CLASS_DEF: &str = "\n\n@typing.final\nclass RegexFilterPluginCore:\n def __new__(cls, config: typing.Any) -> RegexFilterPluginCore: ...\n def prompt_pre_fetch(self, payload: typing.Any, context: typing.Any) -> typing.Any: ...\n def prompt_post_fetch(self, payload: typing.Any, context: typing.Any) -> typing.Any: ...\n def tool_pre_invoke(self, payload: typing.Any, context: typing.Any) -> typing.Any: ...\n def tool_post_invoke(self, payload: typing.Any, context: typing.Any) -> typing.Any: ...\n";
15+
const ENGINE_CLASS_MARKER: &str = "class SearchReplacePluginRust:";
16+
const CURATED_ALL_BLOCK: &str =
17+
"__all__ = [\n \"RegexFilterPluginCore\",\n \"SearchReplacePluginRust\",\n]\n";
18+
const PLUGIN_CORE_CLASS_DEF: &str = "\n\n@typing.final\nclass RegexFilterPluginCore:\n def __new__(cls, config: dict) -> RegexFilterPluginCore: ...\n def prompt_pre_fetch(self, payload: typing.Any, context: typing.Any) -> typing.Any: ...\n def prompt_post_fetch(self, payload: typing.Any, context: typing.Any) -> typing.Any: ...\n def tool_pre_invoke(self, payload: typing.Any, context: typing.Any) -> typing.Any: ...\n def tool_post_invoke(self, payload: typing.Any, context: typing.Any) -> typing.Any: ...\n";
1719

1820
fn curate_extension_stub_content(content: &str) -> String {
19-
let mut curated = content.replace(GENERATED_ALL_MARKER, CURATED_ALL_ENTRY);
21+
let mut curated = content.to_string();
22+
if let Some(all_start) = curated.find(GENERATED_ALL_MARKER)
23+
&& let Some(relative_end) = curated[all_start..].find("]\n")
24+
{
25+
let all_end = all_start + relative_end + 2;
26+
curated.replace_range(all_start..all_end, CURATED_ALL_BLOCK);
27+
}
2028
if !curated.contains(PLUGIN_CORE_CLASS_MARKER) {
2129
curated.push_str(PLUGIN_CORE_CLASS_DEF);
2230
}
@@ -25,10 +33,18 @@ fn curate_extension_stub_content(content: &str) -> String {
2533
curated.contains("\"RegexFilterPluginCore\""),
2634
"curated extension stub is missing RegexFilterPluginCore in __all__",
2735
);
36+
assert!(
37+
curated.contains("\"SearchReplacePluginRust\""),
38+
"curated extension stub is missing SearchReplacePluginRust in __all__",
39+
);
2840
assert!(
2941
curated.contains(PLUGIN_CORE_CLASS_MARKER),
3042
"curated extension stub is missing RegexFilterPluginCore class definition",
3143
);
44+
assert!(
45+
curated.contains(ENGINE_CLASS_MARKER),
46+
"curated extension stub is missing SearchReplacePluginRust class definition",
47+
);
3248

3349
curated
3450
}
@@ -68,9 +84,10 @@ mod tests {
6884

6985
#[test]
7086
fn test_curate_extension_stub_adds_required_exports_and_class() {
71-
let generated = "# This file is automatically generated by pyo3_stub_gen\n# ruff: noqa: E501, F401, F403, F405\n\nimport typing\n\n__all__ = [\n \"SearchReplacePluginRust\",\n]\n";
87+
let generated = "# This file is automatically generated by pyo3_stub_gen\n# ruff: noqa: E501, F401, F403, F405\n\nimport builtins\nimport typing\n__all__ = [\n \"SearchReplacePluginRust\",\n]\n\n@typing.final\nclass SearchReplacePluginRust:\n def __new__(cls, config: typing.Any) -> SearchReplacePluginRust: ...\n";
7288
let curated = curate_extension_stub_content(generated);
7389
assert!(curated.contains("\"RegexFilterPluginCore\""));
90+
assert!(curated.contains("\"SearchReplacePluginRust\""));
7491
assert!(curated.contains(PLUGIN_CORE_CLASS_MARKER));
7592
}
7693
}

plugins/rust/python-package/regex_filter/src/lib.rs

Lines changed: 157 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,25 @@
44
// Regex Filter Plugin - Rust Implementation
55

66
use std::borrow::Cow;
7+
use std::collections::HashSet;
78
use std::sync::Once;
89

910
use log::debug;
1011
use pyo3::prelude::*;
1112
use pyo3::types::{PyDict, PyList, PyModule, PyTuple};
1213
use pyo3_stub_gen::define_stub_info_gatherer;
14+
use pyo3_stub_gen::derive::*;
1315
use regex::{Regex, RegexSet};
1416

1517
pub mod plugin;
1618

19+
const MAX_NESTED_DEPTH: usize = 64;
20+
21+
enum TraversalResult {
22+
Unchanged(Py<PyAny>),
23+
Modified(Py<PyAny>),
24+
}
25+
1726
#[derive(Debug, Clone)]
1827
pub struct SearchReplace {
1928
pub search: String,
@@ -85,118 +94,185 @@ impl SearchReplaceConfig {
8594
}
8695
}
8796

97+
#[gen_stub_pyclass]
8898
#[derive(Debug)]
8999
#[pyclass]
90100
pub struct SearchReplacePluginRust {
91101
pub config: SearchReplaceConfig,
92102
}
93103

94-
#[pymethods]
95-
impl SearchReplacePluginRust {
96-
#[new]
97-
pub fn new(config_dict: &Bound<'_, PyDict>) -> PyResult<Self> {
98-
let config = SearchReplaceConfig::from_py_dict(config_dict).map_err(|error| {
99-
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("Invalid config: {}", error))
100-
})?;
101-
Ok(Self { config })
104+
fn apply_patterns_impl<'a>(config: &'a SearchReplaceConfig, text: &'a str) -> Cow<'a, str> {
105+
if let Some(ref pattern_set) = config.pattern_set
106+
&& !pattern_set.is_match(text)
107+
{
108+
return Cow::Borrowed(text);
102109
}
103110

104-
pub fn apply_patterns(&self, text: &str) -> String {
105-
if let Some(ref pattern_set) = self.config.pattern_set
106-
&& !pattern_set.is_match(text)
107-
{
108-
return text.to_string();
111+
let mut result = Cow::Borrowed(text);
112+
let mut modified = false;
113+
114+
for pattern in &config.words {
115+
if pattern.compiled.is_match(&result) {
116+
let replaced = pattern.compiled.replace_all(&result, &pattern.replace);
117+
if let Cow::Owned(new_text) = replaced {
118+
result = Cow::Owned(new_text);
119+
modified = true;
120+
} else if modified {
121+
result = Cow::Owned(replaced.into_owned());
122+
}
123+
}
124+
}
125+
126+
result
127+
}
128+
129+
fn process_nested_impl(
130+
plugin: &SearchReplacePluginRust,
131+
py: Python<'_>,
132+
data: &Bound<'_, PyAny>,
133+
depth: usize,
134+
seen: &mut HashSet<usize>,
135+
) -> PyResult<TraversalResult> {
136+
if depth >= MAX_NESTED_DEPTH {
137+
return Err(pyo3::exceptions::PyValueError::new_err(format!(
138+
"Maximum nested depth of {} exceeded",
139+
MAX_NESTED_DEPTH
140+
)));
141+
}
142+
143+
if let Ok(text) = data.extract::<String>() {
144+
let modified_text = apply_patterns_impl(&plugin.config, &text);
145+
return match modified_text {
146+
Cow::Borrowed(_) => Ok(TraversalResult::Unchanged(data.clone().unbind())),
147+
Cow::Owned(value) => Ok(TraversalResult::Modified(
148+
value.into_pyobject(py)?.into_any().unbind(),
149+
)),
150+
};
151+
}
152+
153+
if let Ok(dict) = data.cast::<PyDict>() {
154+
let identity = dict.as_ptr() as usize;
155+
if !seen.insert(identity) {
156+
return Err(pyo3::exceptions::PyValueError::new_err(
157+
"Cyclic containers are not supported",
158+
));
109159
}
110160

111-
let mut result = Cow::Borrowed(text);
112-
let mut modified = false;
113-
114-
for pattern in &self.config.words {
115-
if pattern.compiled.is_match(&result) {
116-
let replaced = pattern.compiled.replace_all(&result, &pattern.replace);
117-
if let Cow::Owned(new_text) = replaced {
118-
result = Cow::Owned(new_text);
119-
modified = true;
120-
} else if modified {
121-
result = Cow::Owned(replaced.into_owned());
161+
let mut any_modified = false;
162+
let mut processed_items = Vec::with_capacity(dict.len());
163+
for (key, value) in dict.iter() {
164+
match process_nested_impl(plugin, py, &value, depth + 1, seen)? {
165+
TraversalResult::Unchanged(new_value) => {
166+
processed_items.push((key.clone().unbind(), new_value));
167+
}
168+
TraversalResult::Modified(new_value) => {
169+
any_modified = true;
170+
processed_items.push((key.clone().unbind(), new_value));
122171
}
123172
}
124173
}
174+
seen.remove(&identity);
125175

126-
result.into_owned()
176+
if !any_modified {
177+
return Ok(TraversalResult::Unchanged(data.clone().unbind()));
178+
}
179+
180+
let new_dict = PyDict::new(py);
181+
for (key, value) in processed_items {
182+
new_dict.set_item(key.bind(py), value.bind(py))?;
183+
}
184+
return Ok(TraversalResult::Modified(new_dict.into_any().unbind()));
127185
}
128186

129-
pub fn process_nested(
130-
&self,
131-
py: Python<'_>,
132-
data: &Bound<'_, PyAny>,
133-
) -> PyResult<(bool, Py<PyAny>)> {
134-
if let Ok(text) = data.extract::<String>() {
135-
let modified_text = self.apply_patterns(&text);
136-
if modified_text == text {
137-
return Ok((false, data.clone().unbind()));
138-
}
139-
return Ok((true, modified_text.into_pyobject(py)?.into_any().unbind()));
187+
if let Ok(list) = data.cast::<PyList>() {
188+
let identity = list.as_ptr() as usize;
189+
if !seen.insert(identity) {
190+
return Err(pyo3::exceptions::PyValueError::new_err(
191+
"Cyclic containers are not supported",
192+
));
140193
}
141194

142-
if let Ok(dict) = data.cast::<PyDict>() {
143-
let mut any_modified = false;
144-
let mut processed_items = Vec::with_capacity(dict.len());
145-
for (key, value) in dict.iter() {
146-
let (item_modified, new_value) = self.process_nested(py, &value)?;
147-
any_modified |= item_modified;
148-
processed_items.push((key.clone().unbind(), new_value));
195+
let mut any_modified = false;
196+
let mut new_items = Vec::with_capacity(list.len());
197+
for item in list.iter() {
198+
match process_nested_impl(plugin, py, &item, depth + 1, seen)? {
199+
TraversalResult::Unchanged(new_item) => new_items.push(new_item),
200+
TraversalResult::Modified(new_item) => {
201+
any_modified = true;
202+
new_items.push(new_item);
203+
}
149204
}
205+
}
206+
seen.remove(&identity);
150207

151-
if !any_modified {
152-
return Ok((false, data.clone().unbind()));
153-
}
208+
if !any_modified {
209+
return Ok(TraversalResult::Unchanged(data.clone().unbind()));
210+
}
154211

155-
let new_dict = PyDict::new(py);
156-
for (key, value) in processed_items {
157-
new_dict.set_item(key.bind(py), value.bind(py))?;
158-
}
159-
return Ok((true, new_dict.into_any().unbind()));
212+
let new_list = PyList::empty(py);
213+
for item in new_items {
214+
new_list.append(item.bind(py))?;
160215
}
216+
return Ok(TraversalResult::Modified(new_list.into_any().unbind()));
217+
}
161218

162-
if let Ok(list) = data.cast::<PyList>() {
163-
let mut any_modified = false;
164-
let mut new_items = Vec::with_capacity(list.len());
165-
for item in list.iter() {
166-
let (item_modified, new_item) = self.process_nested(py, &item)?;
167-
any_modified |= item_modified;
168-
new_items.push(new_item);
169-
}
219+
if let Ok(tuple) = data.cast::<PyTuple>() {
220+
let identity = tuple.as_ptr() as usize;
221+
if !seen.insert(identity) {
222+
return Err(pyo3::exceptions::PyValueError::new_err(
223+
"Cyclic containers are not supported",
224+
));
225+
}
170226

171-
if !any_modified {
172-
return Ok((false, data.clone().unbind()));
227+
let mut any_modified = false;
228+
let mut new_items = Vec::with_capacity(tuple.len());
229+
for item in tuple.iter() {
230+
match process_nested_impl(plugin, py, &item, depth + 1, seen)? {
231+
TraversalResult::Unchanged(new_item) => new_items.push(new_item),
232+
TraversalResult::Modified(new_item) => {
233+
any_modified = true;
234+
new_items.push(new_item);
235+
}
173236
}
237+
}
238+
seen.remove(&identity);
174239

175-
let new_list = PyList::empty(py);
176-
for item in new_items {
177-
new_list.append(item.bind(py))?;
178-
}
179-
return Ok((true, new_list.into_any().unbind()));
240+
if !any_modified {
241+
return Ok(TraversalResult::Unchanged(data.clone().unbind()));
180242
}
181243

182-
if let Ok(tuple) = data.cast::<PyTuple>() {
183-
let mut any_modified = false;
184-
let mut new_items = Vec::with_capacity(tuple.len());
185-
for item in tuple.iter() {
186-
let (item_modified, new_item) = self.process_nested(py, &item)?;
187-
any_modified |= item_modified;
188-
new_items.push(new_item);
189-
}
244+
let new_tuple = PyTuple::new(py, new_items.iter().map(|item| item.bind(py)))?;
245+
return Ok(TraversalResult::Modified(new_tuple.into_any().unbind()));
246+
}
190247

191-
if !any_modified {
192-
return Ok((false, data.clone().unbind()));
193-
}
248+
Ok(TraversalResult::Unchanged(data.clone().unbind()))
249+
}
194250

195-
let new_tuple = PyTuple::new(py, new_items.iter().map(|item| item.bind(py)))?;
196-
return Ok((true, new_tuple.into_any().unbind()));
197-
}
251+
#[gen_stub_pymethods]
252+
#[pymethods]
253+
impl SearchReplacePluginRust {
254+
#[new]
255+
pub fn new(config_dict: &Bound<'_, PyDict>) -> PyResult<Self> {
256+
let config = SearchReplaceConfig::from_py_dict(config_dict).map_err(|error| {
257+
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("Invalid config: {}", error))
258+
})?;
259+
Ok(Self { config })
260+
}
198261

199-
Ok((false, data.clone().unbind()))
262+
pub fn apply_patterns(&self, text: &str) -> String {
263+
apply_patterns_impl(&self.config, text).into_owned()
264+
}
265+
266+
pub fn process_nested(
267+
&self,
268+
py: Python<'_>,
269+
data: &Bound<'_, PyAny>,
270+
) -> PyResult<(bool, Py<PyAny>)> {
271+
let mut seen = HashSet::new();
272+
Ok(match process_nested_impl(self, py, data, 0, &mut seen)? {
273+
TraversalResult::Unchanged(value) => (false, value),
274+
TraversalResult::Modified(value) => (true, value),
275+
})
200276
}
201277
}
202278

0 commit comments

Comments
 (0)