@@ -6,7 +6,10 @@ package com.microsoft.azure.synapse.ml.nn
66import breeze .linalg .DenseVector
77import com .microsoft .azure .synapse .ml .core .test .base .TestBase
88
9- import java .io .{ByteArrayInputStream , ByteArrayOutputStream , ObjectInputStream , ObjectOutputStream }
9+ import com .microsoft .azure .synapse .ml .core .env .StreamUtilities .using
10+ import com .microsoft .azure .synapse .ml .core .utils .SafeObjectInputStream
11+
12+ import java .io .{ByteArrayInputStream , ByteArrayOutputStream , ObjectOutputStream }
1013
1114class VerifySchemas extends TestBase {
1215
@@ -52,17 +55,97 @@ class VerifySchemas extends TestBase {
5255 test(" BestMatch is serializable" ) {
5356 val bm = BestMatch (7 , 2.5 )
5457 val baos = new ByteArrayOutputStream ()
55- val oos = new ObjectOutputStream (baos)
56- oos.writeObject(bm)
57- oos.close()
58+ using( new ObjectOutputStream (baos)) { oos =>
59+ oos.writeObject(bm)
60+ }
5861
5962 val bais = new ByteArrayInputStream (baos.toByteArray)
60- val ois = new ObjectInputStream (bais)
61- val deserialized = ois.readObject().asInstanceOf [BestMatch ]
62- ois.close()
63+ val deserialized = using( new SafeObjectInputStream (bais, SafeObjectInputStream . DefaultNNAllowedPrefixes )) { ois =>
64+ ois.readObject().asInstanceOf [BestMatch ]
65+ }.get
6366
6467 assert(deserialized.index === 7 )
6568 assert(deserialized.distance === 2.5 )
6669 }
6770
71+ test(" SafeObjectInputStream rejects unauthorized classes" ) {
72+ val malicious = new java.util.concurrent.atomic.AtomicReference [String ](" payload" )
73+ val baos = new ByteArrayOutputStream ()
74+ using(new ObjectOutputStream (baos)) { oos =>
75+ oos.writeObject(malicious)
76+ }
77+
78+ val bais = new ByteArrayInputStream (baos.toByteArray)
79+ val restrictedPrefixes = Set (" com.microsoft.azure.synapse.ml.nn." )
80+ val result = using(new SafeObjectInputStream (bais, restrictedPrefixes)) { ois =>
81+ ois.readObject()
82+ }
83+ assert(result.isFailure)
84+ assert(result.failed.get.isInstanceOf [java.io.InvalidClassException ])
85+ }
86+
87+ test(" SafeObjectInputStream allows primitive arrays" ) {
88+ val data = Array (1.0 , 2.0 , 3.0 )
89+ val baos = new ByteArrayOutputStream ()
90+ using(new ObjectOutputStream (baos)) { oos =>
91+ oos.writeObject(data)
92+ }
93+
94+ val bais = new ByteArrayInputStream (baos.toByteArray)
95+ val deserialized = using(new SafeObjectInputStream (bais, Set (" java.lang." ))) { ois =>
96+ ois.readObject().asInstanceOf [Array [Double ]]
97+ }.get
98+
99+ assert(deserialized.sameElements(data))
100+ }
101+
102+ test(" SafeObjectInputStream allows object arrays with permitted component type" ) {
103+ val data = Array (" hello" , " world" )
104+ val baos = new ByteArrayOutputStream ()
105+ using(new ObjectOutputStream (baos)) { oos =>
106+ oos.writeObject(data)
107+ }
108+
109+ val bais = new ByteArrayInputStream (baos.toByteArray)
110+ val deserialized = using(new SafeObjectInputStream (bais, Set (" java.lang." ))) { ois =>
111+ ois.readObject().asInstanceOf [Array [String ]]
112+ }.get
113+
114+ assert(deserialized.sameElements(data))
115+ }
116+
117+ test(" SafeObjectInputStream rejects object arrays with disallowed component type" ) {
118+ val data = Array (new java.util.concurrent.atomic.AtomicInteger (1 ))
119+ val baos = new ByteArrayOutputStream ()
120+ using(new ObjectOutputStream (baos)) { oos =>
121+ oos.writeObject(data)
122+ }
123+
124+ val bais = new ByteArrayInputStream (baos.toByteArray)
125+ val result = using(new SafeObjectInputStream (bais, Set (" java.lang." ))) { ois =>
126+ ois.readObject()
127+ }
128+ assert(result.isFailure)
129+ assert(result.failed.get.isInstanceOf [java.io.InvalidClassException ])
130+ }
131+
132+ test(" SafeObjectInputStream rejects dynamic proxies with disallowed interfaces" ) {
133+ val proxy = java.lang.reflect.Proxy .newProxyInstance(
134+ getClass.getClassLoader,
135+ Array (classOf [Runnable ]),
136+ (_ : Any , _ : java.lang.reflect.Method , _ : Array [AnyRef ]) => None .orNull
137+ )
138+ val baos = new ByteArrayOutputStream ()
139+ using(new ObjectOutputStream (baos)) { oos =>
140+ oos.writeObject(proxy)
141+ }
142+
143+ val bais = new ByteArrayInputStream (baos.toByteArray)
144+ val result = using(new SafeObjectInputStream (bais, Set (" com.microsoft.azure.synapse.ml.nn." ))) { ois =>
145+ ois.readObject()
146+ }
147+ assert(result.isFailure)
148+ assert(result.failed.get.isInstanceOf [java.io.InvalidClassException ])
149+ }
150+
68151}
0 commit comments