@@ -207,6 +207,10 @@ def get_conn_str(conn_params: Dict[str, Optional[str]]) -> str:
207207 return f"postgresql://{ username } :{ password } @{ server } :{ port } /{ dbname } "
208208
209209
210+ def _quote_ident (val : str ) -> str :
211+ return '"%s"' % val .replace ('"' , '""' )
212+
213+
210214class PsycopgEngine (SQLEngine ):
211215 """Postgres SQL engine backed by a Psycopg connection."""
212216
@@ -588,12 +592,16 @@ def get_primary_keys(self, schema: str, table: str) -> List[Tuple[str, str]]:
588592 return cast (
589593 List [Tuple [str , str ]],
590594 self .run_sql (
591- SQL (
592- """SELECT a.attname, format_type(a.atttypid, a.atttypmod)
593- FROM pg_index i JOIN pg_attribute a ON a.attrelid = i.indrelid
594- AND a.attnum = ANY(i.indkey)
595- WHERE i.indrelid = '{}.{}'::regclass AND i.indisprimary"""
596- ).format (Identifier (schema ), Identifier (table )),
595+ """SELECT c.column_name, c.data_type
596+ FROM information_schema.table_constraints tc
597+ JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name)
598+ JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema
599+ AND tc.table_name = c.table_name AND ccu.column_name = c.column_name
600+ WHERE constraint_type = 'PRIMARY KEY'
601+ AND tc.constraint_schema = %s
602+ AND tc.table_name = %s
603+ """ ,
604+ (schema , table ),
597605 return_shape = ResultShape .MANY_MANY ,
598606 ),
599607 )
@@ -828,11 +836,11 @@ def track_tables(self, tables: List[Tuple[str, str]]) -> None:
828836 """Install the audit trigger on the required tables"""
829837 self .run_sql (
830838 SQL (";" ).join (
831- SQL ( "SELECT {}.audit_table('{}.{}')" ). format (
832- Identifier ( _AUDIT_SCHEMA ), Identifier ( s ), Identifier ( t )
839+ itertools . repeat (
840+ SQL ( "SELECT {}.audit_table(%s)" ). format ( Identifier ( _AUDIT_SCHEMA )), len ( tables )
833841 )
834- for s , t in tables
835- )
842+ ),
843+ [ "{}.{}" . format ( _quote_ident ( s ), _quote_ident ( t )) for s , t in tables ],
836844 )
837845
838846 def untrack_tables (self , tables : List [Tuple [str , str ]]) -> None :
0 commit comments