Skip to content

Commit d65c563

Browse files
committed
feat: enhance SQL identifier handling and validation in Python bindings
1 parent 2922748 commit d65c563

9 files changed

Lines changed: 102 additions & 22 deletions

File tree

bindings/python/examples/16_import_database_vs_transactional_graph_ingest.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
"""
5757

5858
import argparse
59+
import re
5960
import shutil
6061
import time
6162
from pathlib import Path
@@ -66,6 +67,17 @@
6667
ColumnDef = Tuple[str, str]
6768

6869
NUMERIC_COLUMN_TYPES = {"INTEGER", "LONG"}
70+
SAFE_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
71+
72+
73+
def _validated_identifier(identifier: str) -> str:
74+
if not SAFE_IDENTIFIER_RE.fullmatch(identifier):
75+
raise ValueError(f"Unsafe SQL identifier: {identifier!r}")
76+
return identifier
77+
78+
79+
def _quote_identifier(identifier: str) -> str:
80+
return f"`{_validated_identifier(identifier)}`"
6981

7082

7183
def result_int(row, *keys: str) -> int:
@@ -131,9 +143,10 @@ def edge_endpoints(edge_id: int, vertex_count: int) -> Tuple[int, int]:
131143

132144

133145
def build_rid_lookup_for_vertex_type(db, vertex_type: str) -> Dict[int, str]:
146+
safe_vertex_type = _quote_identifier(vertex_type)
134147
rows = db.query(
135148
"sql",
136-
f"SELECT Id, @rid as rid FROM {vertex_type}",
149+
f"SELECT Id, @rid as rid FROM {safe_vertex_type}", # nosec B608 - validated identifier
137150
).to_list()
138151
rid_lookup: Dict[int, str] = {}
139152
for row in rows:
@@ -167,11 +180,11 @@ def query_one_or_none(result_set):
167180
def collect_vertex_sample(
168181
db, vertex_type: str, vertex_id: int, props: List[ColumnDef]
169182
) -> dict:
183+
safe_vertex_type = _quote_identifier(vertex_type)
170184
row = query_one_or_none(
171185
db.query(
172186
"sql",
173-
# vertex_type is a constant from this script; vertex_id is bound as parameter.
174-
f"SELECT FROM {vertex_type} WHERE Id = ?",
187+
f"SELECT FROM {safe_vertex_type} WHERE Id = ?", # nosec B608 - validated identifier
175188
vertex_id,
176189
)
177190
)
@@ -230,6 +243,7 @@ def collect_graph_signature(
230243
str(db_path),
231244
jvm_kwargs={"heap_size": heap_size} if heap_size else None,
232245
) as db:
246+
safe_vertex_type = _quote_identifier(vertex_type)
233247
vertex_int_props = [
234248
name for name, kind in vertex_props if kind in NUMERIC_COLUMN_TYPES
235249
]
@@ -242,7 +256,10 @@ def collect_graph_signature(
242256
"sum(Id) AS sum_id",
243257
"min(Id) AS min_id",
244258
"max(Id) AS max_id",
245-
] + [f"sum({name}) AS sum_{name}" for name in vertex_int_props]
259+
] + [
260+
f"sum({_quote_identifier(name)}) AS sum_{_validated_identifier(name)}"
261+
for name in vertex_int_props
262+
]
246263

247264
edge_match_aggregate_fields = [
248265
"count(r) AS count",
@@ -254,8 +271,7 @@ def collect_graph_signature(
254271
vertex_aggregate = query_one_or_none(
255272
db.query(
256273
"sql",
257-
# vertex_aggregate_fields and vertex_type are script-local constants.
258-
f"SELECT {', '.join(vertex_aggregate_fields)} FROM {vertex_type}",
274+
f"SELECT {', '.join(vertex_aggregate_fields)} FROM {safe_vertex_type}", # nosec B608 - validated identifier
259275
)
260276
)
261277
edge_aggregate = query_one_or_none(

bindings/python/examples/20_graph_algorithms_route_planning.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,21 @@
2929
from __future__ import annotations
3030

3131
import argparse
32+
import re
3233
import shutil
3334
from pathlib import Path
3435

3536
import arcadedb_embedded as arcadedb
3637

38+
SAFE_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
39+
40+
41+
def _quote_identifier(identifier: str) -> str:
42+
if not SAFE_IDENTIFIER_RE.fullmatch(identifier):
43+
raise ValueError(f"Unsafe SQL identifier: {identifier!r}")
44+
return f"`{identifier}`"
45+
46+
3747
CITIES = [
3848
{"code": "ALPHA", "name": "Alpha Hub", "lat": 0.0, "lon": 0.0},
3949
{"code": "BRAVO", "name": "Bravo Junction", "lat": 1.0, "lon": 0.0},
@@ -284,8 +294,7 @@ def insert_seed_data(db) -> None:
284294
for route in ROUTES:
285295
db.command(
286296
"sql",
287-
# route['edge_type'] is a constant from the demo schema.
288-
f"CREATE EDGE {route['edge_type']} "
297+
f"CREATE EDGE {_quote_identifier(route['edge_type'])} " # nosec B608 - validated identifier
289298
"FROM (SELECT FROM City WHERE code = ? LIMIT 1) "
290299
"TO (SELECT FROM City WHERE code = ? LIMIT 1) "
291300
"SET distance = ?, duration = ?, risk = ?, lane = ?",
@@ -758,7 +767,7 @@ def run_reopen_phase(db_path: Path) -> None:
758767
route_count = sum(
759768
reopened_db.query(
760769
"sql",
761-
f"SELECT count(*) AS count FROM {edge_type}",
770+
f"SELECT count(*) AS count FROM {_quote_identifier(edge_type)}", # nosec B608 - validated identifier
762771
)
763772
.first()
764773
.get("count")

bindings/python/examples/22_graph_analytical_view_sql.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from __future__ import annotations
3737

3838
import argparse
39+
import re
3940
import shutil
4041
import time
4142
from pathlib import Path
@@ -51,6 +52,20 @@
5152
"hub": "regional",
5253
"bridge": "interchange",
5354
}
55+
SAFE_GRAPH_LITERAL_RE = re.compile(r"^[A-Za-z0-9_-]+$")
56+
57+
58+
def _sql_string_literal(value: str) -> str:
59+
if not SAFE_GRAPH_LITERAL_RE.fullmatch(value):
60+
raise ValueError(f"Unsafe SQL literal: {value!r}")
61+
return f"'{value}'"
62+
63+
64+
def _validated_limit(value: int) -> int:
65+
limit = int(value)
66+
if limit <= 0:
67+
raise ValueError(f"SQL limit must be positive: {value!r}")
68+
return limit
5469

5570

5671
def parse_args() -> argparse.Namespace:
@@ -539,52 +554,54 @@ def query_direct_neighbor_sample(
539554

540555

541556
def query_two_hop_summary(db, origin_code: str) -> dict:
542-
# origin_code is a script-local constant from the demo dataset.
557+
safe_origin_code = _sql_string_literal(origin_code)
543558
result = db.query(
544559
"sql",
545560
f"""
546561
SELECT count(*) AS destination_count FROM (
547-
MATCH {{type: City, as: src, where: (code = '{origin_code}')}}
562+
MATCH {{type: City, as: src, where: (code = {safe_origin_code})}}
548563
-ROAD->
549564
{{type: City, as: mid}}
550565
-ROAD->
551566
{{type: City, as: dst}}
552567
RETURN DISTINCT dst.code AS code
553568
)
554-
""",
569+
""", # nosec B608 - value validated via _sql_string_literal
555570
)
556571
row = result.first()
557572
require(row is not None, "Expected a two-hop summary row")
558573
return {"destination_count": row.get("destination_count")}
559574

