Skip to content

Commit 1bef366

Browse files
committed
PR6 fixes and docs
1 parent 0446747 commit 1bef366

3 files changed

Lines changed: 59 additions & 8 deletions

File tree

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package com.databricks.labs.validation
2+
3+
import com.databricks.labs.validation.utils.Structures.{Bounds, MinMaxRuleDef}
4+
import com.databricks.labs.validation.utils.SparkSessionWrapper
5+
import org.apache.spark.sql.functions._
6+
7+
object QuickTest extends App with SparkSessionWrapper {
8+
9+
import spark.implicits._
10+
11+
val testDF = Seq(
12+
(1, 2, 3),
13+
(4, 5, 6),
14+
(7, 8, 9)
15+
).toDF("retail_price", "scan_price", "cost")
16+
17+
Rule("Reasonable_sku_counts", count(col("sku")), Bounds(lower = 20.0, upper = 200.0))
18+
19+
val minMaxPriceDefs = Array(
20+
MinMaxRuleDef("MinMax_Retail_Price_Minus_Scan_Price", col("retail_price")-col("scan_price"), Bounds(0.0, 29.99)),
21+
MinMaxRuleDef("MinMax_Scan_Price_Minus_Retail_Price", col("scan_price")-col("retail_price"), Bounds(0.0, 29.99))
22+
)
23+
24+
val someRuleSet = RuleSet(testDF)
25+
.add(Rule("retail_pass", col("retail_price"), Bounds(lower = 1.0, upper = 7.0)))
26+
.add(Rule("retail_agg_pass_high", max(col("retail_price")), Bounds(lower = 0.0, upper = 7.1)))
27+
.add(Rule("retail_agg_pass_low", min(col("retail_price")), Bounds(lower = 0.0, upper = 7.0)))
28+
.add(Rule("retail_fail_low", col("retail_price"), Bounds(lower = 1.1, upper = 7.0)))
29+
.add(Rule("retail_fail_high", col("retail_price"), Bounds(lower = 0.0, upper = 6.9)))
30+
.add(Rule("retail_agg_fail_high", max(col("retail_price")), Bounds(lower = 0.0, upper = 6.9)))
31+
.add(Rule("retail_agg_fail_low", min(col("retail_price")), Bounds(lower = 1.1, upper = 7.0)))
32+
.addMinMaxRules(minMaxPriceDefs: _*)
33+
val (rulesReport, passed) = someRuleSet.validate()
34+
35+
testDF.show(20, false)
36+
rulesReport.show(20, false)
37+
}

src/main/scala/com/databricks/labs/validation/Validator.scala

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,19 +102,33 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper {
102102
// Results must have Invalid_Count & Failed
103103
rule.ruleType match {
104104
case RuleType.ValidateBounds =>
105+
// Rule evaluation for NON-AGG RULES ONLY
105106
val invalid = rule.inputColumn < rule.boundaries.lower || rule.inputColumn > rule.boundaries.upper
107+
// This is the first select it must come before subsequent selects as it aliases the original column name
108+
// to that of the rule name. ADDITIONALLY, this evaluates the boundary rule WHEN the input col is not an Agg.
109+
// This can be confusing because for Non-agg columns it renames the column to the rule_name AND returns a 0
110+
// or 1 (not the original value)
111+
// IF the rule is NOT an AGG then the column is simply aliased to the rule name and no evaluation takes place
112+
// here.
113+
val first = if (!rule.isAgg) { // Not Agg
114+
sum(when(invalid, 1).otherwise(0)).alias(rule.ruleName)
115+
} else { // Is Agg
116+
rule.inputColumn.alias(rule.ruleName)
117+
}
118+
// WHEN RULE IS AGG -- this is where the evaluation happens. The input column was renamed to the name of the
119+
// rule in the required previous select.
120+
// IMPORTANT: REMEMBER - that agg expressions evaluate to a single output value thus the invalid_count in
121+
// cases where agg is used cannot be > 1 since the sum of a single value cannot exceed 1.
122+
123+
// WHEN RULE NOT AGG - determine if the result of "first" select (0 or 1) is > 0, if it is, the rule has
124+
// failed since the sum(1 or more 1s) means that 1 or more rows have failed thus the rule has failed
106125
val failed = if (rule.isAgg) {
107126
when(
108127
col(rule.ruleName) < rule.boundaries.lower || col(rule.ruleName) > rule.boundaries.upper, true)
109128
.otherwise(false).alias("Failed")
110129
} else{
111130
when(col(rule.ruleName) > 0,true).otherwise(false).alias("Failed")
112131
}
113-
val first = if (!rule.isAgg) { // Not Agg
114-
sum(when(invalid, 1).otherwise(0)).alias(rule.ruleName)
115-
} else { // Is Agg
116-
rule.inputColumn.alias(rule.ruleName)
117-
}
118132
val results = if (rule.isAgg) {
119133
Seq(when(failed, 1).otherwise(0).cast(LongType).alias("Invalid_Count"), failed)
120134
} else {

src/test/scala/com/databricks/labs/validation/ValidatorTestSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ class ValidatorTestSuite extends org.scalatest.FunSuite with SparkSessionFixture
5050
assert(rulesReport.count() == 10)
5151
}
5252

53-
test("The input rule should have 3 invalid count for MinMax_Scan_Price_Minus_Retail_Price_min for failing complex type.") {
53+
test("The input rule should have 1 invalid count for MinMax_Scan_Price_Minus_Retail_Price_min and max for failing complex type.") {
5454
val expectedDF = Seq(
55-
("MinMax_Retail_Price_Minus_Scan_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false),
56-
("MinMax_Retail_Price_Minus_Scan_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),3,true),
55+
("MinMax_Retail_Price_Minus_Scan_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),1,true),
56+
("MinMax_Retail_Price_Minus_Scan_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),1,true),
5757
("MinMax_Scan_Price_Minus_Retail_Price_max","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false),
5858
("MinMax_Scan_Price_Minus_Retail_Price_min","bounds",ValidationValue(null,null,Array(0.0, 29.99),null),0,false)
5959
).toDF("Rule_Name","Rule_Type","Validation_Values","Invalid_Count","Failed")

0 commit comments

Comments
 (0)