Skip to content

Commit 303fc31

Browse files
l46kokcopybara-github
authored andcommitted
Add ProtoMessageLiteValueProvider
PiperOrigin-RevId: 749868732
1 parent 220312c commit 303fc31

File tree

8 files changed

+700
-2
lines changed

8 files changed

+700
-2
lines changed

common/src/main/java/dev/cel/common/values/BUILD.bazel

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,26 @@ java_library(
179179
"//protobuf:cel_lite_descriptor",
180180
"@maven//:com_google_errorprone_error_prone_annotations",
181181
"@maven//:com_google_guava_guava",
182+
"@maven//:com_google_protobuf_protobuf_java",
182183
"@maven//:org_jspecify_jspecify",
183184
"@maven_android//:com_google_protobuf_protobuf_javalite",
184185
],
185186
)
187+
188+
java_library(
189+
name = "proto_message_lite_value_provider",
190+
srcs = ["ProtoMessageLiteValueProvider.java"],
191+
tags = [
192+
],
193+
deps = [
194+
":cel_value",
195+
":cel_value_provider",
196+
":proto_message_lite_value",
197+
"//common/internal:cel_lite_descriptor_pool",
198+
"//common/internal:default_lite_descriptor_pool",
199+
"//protobuf:cel_lite_descriptor",
200+
"@maven//:com_google_errorprone_error_prone_annotations",
201+
"@maven//:com_google_guava_guava",
202+
"@maven_android//:com_google_protobuf_protobuf_javalite",
203+
],
204+
)

common/src/main/java/dev/cel/common/values/ProtoLiteCelValueConverter.java

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,32 @@
1616

1717
import static com.google.common.base.Preconditions.checkNotNull;
1818

19+
import com.google.common.annotations.VisibleForTesting;
20+
import com.google.common.base.Defaults;
21+
import com.google.common.collect.ImmutableMap;
22+
import com.google.common.primitives.UnsignedLong;
1923
import com.google.errorprone.annotations.Immutable;
24+
import com.google.protobuf.ByteString;
25+
import com.google.protobuf.CodedInputStream;
26+
import com.google.protobuf.ExtensionRegistryLite;
2027
import com.google.protobuf.MessageLite;
28+
import com.google.protobuf.WireFormat;
2129
import dev.cel.common.annotations.Internal;
2230
import dev.cel.common.internal.CelLiteDescriptorPool;
2331
import dev.cel.common.internal.WellKnownProto;
32+
import dev.cel.protobuf.CelLiteDescriptor.FieldLiteDescriptor;
33+
import dev.cel.protobuf.CelLiteDescriptor.FieldLiteDescriptor.CelFieldValueType;
34+
import dev.cel.protobuf.CelLiteDescriptor.FieldLiteDescriptor.JavaType;
2435
import dev.cel.protobuf.CelLiteDescriptor.MessageLiteDescriptor;
36+
import java.io.IOException;
37+
import java.util.AbstractMap;
38+
import java.util.ArrayList;
39+
import java.util.Collection;
40+
import java.util.Collections;
41+
import java.util.HashMap;
42+
import java.util.LinkedHashMap;
43+
import java.util.List;
44+
import java.util.Map;
2545

