Skip to content

Commit 8cb329c

Browse files
authored
Add files via upload
1 parent b1075a9 commit 8cb329c

1 file changed

Lines changed: 172 additions & 61 deletions

File tree

src/rheelDM/main.py

Lines changed: 172 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,114 +1,206 @@
11
from __future__ import annotations
2-
32
from pathlib import Path
43
from datetime import datetime, date, time, timezone, timedelta
5-
from typing import Any, get_origin, get_args, Union
6-
import ast, types, copy, json, configparser
4+
from typing import Any, get_origin, get_args, Union, Optional
5+
import ast, types, copy, json, configparser, importlib, inspect
76
try: import tomllib # type: ignore
87
except ModuleNotFoundError: tomllib = None
98
try: import toml # type: ignore
109
except ModuleNotFoundError: toml = None
1110
try: import yaml # type: ignore
1211
except ModuleNotFoundError: yaml = None
1312

14-
1513
# =========================================================
1614
# Type Registry
1715
# =========================================================
1816

17+
DEFAULT_TYPES = {
18+
"str": str,
19+
"int": int,
20+
"float": float,
21+
"bool": bool,
22+
"NoneType": type(None),
23+
"Optional": Union[type(None), Any],
24+
25+
"list": list,
26+
"set": set,
27+
"tuple": tuple,
28+
"dict": dict,
29+
30+
"datetime": datetime,
31+
"date": date,
32+
"time": time,
33+
34+
"Path": Path,
35+
}
36+
37+
def import_from_path(path: str):
38+
parts = path.split(".")
39+
40+
for i in range(len(parts), 0, -1):
41+
module_path = ".".join(parts[:i])
42+
43+
try:
44+
module = importlib.import_module(module_path)
45+
obj = module
46+
47+
for part in parts[i:]:
48+
obj = getattr(obj, part)
49+
50+
return obj
51+
52+
except ModuleNotFoundError:
53+
continue
54+
55+
raise ImportError(f"Cannot import {path}")
56+
1957
class TypeRegistry:
2058
"""
21-
Registry for custom types.
22-
23-
Allows registering custom serializer/deserializer pairs
24-
for special Python types.
59+
Simple in-memory type registry.
2560
26-
Example:
27-
registry.register(
28-
"datetime",
29-
datetime,
30-
lambda v: f'{v.isoformat()}',
31-
lambda v: datetime.fromisoformat(v)
32-
)
61+
Supports custom serializer/deserializer pairs.
62+
No persistence, no import issues, no circular dependencies.
3363
"""
3464

35-
def __init__(self):
36-
self._registry: dict[str, tuple[type, callable, callable]] = {}
65+
_types: dict[str, type] = DEFAULT_TYPES.copy()
66+
_registry: dict[str, tuple[type, callable, callable]] = {}
67+
68+
# -----------------------------------------------------
69+
70+
@classmethod
71+
def register(cls, name: str, typ: type, serializer, deserializer):
72+
"""
73+
Register a custom type.
74+
75+
Example:
76+
TypeRegistry.register(
77+
"PartialEmoji",
78+
discord.PartialEmoji,
79+
lambda v: str(v),
80+
discord.PartialEmoji.from_str
81+
)
82+
"""
83+
cls._types[name] = typ
84+
cls._registry[name] = (typ, serializer, deserializer)
85+
86+
# -----------------------------------------------------
87+
88+
@classmethod
89+
def clear_custom(cls):
90+
"""Reset registry to default types."""
91+
cls._registry.clear()
92+
cls._types = DEFAULT_TYPES.copy()
3793

38-
def register(self, name: str, typ: type, serializer, deserializer):
39-
"""Register a new custom type."""
40-
self._registry[name] = (typ, serializer, deserializer)
94+
# -----------------------------------------------------
4195

42-
def serialize(self, value: Any) -> str:
43-
"""Convert a Python value to its RDM string representation."""
44-
for name, (typ, serializer, _) in self._registry.items():
96+
@classmethod
97+
def serialize(cls, value):
98+
"""Convert Python object to RDM string."""
99+
100+
for typ, serializer, _ in cls._registry.values():
45101
if isinstance(value, typ):
46102
return serializer(value)
103+
47104
return repr(value)
48105

49-
def deserialize(self, value_str: str, typ: type):
50-
"""Convert RDM string representation back to Python object."""
51-
for _, (registered_type, _, deserializer) in self._registry.items():
52-
if typ is registered_type:
106+
# -----------------------------------------------------
107+
108+
@classmethod
109+
def deserialize(cls, value_str, typ):
110+
"""Convert RDM string back to Python object."""
111+
112+
for registered_type, _, deserializer in cls._registry.values():
113+
origin = get_origin(typ) or typ
114+
if isinstance(origin, type) and issubclass(origin, registered_type):
53115
return deserializer(value_str)
54-
return ast.literal_eval(value_str)
55116

117+
return ast.literal_eval(value_str)
56118

57-
registry = TypeRegistry()
58119

