Skip to content

Commit 18359da

Browse files
committed
refactor: create json objects using Pydantic
1 parent 3dcdaf9 commit 18359da

7 files changed

Lines changed: 232 additions & 62 deletions

File tree

app/function_schemas.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# -*- coding: utf-8 -*-
2+
"""Pydantic models for OpenAI function calling schemas."""
3+
4+
from enum import Enum
5+
from typing import Optional
6+
7+
from pydantic import BaseModel, Field
8+
9+
10+
class SpecializationArea(str, Enum):
11+
"""Available specialization areas for courses."""
12+
13+
AI = "AI"
14+
MOBILE = "mobile"
15+
WEB = "web"
16+
DATABASE = "database"
17+
NETWORK = "network"
18+
NEURAL_NETWORKS = "neural networks"
19+
20+
21+
class GetCoursesParams(BaseModel):
22+
"""Parameters for the get_courses function."""
23+
24+
max_cost: Optional[float] = Field(
25+
None, description="The maximum cost that a student is willing to pay for a course."
26+
)
27+
description: Optional[SpecializationArea] = Field(
28+
None, description="Areas of specialization for courses in the catalogue."
29+
)
30+
31+
32+
class RegisterCourseParams(BaseModel):
33+
"""Parameters for the register_course function."""
34+
35+
course_code: str = Field(description="The unique code for the course.")
36+
email: str = Field(description="The email address of the new user.")
37+
full_name: str = Field(description="The full name of the new user.")

app/prompt.py

Lines changed: 3 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -29,63 +29,6 @@
2929
logger = get_logger(__name__)
3030

3131

