Skip to content

Commit 55bddbb

Browse files
committed
feat(java): support limit deserialization depth (#2578)
## Why? <!-- Describe the purpose of this PR. --> ## What does this PR do? <!-- Describe the details of this PR. --> ## Related issues <!-- Is there any related issue? If this PR closes them you say say fix/closes: - #xxxx0 - #xxxx1 - Fixes #xxxx2 --> ## Does this PR introduce any user-facing change? <!-- If any user-facing interface changes, please [open an issue](https://github.com/apache/fory/issues/new/choose) describing the need to do so and update the document if necessary. Delete section if not applicable. --> - [ ] Does this PR introduce any public API change? - [ ] Does this PR introduce any binary protocol compatibility change? ## Benchmark <!-- When the PR has an impact on performance (if you don't know whether the PR will have an impact on performance, you can submit the PR first, and if it will have impact on performance, the code reviewer will explain it), be sure to attach a benchmark data here. Delete section if not applicable. -->
1 parent ab49713 commit 55bddbb

16 files changed

Lines changed: 334 additions & 153 deletions

docs/guide/java_serialization_guide.md

Lines changed: 33 additions & 23 deletions
Large diffs are not rendered by default.

java/fory-core/src/main/java/org/apache/fory/Fory.java

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import org.apache.fory.exception.CopyException;
4141
import org.apache.fory.exception.DeserializationException;
4242
import org.apache.fory.exception.ForyException;
43+
import org.apache.fory.exception.InsecureException;
4344
import org.apache.fory.exception.SerializationException;
4445
import org.apache.fory.io.ForyInputStream;
4546
import org.apache.fory.io.ForyReadableChannel;
@@ -126,6 +127,7 @@ public final class Fory implements BaseFory {
126127
private Iterator<MemoryBuffer> outOfBandBuffers;
127128
private boolean peerOutOfBandEnabled;
128129
private int depth;
130+
private final int maxDepth;
129131
private int copyDepth;
130132
private final boolean copyRefTracking;
131133
private final IdentityMap<Object, Object> originToCopyMap;
@@ -141,6 +143,7 @@ public Fory(ForyBuilder builder, ClassLoader classLoader) {
141143
this.shareMeta = config.isMetaShareEnabled();
142144
compressInt = config.compressInt();
143145
longEncoding = config.longEncoding();
146+
maxDepth = config.maxDepth();
144147
if (refTracking) {
145148
this.refResolver = new MapRefResolver();
146149
} else {
@@ -653,17 +656,6 @@ private void writeData(MemoryBuffer buffer, ClassInfo classInfo, Object obj) {
653656
case ClassResolver.STRING_CLASS_ID:
654657
stringSerializer.writeJavaString(buffer, (String) obj);
655658
break;
656-
case ClassResolver.ARRAYLIST_CLASS_ID:
657-
depth++;
658-
arrayListSerializer.write(buffer, (ArrayList) obj);
659-
depth--;
660-
break;
661-
case ClassResolver.HASHMAP_CLASS_ID:
662-
depth++;
663-
hashMapSerializer.write(buffer, (HashMap) obj);
664-
depth--;
665-
break;
666-
// TODO(add fastpath for other types)
667659
default:
668660
depth++;
669661
classInfo.getSerializer().write(buffer, obj);
@@ -1024,7 +1016,7 @@ public Object readNullable(MemoryBuffer buffer, ClassInfoHolder classInfoHolder)
10241016

10251017
/** Class should be read already. */
10261018
public Object readData(MemoryBuffer buffer, ClassInfo classInfo) {
1027-
depth++;
1019+
incReadDepth();
10281020
Serializer<?> serializer = classInfo.getSerializer();
10291021
Object read = serializer.read(buffer);
10301022
depth--;
@@ -1055,19 +1047,8 @@ private Object readDataInternal(MemoryBuffer buffer, ClassInfo classInfo) {
10551047
return buffer.readFloat64();
10561048
case ClassResolver.STRING_CLASS_ID:
10571049
return stringSerializer.readJavaString(buffer);
1058-
case ClassResolver.ARRAYLIST_CLASS_ID:
1059-
depth++;
1060-
Object list = arrayListSerializer.read(buffer);
1061-
depth--;
1062-
return list;
1063-
case ClassResolver.HASHMAP_CLASS_ID:
1064-
depth++;
1065-
Object map = hashMapSerializer.read(buffer);
1066-
depth--;
1067-
return map;
1068-
// TODO(add fastpath for other types)
10691050
default:
1070-
depth++;
1051+
incReadDepth();
10711052
Object read = classInfo.getSerializer().read(buffer);
10721053
depth--;
10731054
return read;
@@ -1112,7 +1093,7 @@ public Object xreadNonRef(MemoryBuffer buffer) {
11121093
}
11131094

11141095
public Object xreadNonRef(MemoryBuffer buffer, Serializer<?> serializer) {
1115-
depth++;
1096+
incReadDepth();
11161097
Object o = serializer.xread(buffer);
11171098
depth--;
11181099
return o;
@@ -1142,7 +1123,7 @@ public Object xreadNonRef(MemoryBuffer buffer, ClassInfo classInfo) {
11421123
return buffer.readFloat64();
11431124
// TODO(add fastpath for other types)
11441125
default:
1145-
depth++;
1126+
incReadDepth();
11461127
Object o = classInfo.getSerializer().xread(buffer);
11471128
depth--;
11481129
return o;
@@ -1682,6 +1663,29 @@ public void incDepth(int diff) {
16821663
this.depth += diff;
16831664
}
16841665

1666+
public void incDepth() {
1667+
this.depth += 1;
1668+
}
1669+
1670+
public void decDepth() {
1671+
this.depth -= 1;
1672+
}
1673+
1674+
public void incReadDepth() {
1675+
if ((this.depth += 1) > maxDepth) {
1676+
throwReadDepthExceedException();
1677+
}
1678+
}
1679+
1680+
private void throwReadDepthExceedException() {
1681+
throw new InsecureException(
1682+
String.format(
1683+
"Read depth exceed max depth %s, "
1684+
+ "the deserialization data may be malicious. If it's not malicious, "
1685+
+ "please increase max read depth by ForyBuilder#withMaxDepth(largerDepth)",
1686+
maxDepth));
1687+
}
1688+
16851689
public void incCopyDepth(int diff) {
16861690
this.copyDepth += diff;
16871691
}

java/fory-core/src/main/java/org/apache/fory/builder/BaseObjectCodecBuilder.java

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1636,13 +1636,13 @@ protected Expression deserializeForNotNull(
16361636
obj = deserializeForMap(buffer, typeRef, serializer, invokeHint);
16371637
} else {
16381638
if (serializer != null) {
1639-
return new Invoke(serializer, "read", OBJECT_TYPE, buffer);
1639+
return read(serializer, buffer, OBJECT_TYPE);
16401640
}
16411641
if (isMonomorphic(cls)) {
16421642
serializer = getOrCreateSerializer(cls);
16431643
Class<?> returnType =
16441644
ReflectionUtils.getReturnType(getRawType(serializer.type()), "read");
1645-
obj = new Invoke(serializer, "read", TypeRef.of(returnType), buffer);
1645+
obj = read(serializer, buffer, TypeRef.of(returnType));
16461646
} else {
16471647
obj = readForNotNullNonFinal(buffer, typeRef, serializer);
16481648
}
@@ -1651,13 +1651,24 @@ protected Expression deserializeForNotNull(
16511651
}
16521652
}
16531653

1654+
protected Expression read(Expression serializer, Expression buffer, TypeRef<?> returnType) {
1655+
Class<?> type = returnType.getRawType();
1656+
Expression read = new Invoke(serializer, "read", returnType, buffer);
1657+
if (ReflectionUtils.isMonomorphic(type) && !TypeUtils.hasExpandableLeafs(type)) {
1658+
return read;
1659+
}
1660+
read = uninline(read);
1661+
return new ListExpression(
1662+
new Invoke(foryRef, "incReadDepth"), read, new Invoke(foryRef, "decDepth"), read);
1663+
}
1664+
16541665
protected Expression readForNotNullNonFinal(
16551666
Expression buffer, TypeRef<?> typeRef, Expression serializer) {
16561667
if (serializer == null) {
16571668
Expression classInfo = readClassInfo(getRawType(typeRef), buffer);
16581669
serializer = inlineInvoke(classInfo, "getSerializer", SERIALIZER_TYPE);
16591670
}
1660-
return new Invoke(serializer, "read", OBJECT_TYPE, buffer);
1671+
return read(serializer, buffer, OBJECT_TYPE);
16611672
}
16621673

16631674
/**
@@ -1693,7 +1704,7 @@ protected Expression deserializeForCollection(
16931704
new If(
16941705
supportHook,
16951706
new ListExpression(collection, hookRead),
1696-
new Invoke(serializer, "read", OBJECT_TYPE, buffer),
1707+
read(serializer, buffer, OBJECT_TYPE),
16971708
false);
16981709
if (invokeHint != null && invokeHint.genNewMethod) {
16991710
invokeHint.add(buffer);
@@ -1969,8 +1980,7 @@ chunkHeader, cast(bitand(sizeAndHeader2, ofInt(0xff)), PRIMITIVE_INT_TYPE)),
19691980
expressions.add(chunksLoop, newMap);
19701981
// first newMap to create map, last newMap as expr value
19711982
Expression map = inlineInvoke(serializer, "onMapRead", OBJECT_TYPE, expressions);
1972-
Expression action =
1973-
new If(supportHook, map, new Invoke(serializer, "read", OBJECT_TYPE, buffer), false);
1983+
Expression action = new If(supportHook, map, read(serializer, buffer, OBJECT_TYPE), false);
19741984
if (invokeHint != null && invokeHint.genNewMethod) {
19751985
invokeHint.add(buffer);
19761986
invokeHint.add(serializer);

java/fory-core/src/main/java/org/apache/fory/builder/CodecBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ public abstract class CodecBuilder {
9595
protected final boolean isRecord;
9696
protected final boolean isInterface;
9797
private final Set<String> duplicatedFields;
98-
protected Reference foryRef = new Reference(FORY_NAME, TypeRef.of(Fory.class));
98+
protected Reference foryRef = Reference.fieldRef(FORY_NAME, TypeRef.of(Fory.class));
9999
public static final Reference recordComponentDefaultValues =
100100
new Reference("recordComponentDefaultValues", OBJECT_ARRAY_TYPE);
101101
protected final Map<String, Reference> fieldMap = new HashMap<>();

java/fory-core/src/main/java/org/apache/fory/builder/ObjectCodecOptimizer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ private void buildGroups() {
121121
MutableTuple3.of(
122122
new ArrayList<>(descriptorGrouper.getFinalDescriptors()), 5, finalReadGroups),
123123
MutableTuple3.of(
124-
new ArrayList<>(descriptorGrouper.getOtherDescriptors()), 5, otherReadGroups),
124+
new ArrayList<>(descriptorGrouper.getOtherDescriptors()), 4, otherReadGroups),
125125
MutableTuple3.of(
126126
new ArrayList<>(descriptorGrouper.getOtherDescriptors()), 9, otherWriteGroups));
127127
for (MutableTuple3<List<Descriptor>, Integer, List<List<Descriptor>>> decs : groups) {

java/fory-core/src/main/java/org/apache/fory/config/Config.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ public class Config implements Serializable {
6363
private final boolean deserializeNonexistentEnumValueAsNull;
6464
private final boolean serializeEnumByName;
6565
private final int bufferSizeLimitBytes;
66+
private final int maxDepth;
6667

6768
public Config(ForyBuilder builder) {
6869
name = builder.name;
@@ -101,6 +102,7 @@ public Config(ForyBuilder builder) {
101102
deserializeNonexistentEnumValueAsNull = builder.deserializeNonexistentEnumValueAsNull;
102103
serializeEnumByName = builder.serializeEnumByName;
103104
bufferSizeLimitBytes = builder.bufferSizeLimitBytes;
105+
maxDepth = builder.maxDepth;
104106
}
105107

106108
/** Returns the name for Fory serialization. */
@@ -357,4 +359,9 @@ public int getConfigHash() {
357359
}
358360
return configHash;
359361
}
362+
363+
/** Returns max depth for deserialization, when depth exceeds, an exception will be thrown. */
364+
public int maxDepth() {
365+
return maxDepth;
366+
}
360367
}

java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import org.apache.fory.serializer.TimeSerializers;
3939
import org.apache.fory.serializer.collection.GuavaCollectionSerializers;
4040
import org.apache.fory.util.GraalvmSupport;
41+
import org.apache.fory.util.Preconditions;
4142

4243
/** Builder class to config and create {@link Fory}. */
4344
// Method naming style for this builder:
@@ -86,6 +87,7 @@ public final class ForyBuilder {
8687
boolean serializeEnumByName = false;
8788
int bufferSizeLimitBytes = 128 * 1024;
8889
MetaCompressor metaCompressor = new DeflaterMetaCompressor();
90+
int maxDepth = 50;
8991

9092
public ForyBuilder() {}
9193

@@ -348,6 +350,16 @@ public ForyBuilder withAsyncCompilation(boolean asyncCompilation) {
348350
return this;
349351
}
350352

353+
/**
354+
* Set max depth for deserialization, when depth exceeds, an exception will be thrown. Default max
355+
* depth is 50.
356+
*/
357+
public ForyBuilder withMaxDepth(int maxDepth) {
358+
Preconditions.checkArgument(maxDepth >= 2, "maxDepth must >= 2 but got %s", maxDepth);
359+
this.maxDepth = maxDepth;
360+
return this;
361+
}
362+
351363
/** Whether enable scala-specific serialization optimization. */
352364
public ForyBuilder withScalaOptimizationEnabled(boolean enableScalaOptimization) {
353365
this.scalaOptimizationEnabled = enableScalaOptimization;

java/fory-core/src/main/java/org/apache/fory/resolver/XtypeResolver.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -647,19 +647,19 @@ private void throwUnexpectTypeIdException(long xtypeId) {
647647
}
648648

649649
private ClassInfo getListClassInfo() {
650-
fory.incDepth(1);
650+
fory.incReadDepth();
651651
GenericType genericType = generics.nextGenericType();
652-
fory.incDepth(-1);
652+
fory.decDepth();
653653
if (genericType != null) {
654654
return getOrBuildClassInfo(genericType.getCls());
655655
}
656656
return xtypeIdToClassMap.get(Types.LIST);
657657
}
658658

659659
private ClassInfo getGenericClassInfo() {
660-
fory.incDepth(1);
660+
fory.incReadDepth();
661661
GenericType genericType = generics.nextGenericType();
662-
fory.incDepth(-1);
662+
fory.decDepth();
663663
if (genericType != null) {
664664
return getOrBuildClassInfo(genericType.getCls());
665665
}

java/fory-core/src/main/java/org/apache/fory/serializer/AbstractObjectSerializer.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,17 @@ static Object readFinalObjectFieldValue(
8989
boolean isFinal,
9090
MemoryBuffer buffer) {
9191
Serializer<Object> serializer = fieldInfo.classInfo.getSerializer();
92+
binding.incReadDepth();
9293
Object fieldValue;
9394
boolean nullable = fieldInfo.nullable;
9495
if (isFinal) {
9596
if (!fieldInfo.trackingRef) {
96-
return binding.readNullable(buffer, serializer, nullable);
97+
fieldValue = binding.readNullable(buffer, serializer, nullable);
98+
} else {
99+
// whether tracking ref is recorded in `fieldInfo.serializer`, so it's still
100+
// consistent with jit serializer.
101+
fieldValue = binding.readRef(buffer, serializer);
97102
}
98-
// whether tracking ref is recorded in `fieldInfo.serializer`, so it's still
99-
// consistent with jit serializer.
100-
fieldValue = binding.readRef(buffer, serializer);
101103
} else {
102104
if (serializer.needToWriteRef()) {
103105
int nextReadRefId = refResolver.tryPreserveRefId(buffer);
@@ -112,13 +114,15 @@ static Object readFinalObjectFieldValue(
112114
if (nullable) {
113115
byte headFlag = buffer.readByte();
114116
if (headFlag == Fory.NULL_FLAG) {
117+
binding.decDepth();
115118
return null;
116119
}
117120
}
118121
typeResolver.readClassInfo(buffer, fieldInfo.classInfo);
119122
fieldValue = serializer.read(buffer);
120123
}
121124
}
125+
binding.decDepth();
122126
return fieldValue;
123127
}
124128

0 commit comments

Comments
 (0)