Skip to content

Commit d1c568e

Browse files
committed
Add unary tile log functions
Signed-off-by: Jason T. Brown <jason@astraea.earth>
1 parent 9b2f8c2 commit d1c568e

4 files changed

Lines changed: 133 additions & 50 deletions

File tree

core/src/main/scala/astraea/spark/rasterframes/RasterFunctions.scala

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -316,24 +316,16 @@ trait RasterFunctions {
316316
def log(tileCol: Column): TypedColumn[Any, Tile] =
317317
Log(tileCol)
318318

319-
/** Take logarithm of cell values with specified base. */
320-
def log[T: Numeric](tileCol: Column, base: T): TypedColumn[Any, Tile] =
321-
Log(tileCol, base)
322-
323-
/** Take logarithm of cell values with specified base. */
324-
def log(tileCol: Column, base: Column): TypedColumn[Any, Tile] =
325-
Log(tileCol, base)
326-
327319
/** Take base 10 logarithm of cell values. */
328320
def log10(tileCol: Column): TypedColumn[Any, Tile] =
329-
Log(tileCol, lit(10.0))
321+
Log10(tileCol)
330322

331323
/** Take base 2 logarithm of cell values. */
332324
def log2(tileCol: Column): TypedColumn[Any, Tile] =
333-
Log(tileCol, 2.0)
325+
Log2(tileCol)
334326

335327
/** Natural logarithm of one plus cell values. */
336328
def log1p(tileCol: Column): TypedColumn[Any, Tile] =
337-
Log(local_add(tileCol, 1.0))
329+
Log1p(tileCol)
338330

339331
}

core/src/main/scala/astraea/spark/rasterframes/expressions/localops/Log.scala

Lines changed: 72 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,41 +22,95 @@
2222
package astraea.spark.rasterframes.expressions.localops
2323

2424
import astraea.spark.rasterframes._
25-
import astraea.spark.rasterframes.expressions.UnaryRasterOp
26-
import astraea.spark.rasterframes.model.TileContext
25+
import astraea.spark.rasterframes.expressions.{UnaryLocalRasterOp, fpTile}
2726
import geotrellis.raster.Tile
28-
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, ExpressionDescription, UnaryExpression}
27+
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription}
2928
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
30-
import org.apache.spark.sql.functions.lit
3129
import org.apache.spark.sql.types.DataType
3230
import org.apache.spark.sql.{Column, TypedColumn}
3331

