Skip to content

Commit 03b36cc

Browse files
committed
Fix defaults
1 parent 1877af2 commit 03b36cc

1 file changed

Lines changed: 80 additions & 75 deletions

File tree

batchflow/config.py

Lines changed: 80 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
""" Config class"""
12
from pprint import pformat
23

34
class Config(dict):
5+
""" Class for configs that can be represented as nested dicts with easy indexing by slashes. """
46

57
# Should be defined temporarily for the already pickled configs
68
class IAddDict(dict):
@@ -27,6 +29,7 @@ def __init__(self, config=None, **kwargs):
2729
kwargs :
2830
Parameters from kwargs also are parsed and saved to self.config.
2931
"""
32+
# pylint: disable=super-init-not-called
3033
self.config = {}
3134

3235
if config is None:
@@ -40,7 +43,7 @@ def __init__(self, config=None, **kwargs):
4043

4144
for key, value in kwargs.items():
4245
self.put(key, value)
43-
46+
4447
def parse(self, config):
4548
""" Parses flatten config with slashes.
4649
@@ -54,14 +57,14 @@ def parse(self, config):
5457
5558
"""
5659
if isinstance(config, Config):
57-
items = config.items(flatten=True) # suppose we have config = {'a': {'b': {'c': 1}}},
58-
# and we try to update config with other = {'a': {'b': {'d': 3}}},
59-
# and expect to see config = {'a': {'b': {'c': 1, 'd': 3}}}
60+
# suppose we have config = {'a': {'b': {'c': 1}}},
61+
# and we try to update config with other = {'a': {'b': {'d': 3}}},
62+
# and expect to see config = {'a': {'b': {'c': 1, 'd': 3}}}
63+
items = config.items(flatten=True)
6064
elif isinstance(config, dict):
6165
items = config.items()
6266
else:
6367
items = dict(config).items()
64-
# items = config.items() if isinstance(config, dict) else dict(config).items()
6568

6669
for key, value in items:
6770
if isinstance(key, str): # if key contains multiple consecutive '/'
@@ -98,11 +101,12 @@ def put(self, key, value):
98101
if isinstance(value, dict) and last_level in config and isinstance(config[last_level], dict):
99102
config[last_level].update(value)
100103
else:
101-
# for example, we try to set config['a/b/c'] = 3, where config = Config({'a/b': 1}) and don't want error here
102104
if isinstance(config, dict):
103105
config[last_level] = value
106+
# for example, we try to set my_config['a/b/c'] = 3,
107+
# where my_config = Config({'a/b': 1}) and don't want error here
104108
else:
105-
prev_config[level] = {last_level: value}
109+
prev_config[level] = {last_level: value} # pylint: disable=undefined-loop-variable
106110
else:
107111
self.config[key] = value
108112

@@ -118,15 +122,19 @@ def _get(self, key, default=None, has_default=False, pop=False):
118122
key = [key]
119123
unpack = True
120124

121-
# Provide `default` for each variable in key
122-
if default is not None and len(key) != 1 and len(default) != len(key):
123-
raise ValueError('You should provide `default` for each variable in `key`') # edit
124-
default = [default] if not isinstance(default, list) else default
125+
n = len(key)
126+
if n > 1:
127+
default = [default] * n if not isinstance(default, list) else default
128+
if len(default) != n:
129+
raise ValueError('The length of `default` must be equal to the length of `key`')
130+
else:
131+
default = [default]
125132

126133
ret_vars = []
127134
for ix, variable in enumerate(key):
128135

129136
if isinstance(variable, str) and '/' in variable:
137+
130138
value = self.config
131139
levels = variable.split('/')
132140
values = []
@@ -137,29 +145,29 @@ def _get(self, key, default=None, has_default=False, pop=False):
137145
if not has_default:
138146
raise KeyError(level)
139147
value = default[ix]
140-
ret_vars.append(value)
148+
values.append(value)
141149
break
142150

143-
elif level not in value:
151+
if level not in value:
144152
if not has_default:
145153
raise KeyError(level)
146154
value = default[ix]
147-
ret_vars.append(value)
155+
values.append(value)
148156
break
149157

150-
else:
151-
value = value[level]
152-
values.append(value)
158+
value = value[level]
159+
values.append(value)
153160

154161
if pop:
155-
del values[-2][level] # delete the last level from the parent dict
162+
# delete the last level from the parent dict
163+
values[-2].pop(level, default[ix]) # pylint: disable=undefined-loop-variable
156164

157165
else:
166+
158167
if variable not in self.config:
159168
if not has_default:
160169
raise KeyError(variable)
161170
value = default[ix]
162-
ret_vars.append(value)
163171

164172
else:
165173
value = method(variable)
@@ -182,7 +190,8 @@ def get(self, key, default=None):
182190
A key in the dictionary. '/' is used to get value from nested dict.
183191
default : misc
184192
Default value if key doesn't exist in config.
185-
Defaults to None, so that this method never raises a KeyError.
193+
By default None, so this method never raises a KeyError.
194+
If key has several variables, `default` can be a list with defaults for each variable.
186195
187196
Returns
188197
-------
@@ -192,7 +201,7 @@ def get(self, key, default=None):
192201
value = self._get(key, default=default, has_default=True)
193202

194203
return value
195-
204+
196205
def pop(self, key, **kwargs):
197206
""" Returns the value or tuple of values for key in the config.
198207
If not found, returns a default value.
@@ -203,7 +212,6 @@ def pop(self, key, **kwargs):
203212
A key in the dictionary. '/' is used to get value from nested dict.
204213
default : misc
205214
Default value if key doesn't exist in config.
206-
Defaults to None, so that this method never raises a KeyError.
207215
208216
Returns
209217
-------
@@ -216,13 +224,6 @@ def pop(self, key, **kwargs):
216224

