diff --git a/conformance/binary_json_conformance_suite.cc b/conformance/binary_json_conformance_suite.cc index 62c9294893a09..91b9f4b3eeba1 100644 --- a/conformance/binary_json_conformance_suite.cc +++ b/conformance/binary_json_conformance_suite.cc @@ -524,6 +524,61 @@ void BinaryAndJsonConformanceSuite::RunRecursionLimitTests() { proto2_msg.SerializeAsString(), "EnforceDepthLimit.MessageSetExtension", RECOMMENDED); } + + auto expect_json_parse_failure = [&](const std::string& name, + const std::string& json) { + TestAllTypesProto3 prototype; + ConformanceRequestSetting setting( + REQUIRED, ::conformance::JSON, ::conformance::PROTOBUF, + ::conformance::JSON_TEST, prototype, name, json); + const ConformanceRequest& request = setting.GetRequest(); + ConformanceResponse response; + std::string effective_test_name = + absl::StrCat("Required.Proto3.JsonInput.", name); + + if (!RunTest(effective_test_name, request, &response)) { + return; + } + + TestStatus test; + test.set_name(effective_test_name); + if (response.result_case() == ConformanceResponse::kParseError) { + ReportSuccess(test); + } else if (response.result_case() == ConformanceResponse::kSkipped) { + ReportSkip(test, request, response); + } else { + test.set_failure_message("Should have failed to parse, but didn't."); + ReportFailure(test, REQUIRED, request, response); + } + }; + + // Test deep Struct nesting. + { + std::string json = "{\"optionalStruct\": "; + for (int i = 0; i < 50; i++) { + json += "{\"a\": "; + } + json += '1'; + for (int i = 0; i < 50; i++) { + json += '}'; + } + json += '}'; + expect_json_parse_failure("EnforceDepthLimit.Struct", json); + } + + // Test deep ListValue nesting. + { + std::string json = "{\"optionalValue\": "; + for (int i = 0; i < 100; i++) { + json += '['; + } + json += '1'; + for (int i = 0; i < 100; i++) { + json += ']'; + } + json += '}'; + expect_json_parse_failure("EnforceDepthLimit.ListValue", json); + } } template diff --git a/java/util/src/main/java/com/google/protobuf/util/JsonFormat.java b/java/util/src/main/java/com/google/protobuf/util/JsonFormat.java index 8ba20be6d056c..aff462b1f3d1b 100644 --- a/java/util/src/main/java/com/google/protobuf/util/JsonFormat.java +++ b/java/util/src/main/java/com/google/protobuf/util/JsonFormat.java @@ -1561,13 +1561,21 @@ public void merge(ParserImpl parser, JsonElement json, Message.Builder builder) private void merge(JsonElement json, Message.Builder builder) throws InvalidProtocolBufferException { - WellKnownTypeParser specialParser = - wellKnownTypeParsers.get(builder.getDescriptorForType().getFullName()); - if (specialParser != null) { - specialParser.merge(this, json, builder); - return; + if (currentDepth >= recursionLimit) { + throw new InvalidProtocolBufferException("Hit recursion limit."); + } + ++currentDepth; + try { + WellKnownTypeParser specialParser = + wellKnownTypeParsers.get(builder.getDescriptorForType().getFullName()); + if (specialParser != null) { + specialParser.merge(this, json, builder); + return; + } + mergeMessage(json, builder, false); + } finally { + --currentDepth; } - mergeMessage(json, builder, false); } // Maps from camel-case field names to FieldDescriptor. @@ -1666,19 +1674,22 @@ private void mergeAny(JsonElement json, Message.Builder builder) DynamicMessage.getDefaultInstance(contentType).newBuilderForType(); WellKnownTypeParser specialParser = wellKnownTypeParsers.get(contentType.getFullName()); - if (currentDepth >= recursionLimit) { - throw new InvalidProtocolBufferException("Hit recursion limit."); - } - ++currentDepth; if (specialParser != null) { JsonElement value = object.get("value"); if (value != null) { - specialParser.merge(this, value, contentBuilder); + merge(value, contentBuilder); } } else { - mergeMessage(json, contentBuilder, true); + if (currentDepth >= recursionLimit) { + throw new InvalidProtocolBufferException("Hit recursion limit."); + } + ++currentDepth; + try { + mergeMessage(json, contentBuilder, true); + } finally { + --currentDepth; + } } - --currentDepth; builder.setField(valueField, contentBuilder.build().toByteString()); } @@ -1820,19 +1831,27 @@ private void mergeMapField(FieldDescriptor field, JsonElement json, Message.Buil } JsonObject object = (JsonObject) json; for (Map.Entry entry : object.entrySet()) { - Message.Builder entryBuilder = builder.newBuilderForField(field); - Object key = parseFieldValue(keyField, new JsonPrimitive(entry.getKey()), entryBuilder); - Object value = parseFieldValue(valueField, entry.getValue(), entryBuilder); - if (value == null) { - if (ignoringUnknownFields && valueField.getType() == FieldDescriptor.Type.ENUM) { - continue; - } else { - throw new InvalidProtocolBufferException("Map value cannot be null."); + if (currentDepth >= recursionLimit) { + throw new InvalidProtocolBufferException("Hit recursion limit."); + } + ++currentDepth; + try { + Message.Builder entryBuilder = builder.newBuilderForField(field); + Object key = parseFieldValue(keyField, new JsonPrimitive(entry.getKey()), entryBuilder); + Object value = parseFieldValue(valueField, entry.getValue(), entryBuilder); + if (value == null) { + if (ignoringUnknownFields && valueField.getType() == FieldDescriptor.Type.ENUM) { + continue; + } else { + throw new InvalidProtocolBufferException("Map value cannot be null."); + } } + entryBuilder.setField(keyField, key); + entryBuilder.setField(valueField, value); + builder.addRepeatedField(field, entryBuilder.build()); + } finally { + --currentDepth; } - entryBuilder.setField(keyField, key); - entryBuilder.setField(valueField, value); - builder.addRepeatedField(field, entryBuilder.build()); } } @@ -2149,13 +2168,8 @@ private Object parseFieldValue(FieldDescriptor field, JsonElement json, Message. case MESSAGE: case GROUP: - if (currentDepth >= recursionLimit) { - throw new InvalidProtocolBufferException("Hit recursion limit."); - } - ++currentDepth; Message.Builder subBuilder = builder.newBuilderForField(field); merge(json, subBuilder); - --currentDepth; return subBuilder.build(); default: diff --git a/src/google/protobuf/json/internal/lexer.h b/src/google/protobuf/json/internal/lexer.h index d588193c2f393..c3056090a4b8b 100644 --- a/src/google/protobuf/json/internal/lexer.h +++ b/src/google/protobuf/json/internal/lexer.h @@ -218,20 +218,38 @@ class JsonLexer { LocationWith BeginMark() { return {stream_.BeginMark(), json_loc_}; } - private: - friend BufferingGuard; - friend Mark; - friend MaybeOwnedString; - - absl::Status Push() { - if (options_.recursion_depth == 0) { + absl::Status Push(int depth = 1) { + if (options_.recursion_depth < depth) { return Invalid("JSON content was too deeply nested"); } - --options_.recursion_depth; + options_.recursion_depth -= depth; return absl::OkStatus(); } - void Pop() { ++options_.recursion_depth; } + void Pop(int depth = 1) { options_.recursion_depth += depth; } + + // A RAII helper to push and pop recursion depth. + class ScopedRecursion { + public: + explicit ScopedRecursion(JsonLexer& lex, int depth = 1) + : lex_(lex), depth_(depth) { + status_ = lex_.Push(depth_); + } + ~ScopedRecursion() { + if (status_.ok()) lex_.Pop(depth_); + } + absl::Status status() const { return status_; } + + private: + JsonLexer& lex_; + int depth_; + absl::Status status_; + }; + + private: + friend BufferingGuard; + friend Mark; + friend MaybeOwnedString; // Parses the next four bytes as a 16-bit hex numeral. absl::StatusOr ParseU16HexCodepoint(); diff --git a/src/google/protobuf/json/internal/parser.cc b/src/google/protobuf/json/internal/parser.cc index e2631bde67374..b21ff415430c3 100644 --- a/src/google/protobuf/json/internal/parser.cc +++ b/src/google/protobuf/json/internal/parser.cc @@ -745,6 +745,8 @@ template absl::Status ParseMapEntry(JsonLexer& lex, Field map_field, Msg& parent_msg, LocationWith& key) { + JsonLexer::ScopedRecursion recursion(lex); + RETURN_IF_ERROR(recursion.status()); bool is_map_of_enums = false; RETURN_IF_ERROR(Traits::WithFieldType( map_field, [&is_map_of_enums](const Desc& desc) { @@ -1107,6 +1109,8 @@ absl::Status ParseListValue(JsonLexer& lex, const Desc& desc, template absl::Status ParseValue(JsonLexer& lex, const Desc& desc, Msg& msg) { + JsonLexer::ScopedRecursion recursion(lex); + RETURN_IF_ERROR(recursion.status()); auto kind = lex.PeekKind(); RETURN_IF_ERROR(kind.status()); // NOTE: The field numbers 1 through 6 are the numbers of the oneof fields @@ -1191,6 +1195,8 @@ absl::Status ParseValue(JsonLexer& lex, const Desc& desc, template absl::Status ParseStructValue(JsonLexer& lex, const Desc& desc, Msg& msg) { + JsonLexer::ScopedRecursion recursion(lex); + RETURN_IF_ERROR(recursion.status()); auto entry_field = Traits::MustHaveField(desc, 1); auto pop = lex.path().Push("", FieldDescriptor::TYPE_MESSAGE, Traits::FieldTypeName(entry_field)); @@ -1209,6 +1215,8 @@ absl::Status ParseStructValue(JsonLexer& lex, const Desc& desc, template absl::Status ParseListValue(JsonLexer& lex, const Desc& desc, Msg& msg) { + JsonLexer::ScopedRecursion recursion(lex); + RETURN_IF_ERROR(recursion.status()); auto entry_field = Traits::MustHaveField(desc, 1); auto pop = lex.path().Push("", FieldDescriptor::TYPE_MESSAGE, Traits::FieldTypeName(entry_field)); @@ -1305,15 +1313,26 @@ absl::Status ParseMessage(JsonLexer& lex, const Desc& desc, case MessageType::kList: return ParseListValue(lex, desc, msg); case MessageType::kWrapper: { + JsonLexer::ScopedRecursion recursion(lex); + RETURN_IF_ERROR(recursion.status()); return ParseSingular(lex, Traits::MustHaveField(desc, 1), msg); } - case MessageType::kTimestamp: + case MessageType::kTimestamp: { + JsonLexer::ScopedRecursion recursion(lex); + RETURN_IF_ERROR(recursion.status()); return ParseTimestamp(lex, desc, msg); - case MessageType::kDuration: + } + case MessageType::kDuration: { + JsonLexer::ScopedRecursion recursion(lex); + RETURN_IF_ERROR(recursion.status()); return ParseDuration(lex, desc, msg); - case MessageType::kFieldMask: + } + case MessageType::kFieldMask: { + JsonLexer::ScopedRecursion recursion(lex); + RETURN_IF_ERROR(recursion.status()); return ParseFieldMask(lex, desc, msg); + } default: break; }