Skip to content

Commit 6586d1a

Browse files
authored
Inference rules (#109)
* use custom loader class, disable inference * additional custom loader tests * add inference tests to cmd tests
1 parent b58e462 commit 6586d1a

6 files changed

Lines changed: 98 additions & 37 deletions

File tree

lib/envstack/node.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -250,30 +250,35 @@ def resolve(self, env: dict = os.environ):
250250

251251

252252
class CustomLoader(yaml.SafeLoader):
253+
"""Custom Loader class to preserve order of keys and ensure required
254+
keys are first in the mapping."""
255+
253256
required_keys = {"include", "all", "darwin", "linux", "windows"}
254257

255258
def construct_mapping(self, node: yaml.Node, deep: bool = False):
256-
mapping = super().construct_mapping(node, deep=deep)
257-
for key, value in mapping.items():
259+
"""Construct a mapping from a YAML node, preserving the order of keys
260+
and ensuring"""
261+
# keep YAML merge keys (<<) working
262+
self.flatten_mapping(node)
263+
264+
mapping = {}
265+
266+
for key_node, value_node in node.value:
267+
# never implicit-resolve scalar KEYS (YES/NO/ON/OFF/null/date/etc)
268+
if isinstance(key_node, yaml.ScalarNode):
269+
key = key_node.value
270+
else:
271+
key = self.construct_object(key_node, deep=deep)
272+
273+
value = self.construct_object(value_node, deep=deep)
274+
mapping[key] = value
275+
276+
# preserve order of keys and ensure required keys are first in the mapping
277+
for key, value in list(mapping.items()):
258278
if key in self.required_keys:
259279
continue
260-
try:
261-
if node.tag == Base64Node.yaml_tag:
262-
mapping[key] = Base64Node(value)
263-
elif node.tag == EncryptedNode.yaml_tag:
264-
mapping[key] = EncryptedNode(value)
265-
elif node.tag == AESGCMNode.yaml_tag:
266-
mapping[key] = AESGCMNode(value)
267-
elif node.tag == FernetNode.yaml_tag:
268-
mapping[key] = FernetNode(value)
269-
elif node.tag == MD5Node.yaml_tag:
270-
mapping[key] = MD5Node(value)
271-
else:
272-
mapping[key] = Template(value)
273-
except Exception as e:
274-
raise yaml.constructor.ConstructorError(
275-
None, None, f"Error parsing template: {e}", node.start_mark
276-
)
280+
mapping[key] = value
281+
277282
return mapping
278283

279284

lib/envstack/util.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,13 @@
4545
import yaml
4646

4747
from envstack import config
48-
from envstack.node import AESGCMNode, Base64Node, EncryptedNode, FernetNode
48+
from envstack.node import (
49+
AESGCMNode,
50+
Base64Node,
51+
EncryptedNode,
52+
FernetNode,
53+
CustomLoader,
54+
)
4955

5056
# default memoization cache timeout in seconds
5157
CACHE_TIMEOUT = 5
@@ -643,8 +649,7 @@ def validate_yaml(file_path: str):
643649
"""
644650
try:
645651
with open(file_path, "r") as stream:
646-
data = yaml.safe_load(stream.read())
647-
# data = yaml.load(stream.read(), Loader=CustomLoader)
652+
data = yaml.load(stream.read(), Loader=CustomLoader)
648653
if not isinstance(data, dict):
649654
raise yaml.YAMLError("invalid data structure")
650655
return data

tests/fixtures/env/thing.env

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ all: &all
55
LOG_LEVEL: ${LOG_LEVEL:=INFO}
66
FLOAT: 1.0
77
INT: 5
8+
YES: 1
9+
TRUE: 1
810
NUMBER_LIST: [1, 2, 3]
911
CHAR_LIST: ['a', 'b', 'c', "${HELLO}"]
1012
DICT: {a: 1, b: 2, c: "${INT}"}

tests/test_cmds.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,8 @@ def test_thing(self):
159159
PYTHONPATH=${DEPLOY_ROOT}/lib/python:${PYTHONPATH}
160160
ROOT=${HOME}/.local/pipe
161161
STACK=thing
162+
TRUE=1
163+
YES=1
162164
"""
163165
command = "%s thing" % self.envstack_bin
164166
output = subprocess.check_output(command, shell=True, universal_newlines=True)
@@ -437,6 +439,8 @@ def test_thing(self):
437439
INT: 5
438440
LOG_LEVEL: ${LOG_LEVEL:=INFO}
439441
NUMBER_LIST: [1, 2, 3]
442+
TRUE: 1
443+
YES: 1
440444
darwin:
441445
<<: *all
442446
ROOT: ${HOME}/Library/Application Support/pipe
@@ -502,6 +506,8 @@ def test_thing_encrypted(self):
502506
INT: !encrypt NQ==
503507
LOG_LEVEL: !encrypt JHtMT0dfTEVWRUw6PUlORk99
504508
NUMBER_LIST: !encrypt WzEsIDIsIDNd
509+
TRUE: !encrypt MQ==
510+
YES: !encrypt MQ==
505511
darwin:
506512
<<: *all
507513
ROOT: !encrypt JHtIT01FfS9MaWJyYXJ5L0FwcGxpY2F0aW9uIFN1cHBvcnQvcGlwZQ==

tests/test_node.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ def test_str(self):
170170

171171

172172
class TestCustomLoader(unittest.TestCase):
173+
"""Test the CustomLoader class"""
174+
175+
def load(self, s: str):
176+
return yaml.load(s, Loader=CustomLoader)
177+
173178
def test_construct_mapping(self):
174179
"""test the CustomLoader construct_mapping method"""
175180
envfile = os.path.join(envpath, "project.env")
@@ -178,8 +183,56 @@ def test_construct_mapping(self):
178183
mapping = loader.construct_mapping(node)
179184
self.assertIsInstance(mapping, dict)
180185

186+
def test_keys_stay_strings_booleanish(self):
187+
"""test that keys that look like booleans are not converted to bools
188+
by the resolver"""
189+
data = self.load("YES: 1\nNO: 2\nON: 3\nOFF: 4\ntrue: 5\nFalse: 6\n")
190+
assert "YES" in data and data["YES"] == 1
191+
assert "NO" in data and data["NO"] == 2
192+
assert "ON" in data and data["ON"] == 3
193+
assert "OFF" in data and data["OFF"] == 4
194+
assert "true" in data and data["true"] == 5
195+
assert "False" in data and data["False"] == 6
196+
# ensure the resolver didn't produce bool keys
197+
assert True not in data
198+
assert False not in data
199+
200+
def test_keys_stay_strings_nullish(self):
201+
"""test that keys that look like nulls are not converted to None by
202+
the resolver"""
203+
data = self.load("NULL: 1\nNull: 2\nnull: 3\n~: 4\n")
204+
assert data["NULL"] == 1
205+
assert data["Null"] == 2
206+
assert data["null"] == 3
207+
assert "~" in data and data["~"] == 4
208+
assert None not in data
209+
210+
def test_keys_stay_strings_timestampish(self):
211+
"""test that keys that look like timestamps are not converted to
212+
datetime objects"""
213+
# YAML 1.1 can infer timestamps
214+
data = self.load("2024-01-02: 1\n2024-01-02T03:04:05Z: 2\n")
215+
assert data["2024-01-02"] == 1
216+
assert data["2024-01-02T03:04:05Z"] == 2
217+
218+
def test_yaml_merge_key_works(self):
219+
"""test that the YAML merge key (<<) works correctly with anchors and
220+
aliases"""
221+
s = """
222+
all: &all
223+
A: 1
224+
dev:
225+
<<: *all
226+
B: 2
227+
"""
228+
data = self.load(s)
229+
assert data["dev"]["A"] == 1
230+
assert data["dev"]["B"] == 2
231+
181232

182233
class TestCustomDumper(unittest.TestCase):
234+
"""Test the CustomDumper class"""
235+
183236
def test_init(self):
184237
"""test the CustomDumper __init__ method"""
185238
dumper = CustomDumper(None)

tests/test_util.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,7 @@ def test_tokenized_value_single(self):
131131
"ROOT": "/var/tmp",
132132
}
133133
result = evaluate_modifiers(expression, environ)
134-
self.assertEqual(
135-
result, r"/var/tmp/{foo}"
136-
)
134+
self.assertEqual(result, r"/var/tmp/{foo}")
137135

