Skip to content

Commit 1ce28de

Browse files
authored
jextract/jni: Implement generic equals method (#778)
1 parent b96f649 commit 1ce28de

5 files changed

Lines changed: 208 additions & 0 deletions

File tree

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2026 Apple Inc. and the Swift.org project authors
6+
// Licensed under Apache License v2.0
7+
//
8+
// See LICENSE.txt for license information
9+
// See CONTRIBUTORS.txt for the list of Swift.org project authors
10+
//
11+
// SPDX-License-Identifier: Apache-2.0
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
public class HashableClass: Hashable {
16+
public let value: Int
17+
public init(value: Int) {
18+
self.value = value
19+
}
20+
21+
public static func == (lhs: HashableClass, rhs: HashableClass) -> Bool {
22+
lhs.value == rhs.value
23+
}
24+
25+
public func hash(into hasher: inout Hasher) {
26+
hasher.combine(value)
27+
}
28+
}
29+
30+
public class HashableSubclass: HashableClass {
31+
public override init(value: Int) {
32+
super.init(value: value)
33+
}
34+
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Swift.org open source project
4+
//
5+
// Copyright (c) 2026 Apple Inc. and the Swift.org project authors
6+
// Licensed under Apache License v2.0
7+
//
8+
// See LICENSE.txt for license information
9+
// See CONTRIBUTORS.txt for the list of Swift.org project authors
10+
//
11+
// SPDX-License-Identifier: Apache-2.0
12+
//
13+
//===----------------------------------------------------------------------===//
14+
15+
package com.example.swift;
16+
17+
import org.junit.jupiter.api.Test;
18+
import org.swift.swiftkit.core.SwiftArena;
19+
20+
import static org.junit.jupiter.api.Assertions.*;
21+
22+
import java.util.HashSet;
23+
import java.util.List;
24+
25+
@SuppressWarnings({"AssertBetweenInconvertibleTypes", "EqualsWithItself"})
26+
public class HashableTest {
27+
@Test
28+
void valueTypeEquals() {
29+
try (var arena = SwiftArena.ofConfined()) {
30+
var a = MyIDs.makeIntID(42, arena);
31+
var b = MyIDs.makeIntID(42, arena);
32+
var c = MyIDs.makeIntID(0, arena);
33+
var d = MyIDs.makeStringID("42", arena);
34+
assertEquals(a, a);
35+
assertEquals(a, b);
36+
assertNotEquals(a, c);
37+
assertNotEquals(a, d);
38+
assertNotEquals("foo", a);
39+
}
40+
}
41+
42+
@Test
43+
void referenceTypeEquals() {
44+
try (var arena = SwiftArena.ofConfined()) {
45+
var a = HashableClass.init(42, arena);
46+
var b = HashableSubclass.init(42, arena);
47+
var c = HashableSubclass.init(0, arena);
48+
assertEquals(a, b);
49+
assertEquals(b, a);
50+
assertEquals(b, b);
51+
assertNotEquals(a, c);
52+
assertNotEquals(b, c);
53+
}
54+
}
55+
56+
@Test
57+
void hashSetValueType() {
58+
try (var arena = SwiftArena.ofConfined()) {
59+
var a = MyIDs.makeIntID(42, arena);
60+
var b = MyIDs.makeIntID(42, arena);
61+
var c = MyIDs.makeIntID(0, arena);
62+
var set = new HashSet<>(List.of(
63+
a, b
64+
));
65+
assertTrue(set.contains(a));
66+
assertTrue(set.contains(b));
67+
assertFalse(set.contains(c));
68+
assertEquals(1, set.size());
69+
}
70+
}
71+
72+
@Test
73+
void hashSetReferenceType() {
74+
try (var arena = SwiftArena.ofConfined()) {
75+
var a = HashableClass.init(42, arena);
76+
var b = HashableClass.init(42, arena);
77+
var c = HashableSubclass.init(42, arena);
78+
var d = HashableSubclass.init(0, arena);
79+
var set = new HashSet<>(List.of(
80+
a, b, c
81+
));
82+
assertTrue(set.contains(a));
83+
assertTrue(set.contains(b));
84+
assertTrue(set.contains(c));
85+
assertFalse(set.contains(d));
86+
assertEquals(1, set.size());
87+
}
88+
}
89+
}

Sources/JExtractSwiftLib/JNI/JNISwift2JavaGenerator+JavaBindingsPrinting.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,17 @@ extension JNISwift2JavaGenerator {
373373

374374
printer.print(
375375
"""
376+
public boolean equals(Object obj) {
377+
if (obj instanceof JNISwiftInstance rhs) {
378+
return SwiftObjects.equals(this.$memoryAddress(), this.$typeMetadataAddress(), rhs.$memoryAddress(), rhs.$typeMetadataAddress());
379+
}
380+
return false;
381+
}
382+
383+
public int hashCode() {
384+
return SwiftObjects.hashCode(this.$memoryAddress(), this.$typeMetadataAddress());
385+
}
386+
376387
public java.lang.String toString() {
377388
return SwiftObjects.toString(this.$memoryAddress(), this.$typeMetadataAddress());
378389
}

Sources/SwiftJava/SwiftObjects.swift

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,76 @@ extension SwiftObjects {
9898
let typeMetadata = unsafeBitCast(selfType$, to: Any.Type.self)
9999
return String(describing: typeMetadata)
100100
}
101+
102+
@JavaMethod
103+
public static func equals(environment: UnsafeMutablePointer<JNIEnv?>!, lhsPointer: Int64, lhsTypePointer: Int64, rhsPointer: Int64, rhsTypePointer: Int64) -> Bool {
104+
guard let lhsType$ = UnsafeRawPointer(bitPattern: Int(lhsTypePointer)) else {
105+
fatalError("lhsType metadata address was null")
106+
}
107+
let lhsMetatype = unsafeBitCast(lhsType$, to: Any.Type.self)
108+
guard let lhsMetatype = lhsMetatype as? (any Equatable.Type) else {
109+
return false
110+
}
111+
112+
guard let rhsType$ = UnsafeRawPointer(bitPattern: Int(rhsTypePointer)) else {
113+
fatalError("rhsType metadata address was null")
114+
}
115+
let rhsMetatype = unsafeBitCast(rhsType$, to: Any.Type.self)
116+
guard let rhsMetatype = rhsMetatype as? (any Equatable.Type) else {
117+
return false
118+
}
119+
120+
func perform<L: Equatable, R: Equatable>(lhsType: L.Type, rhsType: R.Type) -> Bool {
121+
guard let lhs$ = UnsafeMutablePointer<L>(bitPattern: Int(lhsPointer)) else {
122+
fatalError("lhs memory address was null")
123+
}
124+
guard let rhs$ = UnsafeMutablePointer<R>(bitPattern: Int(rhsPointer)) else {
125+
fatalError("rhs memory address was null")
126+
}
127+
if lhsType == rhsType {
128+
return lhs$.pointee == rhs$.pointee as! L
129+
} else if let lhs = lhs$.pointee as? R {
130+
return lhs == rhs$.pointee
131+
} else if let rhs = rhs$.pointee as? L {
132+
return lhs$.pointee == rhs
133+
}
134+
return false
135+
}
136+
return perform(lhsType: lhsMetatype, rhsType: rhsMetatype)
137+
}
138+
139+
@JavaMethod
140+
public static func hashCode(environment: UnsafeMutablePointer<JNIEnv?>!, selfPointer: Int64, selfTypePointer: Int64) -> Int32 {
141+
guard let selfType$ = UnsafeRawPointer(bitPattern: Int(selfTypePointer)) else {
142+
fatalError("selfType metadata address was null")
143+
}
144+
let typeMetadata = unsafeBitCast(selfType$, to: Any.Type.self)
145+
guard let typeMetadata = typeMetadata as? (any Hashable.Type) else {
146+
// For value types, different instances may return different hash codes even if the values are same.
147+
return Int32(truncatingIfNeeded: selfPointer.hashValue)
148+
}
149+
150+
func perform<T: Hashable>(as type: T.Type) -> Int32 {
151+
guard let self$ = UnsafeMutablePointer<T>(bitPattern: Int(selfPointer)) else {
152+
fatalError("self memory address was null")
153+
}
154+
return Int32(truncatingIfNeeded: self$.pointee.hashValue)
155+
}
156+
return perform(as: typeMetadata)
157+
}
158+
}
159+
160+
public class HashableClass: Hashable {
161+
public let value: Int
162+
public init(value: Int) {
163+
self.value = value
164+
}
165+
166+
public static func == (lhs: HashableClass, rhs: HashableClass) -> Bool {
167+
lhs.value == rhs.value
168+
}
169+
170+
public func hash(into hasher: inout Hasher) {
171+
hasher.combine(value)
172+
}
101173
}

SwiftKitCore/src/main/java/org/swift/swiftkit/core/SwiftObjects.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,6 @@ public static void requireNonZero(long number, String name) {
2929
public static native String toDebugString(long selfPointer, long selfTypePointer);
3030
public static native void destroy(long selfPointer, long selfTypePointer);
3131
public static native String typeDescription(long selfTypePointer);
32+
public static native boolean equals(long lhsPointer, long lhsTypePointer, long rhsPointer, long rhsTypePointer);
33+
public static native int hashCode(long selfPointer, long selfTypePointer);
3234
}

0 commit comments

Comments
 (0)