Skip to content

Commit a2af686

Browse files
authored
Search up tree for config (#1122)
* write_default_config takes parent dir * search up tree * formatting * Add another test
1 parent 070d529 commit a2af686

2 files changed

Lines changed: 83 additions & 19 deletions

File tree

src/modelgauge/config.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,40 @@
1212
CONFIG_TEMPLATES = [DEFAULT_SECRETS]
1313

1414

15-
def write_default_config(dir: str = DEFAULT_CONFIG_DIR):
15+
def find_config_dir(path: str = ".") -> str:
16+
"""Search up the tree for the config directory."""
17+
current_dir = os.path.abspath(path)
18+
while True:
19+
config_dir = os.path.join(current_dir, DEFAULT_CONFIG_DIR)
20+
if os.path.exists(config_dir):
21+
return config_dir
22+
parent_dir = os.path.dirname(current_dir)
23+
if parent_dir == current_dir: # Reached root directory
24+
raise FileNotFoundError(
25+
f"Could not find the config directory '{DEFAULT_CONFIG_DIR}' anywhere along the path to '{path}'."
26+
)
27+
current_dir = parent_dir
28+
29+
30+
def write_default_config(parent_dir: str = "."):
1631
"""If the config directory doesn't exist, fill it with defaults."""
17-
if os.path.exists(dir):
32+
try:
33+
find_config_dir(parent_dir)
34+
# Don't do anything if the config directory already exists.
1835
# Assume if it exists we don't need to add templates
19-
return
20-
os.makedirs(dir)
21-
for template in CONFIG_TEMPLATES:
22-
source_file = str(resources.files(config_templates) / template)
23-
output_file = os.path.join(dir, template)
24-
shutil.copyfile(source_file, output_file)
36+
except FileNotFoundError:
37+
dir = os.path.join(parent_dir, DEFAULT_CONFIG_DIR)
38+
os.makedirs(dir)
39+
for template in CONFIG_TEMPLATES:
40+
source_file = str(resources.files(config_templates) / template)
41+
output_file = os.path.join(dir, template)
42+
shutil.copyfile(source_file, output_file)
2543

2644

27-
def load_secrets_from_config(path: str = SECRETS_PATH) -> RawSecrets:
45+
def load_secrets_from_config(path: str = ".") -> RawSecrets:
2846
"""Load the toml file and verify it is shaped as expected."""
29-
with open(path, "rb") as f:
47+
secrets_path = os.path.join(find_config_dir(path), DEFAULT_SECRETS)
48+
with open(secrets_path, "rb") as f:
3049
data = tomli.load(f)
3150
for values in data.values():
3251
# Verify the config is shaped as expected.

tests/modelgauge_tests/test_config.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,43 +3,88 @@
33
from modelgauge.config import (
44
DEFAULT_SECRETS,
55
MissingSecretsFromConfig,
6+
find_config_dir,
67
load_secrets_from_config,
78
raise_if_missing_from_config,
89
write_default_config,
910
)
1011
from modelgauge.secret_values import MissingSecretValues, SecretDescription
1112

1213

14+
def test_find_config_dir(tmpdir):
15+
config_dir = tmpdir.join("config")
16+
os.makedirs(config_dir)
17+
found_dir = find_config_dir(str(tmpdir))
18+
assert found_dir == config_dir
19+
20+
21+
def test_find_config_dir_searches_up_tree(tmpdir):
22+
config_dir = tmpdir.join("config")
23+
os.makedirs(config_dir)
24+
sub_dir = tmpdir.join("subdir")
25+
os.makedirs(sub_dir)
26+
found_dir = find_config_dir(str(sub_dir))
27+
assert found_dir == config_dir
28+
29+
30+
def test_find_config_dir_no_config(tmpdir):
31+
with pytest.raises(FileNotFoundError):
32+
find_config_dir(str(tmpdir))
33+
34+
1335
def test_write_default_config_writes_files(tmpdir):
36+
write_default_config(tmpdir)
1437
config_dir = tmpdir.join("config")
15-
write_default_config(config_dir)
1638
files = [f.basename for f in config_dir.listdir()]
1739
assert files == ["secrets.toml"]
1840

1941

2042
def test_write_default_config_skips_existing_dir(tmpdir):
2143
config_dir = tmpdir.join("config")
2244
os.makedirs(config_dir)
23-
write_default_config(config_dir)
45+
write_default_config(tmpdir)
2446
files = [f.basename for f in config_dir.listdir()]
2547
# No files created
2648
assert files == []
2749

2850

29-
def test_load_secrets_from_config_loads_default(tmpdir):
51+
def test_write_default_config_searches_up_tree(tmpdir):
3052
config_dir = tmpdir.join("config")
31-
write_default_config(config_dir)
32-
secrets_file = config_dir.join(DEFAULT_SECRETS)
53+
os.makedirs(config_dir)
54+
sub_dir = tmpdir.join("subdir")
55+
os.makedirs(sub_dir)
56+
write_default_config(sub_dir)
57+
# Nothing created in subdir
58+
assert not os.path.exists(sub_dir.join("config"))
3359

34-
assert load_secrets_from_config(secrets_file) == {"demo": {"api_key": "12345"}}
60+
61+
def test_load_secrets_from_config_loads_default(tmpdir):
62+
write_default_config(tmpdir)
63+
assert load_secrets_from_config(tmpdir) == {"demo": {"api_key": "12345"}}
64+
65+
66+
def test_load_secrets_works_with_file_path(tmpdir):
67+
"""Test that you can also pass in a file path to load_secrets_from_config."""
68+
config_dir = tmpdir.join("subdir", "config")
69+
os.makedirs(config_dir)
70+
secrets_file = config_dir.join("secrets.toml")
71+
with open(secrets_file, "w") as f:
72+
f.write(
73+
"""\
74+
[scope]
75+
api_key = "12345"
76+
"""
77+
)
78+
secrets = load_secrets_from_config(secrets_file)
79+
assert secrets == {"scope": {"api_key": "12345"}}
3580

3681

3782
def test_load_secrets_from_config_no_file(tmpdir):
3883
config_dir = tmpdir.join("config")
39-
secrets_file = config_dir.join(DEFAULT_SECRETS)
84+
os.makedirs(config_dir)
4085

4186
with pytest.raises(FileNotFoundError):
42-
load_secrets_from_config(secrets_file)
87+
load_secrets_from_config(tmpdir)
4388

4489

4590
def test_load_secrets_from_config_bad_format(tmpdir):
@@ -49,7 +94,7 @@ def test_load_secrets_from_config_bad_format(tmpdir):
4994
with open(secrets_file, "w") as f:
5095
f.write("""not_scoped = "some-value"\n""")
5196
with pytest.raises(AssertionError) as err_info:
52-
load_secrets_from_config(secrets_file)
97+
load_secrets_from_config(tmpdir)
5398
err_text = str(err_info.value)
5499
assert err_text == "All keys should be in a [scope]."
55100

0 commit comments

Comments
 (0)