32-
def tool_factory_get_courses() -> ChatCompletionFunctionToolParam:
33-
"""Factory function to create a tool for getting courses"""
34-
return ChatCompletionFunctionToolParam(
35-
type="function",
36-
function={
37-
"name": "get_courses",
38-
"description": "returns up to 10 rows of course detail data, filtered by the maximum cost a student is willing to pay for a course and the area of specialization.\n",
39-
"parameters": {
40-
"type": "object",
41-
"required": [],
42-
"properties": {
43-
"max_cost": {
44-
"type": "number",
45-
"description": "the maximum cost that a student is willing to pay for a course.",
46-
},
47-
"description": {
48-
"enum": ["AI", "mobile", "web", "database", "network", "neural networks"],
49-
"type": "string",
50-
"description": "areas of specialization for courses in the catalogue.",
51-
},
52-
},
53-
"additionalProperties": False,
54-
},
55-
},
56-
)
57-
58-
59-
def tool_factory_register() -> ChatCompletionFunctionToolParam:
60-
"""Factory function to create a tool for registering a user"""
61-
return ChatCompletionFunctionToolParam(
62-
type="function",
63-
function={
64-
"name": "register_course",
65-
"description": "Register a student in a course with the provided details.",
66-
"parameters": {
67-
"type": "object",
68-
"required": ["course_code", "email", "full_name"],
69-
"properties": {
70-
"course_code": {
71-
"type": "string",
72-
"description": "The unique code for the course.",
73-
},
74-
"email": {
75-
"type": "string",
76-
"description": "The email address of the new user.",
77-
},
78-
"full_name": {
79-
"type": "string",
80-
"description": "The full name of the new user.",
81-
},
82-
},
83-
"additionalProperties": False,
84-
},
85-
},
86-
)
87-
88-
8932
messages: list[
9033
Union[
9134
ChatCompletionSystemMessageParam,
@@ -103,6 +46,7 @@ def tool_factory_register() -> ChatCompletionFunctionToolParam:
10346
Your task is to assist users with their queries related to the platform,
10447
including course information, enrollment procedures, and general support.
10548
You should respond in a concise and clear manner, providing accurate information based on the user's request.
49+
If you ask a follow up question, then place it at the bottom of the response and precede it with "QUESTION:".
10650
""",
10751
),
10852
ChatCompletionAssistantMessageParam(
@@ -206,7 +150,7 @@ def completion(prompt: str) -> tuple[ChatCompletion, list[str]]:
206150
model=model,
207151
messages=messages,
208152
tool_choice={"type": "function", "function": {"name": "get_courses"}},
209-
tools=[tool_factory_get_courses()],
153+
tools=[stackademy_app.tool_factory_get_courses()],
210154
temperature=temperature,
211155
max_tokens=max_tokens,
212156
)
@@ -223,7 +167,7 @@ def completion(prompt: str) -> tuple[ChatCompletion, list[str]]:
223167
response = openai.chat.completions.create(
224168
model=model,
225169
messages=messages,
226-
tools=[tool_factory_get_courses(), tool_factory_register()],
170+
tools=[stackademy_app.tool_factory_get_courses(), stackademy_app.tool_factory_register()],
227171
temperature=temperature,
228172
max_tokens=max_tokens,
229173
)

app/register.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,22 @@ def main():
2222
response, functions_called = completion(prompt=user_prompt)
2323
while response.choices[0].message.content != "Goodbye!":
2424
message = response.choices[0].message
25-
logger.info("ChatGPT: %s", message.content)
25+
response_message = message.content or ""
26+
logger.info("ChatGPT: %s", response_message.strip())
27+
28+
# Check if there's a follow-up question in the response
29+
if "QUESTION:" in response_message:
30+
question_line = [
31+
line.strip() for line in response_message.split("\n") if line.strip().startswith("QUESTION:")
32+
][0]
33+
followup_question = question_line.replace("QUESTION:", "").strip() + " "
34+
else:
35+
followup_question = None
2636

2737
if "get_courses" in functions_called:
28-
user_prompt = input("Would you like to register for a course? ")
38+
user_prompt = input(followup_question or "Would you like to register for a course? ")
2939
elif "register_course" in functions_called:
30-
user_prompt = input("Can I help you with anything else? ")
40+
user_prompt = input(followup_question or "Can I help you with anything else? ")
3141

3242
response, functions_called = completion(prompt=user_prompt)
3343

app/response_models.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# -*- coding: utf-8 -*-
2+
"""Response models for structured outputs from OpenAI."""
3+
4+
from typing import List, Optional
5+
6+
from pydantic import BaseModel, Field
7+
8+
9+
class Course(BaseModel):
10+
"""A course in the Stackademy catalog."""
11+
12+
course_code: str = Field(description="The unique code for the course")
13+
course_name: str = Field(description="The name of the course")
14+
description: str = Field(description="Course description")
15+
cost: float = Field(description="Cost of the course")
16+
prerequisite_course_code: Optional[str] = Field(description="Prerequisite course code", default=None)
17+
prerequisite_course_name: Optional[str] = Field(description="Prerequisite course name", default=None)
18+
19+
20+
class CourseSearchResponse(BaseModel):
21+
"""Response model for course search results."""
22+
23+
courses: List[Course] = Field(description="List of courses matching the search criteria")
24+
total_count: int = Field(description="Total number of courses found")
25+
26+
27+
class RegistrationResponse(BaseModel):
28+
"""Response model for course registration."""
29+
30+
success: bool = Field(description="Whether the registration was successful")
31+
message: str = Field(description="Human-readable message about the registration result")
32+
registration_id: Optional[str] = Field(description="Unique registration ID if successful", default=None)

app/stackademy.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33

44
from typing import Any, Dict, List, Optional
55

6+
from openai.types.chat import ChatCompletionFunctionToolParam
7+
68
from app.database import db
79
from app.exceptions import ConfigurationException
10+
from app.function_schemas import GetCoursesParams, RegisterCourseParams
811
from app.logging_config import get_logger, setup_logging
912

1013

@@ -20,6 +23,28 @@ def __init__(self):
2023
"""Initialize the Stackademy application."""
2124
self.db = db
2225

26+
def tool_factory_get_courses(self) -> ChatCompletionFunctionToolParam:
27+
"""LLM Factory function to create a tool for getting courses"""
28+
return ChatCompletionFunctionToolParam(
29+
type="function",
30+
function={
31+
"name": "get_courses",
32+
"description": "returns up to 10 rows of course detail data, filtered by the maximum cost a student is willing to pay for a course and the area of specialization.",
33+
"parameters": GetCoursesParams.model_json_schema(),
34+
},
35+
)
36+
37+
def tool_factory_register(self) -> ChatCompletionFunctionToolParam:
38+
"""LLMFactory function to create a tool for registering a user"""
39+
return ChatCompletionFunctionToolParam(
40+
type="function",
41+
function={
42+
"name": "register_course",
43+
"description": "Register a student in a course with the provided details.",
44+
"parameters": RegisterCourseParams.model_json_schema(),
45+
},
46+
)
47+
2348
def test_database_connection(self) -> bool:
2449
"""
2550
Test the database connection.

app/structured_outputs.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# -*- coding: utf-8 -*-
2+
"""Example of using OpenAI's structured outputs with Pydantic models."""
3+
4+
import json
5+
from typing import Optional
6+
7+
import openai
8+
from pydantic import ValidationError
9+
10+
from app import settings
11+
from app.function_schemas import (
12+
GetCoursesParams,
13+
RegisterCourseParams,
14+
SpecializationArea,
15+
)
16+
from app.logging_config import get_logger
17+
from app.response_models import Course, CourseSearchResponse, RegistrationResponse
18+
from app.stackademy import stackademy_app
19+
20+
21+
logger = get_logger(__name__)
22+
23+
24+
def get_courses_with_structured_output(
25+
description: Optional[str] = None, max_cost: Optional[float] = None
26+
) -> CourseSearchResponse:
27+
"""
28+
Get courses using structured output parsing.
29+
30+
This ensures the response conforms to our expected schema.
31+
"""
32+
try:
33+
# Convert string to enum if provided
34+
specialization_area = None
35+
if description:
36+
37+
try:
38+
specialization_area = SpecializationArea(description)
39+
except ValueError:
40+
logger.warning("Invalid specialization area: %s", description)
41+
specialization_area = None
42+
43+
# Validate input parameters using Pydantic
44+
params = GetCoursesParams(description=specialization_area, max_cost=max_cost)
45+
46+
# Get raw course data
47+
courses_data = stackademy_app.get_courses(
48+
description=params.description if params.description else None, max_cost=params.max_cost
49+
)
50+
51+
courses = [Course(**course_dict) for course_dict in courses_data]
52+
53+
# Create structured response
54+
return CourseSearchResponse(courses=courses, total_count=len(courses))
55+
56+
except ValidationError as e:
57+
logger.error("Parameter validation error: %s", e)
58+
return CourseSearchResponse(courses=[], total_count=0)
59+
# pylint: disable=broad-except
60+
except Exception as e:
61+
logger.error("Error getting courses: %s", e)
62+
return CourseSearchResponse(courses=[], total_count=0)
63+
64+
65+
def register_course_with_structured_output(course_code: str, email: str, full_name: str) -> RegistrationResponse:
66+
"""
67+
Register for a course using structured output.
68+
"""
69+
try:
70+
# Validate input parameters
71+
params = RegisterCourseParams(course_code=course_code, email=email, full_name=full_name)
72+
73+
# Attempt registration
74+
success = stackademy_app.register_course(
75+
course_code=params.course_code, email=params.email, full_name=params.full_name
76+
)
77+
78+
if success:
79+
return RegistrationResponse(
80+
success=True,
81+
message=f"Successfully registered {full_name} for course {course_code}",
82+
registration_id=f"REG-{course_code}-{hash(email) % 10000:04d}",
83+
)
84+
return RegistrationResponse(success=False, message="Registration failed. Please try again later.")
85+
86+
except ValidationError as e:
87+
logger.error("Parameter validation error: %s", e)
88+
return RegistrationResponse(success=False, message=f"Invalid parameters: {e}")
89+
# pylint: disable=broad-except
90+
except Exception as e:
91+
logger.error("Registration error: %s", e)
92+
return RegistrationResponse(success=False, message="An unexpected error occurred during registration.")
93+
94+
95+
# Example of using OpenAI's beta structured outputs (requires openai>=1.0.0)
96+
# pylint: disable=unused-argument
97+
def completion_with_structured_output(prompt: str, response_model: type):
98+
"""
99+
Example of using OpenAI's structured output parsing.
100+
101+
This is available in the beta API and ensures responses conform to schemas.
102+
"""
103+
try:
104+
openai.api_key = settings.OPENAI_API_KEY
105+
106+
# Note: This is a beta feature and may not be available in all OpenAI versions
107+
# response = openai.beta.chat.completions.parse(
108+
# model=settings.OPENAI_API_MODEL,
109+
# messages=[{"role": "user", "content": prompt}],
110+
# response_format=response_model,
111+
# )
112+
#
113+
# return response.choices[0].message.parsed
114+
115+
# For now, we'll return a placeholder since this is beta
116+
logger.info("Structured output parsing would be used here with model: %s", response_model.__name__)
117+
return None
118+
# pylint: disable=broad-except
119+
except Exception as e:
120+
logger.error("Structured completion error: %s", e)
121+
return None

requirements/prod.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88

99
python-dotenv==1.1.1
1010
openai==2.0.0
11+
pydantic==2.10.0
1112
PyMySQL==1.1.1

0 commit comments

Comments
 (0)