Skip to content

Commit d14ddb9

Browse files
authored
add unit tests to sqlite_helpers.py (#62)
* add unit tests to sqlite_helpers.py * add unit test yml * tests -> test * install tzdata * ZoneInfo('UTC')
1 parent 9e94388 commit d14ddb9

3 files changed

Lines changed: 257 additions & 9 deletions

File tree

.github/workflows/unit-tests.yml

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
name: Cleezy Unit Tests
2+
3+
on:
4+
push:
5+
branches: [dev]
6+
pull_request:
7+
branches: [dev]
8+
9+
jobs:
10+
test:
11+
runs-on: ubuntu-latest
12+
13+
steps:
14+
- name: Checkout code
15+
uses: actions/checkout@v3
16+
17+
- name: Set up Python
18+
uses: actions/setup-python@v4
19+
with:
20+
python-version: '3.11'
21+
22+
- name: Install dependencies
23+
run: |
24+
python -m pip install --upgrade pip
25+
pip install -r requirements.txt
26+
pip install tzdata # see https://github.com/celery/kombu/issues/2106
27+
28+
- name: Run tests
29+
run: |
30+
python -m unittest discover -s test

modules/sqlite_helpers.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1+
from datetime import datetime
2+
import logging
13
import sqlite3
2-
from datetime import datetime, timedelta
4+
import typing
35
from zoneinfo import ZoneInfo
4-
import logging
5-
from modules.args import get_args
6+
67

78
ROWS_PER_PAGE = 25
89

910
logger = logging.getLogger(__name__)
10-
args = get_args()
11-
expiration_date_timezone = ZoneInfo(args.expiration_date_timezone)
1211

1312
def maybe_create_table(sqlite_file: str) -> bool:
1413
db = sqlite3.connect(sqlite_file)
@@ -39,7 +38,7 @@ def maybe_create_table(sqlite_file: str) -> bool:
3938
return False
4039

4140

42-
def insert_url(sqlite_file: str, url: str, alias: str, expiration_date: str):
41+
def insert_url(sqlite_file: str, url: str, alias: str, expiration_date: typing.Union[str, None] = None):
4342
db = sqlite3.connect(sqlite_file)
4443
cursor = db.cursor()
4544
timestamp = datetime.now()
@@ -93,7 +92,6 @@ def get_urls(sqlite_file, page=0, search=None, sort_by="created_at", order="DESC
9392
def get_url(sqlite_file: str, alias: str): #return the string for url entry for a specified alias
9493
db = sqlite3.connect(sqlite_file)
9594
cursor = db.cursor()
96-
9795
try:
9896
sql = "SELECT * FROM urls WHERE alias = ?"
9997
cursor.execute(sql, (alias,))
@@ -125,14 +123,15 @@ def delete_url(sqlite_file: str, alias: str): #delete entry in the database from
125123
def maybe_delete_expired_url(sqlite_file, sqlite_row) -> bool: #returns True if url expired and deleted, otherwise False
126124
db = sqlite3.connect(sqlite_file)
127125
cursor = db.cursor()
126+
utc_tz = ZoneInfo('UTC')
128127

129128
expiration_datetime = None
130129
# sqlite_row[5] represents the expiration datetime e.g., "2024-11-04 18:05:24.006593"
131130
if sqlite_row[5] is not None:
132131
expiration_datetime = datetime.strptime(sqlite_row[5], "%Y-%m-%d %H:%M:%S.%f")
133-
expiration_datetime = expiration_datetime.replace(tzinfo=expiration_date_timezone)
132+
expiration_datetime = expiration_datetime.replace(tzinfo=utc_tz)
134133

135-
now = datetime.now(expiration_date_timezone)
134+
now = datetime.now(tz=utc_tz)
136135
if expiration_datetime is not None and expiration_datetime < now:
137136
sql = "DELETE FROM urls WHERE alias = ?"
138137
cursor.execute(sql, (sqlite_row[2], ))

test/test_sqlite_helpers.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
import datetime
2+
import os
3+
import sqlite3
4+
import sys
5+
import tempfile
6+
import unittest
7+
from unittest import mock
8+
9+
# this allows imports from the modules folder to work
10+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
11+
from modules import sqlite_helpers
12+
13+
14+
class TestDatabaseSetup(unittest.TestCase):
15+
EXAMPLE_DATETIME = datetime.datetime(1996, 12, 24, 12, 0, 0)
16+
EXAMPLE_URL = "https://sce.sjsu.edu/"
17+
18+
def test_maybe_create_table(self):
19+
with tempfile.NamedTemporaryFile() as tmp:
20+
result = sqlite_helpers.maybe_create_table(tmp.name)
21+
self.assertTrue(result)
22+
23+
db = sqlite3.connect(tmp.name)
24+
cursor = db.cursor()
25+
cursor.execute(
26+
"SELECT name FROM sqlite_master WHERE type='table' AND name='urls';"
27+
)
28+
# get first element in response, cursor.fetchone() returns ('urls',)
29+
[table] = cursor.fetchone()
30+
self.assertEqual(table, "urls")
31+
32+
@mock.patch("sqlite3.connect")
33+
def test_maybe_create_table_handles_exception(self, mock_connect):
34+
# Mock connection and cursor
35+
mock_conn = mock.MagicMock()
36+
mock_cursor = mock.MagicMock()
37+
mock_cursor.execute.side_effect = Exception("Simulated execute error")
38+
39+
mock_conn.cursor.return_value = mock_cursor
40+
mock_connect.return_value = mock_conn
41+
42+
with tempfile.NamedTemporaryFile() as tmp:
43+
result = sqlite_helpers.maybe_create_table(tmp.name)
44+
self.assertFalse(result)
45+
46+
@mock.patch("modules.sqlite_helpers.datetime")
47+
def test_insert_url(self, mock_datetime):
48+
mock_datetime.fromisoformat.return_value = self.EXAMPLE_DATETIME
49+
mock_datetime.now.return_value = self.EXAMPLE_DATETIME
50+
cases = [
51+
("does not set expiration date if it is None", None),
52+
("sets expiration date if it is passed in", self.EXAMPLE_DATETIME),
53+
]
54+
55+
for test_name, expiration_value in cases:
56+
with self.subTest(test_name=test_name, expiration_value=expiration_value):
57+
with tempfile.NamedTemporaryFile() as tmp:
58+
sqlite_helpers.maybe_create_table(tmp.name)
59+
result = sqlite_helpers.insert_url(
60+
tmp.name, self.EXAMPLE_URL, "home", expiration_value
61+
)
62+
self.assertEqual(result, self.EXAMPLE_DATETIME)
63+
64+
db = sqlite3.connect(tmp.name)
65+
cursor = db.cursor()
66+
cursor.execute("SELECT * FROM urls;")
67+
68+
[
69+
row_id,
70+
url,
71+
alias,
72+
created_at,
73+
used,
74+
expires_at,
75+
] = cursor.fetchone()
76+
self.assertEqual(row_id, 1)
77+
self.assertEqual(url, self.EXAMPLE_URL)
78+
self.assertEqual(alias, "home")
79+
self.assertEqual(created_at, "1996-12-24 12:00:00")
80+
self.assertEqual(used, 1)
81+
if expiration_value is None:
82+
self.assertIsNone(expires_at)
83+
else:
84+
self.assertEqual(expires_at, "1996-12-24 12:00:00")
85+
86+
def test_insert_url_duplicate_alias_not_allowed(self):
87+
with tempfile.NamedTemporaryFile() as tmp:
88+
sqlite_helpers.maybe_create_table(tmp.name)
89+
result = sqlite_helpers.insert_url(tmp.name, self.EXAMPLE_URL, "home")
90+
self.assertIsNotNone(result)
91+
result_duplicate_alias = sqlite_helpers.insert_url(
92+
tmp.name, self.EXAMPLE_URL, "home"
93+
)
94+
self.assertIsNone(result_duplicate_alias)
95+
96+
def test_get_urls(self):
97+
with tempfile.NamedTemporaryFile() as tmp:
98+
sqlite_helpers.maybe_create_table(tmp.name)
99+
result = sqlite_helpers.insert_url(tmp.name, self.EXAMPLE_URL, "home")
100+
self.assertIsNotNone(result, self.EXAMPLE_DATETIME)
101+
result_duplicate_alias = sqlite_helpers.insert_url(
102+
tmp.name, self.EXAMPLE_URL, "home"
103+
)
104+
self.assertIsNone(result_duplicate_alias)
105+
106+
def test_get_urls_page(self):
107+
cases = [
108+
# the first url should be the most recently created one
109+
# with alias url_29
110+
# the last url should be url_5, and the remaining urls
111+
# (url_0 through url_4) were not included in the first page
112+
("Returns the first page of urls", 0, "url_29", "url_5"),
113+
("Returns the second page of urls", 1, "url_4", "url_0"),
114+
]
115+
for test_name, page, first_alias, last_alias in cases:
116+
with self.subTest(
117+
test_name=test_name, first_alias=first_alias, last_alias=last_alias
118+
):
119+
with tempfile.NamedTemporaryFile() as tmp:
120+
sqlite_helpers.maybe_create_table(tmp.name)
121+
# create 30 urls with aliases url_0 to url_29
122+
for i in range(30):
123+
alias = f"url_{i}"
124+
sqlite_helpers.insert_url(tmp.name, self.EXAMPLE_URL, alias)
125+
stuff = sqlite_helpers.get_urls(tmp.name, page)
126+
self.assertEqual(stuff[0]["alias"], first_alias)
127+
self.assertEqual(stuff[-1]["alias"], last_alias)
128+
129+
def test_get_urls_page_search(self):
130+
with tempfile.NamedTemporaryFile() as tmp:
131+
sqlite_helpers.maybe_create_table(tmp.name)
132+
# create 30 urls with aliases url_0 to url_29
133+
for i in range(30):
134+
alias = f"url_{i}"
135+
sqlite_helpers.insert_url(tmp.name, self.EXAMPLE_URL, alias)
136+
# we should get a full page of answers with a vague search term
137+
stuff = sqlite_helpers.get_urls(tmp.name, search="url")
138+
self.assertEqual(len(stuff), sqlite_helpers.ROWS_PER_PAGE)
139+
# we should get one answer with a specfic search term
140+
stuff = sqlite_helpers.get_urls(tmp.name, search="url_0")
141+
self.assertEqual(len(stuff), 1)
142+
# we should get nothing with a not found search term
143+
stuff = sqlite_helpers.get_urls(tmp.name, search="not_real")
144+
self.assertEqual(len(stuff), 0)
145+
146+
def test_get_url(self):
147+
with tempfile.NamedTemporaryFile() as tmp:
148+
sqlite_helpers.maybe_create_table(tmp.name)
149+
sqlite_helpers.insert_url(tmp.name, self.EXAMPLE_URL, "home")
150+
result = sqlite_helpers.get_url(tmp.name, "home")
151+
self.assertEqual(result, self.EXAMPLE_URL)
152+
# not found alias returns None
153+
result = sqlite_helpers.get_url(tmp.name, "not_real")
154+
self.assertIsNone(result, self.EXAMPLE_URL)
155+
156+
def test_get_url_expired(self):
157+
# querying a url that has expired returns nothing
158+
with tempfile.NamedTemporaryFile() as tmp:
159+
sqlite_helpers.maybe_create_table(tmp.name)
160+
sqlite_helpers.insert_url(
161+
tmp.name, self.EXAMPLE_URL, "home", self.EXAMPLE_DATETIME.isoformat()
162+
)
163+
result = sqlite_helpers.get_url(tmp.name, "home")
164+
self.assertIsNone(result)
165+
166+
def test_delete_url(self):
167+
with tempfile.NamedTemporaryFile() as tmp:
168+
sqlite_helpers.maybe_create_table(tmp.name)
169+
sqlite_helpers.insert_url(tmp.name, self.EXAMPLE_URL, "home")
170+
self.assertIsNotNone(sqlite_helpers.get_url(tmp.name, "home"))
171+
172+
result = sqlite_helpers.delete_url(tmp.name, "home")
173+
self.assertTrue(result)
174+
self.assertIsNone(sqlite_helpers.get_url(tmp.name, "home"))
175+
176+
# trying to delete the same url again returls false
177+
result_second_call = sqlite_helpers.get_url(tmp.name, "home")
178+
self.assertFalse(result_second_call)
179+
180+
def test_get_number_of_entries(self):
181+
with tempfile.NamedTemporaryFile() as tmp:
182+
sqlite_helpers.maybe_create_table(tmp.name)
183+
for i in range(30):
184+
alias = f"url_{i}"
185+
sqlite_helpers.insert_url(tmp.name, self.EXAMPLE_URL, alias)
186+
result = sqlite_helpers.get_number_of_entries(tmp.name)
187+
self.assertEqual(result, 30)
188+
result_with_search = sqlite_helpers.get_number_of_entries(tmp.name, "url_1")
189+
# matches url_1 and url_10 - url_19
190+
self.assertEqual(result_with_search, 11)
191+
result_with_search = sqlite_helpers.get_number_of_entries(
192+
tmp.name, "not_real"
193+
)
194+
self.assertEqual(result_with_search, 0)
195+
196+
def test_increment_used_column(self):
197+
with tempfile.NamedTemporaryFile() as tmp:
198+
sqlite_helpers.maybe_create_table(tmp.name)
199+
sqlite_helpers.insert_url(tmp.name, self.EXAMPLE_URL, "home")
200+
db = sqlite3.connect(tmp.name)
201+
cursor = db.cursor()
202+
203+
cursor.execute("SELECT used FROM urls;")
204+
[used] = cursor.fetchone()
205+
self.assertEqual(used, 1)
206+
207+
sqlite_helpers.increment_used_column(tmp.name, "home")
208+
cursor.execute("SELECT used FROM urls;")
209+
[used] = cursor.fetchone()
210+
self.assertEqual(used, 2)
211+
212+
sqlite_helpers.increment_used_column(tmp.name, "home", 20)
213+
cursor.execute("SELECT used FROM urls;")
214+
[used] = cursor.fetchone()
215+
self.assertEqual(used, 22)
216+
217+
218+
if __name__ == "__main__":
219+
unittest.main()

0 commit comments

Comments
 (0)