Skip to content

Commit 895752c

Browse files
fix: Mitigate unsafe Java deserialization (CWE-502) (#2513)
--------- Co-authored-by: Brendan Walsh <37676373+BrendanWalsh@users.noreply.github.com>
1 parent 423b969 commit 895752c

4 files changed

Lines changed: 260 additions & 11 deletions

File tree

.agents/skills/code-review/SKILL.md

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
---
22
name: code-review
3-
description: Quick review checklist for python and scala code changes before callings it done.
3+
description: Quick review checklist for python and scala code changes before callings it done.
44
---
55

66
# Code Review
@@ -13,4 +13,67 @@ Use this skill when reviewing SynapseML changes.
1313
2. Run Scala style: `sbt scalastyle test:scalastyle`
1414
3. Run Python format check: `black --check --extend-exclude 'docs/' .`
1515
4. Run targeted tests for touched code.
16-
5. Report only concrete issues with file paths and fixes.
16+
5. Apply the checklists below to every changed file.
17+
6. Report only concrete issues with file paths and fixes.
18+
19+
## Security Checklist
20+
21+
Apply when changes touch serialization, I/O, network, or authentication code.
22+
23+
### Deserialization (CWE-502)
24+
- [ ] No raw `ObjectInputStream.readObject()` — use `SafeObjectInputStream` with an allowlist
25+
- [ ] `resolveClass` allowlist validates array component types — never allowlist the `[` prefix
26+
directly; array handling must extract and validate the component class name
27+
- [ ] `resolveProxyClass` is overridden to block or validate dynamic proxy interfaces
28+
- [ ] Allowlist uses package-prefix matching, not blocklisting
29+
- [ ] Allowlist presets contain no catch-all entries that bypass the filtering logic
30+
31+
### Input Validation
32+
- [ ] File paths, URLs, and user-supplied strings are validated before use
33+
- [ ] No unsanitized string interpolation into SQL, shell commands, or config
34+
35+
### Resource Management
36+
- [ ] Streams, connections, and closeable resources use `using()` or `try-finally`
37+
- [ ] Cleanup runs even when assertions or exceptions are thrown (especially in tests)
38+
39+
### Secrets & Credentials
40+
- [ ] No hardcoded secrets, tokens, or passwords in source or test code
41+
- [ ] Credentials loaded from environment variables or secure config only
42+
43+
## API Compatibility Checklist
44+
45+
Apply when changes modify public classes, traits, or companion objects.
46+
47+
### Binary Compatibility (JVM)
48+
- [ ] No method signature changes on existing public methods (default parameters
49+
generate synthetic bridges — use explicit overloads instead)
50+
- [ ] No removed or renamed public classes, traits, or objects
51+
- [ ] Companion object `extends DefaultParamsReadable[T]` preserved if it existed
52+
53+
### Source Compatibility
54+
- [ ] New parameters have defaults so existing callers compile unchanged
55+
- [ ] No narrowed return types or widened parameter types on public methods
56+
- [ ] Import changes don't break wildcard imports in downstream code
57+
58+
## Scala Checklist
59+
60+
- [ ] License header present (enforced by scalastyle)
61+
- [ ] `Wrappable` trait mixed in if the class needs a Python wrapper
62+
- [ ] `SynapseMLLogging` trait mixed in; `logClass()` called in constructor
63+
- [ ] No wildcard imports where explicit imports suffice (`java.io._` → named imports)
64+
- [ ] No RDD API usage — DataFrame/Dataset only
65+
- [ ] Lines ≤ 120 chars, files ≤ 800 lines
66+
67+
## Python Checklist
68+
69+
- [ ] License header present
70+
- [ ] Formatted with `black==22.3.0`
71+
- [ ] No edits to files under `target/` (auto-generated)
72+
- [ ] Hand-written overrides in `src/main/python/` extend the generated `_ClassName`
73+
74+
## Test Checklist
75+
76+
- [ ] New functionality has corresponding tests
77+
- [ ] Tests use `using()` for resource cleanup (no bare `.close()` after assertions)
78+
- [ ] Negative tests verify rejection/error cases, not just happy paths
79+
- [ ] No test-only dependencies leaked into main scope
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
// Copyright (C) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License. See LICENSE in project root for information.
3+
4+
package com.microsoft.azure.synapse.ml.core.utils
5+
6+
import java.io.{InputStream, InvalidClassException, ObjectStreamClass}
7+
8+
/** An ObjectInputStream that restricts deserialization to an allowlist of class name prefixes.
9+
*
10+
* This mitigates Java deserialization attacks (CWE-502) by rejecting any class
11+
* whose fully-qualified name does not start with one of the allowed prefixes.
12+
* It also inherits the context-classloader resolution from [[ContextObjectInputStream]].
13+
*
14+
* @param input the underlying input stream
15+
* @param allowedPrefixes set of class name prefixes that are permitted for deserialization
16+
*/
17+
class SafeObjectInputStream(
18+
input: InputStream,
19+
allowedPrefixes: Set[String]
20+
) extends ContextObjectInputStream(input) {
21+
22+
/** Extracts the component type name from a JVM array descriptor.
23+
* Primitive arrays (e.g. `[I`, `[D`) return None since they are always safe.
24+
* Object arrays (e.g. `[Lcom.example.Foo;`) return the fully-qualified class name.
25+
* Multi-dimensional arrays are unwrapped recursively (e.g. `[[Ljava.lang.String;`).
26+
*/
27+
private def extractArrayComponentName(className: String): Option[String] = {
28+
val stripped = className.dropWhile(_ == '[')
29+
if (stripped.startsWith("L") && stripped.endsWith(";")) {
30+
Some(stripped.substring(1, stripped.length - 1))
31+
} else {
32+
None // primitive array (B, C, D, F, I, J, S, Z)
33+
}
34+
}
35+
36+
private def isAllowed(className: String): Boolean = {
37+
allowedPrefixes.exists(prefix => className.startsWith(prefix))
38+
}
39+
40+
protected override def resolveClass(desc: ObjectStreamClass): Class[_] = {
41+
val className = desc.getName
42+
val allowed = if (className.startsWith("[")) {
43+
extractArrayComponentName(className) match {
44+
case Some(componentName) => isAllowed(componentName)
45+
case None => true // primitive arrays are always safe
46+
}
47+
} else {
48+
isAllowed(className)
49+
}
50+
51+
if (!allowed) {
52+
throw new InvalidClassException(
53+
className,
54+
"Deserialization of this class is not allowed. " +
55+
"Only classes with approved package prefixes may be deserialized."
56+
)
57+
}
58+
super.resolveClass(desc)
59+
}
60+
61+
/** Rejects dynamic proxy deserialization unless every interface is allowlisted.
62+
*
63+
* Dynamic proxies are a known deserialization attack vector (e.g. via
64+
* `java.lang.reflect.Proxy` with malicious `InvocationHandler` chains).
65+
* SynapseML model serialization does not use proxies, so this rejects
66+
* them by default while still validating interface names for safety.
67+
*/
68+
protected override def resolveProxyClass(interfaces: Array[String]): Class[_] = {
69+
val disallowed = interfaces.filterNot(isAllowed)
70+
if (disallowed.nonEmpty) {
71+
throw new InvalidClassException(
72+
disallowed.mkString(", "),
73+
"Deserialization of dynamic proxy is not allowed. " +
74+
"Proxy interface(s) not in the approved allowlist."
75+
)
76+
}
77+
super.resolveProxyClass(interfaces)
78+
}
79+
}
80+
81+
object SafeObjectInputStream {
82+
83+
/** Default allowlist suitable for deserializing SynapseML nn package objects
84+
* (BallTree, ConditionalBallTree, and their object graphs).
85+
*/
86+
val DefaultNNAllowedPrefixes: Set[String] = Set(
87+
"com.microsoft.azure.synapse.ml.nn.",
88+
"breeze.",
89+
"scala.",
90+
"java.lang.",
91+
"java.util.",
92+
"java.io.",
93+
"java.math."
94+
)
95+
}

core/src/main/scala/com/microsoft/azure/synapse/ml/nn/BallTree.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ package com.microsoft.azure.synapse.ml.nn
66
import breeze.linalg.{DenseVector, norm, _}
77
import com.microsoft.azure.synapse.ml.core.env.StreamUtilities.using
88

9-
import java.io._
9+
import com.microsoft.azure.synapse.ml.core.utils.SafeObjectInputStream
10+
11+
import java.io.{FileInputStream, FileOutputStream, ObjectOutputStream, Serializable}
1012
import scala.collection.JavaConverters._
1113

1214
private case class Query(point: DenseVector[Double],
@@ -170,8 +172,14 @@ object ConditionalBallTree {
170172
}
171173

172174
def load[L, V](filename: String): ConditionalBallTree[L, V] = {
175+
load(filename, SafeObjectInputStream.DefaultNNAllowedPrefixes)
176+
}
177+
178+
def load[L, V](filename: String,
179+
allowedPrefixes: Set[String]
180+
): ConditionalBallTree[L, V] = {
173181
using(new FileInputStream(filename)) { fileIn =>
174-
using(new ObjectInputStream(fileIn)) { in =>
182+
using(new SafeObjectInputStream(fileIn, allowedPrefixes)) { in =>
175183
in.readObject().asInstanceOf[ConditionalBallTree[L, V]]
176184
}
177185
}.get.get

core/src/test/scala/com/microsoft/azure/synapse/ml/nn/VerifySchemas.scala

Lines changed: 90 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ package com.microsoft.azure.synapse.ml.nn
66
import breeze.linalg.DenseVector
77
import com.microsoft.azure.synapse.ml.core.test.base.TestBase
88

9-
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
9+
import com.microsoft.azure.synapse.ml.core.env.StreamUtilities.using
10+
import com.microsoft.azure.synapse.ml.core.utils.SafeObjectInputStream
11+
12+
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectOutputStream}
1013

1114
class VerifySchemas extends TestBase {
1215

@@ -52,17 +55,97 @@ class VerifySchemas extends TestBase {
5255
test("BestMatch is serializable") {
5356
val bm = BestMatch(7, 2.5)
5457
val baos = new ByteArrayOutputStream()
55-
val oos = new ObjectOutputStream(baos)
56-
oos.writeObject(bm)
57-
oos.close()
58+
using(new ObjectOutputStream(baos)) { oos =>
59+
oos.writeObject(bm)
60+
}
5861

5962
val bais = new ByteArrayInputStream(baos.toByteArray)
60-
val ois = new ObjectInputStream(bais)
61-
val deserialized = ois.readObject().asInstanceOf[BestMatch]
62-
ois.close()
63+
val deserialized = using(new SafeObjectInputStream(bais, SafeObjectInputStream.DefaultNNAllowedPrefixes)) { ois =>
64+
ois.readObject().asInstanceOf[BestMatch]
65+
}.get
6366

6467
assert(deserialized.index === 7)
6568
assert(deserialized.distance === 2.5)
6669
}
6770

71+
test("SafeObjectInputStream rejects unauthorized classes") {
72+
val malicious = new java.util.concurrent.atomic.AtomicReference[String]("payload")
73+
val baos = new ByteArrayOutputStream()
74+
using(new ObjectOutputStream(baos)) { oos =>
75+
oos.writeObject(malicious)
76+
}
77+
78+
val bais = new ByteArrayInputStream(baos.toByteArray)
79+
val restrictedPrefixes = Set("com.microsoft.azure.synapse.ml.nn.")
80+
val result = using(new SafeObjectInputStream(bais, restrictedPrefixes)) { ois =>
81+
ois.readObject()
82+
}
83+
assert(result.isFailure)
84+
assert(result.failed.get.isInstanceOf[java.io.InvalidClassException])
85+
}
86+
87+
test("SafeObjectInputStream allows primitive arrays") {
88+
val data = Array(1.0, 2.0, 3.0)
89+
val baos = new ByteArrayOutputStream()
90+
using(new ObjectOutputStream(baos)) { oos =>
91+
oos.writeObject(data)
92+
}
93+
94+
val bais = new ByteArrayInputStream(baos.toByteArray)
95+
val deserialized = using(new SafeObjectInputStream(bais, Set("java.lang."))) { ois =>
96+
ois.readObject().asInstanceOf[Array[Double]]
97+
}.get
98+
99+
assert(deserialized.sameElements(data))
100+
}
101+
102+
test("SafeObjectInputStream allows object arrays with permitted component type") {
103+
val data = Array("hello", "world")
104+
val baos = new ByteArrayOutputStream()
105+
using(new ObjectOutputStream(baos)) { oos =>
106+
oos.writeObject(data)
107+
}
108+
109+
val bais = new ByteArrayInputStream(baos.toByteArray)
110+
val deserialized = using(new SafeObjectInputStream(bais, Set("java.lang."))) { ois =>
111+
ois.readObject().asInstanceOf[Array[String]]
112+
}.get
113+
114+
assert(deserialized.sameElements(data))
115+
}
116+
117+
test("SafeObjectInputStream rejects object arrays with disallowed component type") {
118+
val data = Array(new java.util.concurrent.atomic.AtomicInteger(1))
119+
val baos = new ByteArrayOutputStream()
120+
using(new ObjectOutputStream(baos)) { oos =>
121+
oos.writeObject(data)
122+
}
123+
124+
val bais = new ByteArrayInputStream(baos.toByteArray)
125+
val result = using(new SafeObjectInputStream(bais, Set("java.lang."))) { ois =>
126+
ois.readObject()
127+
}
128+
assert(result.isFailure)
129+
assert(result.failed.get.isInstanceOf[java.io.InvalidClassException])
130+
}
131+
132+
test("SafeObjectInputStream rejects dynamic proxies with disallowed interfaces") {
133+
val proxy = java.lang.reflect.Proxy.newProxyInstance(
134+
getClass.getClassLoader,
135+
Array(classOf[Runnable]),
136+
(_: Any, _: java.lang.reflect.Method, _: Array[AnyRef]) => None.orNull
137+
)
138+
val baos = new ByteArrayOutputStream()
139+
using(new ObjectOutputStream(baos)) { oos =>
140+
oos.writeObject(proxy)
141+
}
142+
143+
val bais = new ByteArrayInputStream(baos.toByteArray)
144+
val result = using(new SafeObjectInputStream(bais, Set("com.microsoft.azure.synapse.ml.nn."))) { ois =>
145+
ois.readObject()
146+
}
147+
assert(result.isFailure)
148+
assert(result.failed.get.isInstanceOf[java.io.InvalidClassException])
149+
}
150+
68151
}

0 commit comments

Comments
 (0)