forked from apache/fory
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_fory.py
More file actions
550 lines (493 loc) · 19.3 KB
/
Copy path_fory.py
File metadata and controls
550 lines (493 loc) · 19.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import enum
import logging
import os
import warnings
from abc import ABC, abstractmethod
from typing import Union, Iterable, TypeVar
from pyfory.buffer import Buffer
from pyfory.resolver import (
MapRefResolver,
NoRefResolver,
NULL_FLAG,
NOT_NULL_VALUE_FLAG,
)
from pyfory.util import is_little_endian, set_bit, get_bit, clear_bit
from pyfory.type import TypeId
try:
import numpy as np
except ImportError:
np = None
from cloudpickle import Pickler
from pickle import Unpickler
logger = logging.getLogger(__name__)
MAGIC_NUMBER = 0x62D4
DEFAULT_DYNAMIC_WRITE_META_STR_ID = -1
DYNAMIC_TYPE_ID = -1
USE_TYPE_NAME = 0
USE_TYPE_ID = 1
# preserve 0 as flag for type id not set in TypeInfo`
NO_TYPE_ID = 0
INT64_TYPE_ID = TypeId.INT64
FLOAT64_TYPE_ID = TypeId.FLOAT64
BOOL_TYPE_ID = TypeId.BOOL
STRING_TYPE_ID = TypeId.STRING
# `NOT_NULL_VALUE_FLAG` + `TYPE_ID << 1` in little-endian order
NOT_NULL_INT64_FLAG = NOT_NULL_VALUE_FLAG & 0b11111111 | (INT64_TYPE_ID << 8)
NOT_NULL_FLOAT64_FLAG = NOT_NULL_VALUE_FLAG & 0b11111111 | (FLOAT64_TYPE_ID << 8)
NOT_NULL_BOOL_FLAG = NOT_NULL_VALUE_FLAG & 0b11111111 | (BOOL_TYPE_ID << 8)
NOT_NULL_STRING_FLAG = NOT_NULL_VALUE_FLAG & 0b11111111 | (STRING_TYPE_ID << 8)
SMALL_STRING_THRESHOLD = 16
class Language(enum.Enum):
XLANG = 0
JAVA = 1
PYTHON = 2
CPP = 3
GO = 4
JAVA_SCRIPT = 5
RUST = 6
DART = 7
class BufferObject(ABC):
"""
Fory binary representation of an object.
Note: This class is used for zero-copy out-of-band serialization and shouldn't
be used for any other cases.
"""
@abstractmethod
def total_bytes(self) -> int:
"""total size for serialized bytes of an object"""
@abstractmethod
def write_to(self, buffer: "Buffer"):
"""Write serialized object to a buffer."""
@abstractmethod
def to_buffer(self) -> "Buffer":
"""Write serialized data as Buffer."""
class Fory:
__slots__ = (
"language",
"is_py",
"compatbile",
"ref_tracking",
"ref_resolver",
"type_resolver",
"serialization_context",
"require_type_registration",
"buffer",
"pickler",
"unpickler",
"_buffer_callback",
"_buffers",
"metastring_resolver",
"_unsupported_callback",
"_unsupported_objects",
"_peer_language",
)
def __init__(
self,
language=Language.PYTHON,
ref_tracking: bool = False,
require_type_registration: bool = True,
compatbile: bool = False,
):
"""
:param require_type_registration:
Whether to require registering types for serialization, enabled by default.
If disabled, unknown insecure types can be deserialized, which can be
insecure and cause remote code execution attack if the types
`__new__`/`__init__`/`__eq__`/`__hash__` method contain malicious code.
Do not disable type registration if you can't ensure your environment are
*indeed secure*. We are not responsible for security risks if
you disable this option.
:param compatbile:
Whether to enable compatbile mode for cross-language serialization.
When enabled, type forward/backward compatibility for struct fields will be enabled.
"""
self.language = language
self.is_py = language == Language.PYTHON
self.require_type_registration = _ENABLE_TYPE_REGISTRATION_FORCIBLY or require_type_registration
self.compatbile = compatbile
self.ref_tracking = ref_tracking
if self.ref_tracking:
self.ref_resolver = MapRefResolver()
else:
self.ref_resolver = NoRefResolver()
from pyfory._serialization import MetaStringResolver
from pyfory._registry import TypeResolver
self.metastring_resolver = MetaStringResolver()
self.type_resolver = TypeResolver(self, meta_share=compatbile)
self.type_resolver.initialize()
from pyfory._serialization import SerializationContext
self.serialization_context = SerializationContext(scoped_meta_share_enabled=compatbile)
self.buffer = Buffer.allocate(32)
if not require_type_registration:
warnings.warn(
"Type registration is disabled, unknown types can be deserialized which may be insecure.",
RuntimeWarning,
stacklevel=2,
)
self.pickler = Pickler(self.buffer)
self.unpickler = None
else:
self.pickler = _PicklerStub()
self.unpickler = _UnpicklerStub()
self._buffer_callback = None
self._buffers = None
self._unsupported_callback = None
self._unsupported_objects = None
self._peer_language = None
def register(
self,
cls: Union[type, TypeVar],
*,
type_id: int = None,
namespace: str = None,
typename: str = None,
serializer=None,
):
self.register_type(cls, type_id=type_id, namespace=namespace, typename=typename, serializer=serializer)
# `Union[type, TypeVar]` is not supported in py3.6
def register_type(
self,
cls: Union[type, TypeVar],
*,
type_id: int = None,
namespace: str = None,
typename: str = None,
serializer=None,
):
return self.type_resolver.register_type(
cls,
type_id=type_id,
namespace=namespace,
typename=typename,
serializer=serializer,
)
def register_serializer(self, cls: type, serializer):
self.type_resolver.register_serializer(cls, serializer)
def serialize(
self,
obj,
buffer: Buffer = None,
buffer_callback=None,
unsupported_callback=None,
) -> Union[Buffer, bytes]:
try:
return self._serialize(
obj,
buffer,
buffer_callback=buffer_callback,
unsupported_callback=unsupported_callback,
)
finally:
self.reset_write()
def _serialize(
self,
obj,
buffer: Buffer = None,
buffer_callback=None,
unsupported_callback=None,
) -> Union[Buffer, bytes]:
self._buffer_callback = buffer_callback
self._unsupported_callback = unsupported_callback
if buffer is not None:
self.pickler = Pickler(buffer)
else:
self.buffer.writer_index = 0
buffer = self.buffer
if self.language == Language.XLANG:
buffer.write_int16(MAGIC_NUMBER)
mask_index = buffer.writer_index
# 1byte used for bit mask
buffer.grow(1)
buffer.writer_index = mask_index + 1
if obj is None:
set_bit(buffer, mask_index, 0)
else:
clear_bit(buffer, mask_index, 0)
# set endian
if is_little_endian:
set_bit(buffer, mask_index, 1)
else:
clear_bit(buffer, mask_index, 1)
if self.language == Language.XLANG:
# set reader as x_lang.
set_bit(buffer, mask_index, 2)
# set writer language.
buffer.write_int8(Language.PYTHON.value)
else:
# set reader as native.
clear_bit(buffer, mask_index, 2)
if self._buffer_callback is not None:
set_bit(buffer, mask_index, 3)
else:
clear_bit(buffer, mask_index, 3)
# Reserve space for type definitions offset, similar to Java implementation
type_defs_offset_pos = None
if self.serialization_context.scoped_meta_share_enabled:
type_defs_offset_pos = buffer.writer_index
buffer.write_int32(-1) # Reserve 4 bytes for type definitions offset
if self.language == Language.PYTHON:
self.serialize_ref(buffer, obj)
else:
self.xserialize_ref(buffer, obj)
# Write type definitions at the end, similar to Java implementation
if self.serialization_context.scoped_meta_share_enabled:
meta_context = self.serialization_context.meta_context
if meta_context is not None and len(meta_context.get_writing_type_defs()) > 0:
# Update the offset to point to current position
current_pos = buffer.writer_index
buffer.put_int32(type_defs_offset_pos, current_pos - type_defs_offset_pos - 4)
self.type_resolver.write_type_defs(buffer)
self.reset_write()
if buffer is not self.buffer:
return buffer
else:
return buffer.to_bytes(0, buffer.writer_index)
def serialize_ref(self, buffer, obj, typeinfo=None):
cls = type(obj)
if cls is str:
buffer.write_int16(NOT_NULL_STRING_FLAG)
buffer.write_string(obj)
return
elif cls is int:
buffer.write_int16(NOT_NULL_INT64_FLAG)
buffer.write_varint64(obj)
return
elif cls is bool:
buffer.write_int16(NOT_NULL_BOOL_FLAG)
buffer.write_bool(obj)
return
if self.ref_resolver.write_ref_or_null(buffer, obj):
return
if typeinfo is None:
typeinfo = self.type_resolver.get_typeinfo(cls)
self.type_resolver.write_typeinfo(buffer, typeinfo)
typeinfo.serializer.write(buffer, obj)
def serialize_nonref(self, buffer, obj):
cls = type(obj)
if cls is str:
buffer.write_varuint32(STRING_TYPE_ID)
buffer.write_string(obj)
return
elif cls is int:
buffer.write_varuint32(INT64_TYPE_ID)
buffer.write_varint64(obj)
return
elif cls is bool:
buffer.write_varuint32(BOOL_TYPE_ID)
buffer.write_bool(obj)
return
else:
typeinfo = self.type_resolver.get_typeinfo(cls)
self.type_resolver.write_typeinfo(buffer, typeinfo)
typeinfo.serializer.write(buffer, obj)
def xserialize_ref(self, buffer, obj, serializer=None):
if serializer is None or serializer.need_to_write_ref:
if not self.ref_resolver.write_ref_or_null(buffer, obj):
self.xserialize_nonref(buffer, obj, serializer=serializer)
else:
if obj is None:
buffer.write_int8(NULL_FLAG)
else:
buffer.write_int8(NOT_NULL_VALUE_FLAG)
self.xserialize_nonref(buffer, obj, serializer=serializer)
def xserialize_nonref(self, buffer, obj, serializer=None):
if serializer is not None:
serializer.xwrite(buffer, obj)
return
cls = type(obj)
typeinfo = self.type_resolver.get_typeinfo(cls)
self.type_resolver.write_typeinfo(buffer, typeinfo)
typeinfo.serializer.xwrite(buffer, obj)
def deserialize(
self,
buffer: Union[Buffer, bytes],
buffers: Iterable = None,
unsupported_objects: Iterable = None,
):
try:
return self._deserialize(buffer, buffers, unsupported_objects)
finally:
self.reset_read()
def _deserialize(
self,
buffer: Union[Buffer, bytes],
buffers: Iterable = None,
unsupported_objects: Iterable = None,
):
if isinstance(buffer, bytes):
buffer = Buffer(buffer)
if unsupported_objects is not None:
self._unsupported_objects = iter(unsupported_objects)
if self.language == Language.XLANG:
magic_numer = buffer.read_int16()
assert magic_numer == MAGIC_NUMBER, (
f"The fory xlang serialization must start with magic number {hex(MAGIC_NUMBER)}. "
"Please check whether the serialization is based on the xlang protocol and the data didn't corrupt."
)
reader_index = buffer.reader_index
buffer.reader_index = reader_index + 1
if get_bit(buffer, reader_index, 0):
return None
is_little_endian_ = get_bit(buffer, reader_index, 1)
assert is_little_endian_, "Big endian is not supported for now, please ensure peer machine is little endian."
is_target_x_lang = get_bit(buffer, reader_index, 2)
if is_target_x_lang:
self._peer_language = Language(buffer.read_int8())
else:
self._peer_language = Language.PYTHON
is_out_of_band_serialization_enabled = get_bit(buffer, reader_index, 3)
if is_out_of_band_serialization_enabled:
assert buffers is not None, "buffers shouldn't be null when the serialized stream is produced with buffer_callback not null."
self._buffers = iter(buffers)
else:
assert buffers is None, "buffers should be null when the serialized stream is produced with buffer_callback null."
# Read type definitions at the start, similar to Java implementation
if self.serialization_context.scoped_meta_share_enabled:
relative_type_defs_offset = buffer.read_int32()
if relative_type_defs_offset != -1:
# Save current reader position
current_reader_index = buffer.reader_index
# Jump to type definitions
buffer.reader_index = current_reader_index + relative_type_defs_offset
# Read type definitions
self.type_resolver.read_type_defs(buffer)
# Jump back to continue with object deserialization
buffer.reader_index = current_reader_index
if is_target_x_lang:
obj = self.xdeserialize_ref(buffer)
else:
obj = self.deserialize_ref(buffer)
return obj
def deserialize_ref(self, buffer):
ref_resolver = self.ref_resolver
ref_id = ref_resolver.try_preserve_ref_id(buffer)
# indicates that the object is first read.
if ref_id >= NOT_NULL_VALUE_FLAG:
typeinfo = self.type_resolver.read_typeinfo(buffer)
o = typeinfo.serializer.read(buffer)
ref_resolver.set_read_object(ref_id, o)
return o
else:
return ref_resolver.get_read_object()
def deserialize_nonref(self, buffer):
"""Deserialize not-null and non-reference object from buffer."""
typeinfo = self.type_resolver.read_typeinfo(buffer)
return typeinfo.serializer.read(buffer)
def xdeserialize_ref(self, buffer, serializer=None):
if serializer is None or serializer.need_to_write_ref:
ref_resolver = self.ref_resolver
ref_id = ref_resolver.try_preserve_ref_id(buffer)
# indicates that the object is first read.
if ref_id >= NOT_NULL_VALUE_FLAG:
o = self.xdeserialize_nonref(buffer, serializer=serializer)
ref_resolver.set_read_object(ref_id, o)
return o
else:
return ref_resolver.get_read_object()
head_flag = buffer.read_int8()
if head_flag == NULL_FLAG:
return None
return self.xdeserialize_nonref(buffer, serializer=serializer)
def xdeserialize_nonref(self, buffer, serializer=None):
if serializer is None:
serializer = self.type_resolver.read_typeinfo(buffer).serializer
return serializer.xread(buffer)
def write_buffer_object(self, buffer, buffer_object: BufferObject):
if self._buffer_callback is None or self._buffer_callback(buffer_object):
buffer.write_bool(True)
size = buffer_object.total_bytes()
# writer length.
buffer.write_varuint32(size)
writer_index = buffer.writer_index
buffer.ensure(writer_index + size)
buf = buffer.slice(buffer.writer_index, size)
buffer_object.write_to(buf)
buffer.writer_index += size
else:
buffer.write_bool(False)
def read_buffer_object(self, buffer) -> Buffer:
in_band = buffer.read_bool()
if in_band:
size = buffer.read_varuint32()
buf = buffer.slice(buffer.reader_index, size)
buffer.reader_index += size
return buf
else:
assert self._buffers is not None
return next(self._buffers)
def handle_unsupported_write(self, buffer, obj):
if self._unsupported_callback is None or self._unsupported_callback(obj):
buffer.write_bool(True)
self.pickler.dump(obj)
else:
buffer.write_bool(False)
def handle_unsupported_read(self, buffer):
in_band = buffer.read_bool()
if in_band:
unpickler = self.unpickler
if unpickler is None:
self.unpickler = unpickler = Unpickler(buffer)
return unpickler.load()
else:
assert self._unsupported_objects is not None
return next(self._unsupported_objects)
def write_ref_pyobject(self, buffer, value, typeinfo=None):
if self.ref_resolver.write_ref_or_null(buffer, value):
return
if typeinfo is None:
typeinfo = self.type_resolver.get_typeinfo(type(value))
self.type_resolver.write_typeinfo(buffer, typeinfo)
typeinfo.serializer.write(buffer, value)
def read_ref_pyobject(self, buffer):
return self.deserialize_ref(buffer)
def reset_write(self):
self.ref_resolver.reset_write()
self.type_resolver.reset_write()
self.serialization_context.reset_write()
self.metastring_resolver.reset_write()
self.pickler.clear_memo()
self._buffer_callback = None
self._unsupported_callback = None
def reset_read(self):
self.ref_resolver.reset_read()
self.type_resolver.reset_read()
self.serialization_context.reset_read()
self.metastring_resolver.reset_write()
self.unpickler = None
self._buffers = None
self._unsupported_objects = None
def reset(self):
self.reset_write()
self.reset_read()
_ENABLE_TYPE_REGISTRATION_FORCIBLY = os.getenv("ENABLE_TYPE_REGISTRATION_FORCIBLY", "0") in {
"1",
"true",
}
class _PicklerStub:
def dump(self, o):
raise ValueError(
f"Type {type(o)} is not registered, "
f"pickle is not allowed when type registration enabled, "
f"Please register the type or pass unsupported_callback"
)
def clear_memo(self):
pass
class _UnpicklerStub:
def load(self):
raise ValueError("pickle is not allowed when type registration enabled, Please register the type or pass unsupported_callback")