@@ -13,11 +13,11 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper {
1313
1414 import spark .implicits ._
1515
16- private val boundaryRules = ruleSet.getRules.filter(_.ruleType == " bounds " )
17- private val categoricalRules = ruleSet.getRules.filter(rule => rule.ruleType == " validNumerics " ||
18- rule.ruleType == " validStrings " )
19- private val dateTimeRules = ruleSet.getRules.filter(_.ruleType == " dateTime " )
20- private val complexRules = ruleSet.getRules.filter(_.ruleType == " complex " )
16+ private val boundaryRules = ruleSet.getRules.filter(_.ruleType == RuleType . ValidateBounds )
17+ private val categoricalRules = ruleSet.getRules.filter(rule => rule.ruleType == RuleType . ValidateNumerics ||
18+ rule.ruleType == RuleType . ValidateStrings )
19+ private val dateTimeRules = ruleSet.getRules.filter(_.ruleType == RuleType . ValidateDateTime )
20+ private val complexRules = ruleSet.getRules.filter(_.ruleType == RuleType . ValidateComplex )
2121 private val byCols = ruleSet.getGroupBys map col
2222
2323 /**
@@ -36,15 +36,15 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper {
3636 */
3737 private def buildValidationsByType (rule : Rule ): Column = {
3838 val nulls = mutable.Map [String , Column ](
39- " bounds " -> lit(null ).cast(ArrayType (DoubleType )).alias(" bounds " ),
40- " validNumerics " -> lit(null ).cast(ArrayType (DoubleType )).alias(" validNumerics " ),
41- " validStrings " -> lit(null ).cast(ArrayType (StringType )).alias(" validStrings " ),
42- " validDate " -> lit(null ).cast(LongType ).alias(" validDate " )
39+ RuleType . ValidateBounds .toString -> lit(null ).cast(ArrayType (DoubleType )).alias(RuleType . ValidateBounds .toString ),
40+ RuleType . ValidateNumerics .toString -> lit(null ).cast(ArrayType (DoubleType )).alias(RuleType . ValidateNumerics .toString ),
41+ RuleType . ValidateStrings .toString -> lit(null ).cast(ArrayType (StringType )).alias(RuleType . ValidateStrings .toString ),
42+ RuleType . ValidateDateTime .toString -> lit(null ).cast(LongType ).alias(RuleType . ValidateDateTime .toString )
4343 )
4444 rule.ruleType match {
45- case " bounds " => nulls(" bounds " ) = array(lit(rule.boundaries.lower), lit(rule.boundaries.upper)).alias(" bounds " )
46- case " validNumerics " => nulls(" validNumerics " ) = lit(rule.validNumerics).alias(" validNumerics " )
47- case " validStrings " => nulls(" validStrings " ) = lit(rule.validStrings).alias(" validStrings " )
45+ case RuleType . ValidateBounds => nulls(RuleType . ValidateBounds .toString ) = array(lit(rule.boundaries.lower), lit(rule.boundaries.upper)).alias(RuleType . ValidateBounds .toString )
46+ case RuleType . ValidateNumerics => nulls(RuleType . ValidateNumerics .toString ) = lit(rule.validNumerics).alias(RuleType . ValidateNumerics .toString )
47+ case RuleType . ValidateStrings => nulls(RuleType . ValidateStrings .toString ) = lit(rule.validStrings).alias(RuleType . ValidateStrings .toString )
4848 }
4949 val validationsByType = nulls.toMap.values.toSeq
5050 struct(
@@ -61,7 +61,7 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper {
6161 private def buildOutputStruct (rule : Rule , results : Seq [Column ]): Column = {
6262 struct(
6363 lit(rule.ruleName).alias(" Rule_Name" ),
64- lit(rule.ruleType).alias(" Rule_Type" ),
64+ lit(rule.ruleType.toString ).alias(" Rule_Type" ),
6565 buildValidationsByType(rule),
6666 struct(results : _* ).alias(" Results" )
6767 ).alias(" Validation" )
@@ -101,28 +101,42 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper {
101101
102102 // Results must have Invalid_Count & Failed
103103 rule.ruleType match {
104- case " bounds" =>
104+ case RuleType .ValidateBounds =>
105+ // Rule evaluation for NON-AGG RULES ONLY
105106 val invalid = rule.inputColumn < rule.boundaries.lower || rule.inputColumn > rule.boundaries.upper
107+ // This is the first select it must come before subsequent selects as it aliases the original column name
108+ // to that of the rule name. ADDITIONALLY, this evaluates the boundary rule WHEN the input col is not an Agg.
109+ // This can be confusing because for Non-agg columns it renames the column to the rule_name AND returns a 0
110+ // or 1 (not the original value)
111+ // IF the rule is NOT an AGG then the column is simply aliased to the rule name and no evaluation takes place
112+ // here.
113+ val first = if (! rule.isAgg) { // Not Agg
114+ sum(when(invalid, 1 ).otherwise(0 )).alias(rule.ruleName)
115+ } else { // Is Agg
116+ rule.inputColumn.alias(rule.ruleName)
117+ }
118+ // WHEN RULE IS AGG -- this is where the evaluation happens. The input column was renamed to the name of the
119+ // rule in the required previous select.
120+ // IMPORTANT: REMEMBER - that agg expressions evaluate to a single output value thus the invalid_count in
121+ // cases where agg is used cannot be > 1 since the sum of a single value cannot exceed 1.
122+
123+ // WHEN RULE NOT AGG - determine if the result of "first" select (0 or 1) is > 0, if it is, the rule has
124+ // failed since the sum(1 or more 1s) means that 1 or more rows have failed thus the rule has failed
106125 val failed = if (rule.isAgg) {
107126 when(
108127 col(rule.ruleName) < rule.boundaries.lower || col(rule.ruleName) > rule.boundaries.upper, true )
109128 .otherwise(false ).alias(" Failed" )
110129 } else {
111130 when(col(rule.ruleName) > 0 ,true ).otherwise(false ).alias(" Failed" )
112131 }
113- val first = if (! rule.isAgg) { // Not Agg
114- sum(when(invalid, 1 ).otherwise(0 )).alias(rule.ruleName)
115- } else { // Is Agg
116- rule.inputColumn.alias(rule.ruleName)
117- }
118132 val results = if (rule.isAgg) {
119133 Seq (when(failed, 1 ).otherwise(0 ).cast(LongType ).alias(" Invalid_Count" ), failed)
120134 } else {
121135 Seq (col(rule.ruleName).cast(LongType ).alias(" Invalid_Count" ), failed)
122136 }
123137 Selects (buildOutputStruct(rule, results), first)
124- case x if x == " validNumerics " || x == " validStrings " =>
125- val invalid = if (x == " validNumerics " ) {
138+ case x if x == RuleType . ValidateNumerics || x == RuleType . ValidateStrings =>
139+ val invalid = if (x == RuleType . ValidateNumerics ) {
126140 expr(s " size(array_except( ${rule.ruleName}, " +
127141 s " array( ${rule.validNumerics.mkString(" D," )}D))) " )
128142 } else {
@@ -134,8 +148,8 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper {
134148 val first = collect_set(rule.inputColumn).alias(rule.ruleName)
135149 val results = Seq (invalid.cast(LongType ).alias(" Invalid_Count" ), failed)
136150 Selects (buildOutputStruct(rule, results), first)
137- case " validDate " => ??? // TODO
138- case " complex " => ??? // TODO
151+ case RuleType . ValidateDateTime => ??? // TODO
152+ case RuleType . ValidateComplex => ??? // TODO
139153 }
140154 })
141155 }
0 commit comments