-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest_db.py
More file actions
94 lines (68 loc) · 2.5 KB
/
test_db.py
File metadata and controls
94 lines (68 loc) · 2.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""Tests for database module."""
import pytest
import sqlite3
from db import create_db
def test_create_db():
"""Test that database is created successfully."""
conn = create_db()
assert isinstance(conn, sqlite3.Connection)
conn.close()
def test_table_exists():
"""Test that paper_authorships table exists."""
conn = create_db()
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='paper_authorships'")
result = cursor.fetchone()
assert result is not None
assert result[0] == "paper_authorships"
conn.close()
def test_distinct_years():
"""Test counting distinct years in the database."""
conn = create_db()
cursor = conn.cursor()
cursor.execute("SELECT COUNT(DISTINCT year) FROM paper_authorships")
year_count = cursor.fetchone()[0]
assert year_count == 19
conn.close()
def test_total_records():
"""Test total number of records in the database."""
conn = create_db()
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM paper_authorships")
total_records = cursor.fetchone()[0]
assert total_records == 172164
conn.close()
def test_conferences():
"""Test that expected conferences are present."""
conn = create_db()
cursor = conn.cursor()
cursor.execute("SELECT DISTINCT conference FROM paper_authorships ORDER BY conference")
conferences = [row[0] for row in cursor.fetchall()]
assert "NeurIPS" in conferences
assert "ICML" in conferences
assert "ICLR" in conferences
assert len(conferences) == 3
conn.close()
def test_table_schema():
"""Test that table has expected columns."""
conn = create_db()
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(paper_authorships)")
columns = [row[1] for row in cursor.fetchall()]
expected_columns = ["Conference", "Year", "Title", "Author", "Affiliation"]
assert columns == expected_columns
conn.close()
def test_sample_data():
"""Test that we can retrieve sample data."""
conn = create_db()
cursor = conn.cursor()
cursor.execute("SELECT * FROM paper_authorships LIMIT 1")
row = cursor.fetchone()
assert row is not None
assert len(row) == 5 # Conference, Year, Title, Author, Affiliation
assert isinstance(row[0], str) # Conference
assert isinstance(row[1], int) # Year
assert isinstance(row[2], str) # Title
assert isinstance(row[3], str) # Author
assert isinstance(row[4], str) # Affiliation
conn.close()