diff --git a/src/postgrest/src/postgrest/_async/request_builder.py b/src/postgrest/src/postgrest/_async/request_builder.py index 855e46db..9523d2ca 100644 --- a/src/postgrest/src/postgrest/_async/request_builder.py +++ b/src/postgrest/src/postgrest/_async/request_builder.py @@ -27,6 +27,7 @@ from ..utils import model_validate_json ReqConfig = RequestConfig[AsyncClient] +QueryBuilderT = TypeVar("QueryBuilderT", bound="AsyncQueryRequestBuilder") def get_retry_delay(resp: Response, attempt_count: int) -> int: @@ -59,6 +60,17 @@ class AsyncQueryRequestBuilder: def __init__(self, request: ReqConfig): self.request = request + def select(self: QueryBuilderT, *columns: str) -> QueryBuilderT: + _, params, _, _ = pre_select(*columns, count=None) + self.request.params = self.request.params.add("select", params["select"]) + if prefer_headers := self.request.headers.get_list("Prefer", split_commas=True): + prefer_headers = [h for h in prefer_headers if not h.startswith("return=")] + prefer_headers.append("return=representation") + self.request.headers["Prefer"] = ",".join(prefer_headers) + else: + self.request.headers["Prefer"] = "return=representation" + return self + def retry(self, enabled: bool) -> Self: self.request.retry_enabled = enabled return self diff --git a/src/postgrest/src/postgrest/_sync/request_builder.py b/src/postgrest/src/postgrest/_sync/request_builder.py index 12777316..b6891f56 100644 --- a/src/postgrest/src/postgrest/_sync/request_builder.py +++ b/src/postgrest/src/postgrest/_sync/request_builder.py @@ -27,6 +27,7 @@ from ..utils import model_validate_json ReqConfig = RequestConfig[Client] +QueryBuilderT = TypeVar("QueryBuilderT", bound="SyncQueryRequestBuilder") def get_retry_delay(resp: Response, attempt_count: int) -> int: @@ -59,6 +60,17 @@ class SyncQueryRequestBuilder: def __init__(self, request: ReqConfig): self.request = request + def select(self: QueryBuilderT, *columns: str) -> QueryBuilderT: + _, params, _, _ = pre_select(*columns, count=None) + self.request.params = self.request.params.add("select", params["select"]) + if prefer_headers := self.request.headers.get_list("Prefer", split_commas=True): + prefer_headers = [h for h in prefer_headers if not h.startswith("return=")] + prefer_headers.append("return=representation") + self.request.headers["Prefer"] = ",".join(prefer_headers) + else: + self.request.headers["Prefer"] = "return=representation" + return self + def retry(self, enabled: bool) -> Self: self.request.retry_enabled = enabled return self diff --git a/src/postgrest/tests/_async/test_request_builder.py b/src/postgrest/tests/_async/test_request_builder.py index 356be18f..9cdc0baa 100644 --- a/src/postgrest/tests/_async/test_request_builder.py +++ b/src/postgrest/tests/_async/test_request_builder.py @@ -7,7 +7,7 @@ from postgrest import AsyncRequestBuilder, AsyncSingleRequestBuilder from postgrest._async.request_builder import RequestConfig from postgrest.base_request_builder import APIResponse, SingleAPIResponse -from postgrest.types import JSON, CountMethod +from postgrest.types import JSON, CountMethod, ReturnMethod @pytest.fixture @@ -120,6 +120,26 @@ def test_upsert(self, request_builder: AsyncRequestBuilder): assert builder.request.http_method == "POST" assert builder.request.json == {"key1": "val1"} + def test_insert_with_select(self, request_builder: AsyncRequestBuilder): + builder = request_builder.insert({"key1": "val1"}).select("id", "key1") + + assert builder.request.params["select"] == "id,key1" + assert builder.request.headers.get_list("prefer", True) == [ + "return=representation" + ] + + def test_insert_with_select_forces_representation( + self, request_builder: AsyncRequestBuilder + ): + builder = request_builder.insert( + {"key1": "val1"}, returning=ReturnMethod.minimal + ).select("id") + + assert builder.request.params["select"] == "id" + assert builder.request.headers.get_list("prefer", True) == [ + "return=representation", + ] + def test_bulk_upsert_with_default(self, request_builder: AsyncRequestBuilder): builder = request_builder.upsert( [{"key1": "val1", "key2": "val2"}, {"key3": "val3"}], default_to_null=False @@ -168,6 +188,12 @@ def test_update_with_max_affected(self, request_builder: AsyncRequestBuilder): assert builder.request.http_method == "PATCH" assert builder.request.json == {"key1": "val1"} + def test_update_with_select(self, request_builder: AsyncRequestBuilder): + builder = request_builder.update({"key1": "val1"}).eq("id", 1).select("id") + + assert builder.request.params["id"] == "eq.1" + assert builder.request.params["select"] == "id" + class TestDelete: def test_delete(self, request_builder: AsyncRequestBuilder): @@ -198,6 +224,12 @@ def test_delete_with_max_affected(self, request_builder: AsyncRequestBuilder): assert builder.request.http_method == "DELETE" assert builder.request.json == {} + def test_delete_with_select(self, request_builder: AsyncRequestBuilder): + builder = request_builder.delete().eq("id", 1).select("id") + + assert builder.request.params["id"] == "eq.1" + assert builder.request.params["select"] == "id" + class TestTextSearch: def test_text_search(self, request_builder: AsyncRequestBuilder): diff --git a/src/postgrest/tests/_sync/test_request_builder.py b/src/postgrest/tests/_sync/test_request_builder.py index 5217de24..435f8ab5 100644 --- a/src/postgrest/tests/_sync/test_request_builder.py +++ b/src/postgrest/tests/_sync/test_request_builder.py @@ -7,7 +7,7 @@ from postgrest import SyncRequestBuilder, SyncSingleRequestBuilder from postgrest._async.request_builder import RequestConfig from postgrest.base_request_builder import APIResponse, SingleAPIResponse -from postgrest.types import JSON, CountMethod +from postgrest.types import JSON, CountMethod, ReturnMethod @pytest.fixture @@ -120,6 +120,26 @@ def test_upsert(self, request_builder: SyncRequestBuilder): assert builder.request.http_method == "POST" assert builder.request.json == {"key1": "val1"} + def test_insert_with_select(self, request_builder: SyncRequestBuilder): + builder = request_builder.insert({"key1": "val1"}).select("id", "key1") + + assert builder.request.params["select"] == "id,key1" + assert builder.request.headers.get_list("prefer", True) == [ + "return=representation" + ] + + def test_insert_with_select_forces_representation( + self, request_builder: SyncRequestBuilder + ): + builder = request_builder.insert( + {"key1": "val1"}, returning=ReturnMethod.minimal + ).select("id") + + assert builder.request.params["select"] == "id" + assert builder.request.headers.get_list("prefer", True) == [ + "return=representation", + ] + def test_bulk_upsert_with_default(self, request_builder: SyncRequestBuilder): builder = request_builder.upsert( [{"key1": "val1", "key2": "val2"}, {"key3": "val3"}], default_to_null=False @@ -168,6 +188,12 @@ def test_update_with_max_affected(self, request_builder: SyncRequestBuilder): assert builder.request.http_method == "PATCH" assert builder.request.json == {"key1": "val1"} + def test_update_with_select(self, request_builder: SyncRequestBuilder): + builder = request_builder.update({"key1": "val1"}).eq("id", 1).select("id") + + assert builder.request.params["id"] == "eq.1" + assert builder.request.params["select"] == "id" + class TestDelete: def test_delete(self, request_builder: SyncRequestBuilder): @@ -198,6 +224,12 @@ def test_delete_with_max_affected(self, request_builder: SyncRequestBuilder): assert builder.request.http_method == "DELETE" assert builder.request.json == {} + def test_delete_with_select(self, request_builder: SyncRequestBuilder): + builder = request_builder.delete().eq("id", 1).select("id") + + assert builder.request.params["id"] == "eq.1" + assert builder.request.params["select"] == "id" + class TestTextSearch: def test_text_search(self, request_builder: SyncRequestBuilder):