diff --git a/ci/format.sh b/ci/format.sh index b80b26a9d8..fed6ab9df7 100755 --- a/ci/format.sh +++ b/ci/format.sh @@ -125,11 +125,11 @@ format_files() { format_all_scripts() { echo "$(date)" "Ruff format...." - git ls-files -- '*.py' '*.pyx' '*.pxd' '*.pxi' "${GIT_LS_EXCLUDES[@]}" | xargs -P 10 \ + git ls-files -- '*.py' "${GIT_LS_EXCLUDES[@]}" | xargs -P 10 \ ruff format echo "$(date)" "Ruff check...." - git ls-files -- '*.py' '*.pyx' '*.pxd' '*.pxi' "${GIT_LS_EXCLUDES[@]}" | xargs \ + git ls-files -- '*.py' "${GIT_LS_EXCLUDES[@]}" | xargs \ ruff check --fix } @@ -193,10 +193,10 @@ format_changed() { # exist on both branches. MERGEBASE="$(git merge-base origin/main HEAD)" - if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyx' '*.pxd' '*.pxi' &>/dev/null; then - git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' '*.pyx' '*.pxd' '*.pxi' | xargs -P 5 \ + if ! git diff --diff-filter=ACRM --quiet --exit-code "$MERGEBASE" -- '*.py' &>/dev/null; then + git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' | xargs -P 5 \ ruff format - git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' '*.pyx' '*.pxd' '*.pxi' | xargs -P 5 \ + git diff --name-only --diff-filter=ACRM "$MERGEBASE" -- '*.py' | xargs -P 5 \ ruff check --fix fi diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py index 6bc87ba1a9..b05b37ffcf 100644 --- a/python/pyfory/_fory.py +++ b/python/pyfory/_fory.py @@ -98,6 +98,7 @@ class Fory: __slots__ = ( "language", "is_py", + "compatible", "ref_tracking", "ref_resolver", "type_resolver", @@ -113,13 +114,13 @@ class Fory: "_unsupported_objects", "_peer_language", ) - serialization_context: "SerializationContext" def __init__( self, language=Language.PYTHON, ref_tracking: bool = False, require_type_registration: bool = True, + compatible: bool = False, ): """ :param require_type_registration: @@ -130,10 +131,14 @@ def __init__( Do not disable type registration if you can't ensure your environment are *indeed secure*. We are not responsible for security risks if you disable this option. + :param compatible: + Whether to enable compatible mode for cross-language serialization. + When enabled, type forward/backward compatibility for struct fields will be enabled. """ self.language = language self.is_py = language == Language.PYTHON self.require_type_registration = _ENABLE_TYPE_REGISTRATION_FORCIBLY or require_type_registration + self.compatible = compatible self.ref_tracking = ref_tracking if self.ref_tracking: self.ref_resolver = MapRefResolver() @@ -143,9 +148,11 @@ def __init__( from pyfory._registry import TypeResolver self.metastring_resolver = MetaStringResolver() - self.type_resolver = TypeResolver(self) + self.type_resolver = TypeResolver(self, meta_share=compatible) self.type_resolver.initialize() - self.serialization_context = SerializationContext() + from pyfory._serialization import SerializationContext + + self.serialization_context = SerializationContext(scoped_meta_share_enabled=compatible) self.buffer = Buffer.allocate(32) if not require_type_registration: warnings.warn( @@ -255,10 +262,26 @@ def _serialize( set_bit(buffer, mask_index, 3) else: clear_bit(buffer, mask_index, 3) + # Reserve space for type definitions offset, similar to Java implementation + type_defs_offset_pos = None + if self.serialization_context.scoped_meta_share_enabled: + type_defs_offset_pos = buffer.writer_index + buffer.write_int32(-1) # Reserve 4 bytes for type definitions offset + if self.language == Language.PYTHON: self.serialize_ref(buffer, obj) else: self.xserialize_ref(buffer, obj) + + # Write type definitions at the end, similar to Java implementation + if self.serialization_context.scoped_meta_share_enabled: + meta_context = self.serialization_context.meta_context + if meta_context is not None and len(meta_context.get_writing_type_defs()) > 0: + # Update the offset to point to current position + current_pos = buffer.writer_index + buffer.put_int32(type_defs_offset_pos, current_pos - type_defs_offset_pos - 4) + self.type_resolver.write_type_defs(buffer) + self.reset_write() if buffer is not self.buffer: return buffer @@ -369,6 +392,20 @@ def _deserialize( self._buffers = iter(buffers) else: assert buffers is None, "buffers should be null when the serialized stream is produced with buffer_callback null." + + # Read type definitions at the start, similar to Java implementation + if self.serialization_context.scoped_meta_share_enabled: + relative_type_defs_offset = buffer.read_int32() + if relative_type_defs_offset != -1: + # Save current reader position + current_reader_index = buffer.reader_index + # Jump to type definitions + buffer.reader_index = current_reader_index + relative_type_defs_offset + # Read type definitions + self.type_resolver.read_type_defs(buffer) + # Jump back to continue with object deserialization + buffer.reader_index = current_reader_index + if is_target_x_lang: obj = self.xdeserialize_ref(buffer) else: @@ -470,7 +507,7 @@ def read_ref_pyobject(self, buffer): def reset_write(self): self.ref_resolver.reset_write() self.type_resolver.reset_write() - self.serialization_context.reset() + self.serialization_context.reset_write() self.metastring_resolver.reset_write() self.pickler.clear_memo() self._buffer_callback = None @@ -479,7 +516,7 @@ def reset_write(self): def reset_read(self): self.ref_resolver.reset_read() self.type_resolver.reset_read() - self.serialization_context.reset() + self.serialization_context.reset_read() self.metastring_resolver.reset_write() self.unpickler = None self._buffers = None @@ -490,36 +527,6 @@ def reset(self): self.reset_read() -class SerializationContext: - """ - A context is used to add some context-related information, so that the - serializers can setup relation between serializing different objects. - The context will be reset after finished serializing/deserializing the - object tree. - """ - - __slots__ = ("objects",) - - def __init__(self): - self.objects = dict() - - def add(self, key, obj): - self.objects[id(key)] = obj - - def __contains__(self, key): - return id(key) in self.objects - - def __getitem__(self, key): - return self.objects[id(key)] - - def get(self, key): - return self.objects.get(id(key)) - - def reset(self): - if len(self.objects) > 0: - self.objects.clear() - - _ENABLE_TYPE_REGISTRATION_FORCIBLY = os.getenv("ENABLE_TYPE_REGISTRATION_FORCIBLY", "0") in { "1", "true", diff --git a/python/pyfory/_registry.py b/python/pyfory/_registry.py index 2333ff7824..8b421f036e 100644 --- a/python/pyfory/_registry.py +++ b/python/pyfory/_registry.py @@ -76,12 +76,16 @@ Float32Type, Float64Type, load_class, + is_struct_type, ) from pyfory._fory import ( DYNAMIC_TYPE_ID, # preserve 0 as flag for type id not set in TypeInfo` NO_TYPE_ID, ) +from pyfory.meta.typedef import TypeDef +from pyfory.meta.typedef_decoder import decode_typedef, skip_typedef +from pyfory.meta.typedef_encoder import encode_typedef try: import numpy as np @@ -104,6 +108,7 @@ class TypeInfo: "namespace_bytes", "typename_bytes", "dynamic_type", + "type_def", ) def __init__( @@ -114,6 +119,7 @@ def __init__( namespace_bytes=None, typename_bytes=None, dynamic_type: bool = False, + type_def: TypeDef = None, ): self.cls = cls self.type_id = type_id @@ -121,6 +127,7 @@ def __init__( self.namespace_bytes = namespace_bytes self.typename_bytes = typename_bytes self.dynamic_type = dynamic_type + self.type_def = None def __repr__(self): return f"TypeInfo(cls={self.cls}, type_id={self.type_id}, serializer={self.serializer})" @@ -160,9 +167,11 @@ class TypeResolver: "metastring_resolver", "language", "_type_id_to_typeinfo", + "_meta_shared_typeinfo", + "meta_share", ) - def __init__(self, fory): + def __init__(self, fory, meta_share=False): self.fory = fory self.metastring_resolver = fory.metastring_resolver self.language = fory.language @@ -182,9 +191,12 @@ def __init__(self, fory): self._named_type_to_typeinfo = dict() self.namespace_encoder = MetaStringEncoder(".", "_") self.namespace_decoder = MetaStringDecoder(".", "_") + # Cache for TypeDef and TypeInfo tuples (similar to Java's classIdToDef) + self._meta_shared_typeinfo = {} self.typename_encoder = MetaStringEncoder("$", "_") self.typename_decoder = MetaStringDecoder("$", "_") self.meta_compressor = DeflaterMetaCompressor() + self.meta_share = meta_share def initialize(self): self._initialize_xlang() @@ -356,7 +368,7 @@ def _register_xtype( serializer = FunctionSerializer(self.fory, cls) type_id = TypeId.NAMED_EXT if type_id is None else ((type_id << 8) + TypeId.EXT) else: - serializer = DataClassSerializer(self.fory, cls, xlang=True) + serializer = None type_id = TypeId.NAMED_STRUCT if type_id is None else ((type_id << 8) + TypeId.STRUCT) elif not internal: type_id = TypeId.NAMED_EXT if type_id is None else ((type_id << 8) + TypeId.EXT) @@ -460,7 +472,7 @@ def get_typeinfo(self, cls, create=True): type_info = self._types_info.get(cls) if type_info is not None: if type_info.serializer is None: - type_info.serializer = self._create_serializer(cls) + self._set_typeinfo(type_info) return type_info elif not create: return None @@ -491,6 +503,20 @@ def get_typeinfo(self, cls, create=True): serializer=serializer, ) + def _set_typeinfo(self, typeinfo): + type_id = typeinfo.type_id & 0xFF + if is_struct_type(type_id): + if self.meta_share: + type_def = encode_typedef(self, typeinfo.cls) + typeinfo.serializer = type_def.create_serializer(self) + typeinfo.type_def = type_def + else: + typeinfo.serializer = DataClassSerializer(self.fory, typeinfo.cls, xlang=not self.fory.is_py) + else: + typeinfo.serializer = self._create_serializer(typeinfo.cls) + + return typeinfo + def _create_serializer(self, cls): for clz in cls.__mro__: type_info = self._types_info.get(clz) @@ -502,7 +528,11 @@ def _create_serializer(self, cls): # Use FunctionSerializer for function types (including lambdas) serializer = FunctionSerializer(self.fory, cls) elif dataclasses.is_dataclass(cls): - serializer = DataClassSerializer(self.fory, cls) + if not self.meta_share: + serializer = DataClassSerializer(self.fory, cls, xlang=not self.fory.is_py) + else: + # lazy create serializer to handle nested struct fields. + serializer = None elif issubclass(cls, enum.Enum): serializer = EnumSerializer(self.fory, cls) elif (hasattr(cls, "__reduce__") and cls.__reduce__ is not object.__reduce__) or ( @@ -536,6 +566,28 @@ def _create_serializer(self, cls): serializer = PickleSerializer(self.fory, cls) return serializer + def is_registered_by_name(self, cls): + typeinfo = self._types_info.get(cls) + if typeinfo is None: + return False + return TypeId.is_namespaced_type(typeinfo.type_id & 0xFF) + + def is_registered_by_id(self, cls): + typeinfo = self._types_info.get(cls) + if typeinfo is None: + return False + return not TypeId.is_namespaced_type(typeinfo.type_id & 0xFF) + + def get_registered_name(self, cls): + typeinfo = self._types_info.get(cls) + assert typeinfo is not None, f"{cls} not registered" + return typeinfo.decode_namespace(), typeinfo.decode_typename() + + def get_registered_id(self, cls): + typeinfo = self._types_info.get(cls) + assert typeinfo is not None, f"{cls} not registered" + return typeinfo.type_id + def _load_metabytes_to_typeinfo(self, ns_metabytes, type_metabytes): typeinfo = self._ns_type_to_typeinfo.get((ns_metabytes, type_metabytes)) if typeinfo is not None: @@ -557,12 +609,22 @@ def write_typeinfo(self, buffer, typeinfo): return type_id = typeinfo.type_id internal_type_id = type_id & 0xFF + + # Check if meta share is enabled first + if self.meta_share: + self.write_shared_type_meta(buffer, typeinfo) + return + buffer.write_varuint32(type_id) if TypeId.is_namespaced_type(internal_type_id): self.metastring_resolver.write_meta_string_bytes(buffer, typeinfo.namespace_bytes) self.metastring_resolver.write_meta_string_bytes(buffer, typeinfo.typename_bytes) def read_typeinfo(self, buffer): + # Check if meta share is enabled first + if self.meta_share: + return self.read_shared_type_meta(buffer) + type_id = buffer.read_varuint32() internal_type_id = type_id & 0xFF if TypeId.is_namespaced_type(internal_type_id): @@ -595,6 +657,66 @@ def get_typeinfo_by_name(self, namespace, typename): def get_meta_compressor(self): return self.meta_compressor + def write_shared_type_meta(self, buffer, typeinfo): + """Write shared type meta information.""" + assert typeinfo.type_def is not None, "Type info must be set when meta share is enabled" + meta_context = self.fory.serialization_context.meta_context + meta_context.write_typeinfo(buffer, typeinfo) + + def read_shared_type_meta(self, buffer): + """Read shared type meta information.""" + meta_context = self.fory.serialization_context.meta_context + assert meta_context is not None, "Meta context must be set when meta share is enabled" + type_id = buffer.read_varuint32() + typeinfo = meta_context.get_read_type_info(type_id) + assert typeinfo is not None, f"Type info not found for ID {type_id}" + return typeinfo + + def write_type_defs(self, buffer): + """Write all type definitions that need to be sent.""" + meta_context = self.fory.serialization_context.meta_context + if meta_context is None: + return + writing_type_defs = meta_context.get_writing_type_defs() + buffer.write_varuint32(len(writing_type_defs)) + for type_def in writing_type_defs: + # Just copy the encoded bytes directly + buffer.write_bytes(type_def.encoded) + + def read_type_defs(self, buffer): + """Read all type definitions from the buffer.""" + meta_context = self.fory.serialization_context.meta_context + if meta_context is None: + return + + num_type_defs = buffer.read_varuint32() + for i in range(num_type_defs): + # Read the header (first 8 bytes) to get the type ID + header = buffer.read_int64() + # Check if we already have this TypeDef cached + type_info = self._meta_shared_typeinfo.get(header) + if type_info is not None: + # Skip the rest of the TypeDef binary for faster performance + skip_typedef(buffer, header) + else: + # Read the TypeDef and create TypeInfo + type_def = decode_typedef(buffer, self, header=header) + type_info = self._build_type_info_from_typedef(type_def) + # Cache the tuple for future use + self._meta_shared_typeinfo[header] = type_info + meta_context.add_read_type_info(type_info) + + def _build_type_info_from_typedef(self, type_def): + """Build TypeInfo from TypeDef using TypeDef's create_serializer method.""" + # Create serializer using TypeDef's create_serializer method + serializer = type_def.create_serializer(self) + ns_metastr = self.namespace_encoder.encode(type_def.namespace or "") + ns_meta_bytes = self.metastring_resolver.get_metastr_bytes(ns_metastr) + type_metastr = self.typename_encoder.encode(type_def.typename) + type_meta_bytes = self.metastring_resolver.get_metastr_bytes(type_metastr) + typeinfo = TypeInfo(type_def.cls, type_def.type_id, serializer, ns_meta_bytes, type_meta_bytes, False, type_def) + return typeinfo + def reset(self): pass diff --git a/python/pyfory/_serialization.pyx b/python/pyfory/_serialization.pyx index 315333a704..458eb640f6 100644 --- a/python/pyfory/_serialization.pyx +++ b/python/pyfory/_serialization.pyx @@ -399,6 +399,7 @@ cdef class TypeInfo: cdef public MetaStringBytes namespace_bytes cdef public MetaStringBytes typename_bytes cdef public c_bool dynamic_type + cdef public object type_def def __init__( self, @@ -408,6 +409,7 @@ cdef class TypeInfo: namespace_bytes: MetaStringBytes = None, typename_bytes: MetaStringBytes = None, dynamic_type: bool = False, + type_def: object = None ): self.cls = cls self.type_id = type_id @@ -415,6 +417,7 @@ cdef class TypeInfo: self.namespace_bytes = namespace_bytes self.typename_bytes = typename_bytes self.dynamic_type = dynamic_type + self.type_def = type_def def __repr__(self): return f"TypeInfo(cls={self.cls}, type_id={self.type_id}, " \ @@ -443,12 +446,16 @@ cdef class TypeResolver: # hash -> TypeInfo flat_hash_map[pair[int64_t, int64_t], PyObject *] _c_meta_hash_to_typeinfo MetaStringResolver meta_string_resolver + c_bool meta_share + SerializationContext serialization_context - def __init__(self, fory): + def __init__(self, fory, meta_share=False): self.fory = fory self.metastring_resolver = fory.metastring_resolver + self.meta_share = meta_share from pyfory._registry import TypeResolver - self._resolver = TypeResolver(fory) + self._resolver = TypeResolver(fory, meta_share=meta_share) + self.serialization_context = fory.serialization_context def initialize(self): self._resolver.initialize() @@ -518,7 +525,7 @@ cdef class TypeResolver: if type_info.serializer is not None: return type_info else: - type_info.serializer = self._resolver._create_serializer(cls) + type_info.serializer = self._resolver.get_typeinfo(cls).serializer return type_info elif not create: return None @@ -527,6 +534,18 @@ cdef class TypeResolver: self._c_types_info[ cls] = type_info self._populate_typeinfo(type_info) return type_info + + def is_registered_by_name(self, cls): + return self._resolver.is_registered_by_name(cls) + + def is_registered_by_id(self, cls): + return self._resolver.is_registered_by_id(cls) + + def get_registered_name(self, cls): + return self._resolver.get_registered_name(cls) + + def get_registered_id(self, cls): + return self._resolver.get_registered_id(cls) cdef inline TypeInfo _load_bytes_to_typeinfo( self, int32_t type_id, MetaStringBytes ns_metabytes, MetaStringBytes type_metabytes): @@ -546,12 +565,20 @@ cdef class TypeResolver: cdef: int32_t type_id = typeinfo.type_id int32_t internal_type_id = type_id & 0xFF + + if self.meta_share: + self.write_shared_type_meta(buffer, typeinfo) + return + buffer.write_varuint32(type_id) if IsNamespacedType(internal_type_id): self.metastring_resolver.write_meta_string_bytes(buffer, typeinfo.namespace_bytes) self.metastring_resolver.write_meta_string_bytes(buffer, typeinfo.typename_bytes) cpdef inline TypeInfo read_typeinfo(self, Buffer buffer): + if self.meta_share: + return self.read_shared_type_meta(buffer) + cdef: int32_t type_id = buffer.read_varuint32() if type_id < 0: @@ -580,6 +607,29 @@ cdef class TypeResolver: def get_meta_compressor(self): return self._resolver.get_meta_compressor() + cpdef write_shared_type_meta(self, Buffer buffer, TypeInfo typeinfo): + """Write shared type meta information.""" + meta_context = self.serialization_context.meta_context + assert meta_context is not None, "Meta context must be set when meta share is enabled" + meta_context.write_typeinfo(buffer, typeinfo) + + cpdef TypeInfo read_shared_type_meta(self, Buffer buffer): + """Read shared type meta information.""" + meta_context = self.serialization_context.meta_context + assert meta_context is not None, "Meta context must be set when meta share is enabled" + type_id = buffer.read_varuint32() + typeinfo = meta_context.get_read_type_info(type_id) + assert typeinfo is not None, f"Type info not found for ID {type_id}" + return typeinfo + + cpdef write_type_defs(self, Buffer buffer): + """Write all type definitions that need to be sent.""" + self._resolver.write_type_defs(buffer) + + cpdef read_type_defs(self, Buffer buffer): + """Read all type definitions from the buffer.""" + self._resolver.read_type_defs(buffer) + cpdef inline reset(self): pass @@ -590,12 +640,124 @@ cdef class TypeResolver: pass +@cython.final +cdef class MetaContext: + """ + Context for sharing type meta across multiple serialization. Type name, field name and field + type will be shared between different serialization. + + This is the Cython-optimized equivalent of Java's MetaContext class. + """ + cdef: + # Types which have sent definitions to peer + # Maps type objects to their assigned IDs + flat_hash_map[uint64_t, int32_t] _c_type_map + + # Counter for assigning new IDs + list _writing_type_defs + list _read_type_infos + + def __cinit__(self): + self._writing_type_defs = [] + self._read_type_infos = [] + + cpdef inline int32_t write_typeinfo(self, Buffer buffer, typeinfo): + """Add a type definition to the writing queue.""" + type_cls = typeinfo.cls + cdef uint64_t type_addr = type_cls + cdef flat_hash_map[uint64_t, int32_t].iterator it = self._c_type_map.find(type_addr) + if it != self._c_type_map.end(): + buffer.write_varuint32(deref(it).second) + + cdef index = self._c_type_map.size() + buffer.write_varuint32(index) + self._c_type_map[type_addr] = index + type_def = typeinfo.type_def + self._writing_type_defs.append(type_def) + + cpdef inline list get_writing_type_defs(self): + """Get all type definitions that need to be written.""" + return self._writing_type_defs + + cpdef inline reset_write(self): + """Reset write state.""" + self._writing_type_defs.clear() + self._c_type_map.clear() + + cpdef inline add_read_type_info(self, type_info): + """Add a type info read from peer.""" + self._read_type_infos.append(type_info) + + cpdef inline get_read_type_info(self, int32_t index): + """Get a type info by index.""" + return self._read_type_infos[index] + + cpdef inline reset_read(self): + """Reset read state.""" + self._read_type_infos.clear() + + cpdef inline reset(self): + """Reset both read and write state.""" + self.reset_write() + self.reset_read() + + def __repr__(self): + return (f"MetaContext(" + f"read_defs={len(self._read_type_defs)}, " + f"read_infos={len(self._read_type_infos)}, " + f"writing_defs={len(self._writing_type_defs)})") + + +@cython.final +cdef class SerializationContext: + cdef dict objects + cdef readonly bint scoped_meta_share_enabled + cdef public object meta_context + + def __init__(self, scoped_meta_share_enabled: bool = False): + self.objects = dict() + self.scoped_meta_share_enabled = scoped_meta_share_enabled + if scoped_meta_share_enabled: + self.meta_context = MetaContext() + else: + self.meta_context = None + + def add(self, key, obj): + self.objects[id(key)] = obj + + def __contains__(self, key): + return id(key) in self.objects + + def __getitem__(self, key): + return self.objects[id(key)] + + def get(self, key): + return self.objects.get(id(key)) + + cpdef reset(self): + if len(self.objects) > 0: + self.objects.clear() + + cpdef reset_write(self): + if len(self.objects) > 0: + self.objects.clear() + if self.scoped_meta_share_enabled and self.meta_context is not None: + self.meta_context.reset_write() + + cpdef reset_read(self): + if len(self.objects) > 0: + self.objects.clear() + if self.scoped_meta_share_enabled and self.meta_context is not None: + self.meta_context.reset_read() + + @cython.final cdef class Fory: cdef readonly object language cdef readonly c_bool ref_tracking cdef readonly c_bool require_type_registration cdef readonly c_bool is_py + cdef readonly c_bool compatible cdef readonly MapRefResolver ref_resolver cdef readonly TypeResolver type_resolver cdef readonly MetaStringResolver metastring_resolver @@ -614,6 +776,7 @@ cdef class Fory: language=Language.PYTHON, ref_tracking: bool = False, require_type_registration: bool = True, + compatible: bool = False, ): """ :param require_type_registration: @@ -621,22 +784,26 @@ cdef class Fory: If disabled, unknown insecure types can be deserialized, which can be insecure and cause remote code execution attack if the types `__new__`/`__init__`/`__eq__`/`__hash__` method contain malicious code. - Do not disable type registration if you can't ensure your environment are - *indeed secure*. We are not responsible for security risks if - you disable this option. - """ + Do not disable type registration if you can't ensure your environment are + *indeed secure*. We are not responsible for security risks if + you disable this option. + :param compatible: + Whether to enable compatible mode for cross-language serialization. + When enabled, type forward/backward compatibility for struct fields will be enabled. + """ self.language = language if _ENABLE_TYPE_REGISTRATION_FORCIBLY or require_type_registration: self.require_type_registration = True else: self.require_type_registration = False + self.compatible = compatible self.ref_tracking = ref_tracking self.ref_resolver = MapRefResolver(ref_tracking) self.is_py = self.language == Language.PYTHON self.metastring_resolver = MetaStringResolver() - self.type_resolver = TypeResolver(self) + self.serialization_context = SerializationContext(scoped_meta_share_enabled=compatible) + self.type_resolver = TypeResolver(self, meta_share=compatible) self.type_resolver.initialize() - self.serialization_context = SerializationContext() self.buffer = Buffer.allocate(32) if not require_type_registration: warnings.warn( @@ -735,11 +902,27 @@ cdef class Fory: set_bit(buffer, mask_index, 3) else: clear_bit(buffer, mask_index, 3) + # Reserve space for type definitions offset, similar to Java implementation + cdef int32_t type_defs_offset_pos = -1 + if self.serialization_context.scoped_meta_share_enabled: + type_defs_offset_pos = buffer.writer_index + buffer.write_int32(-1) # Reserve 4 bytes for type definitions offset + cdef int32_t start_offset if self.language == Language.PYTHON: self.serialize_ref(buffer, obj) else: self.xserialize_ref(buffer, obj) + + # Write type definitions at the end, similar to Java implementation + if self.serialization_context.scoped_meta_share_enabled: + meta_context = self.serialization_context.meta_context + if meta_context is not None and len(meta_context.get_writing_type_defs()) > 0: + # Update the offset to point to current position + current_pos = buffer.writer_index + buffer.put_int32(type_defs_offset_pos, current_pos - type_defs_offset_pos - 4) + self.type_resolver.write_type_defs(buffer) + if buffer is not self.buffer: return buffer else: @@ -870,6 +1053,20 @@ cdef class Fory: "buffers should be null when the serialized stream is " "produced with buffer_callback null." ) + + # Read type definitions at the start, similar to Java implementation + if self.serialization_context.scoped_meta_share_enabled: + relative_type_defs_offset = buffer.read_int32() + if relative_type_defs_offset != -1: + # Save current reader position + current_reader_index = buffer.reader_index + # Jump to type definitions + buffer.reader_index = current_reader_index + relative_type_defs_offset + # Read type definitions + self.type_resolver.read_type_defs(buffer) + # Jump back to continue with object deserialization + buffer.reader_index = current_reader_index + if not is_target_x_lang: return self.deserialize_ref(buffer) return self.xdeserialize_ref(buffer) @@ -1001,7 +1198,7 @@ cdef class Fory: self.ref_resolver.reset_write() self.type_resolver.reset_write() self.metastring_resolver.reset_write() - self.serialization_context.reset() + self.serialization_context.reset_write() self.pickler.clear_memo() self._unsupported_callback = None @@ -1009,7 +1206,7 @@ cdef class Fory: self.ref_resolver.reset_read() self.type_resolver.reset_read() self.metastring_resolver.reset_read() - self.serialization_context.reset() + self.serialization_context.reset_read() self._buffers = None self.unpickler = None self._unsupported_objects = None @@ -1071,29 +1268,6 @@ cpdef inline read_nullable_pystr(Buffer buffer): return None -@cython.final -cdef class SerializationContext: - cdef dict objects - - def __init__(self): - self.objects = dict() - - def add(self, key, obj): - self.objects[id(key)] = obj - - def __contains__(self, key): - return id(key) in self.objects - - def __getitem__(self, key): - return self.objects[id(key)] - - def get(self, key): - return self.objects.get(id(key)) - - def reset(self): - if len(self.objects) > 0: - self.objects.clear() - cdef class Serializer: cdef readonly Fory fory cdef readonly object type_ diff --git a/python/pyfory/_serializer.py b/python/pyfory/_serializer.py index dacdae6704..a12b8c0966 100644 --- a/python/pyfory/_serializer.py +++ b/python/pyfory/_serializer.py @@ -53,15 +53,11 @@ # Key is null, value type is declared type, and ref tracking for value is disabled. NULL_KEY_VALUE_DECL_TYPE = KEY_HAS_NULL | VALUE_DECL_TYPE # Key is null, value type is declared type, and ref tracking for value is enabled. -NULL_KEY_VALUE_DECL_TYPE_TRACKING_REF = ( - KEY_HAS_NULL | VALUE_DECL_TYPE | TRACKING_VALUE_REF -) +NULL_KEY_VALUE_DECL_TYPE_TRACKING_REF = KEY_HAS_NULL | VALUE_DECL_TYPE | TRACKING_VALUE_REF # Value is null, key type is declared type, and ref tracking for key is disabled. NULL_VALUE_KEY_DECL_TYPE = VALUE_HAS_NULL | KEY_DECL_TYPE # Value is null, key type is declared type, and ref tracking for key is enabled. -NULL_VALUE_KEY_DECL_TYPE_TRACKING_REF = ( - VALUE_HAS_NULL | KEY_DECL_TYPE | TRACKING_VALUE_REF -) +NULL_VALUE_KEY_DECL_TYPE_TRACKING_REF = VALUE_HAS_NULL | KEY_DECL_TYPE | TRACKING_VALUE_REF class Serializer(ABC): @@ -182,11 +178,7 @@ def read(self, buffer): class DateSerializer(CrossLanguageCompatibleSerializer): def write(self, buffer, value: datetime.date): if not isinstance(value, datetime.date): - raise TypeError( - "{} should be {} instead of {}".format( - value, datetime.date, type(value) - ) - ) + raise TypeError("{} should be {} instead of {}".format(value, datetime.date, type(value))) days = (value - _base_date).days buffer.write_int32(days) @@ -208,9 +200,7 @@ def _get_timestamp(self, value: datetime.datetime): def write(self, buffer, value: datetime.datetime): if not isinstance(value, datetime.datetime): - raise TypeError( - "{} should be {} instead of {}".format(value, datetime, type(value)) - ) + raise TypeError("{} should be {} instead of {}".format(value, datetime, type(value))) # TimestampType represent micro seconds buffer.write_int64(self._get_timestamp(value)) @@ -287,10 +277,7 @@ def write_header(self, buffer, value): collect_flag |= COLLECTION_TRACKING_REF buffer.write_varuint32(len(value)) buffer.write_int8(collect_flag) - if ( - not has_different_type - and (collect_flag & COLLECTION_NOT_DECL_ELEMENT_TYPE) != 0 - ): + if not has_different_type and (collect_flag & COLLECTION_NOT_DECL_ELEMENT_TYPE) != 0: self.type_resolver.write_typeinfo(buffer, elem_typeinfo) return collect_flag, elem_typeinfo @@ -385,9 +372,7 @@ def _read_different_types(self, buffer, len_, collection_): for _ in range(len_): self._add_element( collection_, - get_next_element( - buffer, self.ref_resolver, self.type_resolver, self.is_py - ), + get_next_element(buffer, self.ref_resolver, self.type_resolver, self.is_py), ) def xwrite(self, buffer, value): @@ -532,12 +517,8 @@ def write(self, buffer, o): type_resolver.write_typeinfo(buffer, value_typeinfo) value_serializer = value_typeinfo.serializer - key_write_ref = ( - key_serializer.need_to_write_ref if key_serializer else False - ) - value_write_ref = ( - value_serializer.need_to_write_ref if value_serializer else False - ) + key_write_ref = key_serializer.need_to_write_ref if key_serializer else False + value_write_ref = value_serializer.need_to_write_ref if value_serializer else False if key_write_ref: chunk_header |= TRACKING_KEY_REF if value_write_ref: @@ -547,18 +528,11 @@ def write(self, buffer, o): chunk_size = 0 while chunk_size < MAX_CHUNK_SIZE: - if ( - key is None - or value is None - or type(key) is not key_cls - or type(value) is not value_cls - ): + if key is None or value is None or type(key) is not key_cls or type(value) is not value_cls: break if not key_write_ref or not ref_resolver.write_ref_or_null(buffer, key): self._write_obj(key_serializer, buffer, key) - if not value_write_ref or not ref_resolver.write_ref_or_null( - buffer, value - ): + if not value_write_ref or not ref_resolver.write_ref_or_null(buffer, value): value_serializer.write(buffer, value) chunk_size += 1 @@ -583,9 +557,7 @@ def read(self, buffer): if size != 0: chunk_header = buffer.read_uint8() key_serializer, value_serializer = self.key_serializer, self.value_serializer - deserialize_ref = ( - fory.deserialize_ref if self.fory.is_py else fory.xdeserialize_ref - ) + deserialize_ref = fory.deserialize_ref if self.fory.is_py else fory.xdeserialize_ref while size > 0: while True: key_has_null = (chunk_header & KEY_HAS_NULL) != 0 diff --git a/python/pyfory/_struct.py b/python/pyfory/_struct.py index 4e251c1c8b..ef1cdc68a4 100644 --- a/python/pyfory/_struct.py +++ b/python/pyfory/_struct.py @@ -243,7 +243,8 @@ def visit_dict(self, field_name, key_type, value_type, types_path=None): return TypeId.MAP, key_ids, value_ids def visit_customized(self, field_name, type_, types_path=None): - return None, None + typeinfo = self.fory.type_resolver.get_typeinfo(type_) + return [typeinfo.type_id] def visit_other(self, field_name, type_, types_path=None): from pyfory.serializer import PickleSerializer # Local import diff --git a/python/pyfory/format/__init__.py b/python/pyfory/format/__init__.py index 3bc70502fc..f6fd1d8f5a 100644 --- a/python/pyfory/format/__init__.py +++ b/python/pyfory/format/__init__.py @@ -41,8 +41,7 @@ ) except (ImportError, AttributeError) as e: warnings.warn( - f"Fory format initialization failed, please ensure pyarrow is installed " - f"with version which fory is compiled with: {e}", + f"Fory format initialization failed, please ensure pyarrow is installed with version which fory is compiled with: {e}", RuntimeWarning, stacklevel=2, ) diff --git a/python/pyfory/format/tests/test_encoder.py b/python/pyfory/format/tests/test_encoder.py index ac9dbdfdce..b4b01fc1c9 100644 --- a/python/pyfory/format/tests/test_encoder.py +++ b/python/pyfory/format/tests/test_encoder.py @@ -63,9 +63,7 @@ def test_encoder_with_schema(): @require_pyarrow def test_dict(): dict_ = {"f1": 1, "f2": "str"} - encoder = pyfory.create_row_encoder( - pa.schema([("f1", pa.int32()), ("f2", pa.utf8())]) - ) + encoder = pyfory.create_row_encoder(pa.schema([("f1", pa.int32()), ("f2", pa.utf8())])) row = encoder.to_row(dict_) new_obj = encoder.from_row(row) assert new_obj.f1 == dict_["f1"] @@ -74,9 +72,7 @@ def test_dict(): @require_pyarrow def test_ints(): - cls = pyfory.record_class_factory( - "TestNumeric", ["f" + str(i) for i in range(1, 9)] - ) + cls = pyfory.record_class_factory("TestNumeric", ["f" + str(i) for i in range(1, 9)]) schema = pa.schema( [ ("f1", pa.int64()), diff --git a/python/pyfory/meta/typedef.py b/python/pyfory/meta/typedef.py index dd786e990e..1f054c1564 100644 --- a/python/pyfory/meta/typedef.py +++ b/python/pyfory/meta/typedef.py @@ -19,9 +19,7 @@ import typing from pyfory.type import TypeId from pyfory._util import Buffer -from pyfory.serializer import MapSerializer, ListSerializer, SetSerializer -from pyfory._struct import _sort_fields, StructTypeIdVisitor, get_field_names -from pyfory.type import TypeId, infer_field, is_primitive_type, is_polymorphic_type +from pyfory.type import infer_field, is_primitive_type, is_polymorphic_type, is_struct_type from pyfory.meta.metastring import Encoding @@ -43,8 +41,12 @@ class TypeDef: - def __init__(self, name: str, type_id: int, fields: List["FieldInfo"], encoded: bytes = None, is_compressed: bool = False): - self.name = name + def __init__( + self, namespace: str, typename: str, cls: type, type_id: int, fields: List["FieldInfo"], encoded: bytes = None, is_compressed: bool = False + ): + self.namespace = namespace + self.typename = typename + self.cls = cls self.type_id = type_id self.fields = fields self.encoded = encoded @@ -54,8 +56,19 @@ def create_fields_serializer(self, resolver): serializers = [field_info.field_type.create_serializer(resolver) for field_info in self.fields] return serializers + def get_field_names(self): + return [field_info.name for field_info in self.fields] + + def create_serializer(self, resolver): + from pyfory.serializer import DataClassSerializer + + fory = resolver.fory + return DataClassSerializer( + fory, self.cls, xlang=not fory.is_py, field_names=self.get_field_names(), serializers=self.create_fields_serializer(resolver) + ) + def __repr__(self): - return f"TypeDef(name={self.name}, type_id={self.type_id}, fields={self.fields}, is_compressed={self.is_compressed})" + return f"TypeDef(namespace={self.namespace}, typename={self.typename}, cls={self.cls}, type_id={self.type_id}, fields={self.fields}, is_compressed={self.is_compressed})" class FieldInfo: @@ -121,7 +134,11 @@ def xread_with_type(cls, buffer: Buffer, resolver, xtype_id: int, is_nullable: b elif xtype_id == TypeId.UNKNOWN: return DynamicFieldType(xtype_id, False, is_nullable, is_tracking_ref) else: - return FieldType(xtype_id, False, is_nullable, is_tracking_ref) + # For primitive types, determine if they are monomorphic based on the type + from pyfory.type import is_polymorphic_type + + is_monomorphic = not is_polymorphic_type(xtype_id) + return FieldType(xtype_id, is_monomorphic, is_nullable, is_tracking_ref) def create_serializer(self, resolver): if self.type_id in [TypeId.EXT, TypeId.STRUCT, TypeId.NAMED_STRUCT, TypeId.COMPATIBLE_STRUCT, TypeId.NAMED_COMPATIBLE_STRUCT, TypeId.UNKNOWN]: @@ -145,6 +162,8 @@ def __init__( self.element_type = element_type def create_serializer(self, resolver): + from pyfory.serializer import ListSerializer, SetSerializer + if self.type_id == TypeId.LIST: return ListSerializer(resolver.fory, list, self.element_type.create_serializer(resolver)) elif self.type_id == TypeId.SET: @@ -170,6 +189,8 @@ def __init__( def create_serializer(self, resolver): key_serializer = self.key_type.create_serializer(resolver) value_serializer = self.value_type.create_serializer(resolver) + from pyfory.serializer import MapSerializer + return MapSerializer(resolver.fory, dict, key_serializer, value_serializer) def __repr__(self): @@ -192,6 +213,8 @@ def __repr__(self): def build_field_infos(type_resolver, cls): """Build field information for the class.""" + from pyfory._struct import _sort_fields, StructTypeIdVisitor, get_field_names + field_names = get_field_names(cls) type_hints = typing.get_type_hints(cls) @@ -205,6 +228,7 @@ def build_field_infos(type_resolver, cls): field_infos.append(field_info) serializers = [field_info.field_type.create_serializer(type_resolver) for field_info in field_infos] + field_names, serializers = _sort_fields(type_resolver, field_names, serializers) field_infos_map = {field_info.name: field_info for field_info in field_infos} new_field_infos = [] @@ -217,12 +241,15 @@ def build_field_infos(type_resolver, cls): def build_field_type(type_resolver, field_name: str, type_hint, visitor): """Build field type from type hint.""" type_ids = infer_field(field_name, type_hint, visitor) + print(f"=??????????=> {field_name, type_hint, visitor, type_ids}") return build_field_type_from_type_ids(type_resolver, field_name, type_ids, visitor) def build_field_type_from_type_ids(type_resolver, field_name: str, type_ids, visitor): tracking_ref = type_resolver.fory.ref_tracking type_id = type_ids[0] + if type_id is not None and type_id >= 0: + type_id = type_id & 0xFF morphic = not is_polymorphic_type(type_id) if type_id in [TypeId.SET, TypeId.LIST]: elem_type = build_field_type_from_type_ids(type_resolver, field_name, type_ids[1], visitor) @@ -234,7 +261,7 @@ def build_field_type_from_type_ids(type_resolver, field_name: str, type_ids, vis elif type_id in [TypeId.UNKNOWN, TypeId.EXT, TypeId.STRUCT, TypeId.NAMED_STRUCT, TypeId.COMPATIBLE_STRUCT, TypeId.NAMED_COMPATIBLE_STRUCT]: return DynamicFieldType(type_id, False, True, tracking_ref) else: - assert is_primitive_type(type_id) or type_id in [TypeId.STRING, TypeId.ENUM, TypeId.NAMED_ENUM], ( + assert is_primitive_type(type_id) or type_id in [TypeId.STRING, TypeId.ENUM, TypeId.NAMED_ENUM] or is_struct_type(type_id), ( f"Unknown type: {type_id} for field: {field_name}" ) return FieldType(type_id, morphic, True, tracking_ref) diff --git a/python/pyfory/meta/typedef_decoder.py b/python/pyfory/meta/typedef_decoder.py index 18f6cdbb43..2cf5633c7e 100644 --- a/python/pyfory/meta/typedef_decoder.py +++ b/python/pyfory/meta/typedef_decoder.py @@ -25,18 +25,15 @@ from pyfory._util import Buffer from pyfory.meta.typedef import TypeDef, FieldInfo, FieldType from pyfory.meta.typedef import ( - FieldInfo, - TypeDef, SMALL_NUM_FIELDS_THRESHOLD, REGISTER_BY_NAME_FLAG, FIELD_NAME_SIZE_THRESHOLD, COMPRESS_META_FLAG, HAS_FIELDS_META_FLAG, META_SIZE_MASKS, - NUM_HASH_BITS, FIELD_NAME_ENCODINGS, ) -from pyfory.type import TypeId +from pyfory.type import TypeId, record_class_factory from pyfory.meta.metastring import MetaStringDecoder, Encoding @@ -46,7 +43,20 @@ FIELD_NAME_DECODER = MetaStringDecoder("$", "_") -def decode_typedef(buffer: Buffer, resolver) -> TypeDef: +def skip_typedef(buffer: Buffer, header) -> None: + """ + Skip a TypeDef from the buffer. + """ + # Extract components from header + meta_size = header & META_SIZE_MASKS + # If meta size is at maximum, read additional size + if meta_size == META_SIZE_MASKS: + meta_size += buffer.read_varuint32() + # Read meta data + buffer.read_bytes(meta_size) + + +def decode_typedef(buffer: Buffer, resolver, header=None) -> TypeDef: """ Decode a TypeDef from the buffer. @@ -58,7 +68,8 @@ def decode_typedef(buffer: Buffer, resolver) -> TypeDef: The decoded TypeDef. """ # Read global binary header - header = buffer.read_int64() + if header is None: + header = buffer.read_int64() # Extract components from header meta_size = header & META_SIZE_MASKS @@ -90,11 +101,11 @@ def decode_typedef(buffer: Buffer, resolver) -> TypeDef: # Check if registered by name is_registered_by_name = (meta_header & REGISTER_BY_NAME_FLAG) != 0 + type_cls = None # Read type info if is_registered_by_name: namespace = read_namespace(meta_buffer) typename = read_typename(meta_buffer) - name = namespace + "." + typename if namespace else typename # Look up the type_id from namespace and typename type_info = resolver.get_typeinfo_by_name(namespace, typename) if type_info: @@ -105,15 +116,23 @@ def decode_typedef(buffer: Buffer, resolver) -> TypeDef: else: type_id = meta_buffer.read_varuint32() type_info = resolver.get_typeinfo_by_id(type_id) - name = type_info.cls.__name__ - + if type_info is not None: + type_cls = type_info.cls + namespace = type_info.decode_namespace() + typename = type_info.decode_typename() + else: + namespace = "fory" + typename = f"Nonexistent{type_id}" + name = namespace + "." + typename if namespace else typename # Read fields info if present field_infos = [] if has_fields_meta: field_infos = read_fields_info(meta_buffer, resolver, name, num_fields) + if type_cls is None: + type_cls = record_class_factory(name, [field_info.name for field_info in field_infos]) # Create TypeDef object - return TypeDef(name, type_id, field_infos, meta_data, is_compressed) + return TypeDef(namespace, typename, type_cls, type_id, field_infos, meta_data, is_compressed) def read_namespace(buffer: Buffer) -> str: @@ -174,7 +193,7 @@ def read_field_info(buffer: Buffer, resolver, defined_class: str) -> FieldInfo: field_name_size += 1 encoding = FIELD_NAME_ENCODINGS[field_name_encoding] is_nullable = (header & 0b10) != 0 - is_tracking_ref = header & 0b1 + is_tracking_ref = (header & 0b1) != 0 # Read field type info (without flags since they're in the header) xtype_id = buffer.read_varuint32() diff --git a/python/pyfory/meta/typedef_encoder.py b/python/pyfory/meta/typedef_encoder.py index f652fc40f9..7d8b5fdb3e 100644 --- a/python/pyfory/meta/typedef_encoder.py +++ b/python/pyfory/meta/typedef_encoder.py @@ -33,7 +33,6 @@ from pyfory.meta.metastring import MetaStringEncoder from pyfory._util import Buffer -from pyfory.type import TypeId from pyfory.lib.mmh3 import hash_buffer @@ -75,18 +74,17 @@ def encode_typedef(type_resolver, cls): buffer.write_varuint32(len(field_infos) - SMALL_NUM_FIELDS_THRESHOLD) # Write type info - type_info = type_resolver.get_typeinfo(cls) - assert type_info.type_id > 0 - - if not TypeId.is_namespaced_type(type_info.type_id): - buffer.write_varuint32(type_info.type_id) - else: + if type_resolver.is_registered_by_name(cls): header |= REGISTER_BY_NAME_FLAG - namespace = type_info.decode_namespace() - typename = type_info.decode_typename() + namespace, typename = type_resolver.get_registered_name(cls) write_namespace(buffer, namespace) write_typename(buffer, typename) - + # Use the actual type_id from the resolver, not a generic one + type_id = type_resolver.get_registered_id(cls) + else: + assert type_resolver.is_registered_by_id(cls), "Class must be registered by name or id" + type_id = type_resolver.get_registered_id(cls) + buffer.write_varuint32(type_id) # Update header byte buffer.put_uint8(0, header) @@ -103,7 +101,15 @@ def encode_typedef(type_resolver, cls): binary = compressed_binary # Prepend header binary = prepend_header(binary, is_compressed, len(field_infos) > 0) - return TypeDef(cls.__name__, type_info.type_id, field_infos, binary, is_compressed) + # Extract namespace and typename + if type_resolver.is_registered_by_name(cls): + namespace, typename = type_resolver.get_registered_name(cls) + else: + splits = cls.__name__.rsplit(".", 1) + if len(splits) == 1: + splits.insert(0, "") + namespace, typename = splits + return TypeDef(namespace, typename, cls, type_id, field_infos, binary, is_compressed) def prepend_header(buffer: bytes, is_compressed: bool, has_fields_meta: bool): @@ -125,7 +131,7 @@ def prepend_header(buffer: bytes, is_compressed: bool, has_fields_meta: bool): result.write_varuint32(meta_size - META_SIZE_MASKS) result.write_bytes(buffer) - return result + return result.to_bytes() def write_namespace(buffer: Buffer, namespace: str): diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py index 0018232d92..c19bfc777b 100644 --- a/python/pyfory/serializer.py +++ b/python/pyfory/serializer.py @@ -24,6 +24,7 @@ import pickle import types import typing +from typing import List import warnings from weakref import WeakValueDictionary @@ -297,21 +298,22 @@ def xread(self, buffer): class DataClassSerializer(Serializer): - def __init__(self, fory, clz: type, xlang: bool = False): + def __init__(self, fory, clz: type, xlang: bool = False, field_names: List[str] = None, serializers: List[Serializer] = None): super().__init__(fory, clz) self._xlang = xlang # This will get superclass type hints too. self._type_hints = typing.get_type_hints(clz) - self._field_names = self._get_field_names(clz) + self._field_names = field_names or self._get_field_names(clz) self._has_slots = hasattr(clz, "__slots__") if self._xlang: - self._serializers = [None] * len(self._field_names) - visitor = ComplexTypeVisitor(fory) - for index, key in enumerate(self._field_names): - serializer = infer_field(key, self._type_hints[key], visitor, types_path=[]) - self._serializers[index] = serializer - self._field_names, self._serializers = _sort_fields(fory.type_resolver, self._field_names, self._serializers) + self._serializers = serializers or [None] * len(self._field_names) + if serializers is None: + visitor = ComplexTypeVisitor(fory) + for index, key in enumerate(self._field_names): + serializer = infer_field(key, self._type_hints[key], visitor, types_path=[]) + self._serializers[index] = serializer + self._field_names, self._serializers = _sort_fields(fory.type_resolver, self._field_names, self._serializers) self._hash = 0 # Will be computed on first xwrite/xread self._generated_xwrite_method = self._gen_xwrite_method() self._generated_xread_method = self._gen_xread_method() @@ -443,13 +445,14 @@ def _gen_xwrite_method(self): context["_field_names"] = self._field_names context["_type_hints"] = self._type_hints context["_serializers"] = self._serializers - # Compute hash at generation time since we're in xlang mode - if self._hash == 0: - self._hash = _get_hash(self.fory, self._field_names, self._type_hints) stmts = [ f'"""xwrite method for {self.type_}"""', - f"{buffer}.write_int32({self._hash})", ] + if not self.fory.compatible: + # Compute hash at generation time since we're in xlang mode + if self._hash == 0: + self._hash = _get_hash(self.fory, self._field_names, self._type_hints) + stmts.append(f"{buffer}.write_int32({self._hash})") if not self._has_slots: stmts.append(f"{value_dict} = {value}.__dict__") for index, field_name in enumerate(self._field_names): @@ -487,18 +490,29 @@ def _gen_xread_method(self): context["_field_names"] = self._field_names context["_type_hints"] = self._type_hints context["_serializers"] = self._serializers - # Compute hash at generation time since we're in xlang mode - if self._hash == 0: - self._hash = _get_hash(self.fory, self._field_names, self._type_hints) + current_class_field_names = set(self._get_field_names(self.type_)) stmts = [ f'"""xread method for {self.type_}"""', - f"read_hash = {buffer}.read_int32()", - f"if read_hash != {self._hash}:", - f""" raise TypeNotCompatibleError( - f"Hash {{read_hash}} is not consistent with {self._hash} for type {self.type_}")""", - f"{obj} = {obj_class}.__new__({obj_class})", - f"{ref_resolver}.reference({obj})", ] + if not self.fory.compatible: + # Compute hash at generation time since we're in xlang mode + if self._hash == 0: + self._hash = _get_hash(self.fory, self._field_names, self._type_hints) + stmts.extend( + [ + f"read_hash = {buffer}.read_int32()", + f"if read_hash != {self._hash}:", + f""" raise TypeNotCompatibleError( + f"Hash {{read_hash}} is not consistent with {self._hash} for type {self.type_}")""", + ] + ) + stmts.extend( + [ + f"{obj} = {obj_class}.__new__({obj_class})", + f"{ref_resolver}.reference({obj})", + ] + ) + if not self._has_slots: stmts.append(f"{obj_dict} = {obj}.__dict__") @@ -507,6 +521,9 @@ def _gen_xread_method(self): context[serializer_var] = self._serializers[index] field_value = f"field_value{index}" stmts.append(f"{field_value} = {fory}.xdeserialize_ref({buffer}, serializer={serializer_var})") + if field_name not in current_class_field_names: + stmts.append(f"# {field_name} is not in {self.type_}") + continue if not self._has_slots: stmts.append(f"{obj_dict}['{field_name}'] = {field_value}") else: diff --git a/python/pyfory/tests/benchmark.py b/python/pyfory/tests/benchmark.py index ab6abe1b70..75c883296e 100644 --- a/python/pyfory/tests/benchmark.py +++ b/python/pyfory/tests/benchmark.py @@ -33,13 +33,9 @@ def test_encode(): assert foo == encoder.from_row(row) t1 = timeit.timeit(lambda: encoder.to_row(foo), number=iter_nums) - print( - "encoder take {0} for {1} times, avg: {2}".format(t1, iter_nums, t1 / iter_nums) - ) + print("encoder take {0} for {1} times, avg: {2}".format(t1, iter_nums, t1 / iter_nums)) t2 = timeit.timeit(lambda: pickle.dumps(foo), number=iter_nums) - print( - "pickle take {0} for {1} times, avg: {2}".format(t2, iter_nums, t2 / iter_nums) - ) + print("pickle take {0} for {1} times, avg: {2}".format(t2, iter_nums, t2 / iter_nums)) @pytest.mark.skip(reason="take too long") @@ -51,18 +47,10 @@ def test_decode(): row = encoder.to_row(foo) assert foo == encoder.from_row(row) t1 = timeit.timeit(lambda: encoder.from_row(row), number=iter_nums) - print( - "encoder take {0} for {1} times, avg: {2}, size {3}".format( - t1, iter_nums, t1 / iter_nums, row.size_bytes() - ) - ) + print("encoder take {0} for {1} times, avg: {2}, size {3}".format(t1, iter_nums, t1 / iter_nums, row.size_bytes())) pickled_data = pickle.dumps(foo) t2 = timeit.timeit(lambda: pickle.loads(pickled_data), number=iter_nums) - print( - "pickle take {0} for {1} times, avg: {2}, size {3}".format( - t2, iter_nums, t2 / iter_nums, len(pickled_data) - ) - ) + print("pickle take {0} for {1} times, avg: {2}, size {3}".format(t2, iter_nums, t2 / iter_nums, len(pickled_data))) if __name__ == "__main__": diff --git a/python/pyfory/tests/record.py b/python/pyfory/tests/record.py index 2f56a9ad81..31ebd66a81 100644 --- a/python/pyfory/tests/record.py +++ b/python/pyfory/tests/record.py @@ -117,9 +117,7 @@ def foo_schema(): ("f4", pa.map_(pa.string(), pa.int32())), ("f5", pa.list_(pa.int32())), ("f6", pa.int32()), - pa.field( - "f7", bar_struct, metadata={"cls": fory.get_qualified_classname(Bar)} - ), + pa.field("f7", bar_struct, metadata={"cls": fory.get_qualified_classname(Bar)}), ], metadata={"cls": fory.get_qualified_classname(Foo)}, ) diff --git a/python/pyfory/tests/test_buffer.py b/python/pyfory/tests/test_buffer.py index 3ba9c388ed..cefd6abf5a 100644 --- a/python/pyfory/tests/test_buffer.py +++ b/python/pyfory/tests/test_buffer.py @@ -217,10 +217,7 @@ def check_varuint64(buf: Buffer, value: int, bytes_written: int): assert buf.writer_index == buf.reader_index assert value == varint # test slow read branch in `read_varint64` - assert ( - buf.slice(reader_index, buf.reader_index - reader_index).read_varuint64() - == value - ) + assert buf.slice(reader_index, buf.reader_index - reader_index).read_varuint64() == value def test_write_buffer(): diff --git a/python/pyfory/tests/test_codegen.py b/python/pyfory/tests/test_codegen.py index 3b2243b29a..b73d2465e2 100644 --- a/python/pyfory/tests/test_codegen.py +++ b/python/pyfory/tests/test_codegen.py @@ -43,8 +43,6 @@ def _debug_compiled(x): def test_compile_function(): - code, func = codegen.compile_function( - "test_compile_function", ["x"], ["print(1)", "print(2)", "return x"], {} - ) + code, func = codegen.compile_function("test_compile_function", ["x"], ["print(1)", "print(2)", "return x"], {}) print(code) assert func(100) == 100 diff --git a/python/pyfory/tests/test_meta_share.py b/python/pyfory/tests/test_meta_share.py new file mode 100644 index 0000000000..d405b24dc0 --- /dev/null +++ b/python/pyfory/tests/test_meta_share.py @@ -0,0 +1,284 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import dataclasses +from typing import List, Dict +from pyfory import Fory, Language +import pyfory + + +@dataclasses.dataclass +class SimpleDataClass: + name: str + age: int + active: bool + + +@dataclasses.dataclass +class SimpleNestedDataClass: + value: int + name: str + + +@dataclasses.dataclass +class ExtendedDataClass: + name: str + age: int + active: bool + email: str # Additional field + + +@dataclasses.dataclass +class ReducedDataClass: + name: str + age: int + # Missing 'active' field + + +@dataclasses.dataclass +class NestedStructClass: + name: str + nested: SimpleNestedDataClass + + +@dataclasses.dataclass +class NestedStructClassInconsistent: + name: str + nested: ExtendedDataClass # Different nested type + + +@dataclasses.dataclass +class ListFieldsClass: + name: str + int_list: List[pyfory.Int32Type] + str_list: List[str] + + +@dataclasses.dataclass +class ListFieldsClassInconsistent: + name: str + int_list: List[str] # Changed from Int32Type to str + str_list: List[pyfory.Int32Type] # Changed from str to Int32Type + + +@dataclasses.dataclass +class DictFieldsClass: + name: str + int_dict: Dict[str, pyfory.Int32Type] + str_dict: Dict[str, str] + + +@dataclasses.dataclass +class DictFieldsClassInconsistent: + name: str + int_dict: Dict[str, str] # Changed from Int32Type to str + str_dict: Dict[str, pyfory.Int32Type] # Changed from str to Int32Type + + +class TestMetaShareMode: + def setup_method(self): + """Setup method to register dataclasses for each test.""" + pass + + def test_meta_share_enabled(self): + """Test that meta share mode can be enabled.""" + fory = Fory(language=Language.XLANG, compatible=True) + assert fory.serialization_context.scoped_meta_share_enabled + assert fory.serialization_context.meta_context is not None + + def test_meta_share_disabled(self): + """Test that meta share mode can be disabled.""" + fory = Fory(language=Language.XLANG, compatible=False) + assert not fory.serialization_context.scoped_meta_share_enabled + assert fory.serialization_context.meta_context is None + + def test_simple_dataclass_serialization(self): + """Test serialization of simple dataclass with meta share.""" + fory = Fory(language=Language.XLANG, compatible=True) + + # Register the dataclass + fory.register_type(SimpleDataClass) + + obj = SimpleDataClass(name="test", age=25, active=True) + buffer = fory.serialize(obj) + + # Deserialize + deserialized = fory.deserialize(buffer) + assert deserialized.name == obj.name + assert deserialized.age == obj.age + assert deserialized.active == obj.active + + def test_multiple_objects_same_type(self): + """Test that multiple objects of same type reuse type definition.""" + fory = Fory(language=Language.XLANG, compatible=True) + + # Register the dataclass + fory.register_type(SimpleDataClass) + + obj1 = SimpleDataClass(name="test1", age=25, active=True) + obj2 = SimpleDataClass(name="test2", age=30, active=False) + + # Serialize both objects + buffer1 = fory.serialize(obj1) + buffer2 = fory.serialize(obj2) + + # Create a new fory instance with the same meta context for deserialization + fory2 = Fory(language=Language.XLANG, compatible=True) + fory2.register_type(SimpleDataClass) + # Copy the meta context from the first fory instance + fory2.serialization_context.meta_context = fory.serialization_context.meta_context + + # Deserialize both + deserialized1 = fory2.deserialize(buffer1) + deserialized2 = fory2.deserialize(buffer2) + + assert deserialized1.name == obj1.name + assert deserialized2.name == obj2.name + assert deserialized1.age == obj1.age + assert deserialized2.age == obj2.age + + def test_simple_nested_dataclass_serialization(self): + """Test serialization of simple nested dataclass with meta share.""" + fory = Fory(language=Language.XLANG, compatible=True) + + # Register the dataclass + fory.register_type(SimpleNestedDataClass) + + obj = SimpleNestedDataClass(value=42, name="test") + + buffer = fory.serialize(obj) + deserialized = fory.deserialize(buffer) + + assert deserialized.value == obj.value + assert deserialized.name == obj.name + + def test_serialization_without_meta_share(self): + """Test that serialization works without meta share mode.""" + fory = Fory(language=Language.XLANG, compatible=False) + + # Register the dataclass + fory.register_type(SimpleDataClass) + + obj = SimpleDataClass(name="test", age=25, active=True) + buffer = fory.serialize(obj) + deserialized = fory.deserialize(buffer) + + assert deserialized.name == obj.name + assert deserialized.age == obj.age + assert deserialized.active == obj.active + + def test_schema_evolution_more_fields(self): + # Serialize with original schema + fory1 = Fory(language=Language.XLANG, compatible=True) + fory1.register_type(SimpleDataClass) + + obj = SimpleDataClass(name="test", age=25, active=True) + buffer = fory1.serialize(obj) + + # Deserialize with extended schema (more fields) + fory2 = Fory(language=Language.XLANG, compatible=True) + fory2.register_type(ExtendedDataClass) + deserialized = fory2.deserialize(buffer) + + # Current behavior: deserialized object is of the new registered type + assert isinstance(deserialized, ExtendedDataClass) + assert deserialized.name == obj.name + assert deserialized.age == obj.age + assert deserialized.active == obj.active + assert not hasattr(deserialized, "email") + + def test_schema_evolution_fewer_fields(self): + # Serialize with original schema + fory1 = Fory(language=Language.XLANG, compatible=True) + fory1.register_type(SimpleDataClass) + obj = SimpleDataClass(name="test", age=25, active=True) + buffer = fory1.serialize(obj) + + # Deserialize with reduced schema (fewer fields) + fory2 = Fory(language=Language.XLANG, compatible=True) + fory2.register_type(ReducedDataClass) + deserialized = fory2.deserialize(buffer) + + assert isinstance(deserialized, ReducedDataClass) + assert deserialized.name == obj.name + assert deserialized.age == obj.age + # The missing field should not be present + assert not hasattr(deserialized, "active") + + def test_schema_inconsistent_nested_struct(self): + """Test schema inconsistency with nested struct types.""" + # Serialize with original schema + fory1 = Fory(language=Language.XLANG, compatible=True) + fory1.register_type(NestedStructClass) + fory1.register_type(SimpleNestedDataClass) + + obj = NestedStructClass(name="test", nested=SimpleNestedDataClass(value=42, name="nested_test")) + buffer = fory1.serialize(obj) + + # Deserialize with inconsistent schema (different nested type) + fory2 = Fory(language=Language.XLANG, compatible=True) + fory2.register_type(NestedStructClassInconsistent) + fory2.register_type(ExtendedDataClass) + + # This should handle the schema inconsistency gracefully + deserialized = fory2.deserialize(buffer) + assert isinstance(deserialized, NestedStructClassInconsistent) + assert deserialized.name == obj.name + # The nested field type has changed, so we expect different behavior + assert hasattr(deserialized, "nested") + + def test_schema_inconsistent_list_fields(self): + """Test schema inconsistency with List field types.""" + # Serialize with original schema + fory1 = Fory(language=Language.XLANG, compatible=True) + fory1.register_type(ListFieldsClass) + + obj = ListFieldsClass(name="test", int_list=[1, 2, 3], str_list=["a", "b", "c"]) + buffer = fory1.serialize(obj) + + # Deserialize with inconsistent schema (swapped List types) + fory2 = Fory(language=Language.XLANG, compatible=True) + fory2.register_type(ListFieldsClassInconsistent) + + # This should handle the schema inconsistency gracefully + deserialized = fory2.deserialize(buffer) + assert isinstance(deserialized, ListFieldsClassInconsistent) + assert deserialized.name == obj.name + # The field types have been swapped, so we expect different behavior + assert hasattr(deserialized, "int_list") + assert hasattr(deserialized, "str_list") + + def test_schema_inconsistent_dict_fields(self): + """Test schema inconsistency with Dict field types.""" + # Serialize with original schema + fory1 = Fory(language=Language.XLANG, compatible=True) + fory1.register_type(DictFieldsClass) + + obj = DictFieldsClass(name="test", int_dict={"key1": 1, "key2": 2}, str_dict={"key1": "value1", "key2": "value2"}) + buffer = fory1.serialize(obj) + + # Deserialize with inconsistent schema (swapped Dict value types) + fory2 = Fory(language=Language.XLANG, compatible=True) + fory2.register_type(DictFieldsClassInconsistent) + + # This should handle the schema inconsistency gracefully + deserialized = fory2.deserialize(buffer) + assert isinstance(deserialized, DictFieldsClassInconsistent) + assert deserialized.name == obj.name + # The field value types have been swapped, so we expect different behavior + assert hasattr(deserialized, "int_dict") + assert hasattr(deserialized, "str_dict") diff --git a/python/pyfory/tests/test_metastring.py b/python/pyfory/tests/test_metastring.py index f21e09585c..d470de2994 100644 --- a/python/pyfory/tests/test_metastring.py +++ b/python/pyfory/tests/test_metastring.py @@ -196,7 +196,5 @@ def test_non_ascii_encoding_and_non_utf8(): non_ascii_string = "こんにちは" # Non-ASCII string - with pytest.raises( - ValueError, match="Unsupported character for LOWER_SPECIAL encoding: こ" - ): + with pytest.raises(ValueError, match="Unsupported character for LOWER_SPECIAL encoding: こ"): encoder.encode_with_encoding(non_ascii_string, Encoding.LOWER_SPECIAL) diff --git a/python/pyfory/tests/test_typedef_encoding.py b/python/pyfory/tests/test_typedef_encoding.py index 70ad351828..b53fa0986c 100644 --- a/python/pyfory/tests/test_typedef_encoding.py +++ b/python/pyfory/tests/test_typedef_encoding.py @@ -75,9 +75,10 @@ def test_typedef_creation(): FieldInfo("age", FieldType(TypeId.INT32, True, True, False), "TestTypeDef"), ] - typedef = TypeDef("TestTypeDef", TypeId.STRUCT, fields, b"encoded_data", False) + typedef = TypeDef("", "TestTypeDef", None, TypeId.STRUCT, fields, b"encoded_data", False) - assert typedef.name == "TestTypeDef" + assert typedef.namespace == "" + assert typedef.typename == "TestTypeDef" assert typedef.type_id == TypeId.STRUCT assert len(typedef.fields) == 2 assert typedef.encoded == b"encoded_data" diff --git a/python/pyfory/type.py b/python/pyfory/type.py index 7018f504fb..1ff93e509f 100644 --- a/python/pyfory/type.py +++ b/python/pyfory/type.py @@ -129,6 +129,7 @@ class TypeId: Fory type for cross-language serialization. See `org.apache.fory.types.Type` """ + UNKNOWN = -1 # null value NA = 0 @@ -356,7 +357,7 @@ def is_map_type(type_): return issubclass(type_, typing.Dict) except TypeError: return False - + _polymorphic_type_ids = { TypeId.STRUCT, @@ -368,11 +369,22 @@ def is_map_type(type_): TypeId.UNKNOWN, } +_struct_type_ids = { + TypeId.STRUCT, + TypeId.COMPATIBLE_STRUCT, + TypeId.NAMED_STRUCT, + TypeId.NAMED_COMPATIBLE_STRUCT, +} + def is_polymorphic_type(type_id: int) -> bool: return type_id in _polymorphic_type_ids +def is_struct_type(type_id: int) -> bool: + return type_id in _struct_type_ids + + def is_subclass(from_type, to_type): try: return issubclass(from_type, to_type) @@ -401,30 +413,18 @@ def visit_other(self, field_name, type_, types_path=None): def infer_field(field_name, type_, visitor: TypeVisitor, types_path=None): types_path = list(types_path or []) types_path.append(type_) - origin = ( - typing.get_origin(type_) - if hasattr(typing, "get_origin") - else getattr(type_, "__origin__", type_) - ) + origin = typing.get_origin(type_) if hasattr(typing, "get_origin") else getattr(type_, "__origin__", type_) origin = origin or type_ - args = ( - typing.get_args(type_) - if hasattr(typing, "get_args") - else getattr(type_, "__args__", ()) - ) + args = typing.get_args(type_) if hasattr(typing, "get_args") else getattr(type_, "__args__", ()) if args: if origin is list or origin == typing.List: elem_type = args[0] return visitor.visit_list(field_name, elem_type, types_path=types_path) elif origin is dict or origin == typing.Dict: key_type, value_type = args - return visitor.visit_dict( - field_name, key_type, value_type, types_path=types_path - ) + return visitor.visit_dict(field_name, key_type, value_type, types_path=types_path) else: - raise TypeError( - f"Collection types should be {list, dict} instead of {type_}" - ) + raise TypeError(f"Collection types should be {list, dict} instead of {type_}") else: if is_function(origin) or not hasattr(origin, "__annotations__"): return visitor.visit_other(field_name, type_, types_path=types_path)