Skip to content

Commit 67417fe

Browse files
author
Paul Prescod
committed
Refactor YAML loading to use add_representer
1 parent 7844d2c commit 67417fe

5 files changed

Lines changed: 51 additions & 33 deletions

File tree

snowfakery/data_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def load_continuation_yaml(continuation_file: OpenFileLike):
9595
def save_continuation_yaml(continuation_data: Globals, continuation_file: OpenFileLike):
9696
"""Save the global interpreter state from Globals into a continuation_file"""
9797
yaml.dump(
98-
continuation_data.__getstate__(),
98+
continuation_data,
9999
continuation_file,
100100
Dumper=SnowfakeryDumper,
101101
)

snowfakery/data_generator_runtime.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
)
2828
from snowfakery.plugins import PluginContext, SnowfakeryPlugin, ScalarTypes
2929
from snowfakery.utils.collections import OrderedSet
30+
from snowfakery.utils.yaml_utils import register_for_continuation
3031

3132
OutputStream = "snowfakery.output_streams.OutputStream"
3233
VariableDefinition = "snowfakery.data_generator_runtime_object_model.VariableDefinition"
@@ -60,6 +61,7 @@ def generate_id(self, table_name: str) -> int:
6061
def __getitem__(self, table_name: str) -> int:
6162
return self.last_used_ids[table_name]
6263

64+
# TODO: Fix this to use the new convention of get_continuation_data
6365
def __getstate__(self):
6466
return {"last_used_ids": dict(self.last_used_ids)}
6567

@@ -195,21 +197,14 @@ def check_slots_filled(self):
195197
def first_new_id(self, tablename):
196198
return self.transients.first_new_id(tablename)
197199

198-
def __getstate__(self):
199-
def serialize_dict_of_object_rows(dct):
200-
return {k: v.__getstate__() for k, v in dct.items()}
201-
202-
persistent_nicknames = serialize_dict_of_object_rows(self.persistent_nicknames)
203-
persistent_objects_by_table = serialize_dict_of_object_rows(
204-
self.persistent_objects_by_table
205-
)
200+
def get_continuation_state(self):
206201
intertable_dependencies = [
207202
dict(v._asdict()) for v in self.intertable_dependencies
208203
] # converts ordered-dict to dict for Python 3.6 and 3.7
209204

