Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 42 additions & 35 deletions python/pyfory/_fory.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class Fory:
__slots__ = (
"language",
"is_py",
"compatbile",
"ref_tracking",
"ref_resolver",
"type_resolver",
Expand All @@ -113,13 +114,13 @@ class Fory:
"_unsupported_objects",
"_peer_language",
)
serialization_context: "SerializationContext"

def __init__(
self,
language=Language.PYTHON,
ref_tracking: bool = False,
require_type_registration: bool = True,
compatbile: bool = False,
):
"""
:param require_type_registration:
Expand All @@ -130,10 +131,14 @@ def __init__(
Do not disable type registration if you can't ensure your environment are
*indeed secure*. We are not responsible for security risks if
you disable this option.
:param compatbile:
Whether to enable compatbile mode for cross-language serialization.
When enabled, type forward/backward compatibility for struct fields will be enabled.
"""
self.language = language
self.is_py = language == Language.PYTHON
self.require_type_registration = _ENABLE_TYPE_REGISTRATION_FORCIBLY or require_type_registration
self.compatbile = compatbile
self.ref_tracking = ref_tracking
if self.ref_tracking:
self.ref_resolver = MapRefResolver()
Expand All @@ -143,9 +148,11 @@ def __init__(
from pyfory._registry import TypeResolver

self.metastring_resolver = MetaStringResolver()
self.type_resolver = TypeResolver(self)
self.type_resolver = TypeResolver(self, meta_share=compatbile)
self.type_resolver.initialize()
self.serialization_context = SerializationContext()
from pyfory._serialization import SerializationContext

self.serialization_context = SerializationContext(scoped_meta_share_enabled=compatbile)
self.buffer = Buffer.allocate(32)
if not require_type_registration:
warnings.warn(
Expand Down Expand Up @@ -255,10 +262,26 @@ def _serialize(
set_bit(buffer, mask_index, 3)
else:
clear_bit(buffer, mask_index, 3)
# Reserve space for type definitions offset, similar to Java implementation
type_defs_offset_pos = None
if self.serialization_context.scoped_meta_share_enabled:
type_defs_offset_pos = buffer.writer_index
buffer.write_int32(-1) # Reserve 4 bytes for type definitions offset

if self.language == Language.PYTHON:
self.serialize_ref(buffer, obj)
else:
self.xserialize_ref(buffer, obj)

# Write type definitions at the end, similar to Java implementation
if self.serialization_context.scoped_meta_share_enabled:
meta_context = self.serialization_context.meta_context
if meta_context is not None and len(meta_context.get_writing_type_defs()) > 0:
# Update the offset to point to current position
current_pos = buffer.writer_index
buffer.put_int32(type_defs_offset_pos, current_pos - type_defs_offset_pos - 4)
self.type_resolver.write_type_defs(buffer)

self.reset_write()
if buffer is not self.buffer:
return buffer
Expand Down Expand Up @@ -369,6 +392,20 @@ def _deserialize(
self._buffers = iter(buffers)
else:
assert buffers is None, "buffers should be null when the serialized stream is produced with buffer_callback null."

# Read type definitions at the start, similar to Java implementation
if self.serialization_context.scoped_meta_share_enabled:
relative_type_defs_offset = buffer.read_int32()
if relative_type_defs_offset != -1:
# Save current reader position
current_reader_index = buffer.reader_index
# Jump to type definitions
buffer.reader_index = current_reader_index + relative_type_defs_offset
# Read type definitions
self.type_resolver.read_type_defs(buffer)
# Jump back to continue with object deserialization
buffer.reader_index = current_reader_index

if is_target_x_lang:
obj = self.xdeserialize_ref(buffer)
else:
Expand Down Expand Up @@ -470,7 +507,7 @@ def read_ref_pyobject(self, buffer):
def reset_write(self):
self.ref_resolver.reset_write()
self.type_resolver.reset_write()
self.serialization_context.reset()
self.serialization_context.reset_write()
self.metastring_resolver.reset_write()
self.pickler.clear_memo()
self._buffer_callback = None
Expand All @@ -479,7 +516,7 @@ def reset_write(self):
def reset_read(self):
self.ref_resolver.reset_read()
self.type_resolver.reset_read()
self.serialization_context.reset()
self.serialization_context.reset_read()
self.metastring_resolver.reset_write()
self.unpickler = None
self._buffers = None
Expand All @@ -490,36 +527,6 @@ def reset(self):
self.reset_read()


class SerializationContext:
"""
A context is used to add some context-related information, so that the
serializers can setup relation between serializing different objects.
The context will be reset after finished serializing/deserializing the
object tree.
"""

__slots__ = ("objects",)

def __init__(self):
self.objects = dict()

def add(self, key, obj):
self.objects[id(key)] = obj

def __contains__(self, key):
return id(key) in self.objects

def __getitem__(self, key):
return self.objects[id(key)]

def get(self, key):
return self.objects.get(id(key))

def reset(self):
if len(self.objects) > 0:
self.objects.clear()


_ENABLE_TYPE_REGISTRATION_FORCIBLY = os.getenv("ENABLE_TYPE_REGISTRATION_FORCIBLY", "0") in {
"1",
"true",
Expand Down
115 changes: 113 additions & 2 deletions python/pyfory/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@
# preserve 0 as flag for type id not set in TypeInfo`
NO_TYPE_ID,
)
from pyfory.meta.typedef import TypeDef
from pyfory.meta.typedef_decoder import decode_typedef, skip_typedef
from pyfory.meta.typedef_encoder import encode_typedef

try:
import numpy as np
Expand All @@ -104,6 +107,7 @@ class TypeInfo:
"namespace_bytes",
"typename_bytes",
"dynamic_type",
"type_def",
)

def __init__(
Expand All @@ -114,13 +118,15 @@ def __init__(
namespace_bytes=None,
typename_bytes=None,
dynamic_type: bool = False,
type_def: TypeDef = None,
):
self.cls = cls
self.type_id = type_id
self.serializer = serializer
self.namespace_bytes = namespace_bytes
self.typename_bytes = typename_bytes
self.dynamic_type = dynamic_type
self.type_def = None

def __repr__(self):
return f"TypeInfo(cls={self.cls}, type_id={self.type_id}, serializer={self.serializer})"
Expand Down Expand Up @@ -160,9 +166,11 @@ class TypeResolver:
"metastring_resolver",
"language",
"_type_id_to_typeinfo",
"_meta_shared_typeinfo",
"meta_share",
)

def __init__(self, fory):
def __init__(self, fory, meta_share=False):
self.fory = fory
self.metastring_resolver = fory.metastring_resolver
self.language = fory.language
Expand All @@ -182,9 +190,12 @@ def __init__(self, fory):
self._named_type_to_typeinfo = dict()
self.namespace_encoder = MetaStringEncoder(".", "_")
self.namespace_decoder = MetaStringDecoder(".", "_")
# Cache for TypeDef and TypeInfo tuples (similar to Java's classIdToDef)
self._meta_shared_typeinfo = {}
self.typename_encoder = MetaStringEncoder("$", "_")
self.typename_decoder = MetaStringDecoder("$", "_")
self.meta_compressor = DeflaterMetaCompressor()
self.meta_share = meta_share

def initialize(self):
self._initialize_xlang()
Expand Down Expand Up @@ -423,6 +434,9 @@ def __register_type(
if type_id not in self._type_id_to_typeinfo or not internal:
self._type_id_to_typeinfo[type_id] = typeinfo
self._types_info[cls] = typeinfo

if self.meta_share and isinstance(serializer, DataClassSerializer):
self._set_struct_typeinfo(typeinfo)
return typeinfo

def _next_type_id(self):
Expand Down Expand Up @@ -502,7 +516,7 @@ def _create_serializer(self, cls):
# Use FunctionSerializer for function types (including lambdas)
serializer = FunctionSerializer(self.fory, cls)
elif dataclasses.is_dataclass(cls):
serializer = DataClassSerializer(self.fory, cls)
serializer = DataClassSerializer(self.fory, cls, xlang=not self.fory.is_py)
elif issubclass(cls, enum.Enum):
serializer = EnumSerializer(self.fory, cls)
elif (hasattr(cls, "__reduce__") and cls.__reduce__ is not object.__reduce__) or (
Expand Down Expand Up @@ -536,6 +550,34 @@ def _create_serializer(self, cls):
serializer = PickleSerializer(self.fory, cls)
return serializer

def _set_struct_typeinfo(self, typeinfo):
assert self.meta_share, "Meta share must be enabled"
type_def = encode_typedef(self, typeinfo.cls)
typeinfo.serializer = type_def.create_serializer(self)
typeinfo.type_def = type_def

def is_registered_by_name(self, cls):
typeinfo = self._types_info.get(cls)
if typeinfo is None:
return False
return TypeId.is_namespaced_type(typeinfo.type_id & 0xFF)

def is_registered_by_id(self, cls):
typeinfo = self._types_info.get(cls)
if typeinfo is None:
return False
return not TypeId.is_namespaced_type(typeinfo.type_id & 0xFF)

def get_registered_name(self, cls):
typeinfo = self._types_info.get(cls)
assert typeinfo is not None, f"{cls} not registered"
return typeinfo.decode_namespace(), typeinfo.decode_typename()

def get_registered_id(self, cls):
typeinfo = self._types_info.get(cls)
assert typeinfo is not None, f"{cls} not registered"
return typeinfo.type_id

def _load_metabytes_to_typeinfo(self, ns_metabytes, type_metabytes):
typeinfo = self._ns_type_to_typeinfo.get((ns_metabytes, type_metabytes))
if typeinfo is not None:
Expand All @@ -557,12 +599,22 @@ def write_typeinfo(self, buffer, typeinfo):
return
type_id = typeinfo.type_id
internal_type_id = type_id & 0xFF

# Check if meta share is enabled first
if self.meta_share:
self.write_shared_type_meta(buffer, typeinfo)
return

buffer.write_varuint32(type_id)
if TypeId.is_namespaced_type(internal_type_id):
self.metastring_resolver.write_meta_string_bytes(buffer, typeinfo.namespace_bytes)
self.metastring_resolver.write_meta_string_bytes(buffer, typeinfo.typename_bytes)

def read_typeinfo(self, buffer):
# Check if meta share is enabled first
if self.meta_share:
return self.read_shared_type_meta(buffer)

type_id = buffer.read_varuint32()
internal_type_id = type_id & 0xFF
if TypeId.is_namespaced_type(internal_type_id):
Expand Down Expand Up @@ -595,6 +647,65 @@ def get_typeinfo_by_name(self, namespace, typename):
def get_meta_compressor(self):
return self.meta_compressor

def write_shared_type_meta(self, buffer, typeinfo):
"""Write shared type meta information."""
meta_context = self.fory.serialization_context.meta_context
meta_context.write_typeinfo(buffer, typeinfo)

def read_shared_type_meta(self, buffer):
"""Read shared type meta information."""
meta_context = self.fory.serialization_context.meta_context
assert meta_context is not None, "Meta context must be set when meta share is enabled"
type_id = buffer.read_varuint32()
typeinfo = meta_context.get_read_type_info(type_id)
assert typeinfo is not None, f"Type info not found for ID {type_id}"
return typeinfo

def write_type_defs(self, buffer):
"""Write all type definitions that need to be sent."""
meta_context = self.fory.serialization_context.meta_context
if meta_context is None:
return
writing_type_defs = meta_context.get_writing_type_defs()
buffer.write_varuint32(len(writing_type_defs))
for type_def in writing_type_defs:
# Just copy the encoded bytes directly
buffer.write_bytes(type_def.encoded)

def read_type_defs(self, buffer):
"""Read all type definitions from the buffer."""
meta_context = self.fory.serialization_context.meta_context
if meta_context is None:
return

num_type_defs = buffer.read_varuint32()
for i in range(num_type_defs):
# Read the header (first 8 bytes) to get the type ID
header = buffer.read_int64()
# Check if we already have this TypeDef cached
type_info = self._meta_shared_typeinfo.get(header)
if type_info is not None:
# Skip the rest of the TypeDef binary for faster performance
skip_typedef(buffer, header)
else:
# Read the TypeDef and create TypeInfo
type_def = decode_typedef(buffer, self, header=header)
type_info = self._build_type_info_from_typedef(type_def)
# Cache the tuple for future use
self._meta_shared_typeinfo[header] = type_info
meta_context.add_read_type_info(type_info)

def _build_type_info_from_typedef(self, type_def):
"""Build TypeInfo from TypeDef using TypeDef's create_serializer method."""
# Create serializer using TypeDef's create_serializer method
serializer = type_def.create_serializer(self)
ns_metastr = self.namespace_encoder.encode(type_def.namespace or "")
ns_meta_bytes = self.metastring_resolver.get_metastr_bytes(ns_metastr)
type_metastr = self.typename_encoder.encode(type_def.typename)
type_meta_bytes = self.metastring_resolver.get_metastr_bytes(type_metastr)
typeinfo = TypeInfo(type_def.cls, type_def.type_id, serializer, ns_meta_bytes, type_meta_bytes, False, type_def)
return typeinfo

def reset(self):
pass

Expand Down
Loading
Loading