Skip to content

Commit 9a8aa9c

Browse files
authored
Merge pull request #42 from target/add-stats
Column Stats
2 parents 9805204 + b05b915 commit 9a8aa9c

18 files changed

Lines changed: 767 additions & 8 deletions

README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,25 @@ This check sums a column in all rows. If the sum applied to the `column` doesn't
326326

327327
**Note:** If bounds are non-inclusive, and the actual sum is equal to one of the bounds, the relative error percentage will be undefined.
328328

329+
#### `colstats`
330+
331+
This check generates column statistics about the specified column.
332+
333+
| Arg | Type | Description |
334+
|-------------|-------------|--------------------------------------------|
335+
| `column` | String | The column on which to collect statistics. |
336+
337+
These values will appear in the check's JSON summary when using the JSON report output mode:
338+
339+
| Arg | Type | Description |
340+
|-------------|-------------|-------------------------------------------------------------------------------------------------------------------------|
341+
| `count` | Integer | Count of non-null entries in the `column`. |
342+
| `mean` | Double | Mean/Average of the values in the `column`. |
343+
| `min` | Double | Smallest value in the `column`. |
344+
| `max` | Double | Largest value in the `column`. |
345+
| `stdDev` | Double | Standard deviation of the values in the `column`. |
346+
| `histogram` | Complex | Summary of an equi-width histogram, counts of values appearing in 10 equally sized buckets over the range `[min, max]`. |
347+
329348
## Example Config
330349

