From 20838276d0bbac94b8c7f23a8dd618e202ef00d1 Mon Sep 17 00:00:00 2001 From: Giorgio Maria Federico Birnthaler Date: Sun, 8 Feb 2026 19:31:31 +0100 Subject: [PATCH 1/3] feat(postgrest): support select chaining on write builders --- .../src/postgrest/_async/request_builder.py | 13 +++++++ .../src/postgrest/_sync/request_builder.py | 13 +++++++ .../tests/_async/test_request_builder.py | 35 ++++++++++++++++++- .../tests/_sync/test_request_builder.py | 35 ++++++++++++++++++- 4 files changed, 94 insertions(+), 2 deletions(-) diff --git a/src/postgrest/src/postgrest/_async/request_builder.py b/src/postgrest/src/postgrest/_async/request_builder.py index 3c996922..76c5711e 100644 --- a/src/postgrest/src/postgrest/_async/request_builder.py +++ b/src/postgrest/src/postgrest/_async/request_builder.py @@ -26,12 +26,25 @@ from ..utils import model_validate_json ReqConfig = RequestConfig[AsyncClient] +QueryBuilderT = TypeVar("QueryBuilderT", bound="AsyncQueryRequestBuilder") class AsyncQueryRequestBuilder: def __init__(self, request: ReqConfig): self.request = request + def select(self: QueryBuilderT, *columns: str) -> QueryBuilderT: + _, params, _, _ = pre_select(*columns, count=None) + self.request.params = self.request.params.add("select", params["select"]) + prefer_header = self.request.headers.get("Prefer") + if not prefer_header: + self.request.headers["Prefer"] = "return=representation" + elif "return=representation" not in [ + value.strip() for value in prefer_header.split(",") + ]: + self.request.headers["Prefer"] = f"{prefer_header},return=representation" + return self + async def execute(self) -> APIResponse: """Execute the query. diff --git a/src/postgrest/src/postgrest/_sync/request_builder.py b/src/postgrest/src/postgrest/_sync/request_builder.py index a5340403..33398e5a 100644 --- a/src/postgrest/src/postgrest/_sync/request_builder.py +++ b/src/postgrest/src/postgrest/_sync/request_builder.py @@ -26,12 +26,25 @@ from ..utils import model_validate_json ReqConfig = RequestConfig[Client] +QueryBuilderT = TypeVar("QueryBuilderT", bound="SyncQueryRequestBuilder") class SyncQueryRequestBuilder: def __init__(self, request: ReqConfig): self.request = request + def select(self: QueryBuilderT, *columns: str) -> QueryBuilderT: + _, params, _, _ = pre_select(*columns, count=None) + self.request.params = self.request.params.add("select", params["select"]) + prefer_header = self.request.headers.get("Prefer") + if not prefer_header: + self.request.headers["Prefer"] = "return=representation" + elif "return=representation" not in [ + value.strip() for value in prefer_header.split(",") + ]: + self.request.headers["Prefer"] = f"{prefer_header},return=representation" + return self + def execute(self) -> APIResponse: """Execute the query. diff --git a/src/postgrest/tests/_async/test_request_builder.py b/src/postgrest/tests/_async/test_request_builder.py index 356be18f..0c807730 100644 --- a/src/postgrest/tests/_async/test_request_builder.py +++ b/src/postgrest/tests/_async/test_request_builder.py @@ -7,7 +7,7 @@ from postgrest import AsyncRequestBuilder, AsyncSingleRequestBuilder from postgrest._async.request_builder import RequestConfig from postgrest.base_request_builder import APIResponse, SingleAPIResponse -from postgrest.types import JSON, CountMethod +from postgrest.types import JSON, CountMethod, ReturnMethod @pytest.fixture @@ -120,6 +120,27 @@ def test_upsert(self, request_builder: AsyncRequestBuilder): assert builder.request.http_method == "POST" assert builder.request.json == {"key1": "val1"} + def test_insert_with_select(self, request_builder: AsyncRequestBuilder): + builder = request_builder.insert({"key1": "val1"}).select("id", "key1") + + assert builder.request.params["select"] == "id,key1" + assert builder.request.headers.get_list("prefer", True) == [ + "return=representation" + ] + + def test_insert_with_select_forces_representation( + self, request_builder: AsyncRequestBuilder + ): + builder = request_builder.insert( + {"key1": "val1"}, returning=ReturnMethod.minimal + ).select("id") + + assert builder.request.params["select"] == "id" + assert builder.request.headers.get_list("prefer", True) == [ + "return=minimal", + "return=representation", + ] + def test_bulk_upsert_with_default(self, request_builder: AsyncRequestBuilder): builder = request_builder.upsert( [{"key1": "val1", "key2": "val2"}, {"key3": "val3"}], default_to_null=False @@ -168,6 +189,12 @@ def test_update_with_max_affected(self, request_builder: AsyncRequestBuilder): assert builder.request.http_method == "PATCH" assert builder.request.json == {"key1": "val1"} + def test_update_with_select(self, request_builder: AsyncRequestBuilder): + builder = request_builder.update({"key1": "val1"}).eq("id", 1).select("id") + + assert builder.request.params["id"] == "eq.1" + assert builder.request.params["select"] == "id" + class TestDelete: def test_delete(self, request_builder: AsyncRequestBuilder): @@ -198,6 +225,12 @@ def test_delete_with_max_affected(self, request_builder: AsyncRequestBuilder): assert builder.request.http_method == "DELETE" assert builder.request.json == {} + def test_delete_with_select(self, request_builder: AsyncRequestBuilder): + builder = request_builder.delete().eq("id", 1).select("id") + + assert builder.request.params["id"] == "eq.1" + assert builder.request.params["select"] == "id" + class TestTextSearch: def test_text_search(self, request_builder: AsyncRequestBuilder): diff --git a/src/postgrest/tests/_sync/test_request_builder.py b/src/postgrest/tests/_sync/test_request_builder.py index 5217de24..e439f3be 100644 --- a/src/postgrest/tests/_sync/test_request_builder.py +++ b/src/postgrest/tests/_sync/test_request_builder.py @@ -7,7 +7,7 @@ from postgrest import SyncRequestBuilder, SyncSingleRequestBuilder from postgrest._async.request_builder import RequestConfig from postgrest.base_request_builder import APIResponse, SingleAPIResponse -from postgrest.types import JSON, CountMethod +from postgrest.types import JSON, CountMethod, ReturnMethod @pytest.fixture @@ -120,6 +120,27 @@ def test_upsert(self, request_builder: SyncRequestBuilder): assert builder.request.http_method == "POST" assert builder.request.json == {"key1": "val1"} + def test_insert_with_select(self, request_builder: SyncRequestBuilder): + builder = request_builder.insert({"key1": "val1"}).select("id", "key1") + + assert builder.request.params["select"] == "id,key1" + assert builder.request.headers.get_list("prefer", True) == [ + "return=representation" + ] + + def test_insert_with_select_forces_representation( + self, request_builder: SyncRequestBuilder + ): + builder = request_builder.insert( + {"key1": "val1"}, returning=ReturnMethod.minimal + ).select("id") + + assert builder.request.params["select"] == "id" + assert builder.request.headers.get_list("prefer", True) == [ + "return=minimal", + "return=representation", + ] + def test_bulk_upsert_with_default(self, request_builder: SyncRequestBuilder): builder = request_builder.upsert( [{"key1": "val1", "key2": "val2"}, {"key3": "val3"}], default_to_null=False @@ -168,6 +189,12 @@ def test_update_with_max_affected(self, request_builder: SyncRequestBuilder): assert builder.request.http_method == "PATCH" assert builder.request.json == {"key1": "val1"} + def test_update_with_select(self, request_builder: SyncRequestBuilder): + builder = request_builder.update({"key1": "val1"}).eq("id", 1).select("id") + + assert builder.request.params["id"] == "eq.1" + assert builder.request.params["select"] == "id" + class TestDelete: def test_delete(self, request_builder: SyncRequestBuilder): @@ -198,6 +225,12 @@ def test_delete_with_max_affected(self, request_builder: SyncRequestBuilder): assert builder.request.http_method == "DELETE" assert builder.request.json == {} + def test_delete_with_select(self, request_builder: SyncRequestBuilder): + builder = request_builder.delete().eq("id", 1).select("id") + + assert builder.request.params["id"] == "eq.1" + assert builder.request.params["select"] == "id" + class TestTextSearch: def test_text_search(self, request_builder: SyncRequestBuilder): From 0e5be5f4ec80d2c1cf2bfef71c5d44f086f22a06 Mon Sep 17 00:00:00 2001 From: Leonardo Santiago Date: Wed, 29 Apr 2026 10:04:25 -0300 Subject: [PATCH 2/3] fix(postgrest): only ever set return= header once fixes edge cases where the `select` would set multiple values on the header --- .../src/postgrest/_async/request_builder.py | 11 +++++------ .../src/postgrest/_sync/request_builder.py | 15 +++++++-------- .../tests/_async/test_request_builder.py | 1 - src/postgrest/tests/_sync/test_request_builder.py | 1 - 4 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/postgrest/src/postgrest/_async/request_builder.py b/src/postgrest/src/postgrest/_async/request_builder.py index c605be8b..9523d2ca 100644 --- a/src/postgrest/src/postgrest/_async/request_builder.py +++ b/src/postgrest/src/postgrest/_async/request_builder.py @@ -63,13 +63,12 @@ def __init__(self, request: ReqConfig): def select(self: QueryBuilderT, *columns: str) -> QueryBuilderT: _, params, _, _ = pre_select(*columns, count=None) self.request.params = self.request.params.add("select", params["select"]) - prefer_header = self.request.headers.get("Prefer") - if not prefer_header: + if prefer_headers := self.request.headers.get_list("Prefer", split_commas=True): + prefer_headers = [h for h in prefer_headers if not h.startswith("return=")] + prefer_headers.append("return=representation") + self.request.headers["Prefer"] = ",".join(prefer_headers) + else: self.request.headers["Prefer"] = "return=representation" - elif "return=representation" not in [ - value.strip() for value in prefer_header.split(",") - ]: - self.request.headers["Prefer"] = f"{prefer_header},return=representation" return self def retry(self, enabled: bool) -> Self: diff --git a/src/postgrest/src/postgrest/_sync/request_builder.py b/src/postgrest/src/postgrest/_sync/request_builder.py index 6f541219..de6b3b77 100644 --- a/src/postgrest/src/postgrest/_sync/request_builder.py +++ b/src/postgrest/src/postgrest/_sync/request_builder.py @@ -1,6 +1,6 @@ from __future__ import annotations -import time +import asyncio from typing import Any, Generic, Literal, Optional, TypeVar, Union, overload from httpx import BasicAuth, Client, Headers, QueryParams, Response @@ -51,7 +51,7 @@ def send_with_retry(req: ReqConfig) -> Response: resp = req.send(headers) if resp.is_success or not req.should_retry(resp, attempt_count=attempt_count): break - time.sleep(get_retry_delay(resp, attempt_count)) + asyncio.sleep(get_retry_delay(resp, attempt_count)) attempt_count += 1 return resp @@ -63,13 +63,12 @@ def __init__(self, request: ReqConfig): def select(self: QueryBuilderT, *columns: str) -> QueryBuilderT: _, params, _, _ = pre_select(*columns, count=None) self.request.params = self.request.params.add("select", params["select"]) - prefer_header = self.request.headers.get("Prefer") - if not prefer_header: + if prefer_headers := self.request.headers.get_list("Prefer", split_commas=True): + prefer_headers = [h for h in prefer_headers if not h.startswith("return=")] + prefer_headers.append("return=representation") + self.request.headers["Prefer"] = ",".join(prefer_headers) + else: self.request.headers["Prefer"] = "return=representation" - elif "return=representation" not in [ - value.strip() for value in prefer_header.split(",") - ]: - self.request.headers["Prefer"] = f"{prefer_header},return=representation" return self def retry(self, enabled: bool) -> Self: diff --git a/src/postgrest/tests/_async/test_request_builder.py b/src/postgrest/tests/_async/test_request_builder.py index 0c807730..9cdc0baa 100644 --- a/src/postgrest/tests/_async/test_request_builder.py +++ b/src/postgrest/tests/_async/test_request_builder.py @@ -137,7 +137,6 @@ def test_insert_with_select_forces_representation( assert builder.request.params["select"] == "id" assert builder.request.headers.get_list("prefer", True) == [ - "return=minimal", "return=representation", ] diff --git a/src/postgrest/tests/_sync/test_request_builder.py b/src/postgrest/tests/_sync/test_request_builder.py index e439f3be..435f8ab5 100644 --- a/src/postgrest/tests/_sync/test_request_builder.py +++ b/src/postgrest/tests/_sync/test_request_builder.py @@ -137,7 +137,6 @@ def test_insert_with_select_forces_representation( assert builder.request.params["select"] == "id" assert builder.request.headers.get_list("prefer", True) == [ - "return=minimal", "return=representation", ] From 56ed182e48465192a66b6a79ed094c8915298753 Mon Sep 17 00:00:00 2001 From: Leonardo Santiago Date: Wed, 29 Apr 2026 11:26:56 -0300 Subject: [PATCH 3/3] fix(postgrest): use time instead of asyncio --- src/postgrest/src/postgrest/_sync/request_builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/postgrest/src/postgrest/_sync/request_builder.py b/src/postgrest/src/postgrest/_sync/request_builder.py index de6b3b77..b6891f56 100644 --- a/src/postgrest/src/postgrest/_sync/request_builder.py +++ b/src/postgrest/src/postgrest/_sync/request_builder.py @@ -1,6 +1,6 @@ from __future__ import annotations -import asyncio +import time from typing import Any, Generic, Literal, Optional, TypeVar, Union, overload from httpx import BasicAuth, Client, Headers, QueryParams, Response @@ -51,7 +51,7 @@ def send_with_retry(req: ReqConfig) -> Response: resp = req.send(headers) if resp.is_success or not req.should_retry(resp, attempt_count=attempt_count): break - asyncio.sleep(get_retry_delay(resp, attempt_count)) + time.sleep(get_retry_delay(resp, attempt_count)) attempt_count += 1 return resp