diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index 466f8199bfd1..0d9178dae812 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -431,6 +431,15 @@ class BaseDatabaseFeatures: # that should be skipped for this database. django_test_skips = {} + # DatabaseWrapper methods that should raise an error if accessed in + # django.test.SimpleTestCase. + disallowed_simple_test_case_connection_methods = [ + ("connect", "connections"), + ("temporary_connection", "connections"), + ("cursor", "queries"), + ("chunked_cursor", "queries"), + ] + supports_uuid4_function = False supports_uuid7_function = False supports_uuid7_function_shift = False diff --git a/django/test/testcases.py b/django/test/testcases.py index b31e70d9d18e..f866a33963d0 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -189,7 +189,7 @@ def __init__(self, wrapped, message): self.wrapped = wrapped self.message = message - def __call__(self): + def __call__(self, *args, **kwargs): raise DatabaseOperationForbidden(self.message) @@ -209,12 +209,6 @@ class SimpleTestCase(unittest.TestCase): "proper test isolation or add %(alias)r to %(test)s.databases to silence " "this failure." ) - _disallowed_connection_methods = [ - ("connect", "connections"), - ("temporary_connection", "connections"), - ("cursor", "queries"), - ("chunked_cursor", "queries"), - ] @classmethod def setUpClass(cls): @@ -254,7 +248,10 @@ def _add_databases_failures(cls): if alias in cls.databases: continue connection = connections[alias] - for name, operation in cls._disallowed_connection_methods: + disallowed_methods = ( + connection.features.disallowed_simple_test_case_connection_methods + ) + for name, operation in disallowed_methods: message = cls._disallowed_database_msg % { "test": "%s.%s" % (cls.__module__, cls.__qualname__), "alias": alias, @@ -276,7 +273,10 @@ def _remove_databases_failures(cls): if alias in cls.databases: continue connection = connections[alias] - for name, _ in cls._disallowed_connection_methods: + disallowed_methods = ( + connection.features.disallowed_simple_test_case_connection_methods + ) + for name, _ in disallowed_methods: method = getattr(connection, name) setattr(connection, name, method.wrapped)