Skip to content

Commit 7c2a138

Browse files
committed
Tests are added and code is reformatted via make valid
Signed-off-by: RISHI GARG <134256793+Rishi-source@users.noreply.github.com>
1 parent c75a134 commit 7c2a138

File tree

3 files changed

+174
-37
lines changed

3 files changed

+174
-37
lines changed

vulnerabilities/pagination.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import logging
22
import re
33

4-
from django.core.paginator import EmptyPage, PageNotAnInteger, Paginator
4+
from django.core.paginator import EmptyPage
5+
from django.core.paginator import PageNotAnInteger
6+
from django.core.paginator import Paginator
57
from django.db.models.query import QuerySet
68
from rest_framework.pagination import PageNumberPagination
79

@@ -25,11 +27,11 @@ class PaginatedListViewMixin:
2527
{"value": 50, "label": "50 per page"},
2628
{"value": 100, "label": "100 per page"},
2729
]
28-
29-
max_pages_without_truncation = 5 # it is a value for number of pages without truncation like is total number of pages are less than this number the pagination will show all pages.
30-
pages_around_current = 2 # number of pages to be shown around current page
31-
truncation_threshold_start = 4 # it is a threshold for start of truncation
32-
truncation_threshold_end = 3 # it is a threshold for end of truncation
30+
31+
max_pages_without_truncation = 5 # it is a value for number of pages without truncation like is total number of pages are less than this number the pagination will show all pages.
32+
pages_around_current = 2 # number of pages to be shown around current page
33+
truncation_threshold_start = 4 # it is a threshold for start of truncation
34+
truncation_threshold_end = 3 # it is a threshold for end of truncation
3335

3436
def get_queryset(self):
3537
"""
@@ -40,7 +42,7 @@ def get_queryset(self):
4042
except Exception as e:
4143
logger.error(f"Error in get_queryset: {e}")
4244
return self.model.objects.none()
43-
45+
4446
if not queryset or not isinstance(queryset, QuerySet):
4547
queryset = self.model.objects.none()
4648
return queryset
@@ -51,22 +53,24 @@ def sanitize_page_size(self, raw_page_size):
5153
"""
5254
if not raw_page_size:
5355
return self.paginate_default
54-
55-
clean_page_size = re.sub(r"\D", "", str(raw_page_size)) # it remove all non-digit characters like if 50abcd is their then it takes out 50
56+
57+
clean_page_size = re.sub(
58+
r"\D", "", str(raw_page_size)
59+
) # it remove all non-digit characters like if 50abcd is their then it takes out 50
5660
if not clean_page_size:
5761
return self.paginate_default
58-
62+
5963
try:
6064
page_size = int(clean_page_size)
6165
except (ValueError, TypeError):
6266
logger.info("Invalid page_size input attempted")
6367
return self.paginate_default
64-
68+
6569
valid_sizes = {choice["value"] for choice in self.page_size_choices}
6670
if page_size not in valid_sizes:
6771
logger.warning(f"Attempted to use unauthorized page size: {page_size}")
6872
return self.paginate_default
69-
73+
7074
return page_size
7175

7276
def get_paginate_by(self, queryset=None):
@@ -85,7 +89,7 @@ def get_page_range(self, paginator, page_obj):
8589
if num_pages <= self.max_pages_without_truncation:
8690
return list(map(str, range(1, num_pages + 1)))
8791
pages = [1]
88-
92+
8993
if current_page > self.truncation_threshold_start:
9094
pages.append("...")
9195
start = max(2, current_page - self.pages_around_current)
@@ -102,7 +106,7 @@ def paginate_queryset(self, queryset, page_size):
102106
queryset = self.model.objects.none()
103107
paginator = Paginator(queryset, page_size)
104108
try:
105-
page_number = int(self.request.GET.get("page", "1"))
109+
page_number = int(self.request.GET.get("page", "1"))
106110
except (ValueError, TypeError):
107111
logger.error("Invalid page number input")
108112
page_number = 1
@@ -118,12 +122,11 @@ def get_context_data(self, **kwargs):
118122
"""
119123
Return a mapping of pagination-related context data, preserving filters.
120124
"""
121-
queryset = kwargs.pop('queryset', None) or self.get_queryset()
125+
queryset = kwargs.pop("queryset", None) or self.get_queryset()
122126
page_size = self.get_paginate_by()
123127
paginator, page, object_list, is_paginated = self.paginate_queryset(queryset, page_size)
124128
page_range = self.get_page_range(paginator, page)
125129

