fix: enforce data isolation and harden shared servers in server mode (#9830)

pgAdmin 4 in server mode had no data isolation between users — any
authenticated user could access other users' private servers,
background processes, and debugger state by guessing object IDs.
The shared server feature had 21 vulnerabilities including credential
leaks, privilege escalation via passexec_cmd, and owner data
corruption via SQLAlchemy session mutations.

Centralized access control:
- New server_access.py with get_server(), get_server_group(),
  get_user_server_query() replacing ~20 unfiltered queries
- connection_manager() raises ObjectGone (HTTP 410) in server mode
  when access is denied — fixes 155+ unguarded callers
- UserScopedMixin.for_user() on 10 models replaces scattered
  user_id filters

Shared server isolation (all 21 audit issues):
- Expunge server from session before property merge to prevent
  owner data corruption
- Suppress passexec_cmd, post_connection_sql for non-owners in
  merge, API response, and ServerManager
- Override all 6 SSL/passfile connection_params keys from
  SharedServer; strip owner-only keys; sanitize on creation
- _is_non_owner() helper centralises 15+ inline ownership checks
- SharedServer lookup uses (osid, user_id) not name
- Unique constraint on SharedServer(osid, user_id)
- Tunnel/DB password save, change_password, clear_saved_password,
  clear_sshtunnel_password all branch on ownership
- Only owner can unshare (delete_shared_server guard)
- Session restore includes shared servers
- tunnel_port/tunnel_keep_alive copied from owner, not hardcoded

Tool/module hardening:
- All tool endpoints use get_server()
- Debugger function arguments scoped by user_id
- Background processes use Process.for_user()
- Workspace adhoc servers scoped to current user

Migration (schema version 49 -> 50):
- Add user_id to debugger_function_arguments composite PK
- Add indexes on server, sharedserver, servergroup
- Add unique constraint on sharedserver(osid, user_id)
custom_user_support_in_docker^2
Ashesh Vashi 2026-04-09 18:32:59 +05:30 committed by GitHub
parent 872d5ac0b3
commit 9a76ed80bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 1850 additions and 312 deletions

View File

@ -0,0 +1,149 @@
##########################################################################
#
# pgAdmin 4 - PostgreSQL Tools
#
# Copyright (C) 2013 - 2026, The pgAdmin Development Team
# This software is released under the PostgreSQL Licence
#
##########################################################################
"""Add user_id to debugger_function_arguments and indexes for data isolation
Revision ID: add_user_id_dbg_args
Revises: add_tools_ai_perm
Create Date: 2026-04-08
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'add_user_id_dbg_args'
down_revision = 'add_tools_ai_perm'
branch_labels = None
depends_on = None
def upgrade():
conn = op.get_bind()
dialect = conn.dialect.name
# --- DebuggerFunctionArguments: add user_id to composite PK ---
if dialect == 'sqlite':
# SQLite cannot ALTER composite PKs. Recreate the table.
# Existing debugger argument data is ephemeral (cached function
# args) so dropping is acceptable.
op.execute(
'DROP TABLE IF EXISTS debugger_function_arguments'
)
op.create_table(
'debugger_function_arguments',
sa.Column('user_id', sa.Integer(),
sa.ForeignKey('user.id'), nullable=False),
sa.Column('server_id', sa.Integer(), nullable=False),
sa.Column('database_id', sa.Integer(), nullable=False),
sa.Column('schema_id', sa.Integer(), nullable=False),
sa.Column('function_id', sa.Integer(), nullable=False),
sa.Column('arg_id', sa.Integer(), nullable=False),
sa.Column('is_null', sa.Integer(), nullable=False),
sa.Column('is_expression', sa.Integer(), nullable=False),
sa.Column('use_default', sa.Integer(), nullable=False),
sa.Column('value', sa.String(), nullable=True),
sa.PrimaryKeyConstraint(
'user_id', 'server_id', 'database_id',
'schema_id', 'function_id', 'arg_id'
),
sa.CheckConstraint('is_null >= 0 AND is_null <= 1'),
sa.CheckConstraint(
'is_expression >= 0 AND is_expression <= 1'),
sa.CheckConstraint(
'use_default >= 0 AND use_default <= 1'),
)
else:
# PostgreSQL: add column, backfill from server owner, recreate
# PK using batch_alter_table for portability.
op.add_column(
'debugger_function_arguments',
sa.Column('user_id', sa.Integer(),
sa.ForeignKey('user.id'), nullable=True)
)
# Backfill: assign user_id from the server's owner
op.execute(
'UPDATE debugger_function_arguments '
'SET user_id = s.user_id '
'FROM server s '
'WHERE debugger_function_arguments.server_id = s.id'
)
# Delete orphans (rows with no matching server)
op.execute(
'DELETE FROM debugger_function_arguments '
'WHERE user_id IS NULL'
)
op.alter_column(
'debugger_function_arguments', 'user_id', nullable=False
)
# Recreate PK with user_id using batch_alter_table
with op.batch_alter_table(
'debugger_function_arguments'
) as batch:
batch.drop_constraint(
'debugger_function_arguments_pkey', type_='primary'
)
batch.create_primary_key(
'debugger_function_arguments_pkey',
['user_id', 'server_id', 'database_id',
'schema_id', 'function_id', 'arg_id']
)
# --- Indexes for data isolation query performance ---
# Only create indexes on tables that exist (sharedserver may be
# absent in older schemas that haven't run all prior migrations).
inspector = sa.inspect(conn)
index_stmts = [
('server',
'CREATE INDEX IF NOT EXISTS ix_server_user_id '
'ON server (user_id)'),
('server',
'CREATE INDEX IF NOT EXISTS ix_server_servergroup_id '
'ON server (servergroup_id)'),
('sharedserver',
'CREATE INDEX IF NOT EXISTS ix_sharedserver_user_id '
'ON sharedserver (user_id)'),
('sharedserver',
'CREATE INDEX IF NOT EXISTS ix_sharedserver_osid '
'ON sharedserver (osid)'),
('servergroup',
'CREATE INDEX IF NOT EXISTS ix_servergroup_user_id '
'ON servergroup (user_id)'),
]
for table_name, stmt in index_stmts:
if inspector.has_table(table_name):
op.execute(stmt)
# --- Unique constraint on SharedServer(osid, user_id) ---
# Prevents duplicate SharedServer records from TOCTOU race.
# First remove duplicates (keep lowest id per osid+user_id).
if inspector.has_table('sharedserver'):
if dialect == 'sqlite':
op.execute(
'DELETE FROM sharedserver WHERE id NOT IN '
'(SELECT MIN(id) FROM sharedserver '
'GROUP BY osid, user_id)'
)
else:
op.execute(
'DELETE FROM sharedserver s1 USING '
'sharedserver s2 WHERE s1.osid = s2.osid '
'AND s1.user_id = s2.user_id '
'AND s1.id > s2.id'
)
with op.batch_alter_table('sharedserver') as batch:
batch.create_unique_constraint(
'uq_sharedserver_osid_user',
['osid', 'user_id']
)
def downgrade():
# pgAdmin only upgrades, downgrade not implemented.
pass

View File

@ -15,8 +15,6 @@ Create Date: 2018-08-29 15:33:57.855491
"""
from alembic import op
from sqlalchemy.orm.session import Session
from pgadmin.model import DebuggerFunctionArguments
# revision identifiers, used by Alembic.
revision = 'ca00ec32581b'
@ -26,11 +24,10 @@ depends_on = None
def upgrade():
session = Session(bind=op.get_bind())
debugger_records = session.query(DebuggerFunctionArguments).all()
if debugger_records:
session.delete(debugger_records)
# Use raw SQL instead of importing the model class, because
# model changes in later migrations (e.g. adding user_id) would
# cause this migration to fail on fresh databases.
op.execute('DELETE FROM debugger_function_arguments')
def downgrade():

View File

@ -25,6 +25,8 @@ from sqlalchemy import exc
from pgadmin.model import db, ServerGroup, Server
import config
from pgadmin.utils.preferences import Preferences
from pgadmin.utils.server_access import get_server_group, \
get_server_groups_for_user
def get_icon_css_class(group_id, group_user_id,
@ -286,7 +288,7 @@ class ServerGroupView(NodeView):
def properties(self, gid):
"""Update the server-group properties"""
sg = ServerGroup.query.filter(ServerGroup.id == gid).first()
sg = get_server_group(gid)
if sg is None:
return make_json_response(
@ -296,7 +298,8 @@ class ServerGroupView(NodeView):
)
else:
return ajax_response(
response={'id': sg.id, 'name': sg.name, 'user_id': sg.user_id},
response={'id': sg.id, 'name': sg.name,
'user_id': sg.user_id},
status=200
)
@ -373,8 +376,9 @@ class ServerGroupView(NodeView):
@staticmethod
def get_all_server_groups():
"""
Returns the list of server groups to show in server mode and
if there is any shared server in the group.
Returns the list of server groups to show in server mode.
Includes groups owned by the user and groups containing
shared servers accessible to this user.
:return: server groups
"""
@ -383,17 +387,18 @@ class ServerGroupView(NodeView):
pref = Preferences.module('browser')
hide_shared_server = pref.preference('hide_shared_server').get()
server_groups = ServerGroup.query.all()
groups = []
for group in server_groups:
if hide_shared_server and \
ServerGroupModule.has_shared_server(group.id) and \
group.user_id != current_user.id:
continue
if group.user_id == current_user.id or \
ServerGroupModule.has_shared_server(group.id):
server_groups = get_server_groups_for_user()
if hide_shared_server:
groups = []
for group in server_groups:
if group.user_id != current_user.id and \
ServerGroupModule.has_shared_server(group.id):
continue
groups.append(group)
return groups
return groups
return server_groups
@pga_login_required
def nodes(self, gid=None):
@ -421,7 +426,7 @@ class ServerGroupView(NodeView):
)
)
else:
group = ServerGroup.query.filter(ServerGroup.id == gid).first()
group = get_server_group(gid)
if not group:
return gone(

View File

@ -39,12 +39,30 @@ from pgadmin.browser.server_groups.servers.utils import \
from pgadmin.utils.constants import UNAUTH_REQ, MIMETYPE_APP_JS, \
SERVER_CONNECTION_CLOSED, RESTRICTION_TYPE_SQL
from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import object_session
from sqlalchemy.orm.attributes import flag_modified
from pgadmin.utils.preferences import Preferences
from .... import socketio as sio
from pgadmin.utils import get_complete_file_path
from pgadmin.settings.utils import with_object_filters
from pgadmin.utils.server_access import get_server, \
get_user_server_query, get_server_group
# File-path keys in connection_params that are per-user and must
# not be copied from the owner to a new SharedServer or leaked
# through the property merge.
SENSITIVE_CONN_KEYS = frozenset({
'passfile', 'sslcert', 'sslkey',
'sslrootcert', 'sslcrl', 'sslcrldir',
})
def _is_non_owner(server):
"""True if the server is shared and the current user is not
the owner. Centralises the check used in 15+ places."""
return server.shared and server.user_id != current_user.id
def has_any(data, keys):
@ -151,15 +169,30 @@ class ServerModule(sg.ServerGroupPluginModule):
@staticmethod
def get_shared_server_properties(server, sharedserver):
"""
Return shared server properties
Return shared server properties.
Overlays per-user SharedServer values onto the owner's Server
object. Security-sensitive fields that are absent from the
SharedServer model (passexec_cmd, post_connection_sql) are
suppressed for non-owners.
The server is expunged from the SQLAlchemy session before
mutation so that the owner's record is never dirtied.
:param server:
:param sharedserver:
:return: shared server
:return: shared server (detached)
"""
# Detach from session so in-place mutations are never
# flushed back to the owner's Server row.
sess = object_session(server)
if sess is not None:
sess.expunge(server)
server.bgcolor = sharedserver.bgcolor
server.fgcolor = sharedserver.fgcolor
server.name = sharedserver.name
server.role = sharedserver.role
server.service = sharedserver.service
server.use_ssh_tunnel = sharedserver.use_ssh_tunnel
server.tunnel_host = sharedserver.tunnel_host
server.tunnel_port = sharedserver.tunnel_port
@ -169,24 +202,36 @@ class ServerModule(sg.ServerGroupPluginModule):
server.save_password = sharedserver.save_password
server.tunnel_identity_file = sharedserver.tunnel_identity_file
server.tunnel_prompt_password = sharedserver.tunnel_prompt_password
if hasattr(server, 'connection_params') and \
hasattr(sharedserver, 'connection_params') and \
'passfile' in server.connection_params and \
'passfile' in sharedserver.connection_params:
server.connection_params['passfile'] = \
sharedserver.connection_params['passfile']
# Override per-user connection_params keys. Use the
# SharedServer value whenever it is present, regardless of
# whether the owner's Server has the same key.
s_conn = getattr(server, 'connection_params', None) \
or {}
ss_conn = getattr(sharedserver, 'connection_params',
None) or {}
for key in SENSITIVE_CONN_KEYS:
if key in ss_conn:
s_conn[key] = ss_conn[key]
elif key in s_conn:
# Owner has this key but non-owner doesn't —
# remove it so the owner's path doesn't leak.
del s_conn[key]
server.connection_params = s_conn
server.servergroup_id = sharedserver.servergroup_id
if hasattr(server, 'connection_params') and \
hasattr(sharedserver, 'connection_params') and \
'sslcert' in server.connection_params and \
'sslcert' in sharedserver.connection_params:
server.connection_params['sslcert'] = \
sharedserver.connection_params['sslcert']
server.username = sharedserver.username
server.server_owner = sharedserver.server_owner
server.password = sharedserver.password
server.prepare_threshold = sharedserver.prepare_threshold
# Suppress owner-only fields that are absent from SharedServer
# and dangerous when inherited (privilege escalation / code
# execution).
server.passexec_cmd = None
server.passexec_expiration = None
server.post_connection_sql = None
return server
def get_servers(self, all_servers, hide_shared_server, gid):
@ -203,12 +248,13 @@ class ServerModule(sg.ServerGroupPluginModule):
if server.discovery_id and \
not server.shared and \
config.SERVER_MODE and \
len(SharedServer.query.filter_by(
SharedServer.query.filter_by(
user_id=current_user.id,
name=server.name).all()) > 0 and not hide_shared_server:
osid=server.id).first() is not None \
and not hide_shared_server:
continue
if server.shared and server.user_id != current_user.id:
if _is_non_owner(server):
shared_server = self.get_shared_server(server, gid)
@ -245,8 +291,7 @@ class ServerModule(sg.ServerGroupPluginModule):
"""Return a JSON document listing the server groups for the user"""
hide_shared_server = get_preferences()
servers = Server.query.filter(
or_(Server.user_id == current_user.id, Server.shared),
servers = get_user_server_query().filter(
Server.servergroup_id == gid, Server.is_adhoc == 0)
driver = get_driver(PG_DEFAULT_DRIVER)
@ -392,6 +437,18 @@ class ServerModule(sg.ServerGroupPluginModule):
try:
db.session.rollback()
user = User.query.filter_by(id=data.user_id).first()
# Strip owner's sensitive file paths from
# connection_params — each user should configure
# their own SSL/passfile paths.
safe_conn_params = {}
if data.connection_params:
safe_conn_params = {
k: v for k, v in
data.connection_params.items()
if k not in SENSITIVE_CONN_KEYS
}
shared_server = SharedServer(
osid=data.id,
user_id=current_user.id,
@ -410,43 +467,57 @@ class ServerModule(sg.ServerGroupPluginModule):
service=data.service if data.service else None,
use_ssh_tunnel=data.use_ssh_tunnel,
tunnel_host=data.tunnel_host,
tunnel_port=22,
tunnel_port=data.tunnel_port
if data.tunnel_port else 22,
tunnel_username=None,
tunnel_authentication=0,
tunnel_identity_file=None,
tunnel_keep_alive=0,
tunnel_keep_alive=data.tunnel_keep_alive
if data.tunnel_keep_alive else 0,
tunnel_prompt_password=0,
shared=True,
connection_params=data.connection_params,
connection_params=safe_conn_params,
prepare_threshold=data.prepare_threshold
)
db.session.add(shared_server)
db.session.commit()
except Exception as e:
if shared_server:
db.session.delete(shared_server)
db.session.commit()
db.session.rollback()
raise e
@staticmethod
def get_shared_server(server, gid):
"""
return the shared server
Return the SharedServer record for the current user,
creating one lazily if it doesn't exist. The unique
constraint on (osid, user_id) prevents duplicates from
concurrent requests.
:param server:
:param gid:
:return: shared_server
:return: shared_server (never None)
:raises: Exception if SharedServer cannot be created
"""
shared_server = SharedServer.query.filter_by(
name=server.name, user_id=current_user.id,
servergroup_id=int(gid), osid=server.id).first()
user_id=current_user.id,
osid=server.id).first()
if shared_server is None:
ServerModule.create_shared_server(server, int(gid))
try:
ServerModule.create_shared_server(
server, int(gid))
except IntegrityError:
# Unique constraint violation from a concurrent
# request — the record now exists.
db.session.rollback()
shared_server = SharedServer.query.filter_by(
name=server.name, user_id=current_user.id,
servergroup_id=int(gid), osid=server.id).first()
user_id=current_user.id,
osid=server.id).first()
if shared_server is None:
raise Exception(
"Failed to create shared server record "
"for server {0}".format(server.id))
return shared_server
@ -495,17 +566,28 @@ class ServerNode(PGChildNodeView):
'clear_sshtunnel_password': [{'put': 'clear_sshtunnel_password'}],
})
def update_connection_parameter(self, data, server):
def update_connection_parameter(self, data, server, sharedserver=None):
"""
This function is used to update the connection parameters.
"""
if 'connection_params' in data and \
hasattr(server, 'connection_params'):
existing_conn_params = getattr(server, 'connection_params')
# For shared servers accessed by non-owners, apply changes
# to the SharedServer's connection_params (a copy) so we
# don't mutate the owner's Server record in-place.
if sharedserver is not None and \
server.shared and \
server.user_id != current_user.id:
existing_conn_params = dict(
sharedserver.connection_params or {})
else:
existing_conn_params = getattr(
server, 'connection_params')
new_conn_params = data['connection_params']
if 'deleted' in new_conn_params:
for item in new_conn_params['deleted']:
del existing_conn_params[item['name']]
if item['name'] in existing_conn_params:
del existing_conn_params[item['name']]
if 'added' in new_conn_params:
for item in new_conn_params['added']:
existing_conn_params[item['name']] = item['value']
@ -560,15 +642,13 @@ class ServerNode(PGChildNodeView):
Return a JSON document listing the servers under this server group
for the user.
"""
servers = Server.query.filter(
or_(Server.user_id == current_user.id,
Server.shared),
servers = get_user_server_query().filter(
Server.servergroup_id == gid, Server.is_adhoc == 0)
driver = get_driver(PG_DEFAULT_DRIVER)
for server in servers:
if server.shared and server.user_id != current_user.id:
if _is_non_owner(server):
shared_server = ServerModule.get_shared_server(server, gid)
server = \
ServerModule.get_shared_server_properties(server,
@ -627,24 +707,22 @@ class ServerNode(PGChildNodeView):
@pga_login_required
def node(self, gid, sid):
"""Return a JSON document listing the server groups for the user"""
server = Server.query.filter_by(id=sid).first()
if server.shared and server.user_id != current_user.id:
shared_server = ServerModule.get_shared_server(server, gid)
server = ServerModule.get_shared_server_properties(server,
shared_server)
server = get_server(sid)
if server is None:
return make_json_response(
status=410,
success=0,
errormsg=gettext(
gettext(
"Could not find the server with id# {0}."
).format(sid)
)
"Could not find the server with id# {0}."
).format(sid)
)
if _is_non_owner(server):
shared_server = ServerModule.get_shared_server(server, gid)
server = ServerModule.get_shared_server_properties(server,
shared_server)
manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(server.id)
conn = manager.connection()
connected = conn.connected()
@ -693,16 +771,20 @@ class ServerNode(PGChildNodeView):
),
)
def delete_shared_server(self, server_name, gid, osid):
def delete_shared_server(self, gid, osid, user_id=None):
"""
Delete the shared server
:param server_name:
:return:
Delete SharedServer records for a given original server.
:param gid: Server group ID
:param osid: Original server ID
:param user_id: If set, only delete for this user.
If None, delete for ALL users (owner unshare/delete).
"""
try:
shared_server = SharedServer.query.filter_by(name=server_name,
servergroup_id=gid,
osid=osid)
filters = dict(servergroup_id=gid, osid=osid)
if user_id is not None:
filters['user_id'] = user_id
shared_server = SharedServer.query.filter_by(
**filters)
for s in shared_server:
get_driver(PG_DEFAULT_DRIVER).delete_manager(s.id)
db.session.delete(s)
@ -738,7 +820,7 @@ class ServerNode(PGChildNodeView):
get_driver(PG_DEFAULT_DRIVER).delete_manager(s.id)
db.session.delete(s)
db.session.commit()
self.delete_shared_server(server_name, gid, sid)
self.delete_shared_server(gid, sid)
QueryHistory.clear_history(current_user.id, sid)
except Exception as e:
@ -754,7 +836,7 @@ class ServerNode(PGChildNodeView):
@pga_login_required
def update(self, gid, sid):
"""Update the server settings"""
server = Server.query.filter_by(id=sid).first()
server = get_server(sid)
sharedserver = None
if server is None:
@ -821,7 +903,7 @@ class ServerNode(PGChildNodeView):
data['db_res'] = ','.join(data['db_res'])
# Update connection parameter if any.
self.update_connection_parameter(data, server)
self.update_connection_parameter(data, server, sharedserver)
self.update_tags(data, server)
if 'connection_params' in data and \
@ -878,7 +960,7 @@ class ServerNode(PGChildNodeView):
server.name,
server_icon_and_background(
connected, manager, sharedserver)
if server.shared and server.user_id != current_user.id
if _is_non_owner(server)
else server_icon_and_background(
connected, manager, server),
True,
@ -902,7 +984,7 @@ class ServerNode(PGChildNodeView):
if value == '':
value = None
if server.shared and server.user_id != current_user.id:
if _is_non_owner(server):
setattr(sharedserver, config_param_map[arg], value)
else:
setattr(server, config_param_map[arg], value)
@ -921,17 +1003,20 @@ class ServerNode(PGChildNodeView):
value = data[arg]
if arg == 'password':
value = encrypt(data[arg], crypt_key)
# sqlite3 do not have boolean type so we need to convert
# it manually to integer
if 'shared' in data and not data['shared']:
# Delete the shared server from DB if server
# owner uncheck shared property
self.delete_shared_server(server.name, gid, server.id)
# sqlite3 do not have boolean type so we need to
# convert it manually to integer.
# Only the owner may unshare — this deletes ALL
# users' SharedServer records.
if 'shared' in data and not data['shared'] \
and not _is_non_owner(server):
self.delete_shared_server(gid, server.id)
if arg in ('sslcompression', 'use_ssh_tunnel',
'tunnel_authentication', 'kerberos_conn', 'shared'):
'tunnel_authentication',
'kerberos_conn', 'shared'):
value = 1 if value else 0
self._update_server_details(server, sharedserver,
config_param_map, arg, value)
config_param_map, arg,
value)
idx += 1
return idx
@ -956,19 +1041,16 @@ class ServerNode(PGChildNodeView):
"""
Return list of attributes of all servers.
"""
servers = Server.query.filter(
or_(Server.user_id == current_user.id, Server.shared),
servers = get_user_server_query().filter(
Server.servergroup_id == gid,
Server.is_adhoc == 0).order_by(Server.name)
sg = ServerGroup.query.filter_by(
id=gid
).first()
sg = get_server_group(gid)
res = []
driver = get_driver(PG_DEFAULT_DRIVER)
for server in servers:
if server.shared and server.user_id != current_user.id:
if _is_non_owner(server):
shared_server = ServerModule.get_shared_server(server, gid)
server = \
ServerModule.get_shared_server_properties(server,
@ -1002,8 +1084,7 @@ class ServerNode(PGChildNodeView):
def properties(self, gid, sid):
"""Return list of attributes of a server"""
server = Server.query.filter_by(
id=sid).first()
server = get_server(sid)
if server is None:
return make_json_response(
@ -1026,7 +1107,7 @@ class ServerNode(PGChildNodeView):
# port and user when server is connected
display_connection_str = self.update_connection_string(manager, server)
if server.shared and server.user_id != current_user.id:
if _is_non_owner(server):
shared_server = ServerModule.get_shared_server(server, gid)
server = ServerModule.get_shared_server_properties(server,
shared_server)
@ -1079,10 +1160,13 @@ class ServerNode(PGChildNodeView):
'db_res': get_db_restriction(server.db_res_type, server.db_res),
'db_res_type': server.db_res_type,
'passexec_cmd':
server.passexec_cmd if server.passexec_cmd else None,
server.passexec_cmd
if server.passexec_cmd and
not _is_non_owner(server) else None,
'passexec_expiration':
server.passexec_expiration if server.passexec_expiration
else None,
server.passexec_expiration
if server.passexec_expiration and
not _is_non_owner(server) else None,
'service': server.service if server.service else None,
'use_ssh_tunnel': use_ssh_tunnel,
'tunnel_host': tunnel_host,
@ -1102,7 +1186,8 @@ class ServerNode(PGChildNodeView):
'connection_string': display_connection_str,
'prepare_threshold': server.prepare_threshold,
'tags': tags,
'post_connection_sql': server.post_connection_sql,
'post_connection_sql': server.post_connection_sql
if not _is_non_owner(server) else None,
}
return ajax_response(response)
@ -1395,7 +1480,12 @@ class ServerNode(PGChildNodeView):
def connect_status(self, gid, sid):
"""Check and return the connection status."""
server = Server.query.filter_by(id=sid).first()
server = get_server(sid)
if server is None:
return make_json_response(
status=410, success=0,
errormsg=self.not_found_error_msg()
)
manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(sid)
conn = manager.connection()
connected = conn.connected()
@ -1464,19 +1554,17 @@ class ServerNode(PGChildNodeView):
# function in that case no need to fetch the server detail based on
# sid.
if server is None:
server = Server.query.filter_by(id=sid).first()
server = get_server(sid)
shared_server = None
if server.shared and server.user_id != current_user.id:
shared_server = ServerModule.get_shared_server(server, gid)
sess = object_session(server)
if sess is not None:
sess.expunge(server)
server = ServerModule.get_shared_server_properties(server,
shared_server)
if server is None:
return bad_request(self.not_found_error_msg())
shared_server = None
if _is_non_owner(server):
shared_server = ServerModule.get_shared_server(server, gid)
server = ServerModule.get_shared_server_properties(server,
shared_server)
# Return if username is blank and the server is shared
if server.username is None and not server.service and \
server.shared:
@ -1617,12 +1705,8 @@ class ServerNode(PGChildNodeView):
else:
if save_password and config.ALLOW_SAVE_PASSWORD:
try:
# If DB server is running in trust mode then password may
# not be available but we don't need to ask password
# every time user try to connect
# 1 is True in SQLite as no boolean type
setattr(server, 'save_password', 1)
if server.shared and server.user_id != current_user.id:
if _is_non_owner(server):
setattr(shared_server, 'save_password', 1)
else:
setattr(server, 'save_password', 1)
@ -1630,7 +1714,7 @@ class ServerNode(PGChildNodeView):
# Save the encrypted password using the user's login
# password key, if there is any password to save
if password:
if server.shared and server.user_id != current_user.id:
if _is_non_owner(server):
setattr(shared_server, 'password', password)
else:
setattr(server, 'password', password)
@ -1646,7 +1730,11 @@ class ServerNode(PGChildNodeView):
if save_tunnel_password and config.ALLOW_SAVE_TUNNEL_PASSWORD:
try:
# Save the encrypted tunnel password.
setattr(server, 'tunnel_password', tunnel_password)
if _is_non_owner(server):
setattr(shared_server, 'tunnel_password',
tunnel_password)
else:
setattr(server, 'tunnel_password', tunnel_password)
db.session.commit()
except Exception as e:
# Release Connection
@ -1693,7 +1781,7 @@ class ServerNode(PGChildNodeView):
def disconnect(self, gid, sid):
"""Disconnect the Server."""
server = Server.query.filter_by(id=sid).first()
server = get_server(sid)
if server is None:
return bad_request(self.not_found_error_msg())
@ -1818,7 +1906,7 @@ class ServerNode(PGChildNodeView):
raise CryptKeyMissing
# Fetch Server Details
server = Server.query.filter_by(id=sid).first()
server = get_server(sid, only_owned=False)
if server is None:
return bad_request(self.not_found_error_msg())
@ -1905,11 +1993,24 @@ class ServerNode(PGChildNodeView):
# Store password in sqlite only if no pgpass file
if not is_passfile:
password = encrypt(data['newPassword'], crypt_key)
# Check if old password was stored in pgadmin4 sqlite database.
# If yes then update that password.
if server.password is not None and config.ALLOW_SAVE_PASSWORD:
setattr(server, 'password', password)
db.session.commit()
# Check if old password was stored in pgadmin4
# sqlite database. If yes then update that password.
# For non-owners of shared servers, check the
# SharedServer record (not the owner's Server).
if config.ALLOW_SAVE_PASSWORD:
if server.shared and \
server.user_id != current_user.id:
shared_server = \
ServerModule.get_shared_server(
server, gid)
if shared_server and \
shared_server.password is not None:
setattr(shared_server, 'password',
password)
db.session.commit()
elif server.password is not None:
setattr(server, 'password', password)
db.session.commit()
# Also update password in connection manager.
manager.password = password
manager.update_session()
@ -1929,9 +2030,7 @@ class ServerNode(PGChildNodeView):
"""
Utility function for wal_replay for resume/pause.
"""
server = Server.query.filter_by(
user_id=current_user.id, id=sid
).first()
server = get_server(sid)
if server is None:
return make_json_response(
@ -2015,9 +2114,7 @@ class ServerNode(PGChildNodeView):
sid: Server id
"""
is_pgpass = False
server = Server.query.filter_by(
user_id=current_user.id, id=sid
).first()
server = get_server(sid)
if server is None:
return make_json_response(
@ -2108,38 +2205,22 @@ class ServerNode(PGChildNodeView):
:return:
"""
try:
server = Server.query.filter_by(id=sid).first()
shared_server = None
server = get_server(sid, only_owned=False)
if server is None:
return make_json_response(
success=0,
info=self.not_found_error_msg()
)
if server.shared and server.user_id != current_user.id:
shared_server = SharedServer.query.filter_by(
name=server.name, user_id=current_user.id,
servergroup_id=gid, osid=server.id).first()
if shared_server is None:
return make_json_response(
success=0,
info=gettext("Could not find the required server.")
)
server = ServerModule. \
get_shared_server_properties(server, shared_server)
if server.shared and server.user_id != current_user.id:
if _is_non_owner(server):
shared_server = ServerModule.get_shared_server(
server, gid)
setattr(shared_server, 'password', None)
if shared_server.save_password:
setattr(shared_server, 'save_password', 0)
else:
setattr(server, 'password', None)
# If password was saved then clear the flag also
# 0 is False in SQLite db
if server.save_password:
if server.shared and server.user_id != current_user.id:
setattr(shared_server, 'save_password', 0)
else:
if server.save_password:
setattr(server, 'save_password', 0)
db.session.commit()
except Exception as e:
@ -2165,13 +2246,19 @@ class ServerNode(PGChildNodeView):
:return:
"""
try:
server = Server.query.filter_by(id=sid).first()
server = get_server(sid, only_owned=False)
if server is None:
return make_json_response(
success=0,
info=self.not_found_error_msg()
)
setattr(server, 'tunnel_password', None)
if _is_non_owner(server):
shared_server = ServerModule.get_shared_server(
server, gid)
setattr(shared_server, 'tunnel_password', None)
else:
setattr(server, 'tunnel_password', None)
db.session.commit()
except Exception as e:
current_app.logger.error(

View File

@ -34,6 +34,7 @@ from pgadmin.tools.sqleditor.utils.query_history import QueryHistory
from pgadmin.tools.schema_diff.node_registry import SchemaDiffRegistry
from pgadmin.model import db, Server, Database
from pgadmin.utils.server_access import get_server
from pgadmin.browser.utils import underscore_escape
from pgadmin.utils.constants import TWO_PARAM_STRING
@ -579,7 +580,9 @@ class DatabaseView(PGChildNodeView):
'already_connected': already_connected,
'connected': True,
'info_prefix': TWO_PARAM_STRING.
format(Server.query.filter_by(id=sid)[0].name, conn.db)
format(getattr(
get_server(sid), 'name', None) or
_('Unknown'), conn.db)
}
)
@ -602,7 +605,9 @@ class DatabaseView(PGChildNodeView):
'icon': 'icon-database-not-connected',
'connected': False,
'info_prefix': TWO_PARAM_STRING.
format(Server.query.filter_by(id=sid)[0].name, conn.db)
format(getattr(
get_server(sid), 'name', None) or
_('Unknown'), conn.db)
}
)

