Skip to content

Commit a138ce3

Browse files
maskri17copybara-github
authored andcommitted
Adding transformMap and transformMapEntry macros
PiperOrigin-RevId: 800668571
1 parent e73283b commit a138ce3

7 files changed

Lines changed: 400 additions & 20 deletions

File tree

checker/src/test/resources/standardEnvDump.baseline

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ Source: 'redundant expression so the env is constructed and can be printed'
44

55

66
Standard environment:
7+
declare cel.@mapInsert {
8+
function cel_@mapInsert_map_map (map(K, V), map(K, V)) -> map(K, V)
9+
function cel_@mapInsert_map_key_value (map(K, V), K, V) -> map(K, V)
10+
}
711
declare !_ {
812
function logical_not (bool) -> bool
913
}

extensions/src/main/java/dev/cel/extensions/BUILD.bazel

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,12 +299,19 @@ java_library(
299299
name = "comprehensions",
300300
srcs = ["CelComprehensionsExtensions.java"],
301301
deps = [
302+
"//checker:checker_builder",
302303
"//common:compiler_common",
304+
"//common:options",
303305
"//common/ast",
306+
"//common/types",
304307
"//compiler:compiler_builder",
308+
"//extensions:extension_library",
305309
"//parser:macro",
306310
"//parser:operator",
307311
"//parser:parser_builder",
312+
"//runtime",
313+
"//runtime:function_binding",
314+
"//runtime:runtime_equality",
308315
"@maven//:com_google_guava_guava",
309316
],
310317
)

extensions/src/main/java/dev/cel/extensions/CelComprehensionsExtensions.java

Lines changed: 260 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,137 @@
1818
import static com.google.common.base.Preconditions.checkNotNull;
1919

2020
import com.google.common.collect.ImmutableList;
21+
import com.google.common.collect.ImmutableMap;
2122
import com.google.common.collect.ImmutableSet;
23+
import dev.cel.checker.CelCheckerBuilder;
24+
import dev.cel.common.CelFunctionDecl;
2225
import dev.cel.common.CelIssue;
26+
import dev.cel.common.CelOptions;
27+
import dev.cel.common.CelOverloadDecl;
2328
import dev.cel.common.ast.CelExpr;
29+
import dev.cel.common.types.MapType;
30+
import dev.cel.common.types.TypeParamType;
2431
import dev.cel.compiler.CelCompilerLibrary;
2532
import dev.cel.parser.CelMacro;
2633
import dev.cel.parser.CelMacroExprFactory;
2734
import dev.cel.parser.CelParserBuilder;
2835
import dev.cel.parser.Operator;
36+
import dev.cel.runtime.CelFunctionBinding;
37+
import dev.cel.runtime.CelInternalRuntimeLibrary;
38+
import dev.cel.runtime.CelRuntimeBuilder;
39+
import dev.cel.runtime.RuntimeEquality;
40+
import java.util.Map;
2941
import java.util.Optional;
3042

