|
21 | 21 | import java.io.ByteArrayOutputStream; |
22 | 22 | import java.io.File; |
23 | 23 | import java.io.IOException; |
| 24 | +import java.io.InvalidClassException; |
| 25 | +import java.io.ObjectInputFilter; |
24 | 26 | import java.io.ObjectInputStream; |
25 | 27 | import java.io.ObjectOutputStream; |
26 | 28 | import java.io.Serializable; |
@@ -162,4 +164,142 @@ File getTempFile() { |
162 | 164 | return null; |
163 | 165 | } |
164 | 166 | } |
| 167 | + |
| 168 | + @Test |
| 169 | + public void filterRejectsUnauthorizedClasses() throws Exception { |
| 170 | + // Arrange: Create filter that only allows java.lang and java.util classes |
| 171 | + ObjectInputFilter filter = ObjectInputFilter.Config.createFilter("java.lang.*;java.util.*;!*"); |
| 172 | + TestSerializable testObject = new TestSerializable("test"); |
| 173 | + byte[] serializedData = serialize(testObject); |
| 174 | + |
| 175 | + // Act & Assert: Deserialization should be rejected by filter |
| 176 | + assertThatThrownBy(() -> { |
| 177 | + try (ClassLoaderObjectInputStream ois = new ClassLoaderObjectInputStream( |
| 178 | + new ByteArrayInputStream(serializedData), |
| 179 | + Thread.currentThread().getContextClassLoader(), |
| 180 | + filter)) { |
| 181 | + ois.readObject(); |
| 182 | + } |
| 183 | + }).isInstanceOf(InvalidClassException.class); |
| 184 | + } |
| 185 | + |
| 186 | + @Test |
| 187 | + public void filterAllowsAuthorizedClasses() throws Exception { |
| 188 | + // Arrange: Create filter that allows this test class package |
| 189 | + ObjectInputFilter filter = ObjectInputFilter.Config.createFilter( |
| 190 | + "java.lang.*;java.util.*;org.apache.geode.modules.util.**;!*"); |
| 191 | + TestSerializable testObject = new TestSerializable("test data"); |
| 192 | + byte[] serializedData = serialize(testObject); |
| 193 | + |
| 194 | + // Act: Deserialize with filter |
| 195 | + Object deserialized; |
| 196 | + try (ClassLoaderObjectInputStream ois = new ClassLoaderObjectInputStream( |
| 197 | + new ByteArrayInputStream(serializedData), |
| 198 | + Thread.currentThread().getContextClassLoader(), |
| 199 | + filter)) { |
| 200 | + deserialized = ois.readObject(); |
| 201 | + } |
| 202 | + |
| 203 | + // Assert: Object should be successfully deserialized |
| 204 | + assertThat(deserialized).isInstanceOf(TestSerializable.class); |
| 205 | + assertThat(((TestSerializable) deserialized).getData()).isEqualTo("test data"); |
| 206 | + } |
| 207 | + |
| 208 | + @Test |
| 209 | + public void nullFilterAllowsAllClasses() throws Exception { |
| 210 | + // Arrange: Null filter means no filtering (backward compatibility) |
| 211 | + TestSerializable testObject = new TestSerializable("unfiltered data"); |
| 212 | + byte[] serializedData = serialize(testObject); |
| 213 | + |
| 214 | + // Act: Deserialize with null filter |
| 215 | + Object deserialized; |
| 216 | + try (ClassLoaderObjectInputStream ois = new ClassLoaderObjectInputStream( |
| 217 | + new ByteArrayInputStream(serializedData), |
| 218 | + Thread.currentThread().getContextClassLoader(), |
| 219 | + null)) { |
| 220 | + deserialized = ois.readObject(); |
| 221 | + } |
| 222 | + |
| 223 | + // Assert: Object should be successfully deserialized |
| 224 | + assertThat(deserialized).isInstanceOf(TestSerializable.class); |
| 225 | + assertThat(((TestSerializable) deserialized).getData()).isEqualTo("unfiltered data"); |
| 226 | + } |
| 227 | + |
| 228 | + @Test |
| 229 | + public void deprecatedConstructorStillWorks() throws Exception { |
| 230 | + // Arrange: Use deprecated constructor without filter |
| 231 | + TestSerializable testObject = new TestSerializable("legacy code"); |
| 232 | + byte[] serializedData = serialize(testObject); |
| 233 | + |
| 234 | + // Act: Deserialize using deprecated constructor |
| 235 | + Object deserialized; |
| 236 | + try (ClassLoaderObjectInputStream ois = new ClassLoaderObjectInputStream( |
| 237 | + new ByteArrayInputStream(serializedData), |
| 238 | + Thread.currentThread().getContextClassLoader())) { |
| 239 | + deserialized = ois.readObject(); |
| 240 | + } |
| 241 | + |
| 242 | + // Assert: Object should be successfully deserialized (backward compatibility) |
| 243 | + assertThat(deserialized).isInstanceOf(TestSerializable.class); |
| 244 | + assertThat(((TestSerializable) deserialized).getData()).isEqualTo("legacy code"); |
| 245 | + } |
| 246 | + |
| 247 | + @Test |
| 248 | + public void filterEnforcesResourceLimits() throws Exception { |
| 249 | + // Arrange: Create filter with very low depth limit |
| 250 | + ObjectInputFilter filter = ObjectInputFilter.Config.createFilter("maxdepth=2;*"); |
| 251 | + NestedSerializable nested = new NestedSerializable( |
| 252 | + new NestedSerializable( |
| 253 | + new NestedSerializable(null))); // Depth of 3 |
| 254 | + byte[] serializedData = serialize(nested); |
| 255 | + |
| 256 | + // Act & Assert: Should reject due to depth limit |
| 257 | + assertThatThrownBy(() -> { |
| 258 | + try (ClassLoaderObjectInputStream ois = new ClassLoaderObjectInputStream( |
| 259 | + new ByteArrayInputStream(serializedData), |
| 260 | + Thread.currentThread().getContextClassLoader(), |
| 261 | + filter)) { |
| 262 | + ois.readObject(); |
| 263 | + } |
| 264 | + }).isInstanceOf(InvalidClassException.class); |
| 265 | + } |
| 266 | + |
| 267 | + /** |
| 268 | + * Helper method to serialize an object to byte array |
| 269 | + */ |
| 270 | + private byte[] serialize(Object obj) throws IOException { |
| 271 | + ByteArrayOutputStream baos = new ByteArrayOutputStream(); |
| 272 | + try (ObjectOutputStream oos = new ObjectOutputStream(baos)) { |
| 273 | + oos.writeObject(obj); |
| 274 | + } |
| 275 | + return baos.toByteArray(); |
| 276 | + } |
| 277 | + |
| 278 | + /** |
| 279 | + * Test class for serialization testing |
| 280 | + */ |
| 281 | + static class TestSerializable implements Serializable { |
| 282 | + private static final long serialVersionUID = 1L; |
| 283 | + private final String data; |
| 284 | + |
| 285 | + TestSerializable(String data) { |
| 286 | + this.data = data; |
| 287 | + } |
| 288 | + |
| 289 | + String getData() { |
| 290 | + return data; |
| 291 | + } |
| 292 | + } |
| 293 | + |
| 294 | + /** |
| 295 | + * Nested test class for depth limit testing |
| 296 | + */ |
| 297 | + static class NestedSerializable implements Serializable { |
| 298 | + private static final long serialVersionUID = 1L; |
| 299 | + private final NestedSerializable nested; |
| 300 | + |
| 301 | + NestedSerializable(NestedSerializable nested) { |
| 302 | + this.nested = nested; |
| 303 | + } |
| 304 | + } |
165 | 305 | } |
0 commit comments