Skip to content

Commit 527c1aa

Browse files
KontinuationCopilotpaleolimbot
authored
feat(rust/sedona-raster-functions): add RS_SetSRID/RS_SetCRS with batch-local cache refactoring (#630)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Dewey Dunnington <dewey@dunnington.ca>
1 parent 6c06f80 commit 527c1aa

6 files changed

Lines changed: 856 additions & 47 deletions

File tree

rust/sedona-functions/src/st_setsrid.rs

Lines changed: 12 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@
1414
// KIND, either express or implied. See the License for the
1515
// specific language governing permissions and limitations
1616
// under the License.
17-
use std::{
18-
collections::{HashMap, HashSet},
19-
sync::{Arc, OnceLock},
20-
};
17+
use std::sync::{Arc, OnceLock};
2118

2219
use arrow_array::{
2320
builder::{BinaryBuilder, NullBufferBuilder},
@@ -41,7 +38,11 @@ use sedona_expr::{
4138
scalar_udf::{ScalarKernelRef, SedonaScalarKernel, SedonaScalarUDF},
4239
};
4340
use sedona_geometry::transform::CrsEngine;
44-
use sedona_schema::{crs::deserialize_crs, datatypes::SedonaType, matchers::ArgMatcher};
41+
use sedona_schema::{
42+
crs::{deserialize_crs, CachedCrsNormalization, CachedSRIDToCrs},
43+
datatypes::SedonaType,
44+
matchers::ArgMatcher,
45+
};
4546

4647
/// ST_SetSRID() scalar UDF implementation
4748
///
@@ -473,8 +474,7 @@ fn normalize_crs_array(
473474
| DataType::UInt16
474475
| DataType::UInt32
475476
| DataType::UInt64 => {
476-
// Local cache to avoid re-validating inputs
477-
let mut known_valid = HashSet::new();
477+
let mut srid_to_crs = CachedSRIDToCrs::new();
478478

479479
let int_value = crs_value.cast_to(&DataType::Int64, None)?;
480480
let int_array_ref = ColumnarValue::values_to_arrays(&[int_value])?;
@@ -483,18 +483,10 @@ fn normalize_crs_array(
483483
.iter()
484484
.map(|maybe_srid| -> Result<Option<String>> {
485485
if let Some(srid) = maybe_srid {
486-
if srid == 0 {
486+
let Some(auth_code) = srid_to_crs.get_crs(srid)? else {
487487
return Ok(None);
488-
} else if srid == 4326 {
489-
return Ok(Some("OGC:CRS84".to_string()));
490-
}
491-
492-
let auth_code = format!("EPSG:{srid}");
493-
if !known_valid.contains(&srid) {
494-
validate_crs(&auth_code, maybe_engine)?;
495-
known_valid.insert(srid);
496-
}
497-
488+
};
489+
validate_crs(&auth_code, maybe_engine)?;
498490
Ok(Some(auth_code))
499491
} else {
500492
Ok(None)
@@ -505,7 +497,7 @@ fn normalize_crs_array(
505497
Ok(Arc::new(utf8_view_array))
506498
}
507499
_ => {
508-
let mut known_abbreviated = HashMap::<String, String>::new();
500+
let mut crs_norm = CachedCrsNormalization::new();
509501

510502
let string_value = crs_value.cast_to(&DataType::Utf8View, None)?;
511503
let string_array_ref = ColumnarValue::values_to_arrays(&[string_value])?;
@@ -514,25 +506,7 @@ fn normalize_crs_array(
514506
.iter()
515507
.map(|maybe_crs| -> Result<Option<String>> {
516508
if let Some(crs_str) = maybe_crs {
517-
if crs_str == "0" {
518-
return Ok(None);
519-
}
520-
521-
if let Some(abbreviated_crs) = known_abbreviated.get(crs_str) {
522-
Ok(Some(abbreviated_crs.clone()))
523-
} else if let Some(crs) = deserialize_crs(crs_str)? {
524-
let abbreviated_crs =
525-
if let Some(auth_code) = crs.to_authority_code()? {
526-
auth_code
527-
} else {
528-
crs_str.to_string()
529-
};
530-
531-
known_abbreviated.insert(crs.to_string(), abbreviated_crs.clone());
532-
Ok(Some(abbreviated_crs))
533-
} else {
534-
Ok(None)
535-
}
509+
crs_norm.normalize(crs_str)
536510
} else {
537511
Ok(None)
538512
}

rust/sedona-raster-functions/benches/native-raster-functions.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,20 @@ fn criterion_benchmark(c: &mut Criterion) {
5555
BenchmarkArgs::ArrayScalarScalar(Raster(64, 64), Int32(0, 63), Int32(0, 63)),
5656
);
5757
benchmark::scalar(c, &f, "native-raster", "rs_rotation", Raster(64, 64));
58+
benchmark::scalar(
59+
c,
60+
&f,
61+
"native-raster",
62+
"rs_setcrs",
63+
BenchmarkArgs::ArrayScalar(Raster(64, 64), String("EPSG:3857".to_string())),
64+
);
65+
benchmark::scalar(
66+
c,
67+
&f,
68+
"native-raster",
69+
"rs_setsrid",
70+
BenchmarkArgs::ArrayScalar(Raster(64, 64), Int32(3857, 3858)),
71+
);
5872
benchmark::scalar(c, &f, "native-raster", "rs_scalex", Raster(64, 64));
5973
benchmark::scalar(c, &f, "native-raster", "rs_scaley", Raster(64, 64));
6074
benchmark::scalar(c, &f, "native-raster", "rs_skewx", Raster(64, 64));

rust/sedona-raster-functions/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ pub mod rs_georeference;
2424
pub mod rs_geotransform;
2525
pub mod rs_numbands;
2626
pub mod rs_rastercoordinate;
27+
pub mod rs_setsrid;
2728
pub mod rs_size;
2829
pub mod rs_srid;
2930
pub mod rs_worldcoordinate;

rust/sedona-raster-functions/src/register.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ pub fn default_function_set() -> FunctionSet {
5555
crate::rs_rastercoordinate::rs_worldtorastercoordy_udf,
5656
crate::rs_size::rs_height_udf,
5757
crate::rs_size::rs_width_udf,
58+
crate::rs_setsrid::rs_set_crs_udf,
59+
crate::rs_setsrid::rs_set_srid_udf,
5860
crate::rs_srid::rs_crs_udf,
5961
crate::rs_srid::rs_srid_udf,
6062
crate::rs_worldcoordinate::rs_rastertoworldcoord_udf,

0 commit comments

Comments
 (0)