Skip to content

Commit d930769

Browse files
zhangfengcdtjiayuasu
authored andcommitted
[SEDONA-690] Set default metric to use Haversine for KNN join and code refactoring (#1909)
* [SEDONA-690] Set default metric to use Haversine for KNN join and some code refactor * fix unit tests * clean up join params
1 parent 86f0fc2 commit d930769

7 files changed

Lines changed: 329 additions & 235 deletions

File tree

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.apache.sedona.core.joinJudgement;
20+
21+
import java.util.ArrayList;
22+
import java.util.Collections;
23+
import java.util.Iterator;
24+
import java.util.LinkedHashSet;
25+
import java.util.List;
26+
import java.util.NoSuchElementException;
27+
import org.apache.commons.lang3.tuple.Pair;
28+
import org.apache.sedona.core.enums.DistanceMetric;
29+
import org.apache.sedona.core.wrapper.UniqueGeometry;
30+
import org.apache.spark.util.LongAccumulator;
31+
import org.locationtech.jts.geom.Envelope;
32+
import org.locationtech.jts.geom.Geometry;
33+
import org.locationtech.jts.index.strtree.ItemDistance;
34+
import org.locationtech.jts.index.strtree.STRtree;
35+
36+
public class InMemoryKNNJoinIterator<T extends Geometry, U extends Geometry>
37+
implements Iterator<Pair<T, U>> {
38+
private final Iterator<T> querySideIterator;
39+
private final STRtree strTree;
40+
41+
private final int k;
42+
private final DistanceMetric distanceMetric;
43+
private final boolean includeTies;
44+
private final ItemDistance itemDistance;
45+
46+
private final LongAccumulator streamCount;
47+
private final LongAccumulator resultCount;
48+
49+
private final List<Pair<T, U>> currentResults = new ArrayList<>();
50+
private int currentResultIndex = 0;
51+
52+
public InMemoryKNNJoinIterator(
53+
Iterator<T> querySideIterator,
54+
STRtree strTree,
55+
int k,
56+
DistanceMetric distanceMetric,
57+
boolean includeTies,
58+
LongAccumulator streamCount,
59+
LongAccumulator resultCount) {
60+
this.querySideIterator = querySideIterator;
61+
this.strTree = strTree;
62+
63+
this.k = k;
64+
this.distanceMetric = distanceMetric;
65+
this.includeTies = includeTies;
66+
this.itemDistance = KnnJoinIndexJudgement.getItemDistance(distanceMetric);
67+
68+
this.streamCount = streamCount;
69+
this.resultCount = resultCount;
70+
}
71+
72+
@Override
73+
public boolean hasNext() {
74+
if (currentResultIndex < currentResults.size()) {
75+
return true;
76+
}
77+
78+
currentResultIndex = 0;
79+
currentResults.clear();
80+
while (querySideIterator.hasNext()) {
81+
populateNextBatch();
82+
if (!currentResults.isEmpty()) {
83+
return true;
84+
}
85+
}
86+
87+
return false;
88+
}
89+
90+
@Override
91+
public Pair<T, U> next() {
92+
if (!hasNext()) {
93+
throw new NoSuchElementException();
94+
}
95+
96+
return currentResults.get(currentResultIndex++);
97+
}
98+
99+
private void populateNextBatch() {
100+
T queryItem = querySideIterator.next();
101+
Geometry queryGeom;
102+
if (queryItem instanceof UniqueGeometry) {
103+
queryGeom = (Geometry) ((UniqueGeometry<?>) queryItem).getOriginalGeometry();
104+
} else {
105+
queryGeom = queryItem;
106+
}
107+
streamCount.add(1);
108+
109+
Object[] localK =
110+
strTree.nearestNeighbour(queryGeom.getEnvelopeInternal(), queryGeom, itemDistance, k);
111+
if (includeTies) {
112+
localK = getUpdatedLocalKWithTies(queryGeom, localK, strTree);
113+
}
114+
115+
for (Object obj : localK) {
116+
U candidate = (U) obj;
117+
Pair<T, U> pair = Pair.of(queryItem, candidate);
118+
currentResults.add(pair);
119+
resultCount.add(1);
120+
}
121+
}
122+
123+
private Object[] getUpdatedLocalKWithTies(
124+
Geometry streamShape, Object[] localK, STRtree strTree) {
125+
Envelope searchEnvelope = streamShape.getEnvelopeInternal();
126+
// get the maximum distance from the k nearest neighbors
127+
double maxDistance = 0.0;
128+
LinkedHashSet<U> uniqueCandidates = new LinkedHashSet<>();
129+
for (Object obj : localK) {
130+
U candidate = (U) obj;
131+
uniqueCandidates.add(candidate);
132+
double distance = streamShape.distance(candidate);
133+
if (distance > maxDistance) {
134+
maxDistance = distance;
135+
}
136+
}
137+
searchEnvelope.expandBy(maxDistance);
138+
List<U> candidates = strTree.query(searchEnvelope);
139+
if (!candidates.isEmpty()) {
140+
// update localK with all candidates that are within the maxDistance
141+
List<Object> tiedResults = new ArrayList<>();
142+
// add all localK
143+
Collections.addAll(tiedResults, localK);
144+
145+
for (U candidate : candidates) {
146+
double distance = streamShape.distance(candidate);
147+
if (distance == maxDistance && !uniqueCandidates.contains(candidate)) {
148+
tiedResults.add(candidate);
149+
}
150+
}
151+
localK = tiedResults.toArray();
152+
}
153+
return localK;
154+
}
155+
}

0 commit comments

Comments
 (0)