Skip to content

Commit 823d53d

Browse files
SunWeb3Secvy
andauthored
Harden readObject(ObjectInputStream) method argument checks (#4098)
Signed-off-by: SunWeb3Sec <infosecpt@gmail.com> Co-authored-by: Volkan Yazıcı <volkan@yazi.ci>
1 parent dec4351 commit 823d53d

28 files changed

Lines changed: 196 additions & 42 deletions

File tree

log4j-1.2-api/src/main/java/org/apache/log4j/Level.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.io.Serializable;
2626
import org.apache.log4j.helpers.OptionConverter;
2727
import org.apache.logging.log4j.util.Strings;
28+
import org.apache.logging.log4j.util.internal.SerializationUtil;
2829

2930
/**
3031
* Defines the minimum set of levels recognized by the system, that is
@@ -214,6 +215,7 @@ public static Level toLevel(final String sArg, final Level defaultLevel) {
214215
* @throws ClassNotFoundException if class not found.
215216
*/
216217
private void readObject(final ObjectInputStream s) throws IOException, ClassNotFoundException {
218+
SerializationUtil.assertFiltered(s);
217219
s.defaultReadObject();
218220
level = s.readInt();
219221
syslogEquivalent = s.readInt();

log4j-1.2-api/src/test/java/org/apache/log4j/util/SerializationTestHelper.java

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,14 @@
2424
import java.io.File;
2525
import java.io.FileInputStream;
2626
import java.io.IOException;
27+
import java.io.InputStream;
2728
import java.io.ObjectInputStream;
2829
import java.io.ObjectOutputStream;
30+
import java.util.Arrays;
31+
import java.util.Collection;
2932
import org.apache.commons.io.FileUtils;
33+
import org.apache.logging.log4j.util.Constants;
34+
import org.apache.logging.log4j.util.FilteredObjectInputStream;
3035

3136
/**
3237
* Utiities for serialization tests.
@@ -103,11 +108,23 @@ public static void assertStreamEquals(
103108
* @throws Exception thrown on IO or deserialization exception.
104109
*/
105110
public static Object deserializeStream(final String witness) throws Exception {
106-
try (final ObjectInputStream objIs = new ObjectInputStream(new FileInputStream(witness))) {
111+
try (final ObjectInputStream objIs = newObjectInputStream(new FileInputStream(witness))) {
107112
return objIs.readObject();
108113
}
109114
}
110115

116+
private static ObjectInputStream newObjectInputStream(final InputStream in) throws IOException {
117+
if (Constants.JAVA_MAJOR_VERSION == 8) {
118+
// FilteredObjectInputStream's default allow-list covers `org.apache.logging.log4j.` but
119+
// not the `org.apache.log4j.` 1.2-compatibility namespace, so we have to enumerate the
120+
// 1.2 classes that the tests in this module deserialize on Java 8.
121+
final Collection<String> allowedLog4j12Classes =
122+
Arrays.asList("org.apache.log4j.Level", "org.apache.log4j.LevelTest$CustomLevel");
123+
return new FilteredObjectInputStream(in, allowedLog4j12Classes);
124+
}
125+
return new ObjectInputStream(in);
126+
}
127+
111128
/**
112129
* Creates a clone by serializing object and deserializing byte stream.
113130
*
@@ -123,7 +140,7 @@ public static Object serializeClone(final Object obj) throws IOException, ClassN
123140
}
124141

125142
final ByteArrayInputStream src = new ByteArrayInputStream(memOut.toByteArray());
126-
final ObjectInputStream objIs = new ObjectInputStream(src);
143+
final ObjectInputStream objIs = newObjectInputStream(src);
127144

128145
return objIs.readObject();
129146
}

log4j-api-test/src/main/java/org/apache/logging/log4j/test/SerializableMatchers.java

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
import static org.hamcrest.core.IsInstanceOf.any;
2121

2222
import java.io.Serializable;
23-
import org.apache.commons.lang3.SerializationUtils;
23+
import java.util.Collection;
24+
import java.util.Collections;
25+
import org.apache.logging.log4j.test.junit.SerialUtil;
2426
import org.hamcrest.FeatureMatcher;
2527
import org.hamcrest.Matcher;
2628

@@ -32,10 +34,19 @@
3234
public final class SerializableMatchers {
3335

3436
public static <T extends Serializable> Matcher<T> serializesRoundTrip(final Matcher<T> matcher) {
37+
return serializesRoundTrip(matcher, Collections.emptySet());
38+
}
39+
40+
/**
41+
* Same as {@link #serializesRoundTrip(Matcher)} but extends the default deserialization
42+
* allow-list on Java 8 (see {@link SerialUtil#deserialize(byte[], Collection)}).
43+
*/
44+
public static <T extends Serializable> Matcher<T> serializesRoundTrip(
45+
final Matcher<T> matcher, final Collection<String> allowedExtraClasses) {
3546
return new FeatureMatcher<T, T>(matcher, "serializes round trip", "serializes round trip") {
3647
@Override
3748
protected T featureValueOf(final T actual) {
38-
return SerializationUtils.roundtrip(actual);
49+
return SerialUtil.deserialize(SerialUtil.serialize(actual), allowedExtraClasses);
3950
}
4051
};
4152
}
@@ -52,5 +63,13 @@ public static Matcher<? super Serializable> serializesRoundTrip() {
5263
return serializesRoundTrip(any(Serializable.class));
5364
}
5465

66+
/**
67+
* Same as {@link #serializesRoundTrip()} but extends the default deserialization allow-list on
68+
* Java 8 (see {@link SerialUtil#deserialize(byte[], Collection)}).
69+
*/
70+
public static Matcher<? super Serializable> serializesRoundTrip(final Collection<String> allowedExtraClasses) {
71+
return serializesRoundTrip(any(Serializable.class), allowedExtraClasses);
72+
}
73+
5574
private SerializableMatchers() {}
5675
}

log4j-api-test/src/main/java/org/apache/logging/log4j/test/junit/SerialUtil.java

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import java.io.ObjectOutput;
2525
import java.io.ObjectOutputStream;
2626
import java.io.Serializable;
27+
import java.util.Collection;
28+
import java.util.Collections;
2729
import org.apache.logging.log4j.test.internal.annotation.SuppressFBWarnings;
2830
import org.apache.logging.log4j.util.Constants;
2931
import org.apache.logging.log4j.util.FilteredObjectInputStream;
@@ -68,11 +70,25 @@ public static byte[] serialize(final Serializable... objs) {
6870
* @param data byte array representing the serialized object
6971
* @return the deserialized object
7072
*/
71-
@SuppressWarnings("unchecked")
7273
@SuppressFBWarnings("OBJECT_DESERIALIZATION")
7374
public static <T> T deserialize(final byte[] data) {
75+
return deserialize(data, Collections.emptySet());
76+
}
77+
78+
/**
79+
* Deserialize an object from the specified byte array using a {@link FilteredObjectInputStream}
80+
* extended with the supplied allow-list (Java 8 only — Java 9+ uses the JVM's serialization
81+
* filter, so the allow-list is ignored).
82+
* @param data byte array representing the serialized object
83+
* @param allowedExtraClasses fully-qualified class names to add to {@link
84+
* FilteredObjectInputStream}'s default allow-list on Java 8
85+
* @return the deserialized object
86+
*/
87+
@SuppressWarnings("unchecked")
88+
@SuppressFBWarnings("OBJECT_DESERIALIZATION")
89+
public static <T> T deserialize(final byte[] data, final Collection<String> allowedExtraClasses) {
7490
try {
75-
final ObjectInputStream ois = getObjectInputStream(data);
91+
final ObjectInputStream ois = getObjectInputStream(data, allowedExtraClasses);
7692
return (T) ois.readObject();
7793
} catch (final Exception ex) {
7894
throw new IllegalStateException("Could not deserialize", ex);
@@ -86,8 +102,18 @@ public static <T> T deserialize(final byte[] data) {
86102
*/
87103
@SuppressFBWarnings("OBJECT_DESERIALIZATION")
88104
public static ObjectInputStream getObjectInputStream(final byte[] data) throws IOException {
105+
return getObjectInputStream(data, Collections.emptySet());
106+
}
107+
108+
/**
109+
* Creates an {@link ObjectInputStream} adapted to the current Java version, extended with the
110+
* supplied allow-list on Java 8.
111+
*/
112+
@SuppressFBWarnings("OBJECT_DESERIALIZATION")
113+
public static ObjectInputStream getObjectInputStream(
114+
final byte[] data, final Collection<String> allowedExtraClasses) throws IOException {
89115
final ByteArrayInputStream bas = new ByteArrayInputStream(data);
90-
return getObjectInputStream(bas);
116+
return getObjectInputStream(bas, allowedExtraClasses);
91117
}
92118

93119
/**
@@ -97,8 +123,18 @@ public static ObjectInputStream getObjectInputStream(final byte[] data) throws I
97123
*/
98124
@SuppressFBWarnings("OBJECT_DESERIALIZATION")
99125
public static ObjectInputStream getObjectInputStream(final InputStream stream) throws IOException {
126+
return getObjectInputStream(stream, Collections.emptySet());
127+
}
128+
129+
/**
130+
* Creates an {@link ObjectInputStream} adapted to the current Java version, extended with the
131+
* supplied allow-list on Java 8.
132+
*/
133+
@SuppressFBWarnings("OBJECT_DESERIALIZATION")
134+
public static ObjectInputStream getObjectInputStream(
135+
final InputStream stream, final Collection<String> allowedExtraClasses) throws IOException {
100136
return Constants.JAVA_MAJOR_VERSION == 8
101-
? new FilteredObjectInputStream(stream)
137+
? new FilteredObjectInputStream(stream, allowedExtraClasses)
102138
: new ObjectInputStream(stream);
103139
}
104140
}

log4j-api-test/src/main/java/org/apache/logging/log4j/test/junit/package-info.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* limitations under the license.
1616
*/
1717
@Export
18-
@Version("2.25.3")
18+
@Version("2.27.0")
1919
package org.apache.logging.log4j.test.junit;
2020

2121
import org.osgi.annotation.bundle.Export;

log4j-api-test/src/main/java/org/apache/logging/log4j/test/package-info.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* limitations under the license.
1616
*/
1717
@Export
18-
@Version("2.25.3")
18+
@Version("2.27.0")
1919
package org.apache.logging.log4j.test;
2020

2121
import org.osgi.annotation.bundle.Export;

log4j-api-test/src/test/java/org/apache/logging/log4j/message/FormattedMessageTest.java

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,9 @@
1919
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
2020
import static org.junit.jupiter.api.Assertions.assertEquals;
2121

22-
import java.io.ByteArrayInputStream;
23-
import java.io.ByteArrayOutputStream;
24-
import java.io.IOException;
25-
import java.io.ObjectInputStream;
26-
import java.io.ObjectOutputStream;
2722
import java.util.Locale;
2823
import org.apache.logging.log4j.test.junit.Mutable;
24+
import org.apache.logging.log4j.test.junit.SerialUtil;
2925
import org.apache.logging.log4j.util.Constants;
3026
import org.junit.jupiter.api.Test;
3127
import org.junit.jupiter.api.parallel.ResourceAccessMode;
@@ -158,15 +154,9 @@ void testSafeAfterGetFormattedMessageIsCalled() { // LOG4J2-763
158154
}
159155

160156
@Test
161-
void testSerialization() throws IOException, ClassNotFoundException {
157+
void testSerialization() {
162158
final FormattedMessage expected = new FormattedMessage("Msg", "a", "b", "c");
163-
final ByteArrayOutputStream baos = new ByteArrayOutputStream();
164-
try (final ObjectOutputStream out = new ObjectOutputStream(baos)) {
165-
out.writeObject(expected);
166-
}
167-
final ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
168-
final ObjectInputStream in = new ObjectInputStream(bais);
169-
final FormattedMessage actual = (FormattedMessage) in.readObject();
159+
final FormattedMessage actual = SerialUtil.deserialize(SerialUtil.serialize(expected));
170160
assertEquals(expected, actual);
171161
assertEquals(expected.getFormat(), actual.getFormat());
172162
assertEquals(expected.getFormattedMessage(), actual.getFormattedMessage());

log4j-api-test/src/test/java/org/apache/logging/log4j/message/LocalizedMessageTest.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@
1818

1919
import static org.junit.jupiter.api.Assertions.assertEquals;
2020

21-
import java.io.Serializable;
2221
import java.util.Locale;
23-
import org.apache.commons.lang3.SerializationUtils;
2422
import org.apache.logging.log4j.test.junit.Mutable;
23+
import org.apache.logging.log4j.test.junit.SerialUtil;
2524
import org.junit.jupiter.api.Test;
2625
import org.junit.jupiter.api.parallel.ResourceAccessMode;
2726
import org.junit.jupiter.api.parallel.ResourceLock;
@@ -33,8 +32,8 @@
3332
@ResourceLock(value = Resources.LOCALE, mode = ResourceAccessMode.READ)
3433
class LocalizedMessageTest {
3534

36-
private <T extends Serializable> T roundtrip(final T msg) {
37-
return SerializationUtils.roundtrip(msg);
35+
private LocalizedMessage roundtrip(final LocalizedMessage msg) {
36+
return SerialUtil.deserialize(SerialUtil.serialize(msg));
3837
}
3938

4039
@Test

log4j-api-test/src/test/java/org/apache/logging/log4j/message/ObjectArrayMessageTest.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
2020
import static org.junit.jupiter.api.Assertions.assertNull;
2121

22+
import org.apache.logging.log4j.test.junit.SerialUtil;
2223
import org.junit.jupiter.api.Test;
2324

2425
/**
@@ -38,4 +39,16 @@ void testGetParameters() {
3839
void testGetThrowable() {
3940
assertNull(OBJECT_ARRAY_MESSAGE.getThrowable());
4041
}
42+
43+
/**
44+
* Round-trips through a filtered stream (see {@link SerialUtil#getObjectInputStream})
45+
* to verify that {@code readObject}'s new {@code SerializationUtil.assertFiltered}
46+
* check accepts streams that carry a filter.
47+
*/
48+
@Test
49+
void testSerializableRoundTripThroughFilteredStream() {
50+
final ObjectArrayMessage original = new ObjectArrayMessage("A", "B", "C");
51+
final ObjectArrayMessage restored = SerialUtil.deserialize(SerialUtil.serialize(original));
52+
assertArrayEquals(original.getParameters(), restored.getParameters());
53+
}
4154
}

log4j-api-test/src/test/java/org/apache/logging/log4j/message/StringFormattedMessageTest.java

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,9 @@
2020
import static org.junit.jupiter.api.Assertions.assertEquals;
2121
import static org.junit.jupiter.api.Assertions.assertNotNull;
2222

23-
import java.io.ByteArrayInputStream;
24-
import java.io.ByteArrayOutputStream;
25-
import java.io.IOException;
26-
import java.io.ObjectInputStream;
27-
import java.io.ObjectOutputStream;
2823
import java.util.Locale;
2924
import org.apache.logging.log4j.test.junit.Mutable;
25+
import org.apache.logging.log4j.test.junit.SerialUtil;
3026
import org.junit.jupiter.api.Test;
3127
import org.junit.jupiter.api.parallel.ResourceAccessMode;
3228
import org.junit.jupiter.api.parallel.ResourceLock;
@@ -115,15 +111,9 @@ void testSafeAfterGetFormattedMessageIsCalled() { // LOG4J2-763
115111
}
116112

117113
@Test
118-
void testSerialization() throws IOException, ClassNotFoundException {
114+
void testSerialization() {
119115
final StringFormattedMessage expected = new StringFormattedMessage("Msg", "a", "b", "c");
120-
final ByteArrayOutputStream baos = new ByteArrayOutputStream();
121-
try (final ObjectOutputStream out = new ObjectOutputStream(baos)) {
122-
out.writeObject(expected);
123-
}
124-
final ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
125-
final ObjectInputStream in = new ObjectInputStream(bais);
126-
final StringFormattedMessage actual = (StringFormattedMessage) in.readObject();
116+
final StringFormattedMessage actual = SerialUtil.deserialize(SerialUtil.serialize(expected));
127117
assertEquals(expected, actual);
128118
assertEquals(expected.getFormat(), actual.getFormat());
129119
assertEquals(expected.getFormattedMessage(), actual.getFormattedMessage());

0 commit comments

Comments
 (0)