Skip to content

Commit a4954aa

Browse files
committed
chore: add unit tests
1 parent d9ba4b0 commit a4954aa

3 files changed

Lines changed: 269 additions & 57 deletions

File tree

app/prompt.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -127,33 +127,58 @@ def process_tool_calls(message: ChatCompletionMessage) -> list[str]:
127127
def completion(prompt: str) -> tuple[ChatCompletion, list[str]]:
128128
"""LLM text completion"""
129129

130-
openai.api_key = settings.OPENAI_API_KEY
131-
model = settings.OPENAI_API_MODEL
132-
temperature = settings.OPENAI_API_TEMPERATURE
133-
max_tokens = settings.OPENAI_API_MAX_TOKENS
130+
def handle_completion(tools, tool_choice) -> ChatCompletion:
131+
"""Handle the OpenAI chat completion call."""
132+
openai.api_key = settings.OPENAI_API_KEY
133+
model = settings.OPENAI_API_MODEL
134+
135+
try:
136+
response = openai.chat.completions.create(
137+
model=model,
138+
messages=messages,
139+
tools=tools,
140+
tool_choice=tool_choice,
141+
temperature=settings.OPENAI_API_TEMPERATURE,
142+
max_tokens=settings.OPENAI_API_MAX_TOKENS,
143+
)
144+
logger.debug("OpenAI response: %s", response.model_dump())
145+
return response
146+
except openai.RateLimitError as e:
147+
logger.error("OpenAI rate limit exceeded: %s", e)
148+
raise
149+
except openai.APIConnectionError as e:
150+
logger.error("OpenAI API connection error: %s", e)
151+
raise
152+
except openai.AuthenticationError as e:
153+
logger.error("OpenAI authentication error: %s", e)
154+
raise
155+
except openai.BadRequestError as e:
156+
logger.error("OpenAI bad request error: %s", e)
157+
raise
158+
except openai.APIError as e:
159+
logger.error("OpenAI API error: %s", e)
160+
raise
161+
# pylint: disable=broad-except
162+
except Exception as e:
163+
logger.error("Unexpected error during OpenAI completion: %s", e)
164+
raise
165+
134166
messages.append(ChatCompletionUserMessageParam(role="user", content=prompt))
135167
functions_called = []
136168