2646
/**
2747
* {@code ProtoLiteCelValueConverter} handles bidirectional conversion between native Java and
@@ -43,6 +63,276 @@ public static ProtoLiteCelValueConverter newInstance(
4363
return new ProtoLiteCelValueConverter(celLiteDescriptorPool);
4464
}
4565

66+
private static Object readPrimitiveField(
67+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
68+
switch (fieldDescriptor.getProtoFieldType()) {
69+
case SINT32:
70+
return inputStream.readSInt32();
71+
case SINT64:
72+
return inputStream.readSInt64();
73+
case INT32:
74+
case ENUM:
75+
return inputStream.readInt32();
76+
case INT64:
77+
return inputStream.readInt64();
78+
case UINT32:
79+
return UnsignedLong.fromLongBits(inputStream.readUInt32());
80+
case UINT64:
81+
return UnsignedLong.fromLongBits(inputStream.readUInt64());
82+
case BOOL:
83+
return inputStream.readBool();
84+
case FLOAT:
85+
case FIXED32:
86+
case SFIXED32:
87+
return readFixed32BitField(inputStream, fieldDescriptor);
88+
case DOUBLE:
89+
case FIXED64:
90+
case SFIXED64:
91+
return readFixed64BitField(inputStream, fieldDescriptor);
92+
default:
93+
throw new IllegalStateException(
94+
"Unexpected field type: " + fieldDescriptor.getProtoFieldType());
95+
}
96+
}
97+
98+
private static Object readFixed32BitField(
99+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
100+
switch (fieldDescriptor.getProtoFieldType()) {
101+
case FLOAT:
102+
return inputStream.readFloat();
103+
case FIXED32:
104+
case SFIXED32:
105+
return inputStream.readRawLittleEndian32();
106+
default:
107+
throw new IllegalStateException(
108+
"Unexpected field type: " + fieldDescriptor.getProtoFieldType());
109+
}
110+
}
111+
112+
private static Object readFixed64BitField(
113+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
114+
switch (fieldDescriptor.getProtoFieldType()) {
115+
case DOUBLE:
116+
return inputStream.readDouble();
117+
case FIXED64:
118+
case SFIXED64:
119+
return inputStream.readRawLittleEndian64();
120+
default:
121+
throw new IllegalStateException(
122+
"Unexpected field type: " + fieldDescriptor.getProtoFieldType());
123+
}
124+
}
125+
126+
private Object readLengthDelimitedField(
127+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
128+
FieldLiteDescriptor.Type fieldType = fieldDescriptor.getProtoFieldType();
129+
130+
switch (fieldType) {
131+
case BYTES:
132+
return inputStream.readBytes();
133+
case MESSAGE:
134+
MessageLite.Builder builder =
135+
getDefaultMessageBuilder(fieldDescriptor.getFieldProtoTypeName());
136+
137+
inputStream.readMessage(builder, ExtensionRegistryLite.getEmptyRegistry());
138+
return builder.build();
139+
case STRING:
140+
return inputStream.readStringRequireUtf8();
141+
default:
142+
throw new IllegalStateException("Unexpected field type: " + fieldType);
143+
}
144+
}
145+
146+
private MessageLite.Builder getDefaultMessageBuilder(String protoTypeName) {
147+
return descriptorPool.getDescriptorOrThrow(protoTypeName).newMessageBuilder();
148+
}
149+
150+
CelValue getDefaultCelValue(String protoTypeName, String fieldName) {
151+
MessageLiteDescriptor messageDescriptor = descriptorPool.getDescriptorOrThrow(protoTypeName);
152+
FieldLiteDescriptor fieldDescriptor = messageDescriptor.getByFieldNameOrThrow(fieldName);
153+
154+
Object defaultValue = getDefaultValue(fieldDescriptor);
155+
if (defaultValue instanceof MessageLite) {
156+
return fromProtoMessageToCelValue(
157+
fieldDescriptor.getFieldProtoTypeName(), (MessageLite) defaultValue);
158+
} else {
159+
return fromJavaObjectToCelValue(getDefaultValue(fieldDescriptor));
160+
}
161+
}
162+
163+
private Object getDefaultValue(FieldLiteDescriptor fieldDescriptor) {
164+
FieldLiteDescriptor.CelFieldValueType celFieldValueType =
165+
fieldDescriptor.getCelFieldValueType();
166+
switch (celFieldValueType) {
167+
case LIST:
168+
return Collections.unmodifiableList(new ArrayList<>());
169+
case MAP:
170+
return Collections.unmodifiableMap(new HashMap<>());
171+
case SCALAR:
172+
return getScalarDefaultValue(fieldDescriptor);
173+
}
174+
throw new IllegalStateException("Unexpected cel field value type: " + celFieldValueType);
175+
}
176+
177+
private Object getScalarDefaultValue(FieldLiteDescriptor fieldDescriptor) {
178+
JavaType type = fieldDescriptor.getJavaType();
179+
switch (type) {
180+
case INT:
181+
return fieldDescriptor.getProtoFieldType().equals(FieldLiteDescriptor.Type.UINT32)
182+
? UnsignedLong.ZERO
183+
: Defaults.defaultValue(long.class);
184+
case LONG:
185+
return fieldDescriptor.getProtoFieldType().equals(FieldLiteDescriptor.Type.UINT64)
186+
? UnsignedLong.ZERO
187+
: Defaults.defaultValue(long.class);
188+
case ENUM:
189+
return Defaults.defaultValue(long.class);
190+
case FLOAT:
191+
return Defaults.defaultValue(float.class);
192+
case DOUBLE:
193+
return Defaults.defaultValue(double.class);
194+
case BOOLEAN:
195+
return Defaults.defaultValue(boolean.class);
196+
case STRING:
197+
return "";
198+
case BYTE_STRING:
199+
return ByteString.EMPTY;
200+
case MESSAGE:
201+
if (WellKnownProto.isWrapperType(fieldDescriptor.getFieldProtoTypeName())) {
202+
return NullValue.NULL_VALUE;
203+
}
204+
205+
return getDefaultMessageBuilder(fieldDescriptor.getFieldProtoTypeName()).build();
206+
}
207+
throw new IllegalStateException("Unexpected java type: " + type);
208+
}
209+
210+
private List<Object> readPackedRepeatedFields(
211+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
212+
int length = inputStream.readInt32();
213+
int oldLimit = inputStream.pushLimit(length);
214+
List<Object> repeatedFieldValues = new ArrayList<>();
215+
while (inputStream.getBytesUntilLimit() > 0) {
216+
Object value = readPrimitiveField(inputStream, fieldDescriptor);
217+
repeatedFieldValues.add(value);
218+
}
219+
inputStream.popLimit(oldLimit);
220+
return Collections.unmodifiableList(repeatedFieldValues);
221+
}
222+
223+
private Map.Entry<Object, Object> readSingleMapEntry(
224+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
225+
ImmutableMap<String, Object> singleMapEntry =
226+
readAllFields(inputStream.readByteArray(), fieldDescriptor.getFieldProtoTypeName());
227+
Object key = checkNotNull(singleMapEntry.get("key"));
228+
Object value = checkNotNull(singleMapEntry.get("value"));
229+
230+
return new AbstractMap.SimpleEntry<>(key, value);
231+
}
232+
233+
@VisibleForTesting
234+
ImmutableMap<String, Object> readAllFields(byte[] bytes, String protoTypeName)
235+
throws IOException {
236+
// TODO: Handle unknown fields by collecting them into a separate map.
237+
MessageLiteDescriptor messageDescriptor = descriptorPool.getDescriptorOrThrow(protoTypeName);
238+
CodedInputStream inputStream = CodedInputStream.newInstance(bytes);
239+
240+
ImmutableMap.Builder<String, Object> fieldValues = ImmutableMap.builder();
241+
Map<Integer, List<Object>> repeatedFieldValues = new LinkedHashMap<>();
242+
Map<Integer, Map<Object, Object>> mapFieldValues = new LinkedHashMap<>();
243+
for (int iterCount = 0; iterCount < bytes.length; iterCount++) {
244+
int tag = inputStream.readTag();
245+
if (tag == 0) {
246+
break;
247+
}
248+
249+
int tagWireType = WireFormat.getTagWireType(tag);
250+
int fieldNumber = WireFormat.getTagFieldNumber(tag);
251+
FieldLiteDescriptor fieldDescriptor = messageDescriptor.getByFieldNumberOrThrow(fieldNumber);
252+
253+
Object payload;
254+
switch (tagWireType) {
255+
case WireFormat.WIRETYPE_VARINT:
256+
payload = readPrimitiveField(inputStream, fieldDescriptor);
257+
break;
258+
case WireFormat.WIRETYPE_FIXED32:
259+
payload = readFixed32BitField(inputStream, fieldDescriptor);
260+
break;
261+
case WireFormat.WIRETYPE_FIXED64:
262+
payload = readFixed64BitField(inputStream, fieldDescriptor);
263+
break;
264+
case WireFormat.WIRETYPE_LENGTH_DELIMITED:
265+
CelFieldValueType celFieldValueType = fieldDescriptor.getCelFieldValueType();
266+
switch (celFieldValueType) {
267+
case LIST:
268+
if (fieldDescriptor.getIsPacked()) {
269+
payload = readPackedRepeatedFields(inputStream, fieldDescriptor);
270+
} else {
271+
FieldLiteDescriptor.Type protoFieldType = fieldDescriptor.getProtoFieldType();
272+
boolean isLenDelimited =
273+
protoFieldType.equals(FieldLiteDescriptor.Type.MESSAGE)
274+
|| protoFieldType.equals(FieldLiteDescriptor.Type.STRING)
275+
|| protoFieldType.equals(FieldLiteDescriptor.Type.BYTES);
276+
if (!isLenDelimited) {
277+
throw new IllegalStateException(
278+
"Unexpected field type encountered for LEN-Delimited record: "
279+
+ protoFieldType);
280+
}
281+
282+
payload = readLengthDelimitedField(inputStream, fieldDescriptor);
283+
}
284+
break;
285+
case MAP:
286+
Map<Object, Object> fieldMap =
287+
mapFieldValues.computeIfAbsent(fieldNumber, (unused) -> new LinkedHashMap<>());
288+
Map.Entry<Object, Object> mapEntry = readSingleMapEntry(inputStream, fieldDescriptor);
289+
fieldMap.put(mapEntry.getKey(), mapEntry.getValue());
290+
payload = fieldMap;
291+
break;
292+
default:
293+
payload = readLengthDelimitedField(inputStream, fieldDescriptor);
294+
break;
295+
}
296+
break;
297+
case WireFormat.WIRETYPE_START_GROUP:
298+
case WireFormat.WIRETYPE_END_GROUP:
299+
// TODO: Support groups
300+
throw new UnsupportedOperationException("Groups are not supported");
301+
default:
302+
throw new IllegalArgumentException("Unexpected wire type: " + tagWireType);
303+
}
304+
305+
if (fieldDescriptor.getCelFieldValueType().equals(CelFieldValueType.LIST)) {
306+
String fieldName = fieldDescriptor.getFieldName();
307+
List<Object> repeatedValues =
308+
repeatedFieldValues.computeIfAbsent(
309+
fieldNumber,
310+
(unused) -> {
311+
List<Object> newList = new ArrayList<>();
312+
fieldValues.put(fieldName, newList);
313+
return newList;
314+
});
315+
316+
if (payload instanceof Collection) {
317+
repeatedValues.addAll((Collection<?>) payload);
318+
} else {
319+
repeatedValues.add(payload);
320+
}
321+
} else {
322+
fieldValues.put(fieldDescriptor.getFieldName(), payload);
323+
}
324+
}
325+
326+
// Protobuf encoding follows a "last one wins" semantics. This means for duplicated fields,
327+
// we accept the last value encountered.
328+
return fieldValues.buildKeepingLast();
329+
}
330+
331+
ImmutableMap<String, Object> readAllFields(MessageLite msg, String protoTypeName)
332+
throws IOException {
333+
return readAllFields(msg.toByteArray(), protoTypeName);
334+
}
335+
46336
@Override
47337
public CelValue fromProtoMessageToCelValue(String protoTypeName, MessageLite msg) {
48338
checkNotNull(msg);

common/src/main/java/dev/cel/common/values/ProtoMessageLiteValue.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
package dev.cel.common.values;
1616

1717
import com.google.auto.value.AutoValue;
18+
import com.google.auto.value.extension.memoized.Memoized;
1819
import com.google.common.base.Preconditions;
20+
import com.google.common.collect.ImmutableMap;
1921
import com.google.errorprone.annotations.Immutable;
2022
import com.google.protobuf.MessageLite;
2123
import dev.cel.common.types.CelType;
2224
import dev.cel.common.types.StructTypeReference;
25+
import java.io.IOException;
2326
import java.util.Optional;
2427

2528
/**
@@ -42,19 +45,32 @@ public abstract class ProtoMessageLiteValue extends StructValue<StringValue> {
4245

4346
abstract ProtoLiteCelValueConverter protoLiteCelValueConverter();
4447

48+
@Memoized
49+
ImmutableMap<String, Object> fieldValues() {
50+
try {
51+
return protoLiteCelValueConverter().readAllFields(value(), celType().name());
52+
} catch (IOException e) {
53+
throw new IllegalStateException("Unable to read message fields for " + celType().name(), e);
54+
}
55+
}
56+
4557
@Override
4658
public boolean isZeroValue() {
4759
return value().getDefaultInstanceForType().equals(value());
4860
}
4961

5062
@Override
5163
public CelValue select(StringValue field) {
52-
throw new UnsupportedOperationException("Not implemented yet");
64+
return find(field)
65+
.orElseGet(
66+
() -> protoLiteCelValueConverter().getDefaultCelValue(celType().name(), field.value()));
5367
}
5468

5569
@Override
5670
public Optional<CelValue> find(StringValue field) {
57-
throw new UnsupportedOperationException("Not implemented yet");
71+
Object fieldValue = fieldValues().get(field.value());
72+
return Optional.ofNullable(fieldValue)
73+
.map(value -> protoLiteCelValueConverter().fromJavaObjectToCelValue(fieldValue));
5874
}
5975

6076
public static ProtoMessageLiteValue create(

0 commit comments

Comments
 (0)