Skip to content

Commit b8a274e

Browse files
committed
update
1 parent a1ad19a commit b8a274e

3 files changed

Lines changed: 144 additions & 0 deletions

File tree

parquet-avro/src/main/java/org/apache/parquet/avro/AvroConverters.java

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
import java.lang.reflect.Constructor;
2222
import java.lang.reflect.InvocationTargetException;
23+
import java.math.BigDecimal;
24+
import java.math.BigInteger;
2325
import java.nio.ByteBuffer;
2426
import org.apache.avro.Schema;
2527
import org.apache.avro.generic.GenericData;
@@ -29,6 +31,7 @@
2931
import org.apache.parquet.io.api.Binary;
3032
import org.apache.parquet.io.api.GroupConverter;
3133
import org.apache.parquet.io.api.PrimitiveConverter;
34+
import org.apache.parquet.schema.LogicalTypeAnnotation;
3235
import org.apache.parquet.schema.PrimitiveStringifier;
3336
import org.apache.parquet.schema.PrimitiveType;
3437

@@ -339,4 +342,74 @@ public String convert(Binary binary) {
339342
return stringifier.stringify(binary);
340343
}
341344
}
345+
346+
static final class FieldDecimalIntConverter extends AvroPrimitiveConverter {
347+
private final int scale;
348+
private int[] dict = null;
349+
350+
public FieldDecimalIntConverter(ParentValueContainer parent, PrimitiveType type) {
351+
super(parent);
352+
LogicalTypeAnnotation.DecimalLogicalTypeAnnotation decimalType =
353+
(LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) type.getLogicalTypeAnnotation();
354+
this.scale = decimalType.getScale();
355+
}
356+
357+
@Override
358+
public void addInt(int value) {
359+
parent.add(new BigDecimal(BigInteger.valueOf(value), scale));
360+
}
361+
362+
@Override
363+
public boolean hasDictionarySupport() {
364+
return true;
365+
}
366+
367+
@Override
368+
public void setDictionary(Dictionary dictionary) {
369+
dict = new int[dictionary.getMaxId() + 1];
370+
for (int i = 0; i <= dictionary.getMaxId(); i++) {
371+
dict[i] = dictionary.decodeToInt(i);
372+
}
373+
}
374+
375+
@Override
376+
public void addValueFromDictionary(int dictionaryId) {
377+
addInt(dict[dictionaryId]);
378+
}
379+
}
380+
381+
static final class FieldDecimalLongConverter extends AvroPrimitiveConverter {
382+
private final int scale;
383+
private long[] dict = null;
384+
385+
public FieldDecimalLongConverter(ParentValueContainer parent, PrimitiveType type) {
386+
super(parent);
387+
LogicalTypeAnnotation.DecimalLogicalTypeAnnotation decimalType =
388+
(LogicalTypeAnnotation.DecimalLogicalTypeAnnotation) type.getLogicalTypeAnnotation();
389+
this.scale = decimalType.getScale();
390+
}
391+
392+
@Override
393+
public void addLong(long value) {
394+
parent.add(new BigDecimal(BigInteger.valueOf(value), scale));
395+
}
396+
397+
@Override
398+
public boolean hasDictionarySupport() {
399+
return true;
400+
}
401+
402+
@Override
403+
public void setDictionary(Dictionary dictionary) {
404+
dict = new long[dictionary.getMaxId() + 1];
405+
for (int i = 0; i <= dictionary.getMaxId(); i++) {
406+
dict[i] = dictionary.decodeToLong(i);
407+
}
408+
}
409+
410+
@Override
411+
public void addValueFromDictionary(int dictionaryId) {
412+
addLong(dict[dictionaryId]);
413+
}
414+
}
342415
}

parquet-avro/src/main/java/org/apache/parquet/avro/AvroRecordConverter.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,14 @@ private static Converter newConverter(
337337
return newConverter(schema, type, model, null, setter, validator);
338338
}
339339

