Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.apache.fory.type.Generics;
import org.apache.fory.type.TypeUtils;
import org.apache.fory.type.Types;
import org.apache.fory.util.ExceptionUtils;
import org.apache.fory.util.Preconditions;
import org.apache.fory.util.record.RecordInfo;
import org.apache.fory.util.record.RecordUtils;
Expand Down Expand Up @@ -353,7 +354,7 @@ public static int computeStructHash(Fory fory, Collection<Descriptor> descriptor
}

private static int computeFieldHash(int hash, Fory fory, TypeRef<?> typeRef) {
int id;
int id = 0;
if (typeRef.isSubtypeOf(List.class)) {
// TODO(chaokunyang) add list element type into schema hash
id = Types.LIST;
Expand All @@ -365,21 +366,17 @@ private static int computeFieldHash(int hash, Fory fory, TypeRef<?> typeRef) {
TypeResolver resolver =
fory.isCrossLanguage() ? fory.getXtypeResolver() : fory.getClassResolver();
Class<?> cls = typeRef.getRawType();
if (ReflectionUtils.isAbstract(cls) || cls.isInterface()) {
id = 0;
} else {
ClassInfo classInfo = resolver.getClassInfo(typeRef.getRawType());
int xtypeId = classInfo.getXtypeId();
if (Types.isStructType(xtypeId & 0xff)) {
if (!ReflectionUtils.isAbstract(cls) && !cls.isInterface()) {
ClassInfo classInfo = resolver.getClassInfo(cls);
int xtypeId = id = classInfo.getXtypeId();
if (Types.isNamedType(xtypeId & 0xff)) {
id =
TypeUtils.computeStringHash(
classInfo.decodeNamespace() + classInfo.decodeTypeName());
} else {
id = Math.abs(xtypeId);
}
}
} catch (Exception e) {
id = 0;
ExceptionUtils.ignore(e);
}
}
long newHash = ((long) hash) * 31 + id;
Expand Down
13 changes: 13 additions & 0 deletions java/fory-core/src/main/java/org/apache/fory/type/Types.java
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,19 @@ public class Types {
public static final int UNKNOWN = 63;

// Helper methods
public static boolean isNamedType(int value) {
assert value < 0xff;
switch (value) {
case NAMED_STRUCT:
case NAMED_COMPATIBLE_STRUCT:
case NAMED_ENUM:
case NAMED_EXT:
return true;
default:
return false;
}
}

public static boolean isStructType(int value) {
assert value < 0xff;
return value == STRUCT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ private void structRoundBack(Fory fory, Object obj, String testName) throws IOEx
System.out.println(dataFile.toAbsolutePath());
Files.deleteIfExists(dataFile);
Files.write(dataFile, serialized);
dataFile.toFile().deleteOnExit();
// dataFile.toFile().deleteOnExit();
ImmutableList<String> command =
ImmutableList.of(
PYTHON_EXECUTABLE, "-m", PYTHON_MODULE, testName, dataFile.toAbsolutePath().toString());
Expand Down Expand Up @@ -823,9 +823,15 @@ static class EnumFieldStruct {
String f3;
}

@Test
public void testEnumField() throws java.io.IOException {
Fory fory = Fory.builder().withLanguage(Language.XLANG).requireClassRegistration(true).build();
@Test(dataProvider = "compatible")
public void testEnumField(boolean compatible) throws java.io.IOException {
Fory fory =
Fory.builder()
.withLanguage(Language.XLANG)
.withCompatibleMode(
compatible ? CompatibleMode.COMPATIBLE : CompatibleMode.SCHEMA_CONSISTENT)
.requireClassRegistration(true)
.build();
fory.register(EnumTestClass.class, "test.EnumTestClass");
fory.register(EnumFieldStruct.class, "test.EnumFieldStruct");

Expand All @@ -834,6 +840,6 @@ public void testEnumField() throws java.io.IOException {
a.f2 = EnumTestClass.BAR;
a.f3 = "abc";
Assert.assertEquals(xserDe(fory, a), a);
structRoundBack(fory, a, "test_enum_field");
structRoundBack(fory, a, "test_enum_field" + (compatible ? "_compatible" : ""));
}
}
10 changes: 4 additions & 6 deletions python/pyfory/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
PickleStrongCacheSerializer,
PickleSerializer,
DataClassSerializer,
DataClassStubSerializer,
StatefulSerializer,
ReduceSerializer,
FunctionSerializer,
Expand Down Expand Up @@ -127,7 +128,7 @@ def __init__(
self.namespace_bytes = namespace_bytes
self.typename_bytes = typename_bytes
self.dynamic_type = dynamic_type
self.type_def = None
self.type_def = type_def

def __repr__(self):
return f"TypeInfo(cls={self.cls}, type_id={self.type_id}, serializer={self.serializer})"
Expand Down Expand Up @@ -533,11 +534,8 @@ def _create_serializer(self, cls):
# Use FunctionSerializer for function types (including lambdas)
serializer = FunctionSerializer(self.fory, cls)
elif dataclasses.is_dataclass(cls):
if not self.meta_share:
serializer = DataClassSerializer(self.fory, cls, xlang=not self.fory.is_py)
else:
# lazy create serializer to handle nested struct fields.
serializer = None
# lazy create serializer to handle nested struct fields.
serializer = DataClassStubSerializer(self.fory, cls, xlang=not self.fory.is_py)
elif issubclass(cls, enum.Enum):
serializer = EnumSerializer(self.fory, cls)
elif (hasattr(cls, "__reduce__") and cls.__reduce__ is not object.__reduce__) or (
Expand Down
56 changes: 31 additions & 25 deletions python/pyfory/_serialization.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -534,10 +534,10 @@ cdef class TypeResolver:
self._c_types_info[<uintptr_t> <PyObject *> cls] = <PyObject *> type_info
self._populate_typeinfo(type_info)
return type_info

def is_registered_by_name(self, cls):
return self._resolver.is_registered_by_name(cls)

def is_registered_by_id(self, cls):
return self._resolver.is_registered_by_id(cls)

Expand Down Expand Up @@ -565,11 +565,11 @@ cdef class TypeResolver:
cdef:
int32_t type_id = typeinfo.type_id
int32_t internal_type_id = type_id & 0xFF

if self.meta_share:
self.write_shared_type_meta(buffer, typeinfo)
return

buffer.write_varuint32(type_id)
if IsNamespacedType(internal_type_id):
self.metastring_resolver.write_meta_string_bytes(buffer, typeinfo.namespace_bytes)
Expand All @@ -578,7 +578,7 @@ cdef class TypeResolver:
cpdef inline TypeInfo read_typeinfo(self, Buffer buffer):
if self.meta_share:
return self.read_shared_type_meta(buffer)

cdef:
int32_t type_id = buffer.read_varuint32()
if type_id < 0:
Expand All @@ -597,19 +597,22 @@ cdef class TypeResolver:
raise ValueError(f"Unexpected type_id {type_id}")
typeinfo = <TypeInfo> typeinfo_ptr
return typeinfo

cpdef inline TypeInfo get_typeinfo_by_id(self, int32_t type_id):
if type_id >= self._c_registered_id_to_type_info.size() or type_id < 0 or IsNamespacedType(type_id & 0xFF):
raise ValueError(f"Unexpected type_id {type_id}")
typeinfo_ptr = self._c_registered_id_to_type_info[type_id]
if typeinfo_ptr == NULL:
raise ValueError(f"Unexpected type_id {type_id}")
typeinfo = <TypeInfo> typeinfo_ptr
return typeinfo
return typeinfo

def get_typeinfo_by_name(self, namespace, typename):
return self._resolver.get_typeinfo_by_name(namespace=namespace, typename=typename)

cpdef _set_typeinfo(self, typeinfo):
self._resolver._set_typeinfo(typeinfo)

def get_meta_compressor(self):
return self._resolver.get_meta_compressor()

Expand Down Expand Up @@ -647,26 +650,26 @@ cdef class MetaContext:
"""
Context for sharing type meta across multiple serialization. Type name, field name and field
type will be shared between different serialization.

Note that this context is not thread-safe, you should use it with one Fory instance.
"""
cdef:
# Types which have sent definitions to peer
# Maps type objects to their assigned IDs
flat_hash_map[uint64_t, int32_t] _c_type_map
flat_hash_map[uint64_t, int32_t] _c_type_map

# Counter for assigning new IDs
list _writing_type_defs
list _read_type_infos
object fory
object type_resolver

def __cinit__(self, object fory):
self.fory = fory
self.type_resolver = fory.type_resolver
self._writing_type_defs = []
self._read_type_infos = []

cpdef inline void write_shared_typeinfo(self, Buffer buffer, typeinfo):
"""Add a type definition to the writing queue."""
type_cls = typeinfo.cls
Expand All @@ -680,45 +683,48 @@ cdef class MetaContext:
cdef flat_hash_map[uint64_t, int32_t].iterator it = self._c_type_map.find(type_addr)
if it != self._c_type_map.end():
buffer.write_varuint32(deref(it).second)

cdef index = self._c_type_map.size()
buffer.write_varuint32(index)
self._c_type_map[type_addr] = index
type_def = typeinfo.type_def
if type_def is None:
self.type_resolver._set_typeinfo(typeinfo)
type_def = typeinfo.type_def
self._writing_type_defs.append(type_def)

cpdef inline list get_writing_type_defs(self):
"""Get all type definitions that need to be written."""
return self._writing_type_defs

cpdef inline reset_write(self):
"""Reset write state."""
self._writing_type_defs.clear()
self._c_type_map.clear()

cpdef inline add_read_typeinfo(self, type_info):
"""Add a type info read from peer."""
self._read_type_infos.append(type_info)

cpdef inline read_shared_typeinfo(self, Buffer buffer):
"""Read a type info from buffer."""
cdef type_id = buffer.read_varuint32()
if IsTypeShareMeta(type_id & 0xFF):
return self._read_type_infos[buffer.read_varuint32()]
return self.type_resolver.get_typeinfo_by_id(type_id)

cpdef inline reset_read(self):
"""Reset read state."""
self._read_type_infos.clear()

cpdef inline reset(self):
"""Reset both read and write state."""
self.reset_write()
self.reset_read()

def __str__(self):
return self.__repr__()

def __repr__(self):
return (f"MetaContext("
f"read_infos={self._read_type_infos}, "
Expand Down Expand Up @@ -930,13 +936,13 @@ cdef class Fory:
if self.serialization_context.scoped_meta_share_enabled:
type_defs_offset_pos = buffer.writer_index
buffer.write_int32(-1) # Reserve 4 bytes for type definitions offset

cdef int32_t start_offset
if self.language == Language.PYTHON:
self.serialize_ref(buffer, obj)
else:
self.xserialize_ref(buffer, obj)

# Write type definitions at the end, similar to Java implementation
if self.serialization_context.scoped_meta_share_enabled:
meta_context = self.serialization_context.meta_context
Expand All @@ -945,7 +951,7 @@ cdef class Fory:
current_pos = buffer.writer_index
buffer.put_int32(type_defs_offset_pos, current_pos - type_defs_offset_pos - 4)
self.type_resolver.write_type_defs(buffer)

if buffer is not self.buffer:
return buffer
else:
Expand Down Expand Up @@ -1076,7 +1082,7 @@ cdef class Fory:
"buffers should be null when the serialized stream is "
"produced with buffer_callback null."
)

# Read type definitions at the start, similar to Java implementation
if self.serialization_context.scoped_meta_share_enabled:
relative_type_defs_offset = buffer.read_int32()
Expand All @@ -1089,7 +1095,7 @@ cdef class Fory:
self.type_resolver.read_type_defs(buffer)
# Jump back to continue with object deserialization
buffer.reader_index = current_reader_index

if not is_target_x_lang:
return self.deserialize_ref(buffer)
return self.xdeserialize_ref(buffer)
Expand Down
33 changes: 30 additions & 3 deletions python/pyfory/_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
}


class ComplexTypeVisitor(TypeVisitor):
class StructFieldSerializerVisitor(TypeVisitor):
def __init__(
self,
fory,
Expand All @@ -88,6 +88,8 @@ def visit_dict(self, field_name, key_type, value_type, types_path=None):
return MapSerializer(self.fory, dict, key_serializer, value_serializer)

def visit_customized(self, field_name, type_, types_path=None):
if issubclass(type_, enum.Enum):
return self.fory.type_resolver.get_serializer(type_)
return None

def visit_other(self, field_name, type_, types_path=None):
Expand Down Expand Up @@ -210,7 +212,10 @@ def visit_other(self, field_name, type_, types_path=None):
assert not isinstance(serializer, (PickleSerializer,))
id_ = typeinfo.type_id
assert id_ is not None, serializer
id_ = abs(id_)
if TypeId.is_namespaced_type(typeinfo.type_id):
namespace_str = typeinfo.decode_namespace()
typename_str = typeinfo.decode_typename()
id_ = compute_string_hash(namespace_str + typename_str)
self._hash = self._compute_field_hash(self._hash, id_)

@staticmethod
Expand Down Expand Up @@ -254,14 +259,36 @@ def visit_other(self, field_name, type_, types_path=None):
from pyfory.serializer import PickleSerializer # Local import

if is_subclass(type_, enum.Enum):
return self.fory.type_resolver.get_typeinfo(type_).type_id
return [self.fory.type_resolver.get_typeinfo(type_).type_id]
if type_ not in basic_types and not is_py_array_type(type_):
return None, None
typeinfo = self.fory.type_resolver.get_typeinfo(type_)
assert not isinstance(typeinfo.serializer, (PickleSerializer,))
return [typeinfo.type_id]


class StructTypeVisitor(TypeVisitor):
def __init__(self, cls):
self.cls = cls

def visit_list(self, field_name, elem_type, types_path=None):
# Infer type recursively for type such as List[Dict[str, str]]
elem_types = infer_field("item", elem_type, self, types_path=types_path)
return typing.List, elem_types

def visit_dict(self, field_name, key_type, value_type, types_path=None):
# Infer type recursively for type such as Dict[str, Dict[str, str]]
key_types = infer_field("key", key_type, self, types_path=types_path)
value_types = infer_field("value", value_type, self, types_path=types_path)
return typing.Dict, key_types, value_types

def visit_customized(self, field_name, type_, types_path=None):
return [type_]

def visit_other(self, field_name, type_, types_path=None):
return [type_]


def get_field_names(clz, type_hints=None):
if hasattr(clz, "__dict__"):
# Regular object with __dict__
Expand Down
Loading
Loading