Skip to content

Commit ceb8db3

Browse files
dmitriplotnikovcopybara-github
authored andcommitted
Optimize list.distinct()
PiperOrigin-RevId: 850474438
1 parent 5588410 commit ceb8db3

File tree

3 files changed

+91
-17
lines changed

3 files changed

+91
-17
lines changed

extensions/src/main/java/dev/cel/extensions/CelListsExtensions.java

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import com.google.common.collect.ImmutableList;
2323
import com.google.common.collect.ImmutableSet;
2424
import com.google.common.collect.Lists;
25+
import com.google.common.collect.Sets;
2526
import dev.cel.checker.CelCheckerBuilder;
2627
import dev.cel.common.CelFunctionDecl;
2728
import dev.cel.common.CelIssue;
@@ -316,27 +317,38 @@ public static ImmutableList<Long> genRange(long end) {
316317
return builder.build();
317318
}
318319

320+
private static class RuntimeEqualityObjectWrapper {
321+
private final Object object;
322+
private final int hashCode;
323+
private final RuntimeEquality runtimeEquality;
324+
325+
RuntimeEqualityObjectWrapper(Object object, RuntimeEquality runtimeEquality) {
326+
this.object = object;
327+
this.runtimeEquality = runtimeEquality;
328+
this.hashCode = runtimeEquality.hashCode(object);
329+
}
330+
331+
@Override
332+
public int hashCode() {
333+
return hashCode;
334+
}
335+
336+
@Override
337+
public boolean equals(Object obj) {
338+
if (!(obj instanceof RuntimeEqualityObjectWrapper)) {
339+
return false;
340+
}
341+
return runtimeEquality.objectEquals(object, ((RuntimeEqualityObjectWrapper) obj).object);
342+
}
343+
}
344+
319345
private static ImmutableList<Object> distinct(
320346
Collection<Object> list, RuntimeEquality runtimeEquality) {
321-
// TODO Optimize this method, which currently has the O(N^2) complexity.
322347
int size = list.size();
323348
ImmutableList.Builder<Object> builder = ImmutableList.builderWithExpectedSize(size);
324-
List<Object> theList;
325-
if (list instanceof List) {
326-
theList = (List<Object>) list;
327-
} else {
328-
theList = ImmutableList.copyOf(list);
329-
}
330-
for (int i = 0; i < size; i++) {
331-
Object element = theList.get(i);
332-
boolean found = false;
333-
for (int j = 0; j < i; j++) {
334-
if (runtimeEquality.objectEquals(element, theList.get(j))) {
335-
found = true;
336-
break;
337-
}
338-
}
339-
if (!found) {
349+
Set<RuntimeEqualityObjectWrapper> distinctValues = Sets.newHashSetWithExpectedSize(size);
350+
for (Object element : list) {
351+
if (distinctValues.add(new RuntimeEqualityObjectWrapper(element, runtimeEquality))) {
340352
builder.add(element);
341353
}
342354
}

runtime/src/main/java/dev/cel/runtime/RuntimeEquality.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,41 @@ public boolean objectEquals(Object x, Object y) {
223223
return Objects.equals(x, y);
224224
}
225225

226+
/**
227+
* Returns the hash code consistent with the {@link #objectEquals(Object, Object)} method. For
228+
* example, {@code hashCode(1) == hashCode(1.0)} since {@code objectEquals(1, 1.0)} is true.
229+
*/
230+
public int hashCode(Object object) {
231+
if (object == null) {
232+
return 0;
233+
}
234+
235+
if (celOptions.disableCelStandardEquality()) {
236+
return Objects.hashCode(object);
237+
}
238+
239+
object = runtimeHelpers.adaptValue(object);
240+
if (object instanceof Number) {
241+
return Double.hashCode(((Number) object).doubleValue());
242+
}
243+
if (object instanceof Iterable) {
244+
int h = 1;
245+
Iterable<?> iter = (Iterable<?>) object;
246+
for (Object elem : iter) {
247+
h = h * 31 + hashCode(elem);
248+
}
249+
return h;
250+
}
251+
if (object instanceof Map) {
252+
int h = 0;
253+
for (Map.Entry<?, ?> entry : ((Map<?, ?>) object).entrySet()) {
254+
h += hashCode(entry.getKey()) ^ hashCode(entry.getValue());
255+
}
256+
return h;
257+
}
258+
return Objects.hashCode(object);
259+
}
260+
226261
private static Optional<UnsignedLong> doubleToUnsignedLossless(Number v) {
227262
Optional<UnsignedLong> conv = RuntimeHelpers.doubleToUnsignedChecked(v.doubleValue());
228263
return conv.map(ul -> ul.longValue() == v.doubleValue() ? ul : null);

runtime/src/test/java/dev/cel/runtime/RuntimeEqualityTest.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@
1414

1515
package dev.cel.runtime;
1616

17+
import static com.google.common.truth.Truth.assertThat;
1718
import static org.junit.Assert.assertThrows;
1819

20+
import com.google.common.collect.ImmutableList;
21+
import com.google.common.collect.ImmutableMap;
22+
import com.google.common.primitives.UnsignedLong;
1923
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
2024
import dev.cel.common.CelOptions;
2125
import dev.cel.expr.conformance.proto2.TestAllTypes;
@@ -25,6 +29,29 @@
2529
@RunWith(TestParameterInjector.class)
2630
public final class RuntimeEqualityTest {
2731

32+
@Test
33+
public void objectEquals_and_hashCode() {
34+
RuntimeEquality runtimeEquality =
35+
RuntimeEquality.create(RuntimeHelpers.create(), CelOptions.DEFAULT);
36+
assertEqualityAndHashCode(runtimeEquality, 1, 1);
37+
assertEqualityAndHashCode(runtimeEquality, 2, 2L);
38+
assertEqualityAndHashCode(runtimeEquality, 3, 3.0);
39+
assertEqualityAndHashCode(runtimeEquality, 4, UnsignedLong.valueOf(4));
40+
assertEqualityAndHashCode(
41+
runtimeEquality,
42+
ImmutableList.of(1, 2, 3),
43+
ImmutableList.of(1.0, 2L, UnsignedLong.valueOf(3)));
44+
assertEqualityAndHashCode(
45+
runtimeEquality,
46+
ImmutableMap.of("a", 1, "b", 2),
47+
ImmutableMap.of("a", 1L, "b", UnsignedLong.valueOf(2)));
48+
}
49+
50+
private void assertEqualityAndHashCode(RuntimeEquality runtimeEquality, Object obj1, Object obj2) {
51+
assertThat(runtimeEquality.objectEquals(obj1, obj2)).isTrue();
52+
assertThat(runtimeEquality.hashCode(obj1)).isEqualTo(runtimeEquality.hashCode(obj2));
53+
}
54+
2855
@Test
2956
public void objectEquals_messageLite_throws() {
3057
RuntimeEquality runtimeEquality =

0 commit comments

Comments
 (0)