Skip to content

Commit f09f7d1

Browse files
committed
more efficient blog loading
1 parent 259b44f commit f09f7d1

1 file changed

Lines changed: 65 additions & 26 deletions

File tree

app/services/blog/blog_handler.py

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import sqlalchemy.exc
1212
from fastapi import UploadFile
1313
from pydantic import BaseModel, Field, model_validator
14-
from sqlalchemy import Select, delete, select
14+
from sqlalchemy import Select, delete, func, select
1515
from sqlalchemy.ext.asyncio import AsyncSession
1616
from sqlalchemy.orm import selectinload
1717

@@ -143,7 +143,7 @@ class Paginator(BaseModel, arbitrary_types_allowed=True):
143143
is_only_page: bool = False
144144

145145

146-
async def get_blog_posts( # noqa: PLR0913 (too-many-arguments)
146+
async def get_blog_posts( # noqa: PLR0913,PLR0914 (too-many-arguments, too-many-locals)
147147
*,
148148
db: AsyncSession,
149149
can_see_unpublished: bool,
@@ -154,42 +154,81 @@ async def get_blog_posts( # noqa: PLR0913 (too-many-arguments)
154154
results_per_page: int = 20,
155155
page: int = 1,
156156
) -> Paginator:
157-
"""Get blog posts."""
157+
"""Get blog posts.
158+
159+
Uses a window-function `COUNT(*) OVER()` to retrieve the total number of
160+
matching rows in the same round-trip as the paginated data. The fallback
161+
path (separate COUNT query + re-fetch) is only taken when the requested
162+
page is beyond the last page — an uncommon edge case.
163+
"""
158164
stmt = _get_bp_list_statement()
165+
# Build a parallel count statement for the out-of-range-page fallback.
159166
count_stmt = select(sqlalchemy.func.count()).select_from(db_models.BlogPost)
160167
if not can_see_unpublished:
161-
stmt = stmt.filter(db_models.BlogPost.is_published.is_(True))
162-
count_stmt = count_stmt.where(db_models.BlogPost.is_published.is_(True))
168+
filter_expr = db_models.BlogPost.is_published.is_(True)
169+
stmt = stmt.filter(filter_expr)
170+
count_stmt = count_stmt.where(filter_expr)
163171
if tags:
164172
tags_list = transforms.to_list(tags, lowercase=True)
165-
stmt = stmt.filter(db_models.BlogPost.tags.any(db_models.BlogPostTag.tag.in_(tags_list)))
166-
count_stmt = count_stmt.where(
167-
db_models.BlogPost.tags.any(db_models.BlogPostTag.tag.in_(tags_list))
168-
)
173+
filter_expr = db_models.BlogPost.tags.any(db_models.BlogPostTag.tag.in_(tags_list))
174+
stmt = stmt.filter(filter_expr)
175+
count_stmt = count_stmt.where(filter_expr)
169176
if search:
170-
stmt = stmt.filter(db_models.BlogPost.ts_vector.match(search))
171-
count_stmt = count_stmt.where(db_models.BlogPost.ts_vector.match(search))
172-
173-
# Get total blog posts matching results here
174-
count_result = await db.execute(count_stmt)
175-
total_results = count_result.scalar() or 0
176-
total_pages = _calculate_total_pages(
177-
total_results=total_results, results_per_page=results_per_page
178-
)
179-
actual_page = min(page, total_pages)
180-
actual_page = max(actual_page, 1)
177+
filter_expr = db_models.BlogPost.ts_vector.match(search)
178+
stmt = stmt.filter(filter_expr)
179+
count_stmt = count_stmt.where(filter_expr)
181180

182181
order_by = getattr(db_models.BlogPost, order_by_field)
183182
if not asc:
184183
order_by = order_by.desc()
185-
limit, offset = _calculate_limit_offset(results_per_page=results_per_page, page=actual_page)
186-
stmt = stmt.order_by(order_by).limit(limit).offset(offset)
187-
result = await db.execute(stmt)
188-
blog_posts = list(result.scalars().all())
184+
185+
page = max(page, 1)
186+
limit, offset = _calculate_limit_offset(results_per_page=results_per_page, page=page)
187+
188+
# Single query: annotate each row with COUNT(*) OVER() so we learn the
189+
# total matching rows without a separate round-trip.
190+
windowed_stmt = (
191+
stmt
192+
.add_columns(func.count().over().label("total"))
193+
.order_by(order_by)
194+
.limit(limit)
195+
.offset(offset)
196+
)
197+
result = await db.execute(windowed_stmt)
198+
rows = result.all()
199+
200+
if rows:
201+
# Happy path: window function gives us the total for free.
202+
total_results: int = rows[0][1]
203+
blog_posts = [row[0] for row in rows]
204+
actual_page = page
205+
elif page == 1:
206+
# Page 1 returned nothing — there are simply zero matching posts.
207+
total_results = 0
208+
blog_posts = []
209+
actual_page = 1
210+
else:
211+
# Page is beyond the last page. Fall back to a COUNT query and
212+
# re-fetch the correct (clamped) page. This path is rarely hit.
213+
count_result = await db.execute(count_stmt)
214+
total_results = count_result.scalar() or 0
215+
total_pages_inner = _calculate_total_pages(
216+
total_results=total_results, results_per_page=results_per_page
217+
)
218+
actual_page = min(page, max(total_pages_inner, 1))
219+
limit, offset = _calculate_limit_offset(results_per_page=results_per_page, page=actual_page)
220+
refetch_result = await db.execute(stmt.order_by(order_by).limit(limit).offset(offset))
221+
blog_posts = list(refetch_result.scalars().all())
222+
223+
total_pages = _calculate_total_pages(
224+
total_results=total_results, results_per_page=results_per_page
225+
)
226+
actual_page = min(max(actual_page, 1), max(total_pages, 1))
227+
_, final_offset = _calculate_limit_offset(results_per_page=results_per_page, page=actual_page)
189228
return Paginator(
190229
blog_posts=blog_posts,
191-
min_result=offset + 1,
192-
max_result=offset + len(blog_posts),
230+
min_result=final_offset + 1,
231+
max_result=final_offset + len(blog_posts),
193232
total_results=total_results,
194233
total_pages=total_pages,
195234
current_page=actual_page,

0 commit comments

Comments
 (0)