Skip to content

Commit 1d84132

Browse files
committed
feat: Enhance SQL generation with seeded randomness and improved determinism
- Introduced new methods for generating SQL expressions with seeded randomness in various data generators, ensuring consistent and varied outputs. - Updated `DataGeneratorFactory` to utilize new random expression methods for weight calculations. - Refactored SQL generation logic in `RandomDataGenerator`, `OneOfDataGenerator`, and `RegexNode` to support indexed random values. - Added tests to verify deterministic behavior of seeded generators, ensuring expected outputs across multiple runs. - Enhanced `DataGeneratorDeterminismTest` to validate the consistency and variability of generated values with seeded configurations. These changes improve the reliability of data generation processes, particularly in scenarios requiring reproducible results.
1 parent c68aede commit 1d84132

8 files changed

Lines changed: 308 additions & 85 deletions

File tree

app/src/integrationTest/scala/io/github/datacatering/datacaterer/core/generator/EnhancedForeignKeyIntegrationTest.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,14 +461,14 @@ class EnhancedForeignKeyIntegrationTest extends SparkSuite with Matchers with Be
461461
class RatioCardinalityTestPlan(customersPath: String, ordersPath: String) extends PlanRun {
462462
val customers = csv("customers", customersPath, Map("saveMode" -> "overwrite", "header" -> "true"))
463463
.fields(
464-
field.name("customer_id").regex("CUST[0-9]{6}"),
464+
field.name("customer_id").regex("CUST[0-9]{10}"),
465465
field.name("name").expression("#{Name.name}")
466466
)
467467
.count(count.records(50))
468468

469469
val orders = csv("orders", ordersPath, Map("saveMode" -> "overwrite", "header" -> "true"))
470470
.fields(
471-
field.name("order_id").regex("ORD[0-9]{8}"),
471+
field.name("order_id").regex("ORD[0-9]{10}"),
472472
field.name("customer_id"),
473473
field.name("amount").`type`(DoubleType).min(10.0).max(1000.0)
474474
)

app/src/main/scala/io/github/datacatering/datacaterer/core/generator/DataGeneratorFactory.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,9 @@ class DataGeneratorFactory(faker: Faker, enableFastGeneration: Boolean = false)(
130130
val namedStruct = dataGenerators.map(dg => s"'${dg.structField.name}', CAST(${dg.generateSqlExpressionWrapper} AS ${dg.structField.dataType.sql})").mkString(",")
131131
//if it is using a weighted oneOf generator, it will have a weight column
132132
val countGeneratorSql = countGenerator.generateSqlExpressionWrapper
133-
val optWeightCol = if (countGeneratorSql.contains(RECORD_COUNT_GENERATOR_WEIGHT_FIELD)) Array(s"RAND() AS $RECORD_COUNT_GENERATOR_WEIGHT_FIELD") else Array[String]()
133+
val optWeightCol = if (countGeneratorSql.contains(RECORD_COUNT_GENERATOR_WEIGHT_FIELD)) {
134+
Array(s"${countGenerator.sqlRandomWithOffset(1)} AS $RECORD_COUNT_GENERATOR_WEIGHT_FIELD")
135+
} else Array[String]()
134136
val perCountGeneratedExpr = df.columns ++ optWeightCol ++ Array(
135137
s"CAST($countGeneratorSql AS INT) AS $PER_FIELD_COUNT_GENERATED_NUM",
136138
s"CASE WHEN $PER_FIELD_COUNT_GENERATED_NUM = 0 THEN ARRAY() ELSE SEQUENCE(1, $PER_FIELD_COUNT_GENERATED_NUM) END AS $PER_FIELD_COUNT_GENERATED"

app/src/main/scala/io/github/datacatering/datacaterer/core/generator/provider/DataGenerator.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ trait DataGenerator[T] extends BaseGenerator[T] with Serializable {
2121

2222
lazy val optRandomSeed: Option[Long] = if (structField.metadata.contains(RANDOM_SEED)) Some(structField.metadata.getString(RANDOM_SEED).toLong) else None
2323
lazy val sqlRandom: String = optRandomSeed.map(seed => s"RAND($seed)").getOrElse("RAND()")
24+
def sqlRandomWithOffset(offset: Long): String = optRandomSeed.map(seed => s"RAND(${seed + offset})").getOrElse("RAND()")
25+
def sqlRandomWithIndex(indexExpr: String): String = optRandomSeed match {
26+
case Some(seed) =>
27+
s"(CAST((xxhash64(monotonically_increasing_id(), $seed, $indexExpr) & ${Long.MaxValue}) AS DOUBLE) / ${Long.MaxValue.toDouble})"
28+
case None => "RAND()"
29+
}
2430
lazy val random: Random = if (structField.metadata.contains(RANDOM_SEED)) new Random(structField.metadata.getString(RANDOM_SEED).toLong) else new Random()
2531
lazy val enabledNull: Boolean = if (structField.metadata.contains(ENABLED_NULL)) structField.metadata.getString(ENABLED_NULL).toBoolean else false
2632
lazy val enabledEdgeCases: Boolean = if (structField.metadata.contains(ENABLED_EDGE_CASE)) structField.metadata.getString(ENABLED_EDGE_CASE).toBoolean else false

app/src/main/scala/io/github/datacatering/datacaterer/core/generator/provider/FastDataGenerator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ object FastDataGenerator {
102102
parsedPattern match {
103103
case scala.util.Success(node) =>
104104
// Successfully parsed - use pure SQL generation
105-
node.toSql
105+
node.toSql(sqlRandom, sqlRandomWithIndex)
106106

107107
case scala.util.Failure(_) =>
108108
// Parser failed - fall back to UDF (slower but correct)

app/src/main/scala/io/github/datacatering/datacaterer/core/generator/provider/OneOfDataGenerator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ object OneOfDataGenerator {
2727

2828
override def generateSqlExpression: String = {
2929
val oneOfValuesString = oneOfValues.mkString("||")
30-
s"CAST(SPLIT('$oneOfValuesString', '\\\\|\\\\|')[CAST(RAND() * $oneOfArrayLength AS INT)] AS ${structField.dataType.sql})"
30+
s"CAST(SPLIT('$oneOfValuesString', '\\\\|\\\\|')[CAST($sqlRandom * $oneOfArrayLength AS INT)] AS ${structField.dataType.sql})"
3131
}
3232
}
3333

app/src/main/scala/io/github/datacatering/datacaterer/core/generator/provider/RandomDataGenerator.scala

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,9 @@ object RandomDataGenerator {
220220
val numCompleteWeeks = s"CAST($remainingDays / 7 AS INT)"
221221

222222
// Generate random week (0 to numCompleteWeeks-1) and weekday (0-4)
223-
// Use separate RAND() calls for week and weekday to ensure proper independence
223+
// Use separate seeded RAND calls for week and weekday to ensure proper independence
224224
val randWeek = s"CAST($sqlRandom * $numCompleteWeeks AS INT)"
225-
val randWeekday = s"CAST(RAND() * 5 AS INT)"
225+
val randWeekday = s"CAST(${sqlRandomWithOffset(1)} * 5 AS INT)"
226226

227227
s"DATE_ADD($firstMonday, $randWeek * 7 + $randWeekday)"
228228
} else {
@@ -291,7 +291,7 @@ object RandomDataGenerator {
291291
}
292292

293293
override def generateSqlExpression: String = {
294-
s"TO_BINARY(ARRAY_JOIN(TRANSFORM(ARRAY_REPEAT(1, CAST($sqlRandom * ${maxLength - minLength} + $minLength AS INT)), i -> CHAR(ROUND($sqlRandom * 94 + 32, 0))), ''), 'utf-8')"
294+
s"TO_BINARY(ARRAY_JOIN(TRANSFORM(ARRAY_REPEAT(1, CAST($sqlRandom * ${maxLength - minLength} + $minLength AS INT)), i -> CHAR(ROUND(${sqlRandomWithIndex("i")} * 94 + 32, 0))), ''), 'utf-8')"
295295
}
296296
}
297297

@@ -303,7 +303,7 @@ object RandomDataGenerator {
303303
}
304304

305305
override def generateSqlExpression: String = {
306-
s"TO_BINARY(CHAR(ROUND($sqlRandom * 94 + 32, 0)))"
306+
s"TO_BINARY(CHAR(ROUND($sqlRandom * 94 + 32, 0)), 'utf-8')"
307307
}
308308
}
309309

@@ -385,7 +385,7 @@ object RandomDataGenerator {
385385
val arraySize = valuesStr.split(",").length
386386
val randomIndexExpr = s"CAST($sqlRandom * $arraySize + 1 AS INT)"
387387
val sizeExpr = s"CAST($sqlRandom * ${arrayMaxSize - arrayMinSize} + $arrayMinSize AS INT)"
388-
val arrayExpr = s"TRANSFORM(ARRAY_REPEAT(1, $sizeExpr), i -> ELEMENT_AT($valuesArray, CAST(RAND() * $arraySize + 1 AS INT)))"
388+
val arrayExpr = s"TRANSFORM(ARRAY_REPEAT(1, $sizeExpr), i -> ELEMENT_AT($valuesArray, CAST(${sqlRandomWithIndex("i")} * $arraySize + 1 AS INT)))"
389389

390390
applyEmptyProbability(arrayExpr)
391391
}
@@ -406,22 +406,25 @@ object RandomDataGenerator {
406406

407407
// Calculate cumulative weights
408408
val totalWeight = weightedPairs.map(_._2).sum
409+
require(totalWeight > 0,
410+
s"Invalid weights in field '${structField.name}': total weight must be greater than 0.")
409411
val normalizedWeights = weightedPairs.map { case (v, w) => (v, w / totalWeight) }
410-
411-
// Build CASE WHEN statement for weighted selection
412-
// Use RAND() directly in each WHEN clause to avoid nested subquery
413-
var cumulativeProb = 0.0
414-
val caseStatements = normalizedWeights.zipWithIndex.map { case ((value, weight), idx) =>
415-
val prevCumulativeProb = cumulativeProb
416-
cumulativeProb += weight
417-
if (idx == normalizedWeights.length - 1) {
418-
s"ELSE $value"
419-
} else {
420-
s"WHEN RAND() < $cumulativeProb THEN $value"
421-
}
422-
}.mkString(" ")
423-
424-
val weightedSelectExpr = s"(CASE $caseStatements END)"
412+
val cumulativeThresholds = normalizedWeights.scanLeft(0.0)(_ + _._2).tail
413+
414+
val valuesArrayExpr = s"ARRAY(${normalizedWeights.map(_._1).mkString(",")})"
415+
val thresholdsArrayExpr = s"ARRAY(${cumulativeThresholds.map(_.toString).mkString(",")})"
416+
val zippedExpr = s"ZIP_WITH($valuesArrayExpr, $thresholdsArrayExpr, (v, t) -> named_struct('value', v, 'threshold', t))"
417+
418+
// Use a single RAND() per element to compare against cumulative thresholds
419+
val weightedSelectExpr =
420+
s"""AGGREGATE(
421+
| $zippedExpr,
422+
| named_struct('r', ${sqlRandomWithIndex("i")}, 'picked', CAST(NULL AS ${dataType.sql})),
423+
| (acc, x) -> IF(acc.picked IS NOT NULL, acc,
424+
| IF(acc.r < x.threshold, named_struct('r', acc.r, 'picked', x.value), acc)
425+
| ),
426+
| acc -> acc.picked
427+
|)""".stripMargin
425428
val sizeExpr = s"CAST($sqlRandom * ${arrayMaxSize - arrayMinSize} + $arrayMinSize AS INT)"
426429
val arrayExpr = s"TRANSFORM(ARRAY_REPEAT(1, $sizeExpr), i -> $weightedSelectExpr)"
427430

@@ -442,17 +445,28 @@ object RandomDataGenerator {
442445
assert(mapMinSize <= mapMaxSize, s"mapMinSize has to be less than or equal to mapMaxSize, " +
443446
s"field-name=${structField.name}, mapMinSize=$mapMinSize, mapMaxSize=$mapMaxSize")
444447

445-
override def keyGenerator: DataGenerator[T] = getGeneratorForDataType(keyDataType).asInstanceOf[DataGenerator[T]]
448+
private def seededMetadata: Metadata = optRandomSeed match {
449+
case Some(seed) => new MetadataBuilder().putString(RANDOM_SEED, seed.toString).build()
450+
case None => Metadata.empty
451+
}
452+
453+
override def keyGenerator: DataGenerator[T] =
454+
getGeneratorForStructField(StructField(structField.name, keyDataType, nullable = true, seededMetadata), faker)
455+
.asInstanceOf[DataGenerator[T]]
446456

447-
override def valueGenerator: DataGenerator[K] = getGeneratorForDataType(valueDataType).asInstanceOf[DataGenerator[K]]
457+
override def valueGenerator: DataGenerator[K] =
458+
getGeneratorForStructField(StructField(structField.name, valueDataType, nullable = true, seededMetadata), faker)
459+
.asInstanceOf[DataGenerator[K]]
448460

449461
//how to make it empty map when size is 0
450462
override def generateSqlExpression: String = {
451-
val keyDataGenerator = getGeneratorForDataType(keyDataType)
452-
val valueDataGenerator = getGeneratorForDataType(valueDataType)
463+
val keyDataGenerator = getGeneratorForStructField(StructField(structField.name, keyDataType, nullable = true, seededMetadata), faker)
464+
val valueDataGenerator = getGeneratorForStructField(StructField(structField.name, valueDataType, nullable = true, seededMetadata), faker)
453465
val keySql = keyDataGenerator.generateSqlExpressionWrapper
454466
val valueSql = valueDataGenerator.generateSqlExpressionWrapper
455-
s"STR_TO_MAP(CONCAT_WS(',', TRANSFORM(ARRAY_REPEAT(1, CAST($sqlRandom * ${mapMaxSize - mapMinSize} + $mapMinSize AS INT)), i -> CONCAT($keySql, '->', $valueSql))), '->', ',')"
467+
val keySqlWithIndex = optRandomSeed.map(_ => keySql.replace(sqlRandom, sqlRandomWithIndex("i"))).getOrElse(keySql)
468+
val valueSqlWithIndex = optRandomSeed.map(_ => valueSql.replace(sqlRandom, sqlRandomWithIndex("i"))).getOrElse(valueSql)
469+
s"STR_TO_MAP(CONCAT_WS(',', TRANSFORM(ARRAY_REPEAT(1, CAST($sqlRandom * ${mapMaxSize - mapMinSize} + $mapMinSize AS INT)), i -> CONCAT($keySqlWithIndex, '->', $valueSqlWithIndex))), '->', ',')"
456470
}
457471
}
458472

0 commit comments

Comments
 (0)