2121import java .util .HashSet ;
2222import java .util .List ;
2323import java .util .Set ;
24+ import java .util .function .BiConsumer ;
25+ import java .util .function .Supplier ;
2426import org .jspecify .annotations .Nullable ;
2527import org .pkl .core .Composite ;
2628import org .pkl .core .PClass ;
@@ -38,7 +40,7 @@ public final class VmReference extends VmValue {
3840 private final ImRrbt <VmTyped > path ;
3941 // candidate types can only be: PType.Class, PType.Alias (only preservedAliasTypes),
4042 // PType.StringLiteral, or PType.UNKNOWN
41- private final Set < PType > candidateTypes ;
43+ private final PType referentType ;
4244
4345 private boolean forced = false ;
4446
@@ -69,10 +71,10 @@ public VmReference(VmTyped domain, VmClass clazz, Object data) {
6971 normalizeTypes (new PType .Class (clazz .export ()), clazz .getModule ().getVmClass ().export ()));
7072 }
7173
72- public VmReference (VmTyped domain , Object data , ImRrbt <VmTyped > path , Set < PType > candidateTypes ) {
74+ public VmReference (VmTyped domain , Object data , ImRrbt <VmTyped > path , PType referentType ) {
7375 this .domain = domain ;
7476 this .data = data ;
75- this .candidateTypes = candidateTypes ;
77+ this .referentType = referentType ;
7678 this .path = path ;
7779 }
7880
@@ -88,19 +90,32 @@ public List<VmTyped> getPath() {
8890 return path ;
8991 }
9092
93+ public PType getReferentType () {
94+ return referentType ;
95+ }
96+
9197 // simplifies a type by:
9298 // * erasing constraints
9399 // * transforming T? into T|Null
94100 // * dereferencing aliases (except for well-known stdlib alias types)
95101 // * flattening unions
96102 // * when moduleClass is supplied, replace PType.MODULE with appropriate PType.Class
97103 // * drop PType.NOTHING, PType.Function, and PType.TypeVariable
98- private static Set < PType > normalizeTypes (PType type , PClass moduleClass ) {
104+ private static PType normalizeTypes (PType type , PClass moduleClass ) {
99105 var types = new HashSet <PType >();
100106 normalizeTypes (type , moduleClass , types );
101- if (types .contains (PType .UNKNOWN )) return Set .of (PType .UNKNOWN );
102- if (containsClass (types , anyType .getPClass ())) return Set .of (anyType );
103- return types ;
107+ return minimizeTypes (types );
108+ }
109+
110+ private static PType minimizeTypes (Set <PType > types ) {
111+ if (types .size () == 1 ) return types .iterator ().next ();
112+ // optimization: unknown allows all references, erase all candidates to only unknown
113+ if (types .contains (PType .UNKNOWN )) return PType .UNKNOWN ;
114+ // optimization: All allows all references, erase all candidates to only All
115+ if (containsClass (types , anyType .getPClass ())) return anyType ;
116+ var typesList = new ArrayList <>(types );
117+ typesList .sort (Comparator .comparing (Object ::toString ));
118+ return new PType .Union (typesList );
104119 }
105120
106121 private static void normalizeTypes (PType type , PClass moduleClass , Set <PType > result ) {
@@ -119,8 +134,7 @@ private static void normalizeTypes(PType type, PClass moduleClass, Set<PType> re
119134 } else {
120135 var typeArgs = new ArrayList <PType >(clazz .getTypeArguments ().size ());
121136 for (var arg : clazz .getTypeArguments ()) {
122- var tt = new ArrayList <>(normalizeTypes (arg , moduleClass ));
123- typeArgs .add (tt .size () == 1 ? tt .get (0 ) : new PType .Union (tt ));
137+ typeArgs .add (normalizeTypes (arg , moduleClass ));
124138 }
125139 result .add (new PType .Class (clazz .getPClass (), typeArgs ));
126140 }
@@ -144,33 +158,34 @@ private static void normalizeTypes(PType type, PClass moduleClass, Set<PType> re
144158 }
145159 }
146160
161+ private static Iterable <PType > iterateTypes (PType t ) {
162+ if (t instanceof PType .Union union ) return union .getElementTypes ();
163+ return Collections .singleton (t );
164+ }
165+
147166 public @ Nullable VmReference withPropertyAccess (Identifier property ) {
148- Set <PType > candidates = new HashSet <>();
149- for (var t : candidateTypes ) {
150- getCandidatePropertyType (t , property .toString (), candidates );
151- }
152- if (candidates .isEmpty ()) {
153- return null ; // no valid property found
154- } else if (candidates .contains (PType .UNKNOWN )) {
155- // optimization: unknown allows all references, erase all candidates to only unknown
156- candidates = Set .of (PType .UNKNOWN );
157- }
158- return new VmReference (
159- domain , data , path .append (newAccess (property .toString (), null )), candidates );
167+ var propString = property .toString ();
168+ return withAccess (
169+ (t , candidates ) -> getCandidatePropertyType (t , propString , candidates ),
170+ () -> newAccess (property .toString (), null ));
160171 }
161172
162173 public @ Nullable VmReference withSubscriptAccess (Object key ) {
174+ return withAccess (
175+ (t , candidates ) -> getCandidateSubscriptType (t , key , candidates ),
176+ () -> newAccess (null , key ));
177+ }
178+
179+ private @ Nullable VmReference withAccess (
180+ BiConsumer <PType , Set <PType >> checkCandidate , Supplier <VmTyped > makeAccess ) {
163181 Set <PType > candidates = new HashSet <>();
164- for (var t : candidateTypes ) {
165- getCandidateSubscriptType ( t , key , candidates );
182+ for (var t : iterateTypes ( referentType ) ) {
183+ checkCandidate . accept ( t , candidates );
166184 }
167185 if (candidates .isEmpty ()) {
168- return null ; // no valid subscript found
169- } else if (candidates .contains (PType .UNKNOWN )) {
170- // optimization: unknown allows all references, erase all candidates to only unknown
171- candidates = Set .of (PType .UNKNOWN );
186+ return null ; // no valid access found
172187 }
173- return new VmReference (domain , data , path .append (newAccess ( null , key )), candidates );
188+ return new VmReference (domain , data , path .append (makeAccess . get ( )), minimizeTypes ( candidates ) );
174189 }
175190
176191 @ SuppressWarnings ("DuplicatedCode" )
@@ -234,7 +249,7 @@ private static void getCandidateSubscriptType(PType type, Object key, Set<PType>
234249 || clazz .getPClass ().getInfo () == PClassInfo .Map ) {
235250 var typeArgs = clazz .getTypeArguments ();
236251 var keyTypes = normalizeTypes (typeArgs .get (0 ), clazz .getPClass ().getModuleClass ());
237- for (var kt : keyTypes ) {
252+ for (var kt : iterateTypes ( keyTypes ) ) {
238253 if (kt == PType .UNKNOWN
239254 || (kt instanceof PType .Class klazz
240255 && klazz .getPClass ().getInfo () == PClassInfo .forValue (VmValue .export (key )))
@@ -252,38 +267,34 @@ private static void getCandidateSubscriptType(PType type, Object key, Set<PType>
252267 */
253268 public boolean referentTypeIsSubtypeOf (PType type , PClass moduleClass ) {
254269 // fast path: if referent is unknown it can match any type check
255- if (candidateTypes . contains ( PType .UNKNOWN ) ) {
270+ if (referentType == PType .UNKNOWN ) {
256271 return true ;
257272 }
258273
259- var checkTypes = normalizeTypes (type , moduleClass );
274+ var checkType = normalizeTypes (type , moduleClass );
260275 // fast path: short circuit if any referent is accepted
261- if (checkTypes . contains ( PType .UNKNOWN ) || containsClass ( checkTypes , anyType .getPClass ())) {
276+ if (checkType == PType .UNKNOWN || isClass ( checkType , anyType .getPClass ())) {
262277 return true ;
263278 }
264279 // fast path: short circuit if nothing is accepted
265- if (checkTypes . size () == 1 && checkTypes . contains ( PType .NOTHING ) ) {
280+ if (checkType == PType .NOTHING ) {
266281 return false ;
267282 }
268283
269- // all candidate types must be subtypes of at least one target type
270- candidate :
271- for (var c : candidateTypes ) {
272- for (var t : checkTypes ) {
273- if (isSubtype (c , t )) continue candidate ;
274- }
275- return false ;
276- }
277- return true ;
284+ return isSubtype (referentType , checkType );
278285 }
279286
280287 private static boolean containsClass (Set <PType > types , PClass pClass ) {
281288 for (var t : types ) {
282- if (t instanceof PType . Class clazz && clazz . getPClass () == pClass ) return true ;
289+ if (isClass ( t , pClass ) ) return true ;
283290 }
284291 return false ;
285292 }
286293
294+ private static boolean isClass (PType t , PClass pClass ) {
295+ return t instanceof PType .Class clazz && clazz .getPClass () == pClass ;
296+ }
297+
287298 private static boolean isSubtype (PType a , PType b ) {
288299 // checks if A is a subtype of B
289300 // cases (A -> B)
@@ -305,6 +316,8 @@ private static boolean isSubtype(PType a, PType b) {
305316 // * invariant: A_i must be identical to B_i
306317 // * covariant: A_i must be a subtype of B_i
307318 // * contravariant: B_i must be a subtype of A_i
319+ // * Union -> Union: Each elem of A must be a subtype of at least one elem of B
320+ // * Non-union -> Union: A must be a subtype of at least one elem of B
308321 if (a == b ) return true ;
309322
310323 if (a instanceof PType .StringLiteral aStr ) {
@@ -366,6 +379,21 @@ private static boolean isSubtype(PType a, PType b) {
366379 }
367380 }
368381 return true ;
382+ } else if (b instanceof PType .Union bUnion ) {
383+ if (a instanceof PType .Union aUnion ) {
384+ a :
385+ for (var aElem : aUnion .getElementTypes ()) {
386+ for (var bElem : bUnion .getElementTypes ()) {
387+ if (isSubtype (aElem , bElem )) continue a ;
388+ }
389+ return false ;
390+ }
391+ return true ;
392+ } else {
393+ for (var bElem : bUnion .getElementTypes ()) {
394+ if (isSubtype (a , bElem )) return true ;
395+ }
396+ }
369397 }
370398 return false ;
371399 }
@@ -395,22 +423,14 @@ public Reference export() {
395423 pathList .add (elem .export ());
396424 }
397425
398- return new Reference (domain .export (), VmValue .export (data ), pathList , exportReferentType ());
399- }
400-
401- public PType exportReferentType () {
402- if (candidateTypes .size () == 1 ) return candidateTypes .iterator ().next ();
403- var types = new ArrayList <>(candidateTypes );
404- // sort multiple candidate types to ensure stable output
405- types .sort (Comparator .comparing (Object ::toString ));
406- return new PType .Union (types );
426+ return new Reference (domain .export (), VmValue .export (data ), pathList , getReferentType ());
407427 }
408428
409429 public PType exportType () {
410430 return new PType .Class (
411431 RefModule .getReferenceClass ().export (),
412432 new PType .Class (domain .getVmClass ().export ()),
413- exportReferentType ());
433+ getReferentType ());
414434 }
415435
416436 @ Override
@@ -433,15 +453,15 @@ public boolean equals(@Nullable Object o) {
433453 return domain .equals (that .domain )
434454 && data .equals (that .data )
435455 && path .equals (that .path )
436- && candidateTypes .equals (that .candidateTypes );
456+ && referentType .equals (that .referentType );
437457 }
438458
439459 @ Override
440460 public int hashCode () {
441461 int result = domain .hashCode ();
442462 result = 31 * result + data .hashCode ();
443463 result = 31 * result + path .hashCode ();
444- result = 31 * result + candidateTypes .hashCode ();
464+ result = 31 * result + referentType .hashCode ();
445465 return result ;
446466 }
447467}
0 commit comments