1616
1717import static com .google .common .base .Preconditions .checkNotNull ;
1818
19+ import com .google .common .base .Defaults ;
20+ import com .google .common .collect .ImmutableMap ;
21+ import com .google .common .primitives .UnsignedLong ;
1922import com .google .errorprone .annotations .Immutable ;
23+ import com .google .protobuf .ByteString ;
24+ import com .google .protobuf .CodedInputStream ;
25+ import com .google .protobuf .ExtensionRegistryLite ;
2026import com .google .protobuf .MessageLite ;
27+ import com .google .protobuf .WireFormat ;
2128import dev .cel .common .annotations .Internal ;
2229import dev .cel .common .internal .CelLiteDescriptorPool ;
2330import dev .cel .common .internal .WellKnownProto ;
31+ import dev .cel .protobuf .CelLiteDescriptor .FieldLiteDescriptor ;
32+ import dev .cel .protobuf .CelLiteDescriptor .FieldLiteDescriptor .CelFieldValueType ;
33+ import dev .cel .protobuf .CelLiteDescriptor .FieldLiteDescriptor .JavaType ;
2434import dev .cel .protobuf .CelLiteDescriptor .MessageLiteDescriptor ;
35+ import java .io .IOException ;
36+ import java .util .AbstractMap ;
37+ import java .util .ArrayList ;
38+ import java .util .Collections ;
39+ import java .util .HashMap ;
40+ import java .util .LinkedHashMap ;
41+ import java .util .List ;
42+ import java .util .Map ;
2543
2644/**
2745 * {@code ProtoLiteCelValueConverter} handles bidirectional conversion between native Java and
@@ -43,6 +61,264 @@ public static ProtoLiteCelValueConverter newInstance(
4361 return new ProtoLiteCelValueConverter (celLiteDescriptorPool );
4462 }
4563
64+ private static Object readPrimitiveField (
65+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
66+ switch (fieldDescriptor .getProtoFieldType ()) {
67+ case SINT32 :
68+ return inputStream .readSInt32 ();
69+ case SINT64 :
70+ return inputStream .readSInt64 ();
71+ case INT32 :
72+ case ENUM :
73+ return inputStream .readInt32 ();
74+ case INT64 :
75+ return inputStream .readInt64 ();
76+ case UINT32 :
77+ return UnsignedLong .fromLongBits (inputStream .readUInt32 ());
78+ case UINT64 :
79+ return UnsignedLong .fromLongBits (inputStream .readUInt64 ());
80+ case BOOL :
81+ return inputStream .readBool ();
82+ case FLOAT :
83+ case FIXED32 :
84+ case SFIXED32 :
85+ return readFixed32BitField (inputStream , fieldDescriptor );
86+ case DOUBLE :
87+ case FIXED64 :
88+ case SFIXED64 :
89+ return readFixed64BitField (inputStream , fieldDescriptor );
90+ default :
91+ throw new IllegalStateException (
92+ "Unexpected field type: " + fieldDescriptor .getProtoFieldType ());
93+ }
94+ }
95+
96+ private static Object readFixed32BitField (
97+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
98+ switch (fieldDescriptor .getProtoFieldType ()) {
99+ case FLOAT :
100+ return inputStream .readFloat ();
101+ case FIXED32 :
102+ case SFIXED32 :
103+ return inputStream .readRawLittleEndian32 ();
104+ default :
105+ throw new IllegalStateException (
106+ "Unexpected field type: " + fieldDescriptor .getProtoFieldType ());
107+ }
108+ }
109+
110+ private static Object readFixed64BitField (
111+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
112+ switch (fieldDescriptor .getProtoFieldType ()) {
113+ case DOUBLE :
114+ return inputStream .readDouble ();
115+ case FIXED64 :
116+ case SFIXED64 :
117+ return inputStream .readRawLittleEndian64 ();
118+ default :
119+ throw new IllegalStateException (
120+ "Unexpected field type: " + fieldDescriptor .getProtoFieldType ());
121+ }
122+ }
123+
124+ private Object readLengthDelimitedField (
125+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
126+ FieldLiteDescriptor .Type fieldType = fieldDescriptor .getProtoFieldType ();
127+
128+ switch (fieldType ) {
129+ case BYTES :
130+ return inputStream .readBytes ();
131+ case MESSAGE :
132+ MessageLite .Builder builder =
133+ getDefaultMessageBuilder (fieldDescriptor .getFieldProtoTypeName ());
134+
135+ inputStream .readMessage (builder , ExtensionRegistryLite .getEmptyRegistry ());
136+ return builder .build ();
137+ case STRING :
138+ return inputStream .readStringRequireUtf8 ();
139+ default :
140+ throw new IllegalStateException ("Unexpected field type: " + fieldType );
141+ }
142+ }
143+
144+ private MessageLite .Builder getDefaultMessageBuilder (String protoTypeName ) {
145+ return descriptorPool .getDescriptorOrThrow (protoTypeName ).newMessageBuilder ();
146+ }
147+
148+ CelValue getDefaultCelValue (String protoTypeName , String fieldName ) {
149+ MessageLiteDescriptor messageDescriptor = descriptorPool .getDescriptorOrThrow (protoTypeName );
150+ FieldLiteDescriptor fieldDescriptor = messageDescriptor .getByFieldNameOrThrow (fieldName );
151+
152+ Object defaultValue = getDefaultValue (fieldDescriptor );
153+ if (defaultValue instanceof MessageLite ) {
154+ return fromProtoMessageToCelValue (
155+ fieldDescriptor .getFieldProtoTypeName (), (MessageLite ) defaultValue );
156+ } else {
157+ return fromJavaObjectToCelValue (getDefaultValue (fieldDescriptor ));
158+ }
159+ }
160+
161+ private Object getDefaultValue (FieldLiteDescriptor fieldDescriptor ) {
162+ FieldLiteDescriptor .CelFieldValueType celFieldValueType =
163+ fieldDescriptor .getCelFieldValueType ();
164+ switch (celFieldValueType ) {
165+ case LIST :
166+ return Collections .unmodifiableList (new ArrayList <>());
167+ case MAP :
168+ return Collections .unmodifiableMap (new HashMap <>());
169+ case SCALAR :
170+ return getScalarDefaultValue (fieldDescriptor );
171+ }
172+ throw new IllegalStateException ("Unexpected cel field value type: " + celFieldValueType );
173+ }
174+
175+ private Object getScalarDefaultValue (FieldLiteDescriptor fieldDescriptor ) {
176+ JavaType type = fieldDescriptor .getJavaType ();
177+ switch (type ) {
178+ case INT :
179+ return fieldDescriptor .getProtoFieldType ().equals (FieldLiteDescriptor .Type .UINT32 )
180+ ? UnsignedLong .ZERO
181+ : Defaults .defaultValue (long .class );
182+ case LONG :
183+ return fieldDescriptor .getProtoFieldType ().equals (FieldLiteDescriptor .Type .UINT64 )
184+ ? UnsignedLong .ZERO
185+ : Defaults .defaultValue (long .class );
186+ case ENUM :
187+ return Defaults .defaultValue (long .class );
188+ case FLOAT :
189+ return Defaults .defaultValue (float .class );
190+ case DOUBLE :
191+ return Defaults .defaultValue (double .class );
192+ case BOOLEAN :
193+ return Defaults .defaultValue (boolean .class );
194+ case STRING :
195+ return "" ;
196+ case BYTE_STRING :
197+ return ByteString .EMPTY ;
198+ case MESSAGE :
199+ if (WellKnownProto .isWrapperType (fieldDescriptor .getFieldProtoTypeName ())) {
200+ return NullValue .NULL_VALUE ;
201+ }
202+
203+ return getDefaultMessageBuilder (fieldDescriptor .getFieldProtoTypeName ()).build ();
204+ }
205+ throw new IllegalStateException ("Unexpected java type: " + type );
206+ }
207+
208+ private List <Object > readPackedRepeatedFields (
209+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
210+ int length = inputStream .readInt32 ();
211+ int oldLimit = inputStream .pushLimit (length );
212+ List <Object > repeatedFieldValues = new ArrayList <>();
213+ while (inputStream .getBytesUntilLimit () > 0 ) {
214+ Object value = readPrimitiveField (inputStream , fieldDescriptor );
215+ repeatedFieldValues .add (value );
216+ }
217+ inputStream .popLimit (oldLimit );
218+ return Collections .unmodifiableList (repeatedFieldValues );
219+ }
220+
221+ private Map .Entry <Object , Object > readSingleMapEntry (
222+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
223+ ImmutableMap <String , Object > singleMapEntry =
224+ readAllFields (inputStream .readByteArray (), fieldDescriptor .getFieldProtoTypeName ());
225+ Object key = checkNotNull (singleMapEntry .get ("key" ));
226+ Object value = checkNotNull (singleMapEntry .get ("value" ));
227+
228+ return new AbstractMap .SimpleEntry <>(key , value );
229+ }
230+
231+ private ImmutableMap <String , Object > readAllFields (byte [] bytes , String protoTypeName )
232+ throws IOException {
233+ MessageLiteDescriptor messageDescriptor = descriptorPool .getDescriptorOrThrow (protoTypeName );
234+ CodedInputStream inputStream = CodedInputStream .newInstance (bytes );
235+
236+ ImmutableMap .Builder <String , Object > fieldValues = ImmutableMap .builder ();
237+ Map <Integer , List <Object >> nonPackedRepeatedFields = new LinkedHashMap <>();
238+ Map <Integer , Map <Object , Object >> mapFieldValues = new LinkedHashMap <>();
239+ for (int iterCount = 0 ; iterCount < bytes .length ; iterCount ++) {
240+ int tag = inputStream .readTag ();
241+ if (tag == 0 ) {
242+ break ;
243+ }
244+
245+ int tagWireType = WireFormat .getTagWireType (tag );
246+ int fieldNumber = WireFormat .getTagFieldNumber (tag );
247+ FieldLiteDescriptor fieldDescriptor = messageDescriptor .getByFieldNumberOrThrow (fieldNumber );
248+
249+ Object payload ;
250+ switch (tagWireType ) {
251+ case WireFormat .WIRETYPE_VARINT :
252+ payload = readPrimitiveField (inputStream , fieldDescriptor );
253+ break ;
254+ case WireFormat .WIRETYPE_FIXED32 :
255+ payload = readFixed32BitField (inputStream , fieldDescriptor );
256+ break ;
257+ case WireFormat .WIRETYPE_FIXED64 :
258+ payload = readFixed64BitField (inputStream , fieldDescriptor );
259+ break ;
260+ case WireFormat .WIRETYPE_LENGTH_DELIMITED :
261+ CelFieldValueType celFieldValueType = fieldDescriptor .getCelFieldValueType ();
262+ switch (celFieldValueType ) {
263+ case LIST :
264+ if (fieldDescriptor .getIsPacked ()) {
265+ payload = readPackedRepeatedFields (inputStream , fieldDescriptor );
266+ } else {
267+ FieldLiteDescriptor .Type protoFieldType = fieldDescriptor .getProtoFieldType ();
268+ boolean isLenDelimited =
269+ protoFieldType .equals (FieldLiteDescriptor .Type .MESSAGE )
270+ || protoFieldType .equals (FieldLiteDescriptor .Type .STRING )
271+ || protoFieldType .equals (FieldLiteDescriptor .Type .BYTES );
272+ if (!isLenDelimited ) {
273+ throw new IllegalStateException (
274+ "Unexpected field type encountered for LEN-Delimited record: "
275+ + protoFieldType );
276+ }
277+
278+ payload = readLengthDelimitedField (inputStream , fieldDescriptor );
279+ }
280+ break ;
281+ case MAP :
282+ Map <Object , Object > fieldMap =
283+ mapFieldValues .computeIfAbsent (fieldNumber , (unused ) -> new LinkedHashMap <>());
284+ Map .Entry <Object , Object > mapEntry = readSingleMapEntry (inputStream , fieldDescriptor );
285+ fieldMap .put (mapEntry .getKey (), mapEntry .getValue ());
286+ payload = fieldMap ;
287+ break ;
288+ default :
289+ payload = readLengthDelimitedField (inputStream , fieldDescriptor );
290+ break ;
291+ }
292+ break ;
293+ case WireFormat .WIRETYPE_START_GROUP :
294+ case WireFormat .WIRETYPE_END_GROUP :
295+ // TODO: Support groups
296+ throw new UnsupportedOperationException ("Groups are not supported" );
297+ default :
298+ throw new IllegalArgumentException ("Unexpected wire type: " + tagWireType );
299+ }
300+
301+ if (fieldDescriptor .getCelFieldValueType ().equals (CelFieldValueType .LIST )
302+ && !fieldDescriptor .getIsPacked ()) {
303+ List <Object > repeatedValues =
304+ nonPackedRepeatedFields .computeIfAbsent (fieldNumber , (unused ) -> new ArrayList <>());
305+ repeatedValues .add (payload );
306+ payload = repeatedValues ;
307+ }
308+
309+ fieldValues .put (fieldDescriptor .getFieldName (), payload );
310+ }
311+
312+ // Protobuf encoding follows a "last one wins" semantics. This means for duplicated fields,
313+ // we accept the last value encountered.
314+ return fieldValues .buildKeepingLast ();
315+ }
316+
317+ ImmutableMap <String , Object > readAllFields (MessageLite msg , String protoTypeName )
318+ throws IOException {
319+ return readAllFields (msg .toByteArray (), protoTypeName );
320+ }
321+
46322 @ Override
47323 public CelValue fromProtoMessageToCelValue (String protoTypeName , MessageLite msg ) {
48324 checkNotNull (msg );
0 commit comments