560575

561576
def query_hub_inbound_count(db, hub_code: str) -> int:
577+
safe_hub_code = _sql_string_literal(hub_code)
562578
result = db.query(
563579
"sql",
564580
f"""
565581
SELECT count(*) AS inbound_count FROM (
566582
MATCH {{type: City, as: src}}
567583
-ROAD->
568-
{{type: City, as: hub, where: (code = '{hub_code}')}}
584+
{{type: City, as: hub, where: (code = {safe_hub_code})}}
569585
RETURN src.code AS code
570586
)
571-
""",
587+
""", # nosec B608 - value validated via _sql_string_literal
572588
)
573589
row = result.first()
574590
require(row is not None, "Expected an inbound count row")
575591
return row.get("inbound_count")
576592

577593

578594
def query_region_sample(db, sample_limit: int) -> list[dict]:
595+
safe_sample_limit = _validated_limit(sample_limit)
579596
result = db.query(
580597
"sql",
581598
f"""
582599
SELECT region, count(*) AS city_count, avg(demand_index) AS avg_demand
583600
FROM City
584601
GROUP BY region
585602
ORDER BY region
586-
LIMIT {sample_limit}
587-
""",
603+
LIMIT {safe_sample_limit}
604+
""", # nosec B608 - value validated via _validated_limit
588605
)
589606
return rows_to_dicts(result, ["region", "city_count", "avg_demand"])
590607

bindings/python/scripts/Dockerfile.build

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
ARG PYTHON_VERSION=3.12
1212
ARG PACKAGE_NAME=arcadedb-embedded
1313
ARG PACKAGE_DESCRIPTION="ArcadeDB embedded multi-model database with bundled JRE - no Java installation required"
14-
ARG ARCADEDB_TAG=latest
14+
# Deterministic placeholder. The wrapper build script must pass the real ArcadeDB tag.
15+
ARG ARCADEDB_TAG=build-arg-required
1516
ARG TARGET_PLATFORM=linux-x64
1617
# When set to 1, prefer jars provided in bindings/python/local-jars/lib from the build context.
1718
# If no local jars are present, the build fails fast to avoid silently falling back.

