11# Copyright: See the LICENSE file.
22
33
4+ import contextlib
45import itertools
56import logging
67import typing as T
@@ -166,6 +167,65 @@ class _UNSPECIFIED:
166167 pass
167168
168169
170+ def dictgetattr (obj , attr , default = _UNSPECIFIED ):
171+ """A version of `getattr` that allows retrieving attributes from dictionaries.
172+
173+ Args:
174+ obj (object): the object of which an attribute or key should be read
175+ name (str): the name of an attribute or key to look up.
176+ default (object): the default value to use if the attribute or key wasn't found.
177+
178+ Returns:
179+ the value pointed to by 'name'.
180+
181+ Raises:
182+ KeyError: if obj has no 'name' key.
183+ TypeError: if obj is not subscriptable. Also raised if object has no attribute.
184+ """
185+ # First, try and get the attribute ignoring any attribute errors.
186+ with contextlib .suppress (AttributeError ):
187+ return getattr (obj , attr )
188+
189+ # If that doesn't work, assume it's a dictionary and try to retrieve the value by
190+ # using the attr name as the key.
191+ try :
192+ return obj [attr ]
193+ except (KeyError , TypeError ):
194+ # If no default was specified, re-raise the exception. Otherwise, return the
195+ # default.
196+ if default is _UNSPECIFIED :
197+ raise
198+
199+ return default
200+
201+
202+ def deepdictgetattr (obj , attr , default = _UNSPECIFIED ):
203+ """A deepgetattr that also works with dictionaries. This will first try retrieving
204+ the named attribute before attempting to use bracket (`[]`) access assuming it might
205+ be a dictionary.
206+
207+ Args:
208+ obj (object): the object of which an attribute or key should be read.
209+ name (str): the name of an attribute or key to look up. This can be a dotted
210+ path to a nested attribute or key.
211+ default (object): the default value to use if the attribute wasn't found.
212+
213+ Returns:
214+ the value pointed to by 'name', splitting on '.'.
215+
216+ Raises:
217+ KeyError: if obj has no 'name' key.
218+ TypeError: if obj is not subscriptable. Also raised if object has no attribute.
219+ """
220+ result = obj
221+ # Split on `.` and treat the values in between as parts of a path. Chain off the
222+ # previously retrieved result to get the next nested attribute.
223+ for attr_name in attr .split ("." ):
224+ result = dictgetattr (result , attr_name , default )
225+
226+ return result
227+
228+
169229def deepgetattr (obj , name , default = _UNSPECIFIED ):
170230 """Try to retrieve the given attribute of an object, digging on '.'.
171231
@@ -206,16 +266,20 @@ class SelfAttribute(BaseDeclaration):
206266 attribute_name (str): the name of the attribute to copy.
207267 default (object): the default value to use if the attribute doesn't
208268 exist.
269+ deepgetattr_func (callable): the callable used to retrieve the potentially
270+ nested attribute. Defaults to `deepgetattr` which supports dotted path
271+ attribute access.
209272 """
210273
211- def __init__ (self , attribute_name , default = _UNSPECIFIED ):
274+ def __init__ (self , attribute_name , default = _UNSPECIFIED , deepgetattr_func = deepgetattr ):
212275 super ().__init__ ()
213276 depth = len (attribute_name ) - len (attribute_name .lstrip ('.' ))
214277 attribute_name = attribute_name [depth :]
215278
216279 self .depth = depth
217280 self .attribute_name = attribute_name
218281 self .default = default
282+ self .deepgetattr_func = deepgetattr_func
219283
220284 def evaluate (self , instance , step , extra ):
221285 if self .depth > 1 :
@@ -225,7 +289,7 @@ def evaluate(self, instance, step, extra):
225289 target = instance
226290
227291 logger .debug ("SelfAttribute: Picking attribute %r on %r" , self .attribute_name , target )
228- return deepgetattr (target , self .attribute_name , self .default )
292+ return self . deepgetattr_func (target , self .attribute_name , self .default )
229293
230294 def __repr__ (self ):
231295 return '<%s(%r, default=%r)>' % (
0 commit comments