Skip to content

Commit e7a81a7

Browse files
committed
AI code review changes round 1
1 parent 5209e63 commit e7a81a7

11 files changed

Lines changed: 66 additions & 62 deletions

File tree

.github/copilot-instructions.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,8 @@ python -m scripts.alembic upgrade head
112112
- The `logged_in_basic_user` / `logged_in_admin_user` fixtures call `test_client.cookies.clear()` on teardown. If a test logs in **manually** (e.g., `POST /login` or `POST /auth/token`) without using these fixtures, it **must** call `test_client.cookies.clear()` at the end to avoid leaking cookie state into subsequent tests.
113113
- `BASIC_COOKIE` and `ADMIN_COOKIE` in `tests/__init__.py` are module-level dicts shared across the session. They are cleared between test modules by `_clear_tokens()` called from the `clean_db` fixture teardown.
114114
- Flash messages are stored in the **session cookie**. A stale session cookie from a previous test (e.g., one that logged in) can cause flash messages from a subsequent test to be lost or doubled.
115+
116+
## DO these things
117+
118+
- Run `pre-commit run --all-files` frequently to apply all linters and formatters locally. The `pre-commit` config is in `.pre-commit-config.yaml`
119+
- `ai_workspace/` at the project root is in `.gitignore` and safe for writing debug scripts and output files. Use this instead of writing to `/tmp/` or other shared locations.

app/datastore/database.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# connect_args={"check_same_thread": False} for SQLite
1717

1818
ENGINE: AsyncEngine | None = None
19+
SESSION_MAKER: async_sessionmaker[AsyncSession] | None = None
1920

2021

