Skip to content

Commit c491410

Browse files
committed
Add simple NN resample function
Signed-off-by: Jason T. Brown <jason@astraea.earth>
1 parent f1cd8d4 commit c491410

4 files changed

Lines changed: 117 additions & 1 deletion

File tree

core/src/main/scala/astraea/spark/rasterframes/RasterFunctions.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,4 +344,10 @@ trait RasterFunctions {
344344
def expm1(tileCol: Column): TypedColumn[Any, Tile] =
345345
ExpM1(tileCol)
346346

347+
/** Resample tile using nearest-neighbor */
348+
def resample[T: Numeric](tileCol: Column, value: T) = Resample(tileCol, value)
349+
350+
/** Resample tile using nearest-neighbor */
351+
def resample(tileCol: Column, column2: Column) = Resample(tileCol, column2)
352+
347353
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* This software is licensed under the Apache 2 license, quoted below.
3+
*
4+
* Copyright 2019 Astraea, Inc.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
7+
* use this file except in compliance with the License. You may obtain a copy of
8+
* the License at
9+
*
10+
* [http://www.apache.org/licenses/LICENSE-2.0]
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
* License for the specific language governing permissions and limitations under
16+
* the License.
17+
*
18+
* SPDX-License-Identifier: Apache-2.0
19+
*
20+
*/
21+
22+
package astraea.spark.rasterframes.expressions.localops
23+
24+
import astraea.spark.rasterframes._
25+
import astraea.spark.rasterframes.expressions.DynamicExtractors.tileExtractor
26+
import astraea.spark.rasterframes.expressions.BinaryLocalRasterOp
27+
import geotrellis.raster.Tile
28+
import geotrellis.raster.resample.NearestNeighbor
29+
import org.apache.spark.sql.rf._
30+
import org.apache.spark.sql.catalyst.InternalRow
31+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
32+
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription}
33+
import org.apache.spark.sql.functions.lit
34+
import org.apache.spark.sql.{Column, TypedColumn}
35+
36+
@ExpressionDescription(
37+
usage = "_FUNC_(tile, factor) - Resample tile to different size based on scalar factor or tile whose dimension to match. Scalar less than one will downsample tile; greater than one will upsample. Uses nearest-neighbor value.",
38+
arguments = """
39+
Arguments:
40+
* tile - tile
41+
* rhs - scalar or tile to match dimension""",
42+
examples = """
43+
Examples:
44+
> SELECT _FUNC_(tile, 2.0);
45+
...
46+
> SELECT _FUNC_(tile1, tile2);
47+
..."""
48+
)
49+
case class Resample(left: Expression, right: Expression) extends BinaryLocalRasterOp
50+
with CodegenFallback {
51+
override val nodeName: String = "resample"
52+
override protected def op(left: Tile, right: Tile): Tile = left.resample(right.cols, right.rows, NearestNeighbor)
53+
override protected def op(left: Tile, right: Double): Tile = left.resample((left.cols * right).toInt,
54+
(left.rows * right).toInt, NearestNeighbor)
55+
override protected def op(left: Tile, right: Int): Tile = op(left, right.toDouble)
56+
57+
override def eval(input: InternalRow): Any = {
58+
if(input == null) null
59+
else {
60+
val l = left.eval(input)
61+
val r = right.eval(input)
62+
if (l == null && r == null) null
63+
else if (l == null) r
64+
else if (r == null && tileExtractor.isDefinedAt(right.dataType)) l
65+
else if (r == null) null
66+
else nullSafeEval(l, r)
67+
}
68+
}
69+
}
70+
object Resample{
71+
def apply(left: Column, right: Column): TypedColumn[Any, Tile] =
72+
new Column(Resample(left.expr, right.expr)).as[Tile]
73+
74+
def apply[N: Numeric](tile: Column, value: N): TypedColumn[Any, Tile] =
75+
new Column(Resample(tile.expr, lit(value).expr)).as[Tile]
76+
}

core/src/main/scala/astraea/spark/rasterframes/expressions/package.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ package object expressions {
8686
registry.registerExpression[Exp10]("rf_exp10")
8787
registry.registerExpression[Exp2]("rf_exp2")
8888
registry.registerExpression[ExpM1]("rf_expm1")
89+
registry.registerExpression[Resample]("rf_resample")
8990
registry.registerExpression[TileToArrayDouble]("rf_tile_to_array_double")
9091
registry.registerExpression[TileToArrayInt]("rf_tile_to_array_int")
9192
registry.registerExpression[DataCells]("rf_data_cells")

core/src/test/scala/astraea/spark/rasterframes/RasterFunctionsSpec.scala

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import astraea.spark.rasterframes.tiles.ProjectedRasterTile
2727
import geotrellis.proj4.LatLng
2828
import geotrellis.raster
2929
import geotrellis.raster.testkit.RasterMatchers
30-
import geotrellis.raster.{BitCellType, ByteUserDefinedNoDataCellType, DoubleConstantNoDataCellType, ShortConstantNoDataCellType, Tile, UByteConstantNoDataCellType}
30+
import geotrellis.raster.{ArrayTile, BitCellType, ByteUserDefinedNoDataCellType, DoubleConstantNoDataCellType, ShortConstantNoDataCellType, Tile, UByteConstantNoDataCellType}
3131
import geotrellis.vector.Extent
3232
import org.apache.spark.sql.{AnalysisException, Encoders}
3333
import org.apache.spark.sql.functions._
@@ -659,4 +659,37 @@ class RasterFunctionsSpec extends FunSpec
659659

660660
}
661661
}
662+
it("should resample") {
663+
def lowRes = {
664+
def base = ArrayTile(Array(1,2,3,4), 2, 2)
665+
ProjectedRasterTile(base.convert(ct), extent, crs)
666+
}
667+
def upsampled = {
668+
def base = ArrayTile(Array(
669+
1,1,2,2,
670+
1,1,2,2,
671+
3,3,4,4,
672+
3,3,4,4
673+
), 4, 4)
674+
ProjectedRasterTile(base.convert(ct), extent, crs)
675+
}
676+
// a 4, 4 tile to upsample by shape
677+
def fourByFour = TestData.projectedRasterTile(4, 4, 0, extent, crs, ct)
678+
679+
def df = Seq(lowRes).toDF("tile")
680+
681+
val maybeUp = df.select(resample($"tile", lit(2))).as[ProjectedRasterTile].first()
682+
assertEqual(maybeUp, upsampled)
683+
684+
def df2 = Seq((lowRes, fourByFour)).toDF("tile1", "tile2")
685+
val maybeUpShape = df2.select(resample($"tile1", $"tile2")).as[ProjectedRasterTile].first()
686+
assertEqual(maybeUpShape, upsampled)
687+
688+
// Downsample by double argument < 1
689+
def df3 = Seq(upsampled).toDF("tile").withColumn("factor", lit(0.5))
690+
assertEqual(df3.selectExpr("rf_resample(tile, 0.5)").as[ProjectedRasterTile].first(), lowRes)
691+
assertEqual(df3.selectExpr("rf_resample(tile, factor)").as[ProjectedRasterTile].first(), lowRes)
692+
693+
checkDocs("rf_resample")
694+
}
662695
}

0 commit comments

Comments
 (0)