diff --git a/web/migrations/versions/add_user_id_to_debugger_func_args_.py b/web/migrations/versions/add_user_id_to_debugger_func_args_.py new file mode 100644 index 000000000..d8cd046da --- /dev/null +++ b/web/migrations/versions/add_user_id_to_debugger_func_args_.py @@ -0,0 +1,149 @@ +########################################################################## +# +# pgAdmin 4 - PostgreSQL Tools +# +# Copyright (C) 2013 - 2026, The pgAdmin Development Team +# This software is released under the PostgreSQL Licence +# +########################################################################## + +"""Add user_id to debugger_function_arguments and indexes for data isolation + +Revision ID: add_user_id_dbg_args +Revises: add_tools_ai_perm +Create Date: 2026-04-08 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = 'add_user_id_dbg_args' +down_revision = 'add_tools_ai_perm' +branch_labels = None +depends_on = None + + +def upgrade(): + conn = op.get_bind() + dialect = conn.dialect.name + + # --- DebuggerFunctionArguments: add user_id to composite PK --- + if dialect == 'sqlite': + # SQLite cannot ALTER composite PKs. Recreate the table. + # Existing debugger argument data is ephemeral (cached function + # args) so dropping is acceptable. + op.execute( + 'DROP TABLE IF EXISTS debugger_function_arguments' + ) + op.create_table( + 'debugger_function_arguments', + sa.Column('user_id', sa.Integer(), + sa.ForeignKey('user.id'), nullable=False), + sa.Column('server_id', sa.Integer(), nullable=False), + sa.Column('database_id', sa.Integer(), nullable=False), + sa.Column('schema_id', sa.Integer(), nullable=False), + sa.Column('function_id', sa.Integer(), nullable=False), + sa.Column('arg_id', sa.Integer(), nullable=False), + sa.Column('is_null', sa.Integer(), nullable=False), + sa.Column('is_expression', sa.Integer(), nullable=False), + sa.Column('use_default', sa.Integer(), nullable=False), + sa.Column('value', sa.String(), nullable=True), + sa.PrimaryKeyConstraint( + 'user_id', 'server_id', 'database_id', + 'schema_id', 'function_id', 'arg_id' + ), + sa.CheckConstraint('is_null >= 0 AND is_null <= 1'), + sa.CheckConstraint( + 'is_expression >= 0 AND is_expression <= 1'), + sa.CheckConstraint( + 'use_default >= 0 AND use_default <= 1'), + ) + else: + # PostgreSQL: add column, backfill from server owner, recreate + # PK using batch_alter_table for portability. + op.add_column( + 'debugger_function_arguments', + sa.Column('user_id', sa.Integer(), + sa.ForeignKey('user.id'), nullable=True) + ) + # Backfill: assign user_id from the server's owner + op.execute( + 'UPDATE debugger_function_arguments ' + 'SET user_id = s.user_id ' + 'FROM server s ' + 'WHERE debugger_function_arguments.server_id = s.id' + ) + # Delete orphans (rows with no matching server) + op.execute( + 'DELETE FROM debugger_function_arguments ' + 'WHERE user_id IS NULL' + ) + op.alter_column( + 'debugger_function_arguments', 'user_id', nullable=False + ) + # Recreate PK with user_id using batch_alter_table + with op.batch_alter_table( + 'debugger_function_arguments' + ) as batch: + batch.drop_constraint( + 'debugger_function_arguments_pkey', type_='primary' + ) + batch.create_primary_key( + 'debugger_function_arguments_pkey', + ['user_id', 'server_id', 'database_id', + 'schema_id', 'function_id', 'arg_id'] + ) + + # --- Indexes for data isolation query performance --- + # Only create indexes on tables that exist (sharedserver may be + # absent in older schemas that haven't run all prior migrations). + inspector = sa.inspect(conn) + index_stmts = [ + ('server', + 'CREATE INDEX IF NOT EXISTS ix_server_user_id ' + 'ON server (user_id)'), + ('server', + 'CREATE INDEX IF NOT EXISTS ix_server_servergroup_id ' + 'ON server (servergroup_id)'), + ('sharedserver', + 'CREATE INDEX IF NOT EXISTS ix_sharedserver_user_id ' + 'ON sharedserver (user_id)'), + ('sharedserver', + 'CREATE INDEX IF NOT EXISTS ix_sharedserver_osid ' + 'ON sharedserver (osid)'), + ('servergroup', + 'CREATE INDEX IF NOT EXISTS ix_servergroup_user_id ' + 'ON servergroup (user_id)'), + ] + for table_name, stmt in index_stmts: + if inspector.has_table(table_name): + op.execute(stmt) + + # --- Unique constraint on SharedServer(osid, user_id) --- + # Prevents duplicate SharedServer records from TOCTOU race. + # First remove duplicates (keep lowest id per osid+user_id). + if inspector.has_table('sharedserver'): + if dialect == 'sqlite': + op.execute( + 'DELETE FROM sharedserver WHERE id NOT IN ' + '(SELECT MIN(id) FROM sharedserver ' + 'GROUP BY osid, user_id)' + ) + else: + op.execute( + 'DELETE FROM sharedserver s1 USING ' + 'sharedserver s2 WHERE s1.osid = s2.osid ' + 'AND s1.user_id = s2.user_id ' + 'AND s1.id > s2.id' + ) + with op.batch_alter_table('sharedserver') as batch: + batch.create_unique_constraint( + 'uq_sharedserver_osid_user', + ['osid', 'user_id'] + ) + + +def downgrade(): + # pgAdmin only upgrades, downgrade not implemented. + pass diff --git a/web/migrations/versions/ca00ec32581b_.py b/web/migrations/versions/ca00ec32581b_.py index 6d566cd17..64a3ba12f 100644 --- a/web/migrations/versions/ca00ec32581b_.py +++ b/web/migrations/versions/ca00ec32581b_.py @@ -15,8 +15,6 @@ Create Date: 2018-08-29 15:33:57.855491 """ from alembic import op -from sqlalchemy.orm.session import Session -from pgadmin.model import DebuggerFunctionArguments # revision identifiers, used by Alembic. revision = 'ca00ec32581b' @@ -26,11 +24,10 @@ depends_on = None def upgrade(): - session = Session(bind=op.get_bind()) - - debugger_records = session.query(DebuggerFunctionArguments).all() - if debugger_records: - session.delete(debugger_records) + # Use raw SQL instead of importing the model class, because + # model changes in later migrations (e.g. adding user_id) would + # cause this migration to fail on fresh databases. + op.execute('DELETE FROM debugger_function_arguments') def downgrade(): diff --git a/web/pgadmin/browser/server_groups/__init__.py b/web/pgadmin/browser/server_groups/__init__.py index e0212d277..1d70695c3 100644 --- a/web/pgadmin/browser/server_groups/__init__.py +++ b/web/pgadmin/browser/server_groups/__init__.py @@ -25,6 +25,8 @@ from sqlalchemy import exc from pgadmin.model import db, ServerGroup, Server import config from pgadmin.utils.preferences import Preferences +from pgadmin.utils.server_access import get_server_group, \ + get_server_groups_for_user def get_icon_css_class(group_id, group_user_id, @@ -286,7 +288,7 @@ class ServerGroupView(NodeView): def properties(self, gid): """Update the server-group properties""" - sg = ServerGroup.query.filter(ServerGroup.id == gid).first() + sg = get_server_group(gid) if sg is None: return make_json_response( @@ -296,7 +298,8 @@ class ServerGroupView(NodeView): ) else: return ajax_response( - response={'id': sg.id, 'name': sg.name, 'user_id': sg.user_id}, + response={'id': sg.id, 'name': sg.name, + 'user_id': sg.user_id}, status=200 ) @@ -373,8 +376,9 @@ class ServerGroupView(NodeView): @staticmethod def get_all_server_groups(): """ - Returns the list of server groups to show in server mode and - if there is any shared server in the group. + Returns the list of server groups to show in server mode. + Includes groups owned by the user and groups containing + shared servers accessible to this user. :return: server groups """ @@ -383,17 +387,18 @@ class ServerGroupView(NodeView): pref = Preferences.module('browser') hide_shared_server = pref.preference('hide_shared_server').get() - server_groups = ServerGroup.query.all() - groups = [] - for group in server_groups: - if hide_shared_server and \ - ServerGroupModule.has_shared_server(group.id) and \ - group.user_id != current_user.id: - continue - if group.user_id == current_user.id or \ - ServerGroupModule.has_shared_server(group.id): + server_groups = get_server_groups_for_user() + + if hide_shared_server: + groups = [] + for group in server_groups: + if group.user_id != current_user.id and \ + ServerGroupModule.has_shared_server(group.id): + continue groups.append(group) - return groups + return groups + + return server_groups @pga_login_required def nodes(self, gid=None): @@ -421,7 +426,7 @@ class ServerGroupView(NodeView): ) ) else: - group = ServerGroup.query.filter(ServerGroup.id == gid).first() + group = get_server_group(gid) if not group: return gone( diff --git a/web/pgadmin/browser/server_groups/servers/__init__.py b/web/pgadmin/browser/server_groups/servers/__init__.py index 248259949..cbcf79a3c 100644 --- a/web/pgadmin/browser/server_groups/servers/__init__.py +++ b/web/pgadmin/browser/server_groups/servers/__init__.py @@ -39,12 +39,30 @@ from pgadmin.browser.server_groups.servers.utils import \ from pgadmin.utils.constants import UNAUTH_REQ, MIMETYPE_APP_JS, \ SERVER_CONNECTION_CLOSED, RESTRICTION_TYPE_SQL from sqlalchemy import or_ +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import object_session from sqlalchemy.orm.attributes import flag_modified from pgadmin.utils.preferences import Preferences from .... import socketio as sio from pgadmin.utils import get_complete_file_path from pgadmin.settings.utils import with_object_filters +from pgadmin.utils.server_access import get_server, \ + get_user_server_query, get_server_group + + +# File-path keys in connection_params that are per-user and must +# not be copied from the owner to a new SharedServer or leaked +# through the property merge. +SENSITIVE_CONN_KEYS = frozenset({ + 'passfile', 'sslcert', 'sslkey', + 'sslrootcert', 'sslcrl', 'sslcrldir', +}) + + +def _is_non_owner(server): + """True if the server is shared and the current user is not + the owner. Centralises the check used in 15+ places.""" + return server.shared and server.user_id != current_user.id def has_any(data, keys): @@ -151,15 +169,30 @@ class ServerModule(sg.ServerGroupPluginModule): @staticmethod def get_shared_server_properties(server, sharedserver): """ - Return shared server properties + Return shared server properties. + + Overlays per-user SharedServer values onto the owner's Server + object. Security-sensitive fields that are absent from the + SharedServer model (passexec_cmd, post_connection_sql) are + suppressed for non-owners. + + The server is expunged from the SQLAlchemy session before + mutation so that the owner's record is never dirtied. :param server: :param sharedserver: - :return: shared server + :return: shared server (detached) """ + # Detach from session so in-place mutations are never + # flushed back to the owner's Server row. + sess = object_session(server) + if sess is not None: + sess.expunge(server) + server.bgcolor = sharedserver.bgcolor server.fgcolor = sharedserver.fgcolor server.name = sharedserver.name server.role = sharedserver.role + server.service = sharedserver.service server.use_ssh_tunnel = sharedserver.use_ssh_tunnel server.tunnel_host = sharedserver.tunnel_host server.tunnel_port = sharedserver.tunnel_port @@ -169,24 +202,36 @@ class ServerModule(sg.ServerGroupPluginModule): server.save_password = sharedserver.save_password server.tunnel_identity_file = sharedserver.tunnel_identity_file server.tunnel_prompt_password = sharedserver.tunnel_prompt_password - if hasattr(server, 'connection_params') and \ - hasattr(sharedserver, 'connection_params') and \ - 'passfile' in server.connection_params and \ - 'passfile' in sharedserver.connection_params: - server.connection_params['passfile'] = \ - sharedserver.connection_params['passfile'] + + # Override per-user connection_params keys. Use the + # SharedServer value whenever it is present, regardless of + # whether the owner's Server has the same key. + s_conn = getattr(server, 'connection_params', None) \ + or {} + ss_conn = getattr(sharedserver, 'connection_params', + None) or {} + for key in SENSITIVE_CONN_KEYS: + if key in ss_conn: + s_conn[key] = ss_conn[key] + elif key in s_conn: + # Owner has this key but non-owner doesn't — + # remove it so the owner's path doesn't leak. + del s_conn[key] + server.connection_params = s_conn + server.servergroup_id = sharedserver.servergroup_id - if hasattr(server, 'connection_params') and \ - hasattr(sharedserver, 'connection_params') and \ - 'sslcert' in server.connection_params and \ - 'sslcert' in sharedserver.connection_params: - server.connection_params['sslcert'] = \ - sharedserver.connection_params['sslcert'] server.username = sharedserver.username server.server_owner = sharedserver.server_owner server.password = sharedserver.password server.prepare_threshold = sharedserver.prepare_threshold + # Suppress owner-only fields that are absent from SharedServer + # and dangerous when inherited (privilege escalation / code + # execution). + server.passexec_cmd = None + server.passexec_expiration = None + server.post_connection_sql = None + return server def get_servers(self, all_servers, hide_shared_server, gid): @@ -203,12 +248,13 @@ class ServerModule(sg.ServerGroupPluginModule): if server.discovery_id and \ not server.shared and \ config.SERVER_MODE and \ - len(SharedServer.query.filter_by( + SharedServer.query.filter_by( user_id=current_user.id, - name=server.name).all()) > 0 and not hide_shared_server: + osid=server.id).first() is not None \ + and not hide_shared_server: continue - if server.shared and server.user_id != current_user.id: + if _is_non_owner(server): shared_server = self.get_shared_server(server, gid) @@ -245,8 +291,7 @@ class ServerModule(sg.ServerGroupPluginModule): """Return a JSON document listing the server groups for the user""" hide_shared_server = get_preferences() - servers = Server.query.filter( - or_(Server.user_id == current_user.id, Server.shared), + servers = get_user_server_query().filter( Server.servergroup_id == gid, Server.is_adhoc == 0) driver = get_driver(PG_DEFAULT_DRIVER) @@ -392,6 +437,18 @@ class ServerModule(sg.ServerGroupPluginModule): try: db.session.rollback() user = User.query.filter_by(id=data.user_id).first() + + # Strip owner's sensitive file paths from + # connection_params — each user should configure + # their own SSL/passfile paths. + safe_conn_params = {} + if data.connection_params: + safe_conn_params = { + k: v for k, v in + data.connection_params.items() + if k not in SENSITIVE_CONN_KEYS + } + shared_server = SharedServer( osid=data.id, user_id=current_user.id, @@ -410,43 +467,57 @@ class ServerModule(sg.ServerGroupPluginModule): service=data.service if data.service else None, use_ssh_tunnel=data.use_ssh_tunnel, tunnel_host=data.tunnel_host, - tunnel_port=22, + tunnel_port=data.tunnel_port + if data.tunnel_port else 22, tunnel_username=None, tunnel_authentication=0, tunnel_identity_file=None, - tunnel_keep_alive=0, + tunnel_keep_alive=data.tunnel_keep_alive + if data.tunnel_keep_alive else 0, tunnel_prompt_password=0, shared=True, - connection_params=data.connection_params, + connection_params=safe_conn_params, prepare_threshold=data.prepare_threshold ) db.session.add(shared_server) db.session.commit() except Exception as e: - if shared_server: - db.session.delete(shared_server) - db.session.commit() - + db.session.rollback() raise e @staticmethod def get_shared_server(server, gid): """ - return the shared server + Return the SharedServer record for the current user, + creating one lazily if it doesn't exist. The unique + constraint on (osid, user_id) prevents duplicates from + concurrent requests. :param server: :param gid: - :return: shared_server + :return: shared_server (never None) + :raises: Exception if SharedServer cannot be created """ shared_server = SharedServer.query.filter_by( - name=server.name, user_id=current_user.id, - servergroup_id=int(gid), osid=server.id).first() + user_id=current_user.id, + osid=server.id).first() if shared_server is None: - ServerModule.create_shared_server(server, int(gid)) + try: + ServerModule.create_shared_server( + server, int(gid)) + except IntegrityError: + # Unique constraint violation from a concurrent + # request — the record now exists. + db.session.rollback() shared_server = SharedServer.query.filter_by( - name=server.name, user_id=current_user.id, - servergroup_id=int(gid), osid=server.id).first() + user_id=current_user.id, + osid=server.id).first() + + if shared_server is None: + raise Exception( + "Failed to create shared server record " + "for server {0}".format(server.id)) return shared_server @@ -495,17 +566,28 @@ class ServerNode(PGChildNodeView): 'clear_sshtunnel_password': [{'put': 'clear_sshtunnel_password'}], }) - def update_connection_parameter(self, data, server): + def update_connection_parameter(self, data, server, sharedserver=None): """ This function is used to update the connection parameters. """ if 'connection_params' in data and \ hasattr(server, 'connection_params'): - existing_conn_params = getattr(server, 'connection_params') + # For shared servers accessed by non-owners, apply changes + # to the SharedServer's connection_params (a copy) so we + # don't mutate the owner's Server record in-place. + if sharedserver is not None and \ + server.shared and \ + server.user_id != current_user.id: + existing_conn_params = dict( + sharedserver.connection_params or {}) + else: + existing_conn_params = getattr( + server, 'connection_params') new_conn_params = data['connection_params'] if 'deleted' in new_conn_params: for item in new_conn_params['deleted']: - del existing_conn_params[item['name']] + if item['name'] in existing_conn_params: + del existing_conn_params[item['name']] if 'added' in new_conn_params: for item in new_conn_params['added']: existing_conn_params[item['name']] = item['value'] @@ -560,15 +642,13 @@ class ServerNode(PGChildNodeView): Return a JSON document listing the servers under this server group for the user. """ - servers = Server.query.filter( - or_(Server.user_id == current_user.id, - Server.shared), + servers = get_user_server_query().filter( Server.servergroup_id == gid, Server.is_adhoc == 0) driver = get_driver(PG_DEFAULT_DRIVER) for server in servers: - if server.shared and server.user_id != current_user.id: + if _is_non_owner(server): shared_server = ServerModule.get_shared_server(server, gid) server = \ ServerModule.get_shared_server_properties(server, @@ -627,24 +707,22 @@ class ServerNode(PGChildNodeView): @pga_login_required def node(self, gid, sid): """Return a JSON document listing the server groups for the user""" - server = Server.query.filter_by(id=sid).first() - - if server.shared and server.user_id != current_user.id: - shared_server = ServerModule.get_shared_server(server, gid) - server = ServerModule.get_shared_server_properties(server, - shared_server) + server = get_server(sid) if server is None: return make_json_response( status=410, success=0, errormsg=gettext( - gettext( - "Could not find the server with id# {0}." - ).format(sid) - ) + "Could not find the server with id# {0}." + ).format(sid) ) + if _is_non_owner(server): + shared_server = ServerModule.get_shared_server(server, gid) + server = ServerModule.get_shared_server_properties(server, + shared_server) + manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(server.id) conn = manager.connection() connected = conn.connected() @@ -693,16 +771,20 @@ class ServerNode(PGChildNodeView): ), ) - def delete_shared_server(self, server_name, gid, osid): + def delete_shared_server(self, gid, osid, user_id=None): """ - Delete the shared server - :param server_name: - :return: + Delete SharedServer records for a given original server. + :param gid: Server group ID + :param osid: Original server ID + :param user_id: If set, only delete for this user. + If None, delete for ALL users (owner unshare/delete). """ try: - shared_server = SharedServer.query.filter_by(name=server_name, - servergroup_id=gid, - osid=osid) + filters = dict(servergroup_id=gid, osid=osid) + if user_id is not None: + filters['user_id'] = user_id + shared_server = SharedServer.query.filter_by( + **filters) for s in shared_server: get_driver(PG_DEFAULT_DRIVER).delete_manager(s.id) db.session.delete(s) @@ -738,7 +820,7 @@ class ServerNode(PGChildNodeView): get_driver(PG_DEFAULT_DRIVER).delete_manager(s.id) db.session.delete(s) db.session.commit() - self.delete_shared_server(server_name, gid, sid) + self.delete_shared_server(gid, sid) QueryHistory.clear_history(current_user.id, sid) except Exception as e: @@ -754,7 +836,7 @@ class ServerNode(PGChildNodeView): @pga_login_required def update(self, gid, sid): """Update the server settings""" - server = Server.query.filter_by(id=sid).first() + server = get_server(sid) sharedserver = None if server is None: @@ -821,7 +903,7 @@ class ServerNode(PGChildNodeView): data['db_res'] = ','.join(data['db_res']) # Update connection parameter if any. - self.update_connection_parameter(data, server) + self.update_connection_parameter(data, server, sharedserver) self.update_tags(data, server) if 'connection_params' in data and \ @@ -878,7 +960,7 @@ class ServerNode(PGChildNodeView): server.name, server_icon_and_background( connected, manager, sharedserver) - if server.shared and server.user_id != current_user.id + if _is_non_owner(server) else server_icon_and_background( connected, manager, server), True, @@ -902,7 +984,7 @@ class ServerNode(PGChildNodeView): if value == '': value = None - if server.shared and server.user_id != current_user.id: + if _is_non_owner(server): setattr(sharedserver, config_param_map[arg], value) else: setattr(server, config_param_map[arg], value) @@ -921,17 +1003,20 @@ class ServerNode(PGChildNodeView): value = data[arg] if arg == 'password': value = encrypt(data[arg], crypt_key) - # sqlite3 do not have boolean type so we need to convert - # it manually to integer - if 'shared' in data and not data['shared']: - # Delete the shared server from DB if server - # owner uncheck shared property - self.delete_shared_server(server.name, gid, server.id) + # sqlite3 do not have boolean type so we need to + # convert it manually to integer. + # Only the owner may unshare — this deletes ALL + # users' SharedServer records. + if 'shared' in data and not data['shared'] \ + and not _is_non_owner(server): + self.delete_shared_server(gid, server.id) if arg in ('sslcompression', 'use_ssh_tunnel', - 'tunnel_authentication', 'kerberos_conn', 'shared'): + 'tunnel_authentication', + 'kerberos_conn', 'shared'): value = 1 if value else 0 self._update_server_details(server, sharedserver, - config_param_map, arg, value) + config_param_map, arg, + value) idx += 1 return idx @@ -956,19 +1041,16 @@ class ServerNode(PGChildNodeView): """ Return list of attributes of all servers. """ - servers = Server.query.filter( - or_(Server.user_id == current_user.id, Server.shared), + servers = get_user_server_query().filter( Server.servergroup_id == gid, Server.is_adhoc == 0).order_by(Server.name) - sg = ServerGroup.query.filter_by( - id=gid - ).first() + sg = get_server_group(gid) res = [] driver = get_driver(PG_DEFAULT_DRIVER) for server in servers: - if server.shared and server.user_id != current_user.id: + if _is_non_owner(server): shared_server = ServerModule.get_shared_server(server, gid) server = \ ServerModule.get_shared_server_properties(server, @@ -1002,8 +1084,7 @@ class ServerNode(PGChildNodeView): def properties(self, gid, sid): """Return list of attributes of a server""" - server = Server.query.filter_by( - id=sid).first() + server = get_server(sid) if server is None: return make_json_response( @@ -1026,7 +1107,7 @@ class ServerNode(PGChildNodeView): # port and user when server is connected display_connection_str = self.update_connection_string(manager, server) - if server.shared and server.user_id != current_user.id: + if _is_non_owner(server): shared_server = ServerModule.get_shared_server(server, gid) server = ServerModule.get_shared_server_properties(server, shared_server) @@ -1079,10 +1160,13 @@ class ServerNode(PGChildNodeView): 'db_res': get_db_restriction(server.db_res_type, server.db_res), 'db_res_type': server.db_res_type, 'passexec_cmd': - server.passexec_cmd if server.passexec_cmd else None, + server.passexec_cmd + if server.passexec_cmd and + not _is_non_owner(server) else None, 'passexec_expiration': - server.passexec_expiration if server.passexec_expiration - else None, + server.passexec_expiration + if server.passexec_expiration and + not _is_non_owner(server) else None, 'service': server.service if server.service else None, 'use_ssh_tunnel': use_ssh_tunnel, 'tunnel_host': tunnel_host, @@ -1102,7 +1186,8 @@ class ServerNode(PGChildNodeView): 'connection_string': display_connection_str, 'prepare_threshold': server.prepare_threshold, 'tags': tags, - 'post_connection_sql': server.post_connection_sql, + 'post_connection_sql': server.post_connection_sql + if not _is_non_owner(server) else None, } return ajax_response(response) @@ -1395,7 +1480,12 @@ class ServerNode(PGChildNodeView): def connect_status(self, gid, sid): """Check and return the connection status.""" - server = Server.query.filter_by(id=sid).first() + server = get_server(sid) + if server is None: + return make_json_response( + status=410, success=0, + errormsg=self.not_found_error_msg() + ) manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(sid) conn = manager.connection() connected = conn.connected() @@ -1464,19 +1554,17 @@ class ServerNode(PGChildNodeView): # function in that case no need to fetch the server detail based on # sid. if server is None: - server = Server.query.filter_by(id=sid).first() + server = get_server(sid) - shared_server = None - if server.shared and server.user_id != current_user.id: - shared_server = ServerModule.get_shared_server(server, gid) - sess = object_session(server) - if sess is not None: - sess.expunge(server) - server = ServerModule.get_shared_server_properties(server, - shared_server) if server is None: return bad_request(self.not_found_error_msg()) + shared_server = None + if _is_non_owner(server): + shared_server = ServerModule.get_shared_server(server, gid) + server = ServerModule.get_shared_server_properties(server, + shared_server) + # Return if username is blank and the server is shared if server.username is None and not server.service and \ server.shared: @@ -1617,12 +1705,8 @@ class ServerNode(PGChildNodeView): else: if save_password and config.ALLOW_SAVE_PASSWORD: try: - # If DB server is running in trust mode then password may - # not be available but we don't need to ask password - # every time user try to connect # 1 is True in SQLite as no boolean type - setattr(server, 'save_password', 1) - if server.shared and server.user_id != current_user.id: + if _is_non_owner(server): setattr(shared_server, 'save_password', 1) else: setattr(server, 'save_password', 1) @@ -1630,7 +1714,7 @@ class ServerNode(PGChildNodeView): # Save the encrypted password using the user's login # password key, if there is any password to save if password: - if server.shared and server.user_id != current_user.id: + if _is_non_owner(server): setattr(shared_server, 'password', password) else: setattr(server, 'password', password) @@ -1646,7 +1730,11 @@ class ServerNode(PGChildNodeView): if save_tunnel_password and config.ALLOW_SAVE_TUNNEL_PASSWORD: try: # Save the encrypted tunnel password. - setattr(server, 'tunnel_password', tunnel_password) + if _is_non_owner(server): + setattr(shared_server, 'tunnel_password', + tunnel_password) + else: + setattr(server, 'tunnel_password', tunnel_password) db.session.commit() except Exception as e: # Release Connection @@ -1693,7 +1781,7 @@ class ServerNode(PGChildNodeView): def disconnect(self, gid, sid): """Disconnect the Server.""" - server = Server.query.filter_by(id=sid).first() + server = get_server(sid) if server is None: return bad_request(self.not_found_error_msg()) @@ -1818,7 +1906,7 @@ class ServerNode(PGChildNodeView): raise CryptKeyMissing # Fetch Server Details - server = Server.query.filter_by(id=sid).first() + server = get_server(sid, only_owned=False) if server is None: return bad_request(self.not_found_error_msg()) @@ -1905,11 +1993,24 @@ class ServerNode(PGChildNodeView): # Store password in sqlite only if no pgpass file if not is_passfile: password = encrypt(data['newPassword'], crypt_key) - # Check if old password was stored in pgadmin4 sqlite database. - # If yes then update that password. - if server.password is not None and config.ALLOW_SAVE_PASSWORD: - setattr(server, 'password', password) - db.session.commit() + # Check if old password was stored in pgadmin4 + # sqlite database. If yes then update that password. + # For non-owners of shared servers, check the + # SharedServer record (not the owner's Server). + if config.ALLOW_SAVE_PASSWORD: + if server.shared and \ + server.user_id != current_user.id: + shared_server = \ + ServerModule.get_shared_server( + server, gid) + if shared_server and \ + shared_server.password is not None: + setattr(shared_server, 'password', + password) + db.session.commit() + elif server.password is not None: + setattr(server, 'password', password) + db.session.commit() # Also update password in connection manager. manager.password = password manager.update_session() @@ -1929,9 +2030,7 @@ class ServerNode(PGChildNodeView): """ Utility function for wal_replay for resume/pause. """ - server = Server.query.filter_by( - user_id=current_user.id, id=sid - ).first() + server = get_server(sid) if server is None: return make_json_response( @@ -2015,9 +2114,7 @@ class ServerNode(PGChildNodeView): sid: Server id """ is_pgpass = False - server = Server.query.filter_by( - user_id=current_user.id, id=sid - ).first() + server = get_server(sid) if server is None: return make_json_response( @@ -2108,38 +2205,22 @@ class ServerNode(PGChildNodeView): :return: """ try: - server = Server.query.filter_by(id=sid).first() - shared_server = None + server = get_server(sid, only_owned=False) if server is None: return make_json_response( success=0, info=self.not_found_error_msg() ) - if server.shared and server.user_id != current_user.id: - shared_server = SharedServer.query.filter_by( - name=server.name, user_id=current_user.id, - servergroup_id=gid, osid=server.id).first() - - if shared_server is None: - return make_json_response( - success=0, - info=gettext("Could not find the required server.") - ) - server = ServerModule. \ - get_shared_server_properties(server, shared_server) - - if server.shared and server.user_id != current_user.id: + if _is_non_owner(server): + shared_server = ServerModule.get_shared_server( + server, gid) setattr(shared_server, 'password', None) + if shared_server.save_password: + setattr(shared_server, 'save_password', 0) else: setattr(server, 'password', None) - - # If password was saved then clear the flag also - # 0 is False in SQLite db - if server.save_password: - if server.shared and server.user_id != current_user.id: - setattr(shared_server, 'save_password', 0) - else: + if server.save_password: setattr(server, 'save_password', 0) db.session.commit() except Exception as e: @@ -2165,13 +2246,19 @@ class ServerNode(PGChildNodeView): :return: """ try: - server = Server.query.filter_by(id=sid).first() + server = get_server(sid, only_owned=False) if server is None: return make_json_response( success=0, info=self.not_found_error_msg() ) - setattr(server, 'tunnel_password', None) + + if _is_non_owner(server): + shared_server = ServerModule.get_shared_server( + server, gid) + setattr(shared_server, 'tunnel_password', None) + else: + setattr(server, 'tunnel_password', None) db.session.commit() except Exception as e: current_app.logger.error( diff --git a/web/pgadmin/browser/server_groups/servers/databases/__init__.py b/web/pgadmin/browser/server_groups/servers/databases/__init__.py index 1e88b59d1..1922db32c 100644 --- a/web/pgadmin/browser/server_groups/servers/databases/__init__.py +++ b/web/pgadmin/browser/server_groups/servers/databases/__init__.py @@ -34,6 +34,7 @@ from pgadmin.tools.sqleditor.utils.query_history import QueryHistory from pgadmin.tools.schema_diff.node_registry import SchemaDiffRegistry from pgadmin.model import db, Server, Database +from pgadmin.utils.server_access import get_server from pgadmin.browser.utils import underscore_escape from pgadmin.utils.constants import TWO_PARAM_STRING @@ -579,7 +580,9 @@ class DatabaseView(PGChildNodeView): 'already_connected': already_connected, 'connected': True, 'info_prefix': TWO_PARAM_STRING. - format(Server.query.filter_by(id=sid)[0].name, conn.db) + format(getattr( + get_server(sid), 'name', None) or + _('Unknown'), conn.db) } ) @@ -602,7 +605,9 @@ class DatabaseView(PGChildNodeView): 'icon': 'icon-database-not-connected', 'connected': False, 'info_prefix': TWO_PARAM_STRING. - format(Server.query.filter_by(id=sid)[0].name, conn.db) + format(getattr( + get_server(sid), 'name', None) or + _('Unknown'), conn.db) } ) diff --git a/web/pgadmin/browser/server_groups/servers/databases/schemas/views/__init__.py b/web/pgadmin/browser/server_groups/servers/databases/schemas/views/__init__.py index d978ebefd..5b73863bf 100644 --- a/web/pgadmin/browser/server_groups/servers/databases/schemas/views/__init__.py +++ b/web/pgadmin/browser/server_groups/servers/databases/schemas/views/__init__.py @@ -29,8 +29,9 @@ from pgadmin.utils.ajax import make_json_response, internal_server_error, \ from pgadmin.utils.driver import get_driver from pgadmin.tools.schema_diff.node_registry import SchemaDiffRegistry from .schema_diff_view_utils import SchemaDiffViewCompare -from pgadmin.utils import does_utility_exist, get_server +from pgadmin.utils import does_utility_exist from pgadmin.model import Server +from pgadmin.utils.server_access import get_server from pgadmin.misc.bgprocess.processes import BatchProcess, IProcessDesc from pgadmin.utils.constants import SERVER_NOT_FOUND @@ -2317,8 +2318,7 @@ class MViewNode(ViewNode, VacuumSettings): res['rows'][0]['name']) # Fetch the server details like hostname, port, roles etc - server = Server.query.filter_by( - id=sid).first() + server = get_server(sid) if server is None: return make_json_response( @@ -2436,9 +2436,7 @@ class MViewNode(ViewNode, VacuumSettings): Returns: None """ - server = Server.query.filter_by( - id=sid, user_id=current_user.id - ).first() + server = get_server(sid) if server is None: return make_json_response( diff --git a/web/pgadmin/browser/server_groups/servers/tests/test_server_data_isolation.py b/web/pgadmin/browser/server_groups/servers/tests/test_server_data_isolation.py new file mode 100644 index 000000000..bbee83e7a --- /dev/null +++ b/web/pgadmin/browser/server_groups/servers/tests/test_server_data_isolation.py @@ -0,0 +1,352 @@ +########################################################################## +# +# pgAdmin 4 - PostgreSQL Tools +# +# Copyright (C) 2013 - 2026, The pgAdmin Development Team +# This software is released under the PostgreSQL Licence +# +########################################################################## + +"""Tests for server data isolation between users in server mode.""" + +import json +import config +from pgadmin.utils.route import BaseTestGenerator +from regression.python_test_utils import test_utils as utils +from regression.test_setup import config_data +from regression.python_test_utils.test_utils import \ + create_user_wise_test_client + +test_user_details = None +if config.SERVER_MODE: + test_user_details = \ + config_data['pgAdmin4_test_non_admin_credentials'] + + +class ServerDataIsolationGetTestCase(BaseTestGenerator): + """Verify that a non-admin user cannot access another user's + private (non-shared) server by ID.""" + + scenarios = [ + ('User B gets 410 for User A private server', + dict(is_positive_test=False)), + ] + + def setUp(self): + self.server_id = None + if not config.SERVER_MODE: + self.skipTest( + 'Data isolation tests only apply to server mode.' + ) + + # Create a private (non-shared) server as the admin user + self.server['shared'] = False + url = "/browser/server/obj/{0}/".format(utils.SERVER_GROUP) + response = self.tester.post( + url, + data=json.dumps(self.server), + content_type='html/json' + ) + self.assertEqual(response.status_code, 200) + response_data = json.loads(response.data.decode('utf-8')) + self.assertIn('node', response_data) + self.server_id = response_data['node']['_id'] + + @create_user_wise_test_client(test_user_details) + def runTest(self): + """Non-admin user should NOT be able to GET another user's + private server.""" + if not self.server_id: + raise Exception("Server not found to test isolation") + + url = '/browser/server/obj/{0}/{1}'.format( + utils.SERVER_GROUP, self.server_id) + response = self.tester.get(url, follow_redirects=True) + # Expect 410 Gone (server not accessible to this user) + self.assertEqual( + response.status_code, 410, + 'Non-admin user should not access another user\'s ' + 'private server. Got status {0}'.format( + response.status_code) + ) + + def tearDown(self): + if self.server_id is None: + return + # Clean up with the admin tester (which owns the server) + utils.delete_server_with_api( + self.__class__.tester, self.server_id) + + +class SharedServerAccessTestCase(BaseTestGenerator): + """Verify that a shared server IS accessible by a non-admin + user (positive test — shared servers should work after the + isolation fixes).""" + + scenarios = [ + ('User B can access shared server from User A', + dict(is_positive_test=True)), + ] + + def setUp(self): + self.server_id = None + if not config.SERVER_MODE: + self.skipTest( + 'Data isolation tests only apply to server mode.' + ) + + # Create a shared server as the admin user + self.server['shared'] = True + url = "/browser/server/obj/{0}/".format(utils.SERVER_GROUP) + response = self.tester.post( + url, + data=json.dumps(self.server), + content_type='html/json' + ) + self.assertEqual(response.status_code, 200) + response_data = json.loads(response.data.decode('utf-8')) + self.assertIn('node', response_data) + self.server_id = response_data['node']['_id'] + + @create_user_wise_test_client(test_user_details) + def runTest(self): + """Non-admin user SHOULD be able to GET a shared server.""" + if not self.server_id: + raise Exception("Server not found to test shared access") + + url = '/browser/server/obj/{0}/{1}'.format( + utils.SERVER_GROUP, self.server_id) + response = self.tester.get(url, follow_redirects=True) + self.assertEqual( + response.status_code, 200, + 'Non-admin user should be able to access shared server.' + ' Got status {0}'.format(response.status_code) + ) + + def tearDown(self): + if self.server_id is None: + return + utils.delete_server_with_api( + self.__class__.tester, self.server_id) + + +class SharedServerFieldSuppressionTestCase(BaseTestGenerator): + """Verify that owner-only sensitive fields are suppressed + when a non-owner accesses a shared server's properties.""" + + scenarios = [ + ('Shared server suppresses passexec_cmd and ' + 'post_connection_sql for non-owner', + dict(is_positive_test=True)), + ] + + def setUp(self): + self.server_id = None + if not config.SERVER_MODE: + self.skipTest( + 'Data isolation tests only apply to server mode.' + ) + + # Create a shared server with sensitive owner-only fields + self.server['shared'] = True + self.server['passexec_cmd'] = '/usr/bin/get-secret' + self.server['passexec_expiration'] = 100 + self.server['post_connection_sql'] = 'SET role admin;' + url = "/browser/server/obj/{0}/".format(utils.SERVER_GROUP) + response = self.tester.post( + url, + data=json.dumps(self.server), + content_type='html/json' + ) + self.assertEqual(response.status_code, 200) + response_data = json.loads(response.data.decode('utf-8')) + self.assertIn('node', response_data) + self.server_id = response_data['node']['_id'] + + @create_user_wise_test_client(test_user_details) + def runTest(self): + """Non-owner should NOT see passexec_cmd or + post_connection_sql in properties response.""" + if not self.server_id: + raise Exception("Server not found to test suppression") + + url = '/browser/server/obj/{0}/{1}'.format( + utils.SERVER_GROUP, self.server_id) + response = self.tester.get(url, follow_redirects=True) + self.assertEqual(response.status_code, 200) + data = json.loads(response.data.decode('utf-8')) + + # passexec_cmd must be None/null for non-owners + self.assertIsNone( + data.get('passexec_cmd'), + 'passexec_cmd should be suppressed for non-owners.' + ' Got: {0}'.format(data.get('passexec_cmd')) + ) + self.assertIsNone( + data.get('passexec_expiration'), + 'passexec_expiration should be suppressed for ' + 'non-owners.' + ) + # post_connection_sql must be None/null for non-owners + self.assertIsNone( + data.get('post_connection_sql'), + 'post_connection_sql should be suppressed for ' + 'non-owners. Got: {0}'.format( + data.get('post_connection_sql')) + ) + + def tearDown(self): + if self.server_id is None: + return + utils.delete_server_with_api( + self.__class__.tester, self.server_id) + + +class SharedServerConnectionParamsIsolationTestCase( + BaseTestGenerator): + """Verify that owner's SSL file paths in connection_params + are not leaked to non-owners of shared servers.""" + + scenarios = [ + ('Shared server strips owner SSL paths for non-owner', + dict(is_positive_test=True)), + ] + + def setUp(self): + self.server_id = None + if not config.SERVER_MODE: + self.skipTest( + 'Data isolation tests only apply to server mode.' + ) + + # Create shared server with owner SSL paths + self.server['shared'] = True + # Set connection_params with owner-specific paths + conn_params = self.server.get('connection_params', {}) + conn_params['sslcert'] = '/home/owner/.ssl/cert.pem' + conn_params['sslkey'] = '/home/owner/.ssl/key.pem' + conn_params['sslrootcert'] = '/home/owner/.ssl/ca.pem' + self.server['connection_params'] = conn_params + url = "/browser/server/obj/{0}/".format(utils.SERVER_GROUP) + response = self.tester.post( + url, + data=json.dumps(self.server), + content_type='html/json' + ) + self.assertEqual(response.status_code, 200) + response_data = json.loads(response.data.decode('utf-8')) + self.assertIn('node', response_data) + self.server_id = response_data['node']['_id'] + + @create_user_wise_test_client(test_user_details) + def runTest(self): + """Non-owner should NOT see owner's SSL file paths + in connection_params.""" + if not self.server_id: + raise Exception("Server not found") + + url = '/browser/server/obj/{0}/{1}'.format( + utils.SERVER_GROUP, self.server_id) + response = self.tester.get(url, follow_redirects=True) + self.assertEqual(response.status_code, 200) + data = json.loads(response.data.decode('utf-8')) + + conn_params = data.get('connection_params', {}) + # Owner SSL paths should be stripped for non-owners + # (non-owner has no SharedServer SSL paths configured, + # so keys should be absent) + for key in ('sslcert', 'sslkey', 'sslrootcert', + 'sslcrl', 'sslcrldir'): + val = None + if isinstance(conn_params, list): + for item in conn_params: + if item.get('name') == key: + val = item.get('value') + break + elif isinstance(conn_params, dict): + val = conn_params.get(key) + self.assertIsNone( + val, + 'Owner SSL path "{0}" should not leak to ' + 'non-owner. Got: {1}'.format(key, val) + ) + + def tearDown(self): + if self.server_id is None: + return + utils.delete_server_with_api( + self.__class__.tester, self.server_id) + + +class SharedServerRenameDoesNotOrphanTestCase(BaseTestGenerator): + """Verify that renaming a shared server does not create + orphan SharedServer records (Issue 20 fix — lookup uses + osid, not name).""" + + scenarios = [ + ('Rename shared server preserves non-owner access', + dict(is_positive_test=True)), + ] + + def setUp(self): + self.server_id = None + if not config.SERVER_MODE: + self.skipTest( + 'Data isolation tests only apply to server mode.' + ) + + # Save admin tester BEFORE the decorator replaces it. + self.admin_tester = self.tester + + self.server['shared'] = True + url = "/browser/server/obj/{0}/".format(utils.SERVER_GROUP) + response = self.tester.post( + url, + data=json.dumps(self.server), + content_type='html/json' + ) + self.assertEqual(response.status_code, 200) + response_data = json.loads(response.data.decode('utf-8')) + self.assertIn('node', response_data) + self.server_id = response_data['node']['_id'] + + @create_user_wise_test_client(test_user_details) + def runTest(self): + """After owner renames the shared server, non-owner + should still be able to access it.""" + if not self.server_id: + raise Exception("Server not found") + + # First access as non-owner to create SharedServer record + url = '/browser/server/obj/{0}/{1}'.format( + utils.SERVER_GROUP, self.server_id) + response = self.tester.get(url, follow_redirects=True) + self.assertEqual(response.status_code, 200) + + # Rename the server as admin (saved in setUp before + # the decorator replaced self.tester). + response = self.admin_tester.put( + '/browser/server/obj/{0}/{1}'.format( + utils.SERVER_GROUP, self.server_id), + data=json.dumps( + {'name': 'renamed_shared_server'}), + content_type='html/json' + ) + self.assertIn( + response.status_code, [200], + 'Admin should be able to rename shared server.' + ) + + # Access again as non-owner — should still work + response = self.tester.get(url, follow_redirects=True) + self.assertEqual( + response.status_code, 200, + 'Non-owner should still access shared server after ' + 'rename. Got status {0}'.format(response.status_code) + ) + + def tearDown(self): + if self.server_id is None: + return + utils.delete_server_with_api( + self.__class__.tester, self.server_id) diff --git a/web/pgadmin/browser/server_groups/servers/tests/test_shared_server_unit.py b/web/pgadmin/browser/server_groups/servers/tests/test_shared_server_unit.py new file mode 100644 index 000000000..7b41af905 --- /dev/null +++ b/web/pgadmin/browser/server_groups/servers/tests/test_shared_server_unit.py @@ -0,0 +1,491 @@ +########################################################################## +# +# pgAdmin 4 - PostgreSQL Tools +# +# Copyright (C) 2013 - 2026, The pgAdmin Development Team +# This software is released under the PostgreSQL Licence +# +########################################################################## + +"""Unit tests for shared server isolation logic using mocks. + +These tests verify the security-critical merge, suppression, and +sanitization logic without requiring a running PostgreSQL server +or HTTP infrastructure. +""" + +from unittest.mock import MagicMock, patch, call +from pgadmin.utils.route import BaseTestGenerator + +SRV_MODULE = 'pgadmin.browser.server_groups.servers' + + +def _make_server(**overrides): + """Create a mock Server object with sensible defaults.""" + defaults = dict( + id=1, user_id=100, name='OwnerServer', + shared=True, host='db.owner.com', port=5432, + maintenance_db='postgres', username='owner', + password=b'enc_owner_pass', role=None, + bgcolor=None, fgcolor=None, service=None, + use_ssh_tunnel=0, tunnel_host=None, + tunnel_port=5522, tunnel_authentication=0, + tunnel_username=None, tunnel_password=None, + tunnel_identity_file=None, + tunnel_prompt_password=0, tunnel_keep_alive=30, + save_password=1, servergroup_id=1, + server_owner='owner_user', prepare_threshold=5, + passexec_cmd='/usr/bin/vault-get-secret', + passexec_expiration=300, + post_connection_sql='SET role admin;', + connection_params={ + 'sslmode': 'verify-full', + 'sslcert': '/home/owner/.ssl/cert.pem', + 'sslkey': '/home/owner/.ssl/key.pem', + 'sslrootcert': '/home/owner/.ssl/ca.pem', + 'sslcrl': '/home/owner/.ssl/crl.pem', + 'sslcrldir': '/home/owner/.ssl/crl.d', + 'passfile': '/home/owner/.pgpass', + 'connect_timeout': '10', + }, + discovery_id=None, db_res=None, db_res_type=None, + kerberos_conn=False, cloud_status=0, + shared_username='shared_user', tags=None, + is_adhoc=0, + ) + defaults.update(overrides) + server = MagicMock() + for k, v in defaults.items(): + setattr(server, k, v) + return server + + +def _make_shared_server(**overrides): + """Create a mock SharedServer object.""" + defaults = dict( + id=10, osid=1, user_id=200, + server_owner='owner_user', servergroup_id=2, + name='MySharedView', host='db.owner.com', + port=5432, maintenance_db='postgres', + username='nonowner', password=b'enc_nonowner', + save_password=0, role='readonly', + bgcolor='#ff0000', fgcolor='#ffffff', + service='my_pg_service', + use_ssh_tunnel=1, tunnel_host='bastion.local', + tunnel_port=2222, tunnel_authentication=1, + tunnel_username='tunneluser', + tunnel_password=b'enc_tunnel', + tunnel_identity_file='/home/user/.ssh/id_rsa', + tunnel_prompt_password=0, + tunnel_keep_alive=60, shared=True, + prepare_threshold=10, + connection_params={ + 'sslmode': 'verify-full', + 'sslcert': '/home/nonowner/.ssl/cert.pem', + 'connect_timeout': '10', + }, + ) + defaults.update(overrides) + ss = MagicMock() + for k, v in defaults.items(): + setattr(ss, k, v) + return ss + + +class TestGetSharedServerProperties(BaseTestGenerator): + """Unit tests for ServerModule.get_shared_server_properties() + using mock objects.""" + + scenarios = [ + ('Merge suppresses passexec_cmd', + dict(test_method='test_suppresses_passexec')), + ('Merge suppresses post_connection_sql', + dict(test_method='test_suppresses_post_sql')), + ('Merge strips owner SSL paths not in SharedServer', + dict(test_method='test_strips_owner_ssl_paths')), + ('Merge applies SharedServer SSL paths', + dict(test_method='test_applies_ss_ssl_paths')), + ('Merge overrides service from SharedServer', + dict(test_method='test_overrides_service')), + ('Merge overrides tunnel fields', + dict(test_method='test_overrides_tunnel')), + ('Merge handles None connection_params', + dict(test_method='test_none_conn_params')), + ] + + @patch('pgadmin.browser.server_groups.servers.' + 'object_session', return_value=None) + def runTest(self, mock_sess): + getattr(self, self.test_method)() + + def _merge(self, server=None, ss=None): + from pgadmin.browser.server_groups.servers import \ + ServerModule + if server is None: + server = _make_server() + if ss is None: + ss = _make_shared_server() + return ServerModule.get_shared_server_properties( + server, ss) + + def test_suppresses_passexec(self): + result = self._merge() + self.assertIsNone(result.passexec_cmd) + self.assertIsNone(result.passexec_expiration) + + def test_suppresses_post_sql(self): + result = self._merge() + self.assertIsNone(result.post_connection_sql) + + def test_strips_owner_ssl_paths(self): + result = self._merge() + cp = result.connection_params + # Owner had sslkey, sslrootcert, sslcrl, sslcrldir, + # passfile — SharedServer did not — should be removed. + self.assertNotIn('sslkey', cp) + self.assertNotIn('sslcrl', cp) + self.assertNotIn('sslcrldir', cp) + self.assertNotIn('sslrootcert', cp) + self.assertNotIn('passfile', cp) + + def test_applies_ss_ssl_paths(self): + result = self._merge() + cp = result.connection_params + # SharedServer had sslcert -- should override. + self.assertEqual( + cp['sslcert'], + '/home/nonowner/.ssl/cert.pem') + # Non-sensitive params preserved from owner. + self.assertEqual(cp['sslmode'], 'verify-full') + self.assertEqual(cp['connect_timeout'], '10') + + def test_overrides_service(self): + result = self._merge() + self.assertEqual(result.service, 'my_pg_service') + + def test_overrides_tunnel(self): + result = self._merge() + self.assertEqual(result.tunnel_host, 'bastion.local') + self.assertEqual(result.tunnel_port, 2222) + self.assertEqual(result.tunnel_username, 'tunneluser') + self.assertEqual(result.tunnel_authentication, 1) + self.assertEqual( + result.tunnel_identity_file, + '/home/user/.ssh/id_rsa') + + def test_none_conn_params(self): + server = _make_server(connection_params=None) + ss = _make_shared_server(connection_params=None) + result = self._merge(server, ss) + # Should not crash; connection_params becomes {} + self.assertEqual(result.connection_params, {}) + + +class TestCreateSharedServerSanitization(BaseTestGenerator): + """Verify create_shared_server() strips sensitive + connection_params keys.""" + + scenarios = [ + ('Sanitizes connection_params on creation', + dict(test_method='test_sanitizes_conn_params')), + ('Copies tunnel_port from owner', + dict(test_method='test_copies_tunnel_port')), + ('Copies tunnel_keep_alive from owner', + dict(test_method='test_copies_tunnel_keep_alive')), + ('Handles None connection_params', + dict(test_method='test_none_conn_params')), + ] + + @patch('pgadmin.browser.server_groups.servers.db') + @patch('pgadmin.browser.server_groups.servers.User') + @patch('pgadmin.browser.server_groups.servers.current_user') + @patch('pgadmin.browser.server_groups.servers.SharedServer') + def runTest(self, mock_ss_cls, mock_cu, mock_user, + mock_db): + mock_cu.id = 200 + mock_user.query.filter_by.return_value \ + .first.return_value = MagicMock(username='owner') + # Capture the SharedServer() constructor call + self.captured_kwargs = {} + + def capture_init(**kwargs): + self.captured_kwargs = kwargs + return MagicMock() + + mock_ss_cls.side_effect = capture_init + getattr(self, self.test_method)() + + def _create(self, server=None): + from pgadmin.browser.server_groups.servers import \ + ServerModule + if server is None: + server = _make_server() + ServerModule.create_shared_server(server, 1) + + def test_sanitizes_conn_params(self): + self._create() + cp = self.captured_kwargs.get('connection_params', {}) + # Sensitive keys must be stripped + for key in ('sslcert', 'sslkey', 'sslrootcert', + 'sslcrl', 'sslcrldir', 'passfile'): + self.assertNotIn( + key, cp, + 'Sensitive key "{0}" should be stripped ' + 'on SharedServer creation'.format(key)) + # Non-sensitive keys preserved + self.assertEqual(cp.get('sslmode'), 'verify-full') + self.assertEqual(cp.get('connect_timeout'), '10') + + def test_copies_tunnel_port(self): + server = _make_server(tunnel_port=2222) + self._create(server) + self.assertEqual( + self.captured_kwargs.get('tunnel_port'), 2222) + + def test_copies_tunnel_keep_alive(self): + server = _make_server(tunnel_keep_alive=45) + self._create(server) + self.assertEqual( + self.captured_kwargs.get('tunnel_keep_alive'), 45) + + def test_none_conn_params(self): + server = _make_server(connection_params=None) + self._create(server) + cp = self.captured_kwargs.get('connection_params', {}) + self.assertEqual(cp, {}) + + +class TestMergeExpungesServer(BaseTestGenerator): + """Verify get_shared_server_properties() expunges the server + from the SQLAlchemy session before mutation.""" + + scenarios = [ + ('Expunge called when server is in session', + dict(test_method='test_expunge_called')), + ('No crash when server not in session', + dict(test_method='test_no_session')), + ] + + def runTest(self): + getattr(self, self.test_method)() + + def test_expunge_called(self): + from pgadmin.browser.server_groups.servers import \ + ServerModule + server = _make_server() + ss = _make_shared_server() + mock_session = MagicMock() + with patch(SRV_MODULE + '.object_session', + return_value=mock_session): + ServerModule.get_shared_server_properties( + server, ss) + mock_session.expunge.assert_called_once_with(server) + + def test_no_session(self): + from pgadmin.browser.server_groups.servers import \ + ServerModule + server = _make_server() + ss = _make_shared_server() + with patch(SRV_MODULE + '.object_session', + return_value=None): + # Should not crash + result = ServerModule.get_shared_server_properties( + server, ss) + self.assertIsNone(result.passexec_cmd) + + +class TestUpdateConnectionParameter(BaseTestGenerator): + """Verify update_connection_parameter() routes changes + to SharedServer for non-owners.""" + + scenarios = [ + ('Non-owner changes go to SharedServer copy', + dict(test_method='test_nonowner_routing')), + ('Owner changes go to Server directly', + dict(test_method='test_owner_routing')), + ] + + def runTest(self): + getattr(self, self.test_method)() + + @patch(SRV_MODULE + '.current_user') + def test_nonowner_routing(self, mock_cu): + mock_cu.id = 200 # Non-owner + from pgadmin.browser.server_groups.servers import \ + ServerNode + + server = _make_server( + connection_params={'sslmode': 'require'}) + ss = _make_shared_server( + connection_params={'sslmode': 'require'}) + + data = {'connection_params': { + 'changed': [{'name': 'sslmode', 'value': 'verify'}] + }} + + node = ServerNode.__new__(ServerNode) + node.update_connection_parameter(data, server, ss) + + # The result should be in data, not mutating server + self.assertEqual( + data['connection_params']['sslmode'], 'verify') + # Owner's server should NOT be mutated + self.assertEqual( + server.connection_params['sslmode'], 'require') + + @patch(SRV_MODULE + '.current_user') + def test_owner_routing(self, mock_cu): + mock_cu.id = 100 # Owner + from pgadmin.browser.server_groups.servers import \ + ServerNode + + server = _make_server( + connection_params={'sslmode': 'require'}) + + data = {'connection_params': { + 'changed': [{'name': 'sslmode', 'value': 'verify'}] + }} + + node = ServerNode.__new__(ServerNode) + node.update_connection_parameter(data, server, None) + + # Owner path mutates server directly + self.assertEqual( + data['connection_params']['sslmode'], 'verify') + + +class TestUpdateServerDetails(BaseTestGenerator): + """Verify _update_server_details routes writes to + SharedServer for non-owners.""" + + scenarios = [ + ('Non-owner write goes to SharedServer', + dict(test_method='test_nonowner_write')), + ('Owner write goes to Server', + dict(test_method='test_owner_write')), + ] + + def runTest(self): + getattr(self, self.test_method)() + + @patch(SRV_MODULE + '.current_user') + def test_nonowner_write(self, mock_cu): + mock_cu.id = 200 + from pgadmin.browser.server_groups.servers import \ + ServerNode + + server = _make_server() + ss = _make_shared_server() + config_map = {'name': 'name'} + + ServerNode._update_server_details( + server, ss, config_map, 'name', 'NewName') + + self.assertEqual(ss.name, 'NewName') + # Server should not be modified + self.assertEqual(server.name, 'OwnerServer') + + @patch(SRV_MODULE + '.current_user') + def test_owner_write(self, mock_cu): + mock_cu.id = 100 + from pgadmin.browser.server_groups.servers import \ + ServerNode + + server = _make_server() + config_map = {'name': 'name'} + + ServerNode._update_server_details( + server, None, config_map, 'name', 'NewName') + + self.assertEqual(server.name, 'NewName') + + +class TestDeleteSharedServerOwnerGuard(BaseTestGenerator): + """Verify that only the owner can trigger + delete_shared_server via _set_valid_attr_value.""" + + scenarios = [ + ('Non-owner shared=false does not delete', + dict(test_method='test_nonowner_no_delete')), + ('Owner shared=false triggers delete', + dict(test_method='test_owner_deletes')), + ] + + def runTest(self): + getattr(self, self.test_method)() + + @patch(SRV_MODULE + '.get_crypt_key', + return_value=(True, b'key')) + @patch(SRV_MODULE + '.current_user') + def test_nonowner_no_delete(self, mock_cu, mock_ck): + mock_cu.id = 200 + from pgadmin.browser.server_groups.servers import \ + ServerNode + + server = _make_server() + ss = _make_shared_server() + node = ServerNode.__new__(ServerNode) + node.delete_shared_server = MagicMock() + + data = {'shared': False} + config_map = {'shared': 'shared'} + + node._set_valid_attr_value( + 1, data, config_map, server, ss) + + node.delete_shared_server.assert_not_called() + + @patch(SRV_MODULE + '.get_crypt_key', + return_value=(True, b'key')) + @patch(SRV_MODULE + '.current_user') + def test_owner_deletes(self, mock_cu, mock_ck): + mock_cu.id = 100 # Owner + from pgadmin.browser.server_groups.servers import \ + ServerNode + + server = _make_server() + node = ServerNode.__new__(ServerNode) + node.delete_shared_server = MagicMock() + + data = {'shared': False} + config_map = {'shared': 'shared'} + + node._set_valid_attr_value( + 1, data, config_map, server, None) + + node.delete_shared_server.assert_called_once_with( + 1, server.id) + + +class TestGetSharedServerRaisesOnNone(BaseTestGenerator): + """Verify get_shared_server() raises if SharedServer + cannot be created.""" + + scenarios = [ + ('Raises when SharedServer is None after create', + dict(test_method='test_raises_on_none')), + ] + + def runTest(self): + getattr(self, self.test_method)() + + @patch(SRV_MODULE + '.SharedServer') + @patch(SRV_MODULE + '.current_user') + def test_raises_on_none(self, mock_cu, mock_ss): + mock_cu.id = 200 + # Both queries return None + mock_ss.query.filter_by.return_value \ + .first.return_value = None + + from pgadmin.browser.server_groups.servers import \ + ServerModule + + server = _make_server() + + with patch.object(ServerModule, 'create_shared_server'): + with self.assertRaises(Exception) as ctx: + ServerModule.get_shared_server(server, 1) + + self.assertIn( + 'Failed to create shared server', + str(ctx.exception)) diff --git a/web/pgadmin/browser/server_groups/servers/utils.py b/web/pgadmin/browser/server_groups/servers/utils.py index d9ef4842a..3377d99cd 100644 --- a/web/pgadmin/browser/server_groups/servers/utils.py +++ b/web/pgadmin/browser/server_groups/servers/utils.py @@ -13,13 +13,14 @@ from ipaddress import ip_address import keyring from flask_login import current_user from werkzeug.exceptions import InternalServerError -from flask import render_template +from flask import render_template, has_request_context from pgadmin.utils.constants import ( KEY_RING_USERNAME_FORMAT, KEY_RING_SERVICE_NAME, KEY_RING_TUNNEL_FORMAT, KEY_RING_DESKTOP_USER, SSL_MODES, RESTRICTION_TYPE_DATABASES, RESTRICTION_TYPE_SQL) from pgadmin.utils.crypto import encrypt, decrypt from pgadmin.model import db, Server, SharedServer +from pgadmin.utils.server_access import get_user_server_query from flask import current_app from pgadmin.utils.master_password import set_masterpass_check_text from pgadmin.utils.driver import get_driver @@ -324,7 +325,10 @@ def migrate_passwords_from_pgadmin_db(servers, old_key, enc_key): def get_servers_with_saved_passwords(): - all_server = Server.query.filter(Server.is_adhoc == 0) + all_server = Server.query.filter( + Server.user_id == current_user.id, + Server.is_adhoc == 0 + ) servers_with_pwd_in_os_secret = [] servers_with_pwd_in_pgadmin_db = [] saved_password_servers = [] @@ -648,32 +652,56 @@ def check_ssl_fields(data): def disconnect_from_all_servers(): """ - This function is used to disconnect all the servers + This function is used to disconnect all the servers for the + current user (owned + shared). """ - all_servers = Server.query.all() + all_servers = get_user_server_query().all() for server in all_servers: - manager = get_driver(config.PG_DEFAULT_DRIVER).connection_manager( - server.id) - # Check if any psql terminal is running for the current disconnecting - # server. If any terminate the psql tool connection. - if 'sid_soid_mapping' in current_app.config and str(server.id) in \ - current_app.config['sid_soid_mapping'] and \ - str(server.id) in current_app.config['sid_soid_mapping']: - for i in current_app.config['sid_soid_mapping'][str(server.id)]: - sio.emit('disconnect-psql', namespace='/pty', to=i) - - manager.release() + try: + manager = get_driver( + config.PG_DEFAULT_DRIVER + ).connection_manager(server.id) + # Only emit disconnect-psql for servers owned by the + # current user — shared servers may have other users' + # PSQL sessions mapped to the same sid. + if server.user_id == current_user.id and \ + 'sid_soid_mapping' in current_app.config \ + and str(server.id) in \ + current_app.config['sid_soid_mapping']: + for i in current_app.config[ + 'sid_soid_mapping'][str(server.id)]: + sio.emit( + 'disconnect-psql', + namespace='/pty', to=i + ) + manager.release() + except Exception: + current_app.logger.warning( + 'Failed to disconnect server %s', + server.id, exc_info=True + ) def delete_adhoc_servers(sid=None): """ - This function will remove all the adhoc servers. + This function will remove adhoc servers. When called with a + current_user context, scopes to the current user. When called + during app startup (no user context), cleans all adhoc servers. """ try: + has_user = (has_request_context() and + current_user and current_user.is_authenticated) if sid is not None: - db.session.query(Server).filter(Server.id == sid).delete() + q = db.session.query(Server).filter( + Server.id == sid, Server.is_adhoc == 1) + if has_user: + q = q.filter(Server.user_id == current_user.id) + q.delete() else: - db.session.query(Server).filter(Server.is_adhoc == 1).delete() + q = db.session.query(Server).filter(Server.is_adhoc == 1) + if has_user: + q = q.filter(Server.user_id == current_user.id) + q.delete() db.session.commit() # Reset the sequence again diff --git a/web/pgadmin/browser/server_groups/tests/test_sg_data_isolation.py b/web/pgadmin/browser/server_groups/tests/test_sg_data_isolation.py new file mode 100644 index 000000000..8f870156c --- /dev/null +++ b/web/pgadmin/browser/server_groups/tests/test_sg_data_isolation.py @@ -0,0 +1,78 @@ +########################################################################## +# +# pgAdmin 4 - PostgreSQL Tools +# +# Copyright (C) 2013 - 2026, The pgAdmin Development Team +# This software is released under the PostgreSQL Licence +# +########################################################################## + +"""Tests for ServerGroup data isolation between users in server mode.""" + +import json +import config +from pgadmin.utils.route import BaseTestGenerator +from regression.python_test_utils import test_utils as utils +from regression.test_setup import config_data +from regression.python_test_utils.test_utils import \ + create_user_wise_test_client +from pgadmin.model import db, ServerGroup + +test_user_details = None +if config.SERVER_MODE: + test_user_details = \ + config_data['pgAdmin4_test_non_admin_credentials'] + + +class ServerGroupIsolationTestCase(BaseTestGenerator): + """Verify that a non-admin user cannot fetch another user's + server group properties by ID.""" + + scenarios = [ + ('User B cannot fetch User A server group properties', + dict(is_positive_test=False)), + ] + + def setUp(self): + self.sg_id = None + if not config.SERVER_MODE: + self.skipTest( + 'Data isolation tests only apply to server mode.' + ) + + # Create a server group as the admin user + url = '/browser/server_group/obj/' + response = self.tester.post( + url, + data=json.dumps({'name': 'isolation_test_group'}), + content_type='html/json' + ) + self.assertEqual(response.status_code, 200) + response_data = json.loads(response.data.decode('utf-8')) + self.assertIn('node', response_data) + self.sg_id = response_data['node']['_id'] + + @create_user_wise_test_client(test_user_details) + def runTest(self): + """Non-admin user should NOT see another user's server + group properties.""" + if not self.sg_id: + raise Exception("Server group not created") + + url = '/browser/server_group/obj/{0}'.format(self.sg_id) + response = self.tester.get(url, content_type='html/json') + self.assertEqual( + response.status_code, 410, + 'Non-admin user should not access another user\'s ' + 'server group. Got status {0}'.format( + response.status_code) + ) + + def tearDown(self): + # Clean up with admin + if self.sg_id is None: + return + sg = ServerGroup.query.filter_by(id=self.sg_id).first() + if sg: + db.session.delete(sg) + db.session.commit() diff --git a/web/pgadmin/misc/bgprocess/processes.py b/web/pgadmin/misc/bgprocess/processes.py index 5f59ec0ab..9a44a452d 100644 --- a/web/pgadmin/misc/bgprocess/processes.py +++ b/web/pgadmin/misc/bgprocess/processes.py @@ -153,7 +153,7 @@ class BatchProcess: self.manager_obj = kwargs['manager_obj'] def _retrieve_process(self, _id): - p = Process.query.filter_by(pid=_id, user_id=current_user.id).first() + p = Process.for_user(pid=_id).first() if p is None: raise LookupError(PROCESS_NOT_FOUND) @@ -372,9 +372,7 @@ class BatchProcess: # There is no way to find out the error message from this process # as standard output, and standard error were redirected to # devnull. - p = Process.query.filter_by( - pid=self.id, user_id=current_user.id - ).first() + p = Process.for_user(pid=self.id).first() p.start_time = p.end_time = get_current_time() if not p.exit_code: p.exit_code = self.ecode @@ -382,9 +380,7 @@ class BatchProcess: db.session.commit() else: # Update the process state to "Started" - p = Process.query.filter_by( - pid=self.id, user_id=current_user.id - ).first() + p = Process.for_user(pid=self.id).first() p.process_state = PROCESS_STARTED db.session.commit() @@ -530,9 +526,7 @@ class BatchProcess: """ _pid = self.id - _process = Process.query.filter_by( - user_id=current_user.id, pid=_pid - ).first() + _process = Process.for_user(pid=_pid).first() if _process is None: raise LookupError(PROCESS_NOT_FOUND) @@ -588,9 +582,7 @@ class BatchProcess: out_completed = err_completed = False process_output = (out != -1 and err != -1) - j = Process.query.filter_by( - pid=self.id, user_id=current_user.id - ).first() + j = Process.for_user(pid=self.id).first() enc = sys.getdefaultencoding() if enc == 'ascii': enc = 'utf-8' @@ -739,7 +731,7 @@ class BatchProcess: @staticmethod def list(): - processes = Process.query.filter_by(user_id=current_user.id) + processes = Process.for_user() changed = False browser_preference = Preferences.module('browser') @@ -812,9 +804,7 @@ class BatchProcess: And, delete the process information from the configuration, and the log files related to the process, if it has already been completed. """ - p = Process.query.filter_by( - user_id=current_user.id, pid=_pid - ).first() + p = Process.for_user(pid=_pid).first() if p is None: raise LookupError(PROCESS_NOT_FOUND) @@ -886,9 +876,7 @@ class BatchProcess: def stop_process(_pid): """ """ - p = Process.query.filter_by( - user_id=current_user.id, pid=_pid - ).first() + p = Process.for_user(pid=_pid).first() if p is None: raise LookupError(PROCESS_NOT_FOUND) @@ -910,9 +898,7 @@ class BatchProcess: @staticmethod def update_server_id(_pid, _sid): - p = Process.query.filter_by( - user_id=current_user.id, pid=_pid - ).first() + p = Process.for_user(pid=_pid).first() if p is None: raise LookupError(PROCESS_NOT_FOUND) diff --git a/web/pgadmin/misc/cloud/__init__.py b/web/pgadmin/misc/cloud/__init__.py index 6d61d3a97..f28484c47 100644 --- a/web/pgadmin/misc/cloud/__init__.py +++ b/web/pgadmin/misc/cloud/__init__.py @@ -212,8 +212,9 @@ def clear_cloud_session(pid=None): @pga_login_required def update_cloud_process(sid): """Update Cloud Server Process""" - _process = Process.query.filter_by(user_id=current_user.id, - server_id=sid).first() + _process = Process.for_user(server_id=sid).first() + if _process is None: + return success_return() _process.acknowledge = None db.session.commit() return success_return() diff --git a/web/pgadmin/misc/workspaces/__init__.py b/web/pgadmin/misc/workspaces/__init__.py index 1a99037a7..afb20b5e8 100644 --- a/web/pgadmin/misc/workspaces/__init__.py +++ b/web/pgadmin/misc/workspaces/__init__.py @@ -17,6 +17,7 @@ from flask_babel import gettext from flask_security import current_user from pgadmin.utils import PgAdminModule from pgadmin.model import db, Server +from pgadmin.utils.server_access import get_server from pgadmin.utils.driver import get_driver from pgadmin.utils.ajax import bad_request, make_json_response from pgadmin.browser.server_groups.servers.utils import ( @@ -132,7 +133,8 @@ def adhoc_connect_server(): username=new_username, name=new_server_name, role=new_role, - service=new_service + service=new_service, + user_id=current_user.id ).all() # If found matching servers then compare the connection_params as @@ -143,22 +145,27 @@ def adhoc_connect_server(): server = existing_server break else: - server = Server.query.filter_by(host=new_host, - port=new_port, - maintenance_db=new_db, - username=new_username, - name=new_server_name, - role=new_role, - service=new_service, - connection_params=connection_params - ).first() + server = Server.query.filter_by( + host=new_host, port=new_port, + maintenance_db=new_db, + username=new_username, + name=new_server_name, + role=new_role, + service=new_service, + connection_params=connection_params, + user_id=current_user.id + ).first() # If server is none then no server with the above combination is found. if server is None: # Check if sid is present in data if it is then used that sid. if ('sid' in data and data['sid'] is not None and int(data['sid']) > 0): - server = Server.query.filter_by(id=data['sid']).first() + server = get_server(data['sid']) + if server is None: + return bad_request(gettext( + "Could not find the required server." + )) # Clone the server object server = server.clone() @@ -220,23 +227,30 @@ def check_and_delete_adhoc_server(sid): This function is used to check for adhoc server and if all Query Tool and PSQL connections are closed then delete that server. """ - server = Server.query.filter_by(id=sid).first() - if server.is_adhoc: - # Check PSQL connections. If more connections are open for - # the given sid return from the function. - psql_connections = get_open_psql_connections() - if sid in psql_connections.values(): + server = get_server(sid) + if server is None: + # Server may be deleted or inaccessible; still attempt + # best-effort cleanup of adhoc state. + delete_adhoc_servers(sid) + return + if not server.is_adhoc: + return + + # Check PSQL connections. If more connections are open for + # the given sid return from the function. + psql_connections = get_open_psql_connections() + if sid in psql_connections.values(): + return + + # Check Query Tool connections for the given sid + manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(sid) + for key, value in manager.connections.items(): + if key.startswith('CONN') and value.connected(): return - # Check Query Tool connections for the given sid - manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(sid) - for key, value in manager.connections.items(): - if key.startswith('CONN') and value.connected(): - return + # Assumption at this point all the Query Tool and PSQL connections + # is closed, so now we can release the manager + manager.release() - # Assumption at this point all the Query Tool and PSQL connections - # is closed, so now we can release the manager - manager.release() - - # Delete the adhoc server from the pgadmin database - delete_adhoc_servers(sid) + # Delete the adhoc server from the pgadmin database + delete_adhoc_servers(sid) diff --git a/web/pgadmin/model/__init__.py b/web/pgadmin/model/__init__.py index 69b934683..62d89ca94 100644 --- a/web/pgadmin/model/__init__.py +++ b/web/pgadmin/model/__init__.py @@ -33,7 +33,7 @@ import config # ########################################################################## -SCHEMA_VERSION = 49 +SCHEMA_VERSION = 50 ########################################################################## # @@ -51,6 +51,60 @@ USER_ID = 'user.id' SERVER_ID = 'server.id' CASCADE_STR = "all, delete-orphan" + +class UserScopedMixin: + """Mixin for models that store per-user data. + + Provides for_user() as the default scoped query entry point. + Models with a 'user_id' column or a 'uid' column are supported + automatically — the mixin detects which column name is used. + + Usage: + # Instead of: + Process.query.filter_by(user_id=current_user.id, pid=pid) + # Use: + Process.for_user(pid=pid) + """ + + @classmethod + def _user_column(cls): + """Return the user-scoping column for this model.""" + if hasattr(cls, 'user_id'): + return cls.user_id + if hasattr(cls, 'uid'): + return cls.uid + raise AttributeError( + f"{cls.__name__} has no user_id or uid column" + ) + + @classmethod + def _user_column_name(cls): + """Return the column name string ('user_id' or 'uid').""" + if hasattr(cls, 'user_id'): + return 'user_id' + if hasattr(cls, 'uid'): + return 'uid' + raise AttributeError( + f"{cls.__name__} has no user_id or uid column" + ) + + @classmethod + def for_user(cls, user_id=None, **kwargs): + """Query scoped to a specific user (defaults to current_user). + + Args: + user_id: Explicit user ID. If None, uses current_user.id. + **kwargs: Additional filter_by arguments. + + Returns: + A SQLAlchemy query filtered by the user's ID. + """ + from flask_security import current_user as cu + uid = user_id if user_id is not None else cu.id + kwargs[cls._user_column_name()] = uid + return cls.query.filter_by(**kwargs) + + # Define models roles_users = db.Table( 'roles_users', @@ -158,7 +212,7 @@ class User(db.Model, UserMixin): locked = db.Column(db.Boolean(), default=False) -class Setting(db.Model): +class Setting(db.Model, UserScopedMixin): """Define a setting object""" __tablename__ = 'setting' user_id = db.Column(db.Integer, db.ForeignKey(USER_ID), primary_key=True) @@ -166,7 +220,7 @@ class Setting(db.Model): value = db.Column(db.Text()) -class ServerGroup(db.Model): +class ServerGroup(db.Model, UserScopedMixin): """Define a server group for the treeview""" __tablename__ = 'servergroup' id = db.Column(db.Integer, primary_key=True) @@ -185,7 +239,7 @@ class ServerGroup(db.Model): } -class Server(db.Model): +class Server(db.Model, UserScopedMixin): """Define a registered Postgres server""" __tablename__ = 'server' id = db.Column(db.Integer, primary_key=True) @@ -306,7 +360,7 @@ class Preferences(db.Model): name = db.Column(db.String(1024), nullable=False) -class UserPreference(db.Model): +class UserPreference(db.Model, UserScopedMixin): """Define the preference for a particular user.""" __tablename__ = 'user_preferences' pid = db.Column( @@ -318,9 +372,13 @@ class UserPreference(db.Model): value = db.Column(db.String(1024), nullable=False) -class DebuggerFunctionArguments(db.Model): +class DebuggerFunctionArguments(db.Model, UserScopedMixin): """Define the debugger input function arguments.""" __tablename__ = 'debugger_function_arguments' + user_id = db.Column( + db.Integer, db.ForeignKey(USER_ID), + nullable=False, primary_key=True + ) server_id = db.Column(db.Integer(), nullable=False, primary_key=True) database_id = db.Column(db.Integer(), nullable=False, primary_key=True) schema_id = db.Column(db.Integer(), nullable=False, primary_key=True) @@ -349,7 +407,7 @@ class DebuggerFunctionArguments(db.Model): value = db.Column(db.String(), nullable=True) -class Process(db.Model): +class Process(db.Model, UserScopedMixin): """Define the Process table.""" __tablename__ = 'process' pid = db.Column(db.String(), nullable=False, primary_key=True) @@ -382,7 +440,7 @@ class Keys(db.Model): value = db.Column(db.String(), nullable=False) -class QueryHistoryModel(db.Model): +class QueryHistoryModel(db.Model, UserScopedMixin): """Define the history SQL table.""" __tablename__ = 'query_history' srno = db.Column(db.Integer(), nullable=False, primary_key=True) @@ -397,7 +455,7 @@ class QueryHistoryModel(db.Model): last_updated_flag = db.Column(db.String(), nullable=False) -class ApplicationState(db.Model): +class ApplicationState(db.Model, UserScopedMixin): """Define the application state SQL table.""" __tablename__ = 'application_state' uid = db.Column(db.Integer(), db.ForeignKey(USER_ID), nullable=False, @@ -422,10 +480,14 @@ class Database(db.Model): ) -class SharedServer(db.Model): +class SharedServer(db.Model, UserScopedMixin): """Define a shared Postgres server""" __tablename__ = 'sharedserver' + __table_args__ = ( + db.UniqueConstraint('osid', 'user_id', + name='uq_sharedserver_osid_user'), + ) id = db.Column(db.Integer, primary_key=True) osid = db.Column( db.Integer, @@ -510,7 +572,7 @@ class Macros(db.Model): key_code = db.Column(db.Integer, nullable=False) -class UserMacros(db.Model): +class UserMacros(db.Model, UserScopedMixin): """Define the macro for a particular user.""" __tablename__ = 'user_macros' id = db.Column(db.Integer, primary_key=True, autoincrement=True) @@ -524,7 +586,7 @@ class UserMacros(db.Model): sql = db.Column(db.Text(), nullable=False) -class UserMFA(db.Model): +class UserMFA(db.Model, UserScopedMixin): """Stores the options for the MFA for a particular user.""" __tablename__ = 'user_mfa' user_id = db.Column(db.Integer, db.ForeignKey(USER_ID), primary_key=True) diff --git a/web/pgadmin/tools/backup/tests/test_batch_process.py b/web/pgadmin/tools/backup/tests/test_batch_process.py index d0921261e..08f6c6952 100644 --- a/web/pgadmin/tools/backup/tests/test_batch_process.py +++ b/web/pgadmin/tools/backup/tests/test_batch_process.py @@ -195,6 +195,7 @@ class BatchProcessTest(BaseTestGenerator): self.utility_pid = 123 self.server_id = None + process_mock.for_user = process_mock.query.filter_by mock_result = process_mock.query.filter_by.return_value mock_result.first.return_value = TestMockProcess( backup_obj, self.class_params['args'], self.class_params['cmd']) @@ -239,6 +240,7 @@ class BatchProcessTest(BaseTestGenerator): self.utility_pid = 123 self.server_id = None + process_mock.for_user = process_mock.query.filter_by process_mock.query.filter_by.return_value = [ TestMockProcess(backup_obj, self.class_params['args'], diff --git a/web/pgadmin/tools/debugger/__init__.py b/web/pgadmin/tools/debugger/__init__.py index e9a25deca..7d00f0fae 100644 --- a/web/pgadmin/tools/debugger/__init__.py +++ b/web/pgadmin/tools/debugger/__init__.py @@ -16,7 +16,7 @@ import copy from flask import render_template, request, current_app from flask_babel import gettext -from flask_security import permissions_required +from flask_security import permissions_required, current_user from pgadmin.user_login_check import pga_login_required from werkzeug.user_agent import UserAgent @@ -35,7 +35,9 @@ from pgadmin.browser.server_groups.servers.databases.extensions.utils \ import get_extension_details from pgadmin.utils.constants import PREF_LABEL_KEYBOARD_SHORTCUTS, \ SERVER_CONNECTION_CLOSED -from pgadmin.tools.user_management.PgAdminPermissions import AllPermissionTypes +from pgadmin.tools.user_management.PgAdminPermissions \ + import AllPermissionTypes +from pgadmin.utils.server_access import get_server from pgadmin.preferences import preferences MODULE_NAME = 'debugger' @@ -1803,12 +1805,19 @@ def get_arguments_sqlite(sid, did, scid, func_id): - Function Id """ + if get_server(sid) is None: + return make_json_response( + status=410, success=0, + errormsg=gettext("Could not find the required server.") + ) + """Get the count of the existing data available in sqlite database""" dbg_func_args_count = int(DebuggerFunctionArguments.query.filter_by( server_id=sid, database_id=did, schema_id=scid, - function_id=func_id + function_id=func_id, + user_id=current_user.id ).count()) args_data = [] @@ -1819,7 +1828,8 @@ def get_arguments_sqlite(sid, did, scid, func_id): server_id=sid, database_id=did, schema_id=scid, - function_id=func_id + function_id=func_id, + user_id=current_user.id ) args_list = dbg_func_args.all() @@ -1888,6 +1898,12 @@ def set_arguments_sqlite(sid, did, scid, func_id): - Function Id """ + if get_server(sid) is None: + return make_json_response( + status=410, success=0, + errormsg=gettext("Could not find the required server.") + ) + if request.data: data = json.loads(request.data) @@ -1899,7 +1915,8 @@ def set_arguments_sqlite(sid, did, scid, func_id): database_id=data[i]['database_id'], schema_id=data[i]['schema_id'], function_id=data[i]['function_id'], - arg_id=data[i]['arg_id']).count()) + arg_id=data[i]['arg_id'], + user_id=current_user.id).count()) # handle the Array list sent from the client array_string = '' @@ -1918,7 +1935,8 @@ def set_arguments_sqlite(sid, did, scid, func_id): database_id=data[i]['database_id'], schema_id=data[i]['schema_id'], function_id=data[i]['function_id'], - arg_id=data[i]['arg_id'] + arg_id=data[i]['arg_id'], + user_id=current_user.id ).first() dbg_func_args.is_null = data[i]['is_null'] @@ -1932,6 +1950,7 @@ def set_arguments_sqlite(sid, did, scid, func_id): schema_id=data[i]['schema_id'], function_id=data[i]['function_id'], arg_id=data[i]['arg_id'], + user_id=current_user.id, is_null=data[i]['is_null'], is_expression=data[i]['is_expression'], use_default=data[i]['use_default'], @@ -1977,12 +1996,20 @@ def clear_arguments_sqlite(sid, did, scid, func_id): - Function Id """ + if get_server(sid) is None: + return make_json_response( + status=410, success=0, + errormsg=gettext("Could not find the required server.") + ) + try: db.session.query(DebuggerFunctionArguments) \ .filter(DebuggerFunctionArguments.server_id == sid, DebuggerFunctionArguments.database_id == did, DebuggerFunctionArguments.schema_id == scid, - DebuggerFunctionArguments.function_id == func_id) \ + DebuggerFunctionArguments.function_id == func_id, + DebuggerFunctionArguments.user_id == + current_user.id) \ .delete() db.session.commit() diff --git a/web/pgadmin/tools/erd/__init__.py b/web/pgadmin/tools/erd/__init__.py index 16d9e4b7b..b70eb9f63 100644 --- a/web/pgadmin/tools/erd/__init__.py +++ b/web/pgadmin/tools/erd/__init__.py @@ -20,6 +20,7 @@ from pgadmin.utils import PgAdminModule, \ SHORTCUT_FIELDS as shortcut_fields from pgadmin.utils.ajax import make_json_response, internal_server_error from pgadmin.model import Server +from pgadmin.utils.server_access import get_server from config import PG_DEFAULT_DRIVER, ALLOW_SAVE_PASSWORD from pgadmin.utils.driver import get_driver from pgadmin.browser.utils import underscore_unescape @@ -556,7 +557,7 @@ def panel(trans_id): if "linux" in _platform: is_linux_platform = True - s = Server.query.filter_by(id=int(params['sid'])).first() + s = get_server(int(params['sid'])) if s: params.update({ diff --git a/web/pgadmin/tools/import_export/__init__.py b/web/pgadmin/tools/import_export/__init__.py index 30edc4f18..d7bd7c065 100644 --- a/web/pgadmin/tools/import_export/__init__.py +++ b/web/pgadmin/tools/import_export/__init__.py @@ -23,6 +23,7 @@ from pgadmin.utils.ajax import make_json_response, bad_request, unauthorized from config import PG_DEFAULT_DRIVER from pgadmin.model import Server +from pgadmin.utils.server_access import get_server from pgadmin.utils.constants import SERVER_NOT_FOUND from pgadmin.settings import get_setting, store_setting from pgadmin.tools.user_management.PgAdminPermissions import AllPermissionTypes @@ -97,9 +98,7 @@ class IEMessage(IProcessDesc): def get_server_name(self): # Fetch the server details like hostname, port, roles etc - s = Server.query.filter_by( - id=self.sid, user_id=current_user.id - ).first() + s = get_server(self.sid) if s is None: return _("Not available") @@ -293,8 +292,7 @@ def create_import_export_job(sid): data = json.loads(request.data) # Fetch the server details like hostname, port, roles etc - server = Server.query.filter_by( - id=sid).first() + server = get_server(sid) if server is None: return bad_request(errormsg=_("Could not find the specified server.")) diff --git a/web/pgadmin/tools/import_export/tests/test_batch_process.py b/web/pgadmin/tools/import_export/tests/test_batch_process.py index da42e4367..b9fbe124b 100644 --- a/web/pgadmin/tools/import_export/tests/test_batch_process.py +++ b/web/pgadmin/tools/import_export/tests/test_batch_process.py @@ -204,6 +204,7 @@ class BatchProcessTest(BaseTestGenerator): self.utility_pid = 123 self.server_id = None + process_mock.for_user = process_mock.query.filter_by mock_result = process_mock.query.filter_by.return_value mock_result.first.return_value = TestMockProcess( import_export_obj, self.class_params['args'], @@ -250,6 +251,7 @@ class BatchProcessTest(BaseTestGenerator): self.utility_pid = 123 self.server_id = None + process_mock.for_user = process_mock.query.filter_by process_mock.query.filter_by.return_value = [ TestMockProcess(import_export_obj, self.class_params['args'], diff --git a/web/pgadmin/tools/maintenance/tests/test_batch_process_maintenance.py b/web/pgadmin/tools/maintenance/tests/test_batch_process_maintenance.py index b2ca169f5..5e7ea7d57 100644 --- a/web/pgadmin/tools/maintenance/tests/test_batch_process_maintenance.py +++ b/web/pgadmin/tools/maintenance/tests/test_batch_process_maintenance.py @@ -137,6 +137,7 @@ class BatchProcessTest(BaseTestGenerator): self.utility_pid = 123 self.server_id = None + process_mock.for_user = process_mock.query.filter_by mock_result = process_mock.query.filter_by.return_value mock_result.first.return_value = TestMockProcess( maintenance_obj, self.class_params['args'], @@ -177,6 +178,7 @@ class BatchProcessTest(BaseTestGenerator): self.utility_pid = 123 self.server_id = None + process_mock.for_user = process_mock.query.filter_by process_mock.query.filter_by.return_value = [ TestMockProcess(maintenance_obj, self.class_params['args'], diff --git a/web/pgadmin/tools/psql/__init__.py b/web/pgadmin/tools/psql/__init__.py index 1b51cca66..ab1e22fd3 100644 --- a/web/pgadmin/tools/psql/__init__.py +++ b/web/pgadmin/tools/psql/__init__.py @@ -29,6 +29,7 @@ from ... import socketio as sio from pgadmin.utils import get_complete_file_path from pgadmin.authenticate import socket_login_required from pgadmin.model import Server +from pgadmin.utils.server_access import get_server if _platform == 'win32': # Check Windows platform support for WinPty api, Disable psql @@ -98,7 +99,7 @@ def panel(trans_id): if 'sid_soid_mapping' not in app.config: app.config['sid_soid_mapping'] = dict() - s = Server.query.filter_by(id=int(params['sid'])).first() + s = get_server(int(params['sid'])) if s: data = _get_database_role(params['sid'], params['did']) if data: diff --git a/web/pgadmin/tools/restore/tests/test_batch_process.py b/web/pgadmin/tools/restore/tests/test_batch_process.py index b0045d4e8..3f4f0bfa1 100644 --- a/web/pgadmin/tools/restore/tests/test_batch_process.py +++ b/web/pgadmin/tools/restore/tests/test_batch_process.py @@ -134,6 +134,7 @@ class BatchProcessTest(BaseTestGenerator): self.utility_pid = 123 self.server_id = None + process_mock.for_user = process_mock.query.filter_by mock_result = process_mock.query.filter_by.return_value mock_result.first.return_value = TestMockProcess( restore_obj, self.class_params['args'], @@ -174,6 +175,7 @@ class BatchProcessTest(BaseTestGenerator): self.utility_pid = 123 self.server_id = None + process_mock.for_user = process_mock.query.filter_by process_mock.query.filter_by.return_value = [ TestMockProcess(restore_obj, self.class_params['args'], diff --git a/web/pgadmin/tools/schema_diff/__init__.py b/web/pgadmin/tools/schema_diff/__init__.py index bc244e0b2..2470a4db1 100644 --- a/web/pgadmin/tools/schema_diff/__init__.py +++ b/web/pgadmin/tools/schema_diff/__init__.py @@ -31,6 +31,8 @@ from sqlalchemy import or_ from pgadmin.authenticate import socket_login_required from pgadmin import socketio from pgadmin.tools.user_management.PgAdminPermissions import AllPermissionTypes +from pgadmin.utils.server_access import \ + get_server as get_server_access, get_user_server_query MODULE_NAME = 'schema_diff' COMPARE_MSG = gettext("Comparing objects...") @@ -283,18 +285,14 @@ def servers(): from pgadmin.browser.server_groups.servers import\ server_icon_and_background - for server in Server.query.filter( - or_(Server.user_id == current_user.id, Server.shared), + for server in get_user_server_query().filter( Server.is_adhoc == 0): shared_server = SharedServer.query.filter_by( - name=server.name, user_id=current_user.id, - servergroup_id=server.servergroup_id).first() + user_id=current_user.id, + osid=server.id).first() - if server.discovery_id: - auto_detected_server = server.name - - if shared_server and shared_server.name == auto_detected_server: + if server.discovery_id and shared_server: continue manager = driver.connection_manager(server.id) @@ -336,7 +334,13 @@ def get_server(sid, did): """Return a JSON document listing the server groups for the user""" driver = get_driver(PG_DEFAULT_DRIVER) - server = Server.query.filter_by(id=sid).first() + server = get_server_access(sid) + if server is None: + return make_json_response( + status=410, success=0, + errormsg=gettext( + "Could not find the required server.") + ) manager = driver.connection_manager(sid) conn = manager.connection(did=did) connected = conn.connected() @@ -375,7 +379,12 @@ def connect_server(sid): data={} ) - server = Server.query.filter_by(id=sid).first() + server = get_server_access(sid) + if server is None: + return make_json_response( + status=410, success=0, + errormsg=gettext("Could not find the required server.") + ) view = SchemaDiffRegistry.get_node_view('server') return view.connect(server.servergroup_id, sid) @@ -387,7 +396,12 @@ def connect_server(sid): ) @pga_login_required def connect_database(sid, did): - server = Server.query.filter_by(id=sid).first() + server = get_server_access(sid) + if server is None: + return make_json_response( + status=410, success=0, + errormsg=gettext("Could not find the required server.") + ) view = SchemaDiffRegistry.get_node_view('database') return view.connect(server.servergroup_id, sid, did) @@ -407,7 +421,13 @@ def databases(sid): try: view = SchemaDiffRegistry.get_node_view('database') - server = Server.query.filter_by(id=sid).first() + server = get_server_access(sid) + if server is None: + return make_json_response( + status=410, success=0, + errormsg=gettext( + "Could not find the required server.") + ) response = view.nodes(gid=server.servergroup_id, sid=sid, is_schema_diff=True) databases = json.loads(response.data)['data'] @@ -495,6 +515,15 @@ def compare_database(params): fetch_compare_schemas(params['source_sid'], params['source_did'], params['target_sid'], params['target_did']) + if schema_result is None: + socketio.emit( + 'compare_database_failed', + gettext( + "Failed to fetch schemas from the" + " server."), + namespace=SOCKETIO_NAMESPACE, to=request.sid) + return + total_schema = len(schema_result['source_only']) + len( schema_result['target_only']) + len( schema_result['in_both_database']) @@ -722,11 +751,15 @@ def check_version_compatibility(sid, tid): """Check the version compatibility of source and target servers.""" driver = get_driver(PG_DEFAULT_DRIVER) - src_server = Server.query.filter_by(id=sid).first() + src_server = get_server_access(sid) + if src_server is None: + return False, gettext("Could not find the source server.") src_manager = driver.connection_manager(src_server.id) src_conn = src_manager.connection() - tar_server = Server.query.filter_by(id=tid).first() + tar_server = get_server_access(tid) + if tar_server is None: + return False, gettext("Could not find the target server.") tar_manager = driver.connection_manager(tar_server.id) target_conn = tar_manager.connection() @@ -759,7 +792,9 @@ def get_schemas(sid, did): """ try: view = SchemaDiffRegistry.get_node_view('schema') - server = Server.query.filter_by(id=sid).first() + server = get_server_access(sid) + if server is None: + return None response = view.nodes(gid=server.servergroup_id, sid=sid, did=did, is_schema_diff=True) schemas = json.loads(response.data)['data'] @@ -912,6 +947,9 @@ def fetch_compare_schemas(source_sid, source_did, target_sid, target_did): source_schemas = get_schemas(source_sid, source_did) target_schemas = get_schemas(target_sid, target_did) + if source_schemas is None or target_schemas is None: + return None + src_schema_dict = {item['label']: item['_id'] for item in source_schemas} tar_schema_dict = {item['label']: item['_id'] for item in target_schemas} diff --git a/web/pgadmin/tools/sqleditor/__init__.py b/web/pgadmin/tools/sqleditor/__init__.py index fe69994f1..c9e26df2f 100644 --- a/web/pgadmin/tools/sqleditor/__init__.py +++ b/web/pgadmin/tools/sqleditor/__init__.py @@ -63,6 +63,8 @@ from pgadmin.utils.constants import MIMETYPE_APP_JS, \ ERROR_FETCHING_DATA, MY_STORAGE, ACCESS_DENIED_MESSAGE, \ ERROR_MSG_FAIL_TO_PROMOTE_QT from pgadmin.model import Server, ServerGroup +from pgadmin.utils.server_access import get_server, \ + get_server_groups_for_user, get_user_server_query from pgadmin.tools.schema_diff.node_registry import SchemaDiffRegistry from pgadmin.settings import get_setting from pgadmin.utils.preferences import Preferences @@ -225,7 +227,12 @@ def initialize_viewdata(trans_id, cmd_type, obj_type, sgid, sid, did, obj_id): 'password': _data['password'] if 'password' in _data else None } - server = Server.query.filter_by(id=sid).first() + server = get_server(sid) + if server is None: + return make_json_response( + status=410, success=0, + errormsg=gettext("Could not find the required server.") + ) if kwargs.get('password', None) is None: kwargs['encpass'] = server.password else: @@ -374,7 +381,7 @@ def panel(trans_id): params['bgcolor'] = None params['fgcolor'] = None - s = Server.query.filter_by(id=int(params['sid'])).first() + s = get_server(int(params['sid'])) if s: if s.shared and s.user_id != current_user.id: # Import here to avoid circular dependency @@ -512,7 +519,12 @@ def _init_sqleditor(trans_id, connect, sgid, sid, did, dbname=None, **kwargs): kwargs.pop('conn_id') conn_id_ac = str(secrets.choice(range(1, 9999999))) - server = Server.query.filter_by(id=sid).first() + server = get_server(sid) + if server is None: + return True, internal_server_error( + errormsg=gettext( + "Could not find the required server.") + ), '', '' if server.shared and server.user_id != current_user.id: # Import here to avoid circular dependency from pgadmin.browser.server_groups.servers import ServerModule @@ -2344,8 +2356,13 @@ def _check_server_connection_status(sgid, sid=None): driver = get_driver(PG_DEFAULT_DRIVER) from pgadmin.browser.server_groups.servers import \ server_icon_and_background - server = Server.query.filter_by( - id=sid).first() + server = get_server(sid) + if server is None: + return make_json_response( + status=410, success=0, + errormsg=gettext( + "Could not find the required server.") + ) manager = driver.connection_manager(server.id) conn = manager.connection() @@ -2393,11 +2410,10 @@ def get_new_connection_data(sgid=None, sid=None): driver = get_driver(PG_DEFAULT_DRIVER) from pgadmin.browser.server_groups.servers import \ server_icon_and_background - server_groups = ServerGroup.query.all() + server_groups = get_server_groups_for_user() server_group_data = {server_group.name: [] for server_group in server_groups} - servers = Server.query.filter( - or_(Server.user_id == current_user.id, Server.shared), + servers = get_user_server_query().filter( Server.is_adhoc == 0) for server in servers: @@ -2654,7 +2670,12 @@ def get_new_connection_role(sgid, sid=None): @pga_login_required def connect_server(sid): # Check if server is already connected then no need to reconnect again. - server = Server.query.filter_by(id=sid).first() + server = get_server(sid) + if server is None: + return make_json_response( + status=410, success=0, + errormsg=gettext("Could not find the required server.") + ) driver = get_driver(PG_DEFAULT_DRIVER) manager = driver.connection_manager(sid) diff --git a/web/pgadmin/tools/user_management/__init__.py b/web/pgadmin/tools/user_management/__init__.py index 361187315..52497d4b3 100644 --- a/web/pgadmin/tools/user_management/__init__.py +++ b/web/pgadmin/tools/user_management/__init__.py @@ -759,7 +759,7 @@ def delete_user(uid): ServerGroup.query.filter_by(user_id=uid).delete() - Process.query.filter_by(user_id=uid).delete() + Process.for_user(user_id=uid).delete() # Delete Shared servers for current user. SharedServer.query.filter_by(user_id=uid).delete() diff --git a/web/pgadmin/utils/__init__.py b/web/pgadmin/utils/__init__.py index 6e4f5b22c..7f2136d99 100644 --- a/web/pgadmin/utils/__init__.py +++ b/web/pgadmin/utils/__init__.py @@ -358,14 +358,14 @@ def does_utility_exist(file): return error_msg -def get_server(sid): +def get_server(sid, only_owned=False): + """Fetch a server by ID with access check. + + Delegates to server_access.get_server(). Kept here for backward + compatibility — existing callers import from pgadmin.utils. """ - # Fetch the server etc - :param sid: - :return: server - """ - server = Server.query.filter_by(id=sid).first() - return server + from pgadmin.utils.server_access import get_server as _get_server + return _get_server(sid, only_owned=only_owned) def get_binary_path_versions(binary_path: str) -> dict: diff --git a/web/pgadmin/utils/driver/psycopg3/__init__.py b/web/pgadmin/utils/driver/psycopg3/__init__.py index 5bb606c3d..0695e83f2 100644 --- a/web/pgadmin/utils/driver/psycopg3/__init__.py +++ b/web/pgadmin/utils/driver/psycopg3/__init__.py @@ -16,6 +16,7 @@ object. import datetime import re from flask import session +from flask_babel import gettext from flask_login import current_user from werkzeug.exceptions import InternalServerError import psycopg @@ -23,6 +24,9 @@ from threading import Lock import config from pgadmin.model import Server +from pgadmin.utils.server_access import get_server, \ + get_user_server_query +from pgadmin.utils.exception import ObjectGone from .keywords import scan_keyword from ..abstract import BaseDriver from .connection import Connection @@ -67,20 +71,29 @@ class Driver(BaseDriver): def _restore_connections_from_session(self): """ Used internally by connection_manager to restore connections - from sessions. + from sessions. Includes both owned and shared servers so + non-owner connections survive session restore. """ if session.sid not in self.managers: self.managers[session.sid] = managers = dict() if '__pgsql_server_managers' in session: session_managers = \ session['__pgsql_server_managers'].copy() - for server in \ - Server.query.filter_by( - user_id=current_user.id, is_adhoc=0): + servers = get_user_server_query().filter( + Server.is_adhoc == 0) + for server in servers: manager = managers[str(server.id)] = \ ServerManager(server) + # Suppress owner-only fields for non-owners + # of shared servers so passexec_cmd and + # post_connection_sql don't leak. + if server.shared and \ + server.user_id != current_user.id: + manager.passexec = None + manager.post_connection_sql = None if server.id in session_managers: - manager._restore(session_managers[server.id]) + manager._restore( + session_managers[server.id]) manager.update_session() return managers @@ -100,9 +113,27 @@ class Driver(BaseDriver): assert (sid is not None and isinstance(sid, int)) managers = None - server_data = Server.query.filter_by(id=sid).first() - if server_data is None: - return None + # In server mode, verify the current user has access to this + # server. This is the primary security boundary — all + # check_precondition decorators and tool endpoints flow + # through connection_manager(). + if config.SERVER_MODE: + if current_user and current_user.is_authenticated: + server_data = get_server(sid) + else: + raise ObjectGone( + gettext("Server not found.")) + if server_data is None: + raise ObjectGone( + gettext("Server not found.")) + else: + # Desktop mode — single user, no isolation needed. + # Return None instead of raising so callers that + # handle None gracefully (e.g., test teardown, + # cleanup paths) are not disrupted. + server_data = Server.query.filter_by(id=sid).first() + if server_data is None: + return None if session.sid not in self.managers: with connection_restore_lock: @@ -119,14 +150,18 @@ class Driver(BaseDriver): managers['pinged'] = datetime.datetime.now() if str(sid) not in managers: - s = Server.query.filter_by(id=sid).first() + # server_data was already access-checked above; + # it cannot be None at this point. + manager = ServerManager(server_data) + # Suppress owner-only fields for non-owners of + # shared servers. + if config.SERVER_MODE and server_data.shared and \ + server_data.user_id != current_user.id: + manager.passexec = None + manager.post_connection_sql = None + managers[str(sid)] = manager - if not s: - return None - - managers[str(sid)] = ServerManager(s) - - return managers[str(sid)] + return manager return managers[str(sid)] diff --git a/web/pgadmin/utils/server_access.py b/web/pgadmin/utils/server_access.py new file mode 100644 index 000000000..1e0c8fe6a --- /dev/null +++ b/web/pgadmin/utils/server_access.py @@ -0,0 +1,156 @@ +########################################################################## +# +# pgAdmin 4 - PostgreSQL Tools +# +# Copyright (C) 2013 - 2026, The pgAdmin Development Team +# This software is released under the PostgreSQL Licence +# +########################################################################## + +"""Centralized server access-checking utilities for data isolation. + +In server mode, multiple users share the same pgAdmin instance. These +helpers enforce that users can only access servers they own or that +have been explicitly shared with them via SharedServer entries. +""" + +from sqlalchemy import or_ +from flask_security import current_user + +from pgadmin.model import db, Server, ServerGroup +import config + + +def _is_admin(): + """Check if current user has Administrator role.""" + return current_user.has_role('Administrator') + + +def get_server(sid, only_owned=False): + """Fetch a server by ID, verifying the current user has access. + + Args: + sid: Server ID. + only_owned: If True, only return servers owned by the current + user. Use this for write operations (change_password, + clear_saved_password, etc.) that must not mutate another + user's server record via shared access. + + Returns the server if: + - Desktop mode (single user, no isolation needed), OR + - The user owns it, OR + - The server is shared AND only_owned is False, OR + - The user has the Administrator role. + + Returns None otherwise (caller should return 404). + + Note: In pgAdmin, Server.shared=True means the server is visible + to all authenticated users. SharedServer records are created + lazily for per-user customization, not for access control. + """ + if not config.SERVER_MODE: + return Server.query.filter_by(id=sid).first() + + if only_owned: + return Server.query.filter_by( + id=sid, user_id=current_user.id).first() + + # Single query: owned OR shared + server = Server.query.filter( + Server.id == sid, + or_( + Server.user_id == current_user.id, + Server.shared + ) + ).first() + + if server is not None: + return server + + # Administrators can access all servers + if _is_admin(): + return Server.query.filter_by(id=sid).first() + + return None + + +def get_server_group(gid): + """Fetch a server group by ID, verifying user access. + + Returns the group if: + - Desktop mode, OR + - The user owns it, OR + - It contains shared servers (Server.shared=True), OR + - The user has the Administrator role. + + Returns None otherwise. + """ + if not config.SERVER_MODE: + return ServerGroup.query.filter_by(id=gid).first() + + sg = ServerGroup.query.filter( + ServerGroup.id == gid, + or_( + ServerGroup.user_id == current_user.id, + ServerGroup.id.in_( + db.session.query(Server.servergroup_id).filter( + Server.shared + ) + ) + ) + ).first() + + if sg is not None: + return sg + + if _is_admin(): + return ServerGroup.query.filter_by(id=gid).first() + + return None + + +def get_server_groups_for_user(): + """Return server groups visible to the current user. + + Includes groups owned by the user plus groups containing shared + servers (Server.shared=True, visible to all authenticated users). + Administrators see all groups. + """ + if not config.SERVER_MODE: + return ServerGroup.query.filter_by( + user_id=current_user.id + ).all() + + if _is_admin(): + return ServerGroup.query.all() + + return ServerGroup.query.filter( + or_( + ServerGroup.user_id == current_user.id, + ServerGroup.id.in_( + db.session.query(Server.servergroup_id).filter( + Server.shared + ) + ) + ) + ).all() + + +def get_user_server_query(): + """Return a base query for servers accessible to the current user. + + Includes owned servers + shared servers (visible to all users). + Administrators see all servers. + """ + if not config.SERVER_MODE: + return Server.query + + if _is_admin(): + return Server.query + + return Server.query.filter( + or_( + Server.user_id == current_user.id, + Server.shared + ) + )