Skip to content

Commit 7f963d4

Browse files
authored
Add files via upload
1 parent 8061c00 commit 7f963d4

File tree

2 files changed

+280
-0
lines changed

2 files changed

+280
-0
lines changed

src/rheelDM/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .main import *

src/rheelDM/main.py

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
from __future__ import annotations
2+
3+
from pathlib import Path
4+
from datetime import datetime, date, time
5+
from typing import Any, get_origin, get_args, Union
6+
import ast
7+
import types
8+
9+
10+
# =========================================================
11+
# Type Registry
12+
# =========================================================
13+
14+
class TypeRegistry:
15+
def __init__(self):
16+
self._registry: dict[str, tuple[type, callable, callable]] = {}
17+
18+
def register(self, name: str, typ: type, serializer, deserializer):
19+
self._registry[name] = (typ, serializer, deserializer)
20+
21+
def serialize(self, value: Any) -> str:
22+
for name, (typ, serializer, _) in self._registry.items():
23+
if isinstance(value, typ):
24+
return serializer(value)
25+
return repr(value)
26+
27+
def deserialize(self, value_str: str, typ: type):
28+
for name, (registered_type, _, deserializer) in self._registry.items():
29+
if typ is registered_type:
30+
return deserializer(value_str)
31+
return ast.literal_eval(value_str)
32+
33+
34+
registry = TypeRegistry()
35+
36+
registry.register(
37+
"datetime",
38+
datetime,
39+
lambda v: f'"{v.isoformat()}"',
40+
lambda v: datetime.fromisoformat(v.strip('"'))
41+
)
42+
43+
registry.register(
44+
"date",
45+
date,
46+
lambda v: f'"{v.isoformat()}"',
47+
lambda v: date.fromisoformat(v.strip('"'))
48+
)
49+
50+
registry.register(
51+
"time",
52+
time,
53+
lambda v: f'"{v.isoformat()}"',
54+
lambda v: time.fromisoformat(v.strip('"'))
55+
)
56+
57+
registry.register(
58+
"Path",
59+
Path,
60+
lambda v: f'"{str(v)}"',
61+
lambda v: Path(v.strip('"'))
62+
)
63+
64+
65+
# =========================================================
66+
# Safe Type Parsing
67+
# =========================================================
68+
69+
SAFE_TYPES = {
70+
"str": str,
71+
"int": int,
72+
"float": float,
73+
"bool": bool,
74+
"list": list,
75+
"set": set,
76+
"tuple": tuple,
77+
"dict": dict,
78+
"datetime": datetime,
79+
"date": date,
80+
"time": time,
81+
"Path": Path,
82+
}
83+
84+
85+
def parse_type(type_str: str):
86+
return eval(type_str, SAFE_TYPES)
87+
88+
89+
# =========================================================
90+
# Section
91+
# =========================================================
92+
93+
class Section:
94+
95+
def __init__(self, name: str):
96+
self.name = name
97+
self._items: dict[str, tuple[type, Any]] = {}
98+
99+
# -----------------------
100+
# Public API
101+
# -----------------------
102+
103+
def set(self, key: str, typ: type, value: Any):
104+
self._validate(value, typ)
105+
self._items[key] = (typ, value)
106+
107+
def get(self, key: str):
108+
return self._items[key][1]
109+
110+
# -----------------------
111+
# Type Validation
112+
# -----------------------
113+
114+
def _validate(self, value: Any, typ: type):
115+
origin = get_origin(typ)
116+
args = get_args(typ)
117+
118+
# Union (str | int)
119+
if origin in (Union, types.UnionType):
120+
for option in args:
121+
try:
122+
self._validate(value, option)
123+
return
124+
except TypeError:
125+
continue
126+
raise TypeError(f"{value} does not match any type in {typ}")
127+
128+
# Normal type
129+
if origin is None:
130+
if not isinstance(value, typ):
131+
raise TypeError(f"{value} is not {typ}")
132+
return
133+
134+
if origin is list:
135+
if not isinstance(value, list):
136+
raise TypeError("Expected list")
137+
for v in value:
138+
self._validate(v, args[0])
139+
return
140+
141+
if origin is set:
142+
if not isinstance(value, set):
143+
raise TypeError("Expected set")
144+
for v in value:
145+
self._validate(v, args[0])
146+
return
147+
148+
if origin is tuple:
149+
if not isinstance(value, tuple):
150+
raise TypeError("Expected tuple")
151+
for v in value:
152+
self._validate(v, args[0])
153+
return
154+
155+
if origin is dict:
156+
if not isinstance(value, dict):
157+
raise TypeError("Expected dict")
158+
key_t, val_t = args
159+
for k, v in value.items():
160+
self._validate(k, key_t)
161+
self._validate(v, val_t)
162+
return
163+
164+
raise TypeError(f"Unsupported type {typ}")
165+
166+
# -----------------------
167+
# Serialization
168+
# -----------------------
169+
170+
def serialize(self) -> list[str]:
171+
lines = [f"[{self.name}]"]
172+
173+
max_key = max((len(k) for k in self._items), default=0)
174+
max_type = max((len(self._type_name(t)) for t, _ in self._items.values()), default=0)
175+
176+
for key, (typ, value) in self._items.items():
177+
key_pad = key.ljust(max_key)
178+
type_name = self._type_name(typ)
179+
type_pad = type_name.ljust(max_type)
180+
value_str = registry.serialize(value)
181+
182+
lines.append(f"{key_pad} : {type_pad} = {value_str}")
183+
184+
return lines
185+
186+
def _type_name(self, typ: type) -> str:
187+
origin = get_origin(typ)
188+
args = get_args(typ)
189+
190+
if origin in (Union, types.UnionType):
191+
return " | ".join(self._type_name(a) for a in args)
192+
193+
if origin is None:
194+
return typ.__name__
195+
196+
inner = ", ".join(self._type_name(a) for a in args)
197+
return f"{origin.__name__}[{inner}]"
198+
199+
# -----------------------
200+
# Parsing
201+
# -----------------------
202+
203+
@classmethod
204+
def from_lines(cls, name: str, lines: list[str]):
205+
section = cls(name)
206+
207+
for raw_line in lines:
208+
line = raw_line.split("#", 1)[0].strip()
209+
if not line:
210+
continue
211+
212+
left, value_str = line.split("=", 1)
213+
key_part, type_part = left.split(":", 1)
214+
215+
key = key_part.strip()
216+
type_str = type_part.strip()
217+
value_str = value_str.strip()
218+
219+
typ = parse_type(type_str)
220+
value = registry.deserialize(value_str, typ)
221+
222+
section._items[key] = (typ, value)
223+
224+
return section
225+
226+
227+
# =========================================================
228+
# Obj
229+
# =========================================================
230+
231+
class Obj:
232+
233+
def __init__(self):
234+
self._sections: dict[str, Section] = {}
235+
236+
def section(self, name: str) -> Section:
237+
if name not in self._sections:
238+
self._sections[name] = Section(name)
239+
return self._sections[name]
240+
241+
def save(self, filename: str):
242+
if not filename.endswith(".rdm"):
243+
filename += ".rdm"
244+
245+
lines = []
246+
247+
for section in self._sections.values():
248+
lines.extend(section.serialize())
249+
lines.append("")
250+
251+
Path(filename).write_text("\n".join(lines).rstrip())
252+
253+
@classmethod
254+
def load(cls, filename: str):
255+
content = Path(filename).read_text().splitlines()
256+
257+
obj = cls()
258+
current_name = None
259+
buffer = []
260+
261+
for raw_line in content:
262+
stripped = raw_line.strip()
263+
264+
if not stripped or stripped.startswith("#"):
265+
continue
266+
267+
if stripped.startswith("[") and stripped.endswith("]"):
268+
if current_name:
269+
obj._sections[current_name] = Section.from_lines(current_name, buffer)
270+
buffer = []
271+
272+
current_name = stripped[1:-1]
273+
else:
274+
buffer.append(raw_line)
275+
276+
if current_name:
277+
obj._sections[current_name] = Section.from_lines(current_name, buffer)
278+
279+
return obj

0 commit comments

Comments
 (0)