Skip to content

Commit 0b3e413

Browse files
Merge pull request #196 from cuappdev/sophie/auth
Implement auth
2 parents f9a5859 + 3cdd610 commit 0b3e413

7 files changed

Lines changed: 192 additions & 17 deletions

File tree

app_factory.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
import logging
2+
from datetime import timedelta, timezone
3+
from flask_jwt_extended import JWTManager
24
from datetime import datetime
35
from flask import Flask, render_template
46
from graphene import Schema
57
from graphql.utils import schema_printer
8+
from src.utils.constants import JWT_SECRET_KEY
69
from src.database import db_session, init_db
710
from src.database import Base as db
811
from src.database import db_url, db_user, db_password, db_name, db_host, db_port
912
from flask_migrate import Migrate
1013
from src.schema import Query, Mutation
1114
from flasgger import Swagger
1215
from flask_graphql import GraphQLView
16+
from src.models.token_blacklist import TokenBlocklist
17+
1318

1419
# Set up logging at module level
1520
logging.basicConfig(format="%(asctime)s %(levelname)-8s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S")
@@ -51,6 +56,17 @@ def create_app(run_migrations=False):
5156
schema = Schema(query=Query, mutation=Mutation)
5257
swagger = Swagger(app)
5358

59+
app.config["JWT_SECRET_KEY"] = JWT_SECRET_KEY
60+
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(hours=1)
61+
app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=30)
62+
63+
jwt = JWTManager(app)
64+
65+
@jwt.token_in_blocklist_loader
66+
def check_if_token_revoked(jwt_header, jwt_payload: dict) -> bool:
67+
jti = jwt_payload["jti"]
68+
return db_session.query(TokenBlocklist.id).filter_by(jti=jti).scalar() is not None
69+
5470
# Configure routes
5571
logger.info("Configuring routes")
5672

@@ -158,6 +174,13 @@ def scrape_classes():
158174
except Exception as e:
159175
logging.error(f"Error in scrape_classes: {e}")
160176

177+
@scheduler.task("interval", id="cleanup_expired_tokens", hours=24)
178+
def cleanup_expired_tokens():
179+
logger.info("Deleting expired tokens...")
180+
now = datetime.now(timezone.utc)
181+
db_session.query(TokenBlocklist).filter(TokenBlocklist.expires_at < now).delete()
182+
db_session.commit()
183+
161184
# Update hourly average capacity every hour
162185
@scheduler.task("cron", id="update_capacity", hour="*")
163186
def scheduled_job():
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""add token_blacklist table
2+
3+
Revision ID: 7245f58bb00a
4+
Revises: 0fde4435424e
5+
Create Date: 2025-03-12 17:46:57.085233
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
11+
12+
# revision identifiers, used by Alembic.
13+
revision = '7245f58bb00a'
14+
down_revision = '0fde4435424e'
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade():
20+
# Create the token_blacklist table
21+
op.execute("""
22+
DO $$
23+
BEGIN
24+
IF NOT EXISTS (
25+
SELECT 1
26+
FROM information_schema.tables
27+
WHERE table_name = 'token_blacklist'
28+
) THEN
29+
CREATE TABLE token_blacklist (
30+
id SERIAL PRIMARY KEY,
31+
jti VARCHAR(36) NOT NULL,
32+
expires_at TIMESTAMP NOT NULL
33+
);
34+
CREATE INDEX ix_token_blacklist_jti ON token_blacklist(jti);
35+
END IF;
36+
END $$;
37+
""")
38+
# Create an index on the jti column for faster lookups
39+
40+
41+
def downgrade():
42+
# Drop the index and then the table
43+
op.execute("DROP TABLE IF EXISTS token_blacklist;")

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,5 @@ wasmer-compiler-cranelift==1.1.0
7878
wcwidth==0.2.6
7979
Werkzeug==2.2.2
8080
zipp==3.15.0
81-
sentry-sdk==2.13.0
81+
sentry-sdk==2.13.0
82+
flask_jwt_extended==4.7.1

schema.graphql

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,6 @@ type CreateReport {
7171

7272
scalar DateTime
7373

74-
enum DayOfWeekEnum {
75-
MONDAY
76-
TUESDAY
77-
WEDNESDAY
78-
THURSDAY
79-
FRIDAY
80-
SATURDAY
81-
SUNDAY
82-
}
83-
8474
enum DayOfWeekGraphQLEnum {
8575
MONDAY
8676
TUESDAY
@@ -158,6 +148,15 @@ type HourlyAverageCapacity {
158148

159149
scalar JSONString
160150

151+
type LoginUser {
152+
accessToken: String
153+
refreshToken: String
154+
}
155+
156+
type LogoutUser {
157+
success: Boolean
158+
}
159+
161160
enum MuscleGroup {
162161
ABDOMINALS
163162
CHEST
@@ -179,6 +178,9 @@ type Mutation {
179178
enterGiveaway(giveawayId: Int!, userNetId: String!): GiveawayInstance
180179
setWorkoutGoals(userId: Int!, workoutGoal: [String]!): User
181180
logWorkout(facilityId: Int!, userId: Int!, workoutTime: DateTime!): Workout
181+
loginUser(netId: String!): LoginUser
182+
logoutUser: LogoutUser
183+
refreshAccessToken: RefreshAccessToken
182184
createReport(createdAt: DateTime!, description: String!, gymId: Int!, issue: String!): CreateReport
183185
deleteUser(userId: Int!): User
184186
}
@@ -222,6 +224,10 @@ type Query {
222224
getHourlyAverageCapacitiesByFacilityId(facilityId: Int): [HourlyAverageCapacity]
223225
}
224226

227+
type RefreshAccessToken {
228+
newAccessToken: String
229+
}
230+
225231
type Report {
226232
id: ID!
227233
createdAt: DateTime!
@@ -246,7 +252,7 @@ type User {
246252
name: String!
247253
activeStreak: Int
248254
maxStreak: Int
249-
workoutGoal: [DayOfWeekEnum]
255+
workoutGoal: [DayOfWeekGraphQLEnum]
250256
giveaways: [Giveaway]
251257
}
252258

src/models/token_blacklist.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from sqlalchemy import Column, String, Integer, DateTime
2+
from src.database import Base
3+
4+
5+
class TokenBlocklist(Base):
6+
"""
7+
Represents a JWT token that has been revoked (blacklisted).
8+
9+
Attributes:
10+
- `id` The primary key of the token record.
11+
- `jti` The unique identifier (JWT ID) of the token. Indexed for fast lookup.
12+
- `expires_at` The DateTime when the token expires.
13+
"""
14+
15+
__tablename__ = "token_blacklist"
16+
17+
id = Column(Integer, primary_key=True)
18+
jti = Column(String(36), index=True, nullable=False)
19+
expires_at = Column(DateTime, nullable=False)

0 commit comments

Comments
 (0)