|
18 | 18 | import datetime |
19 | 19 | import enum |
20 | 20 | import logging |
| 21 | +import typing |
21 | 22 |
|
22 | 23 | from pyfory.type import ( |
23 | 24 | TypeVisitor, |
@@ -166,7 +167,7 @@ def numeric_sorter(item): |
166 | 167 | map_types = sorted(map_types, key=sorter) |
167 | 168 | other_types = sorted(other_types, key=sorter) |
168 | 169 | all_types = boxed_types + final_types + other_types + collection_types + map_types |
169 | | - return [t[1] for t in all_types], [t[2] for t in all_types] |
| 170 | + return [t[2] for t in all_types], [t[1] for t in all_types] |
170 | 171 |
|
171 | 172 |
|
172 | 173 | class StructHashVisitor(TypeVisitor): |
@@ -221,3 +222,49 @@ def _compute_field_hash(hash_, id_): |
221 | 222 |
|
222 | 223 | def get_hash(self): |
223 | 224 | return self._hash |
| 225 | + |
| 226 | + |
| 227 | +class StructTypeIdVisitor(TypeVisitor): |
| 228 | + def __init__( |
| 229 | + self, |
| 230 | + fory, |
| 231 | + ): |
| 232 | + self.fory = fory |
| 233 | + |
| 234 | + def visit_list(self, field_name, elem_type, types_path=None): |
| 235 | + # Infer type recursively for type such as List[Dict[str, str]] |
| 236 | + elem_ids = infer_field("item", elem_type, self, types_path=types_path) |
| 237 | + return TypeId.LIST, elem_ids |
| 238 | + |
| 239 | + def visit_dict(self, field_name, key_type, value_type, types_path=None): |
| 240 | + # Infer type recursively for type such as Dict[str, Dict[str, str]] |
| 241 | + key_ids = infer_field("key", key_type, self, types_path=types_path) |
| 242 | + value_ids = infer_field("value", value_type, self, types_path=types_path) |
| 243 | + return TypeId.MAP, key_ids, value_ids |
| 244 | + |
| 245 | + def visit_customized(self, field_name, type_, types_path=None): |
| 246 | + return None, None |
| 247 | + |
| 248 | + def visit_other(self, field_name, type_, types_path=None): |
| 249 | + from pyfory.serializer import PickleSerializer # Local import |
| 250 | + |
| 251 | + if is_subclass(type_, enum.Enum): |
| 252 | + return self.fory.type_resolver.get_typeinfo(type_).type_id |
| 253 | + if type_ not in basic_types and not is_py_array_type(type_): |
| 254 | + return None, None |
| 255 | + typeinfo = self.fory.type_resolver.get_typeinfo(type_) |
| 256 | + assert not isinstance(typeinfo.serializer, (PickleSerializer,)) |
| 257 | + return [typeinfo.type_id] |
| 258 | + |
| 259 | + |
| 260 | +def get_field_names(clz, type_hints=None): |
| 261 | + if hasattr(clz, "__dict__"): |
| 262 | + # Regular object with __dict__ |
| 263 | + # We can't know the fields without an instance, so we rely on type hints |
| 264 | + if type_hints is None: |
| 265 | + type_hints = typing.get_type_hints(clz) |
| 266 | + return sorted(type_hints.keys()) |
| 267 | + elif hasattr(clz, "__slots__"): |
| 268 | + # Object with __slots__ |
| 269 | + return sorted(clz.__slots__) |
| 270 | + return [] |
0 commit comments