Skip to content

Commit c0783d2

Browse files
l46kokcopybara-github
authored andcommitted
Perform field selections on lite messages by reading from the wire format
PiperOrigin-RevId: 748423924
1 parent 220312c commit c0783d2

File tree

4 files changed

+508
-2
lines changed

4 files changed

+508
-2
lines changed

common/src/main/java/dev/cel/common/values/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ java_library(
179179
"//protobuf:cel_lite_descriptor",
180180
"@maven//:com_google_errorprone_error_prone_annotations",
181181
"@maven//:com_google_guava_guava",
182+
"@maven//:com_google_protobuf_protobuf_java",
182183
"@maven//:org_jspecify_jspecify",
183184
"@maven_android//:com_google_protobuf_protobuf_javalite",
184185
],

common/src/main/java/dev/cel/common/values/ProtoLiteCelValueConverter.java

Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,30 @@
1616

1717
import static com.google.common.base.Preconditions.checkNotNull;
1818

19+
import com.google.common.base.Defaults;
20+
import com.google.common.collect.ImmutableMap;
21+
import com.google.common.primitives.UnsignedLong;
1922
import com.google.errorprone.annotations.Immutable;
23+
import com.google.protobuf.ByteString;
24+
import com.google.protobuf.CodedInputStream;
25+
import com.google.protobuf.ExtensionRegistryLite;
2026
import com.google.protobuf.MessageLite;
27+
import com.google.protobuf.WireFormat;
2128
import dev.cel.common.annotations.Internal;
2229
import dev.cel.common.internal.CelLiteDescriptorPool;
2330
import dev.cel.common.internal.WellKnownProto;
31+
import dev.cel.protobuf.CelLiteDescriptor.FieldLiteDescriptor;
32+
import dev.cel.protobuf.CelLiteDescriptor.FieldLiteDescriptor.CelFieldValueType;
33+
import dev.cel.protobuf.CelLiteDescriptor.FieldLiteDescriptor.JavaType;
2434
import dev.cel.protobuf.CelLiteDescriptor.MessageLiteDescriptor;
35+
import java.io.IOException;
36+
import java.util.AbstractMap;
37+
import java.util.ArrayList;
38+
import java.util.Collections;
39+
import java.util.HashMap;
40+
import java.util.LinkedHashMap;
41+
import java.util.List;
42+
import java.util.Map;
2543

