Skip to content

Commit bfb1a23

Browse files
authored
Merge pull request #616 from splitgraph/bugfix/get-primary-key
Fix potential SQL injection in `get_primary_key`.
2 parents 190fe80 + 379afc3 commit bfb1a23

3 files changed

Lines changed: 32 additions & 26 deletions

File tree

splitgraph/engine/__init__.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -418,15 +418,13 @@ def get_full_table_schema(self, schema: str, table_name: str) -> "TableSchema":
418418
assert schema != "pg_temp"
419419

420420
results = self.run_sql(
421-
SQL(
422-
"SELECT c.attnum, c.attname, "
423-
"pg_catalog.format_type(c.atttypid, c.atttypmod), "
424-
"col_description('{}.{}'::regclass, c.attnum) "
425-
"FROM pg_attribute c JOIN pg_class t ON c.attrelid = t.oid "
426-
"JOIN pg_namespace n ON t.relnamespace = n.oid "
427-
"WHERE n.nspname = %s AND t.relname = %s AND NOT c.attisdropped "
428-
"AND c.attnum >= 0 ORDER BY c.attnum "
429-
).format(Identifier(schema), Identifier(table_name)),
421+
"SELECT c.attnum, c.attname, "
422+
"pg_catalog.format_type(c.atttypid, c.atttypmod), pgd.description "
423+
"FROM pg_attribute c JOIN pg_class t ON c.attrelid = t.oid "
424+
"JOIN pg_namespace n ON t.relnamespace = n.oid "
425+
"LEFT JOIN pg_description pgd ON pgd.objoid = t.oid AND pgd.objsubid = c.attnum "
426+
"WHERE n.nspname = %s AND t.relname = %s AND NOT c.attisdropped "
427+
"AND c.attnum >= 0 ORDER BY c.attnum ",
430428
(schema, table_name),
431429
)
432430

splitgraph/engine/postgres/engine.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,10 @@ def get_conn_str(conn_params: Dict[str, Optional[str]]) -> str:
207207
return f"postgresql://{username}:{password}@{server}:{port}/{dbname}"
208208

209209

210+
def _quote_ident(val: str) -> str:
211+
return '"%s"' % val.replace('"', '""')
212+
213+
210214
class PsycopgEngine(SQLEngine):
211215
"""Postgres SQL engine backed by a Psycopg connection."""
212216

@@ -588,12 +592,16 @@ def get_primary_keys(self, schema: str, table: str) -> List[Tuple[str, str]]:
588592
return cast(
589593
List[Tuple[str, str]],
590594
self.run_sql(
591-
SQL(
592-
"""SELECT a.attname, format_type(a.atttypid, a.atttypmod)
593-
FROM pg_index i JOIN pg_attribute a ON a.attrelid = i.indrelid
594-
AND a.attnum = ANY(i.indkey)
595-
WHERE i.indrelid = '{}.{}'::regclass AND i.indisprimary"""
596-
).format(Identifier(schema), Identifier(table)),
595+
"""SELECT c.column_name, c.data_type
596+
FROM information_schema.table_constraints tc
597+
JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name)
598+
JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema
599+
AND tc.table_name = c.table_name AND ccu.column_name = c.column_name
600+
WHERE constraint_type = 'PRIMARY KEY'
601+
AND tc.constraint_schema = %s
602+
AND tc.table_name = %s
603+
""",
604+
(schema, table),
597605
return_shape=ResultShape.MANY_MANY,
598606
),
599607
)
@@ -828,11 +836,11 @@ def track_tables(self, tables: List[Tuple[str, str]]) -> None:
828836
"""Install the audit trigger on the required tables"""
829837
self.run_sql(
830838
SQL(";").join(
831-
SQL("SELECT {}.audit_table('{}.{}')").format(
832-
Identifier(_AUDIT_SCHEMA), Identifier(s), Identifier(t)
839+
itertools.repeat(
840+
SQL("SELECT {}.audit_table(%s)").format(Identifier(_AUDIT_SCHEMA)), len(tables)
833841
)
834-
for s, t in tables
835-
)
842+
),
843+
["{}.{}".format(_quote_ident(s), _quote_ident(t)) for s, t in tables],
836844
)
837845

838846
def untrack_tables(self, tables: List[Tuple[str, str]]) -> None:

test/splitgraph/commands/test_commit_diff.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,15 +1068,15 @@ def test_create_object_out_of_band(local_engine_empty):
10681068
)
10691069

10701070

1071-
def test_unicode_columns(local_engine_empty):
1071+
def test_unicode_columns_and_quotes_in_table_names(local_engine_empty):
10721072
OUTPUT.init()
1073-
OUTPUT.run_sql("CREATE TABLE таблица (key INTEGER PRIMARY KEY, столбец VARCHAR)")
1074-
OUTPUT.run_sql("COMMENT ON COLUMN таблица.столбец IS 'комментарий';")
1075-
OUTPUT.run_sql("INSERT INTO таблица (key, столбец) VALUES (1, 'one'), (2, 'two')")
1073+
OUTPUT.run_sql('CREATE TABLE "таблица\'" (key INTEGER PRIMARY KEY, столбец VARCHAR)')
1074+
OUTPUT.run_sql("COMMENT ON COLUMN \"таблица'\".столбец IS 'комментарий';")
1075+
OUTPUT.run_sql("INSERT INTO \"таблица'\" (key, столбец) VALUES (1, 'one'), (2, 'two')")
10761076

10771077
image = OUTPUT.commit()
10781078

1079-
assert image.get_table("таблица").table_schema == [
1079+
assert image.get_table("таблица'").table_schema == [
10801080
TableColumn(ordinal=1, name="key", pg_type="integer", is_pk=True, comment=None),
10811081
TableColumn(
10821082
ordinal=2,
@@ -1087,10 +1087,10 @@ def test_unicode_columns(local_engine_empty):
10871087
),
10881088
]
10891089
image.checkout()
1090-
assert OUTPUT.run_sql("SELECT * FROM таблица WHERE столбец = 'two'") == [(2, "two")]
1090+
assert OUTPUT.run_sql("SELECT * FROM \"таблица'\" WHERE столбец = 'two'") == [(2, "two")]
10911091

10921092
image.checkout(layered=True)
1093-
assert OUTPUT.run_sql("SELECT * FROM таблица WHERE столбец = 'one'") == [(1, "one")]
1093+
assert OUTPUT.run_sql("SELECT * FROM \"таблица'\" WHERE столбец = 'one'") == [(1, "one")]
10941094

10951095

10961096
def test_commit_diff_views(pg_repo_local):

0 commit comments

Comments
 (0)