Skip to content

Commit a8a24b1

Browse files
committed
Address reviews
- create schema map in controller for each request - put back test for invalid order param with 200 - unit test for get_schema_fields_map helper
1 parent 3c69bb6 commit a8a24b1

5 files changed

Lines changed: 46 additions & 12 deletions

File tree

server/mergin/sync/public_api_v2_controller.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,8 +437,9 @@ def list_workspace_projects(workspace_id, page, per_page, order_params=None, q=N
437437
projects = projects.filter(Project.name.ilike(f"%{q}%"))
438438

439439
if order_params:
440+
schema_map = get_schema_fields_map(ProjectSchemaV2)
440441
order_by_params = parse_order_params(
441-
Project, order_params, field_map=ProjectSchemaV2.field_map
442+
Project, order_params, field_map=schema_map
442443
)
443444
projects = projects.order_by(*order_by_params)
444445

server/mergin/sync/schemas_v2.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
Project,
1212
ProjectVersion,
1313
)
14-
from ..utils import get_schema_fields_map
1514

1615

1716
class ProjectSchema(ma.SQLAlchemyAutoSchema):
@@ -47,6 +46,3 @@ class Meta:
4746
"workspace",
4847
"role",
4948
)
50-
51-
52-
ProjectSchema.field_map = get_schema_fields_map(ProjectSchema)

server/mergin/tests/test_public_api_v2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,10 @@ def test_list_workspace_projects(client):
651651
resp_data = json.loads(response.data)
652652
assert resp_data["projects"][0]["name"] == project_name
653653

654+
# invalid order param
655+
response = client.get(url + f"?page=1&per_page=10&order_params=invalid DESC")
656+
assert response.status_code == 200
657+
654658
# no permissions to workspace
655659
user2 = add_user("user", "password")
656660
login(client, user2.username, "password")

server/mergin/tests/test_utils.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
import json
88
import pytest
99
from flask import url_for, current_app
10+
from marshmallow import Schema, fields
1011
from sqlalchemy import desc
1112
import os
1213
from unittest.mock import patch
1314
from pathvalidate import sanitize_filename
1415
from pygeodiff import GeoDiff
1516
from pathlib import PureWindowsPath
1617

17-
from ..utils import save_diagnostic_log_file
18+
from ..utils import save_diagnostic_log_file, get_schema_fields_map
1819

1920
from ..sync.utils import (
2021
is_reserved_word,
@@ -297,3 +298,27 @@ def test_save_diagnostic_log_file(client, app):
297298
with open(saved_file_path, "r") as f:
298299
content = f.read()
299300
assert content == body.decode("utf-8")
301+
302+
303+
def test_get_schema_fields_map():
304+
"""Test that schema map correctly resolves DB attributes, keeps all fields, and ignores virtual fields."""
305+
306+
# dummy schema for testing
307+
class TestSchema(Schema):
308+
# standard field -> map 'name': 'name'
309+
name = fields.String()
310+
# aliased field -> map 'size': 'disk_usage
311+
size = fields.Integer(attribute="disk_usage")
312+
# virtual fields -> skip
313+
version = fields.Function(lambda obj: "v1")
314+
role = fields.Method("get_role")
315+
# excluded field - set to None in schema inheritance -> skip
316+
hidden_field = None
317+
318+
schema_map = get_schema_fields_map(TestSchema)
319+
320+
expected_map = {
321+
"name": "name",
322+
"size": "disk_usage",
323+
}
324+
assert schema_map == expected_map

server/mergin/utils.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@
88
from datetime import datetime, timedelta, timezone
99
from enum import Enum
1010
import os
11-
from flask import current_app, abort
11+
from flask import current_app
1212
from flask_sqlalchemy import Model
13-
from marshmallow import Schema
13+
from marshmallow import Schema, fields
1414
from pathvalidate import sanitize_filename
1515
from sqlalchemy import Column, JSON
1616
from sqlalchemy.sql.elements import UnaryExpression
17-
from typing import Optional
18-
17+
from typing import Optional, Type
1918

2019
OrderParam = namedtuple("OrderParam", "name direction")
2120

@@ -151,14 +150,23 @@ def save_diagnostic_log_file(app: str, username: str, body: bytes) -> str:
151150
return file_name
152151

153152

154-
def get_schema_fields_map(schema: Schema) -> dict:
153+
def get_schema_fields_map(schema: Type[Schema]) -> dict:
155154
"""
156155
Creates a mapping of schema field names to corresponding DB columns.
157156
This allows sorting by the API field name (e.g. 'size') while
158157
actually sorting by the database column (e.g. 'disk_usage').
159158
"""
160159
mapping = {}
161160
for name, field in schema._declared_fields.items():
162-
if field and field.attribute:
161+
# some fields could have been overridden with None to be excluded
162+
if not field:
163+
continue
164+
# skip virtual fields as DB cannot sort by them
165+
if isinstance(field, (fields.Function, fields.Method)):
166+
continue
167+
if field.attribute:
163168
mapping[name] = field.attribute
169+
# keep the map complete
170+
else:
171+
mapping[name] = name
164172
return mapping

0 commit comments

Comments
 (0)