Skip to content

Commit 82cc224

Browse files
committed
track PType instead of Set<PType>, fix subtype check for type args with unions
1 parent 50fe92e commit 82cc224

5 files changed

Lines changed: 92 additions & 74 deletions

File tree

pkl-core/src/main/java/org/pkl/core/PType.java

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,6 @@ public String toString() {
125125
}
126126
return result;
127127
}
128-
129-
@Override
130-
public boolean equals(Object obj) {
131-
if (this == obj) return true;
132-
if (!(obj instanceof Class clazz)) return false;
133-
return pClass == clazz.getPClass() && typeArguments.equals(clazz.getTypeArguments());
134-
}
135128
}
136129

137130
public static final class Nullable extends PType {

pkl-core/src/main/java/org/pkl/core/runtime/VmReference.java

Lines changed: 75 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import java.util.HashSet;
2222
import java.util.List;
2323
import java.util.Set;
24+
import java.util.function.BiConsumer;
25+
import java.util.function.Supplier;
2426
import org.jspecify.annotations.Nullable;
2527
import org.pkl.core.Composite;
2628
import org.pkl.core.PClass;
@@ -38,7 +40,7 @@ public final class VmReference extends VmValue {
3840
private final ImRrbt<VmTyped> path;
3941
// candidate types can only be: PType.Class, PType.Alias (only preservedAliasTypes),
4042
// PType.StringLiteral, or PType.UNKNOWN
41-
private final Set<PType> candidateTypes;
43+
private final PType referentType;
4244

4345
private boolean forced = false;
4446

@@ -69,10 +71,10 @@ public VmReference(VmTyped domain, VmClass clazz, Object data) {
6971
normalizeTypes(new PType.Class(clazz.export()), clazz.getModule().getVmClass().export()));
7072
}
7173

72-
public VmReference(VmTyped domain, Object data, ImRrbt<VmTyped> path, Set<PType> candidateTypes) {
74+
public VmReference(VmTyped domain, Object data, ImRrbt<VmTyped> path, PType referentType) {
7375
this.domain = domain;
7476
this.data = data;
75-
this.candidateTypes = candidateTypes;
77+
this.referentType = referentType;
7678
this.path = path;
7779
}
7880

@@ -88,19 +90,32 @@ public List<VmTyped> getPath() {
8890
return path;
8991
}
9092

93+
public PType getReferentType() {
94+
return referentType;
95+
}
96+
9197
// simplifies a type by:
9298
// * erasing constraints
9399
// * transforming T? into T|Null
94100
// * dereferencing aliases (except for well-known stdlib alias types)
95101
// * flattening unions
96102
// * when moduleClass is supplied, replace PType.MODULE with appropriate PType.Class
97103
// * drop PType.NOTHING, PType.Function, and PType.TypeVariable
98-
private static Set<PType> normalizeTypes(PType type, PClass moduleClass) {
104+
private static PType normalizeTypes(PType type, PClass moduleClass) {
99105
var types = new HashSet<PType>();
100106
normalizeTypes(type, moduleClass, types);
101-
if (types.contains(PType.UNKNOWN)) return Set.of(PType.UNKNOWN);
102-
if (containsClass(types, anyType.getPClass())) return Set.of(anyType);
103-
return types;
107+
return minimizeTypes(types);
108+
}
109+
110+
private static PType minimizeTypes(Set<PType> types) {
111+
if (types.size() == 1) return types.iterator().next();
112+
// optimization: unknown allows all references, erase all candidates to only unknown
113+
if (types.contains(PType.UNKNOWN)) return PType.UNKNOWN;
114+
// optimization: All allows all references, erase all candidates to only All
115+
if (containsClass(types, anyType.getPClass())) return anyType;
116+
var typesList = new ArrayList<>(types);
117+
typesList.sort(Comparator.comparing(Object::toString));
118+
return new PType.Union(typesList);
104119
}
105120

106121
private static void normalizeTypes(PType type, PClass moduleClass, Set<PType> result) {
@@ -119,8 +134,7 @@ private static void normalizeTypes(PType type, PClass moduleClass, Set<PType> re
119134
} else {
120135
var typeArgs = new ArrayList<PType>(clazz.getTypeArguments().size());
121136
for (var arg : clazz.getTypeArguments()) {
122-
var tt = new ArrayList<>(normalizeTypes(arg, moduleClass));
123-
typeArgs.add(tt.size() == 1 ? tt.get(0) : new PType.Union(tt));
137+
typeArgs.add(normalizeTypes(arg, moduleClass));
124138
}
125139
result.add(new PType.Class(clazz.getPClass(), typeArgs));
126140
}
@@ -144,33 +158,34 @@ private static void normalizeTypes(PType type, PClass moduleClass, Set<PType> re
144158
}
145159
}
146160

161+
private static Iterable<PType> iterateTypes(PType t) {
162+
if (t instanceof PType.Union union) return union.getElementTypes();
163+
return Collections.singleton(t);
164+
}
165+
147166
public @Nullable VmReference withPropertyAccess(Identifier property) {
148-
Set<PType> candidates = new HashSet<>();
149-
for (var t : candidateTypes) {
150-
getCandidatePropertyType(t, property.toString(), candidates);
151-
}
152-
if (candidates.isEmpty()) {
153-
return null; // no valid property found
154-
} else if (candidates.contains(PType.UNKNOWN)) {
155-
// optimization: unknown allows all references, erase all candidates to only unknown
156-
candidates = Set.of(PType.UNKNOWN);
157-
}
158-
return new VmReference(
159-
domain, data, path.append(newAccess(property.toString(), null)), candidates);
167+
var propString = property.toString();
168+
return withAccess(
169+
(t, candidates) -> getCandidatePropertyType(t, propString, candidates),
170+
() -> newAccess(property.toString(), null));
160171
}
161172

162173
public @Nullable VmReference withSubscriptAccess(Object key) {
174+
return withAccess(
175+
(t, candidates) -> getCandidateSubscriptType(t, key, candidates),
176+
() -> newAccess(null, key));
177+
}
178+
179+
private @Nullable VmReference withAccess(
180+
BiConsumer<PType, Set<PType>> checkCandidate, Supplier<VmTyped> makeAccess) {
163181
Set<PType> candidates = new HashSet<>();
164-
for (var t : candidateTypes) {
165-
getCandidateSubscriptType(t, key, candidates);
182+
for (var t : iterateTypes(referentType)) {
183+
checkCandidate.accept(t, candidates);
166184
}
167185
if (candidates.isEmpty()) {
168-
return null; // no valid subscript found
169-
} else if (candidates.contains(PType.UNKNOWN)) {
170-
// optimization: unknown allows all references, erase all candidates to only unknown
171-
candidates = Set.of(PType.UNKNOWN);
186+
return null; // no valid access found
172187
}
173-
return new VmReference(domain, data, path.append(newAccess(null, key)), candidates);
188+
return new VmReference(domain, data, path.append(makeAccess.get()), minimizeTypes(candidates));
174189
}
175190

176191
@SuppressWarnings("DuplicatedCode")
@@ -234,7 +249,7 @@ private static void getCandidateSubscriptType(PType type, Object key, Set<PType>
234249
|| clazz.getPClass().getInfo() == PClassInfo.Map) {
235250
var typeArgs = clazz.getTypeArguments();
236251
var keyTypes = normalizeTypes(typeArgs.get(0), clazz.getPClass().getModuleClass());
237-
for (var kt : keyTypes) {
252+
for (var kt : iterateTypes(keyTypes)) {
238253
if (kt == PType.UNKNOWN
239254
|| (kt instanceof PType.Class klazz
240255
&& klazz.getPClass().getInfo() == PClassInfo.forValue(VmValue.export(key)))
@@ -252,38 +267,34 @@ private static void getCandidateSubscriptType(PType type, Object key, Set<PType>
252267
*/
253268
public boolean referentTypeIsSubtypeOf(PType type, PClass moduleClass) {
254269
// fast path: if referent is unknown it can match any type check
255-
if (candidateTypes.contains(PType.UNKNOWN)) {
270+
if (referentType == PType.UNKNOWN) {
256271
return true;
257272
}
258273

259-
var checkTypes = normalizeTypes(type, moduleClass);
274+
var checkType = normalizeTypes(type, moduleClass);
260275
// fast path: short circuit if any referent is accepted
261-
if (checkTypes.contains(PType.UNKNOWN) || containsClass(checkTypes, anyType.getPClass())) {
276+
if (checkType == PType.UNKNOWN || isClass(checkType, anyType.getPClass())) {
262277
return true;
263278
}
264279
// fast path: short circuit if nothing is accepted
265-
if (checkTypes.size() == 1 && checkTypes.contains(PType.NOTHING)) {
280+
if (checkType == PType.NOTHING) {
266281
return false;
267282
}
268283

269-
// all candidate types must be subtypes of at least one target type
270-
candidate:
271-
for (var c : candidateTypes) {
272-
for (var t : checkTypes) {
273-
if (isSubtype(c, t)) continue candidate;
274-
}
275-
return false;
276-
}
277-
return true;
284+
return isSubtype(referentType, checkType);
278285
}
279286

280287
private static boolean containsClass(Set<PType> types, PClass pClass) {
281288
for (var t : types) {
282-
if (t instanceof PType.Class clazz && clazz.getPClass() == pClass) return true;
289+
if (isClass(t, pClass)) return true;
283290
}
284291
return false;
285292
}
286293

294+
private static boolean isClass(PType t, PClass pClass) {
295+
return t instanceof PType.Class clazz && clazz.getPClass() == pClass;
296+
}
297+
287298
private static boolean isSubtype(PType a, PType b) {
288299
// checks if A is a subtype of B
289300
// cases (A -> B)
@@ -305,6 +316,8 @@ private static boolean isSubtype(PType a, PType b) {
305316
// * invariant: A_i must be identical to B_i
306317
// * covariant: A_i must be a subtype of B_i
307318
// * contravariant: B_i must be a subtype of A_i
319+
// * Union -> Union: Each elem of A must be a subtype of at least one elem of B
320+
// * Non-union -> Union: A must be a subtype of at least one elem of B
308321
if (a == b) return true;
309322

310323
if (a instanceof PType.StringLiteral aStr) {
@@ -366,6 +379,21 @@ private static boolean isSubtype(PType a, PType b) {
366379
}
367380
}
368381
return true;
382+
} else if (b instanceof PType.Union bUnion) {
383+
if (a instanceof PType.Union aUnion) {
384+
a:
385+
for (var aElem : aUnion.getElementTypes()) {
386+
for (var bElem : bUnion.getElementTypes()) {
387+
if (isSubtype(aElem, bElem)) continue a;
388+
}
389+
return false;
390+
}
391+
return true;
392+
} else {
393+
for (var bElem : bUnion.getElementTypes()) {
394+
if (isSubtype(a, bElem)) return true;
395+
}
396+
}
369397
}
370398
return false;
371399
}
@@ -395,22 +423,14 @@ public Reference export() {
395423
pathList.add(elem.export());
396424
}
397425

398-
return new Reference(domain.export(), VmValue.export(data), pathList, exportReferentType());
399-
}
400-
401-
public PType exportReferentType() {
402-
if (candidateTypes.size() == 1) return candidateTypes.iterator().next();
403-
var types = new ArrayList<>(candidateTypes);
404-
// sort multiple candidate types to ensure stable output
405-
types.sort(Comparator.comparing(Object::toString));
406-
return new PType.Union(types);
426+
return new Reference(domain.export(), VmValue.export(data), pathList, getReferentType());
407427
}
408428

409429
public PType exportType() {
410430
return new PType.Class(
411431
RefModule.getReferenceClass().export(),
412432
new PType.Class(domain.getVmClass().export()),
413-
exportReferentType());
433+
getReferentType());
414434
}
415435

416436
@Override
@@ -433,15 +453,15 @@ public boolean equals(@Nullable Object o) {
433453
return domain.equals(that.domain)
434454
&& data.equals(that.data)
435455
&& path.equals(that.path)
436-
&& candidateTypes.equals(that.candidateTypes);
456+
&& referentType.equals(that.referentType);
437457
}
438458

439459
@Override
440460
public int hashCode() {
441461
int result = domain.hashCode();
442462
result = 31 * result + data.hashCode();
443463
result = 31 * result + path.hashCode();
444-
result = 31 * result + candidateTypes.hashCode();
464+
result = 31 * result + referentType.hashCode();
445465
return result;
446466
}
447467
}

pkl-core/src/main/java/org/pkl/core/runtime/VmValueRenderer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ public void visitReference(VmReference value) {
277277
append("Reference(");
278278
visit(value.getDomain());
279279
append(", ");
280-
append(value.exportReferentType());
280+
append(value.getReferentType());
281281
append(", ");
282282
visit(value.getData());
283283
append(")");

pkl-core/src/test/files/LanguageSnippetTests/input/api/reference.pkl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,14 @@ import "pkl:ref"
44
class D extends ref.Domain {
55
function renderReference(reference: ref.Reference<D, Any>): String =
66
let (data = reference.getData())
7-
if (data is Resource)
8-
let (
9-
path =
10-
reference
11-
.getPath()
12-
.map((elem) -> if (elem.isProperty) ".\(elem.property)" else "[\(elem.key)]")
13-
)
14-
"${\(data.name)\(path.join(""))}"
15-
else
16-
throw("can only render references rooted to Resource instances")
7+
let (root = if (data is Resource) data.name else data.toString())
8+
let (
9+
path =
10+
reference
11+
.getPath()
12+
.map((elem) -> if (elem.isProperty) ".\(elem.property)" else "[\(elem.key)]")
13+
)
14+
"${\(root)\(path.join(""))}"
1715
}
1816

1917
local const d: D = new {}
@@ -86,7 +84,7 @@ aRef: Ref<A> = a.$
8684
bRef: Ref<B> = b.$
8785
unknownRef: Ref<A | B> = aRef
8886
unknownRef2: Ref<A> | Ref<B> = aRef
89-
unknownRef3: Ref<A | B> = aOrB.$
87+
unknownRef3: Ref<B | A> = aOrB.$
9088

9189
k: K = new {
9290
aId = aRef.id
@@ -121,6 +119,12 @@ refInterpolation = "\(aRef.outputs.someListing[1])"
121119
kInterpolation = "\(k)"
122120
aValuesJoined = k.aValues.join("\n").replaceAll(Regex("@[a-z0-9]+"), "@<addr>")
123121

122+
// ensure that type arguments that are unions are handled correctly
123+
typeArgs = ref.Reference(d, TypeHolder, null).prop as Ref<Listing<Number | Boolean | String>>
124+
class TypeHolder {
125+
prop: Listing<String | Boolean | Number>
126+
}
127+
124128
output {
125129
renderer {
126130
converters {

pkl-core/src/test/files/LanguageSnippetTests/output/api/reference.pcf

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,4 @@ aValuesJoined = """
5757
org.pkl.core.runtime.VmReference@<addr>
5858
org.pkl.core.runtime.VmReference@<addr>
5959
"""
60+
typeArgs = "${null.prop}"

0 commit comments

Comments
 (0)