137-
response = openai.chat.completions.create(
138-
model=model,
139-
messages=messages,
169+
response = handle_completion(
140170
tool_choice={"type": "function", "function": {"name": "get_courses"}},
141171
tools=[stackademy_app.tool_factory_get_courses()],
142-
temperature=temperature,
143-
max_tokens=max_tokens,
144172
)
145173
logger.debug("Initial response: %s", response.model_dump())
146174

147175
message = response.choices[0].message
148176
while message.tool_calls:
149177
functions_called = process_tool_calls(message)
150178

151-
response = openai.chat.completions.create(
152-
model=model,
153-
messages=messages,
179+
response = handle_completion(
154180
tools=[stackademy_app.tool_factory_get_courses(), stackademy_app.tool_factory_register()],
155-
temperature=temperature,
156-
max_tokens=max_tokens,
181+
tool_choice="auto",
157182
)
158183
message = response.choices[0].message
159184
logger.debug("Updated response: %s", response.model_dump())

app/stackademy.py

Lines changed: 17 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ def __init__(self):
5353
"""Initialize the Stackademy application."""
5454
self.db = db
5555

56+
def _log_success(self, message: str) -> None:
57+
"""
58+
Log a success message with colorized console output.
59+
60+
Args:
61+
message: The success message to log
62+
"""
63+
print(f"\033[1;92m{message}\033[0m")
64+
logger.info(message)
65+
5666
def tool_factory_get_courses(self) -> ChatCompletionFunctionToolParam:
5767
"""LLM Factory function to create a tool for getting courses"""
5868
return ChatCompletionFunctionToolParam(
@@ -147,10 +157,14 @@ def verify_course(self, course_code: str) -> bool:
147157
bool: True if the course exists, False otherwise
148158
"""
149159
query = "SELECT * FROM courses WHERE course_code = %s"
150-
logger.info("verify_course() course_code: %s", course_code)
151160
try:
152161
result = self.db.execute_query(query, (course_code,))
153-
return len(result) > 0
162+
retval = len(result) > 0
163+
if retval:
164+
logger.info("verified course_code: %s", course_code)
165+
else:
166+
logger.warning("course_code not found: %s", course_code)
167+
return retval
154168
# pylint: disable=broad-except
155169
except Exception as e:
156170
logger.error("Failed to retrieve courses: %s", e)
@@ -173,49 +187,9 @@ def register_course(self, course_code: str, email: str, full_name: str) -> bool:
173187
logger.error("Course code %s does not exist.", course_code)
174188
return False
175189

176-
# Print success message in bold bright green
177190
success_message = f"Successfully registered {full_name} ({email}) for course {course_code}."
178-
print(f"\033[1;92m{success_message}\033[0m")
179-
logger.info(success_message)
191+
self._log_success(success_message)
180192
return True
181193

182194

183-
def main():
184-
"""Main function to demonstrate database functionality."""
185-
print("Stackademy MySQL Database Demo")
186-
print("=" * 50)
187-
188-
try:
189-
# Initialize the application
190-
app = Stackademy()
191-
192-
# Test database connection
193-
logger.info("Testing database connection...")
194-
if not app.test_database_connection():
195-
logger.error("Database connection failed. Please check your configuration.")
196-
return
197-
logger.info("✅ Database connection successful!")
198-
199-
# Get courses
200-
logger.info("Retrieving courses...")
201-
courses = app.get_courses(description="python")
202-
for course in courses:
203-
logger.info(
204-
" - %s (%s) - %s - $%s",
205-
course["course_name"],
206-
course["course_code"],
207-
course["description"],
208-
course["cost"],
209-
)
210-
211-
except ConfigurationException as e:
212-
logger.error("Configuration error: %s", e)
213-
# pylint: disable=broad-except
214-
except Exception as e:
215-
logger.error("Application error: %s", e)
216-
217-
218195
stackademy_app = Stackademy()
219-
220-
if __name__ == "__main__":
221-
main()

app/tests/test_stackademy.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# -*- coding: utf-8 -*-
2+
# pylint: disable=wrong-import-position
3+
# pylint: disable=R0801
4+
"""Test Stackademy application."""
5+
6+
# python stuff
7+
import os
8+
import sys
9+
import unittest
10+
from pathlib import Path
11+
from unittest.mock import Mock, patch
12+
13+
from app.exceptions import ConfigurationException
14+
from app.logging_config import get_logger
15+
from app.stackademy import Stackademy
16+
17+
18+
HERE = os.path.abspath(os.path.dirname(__file__))
19+
PROJECT_ROOT = str(Path(HERE).parent.parent)
20+
PYTHON_ROOT = str(Path(PROJECT_ROOT).parent)
21+
if PYTHON_ROOT not in sys.path:
22+
sys.path.append(PYTHON_ROOT) # noqa: E402
23+
24+
25+
logger = get_logger(__name__)
26+
27+
28+
class TestStackademy(unittest.TestCase):
29+
"""Test Stackademy application."""
30+
31+
def setUp(self):
32+
"""Set up test fixtures before each test method."""
33+
self.app = Stackademy()
34+
35+
def test_stackademy_initialization(self):
36+
"""Test that the Stackademy application initializes successfully."""
37+
self.assertIsNotNone(self.app)
38+
self.assertIsNotNone(self.app.db)
39+
40+
def test_database_connection_success(self):
41+
"""Test successful database connection."""
42+
# Mock the database connection to return True
43+
with patch.object(self.app.db, "test_connection", return_value=True):
44+
result = self.app.test_database_connection()
45+
self.assertTrue(result)
46+
logger.info("Database connection test passed successfully")
47+
48+
def test_database_connection_failure(self):
49+
"""Test database connection failure."""
50+
# Mock the database connection to raise an exception
51+
with patch.object(self.app.db, "test_connection", side_effect=Exception("Connection failed")):
52+
result = self.app.test_database_connection()
53+
self.assertFalse(result)
54+
logger.info("Database connection failure test passed")
55+
56+
def test_get_courses_with_description_filter(self):
57+
"""Test retrieving courses with description filter."""
58+
# Mock course data
59+
mock_courses = [
60+
{
61+
"course_code": "PY101",
62+
"course_name": "Python Fundamentals",
63+
"description": "Learn Python programming basics",
64+
"cost": 299.99,
65+
"prerequisite_course_code": None,
66+
"prerequisite_course_name": None,
67+
},
68+
{
69+
"course_code": "PY201",
70+
"course_name": "Advanced Python",
71+
"description": "Advanced Python programming techniques",
72+
"cost": 399.99,
73+
"prerequisite_course_code": "PY101",
74+
"prerequisite_course_name": "Python Fundamentals",
75+
},
76+
]
77+
78+
# Mock the database query
79+
with patch.object(self.app.db, "execute_query", return_value=mock_courses):
80+
courses = self.app.get_courses(description="python")
81+
82+
self.assertEqual(len(courses), 2)
83+
self.assertEqual(courses[0]["course_code"], "PY101")
84+
self.assertEqual(courses[1]["course_code"], "PY201")
85+
86+
# Log course information as in the original code
87+
logger.info("Retrieved %d courses with python description", len(courses))
88+
for course in courses:
89+
logger.info(
90+
" - %s (%s) - %s - $%s",
91+
course["course_name"],
92+
course["course_code"],
93+
course["description"],
94+
course["cost"],
95+
)
96+
97+
def test_get_courses_with_cost_filter(self):
98+
"""Test retrieving courses with maximum cost filter."""
99+
mock_courses = [
100+
{
101+
"course_code": "WEB101",
102+
"course_name": "Web Development Basics",
103+
"description": "Introduction to web development",
104+
"cost": 199.99,
105+
"prerequisite_course_code": None,
106+
"prerequisite_course_name": None,
107+
}
108+
]
109+
110+
with patch.object(self.app.db, "execute_query", return_value=mock_courses):
111+
courses = self.app.get_courses(max_cost=250.0)
112+
113+
self.assertEqual(len(courses), 1)
114+
self.assertLessEqual(courses[0]["cost"], 250.0)
115+
logger.info("Retrieved courses under $250: %d", len(courses))
116+
117+
def test_get_courses_database_error(self):
118+
"""Test get_courses when database error occurs."""
119+
with patch.object(self.app.db, "execute_query", side_effect=Exception("Database error")):
120+
courses = self.app.get_courses(description="python")
121+
122+
self.assertEqual(len(courses), 0)
123+
logger.info("Database error handling test passed")
124+
125+
def test_get_courses_no_results(self):
126+
"""Test get_courses when no courses match criteria."""
127+
with patch.object(self.app.db, "execute_query", return_value=[]):
128+
courses = self.app.get_courses(description="nonexistent")
129+
130+
self.assertEqual(len(courses), 0)
131+
logger.info("No results test passed")
132+
133+
def test_application_workflow_with_configuration_exception(self):
134+
"""Test application workflow that raises ConfigurationException."""
135+
# pylint: disable=broad-exception-caught
136+
try:
137+
# Simulate a configuration error
138+
raise ConfigurationException("Invalid configuration setting")
139+
except ConfigurationException as e:
140+
logger.error("Configuration error: %s", e)
141+
self.assertIsInstance(e, ConfigurationException)
142+
143+
def test_application_workflow_with_general_exception(self):
144+
"""Test application workflow that raises general exception."""
145+
# pylint: disable=broad-exception-caught,broad-except
146+
try:
147+
# Simulate a general application error
148+
raise RuntimeError("General application error")
149+
except Exception as e:
150+
logger.error("Application error: %s", e)
151+
self.assertIsInstance(e, Exception)
152+
153+
def test_full_application_workflow(self):
154+
"""Test the complete application workflow as shown in the example."""
155+
mock_courses = [
156+
{
157+
"course_code": "PY101",
158+
"course_name": "Python Fundamentals",
159+
"description": "Learn Python programming from scratch",
160+
"cost": 299.99,
161+
"prerequisite_course_code": None,
162+
"prerequisite_course_name": None,
163+
},
164+
{
165+
"course_code": "AI201",
166+
"course_name": "Python for AI",
167+
"description": "Python programming for artificial intelligence",
168+
"cost": 499.99,
169+
"prerequisite_course_code": "PY101",
170+
"prerequisite_course_name": "Python Fundamentals",
171+
},
172+
]
173+
174+
# pylint: disable=broad-exception-caught
175+
try:
176+
# Initialize the application
177+
app = Stackademy()
178+
self.assertIsNotNone(app)
179+
180+
# Test database connection
181+
logger.info("Testing database connection...")
182+
with patch.object(app.db, "test_connection", return_value=True):
183+
if not app.test_database_connection():
184+
logger.error("Database connection failed. Please check your configuration.")
185+
self.fail("Database connection should have succeeded")
186+
logger.info("Database connection successful!")
187+
188+
# Get courses
189+
logger.info("Retrieving courses...")
190+
with patch.object(app.db, "execute_query", return_value=mock_courses):
191+
courses = app.get_courses(description="python")
192+
193+
self.assertEqual(len(courses), 2)
194+
195+
for course in courses:
196+
logger.info(
197+
" - %s (%s) - %s - $%s",
198+
course["course_name"],
199+
course["course_code"],
200+
course["description"],
201+
course["cost"],
202+
)
203+
204+
except ConfigurationException as e:
205+
logger.error("Configuration error: %s", e)
206+
self.fail(f"Unexpected ConfigurationException: {e}")
207+
except Exception as e:
208+
logger.error("Application error: %s", e)
209+
self.fail(f"Unexpected application error: {e}")
210+
211+
212+
if __name__ == "__main__":
213+
unittest.main()

0 commit comments

Comments
 (0)