340+
private static boolean isDecimalType(Type type) {
341+
if (!type.isPrimitive()) {
342+
return false;
343+
}
344+
LogicalTypeAnnotation annotation = type.getLogicalTypeAnnotation();
345+
return annotation instanceof LogicalTypeAnnotation.DecimalLogicalTypeAnnotation;
346+
}
347+
340348
private static Converter newConverter(
341349
Schema schema,
342350
Type type,
@@ -359,6 +367,9 @@ private static Converter newConverter(
359367
case BOOLEAN:
360368
return new AvroConverters.FieldBooleanConverter(parent);
361369
case INT:
370+
if (isDecimalType(type)) {
371+
return new AvroConverters.FieldDecimalIntConverter(parent, type.asPrimitiveType());
372+
}
362373
Class<?> intDatumClass = getDatumClass(conversion, knownClass, schema, model);
363374
if (intDatumClass == null) {
364375
return new AvroConverters.FieldIntegerConverter(parent);
@@ -374,6 +385,9 @@ private static Converter newConverter(
374385
}
375386
return new AvroConverters.FieldIntegerConverter(parent);
376387
case LONG:
388+
if (isDecimalType(type)) {
389+
return new AvroConverters.FieldDecimalLongConverter(parent, type.asPrimitiveType());
390+
}
377391
return new AvroConverters.FieldLongConverter(parent);
378392
case FLOAT:
379393
return new AvroConverters.FieldFloatConverter(parent);

parquet-avro/src/test/java/org/apache/parquet/avro/TestReadWrite.java

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
package org.apache.parquet.avro;
2020

2121
import static org.apache.parquet.avro.AvroTestUtil.optional;
22+
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64;
23+
import static org.apache.parquet.schema.Type.Repetition.REQUIRED;
2224
import static org.junit.Assert.assertEquals;
2325
import static org.junit.Assert.assertNotNull;
2426

@@ -61,17 +63,23 @@
6163
import org.apache.parquet.conf.ParquetConfiguration;
6264
import org.apache.parquet.conf.PlainParquetConfiguration;
6365
import org.apache.parquet.example.data.Group;
66+
import org.apache.parquet.example.data.GroupFactory;
67+
import org.apache.parquet.example.data.simple.SimpleGroupFactory;
6468
import org.apache.parquet.hadoop.ParquetReader;
6569
import org.apache.parquet.hadoop.ParquetWriter;
6670
import org.apache.parquet.hadoop.api.WriteSupport;
71+
import org.apache.parquet.hadoop.example.ExampleParquetWriter;
6772
import org.apache.parquet.hadoop.example.GroupReadSupport;
6873
import org.apache.parquet.hadoop.util.HadoopCodecs;
6974
import org.apache.parquet.io.InputFile;
7075
import org.apache.parquet.io.LocalInputFile;
7176
import org.apache.parquet.io.LocalOutputFile;
7277
import org.apache.parquet.io.api.Binary;
7378
import org.apache.parquet.io.api.RecordConsumer;
79+
import org.apache.parquet.schema.LogicalTypeAnnotation;
80+
import org.apache.parquet.schema.MessageType;
7481
import org.apache.parquet.schema.MessageTypeParser;
82+
import org.apache.parquet.schema.PrimitiveType;
7583
import org.junit.Assert;
7684
import org.junit.Rule;
7785
import org.junit.Test;
@@ -400,6 +408,55 @@ public void testFixedDecimalValues() throws Exception {
400408
Assert.assertEquals("Content should match", expected, records);
401409
}
402410

411+
@Test
412+
public void testDecimalInt64Values() throws Exception {
413+
414+
File file = temp.newFile("test_decimal_int64_values.parquet");
415+
file.delete();
416+
Path path = new Path(file.toString());
417+
418+
MessageType parquetSchema = new MessageType(
419+
"test_decimal_int64_values",
420+
new PrimitiveType(REQUIRED, INT64, "decimal_salary").withLogicalTypeAnnotation(LogicalTypeAnnotation.decimalType(1, 10)));
421+
422+
try (ParquetWriter<Group> writer =
423+
ExampleParquetWriter.builder(path).withType(parquetSchema).build()) {
424+
425+
GroupFactory factory = new SimpleGroupFactory(parquetSchema);
426+
427+
Group group1 = factory.newGroup();
428+
group1.add("decimal_salary", 234L);
429+
writer.write(group1);
430+
431+
Group group2 = factory.newGroup();
432+
group2.add("decimal_salary", 1203L);
433+
writer.write(group2);
434+
}
435+
436+
GenericData decimalSupport = new GenericData();
437+
decimalSupport.addLogicalTypeConversion(new Conversions.DecimalConversion());
438+
439+
List<GenericRecord> records = Lists.newArrayList();
440+
try (ParquetReader<GenericRecord> reader = AvroParquetReader.<GenericRecord>builder(path)
441+
.withDataModel(decimalSupport)
442+
.build()) {
443+
GenericRecord rec;
444+
while ((rec = reader.read()) != null) {
445+
records.add(rec);
446+
}
447+
}
448+
449+
Assert.assertEquals("Should read 2 records", 2, records.size());
450+
451+
Object firstSalary = records.get(0).get("decimal_salary");
452+
Object secondSalary = records.get(1).get("decimal_salary");
453+
454+
Assert.assertTrue(
455+
"Should be BigDecimal, but is " + firstSalary.getClass(), firstSalary instanceof BigDecimal);
456+
Assert.assertEquals("Should be 23.4, but is " + firstSalary, new BigDecimal("23.4"), firstSalary);
457+
Assert.assertEquals("Should be 120.3, but is " + secondSalary, new BigDecimal("120.3"), secondSalary);
458+
}
459+
403460
@Test
404461
public void testAll() throws Exception {
405462
Schema schema =

0 commit comments

Comments
 (0)