@@ -220,9 +220,9 @@ object RandomDataGenerator {
220220 val numCompleteWeeks = s " CAST( $remainingDays / 7 AS INT) "
221221
222222 // Generate random week (0 to numCompleteWeeks-1) and weekday (0-4)
223- // Use separate RAND() calls for week and weekday to ensure proper independence
223+ // Use separate seeded RAND calls for week and weekday to ensure proper independence
224224 val randWeek = s " CAST( $sqlRandom * $numCompleteWeeks AS INT) "
225- val randWeekday = s " CAST(RAND() * 5 AS INT) "
225+ val randWeekday = s " CAST( ${sqlRandomWithOffset( 1 )} * 5 AS INT) "
226226
227227 s " DATE_ADD( $firstMonday, $randWeek * 7 + $randWeekday) "
228228 } else {
@@ -291,7 +291,7 @@ object RandomDataGenerator {
291291 }
292292
293293 override def generateSqlExpression : String = {
294- s " TO_BINARY(ARRAY_JOIN(TRANSFORM(ARRAY_REPEAT(1, CAST( $sqlRandom * ${maxLength - minLength} + $minLength AS INT)), i -> CHAR(ROUND( $sqlRandom * 94 + 32, 0))), ''), 'utf-8') "
294+ s " TO_BINARY(ARRAY_JOIN(TRANSFORM(ARRAY_REPEAT(1, CAST( $sqlRandom * ${maxLength - minLength} + $minLength AS INT)), i -> CHAR(ROUND( ${sqlRandomWithIndex( " i " )} * 94 + 32, 0))), ''), 'utf-8') "
295295 }
296296 }
297297
@@ -303,7 +303,7 @@ object RandomDataGenerator {
303303 }
304304
305305 override def generateSqlExpression : String = {
306- s " TO_BINARY(CHAR(ROUND( $sqlRandom * 94 + 32, 0))) "
306+ s " TO_BINARY(CHAR(ROUND( $sqlRandom * 94 + 32, 0)), 'utf-8' ) "
307307 }
308308 }
309309
@@ -385,7 +385,7 @@ object RandomDataGenerator {
385385 val arraySize = valuesStr.split(" ," ).length
386386 val randomIndexExpr = s " CAST( $sqlRandom * $arraySize + 1 AS INT) "
387387 val sizeExpr = s " CAST( $sqlRandom * ${arrayMaxSize - arrayMinSize} + $arrayMinSize AS INT) "
388- val arrayExpr = s " TRANSFORM(ARRAY_REPEAT(1, $sizeExpr), i -> ELEMENT_AT( $valuesArray, CAST(RAND() * $arraySize + 1 AS INT))) "
388+ val arrayExpr = s " TRANSFORM(ARRAY_REPEAT(1, $sizeExpr), i -> ELEMENT_AT( $valuesArray, CAST( ${sqlRandomWithIndex( " i " )} * $arraySize + 1 AS INT))) "
389389
390390 applyEmptyProbability(arrayExpr)
391391 }
@@ -406,22 +406,25 @@ object RandomDataGenerator {
406406
407407 // Calculate cumulative weights
408408 val totalWeight = weightedPairs.map(_._2).sum
409+ require(totalWeight > 0 ,
410+ s " Invalid weights in field ' ${structField.name}': total weight must be greater than 0. " )
409411 val normalizedWeights = weightedPairs.map { case (v, w) => (v, w / totalWeight) }
410-
411- // Build CASE WHEN statement for weighted selection
412- // Use RAND() directly in each WHEN clause to avoid nested subquery
413- var cumulativeProb = 0.0
414- val caseStatements = normalizedWeights.zipWithIndex.map { case ((value, weight), idx) =>
415- val prevCumulativeProb = cumulativeProb
416- cumulativeProb += weight
417- if (idx == normalizedWeights.length - 1 ) {
418- s " ELSE $value"
419- } else {
420- s " WHEN RAND() < $cumulativeProb THEN $value"
421- }
422- }.mkString(" " )
423-
424- val weightedSelectExpr = s " (CASE $caseStatements END) "
412+ val cumulativeThresholds = normalizedWeights.scanLeft(0.0 )(_ + _._2).tail
413+
414+ val valuesArrayExpr = s " ARRAY( ${normalizedWeights.map(_._1).mkString(" ," )}) "
415+ val thresholdsArrayExpr = s " ARRAY( ${cumulativeThresholds.map(_.toString).mkString(" ," )}) "
416+ val zippedExpr = s " ZIP_WITH( $valuesArrayExpr, $thresholdsArrayExpr, (v, t) -> named_struct('value', v, 'threshold', t)) "
417+
418+ // Use a single RAND() per element to compare against cumulative thresholds
419+ val weightedSelectExpr =
420+ s """ AGGREGATE(
421+ | $zippedExpr,
422+ | named_struct('r', ${sqlRandomWithIndex(" i" )}, 'picked', CAST(NULL AS ${dataType.sql})),
423+ | (acc, x) -> IF(acc.picked IS NOT NULL, acc,
424+ | IF(acc.r < x.threshold, named_struct('r', acc.r, 'picked', x.value), acc)
425+ | ),
426+ | acc -> acc.picked
427+ |) """ .stripMargin
425428 val sizeExpr = s " CAST( $sqlRandom * ${arrayMaxSize - arrayMinSize} + $arrayMinSize AS INT) "
426429 val arrayExpr = s " TRANSFORM(ARRAY_REPEAT(1, $sizeExpr), i -> $weightedSelectExpr) "
427430
@@ -442,17 +445,28 @@ object RandomDataGenerator {
442445 assert(mapMinSize <= mapMaxSize, s " mapMinSize has to be less than or equal to mapMaxSize, " +
443446 s " field-name= ${structField.name}, mapMinSize= $mapMinSize, mapMaxSize= $mapMaxSize" )
444447
445- override def keyGenerator : DataGenerator [T ] = getGeneratorForDataType(keyDataType).asInstanceOf [DataGenerator [T ]]
448+ private def seededMetadata : Metadata = optRandomSeed match {
449+ case Some (seed) => new MetadataBuilder ().putString(RANDOM_SEED , seed.toString).build()
450+ case None => Metadata .empty
451+ }
452+
453+ override def keyGenerator : DataGenerator [T ] =
454+ getGeneratorForStructField(StructField (structField.name, keyDataType, nullable = true , seededMetadata), faker)
455+ .asInstanceOf [DataGenerator [T ]]
446456
447- override def valueGenerator : DataGenerator [K ] = getGeneratorForDataType(valueDataType).asInstanceOf [DataGenerator [K ]]
457+ override def valueGenerator : DataGenerator [K ] =
458+ getGeneratorForStructField(StructField (structField.name, valueDataType, nullable = true , seededMetadata), faker)
459+ .asInstanceOf [DataGenerator [K ]]
448460
449461 // how to make it empty map when size is 0
450462 override def generateSqlExpression : String = {
451- val keyDataGenerator = getGeneratorForDataType( keyDataType)
452- val valueDataGenerator = getGeneratorForDataType( valueDataType)
463+ val keyDataGenerator = getGeneratorForStructField( StructField (structField.name, keyDataType, nullable = true , seededMetadata), faker )
464+ val valueDataGenerator = getGeneratorForStructField( StructField (structField.name, valueDataType, nullable = true , seededMetadata), faker )
453465 val keySql = keyDataGenerator.generateSqlExpressionWrapper
454466 val valueSql = valueDataGenerator.generateSqlExpressionWrapper
455- s " STR_TO_MAP(CONCAT_WS(',', TRANSFORM(ARRAY_REPEAT(1, CAST( $sqlRandom * ${mapMaxSize - mapMinSize} + $mapMinSize AS INT)), i -> CONCAT( $keySql, '->', $valueSql))), '->', ',') "
467+ val keySqlWithIndex = optRandomSeed.map(_ => keySql.replace(sqlRandom, sqlRandomWithIndex(" i" ))).getOrElse(keySql)
468+ val valueSqlWithIndex = optRandomSeed.map(_ => valueSql.replace(sqlRandom, sqlRandomWithIndex(" i" ))).getOrElse(valueSql)
469+ s " STR_TO_MAP(CONCAT_WS(',', TRANSFORM(ARRAY_REPEAT(1, CAST( $sqlRandom * ${mapMaxSize - mapMinSize} + $mapMinSize AS INT)), i -> CONCAT( $keySqlWithIndex, '->', $valueSqlWithIndex))), '->', ',') "
456470 }
457471 }
458472
0 commit comments