210205
state = {
211-
"persistent_nicknames": persistent_nicknames,
212-
"persistent_objects_by_table": persistent_objects_by_table,
206+
"persistent_nicknames": self.persistent_nicknames,
207+
"persistent_objects_by_table": self.persistent_objects_by_table,
213208
"id_manager": self.id_manager.__getstate__(),
214209
"today": self.today,
215210
"nicknames_and_tables": self.nicknames_and_tables,
@@ -244,6 +239,9 @@ def deserialize_dict_of_object_rows(dct):
244239
self.reset_slots()
245240

246241

242+
register_for_continuation(Globals, Globals.get_continuation_state)
243+
244+
247245
class JinjaTemplateEvaluatorFactory:
248246
def __init__(self, native_types: bool):
249247
if native_types:

snowfakery/object_rows.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import yaml
44
import snowfakery # noqa
5-
from .utils.yaml_utils import SnowfakeryDumper
5+
from .utils.yaml_utils import register_for_continuation
66
from contextvars import ContextVar
77

88
IdManager = "snowfakery.data_generator_runtime.IdManager"
@@ -14,10 +14,6 @@ class ObjectRow:
1414
1515
Uses __getattr__ so that the template evaluator can use dot-notation."""
1616

17-
yaml_loader = yaml.SafeLoader
18-
yaml_dumper = SnowfakeryDumper
19-
yaml_tag = "!snowfakery_objectrow"
20-
2117
# be careful changing these slots because these objects must be serializable
2218
# to YAML and JSON
2319
__slots__ = ["_tablename", "_values", "_child_index"]
@@ -49,11 +45,17 @@ def __repr__(self):
4945
except Exception:
5046
return super().__repr__()
5147

52-
def __getstate__(self):
48+
def get_continuation_state(self):
5349
"""Get the state of this ObjectRow for serialization.
5450
5551
Do not include related ObjectRows because circular
5652
references in serialization formats cause problems."""
53+
54+
# If we decided to try to serialize hierarchies, we could
55+
# do it like this:
56+
# * keep track of if an object has already been serialized using a
57+
# property of the SnowfakeryDumper
58+
# * If so, output an ObjectReference instead of an ObjectRow
5759
values = {k: v for k, v in self._values.items() if not isinstance(v, ObjectRow)}
5860
return {"_tablename": self._tablename, "_values": values}
5961

@@ -62,6 +64,9 @@ def __setstate__(self, state):
6264
setattr(self, slot, value)
6365

6466

67+
register_for_continuation(ObjectRow, ObjectRow.get_continuation_state)
68+
69+
6570
class ObjectReference(yaml.YAMLObject):
6671
def __init__(self, tablename: str, id: int):
6772
self._tablename = tablename

snowfakery/plugins.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,11 @@
88
from functools import wraps
99
import typing as T
1010

11-
import yaml
12-
from yaml.representer import Representer
1311
from faker.providers import BaseProvider as FakerProvider
1412
from dateutil.relativedelta import relativedelta
1513

1614
import snowfakery.data_gen_exceptions as exc
17-
from .utils.yaml_utils import SnowfakeryDumper
15+
from snowfakery.utils.yaml_utils import register_for_continuation
1816
from .utils.collections import CaseInsensitiveDict
1917

2018
from numbers import Number
@@ -306,17 +304,7 @@ def _from_continuation(cls, args):
306304

307305
def __init_subclass__(cls, **kwargs):
308306
super().__init_subclass__(**kwargs)
309-
_register_for_continuation(cls)
310-
311-
312-
def _register_for_continuation(cls):
313-
SnowfakeryDumper.add_representer(cls, Representer.represent_object)
314-
yaml.SafeLoader.add_constructor(
315-
f"tag:yaml.org,2002:python/object/apply:{cls.__module__}.{cls.__name__}",
316-
lambda loader, node: cls._from_continuation(
317-
loader.construct_mapping(node.value[0])
318-
),
319-
)
307+
register_for_continuation(cls)
320308

321309

322310
class PluginResultIterator(PluginResult):
@@ -372,4 +360,4 @@ def convert(self, value):
372360

373361

374362
# round-trip PluginResult objects through continuation YAML if needed.
375-
_register_for_continuation(PluginResult)
363+
register_for_continuation(PluginResult)

snowfakery/utils/yaml_utils.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from yaml import SafeDumper
1+
from typing import Callable
2+
from yaml import SafeDumper, SafeLoader
3+
from yaml.representer import Representer
24

35

46
class SnowfakeryDumper(SafeDumper):
@@ -9,3 +11,28 @@ def hydrate(cls, data):
911
obj = cls.__new__(cls)
1012
obj.__setstate__(data)
1113
return obj
14+
15+
16+
# Evaluate whether its cleaner for functions to bypass register_for_continuation
17+
# and go directly to SnowfakeryDumper.add_representer.
18+
#
19+
#
20+
21+
22+
def represent_continuation(dumper: SnowfakeryDumper, data):
23+
if isinstance(data, dict):
24+
return Representer.represent_dict(dumper, data)
25+
else:
26+
return Representer.represent_object(dumper, data)
27+
28+
29+
def register_for_continuation(cls, dump_transformer: Callable = lambda x: x):
30+
SnowfakeryDumper.add_representer(
31+
cls, lambda self, data: represent_continuation(self, dump_transformer(data))
32+
)
33+
SafeLoader.add_constructor(
34+
f"tag:yaml.org,2002:python/object/apply:{cls.__module__}.{cls.__name__}",
35+
lambda loader, node: cls._from_continuation(
36+
loader.construct_mapping(node.value[0])
37+
),
38+
)

0 commit comments

Comments
 (0)