Skip to content

Commit df11ac1

Browse files
authored
Merge pull request #57 from s22s/feature/tile-quantile-opts
Alternate approaches to exposing quantile summaries API
2 parents 456914d + 3f90067 commit df11ac1

8 files changed

Lines changed: 180 additions & 173 deletions

File tree

core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,15 @@ trait RasterFunctions {
171171
/** Assign a `NoData` value to the tile column. */
172172
def rf_with_no_data(col: Column, nodata: Column): Column = SetNoDataValue(col, nodata)
173173

174-
/** Compute the full column aggregate floating point histogram. */
174+
/** Compute the approximate aggregate floating point histogram using a streaming algorithm, with the default of 80 buckets. */
175175
def rf_agg_approx_histogram(col: Column): TypedColumn[Any, CellHistogram] = HistogramAggregate(col)
176176

177+
/** Compute the approximate aggregate floating point histogram using a streaming algorithm, with the given number of buckets. */
178+
def rf_agg_approx_histogram(col: Column, numBuckets: Int): TypedColumn[Any, CellHistogram] = {
179+
require(numBuckets > 0, "Must provide a positive number of buckets")
180+
HistogramAggregate(col, numBuckets)
181+
}
182+
177183
/** Compute the full column aggregate floating point statistics. */
178184
def rf_agg_stats(col: Column): TypedColumn[Any, CellStatistics] = CellStatsAggregate(col)
179185

@@ -186,6 +192,23 @@ trait RasterFunctions {
186192
/** Computes the number of NoData cells in a column. */
187193
def rf_agg_no_data_cells(col: Column): TypedColumn[Any, Long] = CellCountAggregate.NoDataCells(col)
188194

195+
/**
196+
* Calculates the approximate quantiles of a tile column of a DataFrame.
197+
* @param tile tile column to extract cells from.
198+
* @param probabilities a list of quantile probabilities
199+
* Each number must belong to [0, 1].
200+
* For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
201+
* @param relativeError The relative target precision to achieve (greater than or equal to 0).
202+
* @return the approximate quantiles at the given probabilities of each column
203+
*/
204+
def rf_agg_approx_quantiles(
205+
tile: Column,
206+
probabilities: Seq[Double],
207+
relativeError: Double = 0.00001): TypedColumn[Any, Seq[Double]] = {
208+
require(probabilities.nonEmpty, "at least one quantile probability is required")
209+
ApproxCellQuantilesAggregate(tile, probabilities, relativeError)
210+
}
211+
189212
/** Compute the Tile-wise mean */
190213
def rf_tile_mean(col: Column): TypedColumn[Any, Double] =
191214
TileMean(col)
@@ -546,14 +569,17 @@ trait RasterFunctions {
546569
/** Return the incoming tile untouched. */
547570
def rf_identity(tileCol: Column): Column = Identity(tileCol)
548571

549-
/** Create a row for each cell in Tile. */
572+
/** Create a row for each cell in Tile.
573+
* The output will include the columns `column_index`, `row_index` indicating where in the tile the cell originated. */
550574
def rf_explode_tiles(cols: Column*): Column = rf_explode_tiles_sample(1.0, None, cols: _*)
551575

552-
/** Create a row for each cell in Tile with random sampling and optional seed. */
576+
/** Create a row for each cell in Tile with random sampling and optional seed.
577+
* The output will include the columns `column_index`, `row_index` indicating where in the tile the cell originated. */
553578
def rf_explode_tiles_sample(sampleFraction: Double, seed: Option[Long], cols: Column*): Column =
554579
ExplodeTiles(sampleFraction, seed, cols)
555580

556-
/** Create a row for each cell in Tile with random sampling (no seed). */
581+
/** Create a row for each cell in Tile with random sampling (no seed).
582+
* The output will include the columns `column_index`, `row_index` indicating where in the tile the cell originated. */
557583
def rf_explode_tiles_sample(sampleFraction: Double, cols: Column*): Column =
558584
ExplodeTiles(sampleFraction, None, cols)
559585
}

core/src/main/scala/org/locationtech/rasterframes/encoders/StandardSerializers.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,21 @@
2121

2222
package org.locationtech.rasterframes.encoders
2323

24+
import java.nio.ByteBuffer
25+
2426
import com.github.blemale.scaffeine.Scaffeine
2527
import geotrellis.proj4.CRS
2628
import geotrellis.raster._
2729
import geotrellis.spark._
2830
import geotrellis.spark.tiling.LayoutDefinition
2931
import geotrellis.vector._
32+
import org.apache.spark.sql.catalyst.util.QuantileSummaries
3033
import org.apache.spark.sql.types._
3134
import org.locationtech.jts.geom.Envelope
3235
import org.locationtech.rasterframes.TileType
3336
import org.locationtech.rasterframes.encoders.CatalystSerializer.{CatalystIO, _}
3437
import org.locationtech.rasterframes.model.LazyCRS
38+
import org.locationtech.rasterframes.util.KryoSupport
3539

3640
/** Collection of CatalystSerializers for third-party types. */
3741
trait StandardSerializers {
@@ -294,9 +298,23 @@ trait StandardSerializers {
294298
implicit val spatialKeyTLMSerializer = tileLayerMetadataSerializer[SpatialKey]
295299
implicit val spaceTimeKeyTLMSerializer = tileLayerMetadataSerializer[SpaceTimeKey]
296300

301+
implicit val quantileSerializer: CatalystSerializer[QuantileSummaries] = new CatalystSerializer[QuantileSummaries] {
302+
override val schema: StructType = StructType(Seq(
303+
StructField("quantile_serializer_kryo", BinaryType, false)
304+
))
305+
306+
override protected def to[R](t: QuantileSummaries, io: CatalystSerializer.CatalystIO[R]): R = {
307+
val buf = KryoSupport.serialize(t)
308+
io.create(buf.array())
309+
}
310+
311+
override protected def from[R](t: R, io: CatalystSerializer.CatalystIO[R]): QuantileSummaries = {
312+
KryoSupport.deserialize[QuantileSummaries](ByteBuffer.wrap(io.getByteArray(t, 0)))
313+
}
314+
}
297315
}
298316

299-
object StandardSerializers {
317+
object StandardSerializers extends StandardSerializers {
300318
private val s2ctCache = Scaffeine().build[String, CellType](
301319
(s: String) => CellType.fromName(s)
302320
)
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* This software is licensed under the Apache 2 license, quoted below.
3+
*
4+
* Copyright 2019 Astraea, Inc.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
7+
* use this file except in compliance with the License. You may obtain a copy of
8+
* the License at
9+
*
10+
* [http://www.apache.org/licenses/LICENSE-2.0]
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
* License for the specific language governing permissions and limitations under
16+
* the License.
17+
*
18+
* SPDX-License-Identifier: Apache-2.0
19+
*
20+
*/
21+
22+
package org.locationtech.rasterframes.expressions.aggregates
23+
24+
import geotrellis.raster.{Tile, isNoData}
25+
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
26+
import org.apache.spark.sql.catalyst.util.QuantileSummaries
27+
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
28+
import org.apache.spark.sql.{Column, Encoder, Row, TypedColumn, types}
29+
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
30+
import org.locationtech.rasterframes.TileType
31+
import org.locationtech.rasterframes.encoders.CatalystSerializer._
32+
import org.locationtech.rasterframes.expressions.accessors.ExtractTile
33+
34+
35+
case class ApproxCellQuantilesAggregate(probabilities: Seq[Double], relativeError: Double) extends UserDefinedAggregateFunction {
36+
import org.locationtech.rasterframes.encoders.StandardSerializers.quantileSerializer
37+
38+
override def inputSchema: StructType = StructType(Seq(
39+
StructField("value", TileType, true)
40+
))
41+
42+
override def bufferSchema: StructType = StructType(Seq(
43+
StructField("buffer", schemaOf[QuantileSummaries], false)
44+
))
45+
46+
override def dataType: types.DataType = DataTypes.createArrayType(DataTypes.DoubleType)
47+
48+
override def deterministic: Boolean = true
49+
50+
override def initialize(buffer: MutableAggregationBuffer): Unit =
51+
buffer.update(0, new QuantileSummaries(QuantileSummaries.defaultCompressThreshold, relativeError).toRow)
52+
53+
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
54+
val qs = buffer.getStruct(0).to[QuantileSummaries]
55+
if (!input.isNullAt(0)) {
56+
val tile = input.getAs[Tile](0)
57+
var result = qs
58+
tile.foreachDouble(d => if (!isNoData(d)) result = result.insert(d))
59+
buffer.update(0, result.toRow)
60+
}
61+
else buffer
62+
}
63+
64+
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
65+
val left = buffer1.getStruct(0).to[QuantileSummaries]
66+
val right = buffer2.getStruct(0).to[QuantileSummaries]
67+
val merged = left.compress().merge(right.compress())
68+
buffer1.update(0, merged.toRow)
69+
}
70+
71+
override def evaluate(buffer: Row): Seq[Double] = {
72+
val summaries = buffer.getStruct(0).to[QuantileSummaries]
73+
probabilities.flatMap(summaries.query)
74+
}
75+
}
76+
77+
object ApproxCellQuantilesAggregate {
78+
private implicit def doubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()
79+
80+
def apply(
81+
tile: Column,
82+
probabilities: Seq[Double],
83+
relativeError: Double = 0.00001): TypedColumn[Any, Seq[Double]] = {
84+
new ApproxCellQuantilesAggregate(probabilities, relativeError)(ExtractTile(tile))
85+
.as(s"rf_agg_approx_quantiles")
86+
.as[Seq[Double]]
87+
}
88+
}

core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/HistogramAggregate.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,10 @@ object HistogramAggregate {
9898
import org.locationtech.rasterframes.encoders.StandardEncoders.cellHistEncoder
9999

100100
def apply(col: Column): TypedColumn[Any, CellHistogram] =
101-
new HistogramAggregate()(ExtractTile(col))
101+
apply(col, StreamingHistogram.DEFAULT_NUM_BUCKETS)
102+
103+
def apply(col: Column, numBuckets: Int): TypedColumn[Any, CellHistogram] =
104+
new HistogramAggregate(numBuckets)(ExtractTile(col))
102105
.as(s"rf_agg_approx_histogram($col)")
103106
.as[CellHistogram]
104107

core/src/main/scala/org/locationtech/rasterframes/extensions/DataFrameMethods.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,6 @@ trait DataFrameMethods[DF <: DataFrame] extends MethodExtensions[DF] with Metada
155155
def withPrefixedColumnNames(prefix: String): DF =
156156
self.columns.foldLeft(self)((df, c) df.withColumnRenamed(c, s"$prefix$c").asInstanceOf[DF])
157157

158-
/** */
159-
def tileStat(): RasterFrameStatFunctions = new RasterFrameStatFunctions(self)
160-
161158
/**
162159
* Performs a jeft join on the dataframe `right` to this one, reprojecting and merging tiles as necessary.
163160
* The operation is logically a "left outer" join, with the left side also determining the target CRS and extents.

core/src/main/scala/org/locationtech/rasterframes/extensions/RasterFrameStatFunctions.scala

Lines changed: 0 additions & 76 deletions
This file was deleted.

core/src/main/scala/org/locationtech/rasterframes/stats/package.scala

Lines changed: 0 additions & 59 deletions
This file was deleted.

0 commit comments

Comments
 (0)