Skip to content

Commit ef2a8f4

Browse files
authored
fix NB 4.5.4 user arg removal from BaseTable (#191)
1 parent b8d5203 commit ef2a8f4

5 files changed

Lines changed: 71 additions & 16 deletions

File tree

validity/netbox_changes/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
from validity import config
1111

1212

13+
def get_base_table_kwargs(self):
14+
return {"user": self.request.user} if config.netbox_version < "4.5.4" else {}
15+
16+
1317
StrFilterLookup = locate("strawberry_django.StrFilterLookup") if config.netbox_version >= "4.5.5" else FilterLookup[str]
1418

1519
if config.netbox_version >= "4.5.0":

validity/tests/test_views.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import datetime
12
import textwrap
23
from functools import partial
34
from http import HTTPStatus
@@ -7,6 +8,7 @@
78
import pytest
89
from base import ViewTest
910
from django.urls import reverse
11+
from django.utils import timezone
1012
from django.utils.functional import classproperty
1113
from factories import (
1214
BackupPointFactory,
@@ -167,6 +169,15 @@ def test_get_serialized_state(admin_client, item, monkeypatch):
167169
assert resp.status_code == HTTPStatus.OK
168170

169171

172+
@pytest.mark.parametrize("query_params", [{}, {"sort": "test"}, {"sort": "-created"}])
173+
@pytest.mark.django_db
174+
def test_device_results(admin_client, query_params):
175+
device = DeviceFactory()
176+
CompTestResultFactory(device=device)
177+
resp = admin_client.get(f"/dcim/devices/{device.pk}/results/", query_params)
178+
assert resp.status_code == HTTPStatus.OK
179+
180+
170181
@pytest.mark.parametrize("query_params", [{}, {"sort": "device"}, {"sort": "-device"}])
171182
@pytest.mark.django_db
172183
def test_report_devices(admin_client, query_params):
@@ -263,35 +274,75 @@ def test_datasource_devices(admin_client):
263274
assert resp.status_code == HTTPStatus.OK
264275

265276

266-
class TestRunTests:
267-
url = "/plugins/validity/tests/run/"
277+
class TestRunTestsView:
278+
"""Covers RunTestsView (validity.views.script.RunTestsView)."""
279+
280+
@staticmethod
281+
def _url():
282+
return reverse("plugins:validity:compliancetest_run")
268283

269284
def test_get(self, admin_client):
270-
resp = admin_client.get(self.url)
285+
resp = admin_client.get(self._url())
271286
assert resp.status_code == HTTPStatus.OK
272287

273288
@pytest.mark.parametrize(
274289
"form_data, status_code, has_workers",
275290
[
276291
({}, HTTPStatus.FOUND, True),
277292
({}, HTTPStatus.OK, False),
278-
({"devices": [1, 2]}, HTTPStatus.OK, True), # devices do not exist
293+
({"devices": [1, 2]}, HTTPStatus.OK, True), # devices do not exist — invalid choices
279294
],
280295
)
281296
def test_post(self, admin_client, di, form_data, status_code, has_workers):
282297
launcher = Mock(**{"has_workers": has_workers, "return_value.pk": 1})
283298
with di.override({dependencies.runtests_launcher: lambda: launcher}):
284-
result = admin_client.post(self.url, form_data)
299+
result = admin_client.post(self._url(), form_data)
285300
assert result.status_code == status_code
286301
if status_code == HTTPStatus.FOUND: # if form is valid
287302
launcher.assert_called_once()
288303
assert isinstance(launcher.call_args.args[0], RunTestsParams)
289304

305+
@pytest.mark.django_db
306+
def test_post_with_valid_devices(self, admin_client, di):
307+
d1, d2 = DeviceFactory(), DeviceFactory()
308+
launcher = Mock(has_workers=True, return_value=Mock(pk=1))
309+
with di.override({dependencies.runtests_launcher: lambda: launcher}):
310+
resp = admin_client.post(self._url(), {"devices": [d1.pk, d2.pk]})
311+
assert resp.status_code == HTTPStatus.FOUND
312+
launcher.assert_called_once()
313+
assert isinstance(launcher.call_args.args[0], RunTestsParams)
314+
290315

291316
@pytest.mark.parametrize("job_factory", [RunTestsJobFactory, DSBackupJobFactory])
292-
def test_scriptresult(admin_client, job_factory):
293-
job = job_factory(status="completed")
294-
resp = admin_client.get(f"/plugins/validity/scripts/results/{job.pk}/")
317+
@pytest.mark.django_db
318+
def test_script_result_view_completed_job(admin_client, job_factory):
319+
"""
320+
Full GET for a finished job: ScriptResultView only builds the log table when ``job.completed``
321+
(the *timestamp* field) is set — same as real jobs after ``terminate()``. ``status`` alone is not enough.
322+
"""
323+
completed_at = timezone.now()
324+
job = job_factory(
325+
status="completed",
326+
started=completed_at - datetime.timedelta(minutes=1),
327+
completed=completed_at,
328+
data={"output": "test output", "log": []},
329+
)
330+
assert job.completed, "need completion timestamp set or get_table is skipped (differs from browser)"
331+
332+
url = reverse("plugins:validity:script_result", kwargs={"pk": job.pk})
333+
resp = admin_client.get(url)
334+
assert resp.status_code == HTTPStatus.OK, getattr(resp, "content", b"")[:2000]
335+
336+
337+
@pytest.mark.parametrize("job_factory", [RunTestsJobFactory, DSBackupJobFactory])
338+
@pytest.mark.django_db
339+
def test_script_result_view_incomplete_job(admin_client, job_factory):
340+
"""Running job has no completion timestamp, so get_table is not used."""
341+
job = job_factory(status="running", started=timezone.now())
342+
assert not job.completed
343+
344+
url = reverse("plugins:validity:script_result", kwargs={"pk": job.pk})
345+
resp = admin_client.get(url)
295346
assert resp.status_code == HTTPStatus.OK
296347

297348

validity/views/base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from utilities.views import ObjectPermissionRequiredMixin as _ObjectPermissionRequiredMixin
1414
from utilities.views import ViewTab
1515

16-
from validity import filtersets, forms, models, scripts, tables
16+
from validity import filtersets, forms, models, netbox_changes, scripts, tables
1717
from validity.utils.misc import partialcls
1818

1919

@@ -100,8 +100,7 @@ def get_table(self, **kwargs):
100100
table.exclude = (self.result_relation,)
101101
return table
102102

103-
def get_table_kwargs(self):
104-
return {"user": self.request.user}
103+
get_table_kwargs = netbox_changes.get_base_table_kwargs
105104

106105
def get_queryset(self):
107106
return self.queryset.filter(**{self.result_relation: self.kwargs["pk"]})

validity/views/report.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from netbox.views import generic
99
from utilities.views import ViewTab, register_model_view
1010

11-
from validity import config, filtersets, forms, models, tables
11+
from validity import filtersets, forms, models, netbox_changes, tables
1212
from validity.choices import DeviceGroupByChoices, SeverityChoices
1313
from .base import FilterViewWithForm, ObjectPermissionRequiredMixin, TestResultBaseView
1414

@@ -102,8 +102,7 @@ def get_table(self, **kwargs):
102102
table.configure(self.request)
103103
return table
104104

105-
def get_table_kwargs(self):
106-
return {"user": self.request.user} if config.netbox_version < "4.5.4" else {}
105+
get_table_kwargs = netbox_changes.get_base_table_kwargs
107106

108107
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
109108
return super().get_context_data(**kwargs) | {

validity/views/script.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from validity import di
1414
from validity.forms import RunTestsForm
15-
from validity.netbox_changes import get_logs
15+
from validity.netbox_changes import get_base_table_kwargs, get_logs
1616
from validity.scripts import Launcher, RunTestsParams, ScriptParams
1717
from validity.tables import ScriptResultTable
1818
from .base import LauncherMixin
@@ -52,10 +52,12 @@ class ScriptResultView(PermissionRequiredMixin, TableMixin, ObjectView):
5252

5353
def get_table(self, job, request, bulk_actions=False):
5454
logs = [entry | {"index": i} for i, entry in enumerate(get_logs(job), start=1)]
55-
table = self.table_class(logs, user=request.user)
55+
table = self.table_class(logs, **self.get_table_kwargs())
5656
table.configure(request)
5757
return table
5858

59+
get_table_kwargs = get_base_table_kwargs
60+
5961
def get(self, request, **kwargs):
6062
job = self.get_object(**kwargs)
6163
table = self.get_table(job, request) if job.completed else None

0 commit comments

Comments
 (0)