View File

@ -29,8 +29,9 @@ from pgadmin.utils.ajax import make_json_response, internal_server_error, \
from pgadmin.utils.driver import get_driver
from pgadmin.tools.schema_diff.node_registry import SchemaDiffRegistry
from .schema_diff_view_utils import SchemaDiffViewCompare
from pgadmin.utils import does_utility_exist, get_server
from pgadmin.utils import does_utility_exist
from pgadmin.model import Server
from pgadmin.utils.server_access import get_server
from pgadmin.misc.bgprocess.processes import BatchProcess, IProcessDesc
from pgadmin.utils.constants import SERVER_NOT_FOUND
@ -2317,8 +2318,7 @@ class MViewNode(ViewNode, VacuumSettings):
res['rows'][0]['name'])
# Fetch the server details like hostname, port, roles etc
server = Server.query.filter_by(
id=sid).first()
server = get_server(sid)
if server is None:
return make_json_response(
@ -2436,9 +2436,7 @@ class MViewNode(ViewNode, VacuumSettings):
Returns:
None
"""
server = Server.query.filter_by(
id=sid, user_id=current_user.id
).first()
server = get_server(sid)
if server is None:
return make_json_response(

View File

@ -0,0 +1,352 @@
##########################################################################
#
# pgAdmin 4 - PostgreSQL Tools
#
# Copyright (C) 2013 - 2026, The pgAdmin Development Team
# This software is released under the PostgreSQL Licence
#
##########################################################################
"""Tests for server data isolation between users in server mode."""
import json
import config
from pgadmin.utils.route import BaseTestGenerator
from regression.python_test_utils import test_utils as utils
from regression.test_setup import config_data
from regression.python_test_utils.test_utils import \
create_user_wise_test_client
test_user_details = None
if config.SERVER_MODE:
test_user_details = \
config_data['pgAdmin4_test_non_admin_credentials']
class ServerDataIsolationGetTestCase(BaseTestGenerator):
"""Verify that a non-admin user cannot access another user's
private (non-shared) server by ID."""
scenarios = [
('User B gets 410 for User A private server',
dict(is_positive_test=False)),
]
def setUp(self):
self.server_id = None
if not config.SERVER_MODE:
self.skipTest(
'Data isolation tests only apply to server mode.'
)
# Create a private (non-shared) server as the admin user
self.server['shared'] = False
url = "/browser/server/obj/{0}/".format(utils.SERVER_GROUP)
response = self.tester.post(
url,
data=json.dumps(self.server),
content_type='html/json'
)
self.assertEqual(response.status_code, 200)
response_data = json.loads(response.data.decode('utf-8'))
self.assertIn('node', response_data)
self.server_id = response_data['node']['_id']
@create_user_wise_test_client(test_user_details)
def runTest(self):
"""Non-admin user should NOT be able to GET another user's
private server."""
if not self.server_id:
raise Exception("Server not found to test isolation")
url = '/browser/server/obj/{0}/{1}'.format(
utils.SERVER_GROUP, self.server_id)
response = self.tester.get(url, follow_redirects=True)
# Expect 410 Gone (server not accessible to this user)
self.assertEqual(
response.status_code, 410,
'Non-admin user should not access another user\'s '
'private server. Got status {0}'.format(
response.status_code)
)
def tearDown(self):
if self.server_id is None:
return
# Clean up with the admin tester (which owns the server)
utils.delete_server_with_api(
self.__class__.tester, self.server_id)
class SharedServerAccessTestCase(BaseTestGenerator):
"""Verify that a shared server IS accessible by a non-admin
user (positive test shared servers should work after the
isolation fixes)."""
scenarios = [
('User B can access shared server from User A',
dict(is_positive_test=True)),
]
def setUp(self):
self.server_id = None
if not config.SERVER_MODE:
self.skipTest(
'Data isolation tests only apply to server mode.'
)
# Create a shared server as the admin user
self.server['shared'] = True
url = "/browser/server/obj/{0}/".format(utils.SERVER_GROUP)
response = self.tester.post(
url,
data=json.dumps(self.server),
content_type='html/json'
)
self.assertEqual(response.status_code, 200)
response_data = json.loads(response.data.decode('utf-8'))
self.assertIn('node', response_data)
self.server_id = response_data['node']['_id']
@create_user_wise_test_client(test_user_details)
def runTest(self):
"""Non-admin user SHOULD be able to GET a shared server."""
if not self.server_id:
raise Exception("Server not found to test shared access")
url = '/browser/server/obj/{0}/{1}'.format(
utils.SERVER_GROUP, self.server_id)
response = self.tester.get(url, follow_redirects=True)
self.assertEqual(
response.status_code, 200,
'Non-admin user should be able to access shared server.'
' Got status {0}'.format(response.status_code)
)
def tearDown(self):
if self.server_id is None:
return
utils.delete_server_with_api(
self.__class__.tester, self.server_id)
class SharedServerFieldSuppressionTestCase(BaseTestGenerator):
"""Verify that owner-only sensitive fields are suppressed
when a non-owner accesses a shared server's properties."""
scenarios = [
('Shared server suppresses passexec_cmd and '
'post_connection_sql for non-owner',
dict(is_positive_test=True)),
]
def setUp(self):
self.server_id = None
if not config.SERVER_MODE:
self.skipTest(
'Data isolation tests only apply to server mode.'
)
# Create a shared server with sensitive owner-only fields
self.server['shared'] = True
self.server['passexec_cmd'] = '/usr/bin/get-secret'
self.server['passexec_expiration'] = 100
self.server['post_connection_sql'] = 'SET role admin;'
url = "/browser/server/obj/{0}/".format(utils.SERVER_GROUP)
response = self.tester.post(
url,
data=json.dumps(self.server),
content_type='html/json'
)
self.assertEqual(response.status_code, 200)
response_data = json.loads(response.data.decode('utf-8'))
self.assertIn('node', response_data)
self.server_id = response_data['node']['_id']
@create_user_wise_test_client(test_user_details)
def runTest(self):
"""Non-owner should NOT see passexec_cmd or
post_connection_sql in properties response."""
if not self.server_id:
raise Exception("Server not found to test suppression")
url = '/browser/server/obj/{0}/{1}'.format(
utils.SERVER_GROUP, self.server_id)
response = self.tester.get(url, follow_redirects=True)
self.assertEqual(response.status_code, 200)
data = json.loads(response.data.decode('utf-8'))
# passexec_cmd must be None/null for non-owners
self.assertIsNone(
data.get('passexec_cmd'),
'passexec_cmd should be suppressed for non-owners.'
' Got: {0}'.format(data.get('passexec_cmd'))
)
self.assertIsNone(
data.get('passexec_expiration'),
'passexec_expiration should be suppressed for '
'non-owners.'
)
# post_connection_sql must be None/null for non-owners
self.assertIsNone(
data.get('post_connection_sql'),
'post_connection_sql should be suppressed for '
'non-owners. Got: {0}'.format(
data.get('post_connection_sql'))
)
def tearDown(self):
if self.server_id is None:
return
utils.delete_server_with_api(
self.__class__.tester, self.server_id)
class SharedServerConnectionParamsIsolationTestCase(
BaseTestGenerator):
"""Verify that owner's SSL file paths in connection_params
are not leaked to non-owners of shared servers."""
scenarios = [
('Shared server strips owner SSL paths for non-owner',
dict(is_positive_test=True)),
]
def setUp(self):
self.server_id = None
if not config.SERVER_MODE:
self.skipTest(
'Data isolation tests only apply to server mode.'
)
# Create shared server with owner SSL paths
self.server['shared'] = True
# Set connection_params with owner-specific paths
conn_params = self.server.get('connection_params', {})
conn_params['sslcert'] = '/home/owner/.ssl/cert.pem'
conn_params['sslkey'] = '/home/owner/.ssl/key.pem'
conn_params['sslrootcert'] = '/home/owner/.ssl/ca.pem'
self.server['connection_params'] = conn_params
url = "/browser/server/obj/{0}/".format(utils.SERVER_GROUP)
response = self.tester.post(
url,
data=json.dumps(self.server),
content_type='html/json'
)
self.assertEqual(response.status_code, 200)
response_data = json.loads(response.data.decode('utf-8'))
self.assertIn('node', response_data)
self.server_id = response_data['node']['_id']
@create_user_wise_test_client(test_user_details)
def runTest(self):
"""Non-owner should NOT see owner's SSL file paths
in connection_params."""
if not self.server_id:
raise Exception("Server not found")
url = '/browser/server/obj/{0}/{1}'.format(
utils.SERVER_GROUP, self.server_id)
response = self.tester.get(url, follow_redirects=True)
self.assertEqual(response.status_code, 200)
data = json.loads(response.data.decode('utf-8'))
conn_params = data.get('connection_params', {})
# Owner SSL paths should be stripped for non-owners
# (non-owner has no SharedServer SSL paths configured,
# so keys should be absent)
for key in ('sslcert', 'sslkey', 'sslrootcert',
'sslcrl', 'sslcrldir'):
val = None
if isinstance(conn_params, list):
for item in conn_params:
if item.get('name') == key:
val = item.get('value')
break
elif isinstance(conn_params, dict):
val = conn_params.get(key)
self.assertIsNone(
val,
'Owner SSL path "{0}" should not leak to '
'non-owner. Got: {1}'.format(key, val)
)
def tearDown(self):
if self.server_id is None:
return
utils.delete_server_with_api(
self.__class__.tester, self.server_id)
class SharedServerRenameDoesNotOrphanTestCase(BaseTestGenerator):
"""Verify that renaming a shared server does not create
orphan SharedServer records (Issue 20 fix lookup uses
osid, not name)."""
scenarios = [
('Rename shared server preserves non-owner access',
dict(is_positive_test=True)),
]
def setUp(self):
self.server_id = None
if not config.SERVER_MODE:
self.skipTest(
'Data isolation tests only apply to server mode.'
)
# Save admin tester BEFORE the decorator replaces it.
self.admin_tester = self.tester
self.server['shared'] = True
url = "/browser/server/obj/{0}/".format(utils.SERVER_GROUP)
response = self.tester.post(
url,
data=json.dumps(self.server),
content_type='html/json'
)
self.assertEqual(response.status_code, 200)
response_data = json.loads(response.data.decode('utf-8'))
self.assertIn('node', response_data)
self.server_id = response_data['node']['_id']
@create_user_wise_test_client(test_user_details)
def runTest(self):
"""After owner renames the shared server, non-owner
should still be able to access it."""
if not self.server_id:
raise Exception("Server not found")
# First access as non-owner to create SharedServer record
url = '/browser/server/obj/{0}/{1}'.format(
utils.SERVER_GROUP, self.server_id)
response = self.tester.get(url, follow_redirects=True)
self.assertEqual(response.status_code, 200)
# Rename the server as admin (saved in setUp before
# the decorator replaced self.tester).
response = self.admin_tester.put(
'/browser/server/obj/{0}/{1}'.format(
utils.SERVER_GROUP, self.server_id),
data=json.dumps(
{'name': 'renamed_shared_server'}),
content_type='html/json'
)
self.assertIn(
response.status_code, [200],
'Admin should be able to rename shared server.'
)
# Access again as non-owner — should still work
response = self.tester.get(url, follow_redirects=True)
self.assertEqual(
response.status_code, 200,
'Non-owner should still access shared server after '
'rename. Got status {0}'.format(response.status_code)
)
def tearDown(self):
if self.server_id is None:
return
utils.delete_server_with_api(
self.__class__.tester, self.server_id)

View File

@ -0,0 +1,491 @@
##########################################################################
#
# pgAdmin 4 - PostgreSQL Tools
#
# Copyright (C) 2013 - 2026, The pgAdmin Development Team
# This software is released under the PostgreSQL Licence
#
##########################################################################
"""Unit tests for shared server isolation logic using mocks.
These tests verify the security-critical merge, suppression, and
sanitization logic without requiring a running PostgreSQL server
or HTTP infrastructure.
"""
from unittest.mock import MagicMock, patch, call
from pgadmin.utils.route import BaseTestGenerator
SRV_MODULE = 'pgadmin.browser.server_groups.servers'
def _make_server(**overrides):
"""Create a mock Server object with sensible defaults."""
defaults = dict(
id=1, user_id=100, name='OwnerServer',
shared=True, host='db.owner.com', port=5432,
maintenance_db='postgres', username='owner',
password=b'enc_owner_pass', role=None,
bgcolor=None, fgcolor=None, service=None,
use_ssh_tunnel=0, tunnel_host=None,
tunnel_port=5522, tunnel_authentication=0,
tunnel_username=None, tunnel_password=None,
tunnel_identity_file=None,
tunnel_prompt_password=0, tunnel_keep_alive=30,
save_password=1, servergroup_id=1,
server_owner='owner_user', prepare_threshold=5,
passexec_cmd='/usr/bin/vault-get-secret',
passexec_expiration=300,
post_connection_sql='SET role admin;',
connection_params={
'sslmode': 'verify-full',
'sslcert': '/home/owner/.ssl/cert.pem',
'sslkey': '/home/owner/.ssl/key.pem',
'sslrootcert': '/home/owner/.ssl/ca.pem',
'sslcrl': '/home/owner/.ssl/crl.pem',
'sslcrldir': '/home/owner/.ssl/crl.d',
'passfile': '/home/owner/.pgpass',
'connect_timeout': '10',
},
discovery_id=None, db_res=None, db_res_type=None,
kerberos_conn=False, cloud_status=0,
shared_username='shared_user', tags=None,
is_adhoc=0,
)
defaults.update(overrides)
server = MagicMock()
for k, v in defaults.items():
setattr(server, k, v)
return server
def _make_shared_server(**overrides):
"""Create a mock SharedServer object."""
defaults = dict(
id=10, osid=1, user_id=200,
server_owner='owner_user', servergroup_id=2,
name='MySharedView', host='db.owner.com',
port=5432, maintenance_db='postgres',
username='nonowner', password=b'enc_nonowner',
save_password=0, role='readonly',
bgcolor='#ff0000', fgcolor='#ffffff',
service='my_pg_service',
use_ssh_tunnel=1, tunnel_host='bastion.local',
tunnel_port=2222, tunnel_authentication=1,
tunnel_username='tunneluser',
tunnel_password=b'enc_tunnel',
tunnel_identity_file='/home/user/.ssh/id_rsa',
tunnel_prompt_password=0,
tunnel_keep_alive=60, shared=True,
prepare_threshold=10,
connection_params={
'sslmode': 'verify-full',
'sslcert': '/home/nonowner/.ssl/cert.pem',
'connect_timeout': '10',
},
)
defaults.update(overrides)
ss = MagicMock()
for k, v in defaults.items():
setattr(ss, k, v)
return ss
class TestGetSharedServerProperties(BaseTestGenerator):
"""Unit tests for ServerModule.get_shared_server_properties()
using mock objects."""
scenarios = [
('Merge suppresses passexec_cmd',
dict(test_method='test_suppresses_passexec')),
('Merge suppresses post_connection_sql',
dict(test_method='test_suppresses_post_sql')),
('Merge strips owner SSL paths not in SharedServer',
dict(test_method='test_strips_owner_ssl_paths')),
('Merge applies SharedServer SSL paths',
dict(test_method='test_applies_ss_ssl_paths')),
('Merge overrides service from SharedServer',
dict(test_method='test_overrides_service')),
('Merge overrides tunnel fields',
dict(test_method='test_overrides_tunnel')),
('Merge handles None connection_params',
dict(test_method='test_none_conn_params')),
]
@patch('pgadmin.browser.server_groups.servers.'
'object_session', return_value=None)
def runTest(self, mock_sess):
getattr(self, self.test_method)()
def _merge(self, server=None, ss=None):
from pgadmin.browser.server_groups.servers import \
ServerModule
if server is None:
server = _make_server()
if ss is None:
ss = _make_shared_server()
return ServerModule.get_shared_server_properties(
server, ss)
def test_suppresses_passexec(self):
result = self._merge()
self.assertIsNone(result.passexec_cmd)
self.assertIsNone(result.passexec_expiration)
def test_suppresses_post_sql(self):
result = self._merge()
self.assertIsNone(result.post_connection_sql)
def test_strips_owner_ssl_paths(self):
result = self._merge()
cp = result.connection_params
# Owner had sslkey, sslrootcert, sslcrl, sslcrldir,
# passfile — SharedServer did not — should be removed.
self.assertNotIn('sslkey', cp)
self.assertNotIn('sslcrl', cp)
self.assertNotIn('sslcrldir', cp)
self.assertNotIn('sslrootcert', cp)
self.assertNotIn('passfile', cp)
def test_applies_ss_ssl_paths(self):
result = self._merge()
cp = result.connection_params
# SharedServer had sslcert -- should override.
self.assertEqual(
cp['sslcert'],
'/home/nonowner/.ssl/cert.pem')
# Non-sensitive params preserved from owner.
self.assertEqual(cp['sslmode'], 'verify-full')
self.assertEqual(cp['connect_timeout'], '10')
def test_overrides_service(self):
result = self._merge()
self.assertEqual(result.service, 'my_pg_service')
def test_overrides_tunnel(self):
result = self._merge()
self.assertEqual(result.tunnel_host, 'bastion.local')
self.assertEqual(result.tunnel_port, 2222)
self.assertEqual(result.tunnel_username, 'tunneluser')
self.assertEqual(result.tunnel_authentication, 1)
self.assertEqual(
result.tunnel_identity_file,
'/home/user/.ssh/id_rsa')
def test_none_conn_params(self):
server = _make_server(connection_params=None)
ss = _make_shared_server(connection_params=None)
result = self._merge(server, ss)
# Should not crash; connection_params becomes {}
self.assertEqual(result.connection_params, {})
class TestCreateSharedServerSanitization(BaseTestGenerator):
"""Verify create_shared_server() strips sensitive
connection_params keys."""
scenarios = [
('Sanitizes connection_params on creation',
dict(test_method='test_sanitizes_conn_params')),
('Copies tunnel_port from owner',
dict(test_method='test_copies_tunnel_port')),
('Copies tunnel_keep_alive from owner',
dict(test_method='test_copies_tunnel_keep_alive')),
('Handles None connection_params',
dict(test_method='test_none_conn_params')),
]
@patch('pgadmin.browser.server_groups.servers.db')
@patch('pgadmin.browser.server_groups.servers.User')
@patch('pgadmin.browser.server_groups.servers.current_user')
@patch('pgadmin.browser.server_groups.servers.SharedServer')
def runTest(self, mock_ss_cls, mock_cu, mock_user,
mock_db):
mock_cu.id = 200
mock_user.query.filter_by.return_value \
.first.return_value = MagicMock(username='owner')
# Capture the SharedServer() constructor call
self.captured_kwargs = {}
def capture_init(**kwargs):
self.captured_kwargs = kwargs
return MagicMock()
mock_ss_cls.side_effect = capture_init
getattr(self, self.test_method)()
def _create(self, server=None):
from pgadmin.browser.server_groups.servers import \
ServerModule
if server is None:
server = _make_server()
ServerModule.create_shared_server(server, 1)
def test_sanitizes_conn_params(self):
self._create()
cp = self.captured_kwargs.get('connection_params', {})
# Sensitive keys must be stripped
for key in ('sslcert', 'sslkey', 'sslrootcert',
'sslcrl', 'sslcrldir', 'passfile'):
self.assertNotIn(
key, cp,
'Sensitive key "{0}" should be stripped '
'on SharedServer creation'.format(key))
# Non-sensitive keys preserved
self.assertEqual(cp.get('sslmode'), 'verify-full')
self.assertEqual(cp.get('connect_timeout'), '10')
def test_copies_tunnel_port(self):
server = _make_server(tunnel_port=2222)
self._create(server)
self.assertEqual(
self.captured_kwargs.get('tunnel_port'), 2222)
def test_copies_tunnel_keep_alive(self):
server = _make_server(tunnel_keep_alive=45)
self._create(server)
self.assertEqual(
self.captured_kwargs.get('tunnel_keep_alive'), 45)
def test_none_conn_params(self):
server = _make_server(connection_params=None)
self._create(server)
cp = self.captured_kwargs.get('connection_params', {})
self.assertEqual(cp, {})
class TestMergeExpungesServer(BaseTestGenerator):
"""Verify get_shared_server_properties() expunges the server
from the SQLAlchemy session before mutation."""
scenarios = [
('Expunge called when server is in session',
dict(test_method='test_expunge_called')),
('No crash when server not in session',
dict(test_method='test_no_session')),
]
def runTest(self):
getattr(self, self.test_method)()
def test_expunge_called(self):
from pgadmin.browser.server_groups.servers import \
ServerModule
server = _make_server()
ss = _make_shared_server()
mock_session = MagicMock()
with patch(SRV_MODULE + '.object_session',
return_value=mock_session):
ServerModule.get_shared_server_properties(
server, ss)
mock_session.expunge.assert_called_once_with(server)
def test_no_session(self):
from pgadmin.browser.server_groups.servers import \
ServerModule
server = _make_server()
ss = _make_shared_server()
with patch(SRV_MODULE + '.object_session',
return_value=None):
# Should not crash
result = ServerModule.get_shared_server_properties(
server, ss)
self.assertIsNone(result.passexec_cmd)
class TestUpdateConnectionParameter(BaseTestGenerator):
"""Verify update_connection_parameter() routes changes
to SharedServer for non-owners."""
scenarios = [
('Non-owner changes go to SharedServer copy',
dict(test_method='test_nonowner_routing')),
('Owner changes go to Server directly',
dict(test_method='test_owner_routing')),
]
def runTest(self):
getattr(self, self.test_method)()
@patch(SRV_MODULE + '.current_user')
def test_nonowner_routing(self, mock_cu):
mock_cu.id = 200 # Non-owner
from pgadmin.browser.server_groups.servers import \
ServerNode
server = _make_server(
connection_params={'sslmode': 'require'})
ss = _make_shared_server(
connection_params={'sslmode': 'require'})
data = {'connection_params': {
'changed': [{'name': 'sslmode', 'value': 'verify'}]
}}
node = ServerNode.__new__(ServerNode)
node.update_connection_parameter(data, server, ss)
# The result should be in data, not mutating server
self.assertEqual(
data['connection_params']['sslmode'], 'verify')
# Owner's server should NOT be mutated
self.assertEqual(
server.connection_params['sslmode'], 'require')
@patch(SRV_MODULE + '.current_user')
def test_owner_routing(self, mock_cu):
mock_cu.id = 100 # Owner
from pgadmin.browser.server_groups.servers import \
ServerNode
server = _make_server(
connection_params={'sslmode': 'require'})
data = {'connection_params': {
'changed': [{'name': 'sslmode', 'value': 'verify'}]
}}
node = ServerNode.__new__(ServerNode)
node.update_connection_parameter(data, server, None)
# Owner path mutates server directly
self.assertEqual(
data['connection_params']['sslmode'], 'verify')
class TestUpdateServerDetails(BaseTestGenerator):
"""Verify _update_server_details routes writes to
SharedServer for non-owners."""
scenarios = [
('Non-owner write goes to SharedServer',
dict(test_method='test_nonowner_write')),
('Owner write goes to Server',
dict(test_method='test_owner_write')),
]
def runTest(self):
getattr(self, self.test_method)()
@patch(SRV_MODULE + '.current_user')
def test_nonowner_write(self, mock_cu):
mock_cu.id = 200
from pgadmin.browser.server_groups.servers import \
ServerNode
server = _make_server()
ss = _make_shared_server()
config_map = {'name': 'name'}
ServerNode._update_server_details(
server, ss, config_map, 'name', 'NewName')
self.assertEqual(ss.name, 'NewName')
# Server should not be modified
self.assertEqual(server.name, 'OwnerServer')
@patch(SRV_MODULE + '.current_user')
def test_owner_write(self, mock_cu):
mock_cu.id = 100
from pgadmin.browser.server_groups.servers import \
ServerNode
server = _make_server()
config_map = {'name': 'name'}
ServerNode._update_server_details(
server, None, config_map, 'name', 'NewName')
self.assertEqual(server.name, 'NewName')
class TestDeleteSharedServerOwnerGuard(BaseTestGenerator):
"""Verify that only the owner can trigger
delete_shared_server via _set_valid_attr_value."""
scenarios = [
('Non-owner shared=false does not delete',
dict(test_method='test_nonowner_no_delete')),
('Owner shared=false triggers delete',
dict(test_method='test_owner_deletes')),
]
def runTest(self):
getattr(self, self.test_method)()
@patch(SRV_MODULE + '.get_crypt_key',
return_value=(True, b'key'))
@patch(SRV_MODULE + '.current_user')
def test_nonowner_no_delete(self, mock_cu, mock_ck):
mock_cu.id = 200
from pgadmin.browser.server_groups.servers import \
ServerNode
server = _make_server()
ss = _make_shared_server()
node = ServerNode.__new__(ServerNode)
node.delete_shared_server = MagicMock()
data = {'shared': False}
config_map = {'shared': 'shared'}
node._set_valid_attr_value(
1, data, config_map, server, ss)
node.delete_shared_server.assert_not_called()
@patch(SRV_MODULE + '.get_crypt_key',
return_value=(True, b'key'))
@patch(SRV_MODULE + '.current_user')
def test_owner_deletes(self, mock_cu, mock_ck):
mock_cu.id = 100 # Owner
from pgadmin.browser.server_groups.servers import \
ServerNode
server = _make_server()
node = ServerNode.__new__(ServerNode)
node.delete_shared_server = MagicMock()
data = {'shared': False}
config_map = {'shared': 'shared'}
node._set_valid_attr_value(
1, data, config_map, server, None)
node.delete_shared_server.assert_called_once_with(
1, server.id)
class TestGetSharedServerRaisesOnNone(BaseTestGenerator):
"""Verify get_shared_server() raises if SharedServer
cannot be created."""
scenarios = [
('Raises when SharedServer is None after create',
dict(test_method='test_raises_on_none')),
]
def runTest(self):
getattr(self, self.test_method)()
@patch(SRV_MODULE + '.SharedServer')
@patch(SRV_MODULE + '.current_user')
def test_raises_on_none(self, mock_cu, mock_ss):
mock_cu.id = 200
# Both queries return None
mock_ss.query.filter_by.return_value \
.first.return_value = None
from pgadmin.browser.server_groups.servers import \
ServerModule
server = _make_server()
with patch.object(ServerModule, 'create_shared_server'):
with self.assertRaises(Exception) as ctx:
ServerModule.get_shared_server(server, 1)
self.assertIn(
'Failed to create shared server',
str(ctx.exception))

View File

@ -13,13 +13,14 @@ from ipaddress import ip_address
import keyring
from flask_login import current_user
from werkzeug.exceptions import InternalServerError
from flask import render_template
from flask import render_template, has_request_context
from pgadmin.utils.constants import (
KEY_RING_USERNAME_FORMAT, KEY_RING_SERVICE_NAME, KEY_RING_TUNNEL_FORMAT,
KEY_RING_DESKTOP_USER, SSL_MODES, RESTRICTION_TYPE_DATABASES,
RESTRICTION_TYPE_SQL)
from pgadmin.utils.crypto import encrypt, decrypt
from pgadmin.model import db, Server, SharedServer
from pgadmin.utils.server_access import get_user_server_query
from flask import current_app
from pgadmin.utils.master_password import set_masterpass_check_text
from pgadmin.utils.driver import get_driver
@ -324,7 +325,10 @@ def migrate_passwords_from_pgadmin_db(servers, old_key, enc_key):
def get_servers_with_saved_passwords():
all_server = Server.query.filter(Server.is_adhoc == 0)
all_server = Server.query.filter(
Server.user_id == current_user.id,
Server.is_adhoc == 0
)
servers_with_pwd_in_os_secret = []
servers_with_pwd_in_pgadmin_db = []
saved_password_servers = []
@ -648,32 +652,56 @@ def check_ssl_fields(data):
def disconnect_from_all_servers():
"""
This function is used to disconnect all the servers
This function is used to disconnect all the servers for the
current user (owned + shared).
"""
all_servers = Server.query.all()
all_servers = get_user_server_query().all()
for server in all_servers:
manager = get_driver(config.PG_DEFAULT_DRIVER).connection_manager(
server.id)
# Check if any psql terminal is running for the current disconnecting
# server. If any terminate the psql tool connection.
if 'sid_soid_mapping' in current_app.config and str(server.id) in \
current_app.config['sid_soid_mapping'] and \
str(server.id) in current_app.config['sid_soid_mapping']:
for i in current_app.config['sid_soid_mapping'][str(server.id)]:
sio.emit('disconnect-psql', namespace='/pty', to=i)
manager.release()
try:
manager = get_driver(
config.PG_DEFAULT_DRIVER
).connection_manager(server.id)
# Only emit disconnect-psql for servers owned by the
# current user — shared servers may have other users'
# PSQL sessions mapped to the same sid.
if server.user_id == current_user.id and \
'sid_soid_mapping' in current_app.config \
and str(server.id) in \
current_app.config['sid_soid_mapping']:
for i in current_app.config[
'sid_soid_mapping'][str(server.id)]:
sio.emit(
'disconnect-psql',
namespace='/pty', to=i
)
manager.release()
except Exception:
current_app.logger.warning(
'Failed to disconnect server %s',
server.id, exc_info=True
)
def delete_adhoc_servers(sid=None):
"""
This function will remove all the adhoc servers.
This function will remove adhoc servers. When called with a
current_user context, scopes to the current user. When called
during app startup (no user context), cleans all adhoc servers.
"""
try:
has_user = (has_request_context() and
current_user and current_user.is_authenticated)
if sid is not None:
db.session.query(Server).filter(Server.id == sid).delete()
q = db.session.query(Server).filter(
Server.id == sid, Server.is_adhoc == 1)
if has_user:
q = q.filter(Server.user_id == current_user.id)
q.delete()
else:
db.session.query(Server).filter(Server.is_adhoc == 1).delete()
q = db.session.query(Server).filter(Server.is_adhoc == 1)
if has_user:
q = q.filter(Server.user_id == current_user.id)
q.delete()
db.session.commit()
# Reset the sequence again

View File

@ -0,0 +1,78 @@
##########################################################################
#
# pgAdmin 4 - PostgreSQL Tools
#
# Copyright (C) 2013 - 2026, The pgAdmin Development Team
# This software is released under the PostgreSQL Licence
#
##########################################################################
"""Tests for ServerGroup data isolation between users in server mode."""
import json
import config
from pgadmin.utils.route import BaseTestGenerator
from regression.python_test_utils import test_utils as utils
from regression.test_setup import config_data
from regression.python_test_utils.test_utils import \
create_user_wise_test_client
from pgadmin.model import db, ServerGroup
test_user_details = None
if config.SERVER_MODE:
test_user_details = \
config_data['pgAdmin4_test_non_admin_credentials']
class ServerGroupIsolationTestCase(BaseTestGenerator):
"""Verify that a non-admin user cannot fetch another user's
server group properties by ID."""
scenarios = [
('User B cannot fetch User A server group properties',
dict(is_positive_test=False)),
]
def setUp(self):
self.sg_id = None
if not config.SERVER_MODE:
self.skipTest(
'Data isolation tests only apply to server mode.'
)
# Create a server group as the admin user
url = '/browser/server_group/obj/'
response = self.tester.post(
url,
data=json.dumps({'name': 'isolation_test_group'}),
content_type='html/json'
)
self.assertEqual(response.status_code, 200)
response_data = json.loads(response.data.decode('utf-8'))
self.assertIn('node', response_data)
self.sg_id = response_data['node']['_id']
@create_user_wise_test_client(test_user_details)
def runTest(self):
"""Non-admin user should NOT see another user's server
group properties."""
if not self.sg_id:
raise Exception("Server group not created")
url = '/browser/server_group/obj/{0}'.format(self.sg_id)
response = self.tester.get(url, content_type='html/json')
self.assertEqual(
response.status_code, 410,
'Non-admin user should not access another user\'s '
'server group. Got status {0}'.format(
response.status_code)
)
def tearDown(self):
# Clean up with admin
if self.sg_id is None:
return
sg = ServerGroup.query.filter_by(id=self.sg_id).first()
if sg:
db.session.delete(sg)
db.session.commit()

View File

@ -153,7 +153,7 @@ class BatchProcess:
self.manager_obj = kwargs['manager_obj']
def _retrieve_process(self, _id):
p = Process.query.filter_by(pid=_id, user_id=current_user.id).first()
p = Process.for_user(pid=_id).first()
if p is None:
raise LookupError(PROCESS_NOT_FOUND)
@ -372,9 +372,7 @@ class BatchProcess:
# There is no way to find out the error message from this process
# as standard output, and standard error were redirected to
# devnull.
p = Process.query.filter_by(
pid=self.id, user_id=current_user.id
).first()
p = Process.for_user(pid=self.id).first()
p.start_time = p.end_time = get_current_time()
if not p.exit_code:
p.exit_code = self.ecode
@ -382,9 +380,7 @@ class BatchProcess:
db.session.commit()
else:
# Update the process state to "Started"
p = Process.query.filter_by(
pid=self.id, user_id=current_user.id
).first()
p = Process.for_user(pid=self.id).first()
p.process_state = PROCESS_STARTED
db.session.commit()
@ -530,9 +526,7 @@ class BatchProcess:
"""
_pid = self.id
_process = Process.query.filter_by(
user_id=current_user.id, pid=_pid
).first()
_process = Process.for_user(pid=_pid).first()
if _process is None:
raise LookupError(PROCESS_NOT_FOUND)
@ -588,9 +582,7 @@ class BatchProcess:
out_completed = err_completed = False
process_output = (out != -1 and err != -1)
j = Process.query.filter_by(
pid=self.id, user_id=current_user.id
).first()
j = Process.for_user(pid=self.id).first()
enc = sys.getdefaultencoding()
if enc == 'ascii':
enc = 'utf-8'
@ -739,7 +731,7 @@ class BatchProcess:
@staticmethod
def list():
processes = Process.query.filter_by(user_id=current_user.id)
processes = Process.for_user()
changed = False
browser_preference = Preferences.module('browser')
@ -812,9 +804,7 @@ class BatchProcess:
And, delete the process information from the configuration, and the log
files related to the process, if it has already been completed.
"""
p = Process.query.filter_by(
user_id=current_user.id, pid=_pid
).first()
p = Process.for_user(pid=_pid).first()
if p is None:
raise LookupError(PROCESS_NOT_FOUND)
@ -886,9 +876,7 @@ class BatchProcess:
def stop_process(_pid):
"""
"""
p = Process.query.filter_by(
user_id=current_user.id, pid=_pid
).first()
p = Process.for_user(pid=_pid).first()
if p is None:
raise LookupError(PROCESS_NOT_FOUND)
@ -910,9 +898,7 @@ class BatchProcess:
@staticmethod
def update_server_id(_pid, _sid):
p = Process.query.filter_by(
user_id=current_user.id, pid=_pid
).first()
p = Process.for_user(pid=_pid).first()
if p is None:
raise LookupError(PROCESS_NOT_FOUND)

View File

@ -212,8 +212,9 @@ def clear_cloud_session(pid=None):
@pga_login_required
def update_cloud_process(sid):
"""Update Cloud Server Process"""
_process = Process.query.filter_by(user_id=current_user.id,
server_id=sid).first()
_process = Process.for_user(server_id=sid).first()
if _process is None:
return success_return()
_process.acknowledge = None
db.session.commit()
return success_return()

View File

@ -17,6 +17,7 @@ from flask_babel import gettext
from flask_security import current_user
from pgadmin.utils import PgAdminModule
from pgadmin.model import db, Server
from pgadmin.utils.server_access import get_server
from pgadmin.utils.driver import get_driver
from pgadmin.utils.ajax import bad_request, make_json_response
from pgadmin.browser.server_groups.servers.utils import (
@ -132,7 +133,8 @@ def adhoc_connect_server():
username=new_username,
name=new_server_name,
role=new_role,
service=new_service
service=new_service,
user_id=current_user.id
).all()
# If found matching servers then compare the connection_params as
@ -143,22 +145,27 @@ def adhoc_connect_server():
server = existing_server
break
else:
server = Server.query.filter_by(host=new_host,
port=new_port,
maintenance_db=new_db,
username=new_username,
name=new_server_name,
role=new_role,
service=new_service,
connection_params=connection_params
).first()
server = Server.query.filter_by(
host=new_host, port=new_port,
maintenance_db=new_db,
username=new_username,
name=new_server_name,
role=new_role,
service=new_service,
connection_params=connection_params,
user_id=current_user.id
).first()
# If server is none then no server with the above combination is found.
if server is None:
# Check if sid is present in data if it is then used that sid.
if ('sid' in data and data['sid'] is not None and
int(data['sid']) > 0):
server = Server.query.filter_by(id=data['sid']).first()
server = get_server(data['sid'])
if server is None:
return bad_request(gettext(
"Could not find the required server."
))
# Clone the server object
server = server.clone()
@ -220,23 +227,30 @@ def check_and_delete_adhoc_server(sid):
This function is used to check for adhoc server and if all Query Tool
and PSQL connections are closed then delete that server.
"""
server = Server.query.filter_by(id=sid).first()
if server.is_adhoc:
# Check PSQL connections. If more connections are open for
# the given sid return from the function.
psql_connections = get_open_psql_connections()
if sid in psql_connections.values():
server = get_server(sid)
if server is None:
# Server may be deleted or inaccessible; still attempt
# best-effort cleanup of adhoc state.
delete_adhoc_servers(sid)
return
if not server.is_adhoc:
return
# Check PSQL connections. If more connections are open for
# the given sid return from the function.
psql_connections = get_open_psql_connections()
if sid in psql_connections.values():
return
# Check Query Tool connections for the given sid
manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(sid)
for key, value in manager.connections.items():
if key.startswith('CONN') and value.connected():
return
# Check Query Tool connections for the given sid
manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(sid)
for key, value in manager.connections.items():
if key.startswith('CONN') and value.connected():
return
# Assumption at this point all the Query Tool and PSQL connections
# is closed, so now we can release the manager
manager.release()
# Assumption at this point all the Query Tool and PSQL connections
# is closed, so now we can release the manager
manager.release()
# Delete the adhoc server from the pgadmin database
delete_adhoc_servers(sid)
# Delete the adhoc server from the pgadmin database
delete_adhoc_servers(sid)

View File

@ -33,7 +33,7 @@ import config
#
##########################################################################
SCHEMA_VERSION = 49
SCHEMA_VERSION = 50
##########################################################################
#
@ -51,6 +51,60 @@ USER_ID = 'user.id'
SERVER_ID = 'server.id'
CASCADE_STR = "all, delete-orphan"
class UserScopedMixin:
"""Mixin for models that store per-user data.
Provides for_user() as the default scoped query entry point.
Models with a 'user_id' column or a 'uid' column are supported
automatically the mixin detects which column name is used.
Usage:
# Instead of:
Process.query.filter_by(user_id=current_user.id, pid=pid)
# Use:
Process.for_user(pid=pid)
"""
@classmethod
def _user_column(cls):
"""Return the user-scoping column for this model."""
if hasattr(cls, 'user_id'):
return cls.user_id
if hasattr(cls, 'uid'):
return cls.uid
raise AttributeError(
f"{cls.__name__} has no user_id or uid column"
)
@classmethod
def _user_column_name(cls):
"""Return the column name string ('user_id' or 'uid')."""
if hasattr(cls, 'user_id'):
return 'user_id'
if hasattr(cls, 'uid'):
return 'uid'
raise AttributeError(
f"{cls.__name__} has no user_id or uid column"
)
@classmethod
def for_user(cls, user_id=None, **kwargs):
"""Query scoped to a specific user (defaults to current_user).
Args:
user_id: Explicit user ID. If None, uses current_user.id.
**kwargs: Additional filter_by arguments.
Returns:
A SQLAlchemy query filtered by the user's ID.
"""
from flask_security import current_user as cu
uid = user_id if user_id is not None else cu.id
kwargs[cls._user_column_name()] = uid
return cls.query.filter_by(**kwargs)
# Define models
roles_users = db.Table(
'roles_users',
@ -158,7 +212,7 @@ class User(db.Model, UserMixin):
locked = db.Column(db.Boolean(), default=False)
class Setting(db.Model):
class Setting(db.Model, UserScopedMixin):
"""Define a setting object"""
__tablename__ = 'setting'
user_id = db.Column(db.Integer, db.ForeignKey(USER_ID), primary_key=True)
@ -166,7 +220,7 @@ class Setting(db.Model):
value = db.Column(db.Text())
class ServerGroup(db.Model):
class ServerGroup(db.Model, UserScopedMixin):
"""Define a server group for the treeview"""
__tablename__ = 'servergroup'
id = db.Column(db.Integer, primary_key=True)
@ -185,7 +239,7 @@ class ServerGroup(db.Model):
}
class Server(db.Model):
class Server(db.Model, UserScopedMixin):
"""Define a registered Postgres server"""
__tablename__ = 'server'
id = db.Column(db.Integer, primary_key=True)
@ -306,7 +360,7 @@ class Preferences(db.Model):
name = db.Column(db.String(1024), nullable=False)
class UserPreference(db.Model):
class UserPreference(db.Model, UserScopedMixin):
"""Define the preference for a particular user."""
__tablename__ = 'user_preferences'
pid = db.Column(
@ -318,9 +372,13 @@ class UserPreference(db.Model):
value = db.Column(db.String(1024), nullable=False)
class DebuggerFunctionArguments(db.Model):
class DebuggerFunctionArguments(db.Model, UserScopedMixin):
"""Define the debugger input function arguments."""
__tablename__ = 'debugger_function_arguments'
user_id = db.Column(
db.Integer, db.ForeignKey(USER_ID),
nullable=False, primary_key=True
)
server_id = db.Column(db.Integer(), nullable=False, primary_key=True)
database_id = db.Column(db.Integer(), nullable=False, primary_key=True)
schema_id = db.Column(db.Integer(), nullable=False, primary_key=True)
@ -349,7 +407,7 @@ class DebuggerFunctionArguments(db.Model):
value = db.Column(db.String(), nullable=True)
class Process(db.Model):
class Process(db.Model, UserScopedMixin):
"""Define the Process table."""
__tablename__ = 'process'
pid = db.Column(db.String(), nullable=False, primary_key=True)
@ -382,7 +440,7 @@ class Keys(db.Model):
value = db.Column(db.String(), nullable=False)
class QueryHistoryModel(db.Model):
class QueryHistoryModel(db.Model, UserScopedMixin):
"""Define the history SQL table."""
__tablename__ = 'query_history'
srno = db.Column(db.Integer(), nullable=False, primary_key=True)
@ -397,7 +455,7 @@ class QueryHistoryModel(db.Model):
last_updated_flag = db.Column(db.String(), nullable=False)
class ApplicationState(db.Model):
class ApplicationState(db.Model, UserScopedMixin):
"""Define the application state SQL table."""
__tablename__ = 'application_state'
uid = db.Column(db.Integer(), db.ForeignKey(USER_ID), nullable=False,
@ -422,10 +480,14 @@ class Database(db.Model):
)
class SharedServer(db.Model):
class SharedServer(db.Model, UserScopedMixin):
"""Define a shared Postgres server"""
__tablename__ = 'sharedserver'
__table_args__ = (
db.UniqueConstraint('osid', 'user_id',
name='uq_sharedserver_osid_user'),
)
id = db.Column(db.Integer, primary_key=True)
osid = db.Column(
db.Integer,
@ -510,7 +572,7 @@ class Macros(db.Model):
key_code = db.Column(db.Integer, nullable=False)
class UserMacros(db.Model):
class UserMacros(db.Model, UserScopedMixin):
"""Define the macro for a particular user."""
__tablename__ = 'user_macros'
id = db.Column(db.Integer, primary_key=True, autoincrement=True)
@ -524,7 +586,7 @@ class UserMacros(db.Model):
sql = db.Column(db.Text(), nullable=False)
class UserMFA(db.Model):
class UserMFA(db.Model, UserScopedMixin):
"""Stores the options for the MFA for a particular user."""
__tablename__ = 'user_mfa'
user_id = db.Column(db.Integer, db.ForeignKey(USER_ID), primary_key=True)

View File

@ -195,6 +195,7 @@ class BatchProcessTest(BaseTestGenerator):
self.utility_pid = 123
self.server_id = None
process_mock.for_user = process_mock.query.filter_by
mock_result = process_mock.query.filter_by.return_value
mock_result.first.return_value = TestMockProcess(
backup_obj, self.class_params['args'], self.class_params['cmd'])
@ -239,6 +240,7 @@ class BatchProcessTest(BaseTestGenerator):
self.utility_pid = 123
self.server_id = None
process_mock.for_user = process_mock.query.filter_by
process_mock.query.filter_by.return_value = [
TestMockProcess(backup_obj,
self.class_params['args'],

View File

@ -16,7 +16,7 @@ import copy
from flask import render_template, request, current_app
from flask_babel import gettext
from flask_security import permissions_required
from flask_security import permissions_required, current_user
from pgadmin.user_login_check import pga_login_required
from werkzeug.user_agent import UserAgent
@ -35,7 +35,9 @@ from pgadmin.browser.server_groups.servers.databases.extensions.utils \
import get_extension_details
from pgadmin.utils.constants import PREF_LABEL_KEYBOARD_SHORTCUTS, \
SERVER_CONNECTION_CLOSED
from pgadmin.tools.user_management.PgAdminPermissions import AllPermissionTypes
from pgadmin.tools.user_management.PgAdminPermissions \
import AllPermissionTypes
from pgadmin.utils.server_access import get_server
from pgadmin.preferences import preferences
MODULE_NAME = 'debugger'
@ -1803,12 +1805,19 @@ def get_arguments_sqlite(sid, did, scid, func_id):
- Function Id
"""
if get_server(sid) is None:
return make_json_response(
status=410, success=0,
errormsg=gettext("Could not find the required server.")
)
"""Get the count of the existing data available in sqlite database"""
dbg_func_args_count = int(DebuggerFunctionArguments.query.filter_by(
server_id=sid,
database_id=did,
schema_id=scid,
function_id=func_id
function_id=func_id,
user_id=current_user.id
).count())
args_data = []
@ -1819,7 +1828,8 @@ def get_arguments_sqlite(sid, did, scid, func_id):
server_id=sid,
database_id=did,
schema_id=scid,
function_id=func_id
function_id=func_id,
user_id=current_user.id
)
args_list = dbg_func_args.all()
@ -1888,6 +1898,12 @@ def set_arguments_sqlite(sid, did, scid, func_id):
- Function Id
"""
if get_server(sid) is None:
return make_json_response(
status=410, success=0,
errormsg=gettext("Could not find the required server.")
)
if request.data:
data = json.loads(request.data)
@ -1899,7 +1915,8 @@ def set_arguments_sqlite(sid, did, scid, func_id):
database_id=data[i]['database_id'],
schema_id=data[i]['schema_id'],
function_id=data[i]['function_id'],
arg_id=data[i]['arg_id']).count())
arg_id=data[i]['arg_id'],
user_id=current_user.id).count())
# handle the Array list sent from the client
array_string = ''
@ -1918,7 +1935,8 @@ def set_arguments_sqlite(sid, did, scid, func_id):
database_id=data[i]['database_id'],
schema_id=data[i]['schema_id'],
function_id=data[i]['function_id'],
arg_id=data[i]['arg_id']
arg_id=data[i]['arg_id'],
user_id=current_user.id
).first()
dbg_func_args.is_null = data[i]['is_null']
@ -1932,6 +1950,7 @@ def set_arguments_sqlite(sid, did, scid, func_id):
schema_id=data[i]['schema_id'],
function_id=data[i]['function_id'],
arg_id=data[i]['arg_id'],
user_id=current_user.id,
is_null=data[i]['is_null'],
is_expression=data[i]['is_expression'],
use_default=data[i]['use_default'],
@ -1977,12 +1996,20 @@ def clear_arguments_sqlite(sid, did, scid, func_id):
- Function Id
"""
if get_server(sid) is None:
return make_json_response(
status=410, success=0,
errormsg=gettext("Could not find the required server.")
)
try:
db.session.query(DebuggerFunctionArguments) \
.filter(DebuggerFunctionArguments.server_id == sid,
DebuggerFunctionArguments.database_id == did,
DebuggerFunctionArguments.schema_id == scid,
DebuggerFunctionArguments.function_id == func_id) \
DebuggerFunctionArguments.function_id == func_id,
DebuggerFunctionArguments.user_id ==
current_user.id) \
.delete()
db.session.commit()

