From bb6d861c1d8fef33e39565a882e02a57a4bbaeaa Mon Sep 17 00:00:00 2001 From: Ryan Causey Date: Wed, 27 Aug 2025 17:34:44 -0700 Subject: [PATCH] feat: support customizing attr getter in SubFactory This change should allow for the case where a `SelfAttribute` attempts to access the field of a `SubFactory` that is a `DictFactory`. The user can override the `deepgetattr_func` of the `SubFactory` to use `deepdictgetattr`, or one of their own choosing, which will be used to retrieve the attribute. closes #1134 --- factory/declarations.py | 68 ++++++++++++++++++++++++++++++++++++-- tests/test_declarations.py | 32 ++++++++++++++++++ tests/test_using.py | 17 ++++++++++ 3 files changed, 115 insertions(+), 2 deletions(-) diff --git a/factory/declarations.py b/factory/declarations.py index f835f0d2..df69c839 100644 --- a/factory/declarations.py +++ b/factory/declarations.py @@ -1,6 +1,7 @@ # Copyright: See the LICENSE file. +import contextlib import itertools import logging import typing as T @@ -166,6 +167,65 @@ class _UNSPECIFIED: pass +def dictgetattr(obj, attr, default=_UNSPECIFIED): + """A version of `getattr` that allows retrieving attributes from dictionaries. + + Args: + obj (object): the object of which an attribute or key should be read + name (str): the name of an attribute or key to look up. + default (object): the default value to use if the attribute or key wasn't found. + + Returns: + the value pointed to by 'name'. + + Raises: + KeyError: if obj has no 'name' key. + TypeError: if obj is not subscriptable. Also raised if object has no attribute. + """ + # First, try and get the attribute ignoring any attribute errors. + with contextlib.suppress(AttributeError): + return getattr(obj, attr) + + # If that doesn't work, assume it's a dictionary and try to retrieve the value by + # using the attr name as the key. + try: + return obj[attr] + except (KeyError, TypeError): + # If no default was specified, re-raise the exception. Otherwise, return the + # default. + if default is _UNSPECIFIED: + raise + + return default + + +def deepdictgetattr(obj, attr, default=_UNSPECIFIED): + """A deepgetattr that also works with dictionaries. This will first try retrieving + the named attribute before attempting to use bracket (`[]`) access assuming it might + be a dictionary. + + Args: + obj (object): the object of which an attribute or key should be read. + name (str): the name of an attribute or key to look up. This can be a dotted + path to a nested attribute or key. + default (object): the default value to use if the attribute wasn't found. + + Returns: + the value pointed to by 'name', splitting on '.'. + + Raises: + KeyError: if obj has no 'name' key. + TypeError: if obj is not subscriptable. Also raised if object has no attribute. + """ + result = obj + # Split on `.` and treat the values in between as parts of a path. Chain off the + # previously retrieved result to get the next nested attribute. + for attr_name in attr.split("."): + result = dictgetattr(result, attr_name, default) + + return result + + def deepgetattr(obj, name, default=_UNSPECIFIED): """Try to retrieve the given attribute of an object, digging on '.'. @@ -206,9 +266,12 @@ class SelfAttribute(BaseDeclaration): attribute_name (str): the name of the attribute to copy. default (object): the default value to use if the attribute doesn't exist. + deepgetattr_func (callable): the callable used to retrieve the potentially + nested attribute. Defaults to `deepgetattr` which supports dotted path + attribute access. """ - def __init__(self, attribute_name, default=_UNSPECIFIED): + def __init__(self, attribute_name, default=_UNSPECIFIED, deepgetattr_func=deepgetattr): super().__init__() depth = len(attribute_name) - len(attribute_name.lstrip('.')) attribute_name = attribute_name[depth:] @@ -216,6 +279,7 @@ def __init__(self, attribute_name, default=_UNSPECIFIED): self.depth = depth self.attribute_name = attribute_name self.default = default + self.deepgetattr_func = deepgetattr_func def evaluate(self, instance, step, extra): if self.depth > 1: @@ -225,7 +289,7 @@ def evaluate(self, instance, step, extra): target = instance logger.debug("SelfAttribute: Picking attribute %r on %r", self.attribute_name, target) - return deepgetattr(target, self.attribute_name, self.default) + return self.deepgetattr_func(target, self.attribute_name, self.default) def __repr__(self): return '<%s(%r, default=%r)>' % ( diff --git a/tests/test_declarations.py b/tests/test_declarations.py index c49bbba7..36581a38 100644 --- a/tests/test_declarations.py +++ b/tests/test_declarations.py @@ -42,6 +42,38 @@ def test_chaining(self): self.assertEqual(42, declarations.deepgetattr(obj, 'a.b.c.n.x', 42)) +class DigDictTestCase(unittest.TestCase): + class MyObj: + def __init__(self, n): + self.n = n + + def test_chaining(self): + """This is the same test as the `DigTestCase.test_chaining`, except it tests + that chaining works for dictionaries. + """ + dictionary = {"n": 1} + dictionary["a"] = {"n": 2} + dictionary["a"]["b"] = {"n": 3} + dictionary["a"]["b"]["c"] = {"n": 4} + + with self.assertRaises(TypeError): + declarations.deepdictgetattr(self.MyObj(1), 'a') + self.assertEqual(2, declarations.deepdictgetattr(dictionary, 'a')["n"]) + with self.assertRaises(KeyError): + declarations.deepdictgetattr(dictionary, 'b') + self.assertEqual(2, declarations.deepdictgetattr(dictionary, 'a.n')) + self.assertEqual(3, declarations.deepdictgetattr(dictionary, 'a.c', 3)) + with self.assertRaises(KeyError): + declarations.deepdictgetattr(dictionary, 'a.c.n') + with self.assertRaises(KeyError): + declarations.deepdictgetattr(dictionary, 'a.d') + self.assertEqual(3, declarations.deepdictgetattr(dictionary, 'a.b')["n"]) + self.assertEqual(3, declarations.deepdictgetattr(dictionary, 'a.b.n')) + self.assertEqual(4, declarations.deepdictgetattr(dictionary, 'a.b.c')["n"]) + self.assertEqual(4, declarations.deepdictgetattr(dictionary, 'a.b.c.n')) + self.assertEqual(42, declarations.deepdictgetattr(dictionary, 'a.b.c.n.x', 42)) + + class MaybeTestCase(unittest.TestCase): def test_init(self): declarations.Maybe('foo', 1, 2) diff --git a/tests/test_using.py b/tests/test_using.py index 5b2200a6..b5a530f7 100644 --- a/tests/test_using.py +++ b/tests/test_using.py @@ -567,6 +567,23 @@ class Meta: test_model = TestModel2Factory() self.assertEqual(4, test_model.two.three) + def test_self_attribute_dict_factory(self): + class ChildFactory(factory.DictFactory): + first_name = "Bob" + last_name = "Dole" + + class ParentFactory(factory.DictFactory): + child = factory.SubFactory(ChildFactory) + last_name = factory.SelfAttribute( + "child.last_name", deepgetattr_func=factory.declarations.deepdictgetattr + ) + + parent = ParentFactory() + self.assertEqual( + parent, + {"last_name": "Dole", "child": {"first_name": "Bob", "last_name": "Dole"}}, + ) + def test_sequence_decorator(self): class TestObjectFactory(factory.Factory): class Meta: