diff --git a/web/pgadmin/browser/server_groups/servers/pgagent/__init__.py b/web/pgadmin/browser/server_groups/servers/pgagent/__init__.py index be175cda3..c30852c37 100644 --- a/web/pgadmin/browser/server_groups/servers/pgagent/__init__.py +++ b/web/pgadmin/browser/server_groups/servers/pgagent/__init__.py @@ -25,7 +25,7 @@ from pgadmin.utils.ajax import make_json_response, internal_server_error, \ from pgadmin.utils.driver import get_driver from pgadmin.utils.preferences import Preferences from pgadmin.browser.server_groups.servers.pgagent.utils \ - import format_schedule_data + import format_schedule_data, format_step_data class JobModule(CollectionNodeModule): @@ -566,39 +566,22 @@ SELECT EXISTS( :return: """ # Format the schedule data. Convert the boolean array - for key in ['added', 'changed']: - jschedules = data.get('jschedules', {}) - if key in jschedules: - for schedule in jschedules.get(key, []): - format_schedule_data(schedule) + jschedules = data.get('jschedules', {}) + if type(jschedules) == dict: + for schedule in jschedules.get('added', []): + format_schedule_data(schedule) + for schedule in jschedules.get('changed', []): + format_schedule_data(schedule) has_connection_str = self.manager.db_info['pgAgent']['has_connstr'] jssteps = data.get('jsteps', {}) - if 'changed' in jschedules: + if type(jssteps) == dict: for changed_step in jssteps.get('changed', []): - if 'jstconntype' not in changed_step and \ - ('jstdbname' in changed_step or - 'jstconnstr' in changed_step) and has_connection_str: - status, rset = self.conn.execute_dict( - render_template( - "/".join([self.template_path, 'steps.sql']), - jid=data['jobid'], - jstid=changed_step['jstid'], - conn=self.conn, - has_connstr=has_connection_str - ) - ) - if not status: - return internal_server_error(errormsg=rset) - - row = rset['rows'][0] - changed_step['jstconntype'] = row['jstconntype'] - if row['jstconntype']: - changed_step['jstdbname'] = changed_step.get( - 'jstdbname', row['jstdbname']) - else: - changed_step['jstconnstr'] = changed_step.get( - 'jstconnstr', row['jstconnstr']) + status, res = format_step_data( + data['jobid'], changed_step, has_connection_str, + self.conn, self.template_path) + if not status: + internal_server_error(errormsg=res) JobView.register_node_view(blueprint) diff --git a/web/pgadmin/browser/server_groups/servers/pgagent/utils.py b/web/pgadmin/browser/server_groups/servers/pgagent/utils.py index 72701ef3c..2c6fcdc37 100644 --- a/web/pgadmin/browser/server_groups/servers/pgagent/utils.py +++ b/web/pgadmin/browser/server_groups/servers/pgagent/utils.py @@ -8,6 +8,7 @@ ########################################################################## """pgagent helper utilities""" +from flask import render_template def format_boolean_array(value): @@ -44,3 +45,40 @@ def format_schedule_data(data): data['jscmonths'] = format_boolean_array(data['jscmonths']) return data + + +def format_step_data(job_id, data, has_connection_str, conn, template_path): + """ + This function is used to format the step data. If data is not an + instance of list then format + :param job_id: Job ID + :param data: a step data + :param has_connection_str: has pgagent connection str + :param conn: Connection obj + :param conn: SQL template path + """ + if 'jstconntype' not in data and \ + ('jstdbname' in data or + 'jstconnstr' in data) and has_connection_str: + status, rset = conn.execute_dict( + render_template( + "/".join([template_path, 'steps.sql']), + jid=job_id, + jstid=data['jstid'], + conn=conn, + has_connstr=has_connection_str + ) + ) + if not status: + return False, rset + + row = rset['rows'][0] + data['jstconntype'] = row['jstconntype'] + if row['jstconntype']: + data['jstdbname'] = data.get( + 'jstdbname', row['jstdbname']) + else: + data['jstconnstr'] = data.get( + 'jstconnstr', row['jstconnstr']) + + return True, None diff --git a/web/pgadmin/utils/driver/psycopg2/__init__.py b/web/pgadmin/utils/driver/psycopg2/__init__.py index 2edbd64cd..6f61b81fe 100644 --- a/web/pgadmin/utils/driver/psycopg2/__init__.py +++ b/web/pgadmin/utils/driver/psycopg2/__init__.py @@ -14,6 +14,7 @@ object. """ import datetime +import re from flask import session from flask_login import current_user from werkzeug.exceptions import InternalServerError @@ -64,6 +65,28 @@ class Driver(BaseDriver): super(Driver, self).__init__() + def _restore_connections_from_session(self): + """ + Used internally by connection_manager to restore connections + from sessions. + """ + if session.sid not in self.managers: + self.managers[session.sid] = managers = dict() + if '__pgsql_server_managers' in session: + session_managers = \ + session['__pgsql_server_managers'].copy() + for server in \ + Server.query.filter_by( + user_id=current_user.id): + manager = managers[str(server.id)] = \ + ServerManager(server) + if server.id in session_managers: + manager._restore(session_managers[server.id]) + manager.update_session() + return managers + + return {} + def connection_manager(self, sid=None): """ connection_manager(...) @@ -86,20 +109,7 @@ class Driver(BaseDriver): with connection_restore_lock: # The wait is over but the object might have been loaded # by some other thread check again - if session.sid not in self.managers: - self.managers[session.sid] = managers = dict() - if '__pgsql_server_managers' in session: - session_managers =\ - session['__pgsql_server_managers'].copy() - for server in \ - Server.query.filter_by( - user_id=current_user.id): - manager = managers[str(server.id)] =\ - ServerManager(server) - if server.id in session_managers: - manager._restore(session_managers[server.id]) - manager.update_session() - + managers = self._restore_connections_from_session() else: managers = self.managers[session.sid] if str(sid) in managers: @@ -331,10 +341,8 @@ class Driver(BaseDriver): if '0' <= val_noarray[0] <= '9': return True - for c in val_noarray: - if (not ('a' <= c <= 'z') and c != '_' and - not ('0' <= c <= '9')): - return True + if re.search('[^a-z_0-9]+', val_noarray): + return True # check string is keywaord or not category = Driver.ScanKeywordExtraLookup(value) diff --git a/web/setup.py b/web/setup.py index c88c4e1fa..e9c42decd 100644 --- a/web/setup.py +++ b/web/setup.py @@ -154,6 +154,45 @@ def dump_servers(args): (servers_dumped, args.dump_servers)) +def _validate_servers_data(data): + """ + Used internally by load_servers to validate servers data. + :param data: servers data + :return: error message if any + """ + # Loop through the servers... + if "Servers" not in data: + return ("'Servers' attribute not found in file '%s'" % + args.load_servers) + + for server in data["Servers"]: + obj = data["Servers"][server] + + def check_attrib(attrib): + if attrib not in obj: + return ("'%s' attribute not found for server '%s'" % + (attrib, server)) + + check_attrib("Name") + check_attrib("Group") + + is_service_attrib_available = obj.get("Service", None) is not None + + if not is_service_attrib_available: + check_attrib("Port") + check_attrib("Username") + + check_attrib("SSLMode") + check_attrib("MaintenanceDB") + + if "Host" not in obj and "HostAddr" not in obj and not \ + is_service_attrib_available: + return ("'Host', 'HostAddr' or 'Service' attribute " + "not found for server '%s'" % server) + + return None + + def load_servers(args): """Load server groups and servers. @@ -162,10 +201,7 @@ def load_servers(args): """ # What user? - if args.user is not None: - load_user = args.user - else: - load_user = config.DESKTOP_USER + load_user = args.user if args.user is not None else config.DESKTOP_USER # And the sqlite path if args.sqlite_path is not None: @@ -212,48 +248,19 @@ def load_servers(args): def print_summary(): print("Added %d Server Group(s) and %d Server(s)." % (groups_added, servers_added)) - # Loop through the servers... - if "Servers" not in data: - print("'Servers' attribute not found in file '%s'" % - args.load_servers) + + err_msg = _validate_servers_data(data) + if err_msg is not None: + print(err_msg) print_summary() sys.exit(1) for server in data["Servers"]: obj = data["Servers"][server] - def check_attrib(attrib): - if attrib not in obj: - print("'%s' attribute not found for server '%s'" % - (attrib, server)) - print_summary() - sys.exit(1) - - check_attrib("Name") - check_attrib("Group") - - is_service_attrib_available = True if "Service" in obj else False - - if not is_service_attrib_available: - check_attrib("Port") - check_attrib("Username") - - check_attrib("SSLMode") - check_attrib("MaintenanceDB") - - if "Host" not in obj and "HostAddr" not in obj: - if is_service_attrib_available is False: - print("'Host', 'HostAddr' or 'Service' attribute " - "not found for server '%s'" % server) - print_summary() - sys.exit(1) - # Get the group. Create if necessary - group_id = -1 - for g in groups: - if g.name == obj["Group"]: - group_id = g.id - break + group_id = next( + (g.id for g in groups if g.name == obj["Group"]), -1) if group_id == -1: new_group = ServerGroup() @@ -281,71 +288,52 @@ def load_servers(args): new_server.ssl_mode = obj["SSLMode"] new_server.maintenance_db = obj["MaintenanceDB"] - if "Host" in obj: - new_server.host = obj["Host"] + new_server.host = obj.get("Host", None) - if "HostAddr" in obj: - new_server.hostaddr = obj["HostAddr"] + new_server.hostaddr = obj.get("HostAddr", None) - if "Port" in obj: - new_server.port = obj["Port"] + new_server.port = obj.get("Port", None) - if "Username" in obj: - new_server.username = obj["Username"] + new_server.username = obj.get("Username", None) - if "Role" in obj: - new_server.role = obj["Role"] + new_server.role = obj.get("Role", None) - if "Comment" in obj: - new_server.comment = obj["Comment"] + new_server.ssl_mode = obj["SSLMode"] - if "DBRestriction" in obj: - new_server.db_res = obj["DBRestriction"] + new_server.comment = obj.get("Comment", None) - if "PassFile" in obj: - new_server.passfile = obj["PassFile"] + new_server.db_res = obj.get("DBRestriction", None) - if "SSLCert" in obj: - new_server.sslcert = obj["SSLCert"] + new_server.passfile = obj.get("PassFile", None) - if "SSLKey" in obj: - new_server.sslkey = obj["SSLKey"] + new_server.sslcert = obj.get("SSLCert", None) - if "SSLRootCert" in obj: - new_server.sslrootcert = obj["SSLRootCert"] + new_server.sslkey = obj.get("SSLKey", None) - if "SSLCrl" in obj: - new_server.sslcrl = obj["SSLCrl"] + new_server.sslrootcert = obj.get("SSLRootCert", None) - if "SSLCompression" in obj: - new_server.sslcompression = obj["SSLCompression"] + new_server.sslcrl = obj.get("SSLCrl", None) - if "BGColor" in obj: - new_server.bgcolor = obj["BGColor"] + new_server.sslcompression = obj.get("SSLCompression", None) - if "FGColor" in obj: - new_server.fgcolor = obj["FGColor"] + new_server.bgcolor = obj.get("BGColor", None) - if is_service_attrib_available: - new_server.service = obj["Service"] + new_server.fgcolor = obj.get("FGColor", None) - if "Timeout" in obj: - new_server.connect_timeout = obj["Timeout"] + new_server.service = obj.get("Service", None) - if "UseSSHTunnel" in obj: - new_server.use_ssh_tunnel = obj["UseSSHTunnel"] + new_server.connect_timeout = obj.get("Timeout", None) - if "TunnelHost" in obj: - new_server.tunnel_host = obj["TunnelHost"] + new_server.use_ssh_tunnel = obj.get("UseSSHTunnel", None) - if "TunnelPort" in obj: - new_server.tunnel_port = obj["TunnelPort"] + new_server.tunnel_host = obj.get("TunnelHost", None) - if "TunnelUsername" in obj: - new_server.tunnel_username = obj["TunnelUsername"] + new_server.tunnel_port = obj.get("TunnelPort", None) - if "TunnelAuthentication" in obj: - new_server.tunnel_authentication = obj["TunnelAuthentication"] + new_server.tunnel_username = obj.get("TunnelUsername", None) + + new_server.tunnel_authentication = \ + obj.get("TunnelAuthentication", None) db.session.add(new_server)