217225
return value
218226

219-
def __repr__(self):
220-
return repr(self.config)
221-
222-
def __getitem__(self, key):
223-
value = self._get(key)
224-
return value
225-
226227
def update(self, other=None, **kwargs):
227228
other = other or {}
228229
if not isinstance(other, (dict, tuple, list)):
@@ -233,6 +234,34 @@ def update(self, other=None, **kwargs):
233234
for key, value in kwargs.items():
234235
self.put(key, value)
235236

237+
def flatten(self, config=None):
238+
""" Transforms nested dict into flatten dict.
239+
240+
Parameters
241+
----------
242+
config : dict, Config or None
243+
If None `self.config` will be parsed else config.
244+
245+
Returns
246+
-------
247+
new_config : dict
248+
249+
"""
250+
config = self.config if config is None else config
251+
new_config = {}
252+
for key, value in config.items():
253+
if isinstance(value, dict) and len(value) > 0:
254+
value = self.flatten(value)
255+
for _key, _value in value.items():
256+
if isinstance(_key, str):
257+
new_config[key + '/' + _key] = _value
258+
else:
259+
new_config[key] = {_key: _value}
260+
else:
261+
new_config[key] = value
262+
263+
return new_config
264+
236265
def keys(self, flatten=False):
237266
""" Returns config keys
238267
@@ -290,33 +319,13 @@ def items(self, flatten=False):
290319
items = self.config.items()
291320
return items
292321

293-
def flatten(self, config=None):
294-
""" Transforms nested dict into flatten dict.
295-
296-
Parameters
297-
----------
298-
config : dict, Config or None
299-
If None `self.config` will be parsed else config.
300-
301-
Returns
302-
-------
303-
new_config : dict
304-
305-
"""
306-
config = self.config if config is None else config
307-
new_config = {}
308-
for key, value in config.items():
309-
if isinstance(value, dict) and len(value) > 0:
310-
value = self.flatten(value)
311-
for _key, _value in value.items():
312-
if isinstance(_key, str):
313-
new_config[key + '/' + _key] = _value
314-
else:
315-
new_config[key] = {_key: _value}
316-
else:
317-
new_config[key] = value
322+
def copy(self):
323+
""" Create a shallow copy of the instance. """
324+
return Config(self.config.copy())
318325

319-
return new_config
326+
def __getitem__(self, key):
327+
value = self._get(key)
328+
return value
320329

321330
def __setitem__(self, key, value):
322331
if key in self.config:
@@ -326,10 +335,6 @@ def __setitem__(self, key, value):
326335
def __delitem__(self, key):
327336
self.pop(key)
328337

329-
def copy(self):
330-
""" Create a shallow copy of the instance. """
331-
return Config(self.config.copy())
332-
333338
def __getattr__(self, key):
334339
if key in self.config:
335340
value = self.config.get(key)
@@ -344,13 +349,6 @@ def __add__(self, other):
344349
return Config([*self.flatten().items(), *other.flatten().items()])
345350
return NotImplemented
346351

347-
def __iter__(self):
348-
return iter(self.config)
349-
350-
def __repr__(self):
351-
lines = ['\n' + 4 * ' ' + line for line in pformat(self.config).split('\n')]
352-
return f"Config({''.join(lines)})"
353-
354352
def __iadd__(self, other):
355353
if isinstance(other, dict):
356354
self.update(other)
@@ -363,21 +361,20 @@ def __radd__(self, other):
363361
other = Config(other)
364362
return other.__add__(self)
365363

366-
def __len__(self):
367-
return len(self.config)
368-
369364
def __eq__(self, other):
370365
self_ = self.flatten()
371366
other_ = Config(other).flatten() if isinstance(other, dict) else other
372367
return self_.__eq__(other_)
373368

374-
def __getstate__(self):
375-
""" Must be explicitly defined for pickling to work. """
376-
return vars(self)
369+
def __len__(self):
370+
return len(self.config)
377371

378-
def __setstate__(self, state):
379-
""" Must be explicitly defined for pickling to work. """
380-
vars(self).update(state)
372+
def __iter__(self):
373+
return iter(self.config)
374+
375+
def __repr__(self):
376+
lines = ['\n' + 4 * ' ' + line for line in pformat(self.config).split('\n')]
377+
return f"Config({''.join(lines)})"
381378

382379
def __rshift__(self, other):
383380
""" Parameters
@@ -390,3 +387,11 @@ def __rshift__(self, other):
390387
Pipeline object with an updated config.
391388
"""
392389
return other << self
390+
391+
def __getstate__(self):
392+
""" Must be explicitly defined for pickling to work. """
393+
return vars(self)
394+
395+
def __setstate__(self, state):
396+
""" Must be explicitly defined for pickling to work. """
397+
vars(self).update(state)

0 commit comments

Comments
 (0)