1616
1717import static com .google .common .base .Preconditions .checkNotNull ;
1818
19+ import com .google .common .base .Defaults ;
20+ import com .google .common .collect .ImmutableMap ;
21+ import com .google .common .primitives .UnsignedLong ;
1922import com .google .errorprone .annotations .Immutable ;
23+ import com .google .protobuf .ByteString ;
24+ import com .google .protobuf .CodedInputStream ;
25+ import com .google .protobuf .ExtensionRegistryLite ;
2026import com .google .protobuf .MessageLite ;
27+ import com .google .protobuf .WireFormat ;
2128import dev .cel .common .annotations .Internal ;
2229import dev .cel .common .internal .CelLiteDescriptorPool ;
2330import dev .cel .common .internal .WellKnownProto ;
31+ import dev .cel .protobuf .CelLiteDescriptor .FieldLiteDescriptor ;
32+ import dev .cel .protobuf .CelLiteDescriptor .FieldLiteDescriptor .CelFieldValueType ;
33+ import dev .cel .protobuf .CelLiteDescriptor .FieldLiteDescriptor .JavaType ;
2434import dev .cel .protobuf .CelLiteDescriptor .MessageLiteDescriptor ;
35+ import java .io .IOException ;
36+ import java .util .ArrayList ;
37+ import java .util .Collections ;
38+ import java .util .HashMap ;
39+ import java .util .LinkedHashMap ;
40+ import java .util .List ;
41+ import java .util .Map ;
2542
2643/**
2744 * {@code ProtoLiteCelValueConverter} handles bidirectional conversion between native Java and
@@ -43,6 +60,262 @@ public static ProtoLiteCelValueConverter newInstance(
4360 return new ProtoLiteCelValueConverter (celLiteDescriptorPool );
4461 }
4562
63+ private static Object readPrimitiveField (
64+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
65+ switch (fieldDescriptor .getProtoFieldType ()) {
66+ case SINT32 :
67+ return inputStream .readSInt32 ();
68+ case SINT64 :
69+ return inputStream .readSInt64 ();
70+ case INT32 :
71+ case ENUM :
72+ return inputStream .readInt32 ();
73+ case INT64 :
74+ return inputStream .readInt64 ();
75+ case UINT32 :
76+ return UnsignedLong .fromLongBits (inputStream .readUInt32 ());
77+ case UINT64 :
78+ return UnsignedLong .fromLongBits (inputStream .readUInt64 ());
79+ case BOOL :
80+ return inputStream .readBool ();
81+ case FLOAT :
82+ case FIXED32 :
83+ case SFIXED32 :
84+ return readFixed32BitField (inputStream , fieldDescriptor );
85+ case DOUBLE :
86+ case FIXED64 :
87+ case SFIXED64 :
88+ return readFixed64BitField (inputStream , fieldDescriptor );
89+ default :
90+ throw new IllegalStateException (
91+ "Unexpected field type: " + fieldDescriptor .getProtoFieldType ());
92+ }
93+ }
94+
95+ private static Object readFixed32BitField (
96+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
97+ switch (fieldDescriptor .getProtoFieldType ()) {
98+ case FLOAT :
99+ return inputStream .readFloat ();
100+ case FIXED32 :
101+ case SFIXED32 :
102+ return inputStream .readRawLittleEndian32 ();
103+ default :
104+ throw new IllegalStateException (
105+ "Unexpected field type: " + fieldDescriptor .getProtoFieldType ());
106+ }
107+ }
108+
109+ private static Object readFixed64BitField (
110+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
111+ switch (fieldDescriptor .getProtoFieldType ()) {
112+ case DOUBLE :
113+ return inputStream .readDouble ();
114+ case FIXED64 :
115+ case SFIXED64 :
116+ return inputStream .readRawLittleEndian64 ();
117+ default :
118+ throw new IllegalStateException (
119+ "Unexpected field type: " + fieldDescriptor .getProtoFieldType ());
120+ }
121+ }
122+
123+ private Object readLengthDelimitedField (
124+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
125+ FieldLiteDescriptor .Type fieldType = fieldDescriptor .getProtoFieldType ();
126+
127+ switch (fieldType ) {
128+ case BYTES :
129+ return inputStream .readBytes ();
130+ case MESSAGE :
131+ MessageLite .Builder builder =
132+ getDefaultMessageBuilder (fieldDescriptor .getFieldProtoTypeName ());
133+
134+ inputStream .readMessage (builder , ExtensionRegistryLite .getEmptyRegistry ());
135+ return builder .build ();
136+ case STRING :
137+ return inputStream .readStringRequireUtf8 ();
138+ default :
139+ throw new IllegalStateException ("Unexpected field type: " + fieldType );
140+ }
141+ }
142+
143+ private MessageLite .Builder getDefaultMessageBuilder (String protoTypeName ) {
144+ return descriptorPool .getDescriptorOrThrow (protoTypeName ).newMessageBuilder ();
145+ }
146+
147+ CelValue getDefaultCelValue (String protoTypeName , String fieldName ) {
148+ MessageLiteDescriptor messageDescriptor = descriptorPool .getDescriptorOrThrow (protoTypeName );
149+ FieldLiteDescriptor fieldDescriptor = messageDescriptor .getByFieldNameOrThrow (fieldName );
150+
151+ Object defaultValue = getDefaultValue (fieldDescriptor );
152+ if (defaultValue instanceof MessageLite ) {
153+ return fromProtoMessageToCelValue (
154+ fieldDescriptor .getFieldProtoTypeName (), (MessageLite ) defaultValue );
155+ } else {
156+ return fromJavaObjectToCelValue (getDefaultValue (fieldDescriptor ));
157+ }
158+ }
159+
160+ private Object getDefaultValue (FieldLiteDescriptor fieldDescriptor ) {
161+ FieldLiteDescriptor .CelFieldValueType celFieldValueType =
162+ fieldDescriptor .getCelFieldValueType ();
163+ switch (celFieldValueType ) {
164+ case LIST :
165+ return Collections .unmodifiableList (new ArrayList <>());
166+ case MAP :
167+ return Collections .unmodifiableMap (new HashMap <>());
168+ case SCALAR :
169+ return getScalarDefaultValue (fieldDescriptor );
170+ }
171+ throw new IllegalStateException ("Unexpected cel field value type: " + celFieldValueType );
172+ }
173+
174+ private Object getScalarDefaultValue (FieldLiteDescriptor fieldDescriptor ) {
175+ JavaType type = fieldDescriptor .getJavaType ();
176+ switch (type ) {
177+ case INT :
178+ return fieldDescriptor .getProtoFieldType ().equals (FieldLiteDescriptor .Type .UINT32 )
179+ ? UnsignedLong .ZERO
180+ : Defaults .defaultValue (long .class );
181+ case LONG :
182+ return fieldDescriptor .getProtoFieldType ().equals (FieldLiteDescriptor .Type .UINT64 )
183+ ? UnsignedLong .ZERO
184+ : Defaults .defaultValue (long .class );
185+ case ENUM :
186+ return Defaults .defaultValue (long .class );
187+ case FLOAT :
188+ return Defaults .defaultValue (float .class );
189+ case DOUBLE :
190+ return Defaults .defaultValue (double .class );
191+ case BOOLEAN :
192+ return Defaults .defaultValue (boolean .class );
193+ case STRING :
194+ return "" ;
195+ case BYTE_STRING :
196+ return ByteString .EMPTY ;
197+ case MESSAGE :
198+ if (WellKnownProto .isWrapperType (fieldDescriptor .getFieldProtoTypeName ())) {
199+ return NullValue .NULL_VALUE ;
200+ }
201+
202+ return getDefaultMessageBuilder (fieldDescriptor .getFieldProtoTypeName ()).build ();
203+ }
204+ throw new IllegalStateException ("Unexpected java type: " + type );
205+ }
206+
207+ private List <Object > readPackedRepeatedFields (
208+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
209+ int length = inputStream .readInt32 ();
210+ int oldLimit = inputStream .pushLimit (length );
211+ List <Object > repeatedFieldValues = new ArrayList <>();
212+ while (inputStream .getBytesUntilLimit () > 0 ) {
213+ Object value = readPrimitiveField (inputStream , fieldDescriptor );
214+ repeatedFieldValues .add (value );
215+ }
216+ inputStream .popLimit (oldLimit );
217+ return Collections .unmodifiableList (repeatedFieldValues );
218+ }
219+
220+ private ImmutableMap <Object , Object > readSingleMapEntry (
221+ CodedInputStream inputStream , FieldLiteDescriptor fieldDescriptor ) throws IOException {
222+ ImmutableMap <String , Object > singleMapEntry =
223+ readAllFields (inputStream .readByteArray (), fieldDescriptor .getFieldProtoTypeName ());
224+ Object key = checkNotNull (singleMapEntry .get ("key" ));
225+ Object value = checkNotNull (singleMapEntry .get ("value" ));
226+ return ImmutableMap .of (key , value );
227+ }
228+
229+ private ImmutableMap <String , Object > readAllFields (byte [] bytes , String protoTypeName )
230+ throws IOException {
231+ MessageLiteDescriptor messageDescriptor = descriptorPool .getDescriptorOrThrow (protoTypeName );
232+ CodedInputStream inputStream = CodedInputStream .newInstance (bytes );
233+
234+ ImmutableMap .Builder <String , Object > fieldValues = ImmutableMap .builder ();
235+ Map <Integer , List <Object >> nonPackedRepeatedFields = new LinkedHashMap <>();
236+ Map <Integer , Map <Object , Object >> mapFieldValues = new LinkedHashMap <>();
237+ for (int iterCount = 0 ; iterCount < bytes .length ; iterCount ++) {
238+ int tag = inputStream .readTag ();
239+ if (tag == 0 ) {
240+ break ;
241+ }
242+
243+ int tagWireType = WireFormat .getTagWireType (tag );
244+ int fieldNumber = WireFormat .getTagFieldNumber (tag );
245+ FieldLiteDescriptor fieldDescriptor = messageDescriptor .getByFieldNumberOrThrow (fieldNumber );
246+
247+ Object payload ;
248+ switch (tagWireType ) {
249+ case WireFormat .WIRETYPE_VARINT :
250+ payload = readPrimitiveField (inputStream , fieldDescriptor );
251+ break ;
252+ case WireFormat .WIRETYPE_FIXED32 :
253+ payload = readFixed32BitField (inputStream , fieldDescriptor );
254+ break ;
255+ case WireFormat .WIRETYPE_FIXED64 :
256+ payload = readFixed64BitField (inputStream , fieldDescriptor );
257+ break ;
258+ case WireFormat .WIRETYPE_LENGTH_DELIMITED :
259+ CelFieldValueType celFieldValueType = fieldDescriptor .getCelFieldValueType ();
260+ switch (celFieldValueType ) {
261+ case LIST :
262+ if (fieldDescriptor .getIsPacked ()) {
263+ payload = readPackedRepeatedFields (inputStream , fieldDescriptor );
264+ } else {
265+ boolean isLenDelimited =
266+ fieldDescriptor .getProtoFieldType ().equals (FieldLiteDescriptor .Type .MESSAGE )
267+ || fieldDescriptor
268+ .getProtoFieldType ()
269+ .equals (FieldLiteDescriptor .Type .STRING )
270+ || fieldDescriptor
271+ .getProtoFieldType ()
272+ .equals (FieldLiteDescriptor .Type .BYTES );
273+ payload =
274+ isLenDelimited
275+ ? readLengthDelimitedField (inputStream , fieldDescriptor )
276+ : readPrimitiveField (inputStream , fieldDescriptor );
277+ }
278+ break ;
279+ case MAP :
280+ Map <Object , Object > fieldMap =
281+ mapFieldValues .computeIfAbsent (fieldNumber , (unused ) -> new LinkedHashMap <>());
282+ fieldMap .putAll (readSingleMapEntry (inputStream , fieldDescriptor ));
283+ payload = fieldMap ;
284+ break ;
285+ default :
286+ payload = readLengthDelimitedField (inputStream , fieldDescriptor );
287+ break ;
288+ }
289+ break ;
290+ case WireFormat .WIRETYPE_START_GROUP :
291+ case WireFormat .WIRETYPE_END_GROUP :
292+ // TODO: Support groups
293+ throw new UnsupportedOperationException ("Groups are not supported" );
294+ default :
295+ throw new IllegalArgumentException ("Unexpected wire type: " + tagWireType );
296+ }
297+
298+ if (fieldDescriptor .getCelFieldValueType ().equals (CelFieldValueType .LIST )
299+ && !fieldDescriptor .getIsPacked ()) {
300+ List <Object > repeatedValues =
301+ nonPackedRepeatedFields .computeIfAbsent (fieldNumber , (unused ) -> new ArrayList <>());
302+ repeatedValues .add (payload );
303+ payload = repeatedValues ;
304+ }
305+
306+ fieldValues .put (fieldDescriptor .getFieldName (), payload );
307+ }
308+
309+ // Protobuf encoding follows a "last one wins" semantics. This means for duplicated fields,
310+ // we accept the last value encountered.
311+ return fieldValues .buildKeepingLast ();
312+ }
313+
314+ ImmutableMap <String , Object > readAllFields (MessageLite msg , String protoTypeName )
315+ throws IOException {
316+ return readAllFields (msg .toByteArray (), protoTypeName );
317+ }
318+
46319 @ Override
47320 public CelValue fromProtoMessageToCelValue (String protoTypeName , MessageLite msg ) {
48321 checkNotNull (msg );
0 commit comments