5151import java .util .Collections ;
5252import java .util .HashMap ;
5353import java .util .HashSet ;
54+ import java .util .IdentityHashMap ;
5455import java .util .List ;
5556import java .util .Map ;
5657import java .util .Optional ;
@@ -151,14 +152,14 @@ public MessageType convert(Schema avroSchema) {
151152 throw new IllegalArgumentException ("Avro schema must be a record." );
152153 }
153154 return new MessageType (
154- avroSchema .getFullName (), convertFields (avroSchema .getFields (), "" , new HashSet <Schema >()));
155+ avroSchema .getFullName (), convertFields (avroSchema .getFields (), "" , new IdentityHashMap <Schema , Void >()));
155156 }
156157
157158 private List <Type > convertFields (List <Schema .Field > fields , String schemaPath ) {
158- return convertFields (fields , schemaPath , new HashSet <Schema >());
159+ return convertFields (fields , schemaPath , new IdentityHashMap <Schema , Void >());
159160 }
160161
161- private List <Type > convertFields (List <Schema .Field > fields , String schemaPath , Set <Schema > seenSchemas ) {
162+ private List <Type > convertFields (List <Schema .Field > fields , String schemaPath , IdentityHashMap <Schema , Void > seenSchemas ) {
162163 List <Type > types = new ArrayList <Type >();
163164 for (Schema .Field field : fields ) {
164165 if (field .schema ().getType ().equals (Schema .Type .NULL )) {
@@ -173,29 +174,29 @@ private Type convertField(String fieldName, Schema schema, String schemaPath) {
173174 return convertField (fieldName , schema , Type .Repetition .REQUIRED , schemaPath );
174175 }
175176
176- private Type convertField (String fieldName , Schema schema , String schemaPath , Set <Schema > seenSchemas ) {
177+ private Type convertField (String fieldName , Schema schema , String schemaPath , IdentityHashMap <Schema , Void > seenSchemas ) {
177178 return convertField (fieldName , schema , Type .Repetition .REQUIRED , schemaPath , seenSchemas );
178179 }
179180
180181 @ SuppressWarnings ("deprecation" )
181182 private Type convertField (String fieldName , Schema schema , Type .Repetition repetition , String schemaPath ) {
182- return convertField (fieldName , schema , repetition , schemaPath , new HashSet <Schema >());
183+ return convertField (fieldName , schema , repetition , schemaPath , new IdentityHashMap <Schema , Void >());
183184 }
184185
185186 @ SuppressWarnings ("deprecation" )
186187 private Type convertField (
187- String fieldName , Schema schema , Type .Repetition repetition , String schemaPath , Set <Schema > seenSchemas ) {
188+ String fieldName , Schema schema , Type .Repetition repetition , String schemaPath , IdentityHashMap <Schema , Void > seenSchemas ) {
188189 Schema .Type type = schema .getType ();
189190 LogicalType logicalType = schema .getLogicalType ();
190191
191192 if (type .equals (Schema .Type .RECORD ) || type .equals (Schema .Type .ENUM ) || type .equals (Schema .Type .FIXED )) {
192193 // If this schema has already been seen in the current branch, we have a recursion loop
193- if (seenSchemas .contains (schema )) {
194+ if (seenSchemas .containsKey (schema )) {
194195 throw new UnsupportedOperationException (
195196 "Recursive Avro schemas are not supported by parquet-avro: " + schema .getFullName ());
196197 }
197- seenSchemas = new HashSet <>(seenSchemas );
198- seenSchemas .add (schema );
198+ seenSchemas = new IdentityHashMap <>(seenSchemas );
199+ seenSchemas .put (schema , null );
199200 }
200201
201202 Types .PrimitiveBuilder <PrimitiveType > builder ;
@@ -275,11 +276,11 @@ private Type convertField(
275276 }
276277
277278 private Type convertUnion (String fieldName , Schema schema , Type .Repetition repetition , String schemaPath ) {
278- return convertUnion (fieldName , schema , repetition , schemaPath , new HashSet <Schema >());
279+ return convertUnion (fieldName , schema , repetition , schemaPath , new IdentityHashMap <Schema , Void >());
279280 }
280281
281282 private Type convertUnion (
282- String fieldName , Schema schema , Type .Repetition repetition , String schemaPath , Set <Schema > seenSchemas ) {
283+ String fieldName , Schema schema , Type .Repetition repetition , String schemaPath , IdentityHashMap <Schema , Void > seenSchemas ) {
283284 List <Schema > nonNullSchemas = new ArrayList <Schema >(schema .getTypes ().size ());
284285 // Found any schemas in the union? Required for the edge case, where the union contains only a single type.
285286 boolean foundNullSchema = false ;
@@ -311,15 +312,15 @@ private Type convertUnion(
311312
312313 private Type convertUnionToGroupType (
313314 String fieldName , Type .Repetition repetition , List <Schema > nonNullSchemas , String schemaPath ) {
314- return convertUnionToGroupType (fieldName , repetition , nonNullSchemas , schemaPath , new HashSet <Schema >());
315+ return convertUnionToGroupType (fieldName , repetition , nonNullSchemas , schemaPath , new IdentityHashMap <Schema , Void >());
315316 }
316317
317318 private Type convertUnionToGroupType (
318319 String fieldName ,
319320 Type .Repetition repetition ,
320321 List <Schema > nonNullSchemas ,
321322 String schemaPath ,
322- Set <Schema > seenSchemas ) {
323+ IdentityHashMap <Schema , Void > seenSchemas ) {
323324 List <Type > unionTypes = new ArrayList <Type >(nonNullSchemas .size ());
324325 int index = 0 ;
325326 for (Schema childSchema : nonNullSchemas ) {
@@ -333,7 +334,7 @@ private Type convertField(Schema.Field field, String schemaPath) {
333334 return convertField (field .name (), field .schema (), schemaPath );
334335 }
335336
336- private Type convertField (Schema .Field field , String schemaPath , Set <Schema > seenSchemas ) {
337+ private Type convertField (Schema .Field field , String schemaPath , IdentityHashMap <Schema , Void > seenSchemas ) {
337338 return convertField (field .name (), field .schema (), schemaPath , seenSchemas );
338339 }
339340
0 commit comments