|
| 1 | +#!/usr/bin/env python |
| 2 | +"""Shared pytest fixtures for deterministic OSMnx tests.""" |
| 3 | + |
| 4 | +from __future__ import annotations |
| 5 | + |
| 6 | +import json |
| 7 | +import os |
| 8 | +from pathlib import Path |
| 9 | +from typing import TYPE_CHECKING |
| 10 | +from typing import Any |
| 11 | +from typing import TypeAlias |
| 12 | + |
| 13 | +import matplotlib as mpl |
| 14 | +import networkx as nx |
| 15 | +import pytest |
| 16 | +import requests |
| 17 | +from shapely import LineString |
| 18 | + |
| 19 | +mpl.use("Agg") |
| 20 | + |
| 21 | +if TYPE_CHECKING: |
| 22 | + from collections.abc import Iterator |
| 23 | + |
| 24 | +LOCATION_POINT = (37.791427, -122.410018) |
| 25 | +ADDRESS = "Transamerica Pyramid, 600 Montgomery Street, San Francisco, California, USA" |
| 26 | +PLACE = {"city": "Piedmont", "state": "California", "country": "USA"} |
| 27 | +TAGS: dict[str, bool | str | list[str]] = { |
| 28 | + "landuse": True, |
| 29 | + "building": True, |
| 30 | + "highway": True, |
| 31 | + "amenity": True, |
| 32 | +} |
| 33 | + |
| 34 | +_ResponseJson: TypeAlias = dict[str, object] | list[dict[str, object]] |
| 35 | +HTTP_OK = 200 |
| 36 | +HTTP_ERROR = 500 |
| 37 | + |
| 38 | + |
| 39 | +def _drive_graph() -> nx.MultiDiGraph: |
| 40 | + import osmnx as ox # noqa: PLC0415 |
| 41 | + |
| 42 | + return ox.graph_from_point( |
| 43 | + LOCATION_POINT, |
| 44 | + dist=500, |
| 45 | + network_type="drive", |
| 46 | + simplify=False, |
| 47 | + retain_all=True, |
| 48 | + ) |
| 49 | + |
| 50 | + |
| 51 | +def _toy_graph(*, crs: str = "epsg:4326") -> nx.MultiDiGraph: |
| 52 | + G = nx.MultiDiGraph(crs=crs) |
| 53 | + G.add_node(1, x=0.0, y=0.0, street_count=1, elevation=0.0) |
| 54 | + G.add_node(2, x=1.0, y=0.0, street_count=2, elevation=10.0) |
| 55 | + G.add_node(3, x=2.0, y=0.0, street_count=1, elevation=20.0) |
| 56 | + G.add_edge( |
| 57 | + 1, |
| 58 | + 2, |
| 59 | + osmid=10, |
| 60 | + length=1.0, |
| 61 | + highway="residential", |
| 62 | + maxspeed="25 mph", |
| 63 | + geometry=LineString([(0, 0), (1, 0)]), |
| 64 | + ) |
| 65 | + G.add_edge( |
| 66 | + 2, |
| 67 | + 3, |
| 68 | + osmid=11, |
| 69 | + length=1.0, |
| 70 | + highway=["primary", "secondary"], |
| 71 | + maxspeed=["30 mph", "50"], |
| 72 | + geometry=LineString([(1, 0), (2, 0)]), |
| 73 | + ) |
| 74 | + return G |
| 75 | + |
| 76 | + |
| 77 | +class _Response(requests.Response): |
| 78 | + def __init__(self, payload: _ResponseJson, *, ok: bool = True, status_code: int = 200) -> None: |
| 79 | + super().__init__() |
| 80 | + self._payload = payload |
| 81 | + self.status_code = status_code if ok or status_code != HTTP_OK else HTTP_ERROR |
| 82 | + self.reason = "OK" if ok else "Error" |
| 83 | + self._content = json.dumps(payload).encode() |
| 84 | + self.url = "https://example.com/api" |
| 85 | + |
| 86 | + def json(self, **kwargs: object) -> _ResponseJson: |
| 87 | + del kwargs |
| 88 | + return self._payload |
| 89 | + |
| 90 | + |
| 91 | +drive_graph = _drive_graph |
| 92 | +toy_graph = _toy_graph |
| 93 | +Response = _Response |
| 94 | + |
| 95 | +HTTP_CACHE_DIR = Path(__file__).parent / "input_data" / "http_cache" |
| 96 | + |
| 97 | + |
| 98 | +def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None: |
| 99 | + """ |
| 100 | + Skip live online tests unless explicitly requested. |
| 101 | +
|
| 102 | + Parameters |
| 103 | + ---------- |
| 104 | + config |
| 105 | + Pytest configuration object. |
| 106 | + items |
| 107 | + Collected test items. |
| 108 | + """ |
| 109 | + del config |
| 110 | + |
| 111 | + if os.environ.get("OSMNX_RUN_ONLINE_TESTS"): |
| 112 | + return |
| 113 | + |
| 114 | + skip_online = pytest.mark.skip(reason="set OSMNX_RUN_ONLINE_TESTS=1 to run online tests") |
| 115 | + for item in items: |
| 116 | + if item.get_closest_marker("online"): |
| 117 | + item.add_marker(skip_online) |
| 118 | + |
| 119 | + |
| 120 | +@pytest.fixture(autouse=True) |
| 121 | +def _isolate_settings(tmp_path: Path) -> Iterator[None]: |
| 122 | + """ |
| 123 | + Restore global settings after each test and isolate generated files. |
| 124 | +
|
| 125 | + Parameters |
| 126 | + ---------- |
| 127 | + tmp_path |
| 128 | + Temporary directory unique to the test. |
| 129 | +
|
| 130 | + Yields |
| 131 | + ------ |
| 132 | + None |
| 133 | + Control returns to pytest after each test. |
| 134 | + """ |
| 135 | + import osmnx as ox # noqa: PLC0415 |
| 136 | + |
| 137 | + original_settings = { |
| 138 | + name: getattr(ox.settings, name) |
| 139 | + for name in dir(ox.settings) |
| 140 | + if not name.startswith("_") and name.islower() |
| 141 | + } |
| 142 | + |
| 143 | + ox.settings.data_folder = tmp_path / "data" |
| 144 | + ox.settings.logs_folder = tmp_path / "logs" |
| 145 | + ox.settings.imgs_folder = tmp_path / "imgs" |
| 146 | + ox.settings.cache_folder = tmp_path / "cache" |
| 147 | + ox.settings.log_console = False |
| 148 | + ox.settings.log_file = False |
| 149 | + ox.settings.use_cache = True |
| 150 | + |
| 151 | + yield |
| 152 | + |
| 153 | + for name, value in original_settings.items(): |
| 154 | + setattr(ox.settings, name, value) |
| 155 | + |
| 156 | + |
| 157 | +@pytest.fixture(autouse=True) |
| 158 | +def _block_network( |
| 159 | + monkeypatch: pytest.MonkeyPatch, |
| 160 | + request: pytest.FixtureRequest, |
| 161 | +) -> None: |
| 162 | + """ |
| 163 | + Prevent accidental live HTTP calls in the default offline suite. |
| 164 | +
|
| 165 | + Parameters |
| 166 | + ---------- |
| 167 | + monkeypatch |
| 168 | + Pytest fixture for temporary object replacement. |
| 169 | + request |
| 170 | + Pytest request object for the active test. |
| 171 | + """ |
| 172 | + if request.node.get_closest_marker("online") and os.environ.get("OSMNX_RUN_ONLINE_TESTS"): |
| 173 | + return |
| 174 | + |
| 175 | + def _blocked_request(*args: Any, **kwargs: Any) -> None: # noqa: ANN401, ARG001 |
| 176 | + msg = ( |
| 177 | + "Network access is blocked in offline tests. Mark the test with " |
| 178 | + "`@pytest.mark.online` and set OSMNX_RUN_ONLINE_TESTS=1 to allow it." |
| 179 | + ) |
| 180 | + raise AssertionError(msg) |
| 181 | + |
| 182 | + monkeypatch.setattr(requests, "get", _blocked_request) |
| 183 | + monkeypatch.setattr(requests, "post", _blocked_request) |
| 184 | + |
| 185 | + |
| 186 | +@pytest.fixture |
| 187 | +def http_cache() -> Path: |
| 188 | + """ |
| 189 | + Point OSMnx cache lookups at committed raw HTTP response fixtures. |
| 190 | +
|
| 191 | + Returns |
| 192 | + ------- |
| 193 | + pathlib.Path |
| 194 | + Directory containing committed raw HTTP cache files. |
| 195 | + """ |
| 196 | + import osmnx as ox # noqa: PLC0415 |
| 197 | + |
| 198 | + ox.settings.cache_folder = HTTP_CACHE_DIR |
| 199 | + ox.settings.use_cache = True |
| 200 | + ox.settings.overpass_rate_limit = False |
| 201 | + return HTTP_CACHE_DIR |
0 commit comments