|
| 1 | +package com.target.data_validator.stats |
| 2 | + |
| 3 | +import org.apache.spark.sql.Row |
| 4 | +import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} |
| 5 | +import org.apache.spark.sql.types._ |
| 6 | + |
| 7 | +/** |
| 8 | + * Calculate the count, mean, min and maximum values of a numeric column. |
| 9 | + */ |
| 10 | +class FirstPassStatsAggregator extends UserDefinedAggregateFunction { |
| 11 | + |
| 12 | + /** |
| 13 | + * input is a single column of `DoubleType` |
| 14 | + */ |
| 15 | + override def inputSchema: StructType = new StructType().add("value", DoubleType) |
| 16 | + |
| 17 | + /** |
| 18 | + * buffer keeps state for the count, sum, min and max |
| 19 | + */ |
| 20 | + override def bufferSchema: StructType = new StructType() |
| 21 | + .add(StructField("count", LongType)) |
| 22 | + .add(StructField("sum", DoubleType)) |
| 23 | + .add(StructField("min", DoubleType)) |
| 24 | + .add(StructField("max", DoubleType)) |
| 25 | + |
| 26 | + private val count = bufferSchema.fieldIndex("count") |
| 27 | + private val sum = bufferSchema.fieldIndex("sum") |
| 28 | + private val min = bufferSchema.fieldIndex("min") |
| 29 | + private val max = bufferSchema.fieldIndex("max") |
| 30 | + |
| 31 | + /** |
| 32 | + * specifies the return type when using the UDAF |
| 33 | + */ |
| 34 | + override def dataType: DataType = FirstPassStats.dataType |
| 35 | + |
| 36 | + /** |
| 37 | + * These calculations are deterministic |
| 38 | + */ |
| 39 | + override def deterministic: Boolean = true |
| 40 | + |
| 41 | + /** |
| 42 | + * set the initial values for count, sum, min and max |
| 43 | + */ |
| 44 | + override def initialize(buffer: MutableAggregationBuffer): Unit = { |
| 45 | + buffer(count) = 0L |
| 46 | + buffer(sum) = 0.0 |
| 47 | + buffer(min) = Double.MaxValue |
| 48 | + buffer(max) = Double.MinValue |
| 49 | + } |
| 50 | + |
| 51 | + /** |
| 52 | + * update the count, sum, min and max buffer values |
| 53 | + */ |
| 54 | + override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { |
| 55 | + buffer(count) = buffer.getLong(count) + 1 |
| 56 | + buffer(sum) = buffer.getDouble(sum) + input.getDouble(0) |
| 57 | + buffer(min) = math.min(input.getDouble(0), buffer.getDouble(min)) |
| 58 | + buffer(max) = math.max(input.getDouble(0), buffer.getDouble(max)) |
| 59 | + } |
| 60 | + |
| 61 | + /** |
| 62 | + * reduce the count, sum, min and max values of two buffers |
| 63 | + */ |
| 64 | + override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { |
| 65 | + buffer1(count) = buffer1.getLong(count) + buffer2.getLong(count) |
| 66 | + buffer1(sum) = buffer1.getDouble(sum) + buffer2.getDouble(sum) |
| 67 | + buffer1(min) = math.min(buffer1.getDouble(min), buffer2.getDouble(min)) |
| 68 | + buffer1(max) = math.max(buffer1.getDouble(max), buffer2.getDouble(max)) |
| 69 | + } |
| 70 | + |
| 71 | + /** |
| 72 | + * evaluate the count, mean, min and max values of a column |
| 73 | + */ |
| 74 | + override def evaluate(buffer: Row): Any = { |
| 75 | + FirstPassStats( |
| 76 | + buffer.getLong(count), |
| 77 | + buffer.getDouble(sum) / buffer.getLong(count), |
| 78 | + buffer.getDouble(min), |
| 79 | + buffer.getDouble(max) |
| 80 | + ) |
| 81 | + } |
| 82 | + |
| 83 | +} |
0 commit comments