diff --git a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java index b144c4ec9..add918f64 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java @@ -343,7 +343,7 @@ private Optional maybeInterceptOptionalCalls( private PlannedInterpretable planCreateStruct(CelExpr celExpr, PlannerContext ctx) { CelStruct struct = celExpr.struct(); - CelType structType = resolveStructType(struct); + CelType structType = resolveStructType(celExpr, ctx); ImmutableList entries = struct.entries(); String[] keys = new String[entries.size()]; @@ -489,7 +489,17 @@ private ResolvedFunction resolveFunction( return ResolvedFunction.newBuilder().setFunctionName(functionName).setTarget(target).build(); } - private CelType resolveStructType(CelStruct struct) { + private CelType resolveStructType(CelExpr expr, PlannerContext ctx) { + CelType checkedType = ctx.typeMap().get(expr.id()); + if (checkedType != null) { + CelKind kind = checkedType.kind(); + // Type-checked ASTs do not need a type-provider lookup as long as it's of expected kind. + if (isValidStructKind(kind)) { + return checkedType; + } + } + + CelStruct struct = expr.struct(); String messageName = struct.messageName(); for (String typeName : container.resolveCandidateNames(messageName)) { CelType structType = typeProvider.findType(typeName).orElse(null); @@ -499,9 +509,7 @@ private CelType resolveStructType(CelStruct struct) { CelKind kind = structType.kind(); - if (!kind.equals(CelKind.STRUCT) - && !kind.equals(CelKind.TIMESTAMP) - && !kind.equals(CelKind.DURATION)) { + if (!isValidStructKind(kind)) { throw new IllegalArgumentException( String.format( "Expected struct type for %s, got %s", structType.name(), structType.kind())); @@ -513,6 +521,12 @@ private CelType resolveStructType(CelStruct struct) { throw new IllegalArgumentException("Undefined type name: " + messageName); } + private static boolean isValidStructKind(CelKind kind) { + return kind.equals(CelKind.STRUCT) + || kind.equals(CelKind.TIMESTAMP) + || kind.equals(CelKind.DURATION); + } + /** Converts a given expression into a qualified name, if possible. */ private Optional toQualifiedName(CelExpr operand) { switch (operand.getKind()) {