138136
def test_tokenized_value_trailing_slash(self):
139137
"""Test tokenized value with a single token and trailing slash."""
@@ -142,9 +140,7 @@ def test_tokenized_value_trailing_slash(self):
142140
"ROOT": "/var/tmp",
143141
}
144142
result = evaluate_modifiers(expression, environ)
145-
self.assertEqual(
146-
result, r"/var/tmp/{foo}/"
147-
)
143+
self.assertEqual(result, r"/var/tmp/{foo}/")
148144

149145
def test_tokenized_value_two_tokens(self):
150146
"""Test tokenized value with two tokens."""
@@ -153,9 +149,7 @@ def test_tokenized_value_two_tokens(self):
153149
"ROOT": "/var/tmp",
154150
}
155151
result = evaluate_modifiers(expression, environ)
156-
self.assertEqual(
157-
result, r"/var/tmp/{foo}/{bar}"
158-
)
152+
self.assertEqual(result, r"/var/tmp/{foo}/{bar}")
159153

160154
def test_tokenized_value_three_tokens(self):
161155
"""Test tokenized value with three tokens."""
@@ -164,9 +158,7 @@ def test_tokenized_value_three_tokens(self):
164158
"ROOT": "/var/tmp",
165159
}
166160
result = evaluate_modifiers(expression, environ)
167-
self.assertEqual(
168-
result, r"/var/tmp/{foo}/{bar}/{baz}"
169-
)
161+
self.assertEqual(result, r"/var/tmp/{foo}/{bar}/{baz}")
170162

171163
def test_tokenized_value_three_tokens_slash(self):
172164
"""Test tokenized value with three tokens and a trailing slash."""
@@ -175,9 +167,7 @@ def test_tokenized_value_three_tokens_slash(self):
175167
"ROOT": "/var/tmp",
176168
}
177169
result = evaluate_modifiers(expression, environ)
178-
self.assertEqual(
179-
result, r"/var/tmp/{foo}/{bar}/{baz}/"
180-
)
170+
self.assertEqual(result, r"/var/tmp/{foo}/{bar}/{baz}/")
181171

182172
def test_default_value_with_default_args(self):
183173
"""Test default value with default args."""

0 commit comments

Comments
 (0)