31-
/** Internal implementation of CEL comprehensions extensions. */
32-
public final class CelComprehensionsExtensions implements CelCompilerLibrary {
43+
/** Internal implementation of CEL two variable comprehensions extensions. */
44+
final class CelComprehensionsExtensions
45+
implements CelCompilerLibrary, CelInternalRuntimeLibrary, CelExtensionLibrary.FeatureSet {
3346

34-
private static final String TRANSFORM_LIST = "transformList";
47+
private static final String MAP_INSERT_FUNCTION = "cel.@mapInsert";
48+
private static final String MAP_INSERT_OVERLOAD_MAP_MAP = "cel_@mapInsert_map_map";
49+
private static final String MAP_INSERT_OVERLOAD_KEY_VALUE = "cel_@mapInsert_map_key_value";
50+
private static final TypeParamType TYPE_PARAM_K = TypeParamType.create("K");
51+
private static final TypeParamType TYPE_PARAM_V = TypeParamType.create("V");
52+
private static final MapType MAP_KV_TYPE = MapType.create(TYPE_PARAM_K, TYPE_PARAM_V);
3553

54+
enum Function {
55+
MAP_INSERT(
56+
CelFunctionDecl.newFunctionDeclaration(
57+
MAP_INSERT_FUNCTION,
58+
CelOverloadDecl.newGlobalOverload(
59+
MAP_INSERT_OVERLOAD_MAP_MAP,
60+
"Returns a map that's the result of merging given two maps.",
61+
MAP_KV_TYPE,
62+
MAP_KV_TYPE,
63+
MAP_KV_TYPE),
64+
CelOverloadDecl.newGlobalOverload(
65+
MAP_INSERT_OVERLOAD_KEY_VALUE,
66+
"Adds the given key-value pair to the map.",
67+
MAP_KV_TYPE,
68+
MAP_KV_TYPE,
69+
TYPE_PARAM_K,
70+
TYPE_PARAM_V)));
71+
72+
private final CelFunctionDecl functionDecl;
73+
74+
String getFunction() {
75+
return functionDecl.name();
76+
}
77+
78+
Function(CelFunctionDecl functionDecl) {
79+
this.functionDecl = functionDecl;
80+
}
81+
}
82+
83+
private static final CelExtensionLibrary<CelComprehensionsExtensions> LIBRARY =
84+
new CelExtensionLibrary<CelComprehensionsExtensions>() {
85+
private final CelComprehensionsExtensions version0 = new CelComprehensionsExtensions();
86+
87+
@Override
88+
public String name() {
89+
return "comprehensions";
90+
}
91+
92+
@Override
93+
public ImmutableSet<CelComprehensionsExtensions> versions() {
94+
return ImmutableSet.of(version0);
95+
}
96+
};
97+
98+
static CelExtensionLibrary<CelComprehensionsExtensions> library() {
99+
return LIBRARY;
100+
}
101+
102+
private final ImmutableSet<Function> functions;
103+
104+
CelComprehensionsExtensions() {
105+
this.functions = ImmutableSet.copyOf(Function.values());
106+
}
107+
108+
@Override
109+
public void setCheckerOptions(CelCheckerBuilder checkerBuilder) {
110+
functions.forEach(function -> checkerBuilder.addFunctionDeclarations(function.functionDecl));
111+
}
112+
113+
@Override
114+
public void setRuntimeOptions(CelRuntimeBuilder runtimeBuilder) {
115+
throw new UnsupportedOperationException("Unsupported");
116+
}
117+
118+
@Override
119+
public void setRuntimeOptions(
120+
CelRuntimeBuilder runtimeBuilder, RuntimeEquality runtimeEquality, CelOptions celOptions) {
121+
for (Function function : functions) {
122+
for (CelOverloadDecl overload : function.functionDecl.overloads()) {
123+
switch (overload.overloadId()) {
124+
case MAP_INSERT_OVERLOAD_MAP_MAP:
125+
runtimeBuilder.addFunctionBindings(
126+
CelFunctionBinding.from(
127+
MAP_INSERT_OVERLOAD_MAP_MAP,
128+
Map.class,
129+
Map.class,
130+
(map1, map2) -> mapInsertMap(map1, map2, runtimeEquality)));
131+
break;
132+
case MAP_INSERT_OVERLOAD_KEY_VALUE:
133+
runtimeBuilder.addFunctionBindings(
134+
CelFunctionBinding.from(
135+
MAP_INSERT_OVERLOAD_KEY_VALUE,
136+
ImmutableList.of(Map.class, Object.class, Object.class),
137+
args -> mapInsertKeyValue(args, runtimeEquality)));
138+
break;
139+
default:
140+
// Nothing to add.
141+
}
142+
}
143+
}
144+
}
145+
146+
@Override
147+
public int version() {
148+
return 0;
149+
}
150+
151+
@Override
36152
public ImmutableSet<CelMacro> macros() {
37153
return ImmutableSet.of(
38154
CelMacro.newReceiverMacro(
@@ -44,16 +160,58 @@ public ImmutableSet<CelMacro> macros() {
44160
3,
45161
CelComprehensionsExtensions::expandExistsOneMacro),
46162
CelMacro.newReceiverMacro(
47-
TRANSFORM_LIST, 3, CelComprehensionsExtensions::transformListMacro),
163+
"transformList", 3, CelComprehensionsExtensions::transformListMacro),
164+
CelMacro.newReceiverMacro(
165+
"transformList", 4, CelComprehensionsExtensions::transformListMacro),
166+
CelMacro.newReceiverMacro(
167+
"transformMap", 3, CelComprehensionsExtensions::transformMapMacro),
168+
CelMacro.newReceiverMacro(
169+
"transformMap", 4, CelComprehensionsExtensions::transformMapMacro),
170+
CelMacro.newReceiverMacro(
171+
"transformMapEntry", 3, CelComprehensionsExtensions::transformMapEntryMacro),
48172
CelMacro.newReceiverMacro(
49-
TRANSFORM_LIST, 4, CelComprehensionsExtensions::transformListMacro));
173+
"transformMapEntry", 4, CelComprehensionsExtensions::transformMapEntryMacro));
50174
}
51175