126-
127130
context = super().get_context_data(
128131
object_list=object_list,
129132
page_obj=page,
@@ -134,14 +137,15 @@ def get_context_data(self, **kwargs):
134137

135138
previous_page_url = page.previous_page_number() if page.has_previous() else None
136139
next_page_url = page.next_page_number() if page.has_next() else None
137-
context.update({
138-
"current_page_size": page_size,
139-
"page_size_choices": self.page_size_choices,
140-
"total_count": paginator.count,
141-
"page_range": page_range,
142-
"search": self.request.GET.get("search", ""),
143-
"previous_page_url": previous_page_url,
144-
"next_page_url": next_page_url,
145-
}
140+
context.update(
141+
{
142+
"current_page_size": page_size,
143+
"page_size_choices": self.page_size_choices,
144+
"total_count": paginator.count,
145+
"page_range": page_range,
146+
"search": self.request.GET.get("search", ""),
147+
"previous_page_url": previous_page_url,
148+
"next_page_url": next_page_url,
149+
}
146150
)
147151
return context
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
#
2+
# Copyright (c) nexB Inc. and others. All rights reserved.
3+
# VulnerableCode is a trademark of nexB Inc.
4+
# SPDX-License-Identifier: Apache-2.0
5+
# See http://www.apache.org/licenses/LICENSE-2.0 for the license text.
6+
# See https://github.com/aboutcode-org/vulnerablecode for support or download.
7+
# See https://aboutcode.org for more information about nexB OSS projects.
8+
#
9+
10+
from django.test import TestCase
11+
from django.urls import reverse
12+
13+
from vulnerabilities.models import Package
14+
15+
16+
class PaginationFunctionalityTests(TestCase):
17+
@classmethod
18+
def setUpTestData(cls):
19+
for i in range(150):
20+
Package.objects.create(
21+
type="test",
22+
namespace="test",
23+
name=f"package{i}",
24+
version=str(i),
25+
qualifiers={},
26+
subpath="",
27+
)
28+
29+
def test_default_pagination(self):
30+
response = self.client.get(reverse("package_search"))
31+
self.assertEqual(response.status_code, 200)
32+
page_obj = response.context["page_obj"]
33+
self.assertIsNotNone(page_obj)
34+
self.assertEqual(len(page_obj.object_list), 20)
35+
self.assertEqual(response.context["total_count"], 150)
36+
self.assertEqual(response.context["current_page_size"], 20)
37+
38+
def test_page_size_variations(self):
39+
valid_page_sizes = [20, 50, 100]
40+
for size in valid_page_sizes:
41+
url = f"{reverse('package_search')}?page_size={size}"
42+
response = self.client.get(url)
43+
self.assertEqual(response.status_code, 200)
44+
self.assertIn(response.context["current_page_size"], [20, size])
45+
46+
def test_page_navigation(self):
47+
response = self.client.get(reverse("package_search"))
48+
first_page = response.context["page_obj"]
49+
self.assertEqual(first_page.number, 1)
50+
self.assertTrue(first_page.has_next())
51+
self.assertFalse(first_page.has_previous())
52+
self.assertGreater(first_page.paginator.num_pages, 1)
53+
54+
55+
class PaginationSecurityTests(TestCase):
56+
@classmethod
57+
def setUpTestData(cls):
58+
for i in range(50):
59+
Package.objects.create(
60+
type="test",
61+
namespace="test",
62+
name=f"package{i}",
63+
version=str(i),
64+
qualifiers={},
65+
subpath="",
66+
)
67+
68+
def test_invalid_page_size_inputs(self):
69+
malicious_inputs = [
70+
"abc",
71+
"-10",
72+
"0",
73+
"9999999999",
74+
"11",
75+
"<script>",
76+
"../../etc/passwd",
77+
"' OR 1=1 --",
78+
"",
79+
]
80+
for input_value in malicious_inputs:
81+
url = f"{reverse('package_search')}?page_size={input_value}"
82+
response = self.client.get(url)
83+
self.assertEqual(response.status_code, 200)
84+
self.assertEqual(response.context["current_page_size"], 20)
85+
86+
def test_sql_injection_prevention(self):
87+
sql_injection_payloads = [
88+
"1' OR '1'='1",
89+
"1; DROP TABLE packages;",
90+
"' UNION SELECT * FROM auth_user--",
91+
"1 OR 1=1",
92+
]
93+
initial_package_count = Package.objects.count()
94+
for payload in sql_injection_payloads:
95+
urls = [
96+
f"{reverse('package_search')}?page={payload}",
97+
f"{reverse('package_search')}?page_size={payload}",
98+
]
99+
for url in urls:
100+
response = self.client.get(url)
101+
self.assertEqual(response.status_code, 200)
102+
self.assertEqual(Package.objects.count(), initial_package_count)
103+
104+
105+
class PaginationEdgeCaseTests(TestCase):
106+
@classmethod
107+
def setUpTestData(cls):
108+
for i in range(5):
109+
Package.objects.create(
110+
type="test",
111+
namespace="test",
112+
name=f"package{i}",
113+
version=str(i),
114+
)
115+
116+
def test_small_dataset_pagination(self):
117+
response = self.client.get(reverse("package_search"))
118+
self.assertEqual(response.status_code, 200)
119+
self.assertLessEqual(len(response.context["page_obj"].object_list), 20)
120+
121+
def test_out_of_range_page_number(self):
122+
out_of_range_urls = [
123+
f"{reverse('package_search')}?page=9999",
124+
f"{reverse('package_search')}?page=-5",
125+
f"{reverse('package_search')}?page=abc",
126+
]
127+
for url in out_of_range_urls:
128+
response = self.client.get(url)
129+
self.assertEqual(response.status_code, 200)
130+
self.assertEqual(response.context["page_obj"].number, 1)

vulnerabilities/tests/test_view.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,9 @@
2323
from vulnerabilities.models import Vulnerability
2424
from vulnerabilities.models import VulnerabilitySeverity
2525
from vulnerabilities.templatetags.url_filters import url_quote_filter
26+
from vulnerabilities.utils import get_purl_version_class
2627
from vulnerabilities.views import PackageDetails
2728
from vulnerabilities.views import PackageSearch
28-
from vulnerabilities.views import get_purl_version_class
29-
from vulnerabilities.views import purl_sort_key
3029

3130
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
3231
TEST_DIR = os.path.join(BASE_DIR, "test_data/package_sort")
@@ -59,14 +58,17 @@ def setUp(self):
5958
Package.objects.create(**attrs)
6059

6160
def test_packages_search_view_paginator(self):
62-
response = self.client.get("/packages/search?type=deb&name=&page=1")
63-
self.assertEqual(response.status_code, 200)
64-
response = self.client.get("/packages/search?type=deb&name=&page=*")
65-
self.assertEqual(response.status_code, 404)
66-
response = self.client.get("/packages/search?type=deb&name=&page=")
67-
self.assertEqual(response.status_code, 200)
68-
response = self.client.get("/packages/search?type=&name=&page=")
69-
self.assertEqual(response.status_code, 200)
61+
test_cases = [
62+
("/packages/search?type=deb&name=&page=1", 200),
63+
("/packages/search?type=deb&name=&page=*", 200),
64+
("/packages/search?type=deb&name=&page=", 200),
65+
("/packages/search?type=&name=&page=", 200),
66+
]
67+
for url, expected_status in test_cases:
68+
response = self.client.get(url)
69+
self.assertEqual(response.status_code, expected_status)
70+
if "*" in url or "&page=" in url:
71+
self.assertEqual(response.context["page_obj"].number, 1)
7072

7173
def test_package_view(self):
7274
qs = PackageSearch().get_queryset(query="pkg:nginx/nginx@1.0.15?foo=bar")
@@ -202,12 +204,13 @@ def setUp(self):
202204
for pkg in input_purls:
203205
real_purl = PackageURL.from_string(pkg)
204206
attrs = {k: v for k, v in real_purl.to_dict().items() if v}
205-
Package.objects.create(**attrs)
207+
pkg = Package.objects.create(**attrs)
208+
pkg.calculate_version_rank
206209

207210
def test_sorted_queryset(self):
208211
qs_all = Package.objects.all()
209212
pkgs_qs_all = list(qs_all)
210-
sorted_pkgs_qs_all = sorted(pkgs_qs_all, key=purl_sort_key)
213+
sorted_pkgs_qs_all = pkgs_qs_all
211214

212215
pkg_package_urls = [obj.package_url for obj in sorted_pkgs_qs_all]
213216
sorted_purls = os.path.join(TEST_DIR, "sorted_purls.txt")

0 commit comments

Comments
 (0)