2122
def get_engine(
@@ -26,7 +27,7 @@ def get_engine(
2627
new: bool = False,
2728
) -> AsyncEngine:
2829
"""Return the database engine, creating a new one if needed."""
29-
global ENGINE # noqa: PLW0603 (global-statement)
30+
global ENGINE, SESSION_MAKER # noqa: PLW0603 (global-statement)
3031
if not new and ENGINE is not None:
3132
return ENGINE
3233

@@ -39,14 +40,21 @@ def get_engine(
3940
pool_size=pool_size,
4041
pool_pre_ping=True,
4142
)
43+
SESSION_MAKER = None # reset when engine changes
4244
return ENGINE
4345

4446

47+
def get_session_maker() -> async_sessionmaker[AsyncSession]:
48+
"""Return the async session maker, creating a new one if needed."""
49+
global SESSION_MAKER # noqa: PLW0603 (global-statement)
50+
if SESSION_MAKER is None:
51+
SESSION_MAKER = async_sessionmaker(get_engine(), expire_on_commit=False)
52+
return SESSION_MAKER
53+
54+
4555
async def get_db_session() -> AsyncGenerator[AsyncSession]:
4656
"""Start a SessionLocal transaction and yield it."""
47-
engine = get_engine()
48-
async_session = async_sessionmaker(engine, expire_on_commit=False)
49-
async with async_session() as session:
57+
async with get_session_maker()() as session:
5058
yield session
5159

5260

app/services/blog/blog_handler.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ class SaveBlogResponse(BaseModel, arbitrary_types_allowed=True):
5656
field_errors: defaultdict[str, list[str]] = defaultdict(list)
5757

5858

59+
SaveBlogResponse.model_rebuild()
60+
61+
5962
class CommentInputPreview(BaseModel, arbitrary_types_allowed=True):
6063
"""Input for comment preview."""
6164

@@ -89,6 +92,9 @@ class SaveCommentResponse(BaseModel, arbitrary_types_allowed=True):
8992
field_errors: defaultdict[str, list[str]] = defaultdict(list)
9093

9194

95+
SaveCommentResponse.model_rebuild()
96+
97+
9298
def _get_bp_statement() -> Select:
9399
return select(db_models.BlogPost).options(
94100
selectinload(db_models.BlogPost.tags),
@@ -286,29 +292,21 @@ async def _get_bp_from_slug_history(db: AsyncSession, slug: str) -> db_models.Bl
286292
"""Get a blog post from its slug history."""
287293
try:
288294
stmt = (
289-
select(db_models.OldBlogPostSlug)
290-
.options(
291-
selectinload(db_models.OldBlogPostSlug.blog_post)
292-
.selectinload(db_models.BlogPost.comments)
293-
.selectinload(db_models.BlogPostComment.user),
294-
selectinload(db_models.OldBlogPostSlug.blog_post).selectinload(
295-
db_models.BlogPost.tags
296-
),
295+
_get_bp_statement()
296+
.join(
297+
db_models.OldBlogPostSlug,
298+
db_models.OldBlogPostSlug.blog_post_id == db_models.BlogPost.id,
297299
)
298300
.filter(db_models.OldBlogPostSlug.slug == slug)
299301
)
300302
result = await db.execute(stmt)
301-
slug_object = result.scalars().one()
303+
return result.scalars().one()
302304
except sqlalchemy.exc.NoResultFound as e:
303305
raise errors.BlogPostNotFoundError from e
304-
else:
305-
return slug_object.blog_post
306306

307307

308308
async def save_blog_post(db: AsyncSession, data: SaveBlogInput) -> SaveBlogResponse:
309309
"""Save a blog post."""
310-
SaveBlogResponse.model_rebuild()
311-
312310
field_errors: defaultdict[str, list[str]] = defaultdict(list)
313311
try:
314312
blog_post = await _save_bp_to_db(data=data, db=db)
@@ -389,7 +387,7 @@ async def update_existing_bp_fields( # noqa: C901 (complexity)
389387
blog_post.series_id = data.series_id
390388
if blog_post.series_position != data.series_position:
391389
blog_post.series_position = data.series_position
392-
blog_post.updated_timestamp = datetime.now().astimezone(UTC)
390+
blog_post.updated_timestamp = datetime.now(UTC)
393391
return blog_post
394392

395393

@@ -414,7 +412,7 @@ async def set_new_bp_fields(
414412
html_description = markdown_parser.markdown_to_html(data.description)
415413
html_content = markdown_parser.markdown_to_html(data.content)
416414
tags = await _get_bp_tags(db=db, tags=data.tags)
417-
now = datetime.now().astimezone(UTC)
415+
now = datetime.now(UTC)
418416
return db_models.BlogPost(
419417
title=data.title,
420418
slug=blog_utils.get_slug(data.title),
@@ -575,7 +573,7 @@ async def commit_media_to_db( # noqa: PLR0913 (too-many-arguments)
575573
name=name,
576574
locations=locations_str,
577575
media_type=media_type,
578-
created_timestamp=datetime.now().astimezone(UTC),
576+
created_timestamp=datetime.now(UTC),
579577
position=position,
580578
)
581579
db.add(bp_media_object)
@@ -623,7 +621,7 @@ async def update_existing_comment(
623621
html_content = generate_comment_html(md_content)
624622
comment.md_content = md_content
625623
comment.html_content = html_content
626-
comment.updated_timestamp = datetime.now().astimezone(UTC)
624+
comment.updated_timestamp = datetime.now(UTC)
627625
if current_user.is_authenticated:
628626
comment.user_id = current_user.id
629627
await db.commit()
@@ -634,7 +632,7 @@ async def update_existing_comment(
634632
def generate_comment(data: CommentInputPreview) -> db_models.BlogPostComment:
635633
"""Generate a blog post comment."""
636634
html_content = generate_comment_html(data.content)
637-
now = datetime.now().astimezone(UTC)
635+
now = datetime.now(UTC)
638636
return db_models.BlogPostComment(
639637
blog_post_id=data.bp_id,
640638
name=data.name,
@@ -680,18 +678,20 @@ def can_edit_comment(
680678
comment: db_models.BlogPostComment, current_user: db_models.User | UnauthenticatedUser
681679
) -> bool:
682680
"""Check if a user can edit this comment."""
683-
return comment.user_id == current_user.id or comment.guest_id == current_user.guest_id
681+
if current_user.is_authenticated:
682+
return comment.user_id == current_user.id
683+
return bool(comment.guest_id and comment.guest_id == current_user.guest_id)
684684

685685

686686
def can_delete_comment(
687687
comment: db_models.BlogPostComment, current_user: db_models.User | UnauthenticatedUser
688688
) -> bool:
689689
"""Check if a user can delete this comment. Allows admin to delete any comment."""
690-
return (
691-
comment.user_id == current_user.id
692-
or comment.guest_id == current_user.guest_id
693-
or current_user.is_admin
694-
)
690+
if current_user.is_admin:
691+
return True
692+
if current_user.is_authenticated:
693+
return comment.user_id == current_user.id
694+
return bool(comment.guest_id and comment.guest_id == current_user.guest_id)
695695

696696

697697
async def get_comment_from_id(db: AsyncSession, comment_id: int) -> db_models.BlogPostComment:

app/services/general/encryption_handler.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99

1010
from app.settings import settings
1111

12-
ENCRYPTION_KEY = settings.encryption_key
13-
1412

1513
def derive_iv(plaintext: str) -> bytes:
1614
"""Derive a consistent IV from the plaintext using a hash function.
@@ -39,7 +37,7 @@ def encrypt(data: str) -> str:
3937
str: The Base64 encoded encrypted data including the IV.
4038
4139
"""
42-
key_bytes = bytes.fromhex(ENCRYPTION_KEY)
40+
key_bytes = bytes.fromhex(settings.encryption_key)
4341
data_bytes = data.encode("utf-8")
4442
iv = derive_iv(data)
4543
cipher = Cipher(algorithms.AES(key_bytes), modes.CBC(iv), backend=default_backend())
@@ -68,7 +66,7 @@ def decrypt(encoded_data: str) -> str:
6866
str: The decrypted data as a string.
6967
7068
"""
71-
key_bytes = bytes.fromhex(ENCRYPTION_KEY)
69+
key_bytes = bytes.fromhex(settings.encryption_key)
7270
encrypted_data_with_iv = base64.b64decode(encoded_data)
7371

7472
iv = encrypted_data_with_iv[:16]
@@ -89,7 +87,7 @@ def decrypt(encoded_data: str) -> str:
8987
if __name__ == "__main__": # pragma: no cover
9088
# Encryption key as a 64-character hex string (256 bits / 32 bytes)
9189
# key = os.urandom(32).hex() # noqa: ERA001 (commented-out code)
92-
print(f"Key (hex): {ENCRYPTION_KEY}") # noqa: T201 (print used for example)
90+
print(f"Key (hex): {settings.encryption_key}") # noqa: T201 (print used for example)
9391

9492
import uuid
9593

app/services/general/transforms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ def to_bool(value: Any) -> bool:
1818
return False
1919

2020

21-
def too_bool_validator(v: Any, _: ValidationInfo) -> bool:
22-
"""Cooerce any value to a bool."""
21+
def to_bool_validator(v: Any, _: ValidationInfo) -> bool:
22+
"""Coerce any value to a bool."""
2323
return to_bool(v)
2424

2525

26-
CoercedBool = Annotated[bool, BeforeValidator(too_bool_validator)]
26+
CoercedBool = Annotated[bool, BeforeValidator(to_bool_validator)]
2727

2828

2929
def to_list(obj: Any, *, lowercase: bool = False) -> list[str]:

app/services/users/user_handler.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ async def create_pw_reset_token(
170170
# Delete any existing tokens older than PW_RESET_TOKEN_EXPIRATION_MINUTES
171171
stmt = delete(db_models.PasswordResetToken).where(
172172
db_models.PasswordResetToken.created_timestamp
173-
< datetime.now().astimezone(UTC) - timedelta(minutes=PW_RESET_TOKEN_EXPIRATION_MINUTES)
173+
< datetime.now(UTC) - timedelta(minutes=PW_RESET_TOKEN_EXPIRATION_MINUTES)
174174
)
175175
await db.execute(stmt)
176176

@@ -179,9 +179,8 @@ async def create_pw_reset_token(
179179
pw_reset_token = db_models.PasswordResetToken(
180180
user_id=user.id,
181181
encrypted_query=encryption_handler.encrypt(query),
182-
created_timestamp=datetime.now().astimezone(UTC),
183-
expires_timestamp=datetime.now().astimezone(UTC)
184-
+ timedelta(minutes=PW_RESET_TOKEN_EXPIRATION_MINUTES),
182+
created_timestamp=datetime.now(UTC),
183+
expires_timestamp=datetime.now(UTC) + timedelta(minutes=PW_RESET_TOKEN_EXPIRATION_MINUTES),
185184
)
186185
db.add(pw_reset_token)
187186
await db.commit()
@@ -202,26 +201,14 @@ async def assert_token_is_valid(db: AsyncSession, query: str) -> db_models.Passw
202201
except sqlalchemy.orm.exc.NoResultFound as e:
203202
raise errors.PasswordResetTokenNotFoundError from e
204203
token_dt = token.expires_timestamp.astimezone(UTC)
205-
if token_dt < datetime.now().astimezone(UTC):
204+
if token_dt < datetime.now(UTC):
206205
raise errors.PasswordResetTokenExpiredError
207206
return token
208207

209208

210209
async def reset_password_from_token(db: AsyncSession, query: str, password: str) -> db_models.User:
211210
"""Reset a user's password from a password reset token."""
212-
encrypted_query = encryption_handler.encrypt(query)
213-
214-
get_token_stmt = select(db_models.PasswordResetToken).where(
215-
db_models.PasswordResetToken.encrypted_query == encrypted_query
216-
)
217-
get_token_result = await db.execute(get_token_stmt)
218-
try:
219-
pw_reset_token = get_token_result.scalars().one()
220-
except sqlalchemy.orm.exc.NoResultFound as e:
221-
raise errors.PasswordResetTokenNotFoundError from e
222-
token_dt = pw_reset_token.expires_timestamp.astimezone(UTC)
223-
if token_dt < datetime.now().astimezone(UTC):
224-
raise errors.PasswordResetTokenExpiredError
211+
pw_reset_token = await assert_token_is_valid(db, query)
225212
get_user_stmt = select(db_models.User).where(db_models.User.id == pw_reset_token.user_id)
226213
get_user_result = await db.execute(get_user_stmt)
227214
try:

app/web/html/routes/auth.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ async def login_for_access_token(
3434
value=token.access_token,
3535
httponly=True,
3636
secure=True,
37+
samesite="lax",
3738
)
3839
return token
3940

app/web/html/routes/sitemap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
@router.get("/sitemap.xml", response_model=None)
17-
@aiocache.cached(key="constant")
17+
@aiocache.cached(key="sitemap", ttl=3600)
1818
async def sitemap(request: Request, db: DBSession) -> HTMLResponse:
1919
"""Return the sitemap page."""
2020
return HTMLResponse(

app/web/html/routes/users.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@
3131
REGISTER_TEMPLATE = "users/register.html"
3232

3333

34+
def _safe_redirect(url: str | None) -> str:
35+
"""Return a safe same-origin redirect path, defaulting to '/'."""
36+
if url and url.startswith("/") and not url.startswith("//"):
37+
return url
38+
return "/"
39+
40+
3441
class LoginForm(Form):
3542
"""Form for user login page."""
3643

@@ -86,7 +93,7 @@ async def login_post(
8693
},
8794
status_code=status.HTTP_422_UNPROCESSABLE_CONTENT,
8895
)
89-
redirect_url = login_form.redirect_url.data or "/"
96+
redirect_url = _safe_redirect(login_form.redirect_url.data)
9097

9198
# Why not a RedirectResponse?
9299
# Because we need to set the HX-Redirect header to perform
@@ -240,7 +247,7 @@ async def register_post(
240247
async def logout(request: Request) -> RedirectResponse:
241248
"""Log the user out and redirect to the home page."""
242249
form_data = await request.form()
243-
redirect_url = str(form_data.get("next", "/"))
250+
redirect_url = _safe_redirect(str(form_data.get("next", "/")))
244251
response = RedirectResponse(
245252
redirect_url,
246253
status_code=status.HTTP_303_SEE_OTHER,

app/web/main.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def create_app() -> FastAPI:
4242
allow_methods=["*"],
4343
allow_headers=["*"],
4444
)
45-
app.add_middleware(SessionMiddleware, secret_key=settings.session_secret)
45+
app.add_middleware(SessionMiddleware, secret_key=settings.session_secret, max_age=86400)
4646

4747
@app.get("/api")
4848
async def api_home(request: Request) -> RedirectResponse:
@@ -63,8 +63,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None]: # noqa: ARG001 (unuse
6363
async with engine.begin() as conn:
6464
if settings.db_create_tables:
6565
await conn.run_sync(db_models.Base.metadata.create_all)
66-
else:
67-
yield
6866
except sqlalchemy.exc.OperationalError as e: # pragma: no cover
6967
err_msg = (
7068
"Could not connect to the database. Check the connection string."

0 commit comments

Comments
 (0)