View File

@ -20,6 +20,7 @@ from pgadmin.utils import PgAdminModule, \
SHORTCUT_FIELDS as shortcut_fields
from pgadmin.utils.ajax import make_json_response, internal_server_error
from pgadmin.model import Server
from pgadmin.utils.server_access import get_server
from config import PG_DEFAULT_DRIVER, ALLOW_SAVE_PASSWORD
from pgadmin.utils.driver import get_driver
from pgadmin.browser.utils import underscore_unescape
@ -556,7 +557,7 @@ def panel(trans_id):
if "linux" in _platform:
is_linux_platform = True
s = Server.query.filter_by(id=int(params['sid'])).first()
s = get_server(int(params['sid']))
if s:
params.update({

View File

@ -23,6 +23,7 @@ from pgadmin.utils.ajax import make_json_response, bad_request, unauthorized
from config import PG_DEFAULT_DRIVER
from pgadmin.model import Server
from pgadmin.utils.server_access import get_server
from pgadmin.utils.constants import SERVER_NOT_FOUND
from pgadmin.settings import get_setting, store_setting
from pgadmin.tools.user_management.PgAdminPermissions import AllPermissionTypes
@ -97,9 +98,7 @@ class IEMessage(IProcessDesc):
def get_server_name(self):
# Fetch the server details like hostname, port, roles etc
s = Server.query.filter_by(
id=self.sid, user_id=current_user.id
).first()
s = get_server(self.sid)
if s is None:
return _("Not available")
@ -293,8 +292,7 @@ def create_import_export_job(sid):
data = json.loads(request.data)
# Fetch the server details like hostname, port, roles etc
server = Server.query.filter_by(
id=sid).first()
server = get_server(sid)
if server is None:
return bad_request(errormsg=_("Could not find the specified server."))

View File

@ -204,6 +204,7 @@ class BatchProcessTest(BaseTestGenerator):
self.utility_pid = 123
self.server_id = None
process_mock.for_user = process_mock.query.filter_by
mock_result = process_mock.query.filter_by.return_value
mock_result.first.return_value = TestMockProcess(
import_export_obj, self.class_params['args'],
@ -250,6 +251,7 @@ class BatchProcessTest(BaseTestGenerator):
self.utility_pid = 123
self.server_id = None
process_mock.for_user = process_mock.query.filter_by
process_mock.query.filter_by.return_value = [
TestMockProcess(import_export_obj,
self.class_params['args'],

View File

@ -137,6 +137,7 @@ class BatchProcessTest(BaseTestGenerator):
self.utility_pid = 123
self.server_id = None
process_mock.for_user = process_mock.query.filter_by
mock_result = process_mock.query.filter_by.return_value
mock_result.first.return_value = TestMockProcess(
maintenance_obj, self.class_params['args'],
@ -177,6 +178,7 @@ class BatchProcessTest(BaseTestGenerator):
self.utility_pid = 123
self.server_id = None
process_mock.for_user = process_mock.query.filter_by
process_mock.query.filter_by.return_value = [
TestMockProcess(maintenance_obj,
self.class_params['args'],

View File

@ -29,6 +29,7 @@ from ... import socketio as sio
from pgadmin.utils import get_complete_file_path
from pgadmin.authenticate import socket_login_required
from pgadmin.model import Server
from pgadmin.utils.server_access import get_server
if _platform == 'win32':
# Check Windows platform support for WinPty api, Disable psql
@ -98,7 +99,7 @@ def panel(trans_id):
if 'sid_soid_mapping' not in app.config:
app.config['sid_soid_mapping'] = dict()
s = Server.query.filter_by(id=int(params['sid'])).first()
s = get_server(int(params['sid']))
if s:
data = _get_database_role(params['sid'], params['did'])
if data:

View File

@ -134,6 +134,7 @@ class BatchProcessTest(BaseTestGenerator):
self.utility_pid = 123
self.server_id = None
process_mock.for_user = process_mock.query.filter_by
mock_result = process_mock.query.filter_by.return_value
mock_result.first.return_value = TestMockProcess(
restore_obj, self.class_params['args'],
@ -174,6 +175,7 @@ class BatchProcessTest(BaseTestGenerator):
self.utility_pid = 123
self.server_id = None
process_mock.for_user = process_mock.query.filter_by
process_mock.query.filter_by.return_value = [
TestMockProcess(restore_obj,
self.class_params['args'],

View File

@ -31,6 +31,8 @@ from sqlalchemy import or_
from pgadmin.authenticate import socket_login_required
from pgadmin import socketio
from pgadmin.tools.user_management.PgAdminPermissions import AllPermissionTypes
from pgadmin.utils.server_access import \
get_server as get_server_access, get_user_server_query
MODULE_NAME = 'schema_diff'
COMPARE_MSG = gettext("Comparing objects...")
@ -283,18 +285,14 @@ def servers():
from pgadmin.browser.server_groups.servers import\
server_icon_and_background
for server in Server.query.filter(
or_(Server.user_id == current_user.id, Server.shared),
for server in get_user_server_query().filter(
Server.is_adhoc == 0):
shared_server = SharedServer.query.filter_by(
name=server.name, user_id=current_user.id,
servergroup_id=server.servergroup_id).first()
user_id=current_user.id,
osid=server.id).first()
if server.discovery_id:
auto_detected_server = server.name
if shared_server and shared_server.name == auto_detected_server:
if server.discovery_id and shared_server:
continue
manager = driver.connection_manager(server.id)
@ -336,7 +334,13 @@ def get_server(sid, did):
"""Return a JSON document listing the server groups for the user"""
driver = get_driver(PG_DEFAULT_DRIVER)
server = Server.query.filter_by(id=sid).first()
server = get_server_access(sid)
if server is None:
return make_json_response(
status=410, success=0,
errormsg=gettext(
"Could not find the required server.")
)
manager = driver.connection_manager(sid)
conn = manager.connection(did=did)
connected = conn.connected()
@ -375,7 +379,12 @@ def connect_server(sid):
data={}
)
server = Server.query.filter_by(id=sid).first()
server = get_server_access(sid)
if server is None:
return make_json_response(
status=410, success=0,
errormsg=gettext("Could not find the required server.")
)
view = SchemaDiffRegistry.get_node_view('server')
return view.connect(server.servergroup_id, sid)
@ -387,7 +396,12 @@ def connect_server(sid):
)
@pga_login_required
def connect_database(sid, did):
server = Server.query.filter_by(id=sid).first()
server = get_server_access(sid)
if server is None:
return make_json_response(
status=410, success=0,
errormsg=gettext("Could not find the required server.")
)
view = SchemaDiffRegistry.get_node_view('database')
return view.connect(server.servergroup_id, sid, did)
@ -407,7 +421,13 @@ def databases(sid):
try:
view = SchemaDiffRegistry.get_node_view('database')
server = Server.query.filter_by(id=sid).first()
server = get_server_access(sid)
if server is None:
return make_json_response(
status=410, success=0,
errormsg=gettext(
"Could not find the required server.")
)
response = view.nodes(gid=server.servergroup_id, sid=sid,
is_schema_diff=True)
databases = json.loads(response.data)['data']
@ -495,6 +515,15 @@ def compare_database(params):
fetch_compare_schemas(params['source_sid'], params['source_did'],
params['target_sid'], params['target_did'])
if schema_result is None:
socketio.emit(
'compare_database_failed',
gettext(
"Failed to fetch schemas from the"
" server."),
namespace=SOCKETIO_NAMESPACE, to=request.sid)
return
total_schema = len(schema_result['source_only']) + len(
schema_result['target_only']) + len(
schema_result['in_both_database'])
@ -722,11 +751,15 @@ def check_version_compatibility(sid, tid):
"""Check the version compatibility of source and target servers."""
driver = get_driver(PG_DEFAULT_DRIVER)
src_server = Server.query.filter_by(id=sid).first()
src_server = get_server_access(sid)
if src_server is None:
return False, gettext("Could not find the source server.")
src_manager = driver.connection_manager(src_server.id)
src_conn = src_manager.connection()
tar_server = Server.query.filter_by(id=tid).first()
tar_server = get_server_access(tid)
if tar_server is None:
return False, gettext("Could not find the target server.")
tar_manager = driver.connection_manager(tar_server.id)
target_conn = tar_manager.connection()
@ -759,7 +792,9 @@ def get_schemas(sid, did):
"""
try:
view = SchemaDiffRegistry.get_node_view('schema')
server = Server.query.filter_by(id=sid).first()
server = get_server_access(sid)
if server is None:
return None
response = view.nodes(gid=server.servergroup_id, sid=sid, did=did,
is_schema_diff=True)
schemas = json.loads(response.data)['data']
@ -912,6 +947,9 @@ def fetch_compare_schemas(source_sid, source_did, target_sid, target_did):
source_schemas = get_schemas(source_sid, source_did)
target_schemas = get_schemas(target_sid, target_did)
if source_schemas is None or target_schemas is None:
return None
src_schema_dict = {item['label']: item['_id'] for item in source_schemas}
tar_schema_dict = {item['label']: item['_id'] for item in target_schemas}

View File

@ -63,6 +63,8 @@ from pgadmin.utils.constants import MIMETYPE_APP_JS, \
ERROR_FETCHING_DATA, MY_STORAGE, ACCESS_DENIED_MESSAGE, \
ERROR_MSG_FAIL_TO_PROMOTE_QT
from pgadmin.model import Server, ServerGroup
from pgadmin.utils.server_access import get_server, \
get_server_groups_for_user, get_user_server_query
from pgadmin.tools.schema_diff.node_registry import SchemaDiffRegistry
from pgadmin.settings import get_setting
from pgadmin.utils.preferences import Preferences
@ -225,7 +227,12 @@ def initialize_viewdata(trans_id, cmd_type, obj_type, sgid, sid, did, obj_id):
'password': _data['password'] if 'password' in _data else None
}
server = Server.query.filter_by(id=sid).first()
server = get_server(sid)
if server is None:
return make_json_response(
status=410, success=0,
errormsg=gettext("Could not find the required server.")
)
if kwargs.get('password', None) is None:
kwargs['encpass'] = server.password
else:
@ -374,7 +381,7 @@ def panel(trans_id):
params['bgcolor'] = None
params['fgcolor'] = None
s = Server.query.filter_by(id=int(params['sid'])).first()
s = get_server(int(params['sid']))
if s:
if s.shared and s.user_id != current_user.id:
# Import here to avoid circular dependency
@ -512,7 +519,12 @@ def _init_sqleditor(trans_id, connect, sgid, sid, did, dbname=None, **kwargs):
kwargs.pop('conn_id')
conn_id_ac = str(secrets.choice(range(1, 9999999)))
server = Server.query.filter_by(id=sid).first()
server = get_server(sid)
if server is None:
return True, internal_server_error(
errormsg=gettext(
"Could not find the required server.")
), '', ''
if server.shared and server.user_id != current_user.id:
# Import here to avoid circular dependency
from pgadmin.browser.server_groups.servers import ServerModule
@ -2344,8 +2356,13 @@ def _check_server_connection_status(sgid, sid=None):
driver = get_driver(PG_DEFAULT_DRIVER)
from pgadmin.browser.server_groups.servers import \
server_icon_and_background
server = Server.query.filter_by(
id=sid).first()
server = get_server(sid)
if server is None:
return make_json_response(
status=410, success=0,
errormsg=gettext(
"Could not find the required server.")
)
manager = driver.connection_manager(server.id)
conn = manager.connection()
@ -2393,11 +2410,10 @@ def get_new_connection_data(sgid=None, sid=None):
driver = get_driver(PG_DEFAULT_DRIVER)
from pgadmin.browser.server_groups.servers import \
server_icon_and_background
server_groups = ServerGroup.query.all()
server_groups = get_server_groups_for_user()
server_group_data = {server_group.name: [] for server_group in
server_groups}
servers = Server.query.filter(
or_(Server.user_id == current_user.id, Server.shared),
servers = get_user_server_query().filter(
Server.is_adhoc == 0)
for server in servers:
@ -2654,7 +2670,12 @@ def get_new_connection_role(sgid, sid=None):
@pga_login_required
def connect_server(sid):
# Check if server is already connected then no need to reconnect again.
server = Server.query.filter_by(id=sid).first()
server = get_server(sid)
if server is None:
return make_json_response(
status=410, success=0,
errormsg=gettext("Could not find the required server.")
)
driver = get_driver(PG_DEFAULT_DRIVER)
manager = driver.connection_manager(sid)

View File

@ -759,7 +759,7 @@ def delete_user(uid):
ServerGroup.query.filter_by(user_id=uid).delete()
Process.query.filter_by(user_id=uid).delete()
Process.for_user(user_id=uid).delete()
# Delete Shared servers for current user.
SharedServer.query.filter_by(user_id=uid).delete()

View File

@ -358,14 +358,14 @@ def does_utility_exist(file):
return error_msg
def get_server(sid):
def get_server(sid, only_owned=False):
"""Fetch a server by ID with access check.
Delegates to server_access.get_server(). Kept here for backward
compatibility existing callers import from pgadmin.utils.
"""
# Fetch the server etc
:param sid:
:return: server
"""
server = Server.query.filter_by(id=sid).first()
return server
from pgadmin.utils.server_access import get_server as _get_server
return _get_server(sid, only_owned=only_owned)
def get_binary_path_versions(binary_path: str) -> dict:

View File

@ -16,6 +16,7 @@ object.
import datetime
import re
from flask import session
from flask_babel import gettext
from flask_login import current_user
from werkzeug.exceptions import InternalServerError
import psycopg
@ -23,6 +24,9 @@ from threading import Lock
import config
from pgadmin.model import Server
from pgadmin.utils.server_access import get_server, \
get_user_server_query
from pgadmin.utils.exception import ObjectGone
from .keywords import scan_keyword
from ..abstract import BaseDriver
from .connection import Connection
@ -67,20 +71,29 @@ class Driver(BaseDriver):
def _restore_connections_from_session(self):
"""
Used internally by connection_manager to restore connections
from sessions.
from sessions. Includes both owned and shared servers so
non-owner connections survive session restore.
"""
if session.sid not in self.managers:
self.managers[session.sid] = managers = dict()
if '__pgsql_server_managers' in session:
session_managers = \
session['__pgsql_server_managers'].copy()
for server in \
Server.query.filter_by(
user_id=current_user.id, is_adhoc=0):
servers = get_user_server_query().filter(
Server.is_adhoc == 0)
for server in servers:
manager = managers[str(server.id)] = \
ServerManager(server)
# Suppress owner-only fields for non-owners
# of shared servers so passexec_cmd and
# post_connection_sql don't leak.
if server.shared and \
server.user_id != current_user.id:
manager.passexec = None
manager.post_connection_sql = None
if server.id in session_managers:
manager._restore(session_managers[server.id])
manager._restore(
session_managers[server.id])
manager.update_session()
return managers
@ -100,9 +113,27 @@ class Driver(BaseDriver):
assert (sid is not None and isinstance(sid, int))
managers = None
server_data = Server.query.filter_by(id=sid).first()
if server_data is None:
return None
# In server mode, verify the current user has access to this
# server. This is the primary security boundary — all
# check_precondition decorators and tool endpoints flow
# through connection_manager().
if config.SERVER_MODE:
if current_user and current_user.is_authenticated:
server_data = get_server(sid)
else:
raise ObjectGone(
gettext("Server not found."))
if server_data is None:
raise ObjectGone(
gettext("Server not found."))
else:
# Desktop mode — single user, no isolation needed.
# Return None instead of raising so callers that
# handle None gracefully (e.g., test teardown,
# cleanup paths) are not disrupted.
server_data = Server.query.filter_by(id=sid).first()
if server_data is None:
return None
if session.sid not in self.managers:
with connection_restore_lock:
@ -119,14 +150,18 @@ class Driver(BaseDriver):
managers['pinged'] = datetime.datetime.now()
if str(sid) not in managers:
s = Server.query.filter_by(id=sid).first()
# server_data was already access-checked above;
# it cannot be None at this point.
manager = ServerManager(server_data)
# Suppress owner-only fields for non-owners of
# shared servers.
if config.SERVER_MODE and server_data.shared and \
server_data.user_id != current_user.id:
manager.passexec = None
manager.post_connection_sql = None
managers[str(sid)] = manager
if not s:
return None
managers[str(sid)] = ServerManager(s)
return managers[str(sid)]
return manager
return managers[str(sid)]

View File

@ -0,0 +1,156 @@
##########################################################################
#
# pgAdmin 4 - PostgreSQL Tools
#
# Copyright (C) 2013 - 2026, The pgAdmin Development Team
# This software is released under the PostgreSQL Licence
#
##########################################################################
"""Centralized server access-checking utilities for data isolation.
In server mode, multiple users share the same pgAdmin instance. These
helpers enforce that users can only access servers they own or that
have been explicitly shared with them via SharedServer entries.
"""
from sqlalchemy import or_
from flask_security import current_user
from pgadmin.model import db, Server, ServerGroup
import config
def _is_admin():
"""Check if current user has Administrator role."""
return current_user.has_role('Administrator')
def get_server(sid, only_owned=False):
"""Fetch a server by ID, verifying the current user has access.
Args:
sid: Server ID.
only_owned: If True, only return servers owned by the current
user. Use this for write operations (change_password,
clear_saved_password, etc.) that must not mutate another
user's server record via shared access.
Returns the server if:
- Desktop mode (single user, no isolation needed), OR
- The user owns it, OR
- The server is shared AND only_owned is False, OR
- The user has the Administrator role.
Returns None otherwise (caller should return 404).
Note: In pgAdmin, Server.shared=True means the server is visible
to all authenticated users. SharedServer records are created
lazily for per-user customization, not for access control.
"""
if not config.SERVER_MODE:
return Server.query.filter_by(id=sid).first()
if only_owned:
return Server.query.filter_by(
id=sid, user_id=current_user.id).first()
# Single query: owned OR shared
server = Server.query.filter(
Server.id == sid,
or_(
Server.user_id == current_user.id,
Server.shared
)
).first()
if server is not None:
return server
# Administrators can access all servers
if _is_admin():
return Server.query.filter_by(id=sid).first()
return None
def get_server_group(gid):
"""Fetch a server group by ID, verifying user access.
Returns the group if:
- Desktop mode, OR
- The user owns it, OR
- It contains shared servers (Server.shared=True), OR
- The user has the Administrator role.
Returns None otherwise.
"""
if not config.SERVER_MODE:
return ServerGroup.query.filter_by(id=gid).first()
sg = ServerGroup.query.filter(
ServerGroup.id == gid,
or_(
ServerGroup.user_id == current_user.id,
ServerGroup.id.in_(
db.session.query(Server.servergroup_id).filter(
Server.shared
)
)
)
).first()
if sg is not None:
return sg
if _is_admin():
return ServerGroup.query.filter_by(id=gid).first()
return None
def get_server_groups_for_user():
"""Return server groups visible to the current user.
Includes groups owned by the user plus groups containing shared
servers (Server.shared=True, visible to all authenticated users).
Administrators see all groups.
"""
if not config.SERVER_MODE:
return ServerGroup.query.filter_by(
user_id=current_user.id
).all()
if _is_admin():
return ServerGroup.query.all()
return ServerGroup.query.filter(
or_(
ServerGroup.user_id == current_user.id,
ServerGroup.id.in_(
db.session.query(Server.servergroup_id).filter(
Server.shared
)
)
)
).all()
def get_user_server_query():
"""Return a base query for servers accessible to the current user.
Includes owned servers + shared servers (visible to all users).
Administrators see all servers.
"""
if not config.SERVER_MODE:
return Server.query
if _is_admin():
return Server.query
return Server.query.filter(
or_(
Server.user_id == current_user.id,
Server.shared
)
)