2644
/**
2745
* {@code ProtoLiteCelValueConverter} handles bidirectional conversion between native Java and
@@ -43,6 +61,264 @@ public static ProtoLiteCelValueConverter newInstance(
4361
return new ProtoLiteCelValueConverter(celLiteDescriptorPool);
4462
}
4563

64+
private static Object readPrimitiveField(
65+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
66+
switch (fieldDescriptor.getProtoFieldType()) {
67+
case SINT32:
68+
return inputStream.readSInt32();
69+
case SINT64:
70+
return inputStream.readSInt64();
71+
case INT32:
72+
case ENUM:
73+
return inputStream.readInt32();
74+
case INT64:
75+
return inputStream.readInt64();
76+
case UINT32:
77+
return UnsignedLong.fromLongBits(inputStream.readUInt32());
78+
case UINT64:
79+
return UnsignedLong.fromLongBits(inputStream.readUInt64());
80+
case BOOL:
81+
return inputStream.readBool();
82+
case FLOAT:
83+
case FIXED32:
84+
case SFIXED32:
85+
return readFixed32BitField(inputStream, fieldDescriptor);
86+
case DOUBLE:
87+
case FIXED64:
88+
case SFIXED64:
89+
return readFixed64BitField(inputStream, fieldDescriptor);
90+
default:
91+
throw new IllegalStateException(
92+
"Unexpected field type: " + fieldDescriptor.getProtoFieldType());
93+
}
94+
}
95+
96+
private static Object readFixed32BitField(
97+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
98+
switch (fieldDescriptor.getProtoFieldType()) {
99+
case FLOAT:
100+
return inputStream.readFloat();
101+
case FIXED32:
102+
case SFIXED32:
103+
return inputStream.readRawLittleEndian32();
104+
default:
105+
throw new IllegalStateException(
106+
"Unexpected field type: " + fieldDescriptor.getProtoFieldType());
107+
}
108+
}
109+
110+
private static Object readFixed64BitField(
111+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
112+
switch (fieldDescriptor.getProtoFieldType()) {
113+
case DOUBLE:
114+
return inputStream.readDouble();
115+
case FIXED64:
116+
case SFIXED64:
117+
return inputStream.readRawLittleEndian64();
118+
default:
119+
throw new IllegalStateException(
120+
"Unexpected field type: " + fieldDescriptor.getProtoFieldType());
121+
}
122+
}
123+
124+
private Object readLengthDelimitedField(
125+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
126+
FieldLiteDescriptor.Type fieldType = fieldDescriptor.getProtoFieldType();
127+
128+
switch (fieldType) {
129+
case BYTES:
130+
return inputStream.readBytes();
131+
case MESSAGE:
132+
MessageLite.Builder builder =
133+
getDefaultMessageBuilder(fieldDescriptor.getFieldProtoTypeName());
134+
135+
inputStream.readMessage(builder, ExtensionRegistryLite.getEmptyRegistry());
136+
return builder.build();
137+
case STRING:
138+
return inputStream.readStringRequireUtf8();
139+
default:
140+
throw new IllegalStateException("Unexpected field type: " + fieldType);
141+
}
142+
}
143+
144+
private MessageLite.Builder getDefaultMessageBuilder(String protoTypeName) {
145+
return descriptorPool.getDescriptorOrThrow(protoTypeName).newMessageBuilder();
146+
}
147+
148+
CelValue getDefaultCelValue(String protoTypeName, String fieldName) {
149+
MessageLiteDescriptor messageDescriptor = descriptorPool.getDescriptorOrThrow(protoTypeName);
150+
FieldLiteDescriptor fieldDescriptor = messageDescriptor.getByFieldNameOrThrow(fieldName);
151+
152+
Object defaultValue = getDefaultValue(fieldDescriptor);
153+
if (defaultValue instanceof MessageLite) {
154+
return fromProtoMessageToCelValue(
155+
fieldDescriptor.getFieldProtoTypeName(), (MessageLite) defaultValue);
156+
} else {
157+
return fromJavaObjectToCelValue(getDefaultValue(fieldDescriptor));
158+
}
159+
}
160+
161+
private Object getDefaultValue(FieldLiteDescriptor fieldDescriptor) {
162+
FieldLiteDescriptor.CelFieldValueType celFieldValueType =
163+
fieldDescriptor.getCelFieldValueType();
164+
switch (celFieldValueType) {
165+
case LIST:
166+
return Collections.unmodifiableList(new ArrayList<>());
167+
case MAP:
168+
return Collections.unmodifiableMap(new HashMap<>());
169+
case SCALAR:
170+
return getScalarDefaultValue(fieldDescriptor);
171+
}
172+
throw new IllegalStateException("Unexpected cel field value type: " + celFieldValueType);
173+
}
174+
175+
private Object getScalarDefaultValue(FieldLiteDescriptor fieldDescriptor) {
176+
JavaType type = fieldDescriptor.getJavaType();
177+
switch (type) {
178+
case INT:
179+
return fieldDescriptor.getProtoFieldType().equals(FieldLiteDescriptor.Type.UINT32)
180+
? UnsignedLong.ZERO
181+
: Defaults.defaultValue(long.class);
182+
case LONG:
183+
return fieldDescriptor.getProtoFieldType().equals(FieldLiteDescriptor.Type.UINT64)
184+
? UnsignedLong.ZERO
185+
: Defaults.defaultValue(long.class);
186+
case ENUM:
187+
return Defaults.defaultValue(long.class);
188+
case FLOAT:
189+
return Defaults.defaultValue(float.class);
190+
case DOUBLE:
191+
return Defaults.defaultValue(double.class);
192+
case BOOLEAN:
193+
return Defaults.defaultValue(boolean.class);
194+
case STRING:
195+
return "";
196+
case BYTE_STRING:
197+
return ByteString.EMPTY;
198+
case MESSAGE:
199+
if (WellKnownProto.isWrapperType(fieldDescriptor.getFieldProtoTypeName())) {
200+
return NullValue.NULL_VALUE;
201+
}
202+
203+
return getDefaultMessageBuilder(fieldDescriptor.getFieldProtoTypeName()).build();
204+
}
205+
throw new IllegalStateException("Unexpected java type: " + type);
206+
}
207+
208+
private List<Object> readPackedRepeatedFields(
209+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
210+
int length = inputStream.readInt32();
211+
int oldLimit = inputStream.pushLimit(length);
212+
List<Object> repeatedFieldValues = new ArrayList<>();
213+
while (inputStream.getBytesUntilLimit() > 0) {
214+
Object value = readPrimitiveField(inputStream, fieldDescriptor);
215+
repeatedFieldValues.add(value);
216+
}
217+
inputStream.popLimit(oldLimit);
218+
return Collections.unmodifiableList(repeatedFieldValues);
219+
}
220+
221+
private Map.Entry<Object, Object> readSingleMapEntry(
222+
CodedInputStream inputStream, FieldLiteDescriptor fieldDescriptor) throws IOException {
223+
ImmutableMap<String, Object> singleMapEntry =
224+
readAllFields(inputStream.readByteArray(), fieldDescriptor.getFieldProtoTypeName());
225+
Object key = checkNotNull(singleMapEntry.get("key"));
226+
Object value = checkNotNull(singleMapEntry.get("value"));
227+
228+
return new AbstractMap.SimpleEntry<>(key, value);
229+
}
230+
231+
private ImmutableMap<String, Object> readAllFields(byte[] bytes, String protoTypeName)
232+
throws IOException {
233+
MessageLiteDescriptor messageDescriptor = descriptorPool.getDescriptorOrThrow(protoTypeName);
234+
CodedInputStream inputStream = CodedInputStream.newInstance(bytes);
235+
236+
ImmutableMap.Builder<String, Object> fieldValues = ImmutableMap.builder();
237+
Map<Integer, List<Object>> nonPackedRepeatedFields = new LinkedHashMap<>();
238+
Map<Integer, Map<Object, Object>> mapFieldValues = new LinkedHashMap<>();
239+
for (int iterCount = 0; iterCount < bytes.length; iterCount++) {
240+
int tag = inputStream.readTag();
241+
if (tag == 0) {
242+
break;
243+
}
244+
245+
int tagWireType = WireFormat.getTagWireType(tag);
246+
int fieldNumber = WireFormat.getTagFieldNumber(tag);
247+
FieldLiteDescriptor fieldDescriptor = messageDescriptor.getByFieldNumberOrThrow(fieldNumber);
248+
249+
Object payload;
250+
switch (tagWireType) {
251+
case WireFormat.WIRETYPE_VARINT:
252+
payload = readPrimitiveField(inputStream, fieldDescriptor);
253+
break;
254+
case WireFormat.WIRETYPE_FIXED32:
255+
payload = readFixed32BitField(inputStream, fieldDescriptor);
256+
break;
257+
case WireFormat.WIRETYPE_FIXED64:
258+
payload = readFixed64BitField(inputStream, fieldDescriptor);
259+
break;
260+
case WireFormat.WIRETYPE_LENGTH_DELIMITED:
261+
CelFieldValueType celFieldValueType = fieldDescriptor.getCelFieldValueType();
262+
switch (celFieldValueType) {
263+
case LIST:
264+
if (fieldDescriptor.getIsPacked()) {
265+
payload = readPackedRepeatedFields(inputStream, fieldDescriptor);
266+
} else {
267+
FieldLiteDescriptor.Type protoFieldType = fieldDescriptor.getProtoFieldType();
268+
boolean isLenDelimited =
269+
protoFieldType.equals(FieldLiteDescriptor.Type.MESSAGE)
270+
|| protoFieldType.equals(FieldLiteDescriptor.Type.STRING)
271+
|| protoFieldType.equals(FieldLiteDescriptor.Type.BYTES);
272+
if (!isLenDelimited) {
273+
throw new IllegalStateException(
274+
"Unexpected field type encountered for LEN-Delimited record: "
275+
+ protoFieldType);
276+
}
277+
278+
payload = readLengthDelimitedField(inputStream, fieldDescriptor);
279+
}
280+
break;
281+
case MAP:
282+
Map<Object, Object> fieldMap =
283+
mapFieldValues.computeIfAbsent(fieldNumber, (unused) -> new LinkedHashMap<>());
284+
Map.Entry<Object, Object> mapEntry = readSingleMapEntry(inputStream, fieldDescriptor);
285+
fieldMap.put(mapEntry.getKey(), mapEntry.getValue());
286+
payload = fieldMap;
287+
break;
288+
default:
289+
payload = readLengthDelimitedField(inputStream, fieldDescriptor);
290+
break;
291+
}
292+
break;
293+
case WireFormat.WIRETYPE_START_GROUP:
294+
case WireFormat.WIRETYPE_END_GROUP:
295+
// TODO: Support groups
296+
throw new UnsupportedOperationException("Groups are not supported");
297+
default:
298+
throw new IllegalArgumentException("Unexpected wire type: " + tagWireType);
299+
}
300+
301+
if (fieldDescriptor.getCelFieldValueType().equals(CelFieldValueType.LIST)
302+
&& !fieldDescriptor.getIsPacked()) {
303+
List<Object> repeatedValues =
304+
nonPackedRepeatedFields.computeIfAbsent(fieldNumber, (unused) -> new ArrayList<>());
305+
repeatedValues.add(payload);
306+
payload = repeatedValues;
307+
}
308+
309+
fieldValues.put(fieldDescriptor.getFieldName(), payload);
310+
}
311+
312+
// Protobuf encoding follows a "last one wins" semantics. This means for duplicated fields,
313+
// we accept the last value encountered.
314+
return fieldValues.buildKeepingLast();
315+
}
316+
317+
ImmutableMap<String, Object> readAllFields(MessageLite msg, String protoTypeName)
318+
throws IOException {
319+
return readAllFields(msg.toByteArray(), protoTypeName);
320+
}
321+
46322
@Override
47323
public CelValue fromProtoMessageToCelValue(String protoTypeName, MessageLite msg) {
48324
checkNotNull(msg);

common/src/main/java/dev/cel/common/values/ProtoMessageLiteValue.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
package dev.cel.common.values;
1616

1717
import com.google.auto.value.AutoValue;
18+
import com.google.auto.value.extension.memoized.Memoized;
1819
import com.google.common.base.Preconditions;
20+
import com.google.common.collect.ImmutableMap;
1921
import com.google.errorprone.annotations.Immutable;
2022
import com.google.protobuf.MessageLite;
2123
import dev.cel.common.types.CelType;
2224
import dev.cel.common.types.StructTypeReference;
25+
import java.io.IOException;
2326
import java.util.Optional;
2427

2528
/**
@@ -42,19 +45,32 @@ public abstract class ProtoMessageLiteValue extends StructValue<StringValue> {
4245

4346
abstract ProtoLiteCelValueConverter protoLiteCelValueConverter();
4447

48+
@Memoized
49+
ImmutableMap<String, Object> fieldValues() {
50+
try {
51+
return protoLiteCelValueConverter().readAllFields(value(), celType().name());
52+
} catch (IOException e) {
53+
throw new IllegalStateException("Unable to read message fields for " + celType().name(), e);
54+
}
55+
}
56+
4557
@Override
4658
public boolean isZeroValue() {
4759
return value().getDefaultInstanceForType().equals(value());
4860
}
4961

5062
@Override
5163
public CelValue select(StringValue field) {
52-
throw new UnsupportedOperationException("Not implemented yet");
64+
return find(field)
65+
.orElseGet(
66+
() -> protoLiteCelValueConverter().getDefaultCelValue(celType().name(), field.value()));
5367
}
5468

5569
@Override
5670
public Optional<CelValue> find(StringValue field) {
57-
throw new UnsupportedOperationException("Not implemented yet");
71+
Object fieldValue = fieldValues().get(field.value());
72+
return Optional.ofNullable(fieldValue)
73+
.map(value -> protoLiteCelValueConverter().fromJavaObjectToCelValue(fieldValue));
5874
}
5975

6076
public static ProtoMessageLiteValue create(

0 commit comments

Comments
 (0)