bindings/python/scripts/build.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,13 @@ else
292292
# Determine Docker build platform (always Linux for cross-compilation)
293293
# We build ON linux/amd64 or linux/arm64, but FOR any target platform
294294
DOCKER_PLATFORM="${PLATFORM}"
295+
296+
if [[ -z "$DOCKER_TAG" ]]; then
297+
echo -e "${RED}❌ Missing ArcadeDB Docker tag${NC}"
298+
echo -e "${YELLOW}💡 Pass a version so Docker builds remain reproducible${NC}"
299+
exit 1
300+
fi
301+
295302
# Build Docker image
296303
echo -e "${CYAN}📦 Building Docker image...${NC}"
297304

bindings/python/src/arcadedb_embedded/vector.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def to_java_float_array(vector):
3737
import numpy as np
3838

3939
if isinstance(vector, np.ndarray):
40-
vector = vector.tolist()
40+
return jtypes.JArray(jtypes.JFloat)(vector)
4141
except ImportError:
4242
pass
4343

@@ -69,12 +69,14 @@ def to_java_byte_array(vector):
6969
import numpy as np
7070

7171
if isinstance(vector, np.ndarray):
72+
if vector.dtype in (np.int8, np.uint8):
73+
return jtypes.JArray(jtypes.JByte)(vector.tobytes())
7274
vector = vector.tolist()
7375
except ImportError:
7476
pass
7577

7678
if isinstance(vector, (bytes, bytearray)):
77-
vector = list(vector)
79+
return jtypes.JArray(jtypes.JByte)(vector)
7880
elif not isinstance(vector, list):
7981
vector = list(vector)
8082

bindings/python/tests/test_cypher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ def test_opencypher_edge_typed_constraint_command(temp_db_path):
539539

540540

541541
def test_opencypher_is_typed_value_predicate(temp_db_path):
542-
"""OpenCypher should expose the GQL IS TYPED value predicate."""
542+
"""Verify that OpenCypher exposes the GQL IS TYPED value predicate."""
543543
with arcadedb.create_database(temp_db_path) as db:
544544
_ensure_opencypher(db)
545545

bindings/python/tests/test_vector.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,35 @@ def test_db(tmp_path):
1919
db.drop()
2020

2121

22+
def test_to_java_byte_array_accepts_bytes_like_fast_paths():
23+
"""to_java_byte_array should preserve signed byte semantics for bytes-like
24+
inputs."""
25+
np = pytest.importorskip("numpy")
26+
27+
assert list(arcadedb.to_java_byte_array(bytes([0, 127, 255]))) == [0, 127, -1]
28+
assert list(arcadedb.to_java_byte_array(bytearray([1, 2, 255]))) == [1, 2, -1]
29+
assert list(arcadedb.to_java_byte_array(np.array([0, 127, -1], dtype=np.int8))) == [
30+
0,
31+
127,
32+
-1,
33+
]
34+
assert list(
35+
arcadedb.to_java_byte_array(np.array([0, 127, 255], dtype=np.uint8))
36+
) == [0, 127, -1]
37+
38+
39+
def test_to_java_float_array_accepts_numpy_directly():
40+
"""to_java_float_array should accept NumPy arrays without Python-list copies."""
41+
np = pytest.importorskip("numpy")
42+
43+
assert list(
44+
arcadedb.to_java_float_array(np.array([1.0, 2.0, 3.0], dtype=np.float32))
45+
) == pytest.approx([1.0, 2.0, 3.0])
46+
assert list(
47+
arcadedb.to_java_float_array(np.array([1.0, 2.0, 3.0], dtype=np.float64))
48+
) == pytest.approx([1.0, 2.0, 3.0])
49+
50+
2251
class TestLSMVectorIndex:
2352
"""Test LSM Vector Index functionality."""
2453

bindings/python/tests/test_vector_sql.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def test_create_index_with_native_int8_encoding_sql(self, test_db):
370370
assert str(metadata.quantizationType) == "NONE"
371371

372372
def test_vector_neighbors_on_native_int8_storage_sql(self, test_db):
373-
"""vectorNeighbors should work against native INT8-encoded storage."""
373+
"""Verify that vectorNeighbors works against native INT8-encoded storage."""
374374
test_db.command("sql", "CREATE VERTEX TYPE SqlNativeInt8SearchDoc")
375375
test_db.command("sql", "CREATE PROPERTY SqlNativeInt8SearchDoc.id STRING")
376376
test_db.command("sql", "CREATE PROPERTY SqlNativeInt8SearchDoc.vec BINARY")
@@ -619,7 +619,6 @@ def test_vector_neighbors_by_key_opencypher(self, test_db):
619619

620620
def test_vector_neighbors_group_by_sql(self, test_db):
621621
"""SQL vector.neighbors should support groupBy/groupSize options."""
622-
623622
test_db.command("sql", "CREATE DOCUMENT TYPE GroupedDoc")
624623
test_db.command("sql", "CREATE PROPERTY GroupedDoc.source_file STRING")
625624
test_db.command("sql", "CREATE PROPERTY GroupedDoc.embedding ARRAY_OF_FLOATS")

0 commit comments

Comments
 (0)