Skip to content

Commit 1877af2

Browse files
committed
Allow default to be unique for each variable
1 parent 7aac9d8 commit 1877af2

1 file changed

Lines changed: 46 additions & 25 deletions

File tree

batchflow/config.py

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from pprint import pformat
22

33
class Config(dict):
4+
5+
# Should be defined temporarily for the already pickled configs
46
class IAddDict(dict):
57
pass
8+
69
def __init__(self, config=None, **kwargs):
710
""" Create Config.
811
@@ -50,7 +53,15 @@ def parse(self, config):
5053
self : Config
5154
5255
"""
53-
items = config.items() if isinstance(config, dict) else dict(config).items()
56+
if isinstance(config, Config):
57+
items = config.items(flatten=True) # suppose we have config = {'a': {'b': {'c': 1}}},
58+
# and we try to update config with other = {'a': {'b': {'d': 3}}},
59+
# and expect to see config = {'a': {'b': {'c': 1, 'd': 3}}}
60+
elif isinstance(config, dict):
61+
items = config.items()
62+
else:
63+
items = dict(config).items()
64+
# items = config.items() if isinstance(config, dict) else dict(config).items()
5465

5566
for key, value in items:
5667
if isinstance(key, str): # if key contains multiple consecutive '/'
@@ -87,7 +98,7 @@ def put(self, key, value):
8798
if isinstance(value, dict) and last_level in config and isinstance(config[last_level], dict):
8899
config[last_level].update(value)
89100
else:
90-
# for example, we try to set config['a/b/c'] = 3, where config = Config({'a/b': 1})
101+
# for example, we try to set config['a/b/c'] = 3, where config = Config({'a/b': 1}) and don't want error here
91102
if isinstance(config, dict):
92103
config[last_level] = value
93104
else:
@@ -96,7 +107,9 @@ def put(self, key, value):
96107
self.config[key] = value
97108

98109
def _get(self, key, default=None, has_default=False, pop=False):
99-
""" Consecutively retrieve values for a given key if the key contains '/'. """
110+
""" Consecutively retrieve values for a given key if the key contains '/'.
111+
This method supports the `default` to be unique for each variable in key.
112+
"""
100113
method = 'get' if not pop else 'pop'
101114
method = getattr(self.config, method)
102115

@@ -106,48 +119,57 @@ def _get(self, key, default=None, has_default=False, pop=False):
106119
unpack = True
107120

108121
# Provide `default` for each variable in key
109-
if has_default:
110-
if isinstance(default, (list, tuple)) and len(key) != 1 and len(default) != len(key):
111-
raise ValueError() #
112-
elif not isinstance(default, (list, tuple)) and len(key) != 1:
113-
default = [default] * len(key)
122+
if default is not None and len(key) != 1 and len(default) != len(key):
123+
raise ValueError('You should provide `default` for each variable in `key`') # edit
124+
default = [default] if not isinstance(default, list) else default
114125

115126
ret_vars = []
116-
for variable in key:
127+
for ix, variable in enumerate(key):
117128

118129
if isinstance(variable, str) and '/' in variable:
119130
value = self.config
120131
levels = variable.split('/')
121132
values = []
122133

123134
for level in levels:
135+
124136
if not isinstance(value, dict):
125-
if has_default:
126-
return default
127-
raise KeyError(level)
128-
if level not in value:
129-
if has_default:
130-
return default
131-
raise KeyError(level)
132-
value = value[level]
133-
values.append(value)
137+
if not has_default:
138+
raise KeyError(level)
139+
value = default[ix]
140+
ret_vars.append(value)
141+
break
142+
143+
elif level not in value:
144+
if not has_default:
145+
raise KeyError(level)
146+
value = default[ix]
147+
ret_vars.append(value)
148+
break
149+
150+
else:
151+
value = value[level]
152+
values.append(value)
134153

135154
if pop:
136155
del values[-2][level] # delete the last level from the parent dict
137156

138157
else:
139158
if variable not in self.config:
140-
if has_default:
141-
return default
142-
raise KeyError
143-
value = method(variable)
159+
if not has_default:
160+
raise KeyError(variable)
161+
value = default[ix]
162+
ret_vars.append(value)
163+
164+
else:
165+
value = method(variable)
144166

145167
if isinstance(value, dict):
146168
value = Config(value)
147-
148169
ret_vars.append(value)
149170

150171
ret_vars = ret_vars[0] if unpack else tuple(ret_vars)
172+
151173
return ret_vars
152174

153175
def get(self, key, default=None):
@@ -206,7 +228,7 @@ def update(self, other=None, **kwargs):
206228
if not isinstance(other, (dict, tuple, list)):
207229
raise TypeError(f'{type(other)} object is not iterable')
208230

209-
self.parse(other)
231+
self.parse(Config(other))
210232

211233
for key, value in kwargs.items():
212234
self.put(key, value)
@@ -346,7 +368,6 @@ def __len__(self):
346368

347369
def __eq__(self, other):
348370
self_ = self.flatten()
349-
print(self_, 'self_')
350371
other_ = Config(other).flatten() if isinstance(other, dict) else other
351372
return self_.__eq__(other_)
352373

0 commit comments

Comments
 (0)