Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 66 additions & 2 deletions factory/declarations.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright: See the LICENSE file.


import contextlib
import itertools
import logging
import typing as T
Expand Down Expand Up @@ -166,6 +167,65 @@ class _UNSPECIFIED:
pass


def dictgetattr(obj, attr, default=_UNSPECIFIED):
"""A version of `getattr` that allows retrieving attributes from dictionaries.

Args:
obj (object): the object of which an attribute or key should be read
name (str): the name of an attribute or key to look up.
default (object): the default value to use if the attribute or key wasn't found.

Returns:
the value pointed to by 'name'.

Raises:
KeyError: if obj has no 'name' key.
TypeError: if obj is not subscriptable. Also raised if object has no attribute.
"""
# First, try and get the attribute ignoring any attribute errors.
with contextlib.suppress(AttributeError):
return getattr(obj, attr)

# If that doesn't work, assume it's a dictionary and try to retrieve the value by
# using the attr name as the key.
try:
return obj[attr]
except (KeyError, TypeError):
# If no default was specified, re-raise the exception. Otherwise, return the
# default.
if default is _UNSPECIFIED:
raise

return default


def deepdictgetattr(obj, attr, default=_UNSPECIFIED):
"""A deepgetattr that also works with dictionaries. This will first try retrieving
the named attribute before attempting to use bracket (`[]`) access assuming it might
be a dictionary.

Args:
obj (object): the object of which an attribute or key should be read.
name (str): the name of an attribute or key to look up. This can be a dotted
path to a nested attribute or key.
default (object): the default value to use if the attribute wasn't found.

Returns:
the value pointed to by 'name', splitting on '.'.

Raises:
KeyError: if obj has no 'name' key.
TypeError: if obj is not subscriptable. Also raised if object has no attribute.
"""
result = obj
# Split on `.` and treat the values in between as parts of a path. Chain off the
# previously retrieved result to get the next nested attribute.
for attr_name in attr.split("."):
result = dictgetattr(result, attr_name, default)

return result


def deepgetattr(obj, name, default=_UNSPECIFIED):
Comment thread
ryancausey marked this conversation as resolved.
Comment thread
ryancausey marked this conversation as resolved.
"""Try to retrieve the given attribute of an object, digging on '.'.

Expand Down Expand Up @@ -206,16 +266,20 @@ class SelfAttribute(BaseDeclaration):
attribute_name (str): the name of the attribute to copy.
default (object): the default value to use if the attribute doesn't
exist.
deepgetattr_func (callable): the callable used to retrieve the potentially
nested attribute. Defaults to `deepgetattr` which supports dotted path
attribute access.
"""

def __init__(self, attribute_name, default=_UNSPECIFIED):
def __init__(self, attribute_name, default=_UNSPECIFIED, deepgetattr_func=deepgetattr):
super().__init__()
depth = len(attribute_name) - len(attribute_name.lstrip('.'))
attribute_name = attribute_name[depth:]

self.depth = depth
self.attribute_name = attribute_name
self.default = default
self.deepgetattr_func = deepgetattr_func

def evaluate(self, instance, step, extra):
if self.depth > 1:
Expand All @@ -225,7 +289,7 @@ def evaluate(self, instance, step, extra):
target = instance

logger.debug("SelfAttribute: Picking attribute %r on %r", self.attribute_name, target)
return deepgetattr(target, self.attribute_name, self.default)
return self.deepgetattr_func(target, self.attribute_name, self.default)

def __repr__(self):
return '<%s(%r, default=%r)>' % (
Expand Down
32 changes: 32 additions & 0 deletions tests/test_declarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,38 @@ def test_chaining(self):
self.assertEqual(42, declarations.deepgetattr(obj, 'a.b.c.n.x', 42))


class DigDictTestCase(unittest.TestCase):
class MyObj:
def __init__(self, n):
self.n = n

def test_chaining(self):
"""This is the same test as the `DigTestCase.test_chaining`, except it tests
that chaining works for dictionaries.
"""
dictionary = {"n": 1}
dictionary["a"] = {"n": 2}
dictionary["a"]["b"] = {"n": 3}
dictionary["a"]["b"]["c"] = {"n": 4}

with self.assertRaises(TypeError):
declarations.deepdictgetattr(self.MyObj(1), 'a')
self.assertEqual(2, declarations.deepdictgetattr(dictionary, 'a')["n"])
with self.assertRaises(KeyError):
declarations.deepdictgetattr(dictionary, 'b')
self.assertEqual(2, declarations.deepdictgetattr(dictionary, 'a.n'))
self.assertEqual(3, declarations.deepdictgetattr(dictionary, 'a.c', 3))
with self.assertRaises(KeyError):
declarations.deepdictgetattr(dictionary, 'a.c.n')
with self.assertRaises(KeyError):
declarations.deepdictgetattr(dictionary, 'a.d')
self.assertEqual(3, declarations.deepdictgetattr(dictionary, 'a.b')["n"])
self.assertEqual(3, declarations.deepdictgetattr(dictionary, 'a.b.n'))
self.assertEqual(4, declarations.deepdictgetattr(dictionary, 'a.b.c')["n"])
self.assertEqual(4, declarations.deepdictgetattr(dictionary, 'a.b.c.n'))
self.assertEqual(42, declarations.deepdictgetattr(dictionary, 'a.b.c.n.x', 42))


class MaybeTestCase(unittest.TestCase):
def test_init(self):
declarations.Maybe('foo', 1, 2)
Expand Down
17 changes: 17 additions & 0 deletions tests/test_using.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,23 @@ class Meta:
test_model = TestModel2Factory()
self.assertEqual(4, test_model.two.three)

def test_self_attribute_dict_factory(self):
class ChildFactory(factory.DictFactory):
first_name = "Bob"
last_name = "Dole"

class ParentFactory(factory.DictFactory):
child = factory.SubFactory(ChildFactory)
last_name = factory.SelfAttribute(
"child.last_name", deepgetattr_func=factory.declarations.deepdictgetattr
)

parent = ParentFactory()
self.assertEqual(
parent,
{"last_name": "Dole", "child": {"first_name": "Bob", "last_name": "Dole"}},
)

def test_sequence_decorator(self):
class TestObjectFactory(factory.Factory):
class Meta:
Expand Down