Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions newsfragments/1225.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert Config kept in TypedDict into a dataclass
17 changes: 10 additions & 7 deletions pytest_postgresql/config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
"""Plugin's configuration."""

from dataclasses import dataclass
from pathlib import Path
from typing import Any, TypedDict
from typing import Any

from _pytest._py.path import LocalPath
from pytest import FixtureRequest


class PostgresqlConfigDict(TypedDict):
"""Typed Config dictionary."""
@dataclass(frozen=True)
class PostgreSQLConfig:
"""PostgreSQL Config."""

exec: str
host: str
Expand All @@ -25,16 +27,16 @@ class PostgresqlConfigDict(TypedDict):
drop_test_database: bool


def get_config(request: FixtureRequest) -> PostgresqlConfigDict:
"""Return a dictionary with config options."""
def get_config(request: FixtureRequest) -> PostgreSQLConfig:
"""Return a PostgreSQLConfig instance with configuration options."""

def get_postgresql_option(option: str) -> Any:
name = "postgresql_" + option
return request.config.getoption(name) or request.config.getini(name)

load_paths = detect_paths(get_postgresql_option("load"))
load_paths: list[Path | str] = detect_paths(get_postgresql_option("load"))

return PostgresqlConfigDict(
cfg = PostgreSQLConfig(
exec=get_postgresql_option("exec"),
host=get_postgresql_option("host"),
port=get_postgresql_option("port"),
Expand All @@ -50,6 +52,7 @@ def get_postgresql_option(option: str) -> Any:
postgres_options=get_postgresql_option("postgres_options"),
drop_test_database=request.config.getoption("postgresql_drop_test_database"),
)
return cfg


def detect_paths(load_paths: list[LocalPath | str]) -> list[Path | str]:
Expand Down
2 changes: 1 addition & 1 deletion pytest_postgresql/factories/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def postgresql_factory(request: FixtureRequest) -> Iterator[Connection]:
password=pg_password,
isolation_level=isolation_level,
)
if config["drop_test_database"]:
if config.drop_test_database:
janitor.drop()
with janitor:
db_connection: Connection = psycopg.connect(
Expand Down
16 changes: 8 additions & 8 deletions pytest_postgresql/factories/noprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ def postgresql_noproc_fixture(request: FixtureRequest) -> Iterator[NoopExecutor]
:returns: tcp executor-like object
"""
config = get_config(request)
pg_host = host or config["host"]
pg_port = port or config["port"] or 5432
pg_user = user or config["user"]
pg_password = password or config["password"]
pg_dbname = xdistify_dbname(dbname or config["dbname"])
pg_options = options or config["options"]
pg_load = load or config["load"]
drop_test_database = config["drop_test_database"]
pg_host = host or config.host
pg_port = port or config.port or 5432
pg_user = user or config.user
pg_password = password or config.password
pg_dbname = xdistify_dbname(dbname or config.dbname)
pg_options = options or config.options
pg_load = load or config.load
drop_test_database = config.drop_test_database

noop_exec = NoopExecutor(
host=pg_host,
Expand Down
32 changes: 16 additions & 16 deletions pytest_postgresql/factories/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@
from port_for import PortForException, get_port
from pytest import FixtureRequest, TempPathFactory

from pytest_postgresql.config import PostgresqlConfigDict, get_config
from pytest_postgresql.config import PostgreSQLConfig, get_config
from pytest_postgresql.exceptions import ExecutableMissingException
from pytest_postgresql.executor import PostgreSQLExecutor
from pytest_postgresql.janitor import DatabaseJanitor

PortType = port_for.PortType # mypy requires explicit export


def _pg_exe(executable: str | None, config: PostgresqlConfigDict) -> str:
def _pg_exe(executable: str | None, config: PostgreSQLConfig) -> str:
"""If executable is set, use it. Otherwise best effort to find the executable."""
postgresql_ctl = executable or config["exec"]
postgresql_ctl = executable or config.exec
# check if that executable exists, as it's no on systems' PATH
# only replace it if executable isn't passed manually
if not os.path.exists(postgresql_ctl) and executable is None:
Expand All @@ -50,9 +50,9 @@ def _pg_exe(executable: str | None, config: PostgresqlConfigDict) -> str:
return postgresql_ctl


def _pg_port(port: PortType | None, config: PostgresqlConfigDict, excluded_ports: Iterable[int]) -> int:
def _pg_port(port: PortType | None, config: PostgreSQLConfig, excluded_ports: Iterable[int]) -> int:
"""User specified port, otherwise find an unused port from config."""
pg_port = get_port(port, excluded_ports) or get_port(config["port"], excluded_ports)
pg_port = get_port(port, excluded_ports) or get_port(config.port, excluded_ports)
assert pg_port is not None
return pg_port

Expand Down Expand Up @@ -115,8 +115,8 @@ def postgresql_proc_fixture(
:returns: tcp executor
"""
config = get_config(request)
pg_dbname = dbname or config["dbname"]
pg_load = load or config["load"]
pg_dbname = dbname or config.dbname
pg_load = load or config.load
postgresql_ctl = _pg_exe(executable, config)
port_path = tmp_path_factory.getbasetemp()
if hasattr(request.config, "workerinput"):
Expand All @@ -138,7 +138,7 @@ def postgresql_proc_fixture(
port_file.write(f"pg_port {pg_port}\n")
break
except FileExistsError:
if n >= config["port_search_count"]:
if n >= config.port_search_count:
raise PortForException(
f"Attempted {n} times to select ports. "
f"All attempted ports: {', '.join(map(str, used_ports))} are already "
Expand All @@ -151,17 +151,17 @@ def postgresql_proc_fixture(

postgresql_executor = PostgreSQLExecutor(
executable=postgresql_ctl,
host=host or config["host"],
host=host or config.host,
port=pg_port,
user=user or config["user"],
password=password or config["password"],
user=user or config.user,
password=password or config.password,
dbname=pg_dbname,
options=options or config["options"],
options=options or config.options,
datadir=str(datadir),
unixsocketdir=unixsocketdir or config["unixsocketdir"],
unixsocketdir=unixsocketdir or config.unixsocketdir,
logfile=str(logfile_path),
startparams=startparams or config["startparams"],
postgres_options=postgres_options or config["postgres_options"],
startparams=startparams or config.startparams,
postgres_options=postgres_options or config.postgres_options,
)
# start server
with postgresql_executor:
Expand All @@ -174,7 +174,7 @@ def postgresql_proc_fixture(
version=postgresql_executor.version,
password=postgresql_executor.password,
)
if config["drop_test_database"]:
if config.drop_test_database:
janitor.drop()
with janitor:
for load_element in pg_load:
Expand Down
2 changes: 1 addition & 1 deletion tests/examples/test_assert_port_search_count_is_ten.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
def test_assert_port_search_count_is_ten(request: FixtureRequest) -> None:
"""Asserts that port_search_count is 10."""
config = get_config(request)
assert config["port_search_count"] == 10
assert config.port_search_count == 10
22 changes: 11 additions & 11 deletions tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,16 @@ def version(self) -> Any:
def test_unsupported_version(request: FixtureRequest) -> None:
"""Check that the error gets raised on unsupported postgres version."""
config = get_config(request)
port = get_port(config["port"])
port = get_port(config.port)
assert port is not None
executor = PatchedPostgreSQLExecutor(
executable=config["exec"],
host=config["host"],
executable=config.exec,
host=config.host,
port=port,
datadir="/tmp/error",
unixsocketdir=config["unixsocketdir"],
unixsocketdir=config.unixsocketdir,
logfile="/tmp/version.error.log",
startparams=config["startparams"],
startparams=config.startparams,
dbname="random_name",
)

Expand All @@ -84,12 +84,12 @@ def test_executor_init_with_password(
datadir, logfile_path = process._prepare_dir(tmpdir, port)
executor = PostgreSQLExecutor(
executable=pg_exe,
host=config["host"],
host=config.host,
port=port,
datadir=str(datadir),
unixsocketdir=config["unixsocketdir"],
unixsocketdir=config.unixsocketdir,
logfile=str(logfile_path),
startparams=config["startparams"],
startparams=config.startparams,
password="somepassword",
dbname="somedatabase",
)
Expand All @@ -109,12 +109,12 @@ def test_executor_init_bad_tmp_path(
datadir, logfile_path = process._prepare_dir(tmpdir, port)
executor = PostgreSQLExecutor(
executable=pg_exe,
host=config["host"],
host=config.host,
port=port,
datadir=str(datadir),
unixsocketdir=config["unixsocketdir"],
unixsocketdir=config.unixsocketdir,
logfile=str(logfile_path),
startparams=config["startparams"],
startparams=config.startparams,
password="some password",
dbname="some database",
)
Expand Down
Loading