Skip to content

Commit 8f3738b

Browse files
fix: implement __eq__ and __ne__ for CopyOnWriteDict (#55)
* fix: implement __eq__ and __ne__ for CopyOnWriteDict Fixes equality comparison bug where CopyOnWriteDict compared equal to {} even when containing data. This caused apply_policy() to incorrectly drop valid payload modifications when plugins removed all arguments. Changes: - Add __eq__ and __ne__ methods to CopyOnWriteDict - Add 13 comprehensive equality unit tests - Add policy regression tests for empty args scenario - Add end-to-end integration tests Signed-off-by: prakhar-singh1928 <prakhar.singh1928@ibm.com> * fix: added length check for performance Signed-off-by: prakhar-singh1928 <prakhar.singh1928@ibm.com> * fix: restore deleted assertion and add performance optimization - Restored missing 'assert a not in keys' in test_iteration_order_with_deletions - Added fast-path length check in CopyOnWriteDict.__eq__() for better performance - Performance optimization is safe: if lengths differ, mappings cannot be equal Signed-off-by: prakhar-singh1928 <prakhar.singh1928@ibm.com> * fix: linted memory.py, added assertion to test. --------- Signed-off-by: prakhar-singh1928 <prakhar.singh1928@ibm.com> Co-authored-by: Teryl Taylor <teryl.taylor@gmail.com>
1 parent 5ed4b3d commit 8f3738b

4 files changed

Lines changed: 259 additions & 0 deletions

File tree

cpex/framework/memory.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import copy
1515
import logging
1616
import weakref
17+
from collections.abc import Mapping
1718
from typing import Any, Iterator, Optional, TypeVar
1819

1920
# Third-Party
@@ -173,6 +174,48 @@ def __repr__(self) -> str:
173174
"""
174175
return f"CopyOnWriteDict({dict(self.items())})"
175176

177+
__hash__ = None
178+
179+
def __eq__(self, other: Any) -> bool:
180+
"""
181+
Compare equality with another mapping.
182+
183+
Compares the materialized logical mapping (original + modifications - deletions)
184+
rather than the empty base dict storage.
185+
186+
Args:
187+
other: The object to compare with.
188+
189+
Returns:
190+
True if other is a Mapping with the same key-value pairs, False otherwise.
191+
Returns NotImplemented for non-Mapping types to allow other.__eq__ to handle it.
192+
"""
193+
if not isinstance(other, Mapping):
194+
return NotImplemented
195+
196+
# Fast-path: if lengths differ, mappings cannot be equal
197+
if len(self) != len(other):
198+
return False
199+
200+
# Compare materialized items
201+
return dict(self.items()) == dict(other.items())
202+
203+
def __ne__(self, other: Any) -> bool:
204+
"""
205+
Compare inequality with another mapping.
206+
207+
Args:
208+
other: The object to compare with.
209+
210+
Returns:
211+
True if not equal, False if equal.
212+
Returns NotImplemented for non-Mapping types.
213+
"""
214+
eq = self.__eq__(other)
215+
if eq is NotImplemented:
216+
return NotImplemented
217+
return not eq
218+
176219
def get(self, key: Any, default: Optional[Any] = None) -> Any:
177220
"""
178221
Get an item with a default fallback.

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ preview = true
146146
fixable = ["ALL"]
147147
unfixable = []
148148

149+
[tool.ruff.lint.pylint]
150+
# Relaxed from the default of 5; existing code has wider try clauses (max observed 38).
151+
max-statements-in-try = 50
152+
149153
# Ignore D1 (docstring checks) and Pylint checks in tests and other non-production code
150154
[tool.ruff.lint.per-file-ignores]
151155
"tests/**/*.py" = ["D1", "PL"]

tests/unit/cpex/framework/test_memory.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,104 @@ def test_iter_skips_deleted_keys_in_modifications(self):
819819
assert set(keys) == {"b", "c"}
820820
assert "a" not in keys
821821

822+
def test_equality_with_empty_dict(self):
823+
"""CopyOnWriteDict with data should not equal empty dict."""
824+
cow = CopyOnWriteDict({"a": 1, "b": 2})
825+
assert cow != {}
826+
assert {} != cow
827+
assert not (cow == {})
828+
assert not ({} == cow)
829+
830+
def test_equality_with_matching_dict(self):
831+
"""CopyOnWriteDict should equal dict with same key-value pairs."""
832+
original = {"a": 1, "b": 2, "c": 3}
833+
cow = CopyOnWriteDict(original)
834+
assert cow == {"a": 1, "b": 2, "c": 3}
835+
assert {"a": 1, "b": 2, "c": 3} == cow
836+
837+
def test_equality_with_different_dict(self):
838+
"""CopyOnWriteDict should not equal dict with different content."""
839+
cow = CopyOnWriteDict({"a": 1, "b": 2})
840+
assert cow != {"a": 1, "b": 3}
841+
assert cow != {"a": 1}
842+
assert cow != {"a": 1, "b": 2, "c": 3}
843+
# Same length, different keys
844+
assert cow != {"a": 1, "c": 2}
845+
846+
def test_equality_after_modifications(self):
847+
"""Equality should reflect modifications."""
848+
cow = CopyOnWriteDict({"a": 1, "b": 2})
849+
cow["c"] = 3
850+
assert cow == {"a": 1, "b": 2, "c": 3}
851+
assert cow != {"a": 1, "b": 2}
852+
853+
def test_equality_after_deletions(self):
854+
"""Equality should reflect deletions."""
855+
cow = CopyOnWriteDict({"a": 1, "b": 2, "c": 3})
856+
del cow["b"]
857+
assert cow == {"a": 1, "c": 3}
858+
assert cow != {"a": 1, "b": 2, "c": 3}
859+
860+
def test_equality_after_override(self):
861+
"""Equality should reflect overridden values."""
862+
cow = CopyOnWriteDict({"a": 1, "b": 2})
863+
cow["a"] = 10
864+
assert cow == {"a": 10, "b": 2}
865+
assert cow != {"a": 1, "b": 2}
866+
867+
def test_equality_with_another_copyonwritedict(self):
868+
"""Two CopyOnWriteDict instances with same content should be equal."""
869+
cow1 = CopyOnWriteDict({"a": 1, "b": 2})
870+
cow2 = CopyOnWriteDict({"a": 1, "b": 2})
871+
assert cow1 == cow2
872+
assert cow2 == cow1
873+
874+
def test_equality_empty_copyonwritedict(self):
875+
"""Empty CopyOnWriteDict should equal empty dict."""
876+
cow = CopyOnWriteDict({})
877+
assert cow == {}
878+
assert {} == cow
879+
880+
def test_equality_with_non_mapping_returns_notimplemented(self):
881+
"""Equality with non-Mapping types should return NotImplemented."""
882+
cow = CopyOnWriteDict({"a": 1})
883+
# These should not raise, Python will handle NotImplemented
884+
assert cow != "not a dict"
885+
assert cow != 123
886+
assert cow != ["a", "list"]
887+
assert cow != None
888+
889+
def test_inequality_operator(self):
890+
"""Test __ne__ operator works correctly."""
891+
cow = CopyOnWriteDict({"a": 1, "b": 2})
892+
assert cow != {}
893+
assert cow != {"a": 1}
894+
assert not (cow != {"a": 1, "b": 2})
895+
896+
def test_copyonwritedict_is_unhashable(self):
897+
"""CopyOnWriteDict should remain unhashable like dict."""
898+
cow = CopyOnWriteDict({"a": 1})
899+
with pytest.raises(TypeError):
900+
hash(cow)
901+
902+
def test_equality_wxo_args_scenario(self):
903+
"""Regression test for the WXO args bug scenario."""
904+
# This is the exact scenario from the bug report
905+
cow = CopyOnWriteDict({
906+
"wxo_connection_id": "",
907+
"wxo_auth": "fake-token",
908+
"wxo_environment_id": "draft",
909+
})
910+
911+
# These were the failing assertions in the bug
912+
assert cow != {}
913+
assert {} != cow
914+
assert cow == {
915+
"wxo_connection_id": "",
916+
"wxo_auth": "fake-token",
917+
"wxo_environment_id": "draft",
918+
}
919+
822920

823921
class TestCopyOnWriteFunction:
824922
"""Test suite for copyonwrite() factory function."""

tests/unit/cpex/framework/test_policies.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,65 @@ class PayloadWithModel(PluginPayload):
172172
assert result is not None
173173
assert result.nested.x == 99 # type: ignore[union-attr]
174174

175+
def test_copyonwritedict_args_empty_modification_preserved(self):
176+
"""Regression test for bug where CopyOnWriteDict equality caused
177+
apply_policy to drop valid empty args modification.
178+
179+
When a plugin receives args as CopyOnWriteDict with data and returns
180+
an empty dict, apply_policy should treat this as a valid modification.
181+
Previously, CopyOnWriteDict.__eq__ was not implemented, causing the
182+
comparison to use dict's default equality which compared the empty
183+
base storage, incorrectly returning True for CopyOnWriteDict({...}) == {}.
184+
"""
185+
from cpex.framework.memory import CopyOnWriteDict
186+
187+
policy = HookPayloadPolicy(writable_fields=frozenset({"args"}))
188+
189+
# Simulate plugin receiving payload with CopyOnWriteDict args
190+
original = SamplePayload(
191+
name="test",
192+
args=CopyOnWriteDict({
193+
"wxo_connection_id": "",
194+
"wxo_auth": "fake-token",
195+
"wxo_environment_id": "draft",
196+
}),
197+
secret="s",
198+
)
199+
200+
# Plugin strips all args, returning empty dict
201+
modified = SamplePayload(name="test", args={}, secret="s")
202+
203+
result = apply_policy(original, modified, policy)
204+
205+
# The modification should be preserved, not dropped
206+
assert result is not None, "apply_policy should not return None when args changed from {...} to {}"
207+
assert result.args == {} # type: ignore[union-attr]
208+
assert result.name == "test" # type: ignore[union-attr]
209+
assert result.secret == "s" # type: ignore[union-attr]
210+
211+
def test_copyonwritedict_args_partial_modification_preserved(self):
212+
"""Test that partial arg removal is also preserved correctly."""
213+
from cpex.framework.memory import CopyOnWriteDict
214+
215+
policy = HookPayloadPolicy(writable_fields=frozenset({"args"}))
216+
217+
original = SamplePayload(
218+
name="test",
219+
args=CopyOnWriteDict({
220+
"wxo_auth": "token",
221+
"real_arg": "value",
222+
}),
223+
secret="s",
224+
)
225+
226+
# Plugin removes only wxo_auth, keeping real_arg
227+
modified = SamplePayload(name="test", args={"real_arg": "value"}, secret="s")
228+
229+
result = apply_policy(original, modified, policy)
230+
231+
assert result is not None
232+
assert result.args == {"real_arg": "value"} # type: ignore[union-attr]
233+
175234

176235
class TestPluginPayloadFrozen:
177236
"""Tests for frozen PluginPayload base class."""
@@ -752,6 +811,61 @@ async def tool_pre_invoke(self, payload, context):
752811
assert result.modified_payload.secret == "safe" # Policy filtered this out
753812

754813

814+
@pytest.mark.asyncio
815+
async def test_tool_pre_invoke_empty_args_modification_preserved_through_executor(self):
816+
"""Regression test for the tool_pre_invoke executor path.
817+
818+
A plugin receives CoW-wrapped args containing only specific fields,
819+
strips them all, and returns a payload with args={}. The executor should
820+
preserve that empty args modification instead of dropping it as
821+
"unchanged".
822+
"""
823+
from cpex.framework.base import HookRef, Plugin, PluginRef
824+
from cpex.framework.hooks.policies import HookPayloadPolicy
825+
from cpex.framework.hooks.tools import ToolPreInvokePayload
826+
from cpex.framework.manager import PluginExecutor
827+
from cpex.framework.memory import CopyOnWriteDict
828+
from cpex.framework.models import GlobalContext, PluginConfig, PluginResult
829+
830+
seen_arg_types = []
831+
832+
class StripWxoArgsPlugin(Plugin):
833+
async def tool_pre_invoke(self, payload, context):
834+
seen_arg_types.append(type(payload.args))
835+
cleaned_args = {k: v for k, v in payload.args.items() if not k.startswith("wxo_")}
836+
modified = payload.model_copy(update={"args": cleaned_args})
837+
return PluginResult(continue_processing=True, modified_payload=modified)
838+
839+
policies = {
840+
"tool_pre_invoke": HookPayloadPolicy(writable_fields=frozenset({"args"})),
841+
}
842+
executor = PluginExecutor(hook_policies=policies)
843+
844+
config = PluginConfig(name="stripper", kind="test.Plugin", version="1.0", hooks=["tool_pre_invoke"])
845+
plugin = StripWxoArgsPlugin(config)
846+
hook_ref = HookRef("tool_pre_invoke", PluginRef(plugin))
847+
848+
payload = ToolPreInvokePayload(
849+
name="list_all_secrets",
850+
args={
851+
"wxo_connection_id": "",
852+
"wxo_auth": "fake-token",
853+
"wxo_environment_id": "draft",
854+
},
855+
)
856+
global_ctx = GlobalContext(request_id="tool-pre-empty-args")
857+
858+
result, _ = await executor.execute([hook_ref], payload, global_ctx, hook_type="tool_pre_invoke")
859+
860+
assert seen_arg_types == [CopyOnWriteDict]
861+
assert result.modified_payload is not None
862+
assert result.modified_payload == ToolPreInvokePayload(name="list_all_secrets", args={})
863+
assert payload.args == {
864+
"wxo_connection_id": "",
865+
"wxo_auth": "fake-token",
866+
"wxo_environment_id": "draft",
867+
}
868+
755869
class TestMultiPluginDictChain:
756870
"""Tests for multi-plugin chains where an earlier plugin returns a dict payload."""
757871

0 commit comments

Comments
 (0)