52176
@Override
53177
public void setParserOptions(CelParserBuilder parserBuilder) {
54178
parserBuilder.addMacros(macros());
55179
}
56180

181+
// TODO: Implement a more efficient map insertion based on mutability once mutable
182+
// maps are supported in Java stack.
183+
private static ImmutableMap<Object, Object> mapInsertMap(
184+
Map<?, ?> targetMap, Map<?, ?> mapToMerge, RuntimeEquality equality) {
185+
ImmutableMap.Builder<Object, Object> resultBuilder =
186+
ImmutableMap.builderWithExpectedSize(targetMap.size() + mapToMerge.size());
187+
188+
for (Map.Entry<?, ?> entry : mapToMerge.entrySet()) {
189+
if (equality.findInMap(targetMap, entry.getKey()).isPresent()) {
190+
throw new IllegalArgumentException(
191+
String.format("insert failed: key '%s' already exists", entry.getKey()));
192+
} else {
193+
resultBuilder.put(entry.getKey(), entry.getValue());
194+
}
195+
}
196+
return resultBuilder.putAll(targetMap).buildOrThrow();
197+
}
198+
199+
private static ImmutableMap<Object, Object> mapInsertKeyValue(
200+
Object[] args, RuntimeEquality equality) {
201+
Map<?, ?> map = (Map<?, ?>) args[0];
202+
Object key = args[1];
203+
Object value = args[2];
204+
205+
if (equality.findInMap(map, key).isPresent()) {
206+
throw new IllegalArgumentException(
207+
String.format("insert failed: key '%s' already exists", key));
208+
}
209+
210+
ImmutableMap.Builder<Object, Object> builder =
211+
ImmutableMap.builderWithExpectedSize(map.size() + 1);
212+
return builder.put(key, value).putAll(map).buildOrThrow();
213+
}
214+
57215
private static Optional<CelExpr> expandAllMacro(
58216
CelMacroExprFactory exprFactory, CelExpr target, ImmutableList<CelExpr> arguments) {
59217
checkNotNull(exprFactory);
@@ -220,6 +378,103 @@ private static Optional<CelExpr> transformListMacro(
220378
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName())));
221379
}
222380

