11from pprint import pformat
22
33class Config (dict ):
4+
5+ # Should be defined temporarily for the already pickled configs
46 class IAddDict (dict ):
57 pass
8+
69 def __init__ (self , config = None , ** kwargs ):
710 """ Create Config.
811
@@ -50,7 +53,15 @@ def parse(self, config):
5053 self : Config
5154
5255 """
53- items = config .items () if isinstance (config , dict ) else dict (config ).items ()
56+ if isinstance (config , Config ):
57+ items = config .items (flatten = True ) # suppose we have config = {'a': {'b': {'c': 1}}},
58+ # and we try to update config with other = {'a': {'b': {'d': 3}}},
59+ # and expect to see config = {'a': {'b': {'c': 1, 'd': 3}}}
60+ elif isinstance (config , dict ):
61+ items = config .items ()
62+ else :
63+ items = dict (config ).items ()
64+ # items = config.items() if isinstance(config, dict) else dict(config).items()
5465
5566 for key , value in items :
5667 if isinstance (key , str ): # if key contains multiple consecutive '/'
@@ -87,7 +98,7 @@ def put(self, key, value):
8798 if isinstance (value , dict ) and last_level in config and isinstance (config [last_level ], dict ):
8899 config [last_level ].update (value )
89100 else :
90- # for example, we try to set config['a/b/c'] = 3, where config = Config({'a/b': 1})
101+ # for example, we try to set config['a/b/c'] = 3, where config = Config({'a/b': 1}) and don't want error here
91102 if isinstance (config , dict ):
92103 config [last_level ] = value
93104 else :
@@ -96,7 +107,9 @@ def put(self, key, value):
96107 self .config [key ] = value
97108
98109 def _get (self , key , default = None , has_default = False , pop = False ):
99- """ Consecutively retrieve values for a given key if the key contains '/'. """
110+ """ Consecutively retrieve values for a given key if the key contains '/'.
111+ This method supports the `default` to be unique for each variable in key.
112+ """
100113 method = 'get' if not pop else 'pop'
101114 method = getattr (self .config , method )
102115
@@ -106,48 +119,57 @@ def _get(self, key, default=None, has_default=False, pop=False):
106119 unpack = True
107120
108121 # Provide `default` for each variable in key
109- if has_default :
110- if isinstance (default , (list , tuple )) and len (key ) != 1 and len (default ) != len (key ):
111- raise ValueError () #
112- elif not isinstance (default , (list , tuple )) and len (key ) != 1 :
113- default = [default ] * len (key )
122+ if default is not None and len (key ) != 1 and len (default ) != len (key ):
123+ raise ValueError ('You should provide `default` for each variable in `key`' ) # edit
124+ default = [default ] if not isinstance (default , list ) else default
114125
115126 ret_vars = []
116- for variable in key :
127+ for ix , variable in enumerate ( key ) :
117128
118129 if isinstance (variable , str ) and '/' in variable :
119130 value = self .config
120131 levels = variable .split ('/' )
121132 values = []
122133
123134 for level in levels :
135+
124136 if not isinstance (value , dict ):
125- if has_default :
126- return default
127- raise KeyError (level )
128- if level not in value :
129- if has_default :
130- return default
131- raise KeyError (level )
132- value = value [level ]
133- values .append (value )
137+ if not has_default :
138+ raise KeyError (level )
139+ value = default [ix ]
140+ ret_vars .append (value )
141+ break
142+
143+ elif level not in value :
144+ if not has_default :
145+ raise KeyError (level )
146+ value = default [ix ]
147+ ret_vars .append (value )
148+ break
149+
150+ else :
151+ value = value [level ]
152+ values .append (value )
134153
135154 if pop :
136155 del values [- 2 ][level ] # delete the last level from the parent dict
137156
138157 else :
139158 if variable not in self .config :
140- if has_default :
141- return default
142- raise KeyError
143- value = method (variable )
159+ if not has_default :
160+ raise KeyError (variable )
161+ value = default [ix ]
162+ ret_vars .append (value )
163+
164+ else :
165+ value = method (variable )
144166
145167 if isinstance (value , dict ):
146168 value = Config (value )
147-
148169 ret_vars .append (value )
149170
150171 ret_vars = ret_vars [0 ] if unpack else tuple (ret_vars )
172+
151173 return ret_vars
152174
153175 def get (self , key , default = None ):
@@ -206,7 +228,7 @@ def update(self, other=None, **kwargs):
206228 if not isinstance (other , (dict , tuple , list )):
207229 raise TypeError (f'{ type (other )} object is not iterable' )
208230
209- self .parse (other )
231+ self .parse (Config ( other ) )
210232
211233 for key , value in kwargs .items ():
212234 self .put (key , value )
@@ -346,7 +368,6 @@ def __len__(self):
346368
347369 def __eq__ (self , other ):
348370 self_ = self .flatten ()
349- print (self_ , 'self_' )
350371 other_ = Config (other ).flatten () if isinstance (other , dict ) else other
351372 return self_ .__eq__ (other_ )
352373
0 commit comments