Skip to content

Commit 2dc4dbb

Browse files
committed
Regression test for THRIFT-4002: Immutable exception deserialization
Client: py Patch: Jens Geyer Generated-by: Opencode big-pickle This test verifies that immutable structs (including exceptions, which are immutable by default since Thrift 0.14.0) can be properly deserialized without triggering the __setattr__ TypeError. The bug manifests when: 1. A struct class is marked immutable (has __setattr__ that raises TypeError) 2. Thrift's deserialization tries to set attributes via setattr instead of using the kwargs constructor Test coverage: - Immutable exception creation and hashability - Immutable exception blocks modification/deletion - Round-trip serialization/deserialization with TBinaryProtocol - Round-trip serialization/deserialization with TCompactProtocol - Accelerated protocol tests (C extension) when available Related: THRIFT-4002, THRIFT-5715
1 parent 3b0ab4d commit 2dc4dbb

2 files changed

Lines changed: 244 additions & 0 deletions

File tree

lib/py/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,5 @@ if(BUILD_TESTING)
3434
add_test(NAME PythonThriftTZlibTransport COMMAND Python3::Interpreter ${CMAKE_CURRENT_SOURCE_DIR}/test/thrift_TZlibTransport.py)
3535
add_test(NAME PythonThriftProtocol COMMAND Python3::Interpreter ${CMAKE_CURRENT_SOURCE_DIR}/test/thrift_TCompactProtocol.py)
3636
add_test(NAME PythonThriftTNonblockingServer COMMAND Python3::Interpreter ${CMAKE_CURRENT_SOURCE_DIR}/test/thrift_TNonblockingServer.py)
37+
add_test(NAME PythonImmutableException COMMAND Python3::Interpreter ${CMAKE_CURRENT_SOURCE_DIR}/test/test_immutable_exception.py)
3738
endif()
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
#!/usr/bin/env python
2+
3+
#
4+
# Licensed to the Apache Software Foundation (ASF) under one
5+
# or more contributor license agreements. See the NOTICE file
6+
# distributed with this work for additional information
7+
# regarding copyright ownership. The ASF licenses this file
8+
# to you under the Apache License, Version 2.0 (the
9+
# "License"); you may not use this file except in compliance
10+
# with the License. You may obtain a copy of the License at
11+
#
12+
# http://www.apache.org/licenses/LICENSE-2.0
13+
#
14+
# Unless required by applicable law or agreed to in writing,
15+
# software distributed under the License is distributed on an
16+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
17+
# KIND, either express or implied. See the License for the
18+
# specific language governing permissions and limitations
19+
# under the License.
20+
#
21+
22+
"""
23+
Test cases for THRIFT-4002: Immutable exception deserialization.
24+
25+
This test verifies that immutable structs (including exceptions, which are immutable
26+
by default since Thrift 0.14.0) can be properly deserialized without triggering
27+
the __setattr__ TypeError.
28+
29+
The bug manifests when:
30+
1. A struct class is marked immutable (has __setattr__ that raises TypeError)
31+
2. Thrift's deserialization tries to set attributes via setattr instead of
32+
using the kwargs constructor
33+
34+
This test ensures that all deserialization paths (C extension, pure Python,
35+
all protocols) correctly handle immutable structs.
36+
"""
37+
38+
import unittest
39+
from collections.abc import Hashable
40+
41+
import glob
42+
import os
43+
import sys
44+
45+
SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__))
46+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR)))
47+
48+
for libpath in glob.glob(os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*')):
49+
for pattern in ('-%d.%d', '-%d%d'):
50+
postfix = pattern % (sys.version_info[0], sys.version_info[1])
51+
if libpath.endswith(postfix):
52+
sys.path.insert(0, libpath)
53+
break
54+
else:
55+
src_path = os.path.join(ROOT_DIR, 'lib', 'py', 'src')
56+
if os.path.exists(src_path):
57+
sys.path.insert(0, src_path)
58+
from thrift.Thrift import TException
59+
from thrift.transport import TTransport
60+
from thrift.protocol import TBinaryProtocol, TCompactProtocol
61+
62+
63+
class ImmutableException(TException):
64+
"""Test exception that mimics generated immutable exception behavior."""
65+
66+
thrift_spec = (
67+
None, # 0
68+
(1, 11, 'message', 'UTF8', None, ), # 1: string
69+
)
70+
71+
def __init__(self, message=None):
72+
super(ImmutableException, self).__init__(message)
73+
74+
def __setattr__(self, *args):
75+
raise TypeError("can't modify immutable instance")
76+
77+
def __delattr__(self, *args):
78+
raise TypeError("can't modify immutable instance")
79+
80+
def __hash__(self):
81+
return hash(self.__class__) ^ hash((self.message,))
82+
83+
def __eq__(self, other):
84+
return isinstance(other, self.__class__) and self.message == other.message
85+
86+
def write(self, oprot):
87+
if oprot._fast_encode is not None and self.thrift_spec is not None:
88+
oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec]))
89+
return
90+
oprot.writeStructBegin('ImmutableException')
91+
if self.message is not None:
92+
oprot.writeFieldBegin('message', 11, 1)
93+
oprot.writeString(self.message)
94+
oprot.writeFieldEnd()
95+
oprot.writeFieldStop()
96+
oprot.writeStructEnd()
97+
98+
@classmethod
99+
def read(cls, iprot):
100+
if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and cls.thrift_spec is not None:
101+
return iprot._fast_decode(None, iprot, [cls, cls.thrift_spec])
102+
return iprot.readStruct(cls, cls.thrift_spec, True)
103+
104+
105+
class MutableException(TException):
106+
"""Test exception that mimics generated mutable exception behavior."""
107+
108+
thrift_spec = (
109+
None, # 0
110+
(1, 11, 'message', 'UTF8', None, ), # 1: string
111+
)
112+
113+
def __init__(self, message=None):
114+
super(MutableException, self).__init__(message)
115+
116+
def write(self, oprot):
117+
if oprot._fast_encode is not None and self.thrift_spec is not None:
118+
oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec]))
119+
return
120+
oprot.writeStructBegin('MutableException')
121+
if self.message is not None:
122+
oprot.writeFieldBegin('message', 11, 1)
123+
oprot.writeString(self.message)
124+
oprot.writeFieldEnd()
125+
oprot.writeFieldStop()
126+
oprot.writeStructEnd()
127+
128+
@classmethod
129+
def read(cls, iprot):
130+
if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and cls.thrift_spec is not None:
131+
return iprot._fast_decode(None, iprot, [cls, cls.thrift_spec])
132+
return iprot.readStruct(cls, cls.thrift_spec, False)
133+
134+
135+
class TestImmutableExceptionDeserialization(unittest.TestCase):
136+
"""Test that immutable exceptions can be properly deserialized."""
137+
138+
def _roundtrip(self, exc, protocol_class):
139+
"""Serialize and deserialize an exception."""
140+
otrans = TTransport.TMemoryBuffer()
141+
oproto = protocol_class.getProtocol(otrans)
142+
exc.write(oproto)
143+
itrans = TTransport.TMemoryBuffer(otrans.getvalue())
144+
iproto = protocol_class.getProtocol(itrans)
145+
return exc.__class__.read(iproto)
146+
147+
def test_immutable_exception_is_hashable(self):
148+
"""Verify that immutable exceptions are hashable (required for caching/logging)."""
149+
exc = ImmutableException(message="test")
150+
self.assertTrue(isinstance(exc, Hashable))
151+
self.assertEqual(hash(exc), hash(ImmutableException(message="test")))
152+
153+
def test_immutable_exception_blocks_modification(self):
154+
"""Verify that immutable exceptions raise TypeError on attribute modification."""
155+
exc = ImmutableException(message="test")
156+
with self.assertRaises(TypeError) as cm:
157+
exc.message = "modified"
158+
self.assertIn("immutable", str(cm.exception))
159+
160+
def test_immutable_exception_blocks_deletion(self):
161+
"""Verify that immutable exceptions raise TypeError on attribute deletion."""
162+
exc = ImmutableException(message="test")
163+
with self.assertRaises(TypeError) as cm:
164+
del exc.message
165+
self.assertIn("immutable", str(cm.exception))
166+
167+
def test_immutable_exception_binary_protocol(self):
168+
"""Test immutable exception deserialization with TBinaryProtocol."""
169+
exc = ImmutableException(message="test error")
170+
deserialized = self._roundtrip(exc, TBinaryProtocol.TBinaryProtocolFactory())
171+
self.assertEqual(exc.message, deserialized.message)
172+
self.assertEqual(exc, deserialized)
173+
174+
def test_immutable_exception_compact_protocol(self):
175+
"""Test immutable exception deserialization with TCompactProtocol."""
176+
exc = ImmutableException(message="test error")
177+
deserialized = self._roundtrip(exc, TCompactProtocol.TCompactProtocolFactory())
178+
self.assertEqual(exc.message, deserialized.message)
179+
self.assertEqual(exc, deserialized)
180+
181+
def test_mutable_exception_can_be_modified(self):
182+
"""Verify that mutable exceptions can be modified (control test)."""
183+
exc = MutableException(message="original")
184+
exc.message = "modified"
185+
self.assertEqual(exc.message, "modified")
186+
187+
188+
class TestImmutableExceptionAccelerated(unittest.TestCase):
189+
"""Test immutable exception deserialization with accelerated protocols (C extension)."""
190+
191+
def setUp(self):
192+
try:
193+
# The import is intentionally unused - it only checks if the C extension
194+
# is available by catching ImportError. The noqa comment documents this.
195+
from thrift.protocol import fastbinary # noqa: F401
196+
self._has_c_extension = True
197+
except ImportError:
198+
self._has_c_extension = False
199+
200+
def _roundtrip(self, exc, protocol_class):
201+
"""Serialize and deserialize an exception."""
202+
otrans = TTransport.TMemoryBuffer()
203+
oproto = protocol_class.getProtocol(otrans)
204+
exc.write(oproto)
205+
itrans = TTransport.TMemoryBuffer(otrans.getvalue())
206+
iproto = protocol_class.getProtocol(itrans)
207+
return exc.__class__.read(iproto)
208+
209+
def test_immutable_exception_binary_accelerated(self):
210+
"""Test immutable exception with TBinaryProtocolAccelerated."""
211+
if not self._has_c_extension:
212+
self.skipTest("C extension not available")
213+
exc = ImmutableException(message="test error")
214+
deserialized = self._roundtrip(
215+
exc,
216+
TBinaryProtocol.TBinaryProtocolAcceleratedFactory(fallback=False)
217+
)
218+
self.assertEqual(exc.message, deserialized.message)
219+
self.assertEqual(exc, deserialized)
220+
221+
def test_immutable_exception_compact_accelerated(self):
222+
"""Test immutable exception with TCompactProtocolAccelerated."""
223+
if not self._has_c_extension:
224+
self.skipTest("C extension not available")
225+
exc = ImmutableException(message="test error")
226+
deserialized = self._roundtrip(
227+
exc,
228+
TCompactProtocol.TCompactProtocolAcceleratedFactory(fallback=False)
229+
)
230+
self.assertEqual(exc.message, deserialized.message)
231+
self.assertEqual(exc, deserialized)
232+
233+
234+
def suite():
235+
suite = unittest.TestSuite()
236+
loader = unittest.TestLoader()
237+
suite.addTest(loader.loadTestsFromTestCase(TestImmutableExceptionDeserialization))
238+
suite.addTest(loader.loadTestsFromTestCase(TestImmutableExceptionAccelerated))
239+
return suite
240+
241+
242+
if __name__ == "__main__":
243+
unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2))

0 commit comments

Comments
 (0)