From 25064f5475b0d2d755489ef4fe6d91737b53e0de Mon Sep 17 00:00:00 2001 From: Ashesh Vashi Date: Thu, 9 Apr 2026 16:01:25 +0530 Subject: [PATCH 1/3] fix: enforce data isolation and harden shared servers in server mode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit pgAdmin 4 in server mode had no data isolation between users — any authenticated user could access other users' private servers, background processes, and debugger state by guessing object IDs. The shared server feature (21 audit issues) leaked owner credentials, saved passwords to wrong records, and corrupted owner data via SQLAlchemy session mutations. Centralized access control: - New server_access.py with get_server(), get_server_group(), get_user_server_query() replacing ~20 scattered unfiltered queries - connection_manager() raises ObjectGone (HTTP 410) in server mode when access is denied — fixes 155+ unguarded callers - UserScopedMixin.for_user() on 10 models replaces scattered user_id=current_user.id filters Shared server isolation (addresses all 21 audit issues): - get_shared_server_properties() expunges server from session, suppresses passexec_cmd/post_connection_sql, overrides all 6 SSL/passfile keys, strips owner-only keys - _is_non_owner() helper centralises 15+ inline checks - SENSITIVE_CONN_KEYS module-level constant (DRY) - Sanitizes connection_params on SharedServer creation - Tunnel/DB password save branches on ownership - change_password() checks SharedServer.password for non-owners - clear_saved/sshtunnel_password() use get_shared_server() - update_connection_parameter() routes to SharedServer copy - Only owner can trigger delete_shared_server (unshare) - SharedServer lookup uses (osid, user_id) not name (Issue 20) - Unique constraint on SharedServer(osid, user_id) with migration - ServerManager suppresses passexec/post_connection_sql for non-owners at creation time - Session restore includes shared servers - wal_replay()/check_pgpass() use get_server() - tunnel_port/tunnel_keep_alive copied from owner - delete_shared_server() accepts user_id parameter - get_shared_server() raises on None, catches IntegrityError - create_shared_server() exception handler uses rollback Tool/module hardening: - All tool endpoints use get_server() - Debugger function arguments scoped by user_id (migration) - Background processes use Process.for_user() - Workspace adhoc servers scoped by user_id Migration (schema version 49 -> 50): - Add user_id to debugger_function_arguments composite PK - Add indexes on server, sharedserver, servergroup - Add unique constraint on sharedserver(osid, user_id) - Fix ca00ec32581b to use raw SQL --- .../add_user_id_to_debugger_func_args_.py | 149 ++++++++ web/migrations/versions/ca00ec32581b_.py | 11 +- web/pgadmin/browser/server_groups/__init__.py | 35 +- .../browser/server_groups/servers/__init__.py | 359 +++++++++++------- .../servers/databases/__init__.py | 9 +- .../databases/schemas/views/__init__.py | 10 +- .../browser/server_groups/servers/utils.py | 64 +++- web/pgadmin/misc/bgprocess/processes.py | 32 +- web/pgadmin/misc/cloud/__init__.py | 5 +- web/pgadmin/misc/workspaces/__init__.py | 70 ++-- web/pgadmin/model/__init__.py | 86 ++++- web/pgadmin/tools/debugger/__init__.py | 41 +- web/pgadmin/tools/erd/__init__.py | 3 +- web/pgadmin/tools/import_export/__init__.py | 8 +- web/pgadmin/tools/psql/__init__.py | 3 +- web/pgadmin/tools/schema_diff/__init__.py | 68 +++- web/pgadmin/tools/sqleditor/__init__.py | 39 +- web/pgadmin/tools/user_management/__init__.py | 2 +- web/pgadmin/utils/__init__.py | 14 +- web/pgadmin/utils/driver/psycopg3/__init__.py | 67 +++- web/pgadmin/utils/server_access.py | 156 ++++++++ 21 files changed, 920 insertions(+), 311 deletions(-) create mode 100644 web/migrations/versions/add_user_id_to_debugger_func_args_.py create mode 100644 web/pgadmin/utils/server_access.py diff --git a/web/migrations/versions/add_user_id_to_debugger_func_args_.py b/web/migrations/versions/add_user_id_to_debugger_func_args_.py new file mode 100644 index 00000000000..d8cd046da5e --- /dev/null +++ b/web/migrations/versions/add_user_id_to_debugger_func_args_.py @@ -0,0 +1,149 @@ +########################################################################## +# +# pgAdmin 4 - PostgreSQL Tools +# +# Copyright (C) 2013 - 2026, The pgAdmin Development Team +# This software is released under the PostgreSQL Licence +# +########################################################################## + +"""Add user_id to debugger_function_arguments and indexes for data isolation + +Revision ID: add_user_id_dbg_args +Revises: add_tools_ai_perm +Create Date: 2026-04-08 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = 'add_user_id_dbg_args' +down_revision = 'add_tools_ai_perm' +branch_labels = None +depends_on = None + + +def upgrade(): + conn = op.get_bind() + dialect = conn.dialect.name + + # --- DebuggerFunctionArguments: add user_id to composite PK --- + if dialect == 'sqlite': + # SQLite cannot ALTER composite PKs. Recreate the table. + # Existing debugger argument data is ephemeral (cached function + # args) so dropping is acceptable. + op.execute( + 'DROP TABLE IF EXISTS debugger_function_arguments' + ) + op.create_table( + 'debugger_function_arguments', + sa.Column('user_id', sa.Integer(), + sa.ForeignKey('user.id'), nullable=False), + sa.Column('server_id', sa.Integer(), nullable=False), + sa.Column('database_id', sa.Integer(), nullable=False), + sa.Column('schema_id', sa.Integer(), nullable=False), + sa.Column('function_id', sa.Integer(), nullable=False), + sa.Column('arg_id', sa.Integer(), nullable=False), + sa.Column('is_null', sa.Integer(), nullable=False), + sa.Column('is_expression', sa.Integer(), nullable=False), + sa.Column('use_default', sa.Integer(), nullable=False), + sa.Column('value', sa.String(), nullable=True), + sa.PrimaryKeyConstraint( + 'user_id', 'server_id', 'database_id', + 'schema_id', 'function_id', 'arg_id' + ), + sa.CheckConstraint('is_null >= 0 AND is_null <= 1'), + sa.CheckConstraint( + 'is_expression >= 0 AND is_expression <= 1'), + sa.CheckConstraint( + 'use_default >= 0 AND use_default <= 1'), + ) + else: + # PostgreSQL: add column, backfill from server owner, recreate + # PK using batch_alter_table for portability. + op.add_column( + 'debugger_function_arguments', + sa.Column('user_id', sa.Integer(), + sa.ForeignKey('user.id'), nullable=True) + ) + # Backfill: assign user_id from the server's owner + op.execute( + 'UPDATE debugger_function_arguments ' + 'SET user_id = s.user_id ' + 'FROM server s ' + 'WHERE debugger_function_arguments.server_id = s.id' + ) + # Delete orphans (rows with no matching server) + op.execute( + 'DELETE FROM debugger_function_arguments ' + 'WHERE user_id IS NULL' + ) + op.alter_column( + 'debugger_function_arguments', 'user_id', nullable=False + ) + # Recreate PK with user_id using batch_alter_table + with op.batch_alter_table( + 'debugger_function_arguments' + ) as batch: + batch.drop_constraint( + 'debugger_function_arguments_pkey', type_='primary' + ) + batch.create_primary_key( + 'debugger_function_arguments_pkey', + ['user_id', 'server_id', 'database_id', + 'schema_id', 'function_id', 'arg_id'] + ) + + # --- Indexes for data isolation query performance --- + # Only create indexes on tables that exist (sharedserver may be + # absent in older schemas that haven't run all prior migrations). + inspector = sa.inspect(conn) + index_stmts = [ + ('server', + 'CREATE INDEX IF NOT EXISTS ix_server_user_id ' + 'ON server (user_id)'), + ('server', + 'CREATE INDEX IF NOT EXISTS ix_server_servergroup_id ' + 'ON server (servergroup_id)'), + ('sharedserver', + 'CREATE INDEX IF NOT EXISTS ix_sharedserver_user_id ' + 'ON sharedserver (user_id)'), + ('sharedserver', + 'CREATE INDEX IF NOT EXISTS ix_sharedserver_osid ' + 'ON sharedserver (osid)'), + ('servergroup', + 'CREATE INDEX IF NOT EXISTS ix_servergroup_user_id ' + 'ON servergroup (user_id)'), + ] + for table_name, stmt in index_stmts: + if inspector.has_table(table_name): + op.execute(stmt) + + # --- Unique constraint on SharedServer(osid, user_id) --- + # Prevents duplicate SharedServer records from TOCTOU race. + # First remove duplicates (keep lowest id per osid+user_id). + if inspector.has_table('sharedserver'): + if dialect == 'sqlite': + op.execute( + 'DELETE FROM sharedserver WHERE id NOT IN ' + '(SELECT MIN(id) FROM sharedserver ' + 'GROUP BY osid, user_id)' + ) + else: + op.execute( + 'DELETE FROM sharedserver s1 USING ' + 'sharedserver s2 WHERE s1.osid = s2.osid ' + 'AND s1.user_id = s2.user_id ' + 'AND s1.id > s2.id' + ) + with op.batch_alter_table('sharedserver') as batch: + batch.create_unique_constraint( + 'uq_sharedserver_osid_user', + ['osid', 'user_id'] + ) + + +def downgrade(): + # pgAdmin only upgrades, downgrade not implemented. + pass diff --git a/web/migrations/versions/ca00ec32581b_.py b/web/migrations/versions/ca00ec32581b_.py index 6d566cd1788..64a3ba12f30 100644 --- a/web/migrations/versions/ca00ec32581b_.py +++ b/web/migrations/versions/ca00ec32581b_.py @@ -15,8 +15,6 @@ """ from alembic import op -from sqlalchemy.orm.session import Session -from pgadmin.model import DebuggerFunctionArguments # revision identifiers, used by Alembic. revision = 'ca00ec32581b' @@ -26,11 +24,10 @@ def upgrade(): - session = Session(bind=op.get_bind()) - - debugger_records = session.query(DebuggerFunctionArguments).all() - if debugger_records: - session.delete(debugger_records) + # Use raw SQL instead of importing the model class, because + # model changes in later migrations (e.g. adding user_id) would + # cause this migration to fail on fresh databases. + op.execute('DELETE FROM debugger_function_arguments') def downgrade(): diff --git a/web/pgadmin/browser/server_groups/__init__.py b/web/pgadmin/browser/server_groups/__init__.py index e0212d277a9..1d70695c316 100644 --- a/web/pgadmin/browser/server_groups/__init__.py +++ b/web/pgadmin/browser/server_groups/__init__.py @@ -25,6 +25,8 @@ from pgadmin.model import db, ServerGroup, Server import config from pgadmin.utils.preferences import Preferences +from pgadmin.utils.server_access import get_server_group, \ + get_server_groups_for_user def get_icon_css_class(group_id, group_user_id, @@ -286,7 +288,7 @@ def update(self, gid): def properties(self, gid): """Update the server-group properties""" - sg = ServerGroup.query.filter(ServerGroup.id == gid).first() + sg = get_server_group(gid) if sg is None: return make_json_response( @@ -296,7 +298,8 @@ def properties(self, gid): ) else: return ajax_response( - response={'id': sg.id, 'name': sg.name, 'user_id': sg.user_id}, + response={'id': sg.id, 'name': sg.name, + 'user_id': sg.user_id}, status=200 ) @@ -373,8 +376,9 @@ def dependents(self, gid): @staticmethod def get_all_server_groups(): """ - Returns the list of server groups to show in server mode and - if there is any shared server in the group. + Returns the list of server groups to show in server mode. + Includes groups owned by the user and groups containing + shared servers accessible to this user. :return: server groups """ @@ -383,17 +387,18 @@ def get_all_server_groups(): pref = Preferences.module('browser') hide_shared_server = pref.preference('hide_shared_server').get() - server_groups = ServerGroup.query.all() - groups = [] - for group in server_groups: - if hide_shared_server and \ - ServerGroupModule.has_shared_server(group.id) and \ - group.user_id != current_user.id: - continue - if group.user_id == current_user.id or \ - ServerGroupModule.has_shared_server(group.id): + server_groups = get_server_groups_for_user() + + if hide_shared_server: + groups = [] + for group in server_groups: + if group.user_id != current_user.id and \ + ServerGroupModule.has_shared_server(group.id): + continue groups.append(group) - return groups + return groups + + return server_groups @pga_login_required def nodes(self, gid=None): @@ -421,7 +426,7 @@ def nodes(self, gid=None): ) ) else: - group = ServerGroup.query.filter(ServerGroup.id == gid).first() + group = get_server_group(gid) if not group: return gone( diff --git a/web/pgadmin/browser/server_groups/servers/__init__.py b/web/pgadmin/browser/server_groups/servers/__init__.py index 24825994980..cbcf79a3cea 100644 --- a/web/pgadmin/browser/server_groups/servers/__init__.py +++ b/web/pgadmin/browser/server_groups/servers/__init__.py @@ -39,12 +39,30 @@ from pgadmin.utils.constants import UNAUTH_REQ, MIMETYPE_APP_JS, \ SERVER_CONNECTION_CLOSED, RESTRICTION_TYPE_SQL from sqlalchemy import or_ +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import object_session from sqlalchemy.orm.attributes import flag_modified from pgadmin.utils.preferences import Preferences from .... import socketio as sio from pgadmin.utils import get_complete_file_path from pgadmin.settings.utils import with_object_filters +from pgadmin.utils.server_access import get_server, \ + get_user_server_query, get_server_group + + +# File-path keys in connection_params that are per-user and must +# not be copied from the owner to a new SharedServer or leaked +# through the property merge. +SENSITIVE_CONN_KEYS = frozenset({ + 'passfile', 'sslcert', 'sslkey', + 'sslrootcert', 'sslcrl', 'sslcrldir', +}) + + +def _is_non_owner(server): + """True if the server is shared and the current user is not + the owner. Centralises the check used in 15+ places.""" + return server.shared and server.user_id != current_user.id def has_any(data, keys): @@ -151,15 +169,30 @@ def script_load(self): @staticmethod def get_shared_server_properties(server, sharedserver): """ - Return shared server properties + Return shared server properties. + + Overlays per-user SharedServer values onto the owner's Server + object. Security-sensitive fields that are absent from the + SharedServer model (passexec_cmd, post_connection_sql) are + suppressed for non-owners. + + The server is expunged from the SQLAlchemy session before + mutation so that the owner's record is never dirtied. :param server: :param sharedserver: - :return: shared server + :return: shared server (detached) """ + # Detach from session so in-place mutations are never + # flushed back to the owner's Server row. + sess = object_session(server) + if sess is not None: + sess.expunge(server) + server.bgcolor = sharedserver.bgcolor server.fgcolor = sharedserver.fgcolor server.name = sharedserver.name server.role = sharedserver.role + server.service = sharedserver.service server.use_ssh_tunnel = sharedserver.use_ssh_tunnel server.tunnel_host = sharedserver.tunnel_host server.tunnel_port = sharedserver.tunnel_port @@ -169,24 +202,36 @@ def get_shared_server_properties(server, sharedserver): server.save_password = sharedserver.save_password server.tunnel_identity_file = sharedserver.tunnel_identity_file server.tunnel_prompt_password = sharedserver.tunnel_prompt_password - if hasattr(server, 'connection_params') and \ - hasattr(sharedserver, 'connection_params') and \ - 'passfile' in server.connection_params and \ - 'passfile' in sharedserver.connection_params: - server.connection_params['passfile'] = \ - sharedserver.connection_params['passfile'] + + # Override per-user connection_params keys. Use the + # SharedServer value whenever it is present, regardless of + # whether the owner's Server has the same key. + s_conn = getattr(server, 'connection_params', None) \ + or {} + ss_conn = getattr(sharedserver, 'connection_params', + None) or {} + for key in SENSITIVE_CONN_KEYS: + if key in ss_conn: + s_conn[key] = ss_conn[key] + elif key in s_conn: + # Owner has this key but non-owner doesn't — + # remove it so the owner's path doesn't leak. + del s_conn[key] + server.connection_params = s_conn + server.servergroup_id = sharedserver.servergroup_id - if hasattr(server, 'connection_params') and \ - hasattr(sharedserver, 'connection_params') and \ - 'sslcert' in server.connection_params and \ - 'sslcert' in sharedserver.connection_params: - server.connection_params['sslcert'] = \ - sharedserver.connection_params['sslcert'] server.username = sharedserver.username server.server_owner = sharedserver.server_owner server.password = sharedserver.password server.prepare_threshold = sharedserver.prepare_threshold + # Suppress owner-only fields that are absent from SharedServer + # and dangerous when inherited (privilege escalation / code + # execution). + server.passexec_cmd = None + server.passexec_expiration = None + server.post_connection_sql = None + return server def get_servers(self, all_servers, hide_shared_server, gid): @@ -203,12 +248,13 @@ def get_servers(self, all_servers, hide_shared_server, gid): if server.discovery_id and \ not server.shared and \ config.SERVER_MODE and \ - len(SharedServer.query.filter_by( + SharedServer.query.filter_by( user_id=current_user.id, - name=server.name).all()) > 0 and not hide_shared_server: + osid=server.id).first() is not None \ + and not hide_shared_server: continue - if server.shared and server.user_id != current_user.id: + if _is_non_owner(server): shared_server = self.get_shared_server(server, gid) @@ -245,8 +291,7 @@ def get_nodes(self, gid, object_filters): """Return a JSON document listing the server groups for the user""" hide_shared_server = get_preferences() - servers = Server.query.filter( - or_(Server.user_id == current_user.id, Server.shared), + servers = get_user_server_query().filter( Server.servergroup_id == gid, Server.is_adhoc == 0) driver = get_driver(PG_DEFAULT_DRIVER) @@ -392,6 +437,18 @@ def create_shared_server(data, gid): try: db.session.rollback() user = User.query.filter_by(id=data.user_id).first() + + # Strip owner's sensitive file paths from + # connection_params — each user should configure + # their own SSL/passfile paths. + safe_conn_params = {} + if data.connection_params: + safe_conn_params = { + k: v for k, v in + data.connection_params.items() + if k not in SENSITIVE_CONN_KEYS + } + shared_server = SharedServer( osid=data.id, user_id=current_user.id, @@ -410,43 +467,57 @@ def create_shared_server(data, gid): service=data.service if data.service else None, use_ssh_tunnel=data.use_ssh_tunnel, tunnel_host=data.tunnel_host, - tunnel_port=22, + tunnel_port=data.tunnel_port + if data.tunnel_port else 22, tunnel_username=None, tunnel_authentication=0, tunnel_identity_file=None, - tunnel_keep_alive=0, + tunnel_keep_alive=data.tunnel_keep_alive + if data.tunnel_keep_alive else 0, tunnel_prompt_password=0, shared=True, - connection_params=data.connection_params, + connection_params=safe_conn_params, prepare_threshold=data.prepare_threshold ) db.session.add(shared_server) db.session.commit() except Exception as e: - if shared_server: - db.session.delete(shared_server) - db.session.commit() - + db.session.rollback() raise e @staticmethod def get_shared_server(server, gid): """ - return the shared server + Return the SharedServer record for the current user, + creating one lazily if it doesn't exist. The unique + constraint on (osid, user_id) prevents duplicates from + concurrent requests. :param server: :param gid: - :return: shared_server + :return: shared_server (never None) + :raises: Exception if SharedServer cannot be created """ shared_server = SharedServer.query.filter_by( - name=server.name, user_id=current_user.id, - servergroup_id=int(gid), osid=server.id).first() + user_id=current_user.id, + osid=server.id).first() if shared_server is None: - ServerModule.create_shared_server(server, int(gid)) + try: + ServerModule.create_shared_server( + server, int(gid)) + except IntegrityError: + # Unique constraint violation from a concurrent + # request — the record now exists. + db.session.rollback() shared_server = SharedServer.query.filter_by( - name=server.name, user_id=current_user.id, - servergroup_id=int(gid), osid=server.id).first() + user_id=current_user.id, + osid=server.id).first() + + if shared_server is None: + raise Exception( + "Failed to create shared server record " + "for server {0}".format(server.id)) return shared_server @@ -495,17 +566,28 @@ class ServerNode(PGChildNodeView): 'clear_sshtunnel_password': [{'put': 'clear_sshtunnel_password'}], }) - def update_connection_parameter(self, data, server): + def update_connection_parameter(self, data, server, sharedserver=None): """ This function is used to update the connection parameters. """ if 'connection_params' in data and \ hasattr(server, 'connection_params'): - existing_conn_params = getattr(server, 'connection_params') + # For shared servers accessed by non-owners, apply changes + # to the SharedServer's connection_params (a copy) so we + # don't mutate the owner's Server record in-place. + if sharedserver is not None and \ + server.shared and \ + server.user_id != current_user.id: + existing_conn_params = dict( + sharedserver.connection_params or {}) + else: + existing_conn_params = getattr( + server, 'connection_params') new_conn_params = data['connection_params'] if 'deleted' in new_conn_params: for item in new_conn_params['deleted']: - del existing_conn_params[item['name']] + if item['name'] in existing_conn_params: + del existing_conn_params[item['name']] if 'added' in new_conn_params: for item in new_conn_params['added']: existing_conn_params[item['name']] = item['value'] @@ -560,15 +642,13 @@ def nodes(self, gid): Return a JSON document listing the servers under this server group for the user. """ - servers = Server.query.filter( - or_(Server.user_id == current_user.id, - Server.shared), + servers = get_user_server_query().filter( Server.servergroup_id == gid, Server.is_adhoc == 0) driver = get_driver(PG_DEFAULT_DRIVER) for server in servers: - if server.shared and server.user_id != current_user.id: + if _is_non_owner(server): shared_server = ServerModule.get_shared_server(server, gid) server = \ ServerModule.get_shared_server_properties(server, @@ -627,24 +707,22 @@ def nodes(self, gid): @pga_login_required def node(self, gid, sid): """Return a JSON document listing the server groups for the user""" - server = Server.query.filter_by(id=sid).first() - - if server.shared and server.user_id != current_user.id: - shared_server = ServerModule.get_shared_server(server, gid) - server = ServerModule.get_shared_server_properties(server, - shared_server) + server = get_server(sid) if server is None: return make_json_response( status=410, success=0, errormsg=gettext( - gettext( - "Could not find the server with id# {0}." - ).format(sid) - ) + "Could not find the server with id# {0}." + ).format(sid) ) + if _is_non_owner(server): + shared_server = ServerModule.get_shared_server(server, gid) + server = ServerModule.get_shared_server_properties(server, + shared_server) + manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(server.id) conn = manager.connection() connected = conn.connected() @@ -693,16 +771,20 @@ def node(self, gid, sid): ), ) - def delete_shared_server(self, server_name, gid, osid): + def delete_shared_server(self, gid, osid, user_id=None): """ - Delete the shared server - :param server_name: - :return: + Delete SharedServer records for a given original server. + :param gid: Server group ID + :param osid: Original server ID + :param user_id: If set, only delete for this user. + If None, delete for ALL users (owner unshare/delete). """ try: - shared_server = SharedServer.query.filter_by(name=server_name, - servergroup_id=gid, - osid=osid) + filters = dict(servergroup_id=gid, osid=osid) + if user_id is not None: + filters['user_id'] = user_id + shared_server = SharedServer.query.filter_by( + **filters) for s in shared_server: get_driver(PG_DEFAULT_DRIVER).delete_manager(s.id) db.session.delete(s) @@ -738,7 +820,7 @@ def delete(self, gid, sid): get_driver(PG_DEFAULT_DRIVER).delete_manager(s.id) db.session.delete(s) db.session.commit() - self.delete_shared_server(server_name, gid, sid) + self.delete_shared_server(gid, sid) QueryHistory.clear_history(current_user.id, sid) except Exception as e: @@ -754,7 +836,7 @@ def delete(self, gid, sid): @pga_login_required def update(self, gid, sid): """Update the server settings""" - server = Server.query.filter_by(id=sid).first() + server = get_server(sid) sharedserver = None if server is None: @@ -821,7 +903,7 @@ def update(self, gid, sid): data['db_res'] = ','.join(data['db_res']) # Update connection parameter if any. - self.update_connection_parameter(data, server) + self.update_connection_parameter(data, server, sharedserver) self.update_tags(data, server) if 'connection_params' in data and \ @@ -878,7 +960,7 @@ def update(self, gid, sid): server.name, server_icon_and_background( connected, manager, sharedserver) - if server.shared and server.user_id != current_user.id + if _is_non_owner(server) else server_icon_and_background( connected, manager, server), True, @@ -902,7 +984,7 @@ def _update_server_details(server, sharedserver, if value == '': value = None - if server.shared and server.user_id != current_user.id: + if _is_non_owner(server): setattr(sharedserver, config_param_map[arg], value) else: setattr(server, config_param_map[arg], value) @@ -921,17 +1003,20 @@ def _set_valid_attr_value(self, gid, data, config_param_map, server, value = data[arg] if arg == 'password': value = encrypt(data[arg], crypt_key) - # sqlite3 do not have boolean type so we need to convert - # it manually to integer - if 'shared' in data and not data['shared']: - # Delete the shared server from DB if server - # owner uncheck shared property - self.delete_shared_server(server.name, gid, server.id) + # sqlite3 do not have boolean type so we need to + # convert it manually to integer. + # Only the owner may unshare — this deletes ALL + # users' SharedServer records. + if 'shared' in data and not data['shared'] \ + and not _is_non_owner(server): + self.delete_shared_server(gid, server.id) if arg in ('sslcompression', 'use_ssh_tunnel', - 'tunnel_authentication', 'kerberos_conn', 'shared'): + 'tunnel_authentication', + 'kerberos_conn', 'shared'): value = 1 if value else 0 self._update_server_details(server, sharedserver, - config_param_map, arg, value) + config_param_map, arg, + value) idx += 1 return idx @@ -956,19 +1041,16 @@ def list(self, gid, object_filters): """ Return list of attributes of all servers. """ - servers = Server.query.filter( - or_(Server.user_id == current_user.id, Server.shared), + servers = get_user_server_query().filter( Server.servergroup_id == gid, Server.is_adhoc == 0).order_by(Server.name) - sg = ServerGroup.query.filter_by( - id=gid - ).first() + sg = get_server_group(gid) res = [] driver = get_driver(PG_DEFAULT_DRIVER) for server in servers: - if server.shared and server.user_id != current_user.id: + if _is_non_owner(server): shared_server = ServerModule.get_shared_server(server, gid) server = \ ServerModule.get_shared_server_properties(server, @@ -1002,8 +1084,7 @@ def list(self, gid, object_filters): def properties(self, gid, sid): """Return list of attributes of a server""" - server = Server.query.filter_by( - id=sid).first() + server = get_server(sid) if server is None: return make_json_response( @@ -1026,7 +1107,7 @@ def properties(self, gid, sid): # port and user when server is connected display_connection_str = self.update_connection_string(manager, server) - if server.shared and server.user_id != current_user.id: + if _is_non_owner(server): shared_server = ServerModule.get_shared_server(server, gid) server = ServerModule.get_shared_server_properties(server, shared_server) @@ -1079,10 +1160,13 @@ def properties(self, gid, sid): 'db_res': get_db_restriction(server.db_res_type, server.db_res), 'db_res_type': server.db_res_type, 'passexec_cmd': - server.passexec_cmd if server.passexec_cmd else None, + server.passexec_cmd + if server.passexec_cmd and + not _is_non_owner(server) else None, 'passexec_expiration': - server.passexec_expiration if server.passexec_expiration - else None, + server.passexec_expiration + if server.passexec_expiration and + not _is_non_owner(server) else None, 'service': server.service if server.service else None, 'use_ssh_tunnel': use_ssh_tunnel, 'tunnel_host': tunnel_host, @@ -1102,7 +1186,8 @@ def properties(self, gid, sid): 'connection_string': display_connection_str, 'prepare_threshold': server.prepare_threshold, 'tags': tags, - 'post_connection_sql': server.post_connection_sql, + 'post_connection_sql': server.post_connection_sql + if not _is_non_owner(server) else None, } return ajax_response(response) @@ -1395,7 +1480,12 @@ def supported_servers(self, **kwargs): def connect_status(self, gid, sid): """Check and return the connection status.""" - server = Server.query.filter_by(id=sid).first() + server = get_server(sid) + if server is None: + return make_json_response( + status=410, success=0, + errormsg=self.not_found_error_msg() + ) manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(sid) conn = manager.connection() connected = conn.connected() @@ -1464,18 +1554,16 @@ def connect(self, gid, sid, is_qt=False, server=None): # function in that case no need to fetch the server detail based on # sid. if server is None: - server = Server.query.filter_by(id=sid).first() + server = get_server(sid) + + if server is None: + return bad_request(self.not_found_error_msg()) shared_server = None - if server.shared and server.user_id != current_user.id: + if _is_non_owner(server): shared_server = ServerModule.get_shared_server(server, gid) - sess = object_session(server) - if sess is not None: - sess.expunge(server) server = ServerModule.get_shared_server_properties(server, shared_server) - if server is None: - return bad_request(self.not_found_error_msg()) # Return if username is blank and the server is shared if server.username is None and not server.service and \ @@ -1617,12 +1705,8 @@ def connect(self, gid, sid, is_qt=False, server=None): else: if save_password and config.ALLOW_SAVE_PASSWORD: try: - # If DB server is running in trust mode then password may - # not be available but we don't need to ask password - # every time user try to connect # 1 is True in SQLite as no boolean type - setattr(server, 'save_password', 1) - if server.shared and server.user_id != current_user.id: + if _is_non_owner(server): setattr(shared_server, 'save_password', 1) else: setattr(server, 'save_password', 1) @@ -1630,7 +1714,7 @@ def connect(self, gid, sid, is_qt=False, server=None): # Save the encrypted password using the user's login # password key, if there is any password to save if password: - if server.shared and server.user_id != current_user.id: + if _is_non_owner(server): setattr(shared_server, 'password', password) else: setattr(server, 'password', password) @@ -1646,7 +1730,11 @@ def connect(self, gid, sid, is_qt=False, server=None): if save_tunnel_password and config.ALLOW_SAVE_TUNNEL_PASSWORD: try: # Save the encrypted tunnel password. - setattr(server, 'tunnel_password', tunnel_password) + if _is_non_owner(server): + setattr(shared_server, 'tunnel_password', + tunnel_password) + else: + setattr(server, 'tunnel_password', tunnel_password) db.session.commit() except Exception as e: # Release Connection @@ -1693,7 +1781,7 @@ def connect(self, gid, sid, is_qt=False, server=None): def disconnect(self, gid, sid): """Disconnect the Server.""" - server = Server.query.filter_by(id=sid).first() + server = get_server(sid) if server is None: return bad_request(self.not_found_error_msg()) @@ -1818,7 +1906,7 @@ def change_password(self, gid, sid): raise CryptKeyMissing # Fetch Server Details - server = Server.query.filter_by(id=sid).first() + server = get_server(sid, only_owned=False) if server is None: return bad_request(self.not_found_error_msg()) @@ -1905,11 +1993,24 @@ def change_password(self, gid, sid): # Store password in sqlite only if no pgpass file if not is_passfile: password = encrypt(data['newPassword'], crypt_key) - # Check if old password was stored in pgadmin4 sqlite database. - # If yes then update that password. - if server.password is not None and config.ALLOW_SAVE_PASSWORD: - setattr(server, 'password', password) - db.session.commit() + # Check if old password was stored in pgadmin4 + # sqlite database. If yes then update that password. + # For non-owners of shared servers, check the + # SharedServer record (not the owner's Server). + if config.ALLOW_SAVE_PASSWORD: + if server.shared and \ + server.user_id != current_user.id: + shared_server = \ + ServerModule.get_shared_server( + server, gid) + if shared_server and \ + shared_server.password is not None: + setattr(shared_server, 'password', + password) + db.session.commit() + elif server.password is not None: + setattr(server, 'password', password) + db.session.commit() # Also update password in connection manager. manager.password = password manager.update_session() @@ -1929,9 +2030,7 @@ def wal_replay(self, sid, pause=True): """ Utility function for wal_replay for resume/pause. """ - server = Server.query.filter_by( - user_id=current_user.id, id=sid - ).first() + server = get_server(sid) if server is None: return make_json_response( @@ -2015,9 +2114,7 @@ def check_pgpass(self, gid, sid): sid: Server id """ is_pgpass = False - server = Server.query.filter_by( - user_id=current_user.id, id=sid - ).first() + server = get_server(sid) if server is None: return make_json_response( @@ -2108,38 +2205,22 @@ def clear_saved_password(self, gid, sid): :return: """ try: - server = Server.query.filter_by(id=sid).first() - shared_server = None + server = get_server(sid, only_owned=False) if server is None: return make_json_response( success=0, info=self.not_found_error_msg() ) - if server.shared and server.user_id != current_user.id: - shared_server = SharedServer.query.filter_by( - name=server.name, user_id=current_user.id, - servergroup_id=gid, osid=server.id).first() - - if shared_server is None: - return make_json_response( - success=0, - info=gettext("Could not find the required server.") - ) - server = ServerModule. \ - get_shared_server_properties(server, shared_server) - - if server.shared and server.user_id != current_user.id: + if _is_non_owner(server): + shared_server = ServerModule.get_shared_server( + server, gid) setattr(shared_server, 'password', None) + if shared_server.save_password: + setattr(shared_server, 'save_password', 0) else: setattr(server, 'password', None) - - # If password was saved then clear the flag also - # 0 is False in SQLite db - if server.save_password: - if server.shared and server.user_id != current_user.id: - setattr(shared_server, 'save_password', 0) - else: + if server.save_password: setattr(server, 'save_password', 0) db.session.commit() except Exception as e: @@ -2165,13 +2246,19 @@ def clear_sshtunnel_password(self, gid, sid): :return: """ try: - server = Server.query.filter_by(id=sid).first() + server = get_server(sid, only_owned=False) if server is None: return make_json_response( success=0, info=self.not_found_error_msg() ) - setattr(server, 'tunnel_password', None) + + if _is_non_owner(server): + shared_server = ServerModule.get_shared_server( + server, gid) + setattr(shared_server, 'tunnel_password', None) + else: + setattr(server, 'tunnel_password', None) db.session.commit() except Exception as e: current_app.logger.error( diff --git a/web/pgadmin/browser/server_groups/servers/databases/__init__.py b/web/pgadmin/browser/server_groups/servers/databases/__init__.py index 1e88b59d1a4..1922db32c93 100644 --- a/web/pgadmin/browser/server_groups/servers/databases/__init__.py +++ b/web/pgadmin/browser/server_groups/servers/databases/__init__.py @@ -34,6 +34,7 @@ from pgadmin.tools.schema_diff.node_registry import SchemaDiffRegistry from pgadmin.model import db, Server, Database +from pgadmin.utils.server_access import get_server from pgadmin.browser.utils import underscore_escape from pgadmin.utils.constants import TWO_PARAM_STRING @@ -579,7 +580,9 @@ def connect(self, gid, sid, did): 'already_connected': already_connected, 'connected': True, 'info_prefix': TWO_PARAM_STRING. - format(Server.query.filter_by(id=sid)[0].name, conn.db) + format(getattr( + get_server(sid), 'name', None) or + _('Unknown'), conn.db) } ) @@ -602,7 +605,9 @@ def disconnect(self, gid, sid, did): 'icon': 'icon-database-not-connected', 'connected': False, 'info_prefix': TWO_PARAM_STRING. - format(Server.query.filter_by(id=sid)[0].name, conn.db) + format(getattr( + get_server(sid), 'name', None) or + _('Unknown'), conn.db) } ) diff --git a/web/pgadmin/browser/server_groups/servers/databases/schemas/views/__init__.py b/web/pgadmin/browser/server_groups/servers/databases/schemas/views/__init__.py index d978ebefd74..5b73863bf9a 100644 --- a/web/pgadmin/browser/server_groups/servers/databases/schemas/views/__init__.py +++ b/web/pgadmin/browser/server_groups/servers/databases/schemas/views/__init__.py @@ -29,8 +29,9 @@ from pgadmin.utils.driver import get_driver from pgadmin.tools.schema_diff.node_registry import SchemaDiffRegistry from .schema_diff_view_utils import SchemaDiffViewCompare -from pgadmin.utils import does_utility_exist, get_server +from pgadmin.utils import does_utility_exist from pgadmin.model import Server +from pgadmin.utils.server_access import get_server from pgadmin.misc.bgprocess.processes import BatchProcess, IProcessDesc from pgadmin.utils.constants import SERVER_NOT_FOUND @@ -2317,8 +2318,7 @@ def refresh_data(self, gid, sid, did, scid, vid): res['rows'][0]['name']) # Fetch the server details like hostname, port, roles etc - server = Server.query.filter_by( - id=sid).first() + server = get_server(sid) if server is None: return make_json_response( @@ -2436,9 +2436,7 @@ def check_utility_exists(self, gid, sid, did, scid, vid): Returns: None """ - server = Server.query.filter_by( - id=sid, user_id=current_user.id - ).first() + server = get_server(sid) if server is None: return make_json_response( diff --git a/web/pgadmin/browser/server_groups/servers/utils.py b/web/pgadmin/browser/server_groups/servers/utils.py index d9ef4842a8f..3377d99cdae 100644 --- a/web/pgadmin/browser/server_groups/servers/utils.py +++ b/web/pgadmin/browser/server_groups/servers/utils.py @@ -13,13 +13,14 @@ import keyring from flask_login import current_user from werkzeug.exceptions import InternalServerError -from flask import render_template +from flask import render_template, has_request_context from pgadmin.utils.constants import ( KEY_RING_USERNAME_FORMAT, KEY_RING_SERVICE_NAME, KEY_RING_TUNNEL_FORMAT, KEY_RING_DESKTOP_USER, SSL_MODES, RESTRICTION_TYPE_DATABASES, RESTRICTION_TYPE_SQL) from pgadmin.utils.crypto import encrypt, decrypt from pgadmin.model import db, Server, SharedServer +from pgadmin.utils.server_access import get_user_server_query from flask import current_app from pgadmin.utils.master_password import set_masterpass_check_text from pgadmin.utils.driver import get_driver @@ -324,7 +325,10 @@ def migrate_passwords_from_pgadmin_db(servers, old_key, enc_key): def get_servers_with_saved_passwords(): - all_server = Server.query.filter(Server.is_adhoc == 0) + all_server = Server.query.filter( + Server.user_id == current_user.id, + Server.is_adhoc == 0 + ) servers_with_pwd_in_os_secret = [] servers_with_pwd_in_pgadmin_db = [] saved_password_servers = [] @@ -648,32 +652,56 @@ def check_ssl_fields(data): def disconnect_from_all_servers(): """ - This function is used to disconnect all the servers + This function is used to disconnect all the servers for the + current user (owned + shared). """ - all_servers = Server.query.all() + all_servers = get_user_server_query().all() for server in all_servers: - manager = get_driver(config.PG_DEFAULT_DRIVER).connection_manager( - server.id) - # Check if any psql terminal is running for the current disconnecting - # server. If any terminate the psql tool connection. - if 'sid_soid_mapping' in current_app.config and str(server.id) in \ - current_app.config['sid_soid_mapping'] and \ - str(server.id) in current_app.config['sid_soid_mapping']: - for i in current_app.config['sid_soid_mapping'][str(server.id)]: - sio.emit('disconnect-psql', namespace='/pty', to=i) - - manager.release() + try: + manager = get_driver( + config.PG_DEFAULT_DRIVER + ).connection_manager(server.id) + # Only emit disconnect-psql for servers owned by the + # current user — shared servers may have other users' + # PSQL sessions mapped to the same sid. + if server.user_id == current_user.id and \ + 'sid_soid_mapping' in current_app.config \ + and str(server.id) in \ + current_app.config['sid_soid_mapping']: + for i in current_app.config[ + 'sid_soid_mapping'][str(server.id)]: + sio.emit( + 'disconnect-psql', + namespace='/pty', to=i + ) + manager.release() + except Exception: + current_app.logger.warning( + 'Failed to disconnect server %s', + server.id, exc_info=True + ) def delete_adhoc_servers(sid=None): """ - This function will remove all the adhoc servers. + This function will remove adhoc servers. When called with a + current_user context, scopes to the current user. When called + during app startup (no user context), cleans all adhoc servers. """ try: + has_user = (has_request_context() and + current_user and current_user.is_authenticated) if sid is not None: - db.session.query(Server).filter(Server.id == sid).delete() + q = db.session.query(Server).filter( + Server.id == sid, Server.is_adhoc == 1) + if has_user: + q = q.filter(Server.user_id == current_user.id) + q.delete() else: - db.session.query(Server).filter(Server.is_adhoc == 1).delete() + q = db.session.query(Server).filter(Server.is_adhoc == 1) + if has_user: + q = q.filter(Server.user_id == current_user.id) + q.delete() db.session.commit() # Reset the sequence again diff --git a/web/pgadmin/misc/bgprocess/processes.py b/web/pgadmin/misc/bgprocess/processes.py index 5f59ec0abe9..9a44a452d95 100644 --- a/web/pgadmin/misc/bgprocess/processes.py +++ b/web/pgadmin/misc/bgprocess/processes.py @@ -153,7 +153,7 @@ def __init__(self, **kwargs): self.manager_obj = kwargs['manager_obj'] def _retrieve_process(self, _id): - p = Process.query.filter_by(pid=_id, user_id=current_user.id).first() + p = Process.for_user(pid=_id).first() if p is None: raise LookupError(PROCESS_NOT_FOUND) @@ -372,9 +372,7 @@ def start(self, cb=None): # There is no way to find out the error message from this process # as standard output, and standard error were redirected to # devnull. - p = Process.query.filter_by( - pid=self.id, user_id=current_user.id - ).first() + p = Process.for_user(pid=self.id).first() p.start_time = p.end_time = get_current_time() if not p.exit_code: p.exit_code = self.ecode @@ -382,9 +380,7 @@ def start(self, cb=None): db.session.commit() else: # Update the process state to "Started" - p = Process.query.filter_by( - pid=self.id, user_id=current_user.id - ).first() + p = Process.for_user(pid=self.id).first() p.process_state = PROCESS_STARTED db.session.commit() @@ -530,9 +526,7 @@ def update_cloud_details(self): """ _pid = self.id - _process = Process.query.filter_by( - user_id=current_user.id, pid=_pid - ).first() + _process = Process.for_user(pid=_pid).first() if _process is None: raise LookupError(PROCESS_NOT_FOUND) @@ -588,9 +582,7 @@ def status(self, out=0, err=0): out_completed = err_completed = False process_output = (out != -1 and err != -1) - j = Process.query.filter_by( - pid=self.id, user_id=current_user.id - ).first() + j = Process.for_user(pid=self.id).first() enc = sys.getdefaultencoding() if enc == 'ascii': enc = 'utf-8' @@ -739,7 +731,7 @@ def _check_process_desc(p): @staticmethod def list(): - processes = Process.query.filter_by(user_id=current_user.id) + processes = Process.for_user() changed = False browser_preference = Preferences.module('browser') @@ -812,9 +804,7 @@ def acknowledge(_pid): And, delete the process information from the configuration, and the log files related to the process, if it has already been completed. """ - p = Process.query.filter_by( - user_id=current_user.id, pid=_pid - ).first() + p = Process.for_user(pid=_pid).first() if p is None: raise LookupError(PROCESS_NOT_FOUND) @@ -886,9 +876,7 @@ def set_env_variables(self, server, **kwargs): def stop_process(_pid): """ """ - p = Process.query.filter_by( - user_id=current_user.id, pid=_pid - ).first() + p = Process.for_user(pid=_pid).first() if p is None: raise LookupError(PROCESS_NOT_FOUND) @@ -910,9 +898,7 @@ def stop_process(_pid): @staticmethod def update_server_id(_pid, _sid): - p = Process.query.filter_by( - user_id=current_user.id, pid=_pid - ).first() + p = Process.for_user(pid=_pid).first() if p is None: raise LookupError(PROCESS_NOT_FOUND) diff --git a/web/pgadmin/misc/cloud/__init__.py b/web/pgadmin/misc/cloud/__init__.py index 6d61d3a97a1..f28484c4797 100644 --- a/web/pgadmin/misc/cloud/__init__.py +++ b/web/pgadmin/misc/cloud/__init__.py @@ -212,8 +212,9 @@ def clear_cloud_session(pid=None): @pga_login_required def update_cloud_process(sid): """Update Cloud Server Process""" - _process = Process.query.filter_by(user_id=current_user.id, - server_id=sid).first() + _process = Process.for_user(server_id=sid).first() + if _process is None: + return success_return() _process.acknowledge = None db.session.commit() return success_return() diff --git a/web/pgadmin/misc/workspaces/__init__.py b/web/pgadmin/misc/workspaces/__init__.py index 1a99037a7e1..afb20b5e81b 100644 --- a/web/pgadmin/misc/workspaces/__init__.py +++ b/web/pgadmin/misc/workspaces/__init__.py @@ -17,6 +17,7 @@ from flask_security import current_user from pgadmin.utils import PgAdminModule from pgadmin.model import db, Server +from pgadmin.utils.server_access import get_server from pgadmin.utils.driver import get_driver from pgadmin.utils.ajax import bad_request, make_json_response from pgadmin.browser.server_groups.servers.utils import ( @@ -132,7 +133,8 @@ def adhoc_connect_server(): username=new_username, name=new_server_name, role=new_role, - service=new_service + service=new_service, + user_id=current_user.id ).all() # If found matching servers then compare the connection_params as @@ -143,22 +145,27 @@ def adhoc_connect_server(): server = existing_server break else: - server = Server.query.filter_by(host=new_host, - port=new_port, - maintenance_db=new_db, - username=new_username, - name=new_server_name, - role=new_role, - service=new_service, - connection_params=connection_params - ).first() + server = Server.query.filter_by( + host=new_host, port=new_port, + maintenance_db=new_db, + username=new_username, + name=new_server_name, + role=new_role, + service=new_service, + connection_params=connection_params, + user_id=current_user.id + ).first() # If server is none then no server with the above combination is found. if server is None: # Check if sid is present in data if it is then used that sid. if ('sid' in data and data['sid'] is not None and int(data['sid']) > 0): - server = Server.query.filter_by(id=data['sid']).first() + server = get_server(data['sid']) + if server is None: + return bad_request(gettext( + "Could not find the required server." + )) # Clone the server object server = server.clone() @@ -220,23 +227,30 @@ def check_and_delete_adhoc_server(sid): This function is used to check for adhoc server and if all Query Tool and PSQL connections are closed then delete that server. """ - server = Server.query.filter_by(id=sid).first() - if server.is_adhoc: - # Check PSQL connections. If more connections are open for - # the given sid return from the function. - psql_connections = get_open_psql_connections() - if sid in psql_connections.values(): + server = get_server(sid) + if server is None: + # Server may be deleted or inaccessible; still attempt + # best-effort cleanup of adhoc state. + delete_adhoc_servers(sid) + return + if not server.is_adhoc: + return + + # Check PSQL connections. If more connections are open for + # the given sid return from the function. + psql_connections = get_open_psql_connections() + if sid in psql_connections.values(): + return + + # Check Query Tool connections for the given sid + manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(sid) + for key, value in manager.connections.items(): + if key.startswith('CONN') and value.connected(): return - # Check Query Tool connections for the given sid - manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(sid) - for key, value in manager.connections.items(): - if key.startswith('CONN') and value.connected(): - return - - # Assumption at this point all the Query Tool and PSQL connections - # is closed, so now we can release the manager - manager.release() + # Assumption at this point all the Query Tool and PSQL connections + # is closed, so now we can release the manager + manager.release() - # Delete the adhoc server from the pgadmin database - delete_adhoc_servers(sid) + # Delete the adhoc server from the pgadmin database + delete_adhoc_servers(sid) diff --git a/web/pgadmin/model/__init__.py b/web/pgadmin/model/__init__.py index 69b934683dd..62d89ca9412 100644 --- a/web/pgadmin/model/__init__.py +++ b/web/pgadmin/model/__init__.py @@ -33,7 +33,7 @@ # ########################################################################## -SCHEMA_VERSION = 49 +SCHEMA_VERSION = 50 ########################################################################## # @@ -51,6 +51,60 @@ SERVER_ID = 'server.id' CASCADE_STR = "all, delete-orphan" + +class UserScopedMixin: + """Mixin for models that store per-user data. + + Provides for_user() as the default scoped query entry point. + Models with a 'user_id' column or a 'uid' column are supported + automatically — the mixin detects which column name is used. + + Usage: + # Instead of: + Process.query.filter_by(user_id=current_user.id, pid=pid) + # Use: + Process.for_user(pid=pid) + """ + + @classmethod + def _user_column(cls): + """Return the user-scoping column for this model.""" + if hasattr(cls, 'user_id'): + return cls.user_id + if hasattr(cls, 'uid'): + return cls.uid + raise AttributeError( + f"{cls.__name__} has no user_id or uid column" + ) + + @classmethod + def _user_column_name(cls): + """Return the column name string ('user_id' or 'uid').""" + if hasattr(cls, 'user_id'): + return 'user_id' + if hasattr(cls, 'uid'): + return 'uid' + raise AttributeError( + f"{cls.__name__} has no user_id or uid column" + ) + + @classmethod + def for_user(cls, user_id=None, **kwargs): + """Query scoped to a specific user (defaults to current_user). + + Args: + user_id: Explicit user ID. If None, uses current_user.id. + **kwargs: Additional filter_by arguments. + + Returns: + A SQLAlchemy query filtered by the user's ID. + """ + from flask_security import current_user as cu + uid = user_id if user_id is not None else cu.id + kwargs[cls._user_column_name()] = uid + return cls.query.filter_by(**kwargs) + + # Define models roles_users = db.Table( 'roles_users', @@ -158,7 +212,7 @@ class User(db.Model, UserMixin): locked = db.Column(db.Boolean(), default=False) -class Setting(db.Model): +class Setting(db.Model, UserScopedMixin): """Define a setting object""" __tablename__ = 'setting' user_id = db.Column(db.Integer, db.ForeignKey(USER_ID), primary_key=True) @@ -166,7 +220,7 @@ class Setting(db.Model): value = db.Column(db.Text()) -class ServerGroup(db.Model): +class ServerGroup(db.Model, UserScopedMixin): """Define a server group for the treeview""" __tablename__ = 'servergroup' id = db.Column(db.Integer, primary_key=True) @@ -185,7 +239,7 @@ def serialize(self): } -class Server(db.Model): +class Server(db.Model, UserScopedMixin): """Define a registered Postgres server""" __tablename__ = 'server' id = db.Column(db.Integer, primary_key=True) @@ -306,7 +360,7 @@ class Preferences(db.Model): name = db.Column(db.String(1024), nullable=False) -class UserPreference(db.Model): +class UserPreference(db.Model, UserScopedMixin): """Define the preference for a particular user.""" __tablename__ = 'user_preferences' pid = db.Column( @@ -318,9 +372,13 @@ class UserPreference(db.Model): value = db.Column(db.String(1024), nullable=False) -class DebuggerFunctionArguments(db.Model): +class DebuggerFunctionArguments(db.Model, UserScopedMixin): """Define the debugger input function arguments.""" __tablename__ = 'debugger_function_arguments' + user_id = db.Column( + db.Integer, db.ForeignKey(USER_ID), + nullable=False, primary_key=True + ) server_id = db.Column(db.Integer(), nullable=False, primary_key=True) database_id = db.Column(db.Integer(), nullable=False, primary_key=True) schema_id = db.Column(db.Integer(), nullable=False, primary_key=True) @@ -349,7 +407,7 @@ class DebuggerFunctionArguments(db.Model): value = db.Column(db.String(), nullable=True) -class Process(db.Model): +class Process(db.Model, UserScopedMixin): """Define the Process table.""" __tablename__ = 'process' pid = db.Column(db.String(), nullable=False, primary_key=True) @@ -382,7 +440,7 @@ class Keys(db.Model): value = db.Column(db.String(), nullable=False) -class QueryHistoryModel(db.Model): +class QueryHistoryModel(db.Model, UserScopedMixin): """Define the history SQL table.""" __tablename__ = 'query_history' srno = db.Column(db.Integer(), nullable=False, primary_key=True) @@ -397,7 +455,7 @@ class QueryHistoryModel(db.Model): last_updated_flag = db.Column(db.String(), nullable=False) -class ApplicationState(db.Model): +class ApplicationState(db.Model, UserScopedMixin): """Define the application state SQL table.""" __tablename__ = 'application_state' uid = db.Column(db.Integer(), db.ForeignKey(USER_ID), nullable=False, @@ -422,10 +480,14 @@ class Database(db.Model): ) -class SharedServer(db.Model): +class SharedServer(db.Model, UserScopedMixin): """Define a shared Postgres server""" __tablename__ = 'sharedserver' + __table_args__ = ( + db.UniqueConstraint('osid', 'user_id', + name='uq_sharedserver_osid_user'), + ) id = db.Column(db.Integer, primary_key=True) osid = db.Column( db.Integer, @@ -510,7 +572,7 @@ class Macros(db.Model): key_code = db.Column(db.Integer, nullable=False) -class UserMacros(db.Model): +class UserMacros(db.Model, UserScopedMixin): """Define the macro for a particular user.""" __tablename__ = 'user_macros' id = db.Column(db.Integer, primary_key=True, autoincrement=True) @@ -524,7 +586,7 @@ class UserMacros(db.Model): sql = db.Column(db.Text(), nullable=False) -class UserMFA(db.Model): +class UserMFA(db.Model, UserScopedMixin): """Stores the options for the MFA for a particular user.""" __tablename__ = 'user_mfa' user_id = db.Column(db.Integer, db.ForeignKey(USER_ID), primary_key=True) diff --git a/web/pgadmin/tools/debugger/__init__.py b/web/pgadmin/tools/debugger/__init__.py index e9a25deca0c..7d00f0fae17 100644 --- a/web/pgadmin/tools/debugger/__init__.py +++ b/web/pgadmin/tools/debugger/__init__.py @@ -16,7 +16,7 @@ from flask import render_template, request, current_app from flask_babel import gettext -from flask_security import permissions_required +from flask_security import permissions_required, current_user from pgadmin.user_login_check import pga_login_required from werkzeug.user_agent import UserAgent @@ -35,7 +35,9 @@ import get_extension_details from pgadmin.utils.constants import PREF_LABEL_KEYBOARD_SHORTCUTS, \ SERVER_CONNECTION_CLOSED -from pgadmin.tools.user_management.PgAdminPermissions import AllPermissionTypes +from pgadmin.tools.user_management.PgAdminPermissions \ + import AllPermissionTypes +from pgadmin.utils.server_access import get_server from pgadmin.preferences import preferences MODULE_NAME = 'debugger' @@ -1803,12 +1805,19 @@ def get_arguments_sqlite(sid, did, scid, func_id): - Function Id """ + if get_server(sid) is None: + return make_json_response( + status=410, success=0, + errormsg=gettext("Could not find the required server.") + ) + """Get the count of the existing data available in sqlite database""" dbg_func_args_count = int(DebuggerFunctionArguments.query.filter_by( server_id=sid, database_id=did, schema_id=scid, - function_id=func_id + function_id=func_id, + user_id=current_user.id ).count()) args_data = [] @@ -1819,7 +1828,8 @@ def get_arguments_sqlite(sid, did, scid, func_id): server_id=sid, database_id=did, schema_id=scid, - function_id=func_id + function_id=func_id, + user_id=current_user.id ) args_list = dbg_func_args.all() @@ -1888,6 +1898,12 @@ def set_arguments_sqlite(sid, did, scid, func_id): - Function Id """ + if get_server(sid) is None: + return make_json_response( + status=410, success=0, + errormsg=gettext("Could not find the required server.") + ) + if request.data: data = json.loads(request.data) @@ -1899,7 +1915,8 @@ def set_arguments_sqlite(sid, did, scid, func_id): database_id=data[i]['database_id'], schema_id=data[i]['schema_id'], function_id=data[i]['function_id'], - arg_id=data[i]['arg_id']).count()) + arg_id=data[i]['arg_id'], + user_id=current_user.id).count()) # handle the Array list sent from the client array_string = '' @@ -1918,7 +1935,8 @@ def set_arguments_sqlite(sid, did, scid, func_id): database_id=data[i]['database_id'], schema_id=data[i]['schema_id'], function_id=data[i]['function_id'], - arg_id=data[i]['arg_id'] + arg_id=data[i]['arg_id'], + user_id=current_user.id ).first() dbg_func_args.is_null = data[i]['is_null'] @@ -1932,6 +1950,7 @@ def set_arguments_sqlite(sid, did, scid, func_id): schema_id=data[i]['schema_id'], function_id=data[i]['function_id'], arg_id=data[i]['arg_id'], + user_id=current_user.id, is_null=data[i]['is_null'], is_expression=data[i]['is_expression'], use_default=data[i]['use_default'], @@ -1977,12 +1996,20 @@ def clear_arguments_sqlite(sid, did, scid, func_id): - Function Id """ + if get_server(sid) is None: + return make_json_response( + status=410, success=0, + errormsg=gettext("Could not find the required server.") + ) + try: db.session.query(DebuggerFunctionArguments) \ .filter(DebuggerFunctionArguments.server_id == sid, DebuggerFunctionArguments.database_id == did, DebuggerFunctionArguments.schema_id == scid, - DebuggerFunctionArguments.function_id == func_id) \ + DebuggerFunctionArguments.function_id == func_id, + DebuggerFunctionArguments.user_id == + current_user.id) \ .delete() db.session.commit() diff --git a/web/pgadmin/tools/erd/__init__.py b/web/pgadmin/tools/erd/__init__.py index 16d9e4b7b8e..b70eb9f63d9 100644 --- a/web/pgadmin/tools/erd/__init__.py +++ b/web/pgadmin/tools/erd/__init__.py @@ -20,6 +20,7 @@ SHORTCUT_FIELDS as shortcut_fields from pgadmin.utils.ajax import make_json_response, internal_server_error from pgadmin.model import Server +from pgadmin.utils.server_access import get_server from config import PG_DEFAULT_DRIVER, ALLOW_SAVE_PASSWORD from pgadmin.utils.driver import get_driver from pgadmin.browser.utils import underscore_unescape @@ -556,7 +557,7 @@ def panel(trans_id): if "linux" in _platform: is_linux_platform = True - s = Server.query.filter_by(id=int(params['sid'])).first() + s = get_server(int(params['sid'])) if s: params.update({ diff --git a/web/pgadmin/tools/import_export/__init__.py b/web/pgadmin/tools/import_export/__init__.py index 30edc4f18c1..d7bd7c065b3 100644 --- a/web/pgadmin/tools/import_export/__init__.py +++ b/web/pgadmin/tools/import_export/__init__.py @@ -23,6 +23,7 @@ from config import PG_DEFAULT_DRIVER from pgadmin.model import Server +from pgadmin.utils.server_access import get_server from pgadmin.utils.constants import SERVER_NOT_FOUND from pgadmin.settings import get_setting, store_setting from pgadmin.tools.user_management.PgAdminPermissions import AllPermissionTypes @@ -97,9 +98,7 @@ def cmd_arg(x): def get_server_name(self): # Fetch the server details like hostname, port, roles etc - s = Server.query.filter_by( - id=self.sid, user_id=current_user.id - ).first() + s = get_server(self.sid) if s is None: return _("Not available") @@ -293,8 +292,7 @@ def create_import_export_job(sid): data = json.loads(request.data) # Fetch the server details like hostname, port, roles etc - server = Server.query.filter_by( - id=sid).first() + server = get_server(sid) if server is None: return bad_request(errormsg=_("Could not find the specified server.")) diff --git a/web/pgadmin/tools/psql/__init__.py b/web/pgadmin/tools/psql/__init__.py index 1b51cca66da..ab1e22fd3c0 100644 --- a/web/pgadmin/tools/psql/__init__.py +++ b/web/pgadmin/tools/psql/__init__.py @@ -29,6 +29,7 @@ from pgadmin.utils import get_complete_file_path from pgadmin.authenticate import socket_login_required from pgadmin.model import Server +from pgadmin.utils.server_access import get_server if _platform == 'win32': # Check Windows platform support for WinPty api, Disable psql @@ -98,7 +99,7 @@ def panel(trans_id): if 'sid_soid_mapping' not in app.config: app.config['sid_soid_mapping'] = dict() - s = Server.query.filter_by(id=int(params['sid'])).first() + s = get_server(int(params['sid'])) if s: data = _get_database_role(params['sid'], params['did']) if data: diff --git a/web/pgadmin/tools/schema_diff/__init__.py b/web/pgadmin/tools/schema_diff/__init__.py index bc244e0b2ac..2470a4db1c5 100644 --- a/web/pgadmin/tools/schema_diff/__init__.py +++ b/web/pgadmin/tools/schema_diff/__init__.py @@ -31,6 +31,8 @@ from pgadmin.authenticate import socket_login_required from pgadmin import socketio from pgadmin.tools.user_management.PgAdminPermissions import AllPermissionTypes +from pgadmin.utils.server_access import \ + get_server as get_server_access, get_user_server_query MODULE_NAME = 'schema_diff' COMPARE_MSG = gettext("Comparing objects...") @@ -283,18 +285,14 @@ def servers(): from pgadmin.browser.server_groups.servers import\ server_icon_and_background - for server in Server.query.filter( - or_(Server.user_id == current_user.id, Server.shared), + for server in get_user_server_query().filter( Server.is_adhoc == 0): shared_server = SharedServer.query.filter_by( - name=server.name, user_id=current_user.id, - servergroup_id=server.servergroup_id).first() + user_id=current_user.id, + osid=server.id).first() - if server.discovery_id: - auto_detected_server = server.name - - if shared_server and shared_server.name == auto_detected_server: + if server.discovery_id and shared_server: continue manager = driver.connection_manager(server.id) @@ -336,7 +334,13 @@ def get_server(sid, did): """Return a JSON document listing the server groups for the user""" driver = get_driver(PG_DEFAULT_DRIVER) - server = Server.query.filter_by(id=sid).first() + server = get_server_access(sid) + if server is None: + return make_json_response( + status=410, success=0, + errormsg=gettext( + "Could not find the required server.") + ) manager = driver.connection_manager(sid) conn = manager.connection(did=did) connected = conn.connected() @@ -375,7 +379,12 @@ def connect_server(sid): data={} ) - server = Server.query.filter_by(id=sid).first() + server = get_server_access(sid) + if server is None: + return make_json_response( + status=410, success=0, + errormsg=gettext("Could not find the required server.") + ) view = SchemaDiffRegistry.get_node_view('server') return view.connect(server.servergroup_id, sid) @@ -387,7 +396,12 @@ def connect_server(sid): ) @pga_login_required def connect_database(sid, did): - server = Server.query.filter_by(id=sid).first() + server = get_server_access(sid) + if server is None: + return make_json_response( + status=410, success=0, + errormsg=gettext("Could not find the required server.") + ) view = SchemaDiffRegistry.get_node_view('database') return view.connect(server.servergroup_id, sid, did) @@ -407,7 +421,13 @@ def databases(sid): try: view = SchemaDiffRegistry.get_node_view('database') - server = Server.query.filter_by(id=sid).first() + server = get_server_access(sid) + if server is None: + return make_json_response( + status=410, success=0, + errormsg=gettext( + "Could not find the required server.") + ) response = view.nodes(gid=server.servergroup_id, sid=sid, is_schema_diff=True) databases = json.loads(response.data)['data'] @@ -495,6 +515,15 @@ def compare_database(params): fetch_compare_schemas(params['source_sid'], params['source_did'], params['target_sid'], params['target_did']) + if schema_result is None: + socketio.emit( + 'compare_database_failed', + gettext( + "Failed to fetch schemas from the" + " server."), + namespace=SOCKETIO_NAMESPACE, to=request.sid) + return + total_schema = len(schema_result['source_only']) + len( schema_result['target_only']) + len( schema_result['in_both_database']) @@ -722,11 +751,15 @@ def check_version_compatibility(sid, tid): """Check the version compatibility of source and target servers.""" driver = get_driver(PG_DEFAULT_DRIVER) - src_server = Server.query.filter_by(id=sid).first() + src_server = get_server_access(sid) + if src_server is None: + return False, gettext("Could not find the source server.") src_manager = driver.connection_manager(src_server.id) src_conn = src_manager.connection() - tar_server = Server.query.filter_by(id=tid).first() + tar_server = get_server_access(tid) + if tar_server is None: + return False, gettext("Could not find the target server.") tar_manager = driver.connection_manager(tar_server.id) target_conn = tar_manager.connection() @@ -759,7 +792,9 @@ def get_schemas(sid, did): """ try: view = SchemaDiffRegistry.get_node_view('schema') - server = Server.query.filter_by(id=sid).first() + server = get_server_access(sid) + if server is None: + return None response = view.nodes(gid=server.servergroup_id, sid=sid, did=did, is_schema_diff=True) schemas = json.loads(response.data)['data'] @@ -912,6 +947,9 @@ def fetch_compare_schemas(source_sid, source_did, target_sid, target_did): source_schemas = get_schemas(source_sid, source_did) target_schemas = get_schemas(target_sid, target_did) + if source_schemas is None or target_schemas is None: + return None + src_schema_dict = {item['label']: item['_id'] for item in source_schemas} tar_schema_dict = {item['label']: item['_id'] for item in target_schemas} diff --git a/web/pgadmin/tools/sqleditor/__init__.py b/web/pgadmin/tools/sqleditor/__init__.py index fe69994f197..c9e26df2f00 100644 --- a/web/pgadmin/tools/sqleditor/__init__.py +++ b/web/pgadmin/tools/sqleditor/__init__.py @@ -63,6 +63,8 @@ ERROR_FETCHING_DATA, MY_STORAGE, ACCESS_DENIED_MESSAGE, \ ERROR_MSG_FAIL_TO_PROMOTE_QT from pgadmin.model import Server, ServerGroup +from pgadmin.utils.server_access import get_server, \ + get_server_groups_for_user, get_user_server_query from pgadmin.tools.schema_diff.node_registry import SchemaDiffRegistry from pgadmin.settings import get_setting from pgadmin.utils.preferences import Preferences @@ -225,7 +227,12 @@ def initialize_viewdata(trans_id, cmd_type, obj_type, sgid, sid, did, obj_id): 'password': _data['password'] if 'password' in _data else None } - server = Server.query.filter_by(id=sid).first() + server = get_server(sid) + if server is None: + return make_json_response( + status=410, success=0, + errormsg=gettext("Could not find the required server.") + ) if kwargs.get('password', None) is None: kwargs['encpass'] = server.password else: @@ -374,7 +381,7 @@ def panel(trans_id): params['bgcolor'] = None params['fgcolor'] = None - s = Server.query.filter_by(id=int(params['sid'])).first() + s = get_server(int(params['sid'])) if s: if s.shared and s.user_id != current_user.id: # Import here to avoid circular dependency @@ -512,7 +519,12 @@ def _init_sqleditor(trans_id, connect, sgid, sid, did, dbname=None, **kwargs): kwargs.pop('conn_id') conn_id_ac = str(secrets.choice(range(1, 9999999))) - server = Server.query.filter_by(id=sid).first() + server = get_server(sid) + if server is None: + return True, internal_server_error( + errormsg=gettext( + "Could not find the required server.") + ), '', '' if server.shared and server.user_id != current_user.id: # Import here to avoid circular dependency from pgadmin.browser.server_groups.servers import ServerModule @@ -2344,8 +2356,13 @@ def _check_server_connection_status(sgid, sid=None): driver = get_driver(PG_DEFAULT_DRIVER) from pgadmin.browser.server_groups.servers import \ server_icon_and_background - server = Server.query.filter_by( - id=sid).first() + server = get_server(sid) + if server is None: + return make_json_response( + status=410, success=0, + errormsg=gettext( + "Could not find the required server.") + ) manager = driver.connection_manager(server.id) conn = manager.connection() @@ -2393,11 +2410,10 @@ def get_new_connection_data(sgid=None, sid=None): driver = get_driver(PG_DEFAULT_DRIVER) from pgadmin.browser.server_groups.servers import \ server_icon_and_background - server_groups = ServerGroup.query.all() + server_groups = get_server_groups_for_user() server_group_data = {server_group.name: [] for server_group in server_groups} - servers = Server.query.filter( - or_(Server.user_id == current_user.id, Server.shared), + servers = get_user_server_query().filter( Server.is_adhoc == 0) for server in servers: @@ -2654,7 +2670,12 @@ def get_new_connection_role(sgid, sid=None): @pga_login_required def connect_server(sid): # Check if server is already connected then no need to reconnect again. - server = Server.query.filter_by(id=sid).first() + server = get_server(sid) + if server is None: + return make_json_response( + status=410, success=0, + errormsg=gettext("Could not find the required server.") + ) driver = get_driver(PG_DEFAULT_DRIVER) manager = driver.connection_manager(sid) diff --git a/web/pgadmin/tools/user_management/__init__.py b/web/pgadmin/tools/user_management/__init__.py index 36118731515..52497d4b330 100644 --- a/web/pgadmin/tools/user_management/__init__.py +++ b/web/pgadmin/tools/user_management/__init__.py @@ -759,7 +759,7 @@ def delete_user(uid): ServerGroup.query.filter_by(user_id=uid).delete() - Process.query.filter_by(user_id=uid).delete() + Process.for_user(user_id=uid).delete() # Delete Shared servers for current user. SharedServer.query.filter_by(user_id=uid).delete() diff --git a/web/pgadmin/utils/__init__.py b/web/pgadmin/utils/__init__.py index 6e4f5b22cad..7f2136d99f5 100644 --- a/web/pgadmin/utils/__init__.py +++ b/web/pgadmin/utils/__init__.py @@ -358,14 +358,14 @@ def does_utility_exist(file): return error_msg -def get_server(sid): - """ - # Fetch the server etc - :param sid: - :return: server +def get_server(sid, only_owned=False): + """Fetch a server by ID with access check. + + Delegates to server_access.get_server(). Kept here for backward + compatibility — existing callers import from pgadmin.utils. """ - server = Server.query.filter_by(id=sid).first() - return server + from pgadmin.utils.server_access import get_server as _get_server + return _get_server(sid, only_owned=only_owned) def get_binary_path_versions(binary_path: str) -> dict: diff --git a/web/pgadmin/utils/driver/psycopg3/__init__.py b/web/pgadmin/utils/driver/psycopg3/__init__.py index 5bb606c3dd0..0695e83f2a7 100644 --- a/web/pgadmin/utils/driver/psycopg3/__init__.py +++ b/web/pgadmin/utils/driver/psycopg3/__init__.py @@ -16,6 +16,7 @@ import datetime import re from flask import session +from flask_babel import gettext from flask_login import current_user from werkzeug.exceptions import InternalServerError import psycopg @@ -23,6 +24,9 @@ import config from pgadmin.model import Server +from pgadmin.utils.server_access import get_server, \ + get_user_server_query +from pgadmin.utils.exception import ObjectGone from .keywords import scan_keyword from ..abstract import BaseDriver from .connection import Connection @@ -67,20 +71,29 @@ def __init__(self, **kwargs): def _restore_connections_from_session(self): """ Used internally by connection_manager to restore connections - from sessions. + from sessions. Includes both owned and shared servers so + non-owner connections survive session restore. """ if session.sid not in self.managers: self.managers[session.sid] = managers = dict() if '__pgsql_server_managers' in session: session_managers = \ session['__pgsql_server_managers'].copy() - for server in \ - Server.query.filter_by( - user_id=current_user.id, is_adhoc=0): + servers = get_user_server_query().filter( + Server.is_adhoc == 0) + for server in servers: manager = managers[str(server.id)] = \ ServerManager(server) + # Suppress owner-only fields for non-owners + # of shared servers so passexec_cmd and + # post_connection_sql don't leak. + if server.shared and \ + server.user_id != current_user.id: + manager.passexec = None + manager.post_connection_sql = None if server.id in session_managers: - manager._restore(session_managers[server.id]) + manager._restore( + session_managers[server.id]) manager.update_session() return managers @@ -100,9 +113,27 @@ def connection_manager(self, sid=None): assert (sid is not None and isinstance(sid, int)) managers = None - server_data = Server.query.filter_by(id=sid).first() - if server_data is None: - return None + # In server mode, verify the current user has access to this + # server. This is the primary security boundary — all + # check_precondition decorators and tool endpoints flow + # through connection_manager(). + if config.SERVER_MODE: + if current_user and current_user.is_authenticated: + server_data = get_server(sid) + else: + raise ObjectGone( + gettext("Server not found.")) + if server_data is None: + raise ObjectGone( + gettext("Server not found.")) + else: + # Desktop mode — single user, no isolation needed. + # Return None instead of raising so callers that + # handle None gracefully (e.g., test teardown, + # cleanup paths) are not disrupted. + server_data = Server.query.filter_by(id=sid).first() + if server_data is None: + return None if session.sid not in self.managers: with connection_restore_lock: @@ -119,14 +150,18 @@ def connection_manager(self, sid=None): managers['pinged'] = datetime.datetime.now() if str(sid) not in managers: - s = Server.query.filter_by(id=sid).first() - - if not s: - return None - - managers[str(sid)] = ServerManager(s) - - return managers[str(sid)] + # server_data was already access-checked above; + # it cannot be None at this point. + manager = ServerManager(server_data) + # Suppress owner-only fields for non-owners of + # shared servers. + if config.SERVER_MODE and server_data.shared and \ + server_data.user_id != current_user.id: + manager.passexec = None + manager.post_connection_sql = None + managers[str(sid)] = manager + + return manager return managers[str(sid)] diff --git a/web/pgadmin/utils/server_access.py b/web/pgadmin/utils/server_access.py new file mode 100644 index 00000000000..1e0c8fe6ad8 --- /dev/null +++ b/web/pgadmin/utils/server_access.py @@ -0,0 +1,156 @@ +########################################################################## +# +# pgAdmin 4 - PostgreSQL Tools +# +# Copyright (C) 2013 - 2026, The pgAdmin Development Team +# This software is released under the PostgreSQL Licence +# +########################################################################## + +"""Centralized server access-checking utilities for data isolation. + +In server mode, multiple users share the same pgAdmin instance. These +helpers enforce that users can only access servers they own or that +have been explicitly shared with them via SharedServer entries. +""" + +from sqlalchemy import or_ +from flask_security import current_user + +from pgadmin.model import db, Server, ServerGroup +import config + + +def _is_admin(): + """Check if current user has Administrator role.""" + return current_user.has_role('Administrator') + + +def get_server(sid, only_owned=False): + """Fetch a server by ID, verifying the current user has access. + + Args: + sid: Server ID. + only_owned: If True, only return servers owned by the current + user. Use this for write operations (change_password, + clear_saved_password, etc.) that must not mutate another + user's server record via shared access. + + Returns the server if: + - Desktop mode (single user, no isolation needed), OR + - The user owns it, OR + - The server is shared AND only_owned is False, OR + - The user has the Administrator role. + + Returns None otherwise (caller should return 404). + + Note: In pgAdmin, Server.shared=True means the server is visible + to all authenticated users. SharedServer records are created + lazily for per-user customization, not for access control. + """ + if not config.SERVER_MODE: + return Server.query.filter_by(id=sid).first() + + if only_owned: + return Server.query.filter_by( + id=sid, user_id=current_user.id).first() + + # Single query: owned OR shared + server = Server.query.filter( + Server.id == sid, + or_( + Server.user_id == current_user.id, + Server.shared + ) + ).first() + + if server is not None: + return server + + # Administrators can access all servers + if _is_admin(): + return Server.query.filter_by(id=sid).first() + + return None + + +def get_server_group(gid): + """Fetch a server group by ID, verifying user access. + + Returns the group if: + - Desktop mode, OR + - The user owns it, OR + - It contains shared servers (Server.shared=True), OR + - The user has the Administrator role. + + Returns None otherwise. + """ + if not config.SERVER_MODE: + return ServerGroup.query.filter_by(id=gid).first() + + sg = ServerGroup.query.filter( + ServerGroup.id == gid, + or_( + ServerGroup.user_id == current_user.id, + ServerGroup.id.in_( + db.session.query(Server.servergroup_id).filter( + Server.shared + ) + ) + ) + ).first() + + if sg is not None: + return sg + + if _is_admin(): + return ServerGroup.query.filter_by(id=gid).first() + + return None + + +def get_server_groups_for_user(): + """Return server groups visible to the current user. + + Includes groups owned by the user plus groups containing shared + servers (Server.shared=True, visible to all authenticated users). + Administrators see all groups. + """ + if not config.SERVER_MODE: + return ServerGroup.query.filter_by( + user_id=current_user.id + ).all() + + if _is_admin(): + return ServerGroup.query.all() + + return ServerGroup.query.filter( + or_( + ServerGroup.user_id == current_user.id, + ServerGroup.id.in_( + db.session.query(Server.servergroup_id).filter( + Server.shared + ) + ) + ) + ).all() + + +def get_user_server_query(): + """Return a base query for servers accessible to the current user. + + Includes owned servers + shared servers (visible to all users). + Administrators see all servers. + """ + if not config.SERVER_MODE: + return Server.query + + if _is_admin(): + return Server.query + + return Server.query.filter( + or_( + Server.user_id == current_user.id, + Server.shared + ) + ) From 801d287f45673a390e35246d2031e9fd5165c00d Mon Sep 17 00:00:00 2001 From: Ashesh Vashi Date: Thu, 9 Apr 2026 16:01:39 +0530 Subject: [PATCH 2/3] test: data isolation and shared server security tests Integration tests (5 cases): private server denied, shared server accessible, passexec/post_connection_sql suppressed, SSL paths stripped, rename preserves access. Unit tests with mocks (22 cases): merge logic, sanitization, routing, ownership guards, expunge verification, error handling. Server group isolation (1 case), batch process mock updates (4). --- .../tests/test_server_data_isolation.py | 352 +++++++++++++ .../servers/tests/test_shared_server_unit.py | 487 ++++++++++++++++++ .../tests/test_sg_data_isolation.py | 78 +++ .../tools/backup/tests/test_batch_process.py | 2 + .../import_export/tests/test_batch_process.py | 2 + .../tests/test_batch_process_maintenance.py | 2 + .../tools/restore/tests/test_batch_process.py | 2 + 7 files changed, 925 insertions(+) create mode 100644 web/pgadmin/browser/server_groups/servers/tests/test_server_data_isolation.py create mode 100644 web/pgadmin/browser/server_groups/servers/tests/test_shared_server_unit.py create mode 100644 web/pgadmin/browser/server_groups/tests/test_sg_data_isolation.py diff --git a/web/pgadmin/browser/server_groups/servers/tests/test_server_data_isolation.py b/web/pgadmin/browser/server_groups/servers/tests/test_server_data_isolation.py new file mode 100644 index 00000000000..bbee83e7a7a --- /dev/null +++ b/web/pgadmin/browser/server_groups/servers/tests/test_server_data_isolation.py @@ -0,0 +1,352 @@ +########################################################################## +# +# pgAdmin 4 - PostgreSQL Tools +# +# Copyright (C) 2013 - 2026, The pgAdmin Development Team +# This software is released under the PostgreSQL Licence +# +########################################################################## + +"""Tests for server data isolation between users in server mode.""" + +import json +import config +from pgadmin.utils.route import BaseTestGenerator +from regression.python_test_utils import test_utils as utils +from regression.test_setup import config_data +from regression.python_test_utils.test_utils import \ + create_user_wise_test_client + +test_user_details = None +if config.SERVER_MODE: + test_user_details = \ + config_data['pgAdmin4_test_non_admin_credentials'] + + +class ServerDataIsolationGetTestCase(BaseTestGenerator): + """Verify that a non-admin user cannot access another user's + private (non-shared) server by ID.""" + + scenarios = [ + ('User B gets 410 for User A private server', + dict(is_positive_test=False)), + ] + + def setUp(self): + self.server_id = None + if not config.SERVER_MODE: + self.skipTest( + 'Data isolation tests only apply to server mode.' + ) + + # Create a private (non-shared) server as the admin user + self.server['shared'] = False + url = "/browser/server/obj/{0}/".format(utils.SERVER_GROUP) + response = self.tester.post( + url, + data=json.dumps(self.server), + content_type='html/json' + ) + self.assertEqual(response.status_code, 200) + response_data = json.loads(response.data.decode('utf-8')) + self.assertIn('node', response_data) + self.server_id = response_data['node']['_id'] + + @create_user_wise_test_client(test_user_details) + def runTest(self): + """Non-admin user should NOT be able to GET another user's + private server.""" + if not self.server_id: + raise Exception("Server not found to test isolation") + + url = '/browser/server/obj/{0}/{1}'.format( + utils.SERVER_GROUP, self.server_id) + response = self.tester.get(url, follow_redirects=True) + # Expect 410 Gone (server not accessible to this user) + self.assertEqual( + response.status_code, 410, + 'Non-admin user should not access another user\'s ' + 'private server. Got status {0}'.format( + response.status_code) + ) + + def tearDown(self): + if self.server_id is None: + return + # Clean up with the admin tester (which owns the server) + utils.delete_server_with_api( + self.__class__.tester, self.server_id) + + +class SharedServerAccessTestCase(BaseTestGenerator): + """Verify that a shared server IS accessible by a non-admin + user (positive test — shared servers should work after the + isolation fixes).""" + + scenarios = [ + ('User B can access shared server from User A', + dict(is_positive_test=True)), + ] + + def setUp(self): + self.server_id = None + if not config.SERVER_MODE: + self.skipTest( + 'Data isolation tests only apply to server mode.' + ) + + # Create a shared server as the admin user + self.server['shared'] = True + url = "/browser/server/obj/{0}/".format(utils.SERVER_GROUP) + response = self.tester.post( + url, + data=json.dumps(self.server), + content_type='html/json' + ) + self.assertEqual(response.status_code, 200) + response_data = json.loads(response.data.decode('utf-8')) + self.assertIn('node', response_data) + self.server_id = response_data['node']['_id'] + + @create_user_wise_test_client(test_user_details) + def runTest(self): + """Non-admin user SHOULD be able to GET a shared server.""" + if not self.server_id: + raise Exception("Server not found to test shared access") + + url = '/browser/server/obj/{0}/{1}'.format( + utils.SERVER_GROUP, self.server_id) + response = self.tester.get(url, follow_redirects=True) + self.assertEqual( + response.status_code, 200, + 'Non-admin user should be able to access shared server.' + ' Got status {0}'.format(response.status_code) + ) + + def tearDown(self): + if self.server_id is None: + return + utils.delete_server_with_api( + self.__class__.tester, self.server_id) + + +class SharedServerFieldSuppressionTestCase(BaseTestGenerator): + """Verify that owner-only sensitive fields are suppressed + when a non-owner accesses a shared server's properties.""" + + scenarios = [ + ('Shared server suppresses passexec_cmd and ' + 'post_connection_sql for non-owner', + dict(is_positive_test=True)), + ] + + def setUp(self): + self.server_id = None + if not config.SERVER_MODE: + self.skipTest( + 'Data isolation tests only apply to server mode.' + ) + + # Create a shared server with sensitive owner-only fields + self.server['shared'] = True + self.server['passexec_cmd'] = '/usr/bin/get-secret' + self.server['passexec_expiration'] = 100 + self.server['post_connection_sql'] = 'SET role admin;' + url = "/browser/server/obj/{0}/".format(utils.SERVER_GROUP) + response = self.tester.post( + url, + data=json.dumps(self.server), + content_type='html/json' + ) + self.assertEqual(response.status_code, 200) + response_data = json.loads(response.data.decode('utf-8')) + self.assertIn('node', response_data) + self.server_id = response_data['node']['_id'] + + @create_user_wise_test_client(test_user_details) + def runTest(self): + """Non-owner should NOT see passexec_cmd or + post_connection_sql in properties response.""" + if not self.server_id: + raise Exception("Server not found to test suppression") + + url = '/browser/server/obj/{0}/{1}'.format( + utils.SERVER_GROUP, self.server_id) + response = self.tester.get(url, follow_redirects=True) + self.assertEqual(response.status_code, 200) + data = json.loads(response.data.decode('utf-8')) + + # passexec_cmd must be None/null for non-owners + self.assertIsNone( + data.get('passexec_cmd'), + 'passexec_cmd should be suppressed for non-owners.' + ' Got: {0}'.format(data.get('passexec_cmd')) + ) + self.assertIsNone( + data.get('passexec_expiration'), + 'passexec_expiration should be suppressed for ' + 'non-owners.' + ) + # post_connection_sql must be None/null for non-owners + self.assertIsNone( + data.get('post_connection_sql'), + 'post_connection_sql should be suppressed for ' + 'non-owners. Got: {0}'.format( + data.get('post_connection_sql')) + ) + + def tearDown(self): + if self.server_id is None: + return + utils.delete_server_with_api( + self.__class__.tester, self.server_id) + + +class SharedServerConnectionParamsIsolationTestCase( + BaseTestGenerator): + """Verify that owner's SSL file paths in connection_params + are not leaked to non-owners of shared servers.""" + + scenarios = [ + ('Shared server strips owner SSL paths for non-owner', + dict(is_positive_test=True)), + ] + + def setUp(self): + self.server_id = None + if not config.SERVER_MODE: + self.skipTest( + 'Data isolation tests only apply to server mode.' + ) + + # Create shared server with owner SSL paths + self.server['shared'] = True + # Set connection_params with owner-specific paths + conn_params = self.server.get('connection_params', {}) + conn_params['sslcert'] = '/home/owner/.ssl/cert.pem' + conn_params['sslkey'] = '/home/owner/.ssl/key.pem' + conn_params['sslrootcert'] = '/home/owner/.ssl/ca.pem' + self.server['connection_params'] = conn_params + url = "/browser/server/obj/{0}/".format(utils.SERVER_GROUP) + response = self.tester.post( + url, + data=json.dumps(self.server), + content_type='html/json' + ) + self.assertEqual(response.status_code, 200) + response_data = json.loads(response.data.decode('utf-8')) + self.assertIn('node', response_data) + self.server_id = response_data['node']['_id'] + + @create_user_wise_test_client(test_user_details) + def runTest(self): + """Non-owner should NOT see owner's SSL file paths + in connection_params.""" + if not self.server_id: + raise Exception("Server not found") + + url = '/browser/server/obj/{0}/{1}'.format( + utils.SERVER_GROUP, self.server_id) + response = self.tester.get(url, follow_redirects=True) + self.assertEqual(response.status_code, 200) + data = json.loads(response.data.decode('utf-8')) + + conn_params = data.get('connection_params', {}) + # Owner SSL paths should be stripped for non-owners + # (non-owner has no SharedServer SSL paths configured, + # so keys should be absent) + for key in ('sslcert', 'sslkey', 'sslrootcert', + 'sslcrl', 'sslcrldir'): + val = None + if isinstance(conn_params, list): + for item in conn_params: + if item.get('name') == key: + val = item.get('value') + break + elif isinstance(conn_params, dict): + val = conn_params.get(key) + self.assertIsNone( + val, + 'Owner SSL path "{0}" should not leak to ' + 'non-owner. Got: {1}'.format(key, val) + ) + + def tearDown(self): + if self.server_id is None: + return + utils.delete_server_with_api( + self.__class__.tester, self.server_id) + + +class SharedServerRenameDoesNotOrphanTestCase(BaseTestGenerator): + """Verify that renaming a shared server does not create + orphan SharedServer records (Issue 20 fix — lookup uses + osid, not name).""" + + scenarios = [ + ('Rename shared server preserves non-owner access', + dict(is_positive_test=True)), + ] + + def setUp(self): + self.server_id = None + if not config.SERVER_MODE: + self.skipTest( + 'Data isolation tests only apply to server mode.' + ) + + # Save admin tester BEFORE the decorator replaces it. + self.admin_tester = self.tester + + self.server['shared'] = True + url = "/browser/server/obj/{0}/".format(utils.SERVER_GROUP) + response = self.tester.post( + url, + data=json.dumps(self.server), + content_type='html/json' + ) + self.assertEqual(response.status_code, 200) + response_data = json.loads(response.data.decode('utf-8')) + self.assertIn('node', response_data) + self.server_id = response_data['node']['_id'] + + @create_user_wise_test_client(test_user_details) + def runTest(self): + """After owner renames the shared server, non-owner + should still be able to access it.""" + if not self.server_id: + raise Exception("Server not found") + + # First access as non-owner to create SharedServer record + url = '/browser/server/obj/{0}/{1}'.format( + utils.SERVER_GROUP, self.server_id) + response = self.tester.get(url, follow_redirects=True) + self.assertEqual(response.status_code, 200) + + # Rename the server as admin (saved in setUp before + # the decorator replaced self.tester). + response = self.admin_tester.put( + '/browser/server/obj/{0}/{1}'.format( + utils.SERVER_GROUP, self.server_id), + data=json.dumps( + {'name': 'renamed_shared_server'}), + content_type='html/json' + ) + self.assertIn( + response.status_code, [200], + 'Admin should be able to rename shared server.' + ) + + # Access again as non-owner — should still work + response = self.tester.get(url, follow_redirects=True) + self.assertEqual( + response.status_code, 200, + 'Non-owner should still access shared server after ' + 'rename. Got status {0}'.format(response.status_code) + ) + + def tearDown(self): + if self.server_id is None: + return + utils.delete_server_with_api( + self.__class__.tester, self.server_id) diff --git a/web/pgadmin/browser/server_groups/servers/tests/test_shared_server_unit.py b/web/pgadmin/browser/server_groups/servers/tests/test_shared_server_unit.py new file mode 100644 index 00000000000..913030ecb36 --- /dev/null +++ b/web/pgadmin/browser/server_groups/servers/tests/test_shared_server_unit.py @@ -0,0 +1,487 @@ +########################################################################## +# +# pgAdmin 4 - PostgreSQL Tools +# +# Copyright (C) 2013 - 2026, The pgAdmin Development Team +# This software is released under the PostgreSQL Licence +# +########################################################################## + +"""Unit tests for shared server isolation logic using mocks. + +These tests verify the security-critical merge, suppression, and +sanitization logic without requiring a running PostgreSQL server +or HTTP infrastructure. +""" + +from unittest.mock import MagicMock, patch, call +from pgadmin.utils.route import BaseTestGenerator + +SRV_MODULE = 'pgadmin.browser.server_groups.servers' + + +def _make_server(**overrides): + """Create a mock Server object with sensible defaults.""" + defaults = dict( + id=1, user_id=100, name='OwnerServer', + shared=True, host='db.owner.com', port=5432, + maintenance_db='postgres', username='owner', + password=b'enc_owner_pass', role=None, + bgcolor=None, fgcolor=None, service=None, + use_ssh_tunnel=0, tunnel_host=None, + tunnel_port=5522, tunnel_authentication=0, + tunnel_username=None, tunnel_password=None, + tunnel_identity_file=None, + tunnel_prompt_password=0, tunnel_keep_alive=30, + save_password=1, servergroup_id=1, + server_owner='owner_user', prepare_threshold=5, + passexec_cmd='/usr/bin/vault-get-secret', + passexec_expiration=300, + post_connection_sql='SET role admin;', + connection_params={ + 'sslmode': 'verify-full', + 'sslcert': '/home/owner/.ssl/cert.pem', + 'sslkey': '/home/owner/.ssl/key.pem', + 'sslrootcert': '/home/owner/.ssl/ca.pem', + 'passfile': '/home/owner/.pgpass', + 'connect_timeout': '10', + }, + discovery_id=None, db_res=None, db_res_type=None, + kerberos_conn=False, cloud_status=0, + shared_username='shared_user', tags=None, + is_adhoc=0, + ) + defaults.update(overrides) + server = MagicMock() + for k, v in defaults.items(): + setattr(server, k, v) + return server + + +def _make_shared_server(**overrides): + """Create a mock SharedServer object.""" + defaults = dict( + id=10, osid=1, user_id=200, + server_owner='owner_user', servergroup_id=2, + name='MySharedView', host='db.owner.com', + port=5432, maintenance_db='postgres', + username='nonowner', password=b'enc_nonowner', + save_password=0, role='readonly', + bgcolor='#ff0000', fgcolor='#ffffff', + service='my_pg_service', + use_ssh_tunnel=1, tunnel_host='bastion.local', + tunnel_port=2222, tunnel_authentication=1, + tunnel_username='tunneluser', + tunnel_password=b'enc_tunnel', + tunnel_identity_file='/home/user/.ssh/id_rsa', + tunnel_prompt_password=0, + tunnel_keep_alive=60, shared=True, + prepare_threshold=10, + connection_params={ + 'sslmode': 'verify-full', + 'sslcert': '/home/nonowner/.ssl/cert.pem', + 'connect_timeout': '10', + }, + ) + defaults.update(overrides) + ss = MagicMock() + for k, v in defaults.items(): + setattr(ss, k, v) + return ss + + +class TestGetSharedServerProperties(BaseTestGenerator): + """Unit tests for ServerModule.get_shared_server_properties() + using mock objects.""" + + scenarios = [ + ('Merge suppresses passexec_cmd', + dict(test_method='test_suppresses_passexec')), + ('Merge suppresses post_connection_sql', + dict(test_method='test_suppresses_post_sql')), + ('Merge strips owner SSL paths not in SharedServer', + dict(test_method='test_strips_owner_ssl_paths')), + ('Merge applies SharedServer SSL paths', + dict(test_method='test_applies_ss_ssl_paths')), + ('Merge overrides service from SharedServer', + dict(test_method='test_overrides_service')), + ('Merge overrides tunnel fields', + dict(test_method='test_overrides_tunnel')), + ('Merge handles None connection_params', + dict(test_method='test_none_conn_params')), + ] + + @patch('pgadmin.browser.server_groups.servers.' + 'object_session', return_value=None) + def runTest(self, mock_sess): + getattr(self, self.test_method)() + + def _merge(self, server=None, ss=None): + from pgadmin.browser.server_groups.servers import \ + ServerModule + if server is None: + server = _make_server() + if ss is None: + ss = _make_shared_server() + return ServerModule.get_shared_server_properties( + server, ss) + + def test_suppresses_passexec(self): + result = self._merge() + self.assertIsNone(result.passexec_cmd) + self.assertIsNone(result.passexec_expiration) + + def test_suppresses_post_sql(self): + result = self._merge() + self.assertIsNone(result.post_connection_sql) + + def test_strips_owner_ssl_paths(self): + result = self._merge() + cp = result.connection_params + # Owner had sslkey, sslrootcert, passfile + # SharedServer did not -- should be removed. + self.assertNotIn('sslkey', cp) + self.assertNotIn('sslrootcert', cp) + self.assertNotIn('passfile', cp) + + def test_applies_ss_ssl_paths(self): + result = self._merge() + cp = result.connection_params + # SharedServer had sslcert -- should override. + self.assertEqual( + cp['sslcert'], + '/home/nonowner/.ssl/cert.pem') + # Non-sensitive params preserved from owner. + self.assertEqual(cp['sslmode'], 'verify-full') + self.assertEqual(cp['connect_timeout'], '10') + + def test_overrides_service(self): + result = self._merge() + self.assertEqual(result.service, 'my_pg_service') + + def test_overrides_tunnel(self): + result = self._merge() + self.assertEqual(result.tunnel_host, 'bastion.local') + self.assertEqual(result.tunnel_port, 2222) + self.assertEqual(result.tunnel_username, 'tunneluser') + self.assertEqual(result.tunnel_authentication, 1) + self.assertEqual( + result.tunnel_identity_file, + '/home/user/.ssh/id_rsa') + + def test_none_conn_params(self): + server = _make_server(connection_params=None) + ss = _make_shared_server(connection_params=None) + result = self._merge(server, ss) + # Should not crash; connection_params becomes {} + self.assertEqual(result.connection_params, {}) + + +class TestCreateSharedServerSanitization(BaseTestGenerator): + """Verify create_shared_server() strips sensitive + connection_params keys.""" + + scenarios = [ + ('Sanitizes connection_params on creation', + dict(test_method='test_sanitizes_conn_params')), + ('Copies tunnel_port from owner', + dict(test_method='test_copies_tunnel_port')), + ('Copies tunnel_keep_alive from owner', + dict(test_method='test_copies_tunnel_keep_alive')), + ('Handles None connection_params', + dict(test_method='test_none_conn_params')), + ] + + @patch('pgadmin.browser.server_groups.servers.db') + @patch('pgadmin.browser.server_groups.servers.User') + @patch('pgadmin.browser.server_groups.servers.current_user') + @patch('pgadmin.browser.server_groups.servers.SharedServer') + def runTest(self, mock_ss_cls, mock_cu, mock_user, + mock_db): + mock_cu.id = 200 + mock_user.query.filter_by.return_value \ + .first.return_value = MagicMock(username='owner') + # Capture the SharedServer() constructor call + self.captured_kwargs = {} + + def capture_init(**kwargs): + self.captured_kwargs = kwargs + return MagicMock() + + mock_ss_cls.side_effect = capture_init + getattr(self, self.test_method)() + + def _create(self, server=None): + from pgadmin.browser.server_groups.servers import \ + ServerModule + if server is None: + server = _make_server() + ServerModule.create_shared_server(server, 1) + + def test_sanitizes_conn_params(self): + self._create() + cp = self.captured_kwargs.get('connection_params', {}) + # Sensitive keys must be stripped + for key in ('sslcert', 'sslkey', 'sslrootcert', + 'passfile'): + self.assertNotIn( + key, cp, + 'Sensitive key "{0}" should be stripped ' + 'on SharedServer creation'.format(key)) + # Non-sensitive keys preserved + self.assertEqual(cp.get('sslmode'), 'verify-full') + self.assertEqual(cp.get('connect_timeout'), '10') + + def test_copies_tunnel_port(self): + server = _make_server(tunnel_port=2222) + self._create(server) + self.assertEqual( + self.captured_kwargs.get('tunnel_port'), 2222) + + def test_copies_tunnel_keep_alive(self): + server = _make_server(tunnel_keep_alive=45) + self._create(server) + self.assertEqual( + self.captured_kwargs.get('tunnel_keep_alive'), 45) + + def test_none_conn_params(self): + server = _make_server(connection_params=None) + self._create(server) + cp = self.captured_kwargs.get('connection_params', {}) + self.assertEqual(cp, {}) + + +class TestMergeExpungesServer(BaseTestGenerator): + """Verify get_shared_server_properties() expunges the server + from the SQLAlchemy session before mutation.""" + + scenarios = [ + ('Expunge called when server is in session', + dict(test_method='test_expunge_called')), + ('No crash when server not in session', + dict(test_method='test_no_session')), + ] + + def runTest(self): + getattr(self, self.test_method)() + + def test_expunge_called(self): + from pgadmin.browser.server_groups.servers import \ + ServerModule + server = _make_server() + ss = _make_shared_server() + mock_session = MagicMock() + with patch(SRV_MODULE + '.object_session', + return_value=mock_session): + ServerModule.get_shared_server_properties( + server, ss) + mock_session.expunge.assert_called_once_with(server) + + def test_no_session(self): + from pgadmin.browser.server_groups.servers import \ + ServerModule + server = _make_server() + ss = _make_shared_server() + with patch(SRV_MODULE + '.object_session', + return_value=None): + # Should not crash + result = ServerModule.get_shared_server_properties( + server, ss) + self.assertIsNone(result.passexec_cmd) + + +class TestUpdateConnectionParameter(BaseTestGenerator): + """Verify update_connection_parameter() routes changes + to SharedServer for non-owners.""" + + scenarios = [ + ('Non-owner changes go to SharedServer copy', + dict(test_method='test_nonowner_routing')), + ('Owner changes go to Server directly', + dict(test_method='test_owner_routing')), + ] + + def runTest(self): + getattr(self, self.test_method)() + + @patch(SRV_MODULE + '.current_user') + def test_nonowner_routing(self, mock_cu): + mock_cu.id = 200 # Non-owner + from pgadmin.browser.server_groups.servers import \ + ServerNode + + server = _make_server( + connection_params={'sslmode': 'require'}) + ss = _make_shared_server( + connection_params={'sslmode': 'require'}) + + data = {'connection_params': { + 'changed': [{'name': 'sslmode', 'value': 'verify'}] + }} + + node = ServerNode.__new__(ServerNode) + node.update_connection_parameter(data, server, ss) + + # The result should be in data, not mutating server + self.assertEqual( + data['connection_params']['sslmode'], 'verify') + # Owner's server should NOT be mutated + self.assertEqual( + server.connection_params['sslmode'], 'require') + + @patch(SRV_MODULE + '.current_user') + def test_owner_routing(self, mock_cu): + mock_cu.id = 100 # Owner + from pgadmin.browser.server_groups.servers import \ + ServerNode + + server = _make_server( + connection_params={'sslmode': 'require'}) + + data = {'connection_params': { + 'changed': [{'name': 'sslmode', 'value': 'verify'}] + }} + + node = ServerNode.__new__(ServerNode) + node.update_connection_parameter(data, server, None) + + # Owner path mutates server directly + self.assertEqual( + data['connection_params']['sslmode'], 'verify') + + +class TestUpdateServerDetails(BaseTestGenerator): + """Verify _update_server_details routes writes to + SharedServer for non-owners.""" + + scenarios = [ + ('Non-owner write goes to SharedServer', + dict(test_method='test_nonowner_write')), + ('Owner write goes to Server', + dict(test_method='test_owner_write')), + ] + + def runTest(self): + getattr(self, self.test_method)() + + @patch(SRV_MODULE + '.current_user') + def test_nonowner_write(self, mock_cu): + mock_cu.id = 200 + from pgadmin.browser.server_groups.servers import \ + ServerNode + + server = _make_server() + ss = _make_shared_server() + config_map = {'name': 'name'} + + ServerNode._update_server_details( + server, ss, config_map, 'name', 'NewName') + + self.assertEqual(ss.name, 'NewName') + # Server should not be modified + self.assertEqual(server.name, 'OwnerServer') + + @patch(SRV_MODULE + '.current_user') + def test_owner_write(self, mock_cu): + mock_cu.id = 100 + from pgadmin.browser.server_groups.servers import \ + ServerNode + + server = _make_server() + config_map = {'name': 'name'} + + ServerNode._update_server_details( + server, None, config_map, 'name', 'NewName') + + self.assertEqual(server.name, 'NewName') + + +class TestDeleteSharedServerOwnerGuard(BaseTestGenerator): + """Verify that only the owner can trigger + delete_shared_server via _set_valid_attr_value.""" + + scenarios = [ + ('Non-owner shared=false does not delete', + dict(test_method='test_nonowner_no_delete')), + ('Owner shared=false triggers delete', + dict(test_method='test_owner_deletes')), + ] + + def runTest(self): + getattr(self, self.test_method)() + + @patch(SRV_MODULE + '.get_crypt_key', + return_value=(True, b'key')) + @patch(SRV_MODULE + '.current_user') + def test_nonowner_no_delete(self, mock_cu, mock_ck): + mock_cu.id = 200 + from pgadmin.browser.server_groups.servers import \ + ServerNode + + server = _make_server() + ss = _make_shared_server() + node = ServerNode.__new__(ServerNode) + node.delete_shared_server = MagicMock() + + data = {'shared': False} + config_map = {'shared': 'shared'} + + node._set_valid_attr_value( + 1, data, config_map, server, ss) + + node.delete_shared_server.assert_not_called() + + @patch(SRV_MODULE + '.get_crypt_key', + return_value=(True, b'key')) + @patch(SRV_MODULE + '.current_user') + def test_owner_deletes(self, mock_cu, mock_ck): + mock_cu.id = 100 # Owner + from pgadmin.browser.server_groups.servers import \ + ServerNode + + server = _make_server() + node = ServerNode.__new__(ServerNode) + node.delete_shared_server = MagicMock() + + data = {'shared': False} + config_map = {'shared': 'shared'} + + node._set_valid_attr_value( + 1, data, config_map, server, None) + + node.delete_shared_server.assert_called_once_with( + 1, server.id) + + +class TestGetSharedServerRaisesOnNone(BaseTestGenerator): + """Verify get_shared_server() raises if SharedServer + cannot be created.""" + + scenarios = [ + ('Raises when SharedServer is None after create', + dict(test_method='test_raises_on_none')), + ] + + def runTest(self): + getattr(self, self.test_method)() + + @patch(SRV_MODULE + '.SharedServer') + @patch(SRV_MODULE + '.current_user') + def test_raises_on_none(self, mock_cu, mock_ss): + mock_cu.id = 200 + # Both queries return None + mock_ss.query.filter_by.return_value \ + .first.return_value = None + + from pgadmin.browser.server_groups.servers import \ + ServerModule + + server = _make_server() + + with patch.object(ServerModule, 'create_shared_server'): + with self.assertRaises(Exception) as ctx: + ServerModule.get_shared_server(server, 1) + + self.assertIn( + 'Failed to create shared server', + str(ctx.exception)) diff --git a/web/pgadmin/browser/server_groups/tests/test_sg_data_isolation.py b/web/pgadmin/browser/server_groups/tests/test_sg_data_isolation.py new file mode 100644 index 00000000000..8f870156cfe --- /dev/null +++ b/web/pgadmin/browser/server_groups/tests/test_sg_data_isolation.py @@ -0,0 +1,78 @@ +########################################################################## +# +# pgAdmin 4 - PostgreSQL Tools +# +# Copyright (C) 2013 - 2026, The pgAdmin Development Team +# This software is released under the PostgreSQL Licence +# +########################################################################## + +"""Tests for ServerGroup data isolation between users in server mode.""" + +import json +import config +from pgadmin.utils.route import BaseTestGenerator +from regression.python_test_utils import test_utils as utils +from regression.test_setup import config_data +from regression.python_test_utils.test_utils import \ + create_user_wise_test_client +from pgadmin.model import db, ServerGroup + +test_user_details = None +if config.SERVER_MODE: + test_user_details = \ + config_data['pgAdmin4_test_non_admin_credentials'] + + +class ServerGroupIsolationTestCase(BaseTestGenerator): + """Verify that a non-admin user cannot fetch another user's + server group properties by ID.""" + + scenarios = [ + ('User B cannot fetch User A server group properties', + dict(is_positive_test=False)), + ] + + def setUp(self): + self.sg_id = None + if not config.SERVER_MODE: + self.skipTest( + 'Data isolation tests only apply to server mode.' + ) + + # Create a server group as the admin user + url = '/browser/server_group/obj/' + response = self.tester.post( + url, + data=json.dumps({'name': 'isolation_test_group'}), + content_type='html/json' + ) + self.assertEqual(response.status_code, 200) + response_data = json.loads(response.data.decode('utf-8')) + self.assertIn('node', response_data) + self.sg_id = response_data['node']['_id'] + + @create_user_wise_test_client(test_user_details) + def runTest(self): + """Non-admin user should NOT see another user's server + group properties.""" + if not self.sg_id: + raise Exception("Server group not created") + + url = '/browser/server_group/obj/{0}'.format(self.sg_id) + response = self.tester.get(url, content_type='html/json') + self.assertEqual( + response.status_code, 410, + 'Non-admin user should not access another user\'s ' + 'server group. Got status {0}'.format( + response.status_code) + ) + + def tearDown(self): + # Clean up with admin + if self.sg_id is None: + return + sg = ServerGroup.query.filter_by(id=self.sg_id).first() + if sg: + db.session.delete(sg) + db.session.commit() diff --git a/web/pgadmin/tools/backup/tests/test_batch_process.py b/web/pgadmin/tools/backup/tests/test_batch_process.py index d0921261e58..08f6c6952bc 100644 --- a/web/pgadmin/tools/backup/tests/test_batch_process.py +++ b/web/pgadmin/tools/backup/tests/test_batch_process.py @@ -195,6 +195,7 @@ def __init__(self, desc, args, cmd): self.utility_pid = 123 self.server_id = None + process_mock.for_user = process_mock.query.filter_by mock_result = process_mock.query.filter_by.return_value mock_result.first.return_value = TestMockProcess( backup_obj, self.class_params['args'], self.class_params['cmd']) @@ -239,6 +240,7 @@ def __init__(self, desc, args, cmd): self.utility_pid = 123 self.server_id = None + process_mock.for_user = process_mock.query.filter_by process_mock.query.filter_by.return_value = [ TestMockProcess(backup_obj, self.class_params['args'], diff --git a/web/pgadmin/tools/import_export/tests/test_batch_process.py b/web/pgadmin/tools/import_export/tests/test_batch_process.py index da42e436703..b9fbe124b77 100644 --- a/web/pgadmin/tools/import_export/tests/test_batch_process.py +++ b/web/pgadmin/tools/import_export/tests/test_batch_process.py @@ -204,6 +204,7 @@ def __init__(self, desc, args, cmd): self.utility_pid = 123 self.server_id = None + process_mock.for_user = process_mock.query.filter_by mock_result = process_mock.query.filter_by.return_value mock_result.first.return_value = TestMockProcess( import_export_obj, self.class_params['args'], @@ -250,6 +251,7 @@ def __init__(self, desc, args, cmd): self.utility_pid = 123 self.server_id = None + process_mock.for_user = process_mock.query.filter_by process_mock.query.filter_by.return_value = [ TestMockProcess(import_export_obj, self.class_params['args'], diff --git a/web/pgadmin/tools/maintenance/tests/test_batch_process_maintenance.py b/web/pgadmin/tools/maintenance/tests/test_batch_process_maintenance.py index b2ca169f5d9..5e7ea7d579b 100644 --- a/web/pgadmin/tools/maintenance/tests/test_batch_process_maintenance.py +++ b/web/pgadmin/tools/maintenance/tests/test_batch_process_maintenance.py @@ -137,6 +137,7 @@ def __init__(self, desc, args, cmd): self.utility_pid = 123 self.server_id = None + process_mock.for_user = process_mock.query.filter_by mock_result = process_mock.query.filter_by.return_value mock_result.first.return_value = TestMockProcess( maintenance_obj, self.class_params['args'], @@ -177,6 +178,7 @@ def __init__(self, desc, args, cmd): self.utility_pid = 123 self.server_id = None + process_mock.for_user = process_mock.query.filter_by process_mock.query.filter_by.return_value = [ TestMockProcess(maintenance_obj, self.class_params['args'], diff --git a/web/pgadmin/tools/restore/tests/test_batch_process.py b/web/pgadmin/tools/restore/tests/test_batch_process.py index b0045d4e8dc..3f4f0bfa155 100644 --- a/web/pgadmin/tools/restore/tests/test_batch_process.py +++ b/web/pgadmin/tools/restore/tests/test_batch_process.py @@ -134,6 +134,7 @@ def __init__(self, desc, args, cmd): self.utility_pid = 123 self.server_id = None + process_mock.for_user = process_mock.query.filter_by mock_result = process_mock.query.filter_by.return_value mock_result.first.return_value = TestMockProcess( restore_obj, self.class_params['args'], @@ -174,6 +175,7 @@ def __init__(self, desc, args, cmd): self.utility_pid = 123 self.server_id = None + process_mock.for_user = process_mock.query.filter_by process_mock.query.filter_by.return_value = [ TestMockProcess(restore_obj, self.class_params['args'], From 08c70f7f9df2008df3b4160b0ffab9bc56378769 Mon Sep 17 00:00:00 2001 From: Ashesh Vashi Date: Thu, 9 Apr 2026 17:06:38 +0530 Subject: [PATCH 3/3] test: add sslcrl and sslcrldir to sensitive key test coverage Add sslcrl and sslcrldir to mock server connection_params and to all test assertions that verify sensitive key stripping. Previously only 4 of 6 SENSITIVE_CONN_KEYS were tested. --- .../servers/tests/test_shared_server_unit.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/web/pgadmin/browser/server_groups/servers/tests/test_shared_server_unit.py b/web/pgadmin/browser/server_groups/servers/tests/test_shared_server_unit.py index 913030ecb36..7b41af9055c 100644 --- a/web/pgadmin/browser/server_groups/servers/tests/test_shared_server_unit.py +++ b/web/pgadmin/browser/server_groups/servers/tests/test_shared_server_unit.py @@ -43,6 +43,8 @@ def _make_server(**overrides): 'sslcert': '/home/owner/.ssl/cert.pem', 'sslkey': '/home/owner/.ssl/key.pem', 'sslrootcert': '/home/owner/.ssl/ca.pem', + 'sslcrl': '/home/owner/.ssl/crl.pem', + 'sslcrldir': '/home/owner/.ssl/crl.d', 'passfile': '/home/owner/.pgpass', 'connect_timeout': '10', }, @@ -138,9 +140,11 @@ def test_suppresses_post_sql(self): def test_strips_owner_ssl_paths(self): result = self._merge() cp = result.connection_params - # Owner had sslkey, sslrootcert, passfile - # SharedServer did not -- should be removed. + # Owner had sslkey, sslrootcert, sslcrl, sslcrldir, + # passfile — SharedServer did not — should be removed. self.assertNotIn('sslkey', cp) + self.assertNotIn('sslcrl', cp) + self.assertNotIn('sslcrldir', cp) self.assertNotIn('sslrootcert', cp) self.assertNotIn('passfile', cp) @@ -223,7 +227,7 @@ def test_sanitizes_conn_params(self): cp = self.captured_kwargs.get('connection_params', {}) # Sensitive keys must be stripped for key in ('sslcert', 'sslkey', 'sslrootcert', - 'passfile'): + 'sslcrl', 'sslcrldir', 'passfile'): self.assertNotIn( key, cp, 'Sensitive key "{0}" should be stripped '