32+
3433
@ExpressionDescription(
35-
usage = "_FUNC_(tile, base) - Performs cell-wise logarithm.",
34+
usage = "_FUNC_(tile) - Performs cell-wise natural logarithm.",
3635
arguments = """
3736
Arguments:
38-
* tile - input tile
39-
* base - base for which to compute logarithm """,
37+
* tile - input tile""",
4038
examples = """
4139
Examples:
4240
> SELECT _FUNC_(tile);
43-
...
44-
> SELECT _FUNC_(tile, 10);
4541
..."""
4642
)
47-
case class Log(left: Expression, right: Expression) extends BinaryExpression with CodegenFallback {
43+
case class Log(child: Expression) extends UnaryLocalRasterOp with CodegenFallback {
4844
override val nodeName: String = "log"
49-
protected def op(left: Tile, right: Double): Tile = left.localLog() / math.log(right)
50-
protected def op(left: Tile, right: Int): Tile = op(left, right.toDouble)
5145

52-
override def dataType: DataType = left.dataType
46+
override protected def op(tile: Tile): Tile = fpTile(tile).localLog()
47+
48+
override def dataType: DataType = child.dataType
5349
}
5450
object Log {
5551
def apply(tile: Column): TypedColumn[Any, Tile] =
56-
new Column(Log(tile.expr, lit(math.E).expr)).as[Tile]
57-
def apply[N: Numeric](tile: Column, value: N): TypedColumn[Any, Tile] =
58-
new Column(Log(tile.expr, lit(value).expr)).as[Tile]
59-
def apply(tile: Column, value: Column): TypedColumn[Any, Tile] =
60-
new Column(Log(tile.expr, value.expr)).as[Tile]
52+
new Column(Log(tile.expr)).as[Tile]
53+
}
54+
55+
@ExpressionDescription(
56+
usage = "_FUNC_(tile) - Performs cell-wise logarithm with base 10.",
57+
arguments = """
58+
Arguments:
59+
* tile - input tile""",
60+
examples = """
61+
Examples:
62+
> SELECT _FUNC_(tile);
63+
..."""
64+
)
65+
case class Log10(child: Expression) extends UnaryLocalRasterOp with CodegenFallback {
66+
override val nodeName: String = "log10"
67+
68+
override protected def op(tile: Tile): Tile = fpTile(tile).localLog10()
69+
70+
override def dataType: DataType = child.dataType
71+
}
72+
object Log10 {
73+
def apply(tile: Column): TypedColumn[Any, Tile] = new Column(Log10(tile.expr)).as[Tile]
74+
}
75+
76+
@ExpressionDescription(
77+
usage = "_FUNC_(tile) - Performs cell-wise logarithm with base 2.",
78+
arguments = """
79+
Arguments:
80+
* tile - input tile""",
81+
examples = """
82+
Examples:
83+
> SELECT _FUNC_(tile);
84+
..."""
85+
)
86+
case class Log2(child: Expression) extends UnaryLocalRasterOp with CodegenFallback {
87+
override val nodeName: String = "log2"
88+
89+
override protected def op(tile: Tile): Tile = fpTile(tile).localLog() / math.log(2.0)
90+
91+
override def dataType: DataType = child.dataType
92+
}
93+
object Log2{
94+
def apply(tile: Column): TypedColumn[Any, Tile] = new Column(Log2(tile.expr)).as[Tile]
6195
}
6296

97+
@ExpressionDescription(
98+
usage = "_FUNC_(tile) - Performs natural logarithm of cell values plus one.",
99+
arguments = """
100+
Arguments:
101+
* tile - input tile""",
102+
examples = """
103+
Examples:
104+
> SELECT _FUNC_(tile);
105+
..."""
106+
)
107+
case class Log1p(child: Expression) extends UnaryLocalRasterOp with CodegenFallback {
108+
override val nodeName: String = "log1p"
109+
110+
override protected def op(tile: Tile): Tile = fpTile(tile).localAdd(1.0).localLog()
111+
112+
override def dataType: DataType = child.dataType
113+
}
114+
object Log1p{
115+
def apply(tile: Column): TypedColumn[Any, Tile] = new Column(Log1p(tile.expr)).as[Tile]
116+
}

core/src/main/scala/astraea/spark/rasterframes/expressions/package.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ package object expressions {
7777
registry.registerExpression[Equal]("rf_local_equal")
7878
registry.registerExpression[Unequal]("rf_local_unequal")
7979
registry.registerExpression[Sum]("rf_tile_sum")
80+
registry.registerExpression[Round]("rf_round")
81+
registry.registerExpression[Log]("rf_log")
82+
registry.registerExpression[Log10]("rf_log10")
83+
registry.registerExpression[Log2]("rf_log2")
84+
registry.registerExpression[Log1p]("rf_log1p")
8085
registry.registerExpression[TileToArrayDouble]("rf_tile_to_array_double")
8186
registry.registerExpression[TileToArrayInt]("rf_tile_to_array_int")
8287
registry.registerExpression[DataCells]("rf_data_cells")

core/src/test/scala/astraea/spark/rasterframes/RasterFunctionsSpec.scala

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import geotrellis.raster
2929
import geotrellis.raster.testkit.RasterMatchers
3030
import geotrellis.raster.{BitCellType, ByteUserDefinedNoDataCellType, DoubleConstantNoDataCellType, ShortConstantNoDataCellType, Tile, UByteConstantNoDataCellType}
3131
import geotrellis.vector.Extent
32-
import org.apache.spark.sql.Encoders
32+
import org.apache.spark.sql.{AnalysisException, Encoders}
3333
import org.apache.spark.sql.functions._
3434
import org.scalatest.{FunSpec, Matchers}
3535

@@ -45,6 +45,7 @@ class RasterFunctionsSpec extends FunSpec
4545
val tileSize = cols * rows
4646
val tileCount = 10
4747
val numND = 4
48+
lazy val zero = TestData.projectedRasterTile(cols, rows, 0, extent, crs, ct)
4849
lazy val one = TestData.projectedRasterTile(cols, rows, 1, extent, crs, ct)
4950
lazy val two = TestData.projectedRasterTile(cols, rows, 2, extent, crs, ct)
5051
lazy val three = TestData.projectedRasterTile(cols, rows, 3, extent, crs, ct)
@@ -55,6 +56,7 @@ class RasterFunctionsSpec extends FunSpec
5556

5657
lazy val randDoubleTile = TestData.projectedRasterTile(cols, rows, scala.util.Random.nextGaussian(), extent, crs, DoubleConstantNoDataCellType)
5758
lazy val randDoubleNDTile = TestData.injectND(numND)(randDoubleTile)
59+
lazy val randPositiveDoubleTile = TestData.projectedRasterTile(cols, rows, scala.util.Random.nextDouble() + 1e-6, extent, crs, DoubleConstantNoDataCellType)
5860

5961
val expectedRandNoData: Long = numND * tileCount
6062
val expectedRandData: Long = cols * rows * tileCount - expectedRandNoData
@@ -113,6 +115,9 @@ class RasterFunctionsSpec extends FunSpec
113115

114116
assertEqual(df.selectExpr("rf_local_divide(six, two)").as[ProjectedRasterTile].first(), three)
115117

118+
assertEqual(df.selectExpr("rf_local_multiply(rf_local_divide(six, 2.0), two)")
119+
.as[ProjectedRasterTile].first(), six)
120+
116121
val maybeThreeTile =
117122
df.select(local_divide(ExtractTile($"six"), ExtractTile($"two"))).as[Tile]
118123
assertEqual(maybeThreeTile.first(), three.toArrayTile())
@@ -540,40 +545,67 @@ class RasterFunctionsSpec extends FunSpec
540545

541546
val df = Seq((three_plus, three_less, three)).toDF("three_plus", "three_less", "three")
542547

543-
assertEqual(df.select(round($"three_plus")).as[Tile].first(), three_double)
544-
assertEqual(df.select(round($"three_less")).as[Tile].first(), three_double)
545-
assertEqual(df.select(round($"three")).as[Tile].first(), three)
548+
assertEqual(df.select(round($"three")).as[ProjectedRasterTile].first(), three)
549+
assertEqual(df.select(round($"three_plus")).as[ProjectedRasterTile].first(), three_double)
550+
assertEqual(df.select(round($"three_less")).as[ProjectedRasterTile].first(), three_double)
546551

552+
assertEqual(df.selectExpr("rf_round(three)").as[ProjectedRasterTile].first(), three)
547553
assertEqual(df.selectExpr("rf_round(three_plus)").as[ProjectedRasterTile].first(), three_double)
548554
assertEqual(df.selectExpr("rf_round(three_less)").as[ProjectedRasterTile].first(), three_double)
549-
assertEqual(df.selectExpr("rf_round(three)").as[ProjectedRasterTile].first(), three)
550555

551556
checkDocs("rf_round")
552557
}
553558

554-
it("should take logarithms"){
555-
// tile zeros ==> nodata
556-
val zeros = TestData.projectedRasterTile(cols, rows, 0, extent, crs, ct)
557-
val nd_float = TestData.projectedRasterTile(cols, rows, Double.NaN, extent, crs, DoubleConstantNoDataCellType)
558-
val df_0 = Seq(zeros).toDF("tile")
559-
assertEqual(df_0.select(log($"tile")).as[Tile].first(), nd_float)
560-
559+
it("should take logarithms positive cell values"){
561560
// log10 1000 == 3
562-
val one_k = TestData.projectedRasterTile(cols, rows, 1000, extent, crs, ShortConstantNoDataCellType)
563-
val threes_dbl = TestData.projectedRasterTile(cols, rows, 3.0, extent, crs, DoubleConstantNoDataCellType)
561+
val thousand = TestData.projectedRasterTile(cols, rows, 1000, extent, crs, ShortConstantNoDataCellType)
562+
val threesDouble = TestData.projectedRasterTile(cols, rows, 3.0, extent, crs, DoubleConstantNoDataCellType)
563+
val zerosDouble = TestData.projectedRasterTile(cols, rows, 0.0, extent, crs, DoubleConstantNoDataCellType)
564564

565-
val df_1 = Seq(one_k).toDF("tile")
566-
assertEqual(df_1.select(log10($"tile")).as[Tile].first(), threes_dbl)
565+
val df1 = Seq(thousand).toDF("tile")
566+
assertEqual(df1.select(log10($"tile")).as[ProjectedRasterTile].first(), threesDouble)
567567

568-
// ln random tile == log10 random tile / log10(e)
569-
val df_2 = Seq(randDoubleTile).toDF("tile")
568+
// ln random tile == log10 random tile / log10(e); random tile square to ensure all positive cell values
569+
val df2 = Seq(randPositiveDoubleTile).toDF("tile")
570570
val log10e = math.log10(math.E)
571-
assertEqual(df_2.select(log($"tile")).as[Tile].first(), df_2.select(log10($"tile")).as[Tile].first / log10e)
571+
assertEqual(df2.select(log($"tile")).as[ProjectedRasterTile].first(),
572+
df2.select(log10($"tile")).as[ProjectedRasterTile].first() / log10e)
573+
574+
lazy val maybeZeros = df2
575+
.selectExpr(s"rf_local_subtract(rf_log(tile), rf_local_divide(rf_log10(tile), ${log10e}))")
576+
.as[ProjectedRasterTile].first()
577+
assertEqual(maybeZeros, zerosDouble)
572578

573-
val maybe_all = df_2.selectExpr(s"rf_local_equal(rf_log(tile), rf_local_divide(rf_log10(tile), ${log10e})").as[Tile].first()
574-
assertEqual(maybe_all, TestData.projectedRasterTile(cols, rows, 1, extent, crs, BitCellType))
579+
// log1p for zeros should be ln(1)
580+
val ln1 = math.log1p(0.0)
581+
val df3 = Seq(zero).toDF("tile")
582+
val maybeLn1 = df3.selectExpr(s"rf_log1p(tile)").as[ProjectedRasterTile].first()
583+
assert(maybeLn1.toArrayDouble().forall(_ == ln1))
575584

576585
checkDocs("rf_log")
586+
checkDocs("rf_log2")
587+
checkDocs("rf_log10")
588+
checkDocs("rf_log1p")
589+
}
590+
591+
it("should take logarithms with non-positive cell values") {
592+
val ni_float = TestData.projectedRasterTile(cols, rows, Double.NegativeInfinity, extent, crs, DoubleConstantNoDataCellType)
593+
val zero_float =TestData.projectedRasterTile(cols, rows, 0.0, extent, crs, DoubleConstantNoDataCellType)
594+
595+
// tile zeros ==> -Infinity
596+
val df_0 = Seq(zero).toDF("tile")
597+
assertEqual(df_0.select(log($"tile")).as[ProjectedRasterTile].first(), ni_float)
598+
assertEqual(df_0.select(log10($"tile")).as[ProjectedRasterTile].first(), ni_float)
599+
assertEqual(df_0.select(log2($"tile")).as[ProjectedRasterTile].first(), ni_float)
600+
// log1p of zeros should be 0.
601+
assertEqual(df_0.select(log1p($"tile")).as[ProjectedRasterTile].first(), zero_float)
602+
603+
// tile negative values ==> NaN
604+
assert(df_0.selectExpr("rf_log(rf_local_subtract(tile, 42))").as[ProjectedRasterTile].first().isNoDataTile)
605+
assert(df_0.selectExpr("rf_log2(rf_local_subtract(tile, 42))").as[ProjectedRasterTile].first().isNoDataTile)
606+
assert(df_0.select(log1p(local_subtract($"tile", 42))).as[ProjectedRasterTile].first().isNoDataTile)
607+
assert(df_0.select(log10(local_subtract($"tile", lit(0.01)))).as[ProjectedRasterTile].first().isNoDataTile)
608+
577609
}
578610
}
579611
}

0 commit comments

Comments
 (0)