Skip to content

Commit bb6d861

Browse files
committed
feat: support customizing attr getter in SubFactory
This change should allow for the case where a `SelfAttribute` attempts to access the field of a `SubFactory` that is a `DictFactory`. The user can override the `deepgetattr_func` of the `SubFactory` to use `deepdictgetattr`, or one of their own choosing, which will be used to retrieve the attribute. closes #1134
1 parent a912544 commit bb6d861

3 files changed

Lines changed: 115 additions & 2 deletions

File tree

factory/declarations.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright: See the LICENSE file.
22

33

4+
import contextlib
45
import itertools
56
import logging
67
import typing as T
@@ -166,6 +167,65 @@ class _UNSPECIFIED:
166167
pass
167168

168169

170+
def dictgetattr(obj, attr, default=_UNSPECIFIED):
171+
"""A version of `getattr` that allows retrieving attributes from dictionaries.
172+
173+
Args:
174+
obj (object): the object of which an attribute or key should be read
175+
name (str): the name of an attribute or key to look up.
176+
default (object): the default value to use if the attribute or key wasn't found.
177+
178+
Returns:
179+
the value pointed to by 'name'.
180+
181+
Raises:
182+
KeyError: if obj has no 'name' key.
183+
TypeError: if obj is not subscriptable. Also raised if object has no attribute.
184+
"""
185+
# First, try and get the attribute ignoring any attribute errors.
186+
with contextlib.suppress(AttributeError):
187+
return getattr(obj, attr)
188+
189+
# If that doesn't work, assume it's a dictionary and try to retrieve the value by
190+
# using the attr name as the key.
191+
try:
192+
return obj[attr]
193+
except (KeyError, TypeError):
194+
# If no default was specified, re-raise the exception. Otherwise, return the
195+
# default.
196+
if default is _UNSPECIFIED:
197+
raise
198+
199+
return default
200+
201+
202+
def deepdictgetattr(obj, attr, default=_UNSPECIFIED):
203+
"""A deepgetattr that also works with dictionaries. This will first try retrieving
204+
the named attribute before attempting to use bracket (`[]`) access assuming it might
205+
be a dictionary.
206+
207+
Args:
208+
obj (object): the object of which an attribute or key should be read.
209+
name (str): the name of an attribute or key to look up. This can be a dotted
210+
path to a nested attribute or key.
211+
default (object): the default value to use if the attribute wasn't found.
212+
213+
Returns:
214+
the value pointed to by 'name', splitting on '.'.
215+
216+
Raises:
217+
KeyError: if obj has no 'name' key.
218+
TypeError: if obj is not subscriptable. Also raised if object has no attribute.
219+
"""
220+
result = obj
221+
# Split on `.` and treat the values in between as parts of a path. Chain off the
222+
# previously retrieved result to get the next nested attribute.
223+
for attr_name in attr.split("."):
224+
result = dictgetattr(result, attr_name, default)
225+
226+
return result
227+
228+
169229
def deepgetattr(obj, name, default=_UNSPECIFIED):
170230
"""Try to retrieve the given attribute of an object, digging on '.'.
171231
@@ -206,16 +266,20 @@ class SelfAttribute(BaseDeclaration):
206266
attribute_name (str): the name of the attribute to copy.
207267
default (object): the default value to use if the attribute doesn't
208268
exist.
269+
deepgetattr_func (callable): the callable used to retrieve the potentially
270+
nested attribute. Defaults to `deepgetattr` which supports dotted path
271+
attribute access.
209272
"""
210273

211-
def __init__(self, attribute_name, default=_UNSPECIFIED):
274+
def __init__(self, attribute_name, default=_UNSPECIFIED, deepgetattr_func=deepgetattr):
212275
super().__init__()
213276
depth = len(attribute_name) - len(attribute_name.lstrip('.'))
214277
attribute_name = attribute_name[depth:]
215278

216279
self.depth = depth
217280
self.attribute_name = attribute_name
218281
self.default = default
282+
self.deepgetattr_func = deepgetattr_func
219283

220284
def evaluate(self, instance, step, extra):
221285
if self.depth > 1:
@@ -225,7 +289,7 @@ def evaluate(self, instance, step, extra):
225289
target = instance
226290

227291
logger.debug("SelfAttribute: Picking attribute %r on %r", self.attribute_name, target)
228-
return deepgetattr(target, self.attribute_name, self.default)
292+
return self.deepgetattr_func(target, self.attribute_name, self.default)
229293

230294
def __repr__(self):
231295
return '<%s(%r, default=%r)>' % (

tests/test_declarations.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,38 @@ def test_chaining(self):
4242
self.assertEqual(42, declarations.deepgetattr(obj, 'a.b.c.n.x', 42))
4343

4444

45+
class DigDictTestCase(unittest.TestCase):
46+
class MyObj:
47+
def __init__(self, n):
48+
self.n = n
49+
50+
def test_chaining(self):
51+
"""This is the same test as the `DigTestCase.test_chaining`, except it tests
52+
that chaining works for dictionaries.
53+
"""
54+
dictionary = {"n": 1}
55+
dictionary["a"] = {"n": 2}
56+
dictionary["a"]["b"] = {"n": 3}
57+
dictionary["a"]["b"]["c"] = {"n": 4}
58+
59+
with self.assertRaises(TypeError):
60+
declarations.deepdictgetattr(self.MyObj(1), 'a')
61+
self.assertEqual(2, declarations.deepdictgetattr(dictionary, 'a')["n"])
62+
with self.assertRaises(KeyError):
63+
declarations.deepdictgetattr(dictionary, 'b')
64+
self.assertEqual(2, declarations.deepdictgetattr(dictionary, 'a.n'))
65+
self.assertEqual(3, declarations.deepdictgetattr(dictionary, 'a.c', 3))
66+
with self.assertRaises(KeyError):
67+
declarations.deepdictgetattr(dictionary, 'a.c.n')
68+
with self.assertRaises(KeyError):
69+
declarations.deepdictgetattr(dictionary, 'a.d')
70+
self.assertEqual(3, declarations.deepdictgetattr(dictionary, 'a.b')["n"])
71+
self.assertEqual(3, declarations.deepdictgetattr(dictionary, 'a.b.n'))
72+
self.assertEqual(4, declarations.deepdictgetattr(dictionary, 'a.b.c')["n"])
73+
self.assertEqual(4, declarations.deepdictgetattr(dictionary, 'a.b.c.n'))
74+
self.assertEqual(42, declarations.deepdictgetattr(dictionary, 'a.b.c.n.x', 42))
75+
76+
4577
class MaybeTestCase(unittest.TestCase):
4678
def test_init(self):
4779
declarations.Maybe('foo', 1, 2)

tests/test_using.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,23 @@ class Meta:
567567
test_model = TestModel2Factory()
568568
self.assertEqual(4, test_model.two.three)
569569

570+
def test_self_attribute_dict_factory(self):
571+
class ChildFactory(factory.DictFactory):
572+
first_name = "Bob"
573+
last_name = "Dole"
574+
575+
class ParentFactory(factory.DictFactory):
576+
child = factory.SubFactory(ChildFactory)
577+
last_name = factory.SelfAttribute(
578+
"child.last_name", deepgetattr_func=factory.declarations.deepdictgetattr
579+
)
580+
581+
parent = ParentFactory()
582+
self.assertEqual(
583+
parent,
584+
{"last_name": "Dole", "child": {"first_name": "Bob", "last_name": "Dole"}},
585+
)
586+
570587
def test_sequence_decorator(self):
571588
class TestObjectFactory(factory.Factory):
572589
class Meta:

0 commit comments

Comments
 (0)