331350
```yaml

src/main/scala/com/target/data_validator/JsonEncoders.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ object JsonEncoders extends LazyLogging {
1919
Json.fromString(a.toString)
2020
}
2121

22+
// scalastyle:off cyclomatic.complexity
2223
implicit val eventEncoder: Encoder[ValidatorEvent] = new Encoder[ValidatorEvent] {
2324
override def apply(a: ValidatorEvent): Json = a match {
2425
case vc: ValidatorCounter => Json.obj(
@@ -69,8 +70,10 @@ object JsonEncoders extends LazyLogging {
6970
("src", Json.fromString(vs.src)),
7071
("dest", vs.dest)
7172
)
73+
case vj: JsonEvent => vj.json
7274
}
7375
}
76+
// scalastyle:on cyclomatic.complexity
7477

7578
implicit val baseEncoder: Encoder[ValidatorBase] = new Encoder[ValidatorBase] {
7679
final def apply(a: ValidatorBase): Json = a.toJson
@@ -98,6 +101,13 @@ object JsonEncoders extends LazyLogging {
98101
("keyColumns", vp.keyColumns.asJson),
99102
("checks", vp.checks.asJson),
100103
("events", vp.getEvents.asJson))
104+
case vdf: ValidatorDataFrame => Json.obj(
105+
("dfLabel", vdf.label.asJson),
106+
("failed", vdf.failed.asJson),
107+
("keyColumns", vdf.keyColumns.asJson),
108+
("checks", vdf.checks.asJson),
109+
("events", vdf.getEvents.asJson)
110+
)
101111
}
102112
}
103113

src/main/scala/com/target/data_validator/ValidatorEvent.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,9 @@ case class VarSubJsonEvent(src: String, dest: Json) extends ValidatorEvent {
9999
override def toString: String = s"VarSub src: $src dest: ${dest.noSpaces}"
100100
override def toHTML: Text.all.Tag = div(cls:="subEvent")(toString)
101101
}
102+
103+
case class JsonEvent(json: Json) extends ValidatorEvent {
104+
override def failed: Boolean = false
105+
override def toString: String = s"JsonEvent: json:${json.noSpaces}"
106+
override def toHTML: Text.all.Tag = div(cls := "jsonEvent")(toString)
107+
}

src/main/scala/com/target/data_validator/ValidatorTable.scala

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
package com.target.data_validator
22

3-
import com.target.data_validator.validator.{CheapCheck, ColumnBased, CostlyCheck, RowBased, ValidatorBase}
4-
import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession}
3+
import com.target.data_validator.validator._
4+
import org.apache.spark.sql._
55
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression}
66
import org.apache.spark.sql.catalyst.expressions.aggregate.{Count, Sum}
77

88
import scala.collection.mutable.ListBuffer
9-
import scala.util.{Failure, Success, Try}
9+
import scala.util._
1010
import scalatags.Text.all._
1111

1212
abstract class ValidatorTable(
@@ -73,16 +73,37 @@ abstract class ValidatorTable(
7373
ret
7474
}
7575

76+
private def performFirstPass(df: DataFrame, checks: List[TwoPassCheapCheck]): Unit = {
77+
if (checks.nonEmpty) {
78+
val firstPassTimer = new ValidatorTimer(s"$label: pre-processing stage")
79+
80+
addEvent(firstPassTimer)
81+
82+
firstPassTimer.time {
83+
val cols = checks.map { _.firstPassSelect() }
84+
val row = df.select(cols: _*).head
85+
86+
checks foreach { _ sinkFirstPassRow row }
87+
}
88+
}
89+
}
90+
91+
private def cheapExpression(dataFrame: DataFrame, dict: VarSubstitution): PartialFunction[CheapCheck, Expression] = {
92+
case tp: TwoPassCheapCheck => tp.select(dataFrame.schema, dict)
93+
case colChk: ColumnBased => colChk.select(dataFrame.schema, dict)
94+
case chk: RowBased => Sum(chk.select(dataFrame.schema, dict)).toAggregateExpression()
95+
}
96+
7697
def quickChecks(session: SparkSession, dict: VarSubstitution)(implicit vc: ValidatorConfig): Boolean = {
7798
val dataFrame = open(session).get
99+
100+
performFirstPass(dataFrame, checks.collect { case tp: TwoPassCheapCheck => tp })
101+
78102
val qc: List[CheapCheck] = checks.flatMap {
79103
case cc: CheapCheck => Some(cc)
80104
case _ => None
81105
}
82-
val checkSelects: Seq[Expression] = qc.map {
83-
case colChk: ColumnBased => colChk.select(dataFrame.schema, dict)
84-
case chk: RowBased => Sum(chk.select(dataFrame.schema, dict)).toAggregateExpression()
85-
}
106+
val checkSelects = qc.map(cheapExpression(dataFrame, dict))
86107

87108
val cols: Seq[Column] = createCountSelect() ++ checkSelects.zipWithIndex.map {
88109
case (chkSelect: Expression, idx: Int) => new Column(Alias(chkSelect, s"qc$idx")())
@@ -278,6 +299,9 @@ case class ValidatorDataFrame(
278299
checks,
279300
"DataFrame" + condition.map(x => s" with condition($x)").getOrElse("")
280301
) {
302+
303+
final def label: String = "DataFrame" + condition.map(x => s" with condition($x)").getOrElse("")
304+
281305
override def getDF(session: SparkSession): Try[DataFrame] = Success(df)
282306

283307
override def substituteVariables(dict: VarSubstitution): ValidatorTable = {
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package com.target.data_validator.stats
2+
3+
case class Bin(lowerBound: Double, upperBound: Double, count: Long)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package com.target.data_validator.stats
2+
3+
import io.circe._
4+
import io.circe.generic.semiauto._
5+
6+
case class CompleteStats(
7+
name: String,
8+
column: String,
9+
count: Long,
10+
mean: Double,
11+
min: Double,
12+
max: Double,
13+
stdDev: Double,
14+
histogram: Histogram
15+
)
16+
17+
object CompleteStats {
18+
implicit val binEncoder: Encoder[Bin] = deriveEncoder
19+
implicit val histogramEncoder: Encoder[Histogram] = deriveEncoder
20+
implicit val encoder: Encoder[CompleteStats] = deriveEncoder
21+
22+
implicit val binDecoder: Decoder[Bin] = deriveDecoder
23+
implicit val histogramDecoder: Decoder[Histogram] = deriveDecoder
24+
implicit val decoder: Decoder[CompleteStats] = deriveDecoder
25+
26+
def apply(
27+
name: String,
28+
column: String,
29+
firstPassStats: FirstPassStats,
30+
secondPassStats: SecondPassStats
31+
): CompleteStats = CompleteStats(
32+
name = name,
33+
column = column,
34+
count = firstPassStats.count,
35+
mean = firstPassStats.mean,
36+
min = firstPassStats.min,
37+
max = firstPassStats.max,
38+
stdDev = secondPassStats.stdDev,
39+
histogram = secondPassStats.histogram
40+
)
41+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package com.target.data_validator.stats
2+
3+
import org.apache.spark.sql.Row
4+
import org.apache.spark.sql.catalyst.ScalaReflection
5+
import org.apache.spark.sql.types.DataType
6+
7+
case class FirstPassStats(count: Long, mean: Double, min: Double, max: Double)
8+
9+
object FirstPassStats {
10+
def dataType: DataType = ScalaReflection
11+
.schemaFor[FirstPassStats]
12+
.dataType
13+
14+
/**
15+
* Convert from Spark SQL row format to case class [[FirstPassStats]] format.
16+
*
17+
* @param row a complex column of [[org.apache.spark.sql.types.StructType]] output of [[FirstPassStatsAggregator]]
18+
* @return struct format converted to [[FirstPassStats]]
19+
*/
20+
def fromRowRepr(row: Row): FirstPassStats = {
21+
FirstPassStats(
22+
count = row.getLong(0),
23+
mean = row.getDouble(1),
24+
min = row.getDouble(2),
25+
max = row.getDouble(3)
26+
)
27+
}
28+
29+
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package com.target.data_validator.stats
2+
3+
import org.apache.spark.sql.Row
4+
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
5+
import org.apache.spark.sql.types._
6+
7+
/**
8+
* Calculate the count, mean, min and maximum values of a numeric column.
9+
*/
10+
class FirstPassStatsAggregator extends UserDefinedAggregateFunction {
11+
12+
/**
13+
* input is a single column of `DoubleType`
14+
*/
15+
override def inputSchema: StructType = new StructType().add("value", DoubleType)
16+
17+
/**
18+
* buffer keeps state for the count, sum, min and max
19+
*/
20+
override def bufferSchema: StructType = new StructType()
21+
.add(StructField("count", LongType))
22+
.add(StructField("sum", DoubleType))
23+
.add(StructField("min", DoubleType))
24+
.add(StructField("max", DoubleType))
25+
26+
private val count = bufferSchema.fieldIndex("count")
27+
private val sum = bufferSchema.fieldIndex("sum")
28+
private val min = bufferSchema.fieldIndex("min")
29+
private val max = bufferSchema.fieldIndex("max")
30+
31+
/**
32+
* specifies the return type when using the UDAF
33+
*/
34+
override def dataType: DataType = FirstPassStats.dataType
35+
36+
/**
37+
* These calculations are deterministic
38+
*/
39+
override def deterministic: Boolean = true
40+
41+
/**
42+
* set the initial values for count, sum, min and max
43+
*/
44+
override def initialize(buffer: MutableAggregationBuffer): Unit = {
45+
buffer(count) = 0L
46+
buffer(sum) = 0.0
47+
buffer(min) = Double.MaxValue
48+
buffer(max) = Double.MinValue
49+
}
50+
51+
/**
52+
* update the count, sum, min and max buffer values
53+
*/
54+
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
55+
buffer(count) = buffer.getLong(count) + 1
56+
buffer(sum) = buffer.getDouble(sum) + input.getDouble(0)
57+
buffer(min) = math.min(input.getDouble(0), buffer.getDouble(min))
58+
buffer(max) = math.max(input.getDouble(0), buffer.getDouble(max))
59+
}
60+
61+
/**
62+
* reduce the count, sum, min and max values of two buffers
63+
*/
64+
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
65+
buffer1(count) = buffer1.getLong(count) + buffer2.getLong(count)
66+
buffer1(sum) = buffer1.getDouble(sum) + buffer2.getDouble(sum)
67+
buffer1(min) = math.min(buffer1.getDouble(min), buffer2.getDouble(min))
68+
buffer1(max) = math.max(buffer1.getDouble(max), buffer2.getDouble(max))
69+
}
70+
71+
/**
72+
* evaluate the count, mean, min and max values of a column
73+
*/
74+
override def evaluate(buffer: Row): Any = {
75+
FirstPassStats(
76+
buffer.getLong(count),
77+
buffer.getDouble(sum) / buffer.getLong(count),
78+
buffer.getDouble(min),
79+
buffer.getDouble(max)
80+
)
81+
}
82+
83+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
package com.target.data_validator.stats
2+
3+
case class Histogram(bins: Seq[Bin])
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package com.target.data_validator.stats
2+
3+
import org.apache.spark.sql.Row
4+
import org.apache.spark.sql.catalyst.ScalaReflection
5+
import org.apache.spark.sql.types.DataType
6+
7+
case class SecondPassStats(stdDev: Double, histogram: Histogram)
8+
9+
object SecondPassStats {
10+
def dataType: DataType = ScalaReflection
11+
.schemaFor[SecondPassStats]
12+
.dataType
13+
14+
/**
15+
* Convert from Spark SQL row format to case class [[SecondPassStats]] format.
16+
*
17+
* @param row a complex column of [[org.apache.spark.sql.types.StructType]] output of [[SecondPassStatsAggregator]]
18+
* @return struct format converted to [[SecondPassStats]]
19+
*/
20+
def fromRowRepr(row: Row): SecondPassStats = {
21+
SecondPassStats(
22+
stdDev = row.getDouble(0),
23+
histogram = Histogram(
24+
row.getStruct(1).getSeq[Row](0) map {
25+
bin => Bin(
26+
lowerBound = bin.getDouble(0),
27+
upperBound = bin.getDouble(1),
28+
count = bin.getLong(2)
29+
)
30+
}
31+
)
32+
)
33+
}
34+
35+
}

0 commit comments

Comments
 (0)