@@ -492,3 +492,67 @@ def test_db_access_in_test_module(self, django_pytester: DjangoPytester) -> None
492492 'or the "db" or "transactional_db" fixtures to enable it.'
493493 ]
494494 )
495+
496+
497+ def test_custom_django_db_setup (django_pytester : DjangoPytester ) -> None :
498+ pytest .importorskip ("xdist" )
499+ pytest .importorskip ("psycopg" )
500+
501+ django_pytester .makeconftest (
502+ """
503+ import pytest
504+ import psycopg
505+ from django.conf import settings as django_settings
506+
507+ def run_sql(query, fetch=False, db='default'):
508+ conn = psycopg.connect(
509+ user=django_settings.DATABASES[db]['USER'],
510+ password=django_settings.DATABASES[db]['PASSWORD'],
511+ host=django_settings.DATABASES[db]['HOST'],
512+ port=django_settings.DATABASES['default']['PORT']
513+ )
514+ (cur := conn.cursor()).execute(query)
515+ response = cur.fetchone() if fetch else None
516+ conn.close()
517+ return response
518+
519+ @pytest.fixture(scope='session')
520+ def django_db_createdb(request, django_db_createdb) -> bool:
521+ db_name = f'test_{django_settings.DATABASES["default"]["NAME"]}'
522+ if xdist_suffix := getattr(request.config, 'workerinput', {}).get('workerid'):
523+ db_name = f'{db_name}_{xdist_suffix}'
524+ db_exists = (result := run_sql(query=f"SELECT EXISTS (SELECT 1 FROM pg_database WHERE datname='{db_name}')", fetch=True)) and result and result[0]
525+ if django_db_createdb or not db_exists:
526+ run_sql('CREATE EXTENSION IF NOT EXISTS vector')
527+ return django_db_createdb or not db_exists
528+
529+ @pytest.fixture(scope='session')
530+ def django_db_setup(django_db_setup, django_db_blocker, django_db_createdb) -> None:
531+ del django_db_setup
532+ if django_db_createdb:
533+ with django_db_blocker.unblock():
534+ call_command('flush', '--noinput')
535+ call_command('loaddata', *pathlib.Path().glob('tests/db_fixtures/**/*.yaml'))
536+ """
537+ )
538+
539+ django_pytester .create_test_module (
540+ """
541+ import pytest
542+ from .app.models import Item
543+
544+ @pytest.mark.django_db
545+ def test_simple():
546+ assert Item.objects.count() == 0
547+ """
548+ )
549+
550+ result = django_pytester .runpytest_subprocess ("-vv" , "--reuse-db" , "-n" , "auto" )
551+ print (result .stdout .str ())
552+ print (result .stderr .str ())
553+ result .assert_outcomes (passed = 1 )
554+
555+ result = django_pytester .runpytest_subprocess ("-vv" , "--reuse-db" , "-n" , "auto" )
556+ print (result .stdout .str ())
557+ print (result .stderr .str ())
558+ result .assert_outcomes (passed = 1 )
0 commit comments