59-
registry.register(
120+
TypeRegistry.register(
60121
"datetime",
61122
datetime,
62-
lambda v: f'{v.isoformat()}',
63-
lambda v: datetime.fromisoformat(v)
123+
lambda v: v.isoformat(),
124+
datetime.fromisoformat
64125
)
65126

66-
registry.register(
127+
TypeRegistry.register(
67128
"date",
68129
date,
69-
lambda v: f'{v.isoformat()}',
70-
lambda v: date.fromisoformat(v)
130+
lambda v: v.isoformat(),
131+
date.fromisoformat
71132
)
72133

73-
registry.register(
134+
TypeRegistry.register(
74135
"time",
75136
time,
76-
lambda v: f'{v.isoformat()}',
77-
lambda v: time.fromisoformat(v)
137+
lambda v: v.isoformat(),
138+
time.fromisoformat
78139
)
79140

80-
registry.register(
141+
TypeRegistry.register(
81142
"Path",
82143
Path,
83-
lambda v: f'{str(v)}',
84-
lambda v: Path(v)
144+
str,
145+
Path
85146
)
86147

87148

88149
# =========================================================
89150
# Safe Type Parsing
90151
# =========================================================
91152

92-
SAFE_TYPES = {
93-
"str": str,
94-
"int": int,
95-
"float": float,
96-
"bool": bool,
97-
"NoneType": type(None),
98-
"list": list,
99-
"set": set,
100-
"tuple": tuple,
101-
"dict": dict,
102-
"datetime": datetime,
103-
"date": date,
104-
"time": time,
105-
"Path": Path,
106-
}
153+
def _convert_optional(type_str: str) -> str:
154+
"""
155+
Converts:
156+
Optional[str] → str | NoneType
157+
Optional[list[int]] → list[int] | NoneType
158+
Optional[dict[str,int]] → dict[str,int] | NoneType
159+
"""
160+
161+
result = ""
162+
i = 0
163+
164+
while i < len(type_str):
165+
if type_str.startswith("Optional[", i):
166+
i += len("Optional[")
167+
168+
bracket_level = 1
169+
inner = ""
170+
171+
while i < len(type_str) and bracket_level > 0:
172+
if type_str[i] == "[":
173+
bracket_level += 1
174+
elif type_str[i] == "]":
175+
bracket_level -= 1
176+
177+
if bracket_level > 0:
178+
inner += type_str[i]
179+
180+
i += 1
181+
182+
# Recursively handle nested Optional
183+
inner = _convert_optional(inner)
184+
185+
result += f"{inner} | NoneType"
186+
else:
187+
result += type_str[i]
188+
i += 1
189+
190+
return result
107191

108192
def parse_type(type_str: str):
109-
"""Safely parse a type string like 'list[str | int]'."""
110-
return eval(type_str, SAFE_TYPES)
193+
"""
194+
Safely parse type strings like:
195+
list[str | int]
196+
Optional[str]
197+
"""
198+
199+
# ---- Convert Optional[T] → T | NoneType ----
200+
if "Optional[" in type_str:
201+
type_str = _convert_optional(type_str)
111202

203+
return eval(type_str, TypeRegistry._types)
112204

113205
# =========================================================
114206
# ExpiredKey Type
@@ -373,7 +465,7 @@ def serialize(self) -> list[str]:
373465
key_pad = key.ljust(max_key)
374466
type_name = self._type_name(typ)
375467
type_pad = type_name.ljust(max_type)
376-
value_str = registry.serialize(value)
468+
value_str = TypeRegistry.serialize(value)
377469

378470
lines.append(f"{key_pad} : {type_pad} = {value_str}")
379471

@@ -383,12 +475,31 @@ def _type_name(self, typ: type) -> str:
383475
origin = get_origin(typ)
384476
args = get_args(typ)
385477

478+
# -------------------------
479+
# Handle Union / Optional
480+
# -------------------------
386481
if origin in (Union, types.UnionType):
482+
args_set = set(args)
483+
484+
# Detect Optional[T] → Union[T, NoneType]
485+
if type(None) in args_set and len(args) == 2:
486+
non_none = [a for a in args if a is not type(None)][0]
487+
return f"Optional[{self._type_name(non_none)}]"
488+
489+
# Fallback: normal union
387490
return " | ".join(self._type_name(a) for a in args)
388491

492+
# -------------------------
493+
# Simple types
494+
# -------------------------
389495
if origin is None:
390-
return typ.__name__
496+
if typ is type(None):
497+
return "NoneType"
498+
return TypeRegistry._types.get(typ.__name__, typ).__name__
391499

500+
# -------------------------
501+
# Generics (list, dict, etc.)
502+
# -------------------------
392503
inner = ", ".join(self._type_name(a) for a in args)
393504
return f"{origin.__name__}[{inner}]"
394505

@@ -410,7 +521,7 @@ def from_lines(cls, name: str, lines: list[str]):
410521
value_str = value_str.strip()
411522

412523
typ = parse_type(type_str)
413-
value = registry.deserialize(value_str, typ)
524+
value = TypeRegistry.deserialize(value_str, typ)
414525

415526
section._items[key] = (typ, value)
416527

0 commit comments

Comments
 (0)