381+
private static Optional<CelExpr> transformMapMacro(
382+
CelMacroExprFactory exprFactory, CelExpr target, ImmutableList<CelExpr> arguments) {
383+
checkNotNull(exprFactory);
384+
checkNotNull(target);
385+
checkArgument(arguments.size() == 3 || arguments.size() == 4);
386+
CelExpr arg0 = validatedIterationVariable(exprFactory, arguments.get(0));
387+
if (arg0.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
388+
return Optional.of(arg0);
389+
}
390+
CelExpr arg1 = validatedIterationVariable(exprFactory, arguments.get(1));
391+
if (arg1.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
392+
return Optional.of(arg1);
393+
}
394+
CelExpr transform;
395+
CelExpr filter = null;
396+
if (arguments.size() == 4) {
397+
filter = checkNotNull(arguments.get(2));
398+
transform = checkNotNull(arguments.get(3));
399+
} else {
400+
transform = checkNotNull(arguments.get(2));
401+
}
402+
CelExpr accuInit = exprFactory.newMap();
403+
CelExpr condition = exprFactory.newBoolLiteral(true);
404+
CelExpr step =
405+
exprFactory.newGlobalCall(
406+
MAP_INSERT_FUNCTION,
407+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()),
408+
arg0,
409+
transform);
410+
if (filter != null) {
411+
step =
412+
exprFactory.newGlobalCall(
413+
Operator.CONDITIONAL.getFunction(),
414+
filter,
415+
step,
416+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()));
417+
}
418+
return Optional.of(
419+
exprFactory.fold(
420+
arg0.ident().name(),
421+
arg1.ident().name(),
422+
target,
423+
exprFactory.getAccumulatorVarName(),
424+
accuInit,
425+
condition,
426+
step,
427+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName())));
428+
}
429+
430+
private static Optional<CelExpr> transformMapEntryMacro(
431+
CelMacroExprFactory exprFactory, CelExpr target, ImmutableList<CelExpr> arguments) {
432+
checkNotNull(exprFactory);
433+
checkNotNull(target);
434+
checkArgument(arguments.size() == 3 || arguments.size() == 4);
435+
CelExpr arg0 = validatedIterationVariable(exprFactory, arguments.get(0));
436+
if (arg0.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
437+
return Optional.of(arg0);
438+
}
439+
CelExpr arg1 = validatedIterationVariable(exprFactory, arguments.get(1));
440+
if (arg1.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
441+
return Optional.of(arg1);
442+
}
443+
CelExpr transform;
444+
CelExpr filter = null;
445+
if (arguments.size() == 4) {
446+
filter = checkNotNull(arguments.get(2));
447+
transform = checkNotNull(arguments.get(3));
448+
} else {
449+
transform = checkNotNull(arguments.get(2));
450+
}
451+
CelExpr accuInit = exprFactory.newMap();
452+
CelExpr condition = exprFactory.newBoolLiteral(true);
453+
CelExpr step =
454+
exprFactory.newGlobalCall(
455+
MAP_INSERT_FUNCTION,
456+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()),
457+
transform);
458+
if (filter != null) {
459+
step =
460+
exprFactory.newGlobalCall(
461+
Operator.CONDITIONAL.getFunction(),
462+
filter,
463+
step,
464+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()));
465+
}
466+
return Optional.of(
467+
exprFactory.fold(
468+
arg0.ident().name(),
469+
arg1.ident().name(),
470+
target,
471+
exprFactory.getAccumulatorVarName(),
472+
accuInit,
473+
condition,
474+
step,
475+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName())));
476+
}
477+
223478
private static CelExpr validatedIterationVariable(
224479
CelMacroExprFactory exprFactory, CelExpr argument) {
225480

extensions/src/main/java/dev/cel/extensions/CelExtensions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,8 @@ public static CelExtensionLibrary<? extends CelExtensionLibrary.FeatureSet> getE
363363
return CelSetsExtensions.library(options);
364364
case "strings":
365365
return CelStringExtensions.library();
366+
case "comprehensions":
367+
return CelComprehensionsExtensions.library();
366368
// TODO: add support for remaining standard extensions
367369
default:
368370
throw new IllegalArgumentException("Unknown standard extension '" + name + "'");

0 commit comments

Comments
 (0)