Skip to content

Commit f7ebf98

Browse files
authored
Merge pull request #6 from stikkireddy/issue-5
Adding unit tests and test reporting to the dataframe-rules-engine
2 parents d008c4f + 1bef366 commit f7ebf98

12 files changed

Lines changed: 439 additions & 33 deletions

File tree

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ project/project/
99
project/target/
1010
*.DS_Store
1111
/target/
12-
/project/
12+
/project/build.properties

CONTRIBUTING.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
We happily welcome contributions to *Databricks Labs - dataframe-rules-engine*.
2-
We use GitHub Issues to track community reported issues and GitHub Pull Requests for accepting changes.
2+
We use GitHub Issues to track community reported issues and GitHub Pull Requests for accepting changes.
3+
Please make a fork of this repository and submit a pull request.

README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,21 @@ cd Downloads
182182
git pull repo
183183
sbt clean package
184184
```
185+
186+
## Running tests
187+
To run tests on the project: <br>
188+
```
189+
sbt test
190+
```
191+
192+
Make sure that your JAVA_HOME is setup for sbt to run the tests properly. You will need JDK 8 as Spark does
193+
not support newer versions of the JDK.
194+
195+
## Test reports for test coverage
196+
To get test coverage report for the project: <br>
197+
```
198+
sbt jacoco
199+
```
200+
201+
The test reports can be found in target/scala-<version>/jacoco/
202+

build.sbt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,26 @@ scalacOptions ++= Seq("-Xmax-classfile-name", "78")
99

1010
libraryDependencies += "org.apache.spark" %% "spark-core" % "2.4.0"
1111
libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.4.0"
12+
libraryDependencies += "org.scalactic" %% "scalactic" % "3.1.1"
13+
libraryDependencies += "org.scalatest" %% "scalatest" % "3.1.1" % "test"
14+
15+
lazy val excludes = jacocoExcludes in Test := Seq()
16+
17+
lazy val jacoco = jacocoReportSettings in test :=JacocoReportSettings(
18+
"Jacoco Scala Example Coverage Report",
19+
None,
20+
JacocoThresholds (branch = 100),
21+
Seq(JacocoReportFormats.ScalaHTML,
22+
JacocoReportFormats.CSV),
23+
"utf-8")
24+
25+
val jacocoSettings = Seq(jacoco)
26+
lazy val jse = (project in file (".")).settings(jacocoSettings: _*)
27+
28+
fork in Test := true
29+
javaOptions ++= Seq("-Xms512M", "-Xmx2048M", "-XX:+CMSClassUnloadingEnabled")
30+
testOptions in Test += Tests.Argument(TestFrameworks.ScalaTest, "-oD")
31+
1232

1333
lazy val commonSettings = Seq(
1434
version := "0.1.1",

project/plugins.sbt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
addSbtPlugin("com.github.sbt" % "sbt-jacoco" % "3.0.3")
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package com.databricks.labs.validation
2+
3+
import com.databricks.labs.validation.utils.Structures.{Bounds, MinMaxRuleDef}
4+
import com.databricks.labs.validation.utils.SparkSessionWrapper
5+
import org.apache.spark.sql.functions._
6+
7+
object QuickTest extends App with SparkSessionWrapper {
8+
9+
import spark.implicits._
10+
11+
val testDF = Seq(
12+
(1, 2, 3),
13+
(4, 5, 6),
14+
(7, 8, 9)
15+
).toDF("retail_price", "scan_price", "cost")
16+
17+
Rule("Reasonable_sku_counts", count(col("sku")), Bounds(lower = 20.0, upper = 200.0))
18+
19+
val minMaxPriceDefs = Array(
20+
MinMaxRuleDef("MinMax_Retail_Price_Minus_Scan_Price", col("retail_price")-col("scan_price"), Bounds(0.0, 29.99)),
21+
MinMaxRuleDef("MinMax_Scan_Price_Minus_Retail_Price", col("scan_price")-col("retail_price"), Bounds(0.0, 29.99))
22+
)
23+
24+
val someRuleSet = RuleSet(testDF)
25+
.add(Rule("retail_pass", col("retail_price"), Bounds(lower = 1.0, upper = 7.0)))
26+
.add(Rule("retail_agg_pass_high", max(col("retail_price")), Bounds(lower = 0.0, upper = 7.1)))
27+
.add(Rule("retail_agg_pass_low", min(col("retail_price")), Bounds(lower = 0.0, upper = 7.0)))
28+
.add(Rule("retail_fail_low", col("retail_price"), Bounds(lower = 1.1, upper = 7.0)))
29+
.add(Rule("retail_fail_high", col("retail_price"), Bounds(lower = 0.0, upper = 6.9)))
30+
.add(Rule("retail_agg_fail_high", max(col("retail_price")), Bounds(lower = 0.0, upper = 6.9)))
31+
.add(Rule("retail_agg_fail_low", min(col("retail_price")), Bounds(lower = 1.1, upper = 7.0)))
32+
.addMinMaxRules(minMaxPriceDefs: _*)
33+
val (rulesReport, passed) = someRuleSet.validate()
34+
35+
testDF.show(20, false)
36+
rulesReport.show(20, false)
37+
}

src/main/scala/com/databricks/labs/validation/Rule.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class Rule {
2020
private var _validNumerics: Array[Double] = _
2121
private var _validStrings: Array[String] = _
2222
private var _dateTimeLogic: Column = _
23-
private var _ruleType: String = _
23+
private var _ruleType: RuleType.Value = _
2424
private var _isAgg: Boolean = _
2525

2626
private def setRuleName(value: String): this.type = {
@@ -69,7 +69,7 @@ class Rule {
6969
this
7070
}
7171

72-
private def setRuleType(value: String): this.type = {
72+
private def setRuleType(value: RuleType.Value): this.type = {
7373
_ruleType = value
7474
this
7575
}
@@ -99,7 +99,7 @@ class Rule {
9999

100100
def dateTimeLogic: Column = _dateTimeLogic
101101

102-
def ruleType: String = _ruleType
102+
def ruleType: RuleType.Value = _ruleType
103103

104104
private[validation] def isAgg: Boolean = _isAgg
105105

@@ -121,7 +121,7 @@ object Rule {
121121
.setRuleName(ruleName)
122122
.setColumn(column)
123123
.setBoundaries(boundaries)
124-
.setRuleType("bounds")
124+
.setRuleType(RuleType.ValidateBounds)
125125
.setIsAgg
126126
}
127127

@@ -135,7 +135,7 @@ object Rule {
135135
.setRuleName(ruleName)
136136
.setColumn(column)
137137
.setValidNumerics(validNumerics)
138-
.setRuleType("validNumerics")
138+
.setRuleType(RuleType.ValidateNumerics)
139139
.setIsAgg
140140
}
141141

@@ -149,7 +149,7 @@ object Rule {
149149
.setRuleName(ruleName)
150150
.setColumn(column)
151151
.setValidNumerics(validNumerics.map(_.toString.toDouble))
152-
.setRuleType("validNumerics")
152+
.setRuleType(RuleType.ValidateNumerics)
153153
.setIsAgg
154154
}
155155

@@ -163,7 +163,7 @@ object Rule {
163163
.setRuleName(ruleName)
164164
.setColumn(column)
165165
.setValidNumerics(validNumerics.map(_.toString.toDouble))
166-
.setRuleType("validNumerics")
166+
.setRuleType(RuleType.ValidateNumerics)
167167
.setIsAgg
168168
}
169169

@@ -177,7 +177,7 @@ object Rule {
177177
.setRuleName(ruleName)
178178
.setColumn(column)
179179
.setValidStrings(validStrings)
180-
.setRuleType("validStrings")
180+
.setRuleType(RuleType.ValidateStrings)
181181
.setIsAgg
182182
}
183183

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package com.databricks.labs.validation
2+
3+
/**
4+
* Definition of the Rule Types as an Enumeration for better type matching
5+
*/
6+
object RuleType extends Enumeration {
7+
val ValidateBounds = Value("bounds")
8+
val ValidateNumerics = Value("validNumerics")
9+
val ValidateStrings = Value("validStrings")
10+
val ValidateDateTime = Value("validDateTime")
11+
val ValidateComplex = Value("complex")
12+
}

src/main/scala/com/databricks/labs/validation/Validator.scala

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper {
1313

1414
import spark.implicits._
1515

16-
private val boundaryRules = ruleSet.getRules.filter(_.ruleType == "bounds")
17-
private val categoricalRules = ruleSet.getRules.filter(rule => rule.ruleType == "validNumerics" ||
18-
rule.ruleType == "validStrings")
19-
private val dateTimeRules = ruleSet.getRules.filter(_.ruleType == "dateTime")
20-
private val complexRules = ruleSet.getRules.filter(_.ruleType == "complex")
16+
private val boundaryRules = ruleSet.getRules.filter(_.ruleType == RuleType.ValidateBounds)
17+
private val categoricalRules = ruleSet.getRules.filter(rule => rule.ruleType == RuleType.ValidateNumerics ||
18+
rule.ruleType == RuleType.ValidateStrings)
19+
private val dateTimeRules = ruleSet.getRules.filter(_.ruleType == RuleType.ValidateDateTime)
20+
private val complexRules = ruleSet.getRules.filter(_.ruleType == RuleType.ValidateComplex)
2121
private val byCols = ruleSet.getGroupBys map col
2222

2323
/**
@@ -36,15 +36,15 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper {
3636
*/
3737
private def buildValidationsByType(rule: Rule): Column = {
3838
val nulls = mutable.Map[String, Column](
39-
"bounds" -> lit(null).cast(ArrayType(DoubleType)).alias("bounds"),
40-
"validNumerics" -> lit(null).cast(ArrayType(DoubleType)).alias("validNumerics"),
41-
"validStrings" -> lit(null).cast(ArrayType(StringType)).alias("validStrings"),
42-
"validDate" -> lit(null).cast(LongType).alias("validDate")
39+
RuleType.ValidateBounds.toString -> lit(null).cast(ArrayType(DoubleType)).alias(RuleType.ValidateBounds.toString),
40+
RuleType.ValidateNumerics.toString -> lit(null).cast(ArrayType(DoubleType)).alias(RuleType.ValidateNumerics.toString),
41+
RuleType.ValidateStrings.toString -> lit(null).cast(ArrayType(StringType)).alias(RuleType.ValidateStrings.toString),
42+
RuleType.ValidateDateTime.toString -> lit(null).cast(LongType).alias(RuleType.ValidateDateTime.toString)
4343
)
4444
rule.ruleType match {
45-
case "bounds" => nulls("bounds") = array(lit(rule.boundaries.lower), lit(rule.boundaries.upper)).alias("bounds")
46-
case "validNumerics" => nulls("validNumerics") = lit(rule.validNumerics).alias("validNumerics")
47-
case "validStrings" => nulls("validStrings") = lit(rule.validStrings).alias("validStrings")
45+
case RuleType.ValidateBounds => nulls(RuleType.ValidateBounds.toString) = array(lit(rule.boundaries.lower), lit(rule.boundaries.upper)).alias(RuleType.ValidateBounds.toString)
46+
case RuleType.ValidateNumerics => nulls(RuleType.ValidateNumerics.toString) = lit(rule.validNumerics).alias(RuleType.ValidateNumerics.toString)
47+
case RuleType.ValidateStrings => nulls(RuleType.ValidateStrings.toString) = lit(rule.validStrings).alias(RuleType.ValidateStrings.toString)
4848
}
4949
val validationsByType = nulls.toMap.values.toSeq
5050
struct(
@@ -61,7 +61,7 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper {
6161
private def buildOutputStruct(rule: Rule, results: Seq[Column]): Column = {
6262
struct(
6363
lit(rule.ruleName).alias("Rule_Name"),
64-
lit(rule.ruleType).alias("Rule_Type"),
64+
lit(rule.ruleType.toString).alias("Rule_Type"),
6565
buildValidationsByType(rule),
6666
struct(results: _*).alias("Results")
6767
).alias("Validation")
@@ -101,28 +101,42 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper {
101101

102102
// Results must have Invalid_Count & Failed
103103
rule.ruleType match {
104-
case "bounds" =>
104+
case RuleType.ValidateBounds =>
105+
// Rule evaluation for NON-AGG RULES ONLY
105106
val invalid = rule.inputColumn < rule.boundaries.lower || rule.inputColumn > rule.boundaries.upper
107+
// This is the first select it must come before subsequent selects as it aliases the original column name
108+
// to that of the rule name. ADDITIONALLY, this evaluates the boundary rule WHEN the input col is not an Agg.
109+
// This can be confusing because for Non-agg columns it renames the column to the rule_name AND returns a 0
110+
// or 1 (not the original value)
111+
// IF the rule is NOT an AGG then the column is simply aliased to the rule name and no evaluation takes place
112+
// here.
113+
val first = if (!rule.isAgg) { // Not Agg
114+
sum(when(invalid, 1).otherwise(0)).alias(rule.ruleName)
115+
} else { // Is Agg
116+
rule.inputColumn.alias(rule.ruleName)
117+
}
118+
// WHEN RULE IS AGG -- this is where the evaluation happens. The input column was renamed to the name of the
119+
// rule in the required previous select.
120+
// IMPORTANT: REMEMBER - that agg expressions evaluate to a single output value thus the invalid_count in
121+
// cases where agg is used cannot be > 1 since the sum of a single value cannot exceed 1.
122+
123+
// WHEN RULE NOT AGG - determine if the result of "first" select (0 or 1) is > 0, if it is, the rule has
124+
// failed since the sum(1 or more 1s) means that 1 or more rows have failed thus the rule has failed
106125
val failed = if (rule.isAgg) {
107126
when(
108127
col(rule.ruleName) < rule.boundaries.lower || col(rule.ruleName) > rule.boundaries.upper, true)
109128
.otherwise(false).alias("Failed")
110129
} else{
111130
when(col(rule.ruleName) > 0,true).otherwise(false).alias("Failed")
112131
}
113-
val first = if (!rule.isAgg) { // Not Agg
114-
sum(when(invalid, 1).otherwise(0)).alias(rule.ruleName)
115-
} else { // Is Agg
116-
rule.inputColumn.alias(rule.ruleName)
117-
}
118132
val results = if (rule.isAgg) {
119133
Seq(when(failed, 1).otherwise(0).cast(LongType).alias("Invalid_Count"), failed)
120134
} else {
121135
Seq(col(rule.ruleName).cast(LongType).alias("Invalid_Count"), failed)
122136
}
123137
Selects(buildOutputStruct(rule, results), first)
124-
case x if x == "validNumerics" || x == "validStrings" =>
125-
val invalid = if (x == "validNumerics") {
138+
case x if x == RuleType.ValidateNumerics || x == RuleType.ValidateStrings =>
139+
val invalid = if (x == RuleType.ValidateNumerics) {
126140
expr(s"size(array_except(${rule.ruleName}," +
127141
s"array(${rule.validNumerics.mkString("D,")}D)))")
128142
} else {
@@ -134,8 +148,8 @@ class Validator(ruleSet: RuleSet, detailLvl: Int) extends SparkSessionWrapper {
134148
val first = collect_set(rule.inputColumn).alias(rule.ruleName)
135149
val results = Seq(invalid.cast(LongType).alias("Invalid_Count"), failed)
136150
Selects(buildOutputStruct(rule, results), first)
137-
case "validDate" => ??? // TODO
138-
case "complex" => ??? // TODO
151+
case RuleType.ValidateDateTime => ??? // TODO
152+
case RuleType.ValidateComplex => ??? // TODO
139153
}
140154
})
141155
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Set everything to be logged to the console
2+
log4j.rootCategory=ERROR, console
3+
log4j.appender.console=org.apache.log4j.ConsoleAppender
4+
log4j.appender.console.target=System.err
5+
log4j.appender.console.layout=org.apache.log4j.PatternLayout
6+
log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n
7+
8+
# Settings to quiet third party logs that are too verbose
9+
log4j.logger.org.eclipse.jetty=WARN
10+
log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR
11+
log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=WARN
12+
log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=WARN
13+
log4j.logger.org.apache.spark.sql.SparkSession$Builder=ERROR

0 commit comments

Comments
 (0)