diff --git a/docs/en_US/release_notes_3_3.rst b/docs/en_US/release_notes_3_3.rst index 77b625822..3b77b9305 100644 --- a/docs/en_US/release_notes_3_3.rst +++ b/docs/en_US/release_notes_3_3.rst @@ -21,6 +21,7 @@ Bug fixes | `Bug #3325 `_ - Fix sort/filter dialog issue where it incorrectly requires ASC/DESC. | `Bug #3347 `_ - Ensure backup should work with '--data-only' and '--schema-only' for any format. | `Bug #3407 `_ - Fix keyboard shortcuts layout in the preferences panel. +| `Bug #3420 `_ - Merge pgcli code with version 1.10.3, which is used for auto complete feature. | `Bug #3461 `_ - Ensure that refreshing a node also updates the Property list. | `Bug #3528 `_ - Handle connection errors properly in the query tool. | `Bug #3547 `_ - Make session implementation thread safe diff --git a/web/pgadmin/feature_tests/query_tool_auto_complete_tests.py b/web/pgadmin/feature_tests/query_tool_auto_complete_tests.py new file mode 100644 index 000000000..2e0dd4a2b --- /dev/null +++ b/web/pgadmin/feature_tests/query_tool_auto_complete_tests.py @@ -0,0 +1,156 @@ +########################################################################## +# +# pgAdmin 4 - PostgreSQL Tools +# +# Copyright (C) 2013 - 2018, The pgAdmin Development Team +# This software is released under the PostgreSQL Licence +# +########################################################################## + +import sys +import random + +from selenium.webdriver import ActionChains +from selenium.webdriver.common.keys import Keys +from selenium.webdriver.support.ui import WebDriverWait +from selenium.webdriver.support import expected_conditions as EC +from regression.python_test_utils import test_utils +from regression.feature_utils.base_feature_test import BaseFeatureTest + + +class QueryToolAutoCompleteFeatureTest(BaseFeatureTest): + """ + This feature test will test the query tool auto complete feature. + """ + + first_table_name = "" + second_table_name = "" + + scenarios = [ + ("Query tool auto complete feature test", dict()) + ] + + def before(self): + self.page.wait_for_spinner_to_disappear() + + self.page.add_server(self.server) + self.first_table_name = "auto_comp_" + \ + str(random.randint(1000, 3000)) + test_utils.create_table(self.server, self.test_db, + self.first_table_name) + + self.second_table_name = "auto_comp_" + \ + str(random.randint(1000, 3000)) + test_utils.create_table(self.server, self.test_db, + self.second_table_name) + + self._locate_database_tree_node() + self.page.open_query_tool() + self.page.wait_for_spinner_to_disappear() + + def runTest(self): + # Test case for keywords + print("\nAuto complete ALTER keyword... ", file=sys.stderr, end="") + self._auto_complete("A", "ALTER") + print("OK.", file=sys.stderr) + self._clear_query_tool() + + print("Auto complete BEGIN keyword... ", file=sys.stderr, end="") + self._auto_complete("BE", "BEGIN") + print("OK.", file=sys.stderr) + self._clear_query_tool() + + print("Auto complete CASCADED keyword... ", file=sys.stderr, end="") + self._auto_complete("CAS", "CASCADED") + print("OK.", file=sys.stderr) + self._clear_query_tool() + + print("Auto complete SELECT keyword... ", file=sys.stderr, end="") + self._auto_complete("SE", "SELECT") + print("OK.", file=sys.stderr) + self._clear_query_tool() + + print("Auto complete pg_backend_pid() function ... ", + file=sys.stderr, end="") + self._auto_complete("SELECT pg_", "pg_backend_pid()") + print("OK.", file=sys.stderr) + self._clear_query_tool() + + print("Auto complete current_query() function ... ", + file=sys.stderr, end="") + self._auto_complete("SELECT current_", "current_query()") + print("OK.", file=sys.stderr) + self._clear_query_tool() + + print("Auto complete function with argument ... ", + file=sys.stderr, end="") + self._auto_complete("SELECT pg_st", "pg_stat_file(filename)") + print("OK.", file=sys.stderr) + self._clear_query_tool() + + print("Auto complete first table in public schema ... ", + file=sys.stderr, end="") + self._auto_complete("SELECT * FROM public.", self.first_table_name) + print("OK.", file=sys.stderr) + self._clear_query_tool() + + print("Auto complete second table in public schema ... ", + file=sys.stderr, end="") + self._auto_complete("SELECT * FROM public.", self.second_table_name) + print("OK.", file=sys.stderr) + self._clear_query_tool() + + print("Auto complete JOIN second table with after schema name ... ", + file=sys.stderr, end="") + query = "SELECT * FROM public." + self.first_table_name + \ + " JOIN public." + self._auto_complete(query, self.second_table_name) + print("OK.", file=sys.stderr) + self._clear_query_tool() + + print("Auto complete JOIN ON some columns ... ", + file=sys.stderr, end="") + query = "SELECT * FROM public." + self.first_table_name + \ + " JOIN public." + self.second_table_name + " ON " + \ + self.second_table_name + "." + expected_string = "some_column = " + self.first_table_name + \ + ".some_column" + self._auto_complete(query, expected_string) + print("OK.", file=sys.stderr) + self._clear_query_tool() + + print("Auto complete JOIN ON some columns using tabel alias ... ", + file=sys.stderr, end="") + query = "SELECT * FROM public." + self.first_table_name + \ + " t1 JOIN public." + self.second_table_name + " t2 ON t2." + self._auto_complete(query, "some_column = t1.some_column") + print("OK.", file=sys.stderr) + self._clear_query_tool() + + def after(self): + self.page.remove_server(self.server) + + def _locate_database_tree_node(self): + self.page.toggle_open_tree_item(self.server['name']) + self.page.toggle_open_tree_item('Databases') + self.page.toggle_open_tree_item(self.test_db) + + def _clear_query_tool(self): + self.page.click_element( + self.page.find_by_xpath("//*[@id='btn-clear-dropdown']") + ) + ActionChains(self.driver) \ + .move_to_element(self.page.find_by_xpath("//*[@id='btn-clear']")) \ + .perform() + self.page.click_element( + self.page.find_by_xpath("//*[@id='btn-clear']") + ) + self.page.click_modal('Yes') + + def _auto_complete(self, word, expected_string): + self.page.fill_codemirror_area_with(word) + ActionChains(self.page.driver).key_down( + Keys.CONTROL).send_keys(Keys.SPACE).key_up(Keys.CONTROL).perform() + self.page.find_by_xpath( + "//ul[contains(@class, 'CodeMirror-hints') and " + "contains(., '" + expected_string + "')]") diff --git a/web/pgadmin/misc/templates/sqlautocomplete/sql/11_plus/functions.sql b/web/pgadmin/misc/templates/sqlautocomplete/sql/11_plus/functions.sql index 1190ad0fd..894bf438b 100644 --- a/web/pgadmin/misc/templates/sqlautocomplete/sql/11_plus/functions.sql +++ b/web/pgadmin/misc/templates/sqlautocomplete/sql/11_plus/functions.sql @@ -1,30 +1,16 @@ {# ============= Fetch the list of functions based on given schema_names ============= #} -{% if func_name %} SELECT n.nspname schema_name, p.proname func_name, - pg_catalog.pg_get_function_arguments(p.oid) arg_list, - pg_catalog.pg_get_function_result(p.oid) return_type, + p.proargnames arg_names, + COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[] arg_types, + p.proargmodes arg_modes, + prorettype::regtype::text return_type, CASE WHEN p.prokind = 'a' THEN true ELSE false END is_aggregate, CASE WHEN p.prokind = 'w' THEN true ELSE false END is_window, - p.proretset is_set_returning + p.proretset is_set_returning, + pg_get_expr(proargdefaults, 0) AS arg_defaults FROM pg_catalog.pg_proc p INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace - WHERE n.nspname = '{{schema_name}}' AND p.proname = '{{func_name}}' - AND p.proretset - ORDER BY 1, 2 -{% else %} -SELECT n.nspname schema_name, - p.proname object_name, - pg_catalog.pg_get_function_arguments(p.oid) arg_list, - pg_catalog.pg_get_function_result(p.oid) return_type, - CASE WHEN p.prokind = 'a' THEN true ELSE false END is_aggregate, - CASE WHEN p.prokind = 'w' THEN true ELSE false END is_window, - p.proretset is_set_returning -FROM pg_catalog.pg_proc p - INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace - WHERE n.nspname IN ({{schema_names}}) -{% if is_set_returning %} - AND p.proretset -{% endif %} - ORDER BY 1, 2 -{% endif %} +WHERE p.prorettype::regtype != 'trigger'::regtype + AND n.nspname IN ({{schema_names}}) +ORDER BY 1, 2 diff --git a/web/pgadmin/misc/templates/sqlautocomplete/sql/default/columns.sql b/web/pgadmin/misc/templates/sqlautocomplete/sql/default/columns.sql index d4a109484..580c3622f 100644 --- a/web/pgadmin/misc/templates/sqlautocomplete/sql/default/columns.sql +++ b/web/pgadmin/misc/templates/sqlautocomplete/sql/default/columns.sql @@ -1,29 +1,43 @@ {# SQL query for getting columns #} {% if object_name == 'table' %} -SELECT - att.attname column_name -FROM pg_catalog.pg_attribute att - INNER JOIN pg_catalog.pg_class cls - ON att.attrelid = cls.oid - INNER JOIN pg_catalog.pg_namespace nsp - ON cls.relnamespace = nsp.oid - WHERE cls.relkind = ANY(array['r']) +SELECT nsp.nspname schema_name, + cls.relname table_name, + att.attname column_name, + att.atttypid::regtype::text type_name, + att.atthasdef AS has_default, + def.adsrc as default +FROM pg_catalog.pg_attribute att + INNER JOIN pg_catalog.pg_class cls + ON att.attrelid = cls.oid + INNER JOIN pg_catalog.pg_namespace nsp + ON cls.relnamespace = nsp.oid + LEFT OUTER JOIN pg_attrdef def + ON def.adrelid = att.attrelid + AND def.adnum = att.attnum +WHERE nsp.nspname IN ({{schema_names}}) + AND cls.relkind = ANY(array['r']) AND NOT att.attisdropped AND att.attnum > 0 - AND (nsp.nspname = '{{schema_name}}' AND cls.relname = '{{rel_name}}') - ORDER BY 1 +ORDER BY 1, 2, att.attnum {% endif %} {% if object_name == 'view' %} -SELECT - att.attname column_name -FROM pg_catalog.pg_attribute att - INNER JOIN pg_catalog.pg_class cls - ON att.attrelid = cls.oid - INNER JOIN pg_catalog.pg_namespace nsp - ON cls.relnamespace = nsp.oid - WHERE cls.relkind = ANY(array['v', 'm']) +SELECT nsp.nspname schema_name, + cls.relname table_name, + att.attname column_name, + att.atttypid::regtype::text type_name, + att.atthasdef AS has_default, + def.adsrc as default +FROM pg_catalog.pg_attribute att + INNER JOIN pg_catalog.pg_class cls + ON att.attrelid = cls.oid + INNER JOIN pg_catalog.pg_namespace nsp + ON cls.relnamespace = nsp.oid + LEFT OUTER JOIN pg_attrdef def + ON def.adrelid = att.attrelid + AND def.adnum = att.attnum +WHERE nsp.nspname IN ({{schema_names}}) + AND cls.relkind = ANY(array['v', 'm']) AND NOT att.attisdropped AND att.attnum > 0 - AND (nsp.nspname = '{{schema_name}}' AND cls.relname = '{{rel_name}}') - ORDER BY 1 -{% endif %} \ No newline at end of file +ORDER BY 1, 2, att.attnum +{% endif %} diff --git a/web/pgadmin/misc/templates/sqlautocomplete/sql/default/foreign_keys.sql b/web/pgadmin/misc/templates/sqlautocomplete/sql/default/foreign_keys.sql new file mode 100644 index 000000000..e635c343a --- /dev/null +++ b/web/pgadmin/misc/templates/sqlautocomplete/sql/default/foreign_keys.sql @@ -0,0 +1,27 @@ +{# SQL query for getting foreign keys #} +SELECT s_p.nspname AS parentschema, + t_p.relname AS parenttable, + unnest(( + select + array_agg(attname ORDER BY i) + from + (select unnest(confkey) as attnum, generate_subscripts(confkey, 1) as i) x + JOIN pg_catalog.pg_attribute c USING(attnum) + WHERE c.attrelid = fk.confrelid + )) AS parentcolumn, + s_c.nspname AS childschema, + t_c.relname AS childtable, + unnest(( + select + array_agg(attname ORDER BY i) + from + (select unnest(conkey) as attnum, generate_subscripts(conkey, 1) as i) x + JOIN pg_catalog.pg_attribute c USING(attnum) + WHERE c.attrelid = fk.conrelid + )) AS childcolumn +FROM pg_catalog.pg_constraint fk +JOIN pg_catalog.pg_class t_p ON t_p.oid = fk.confrelid +JOIN pg_catalog.pg_namespace s_p ON s_p.oid = t_p.relnamespace +JOIN pg_catalog.pg_class t_c ON t_c.oid = fk.conrelid +JOIN pg_catalog.pg_namespace s_c ON s_c.oid = t_c.relnamespace +WHERE fk.contype = 'f' AND s_p.nspname IN ({{schema_names}}) diff --git a/web/pgadmin/misc/templates/sqlautocomplete/sql/default/functions.sql b/web/pgadmin/misc/templates/sqlautocomplete/sql/default/functions.sql index 826da2d81..82e01a44b 100644 --- a/web/pgadmin/misc/templates/sqlautocomplete/sql/default/functions.sql +++ b/web/pgadmin/misc/templates/sqlautocomplete/sql/default/functions.sql @@ -1,30 +1,16 @@ {# ============= Fetch the list of functions based on given schema_names ============= #} -{% if func_name %} SELECT n.nspname schema_name, p.proname func_name, - pg_catalog.pg_get_function_arguments(p.oid) arg_list, - pg_catalog.pg_get_function_result(p.oid) return_type, + p.proargnames arg_names, + COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[] arg_types, + p.proargmodes arg_modes, + prorettype::regtype::text return_type, p.proisagg is_aggregate, p.proiswindow is_window, - p.proretset is_set_returning + p.proretset is_set_returning, + pg_get_expr(proargdefaults, 0) AS arg_defaults FROM pg_catalog.pg_proc p INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace - WHERE n.nspname = '{{schema_name}}' AND p.proname = '{{func_name}}' - AND p.proretset - ORDER BY 1, 2 -{% else %} -SELECT n.nspname schema_name, - p.proname object_name, - pg_catalog.pg_get_function_arguments(p.oid) arg_list, - pg_catalog.pg_get_function_result(p.oid) return_type, - p.proisagg is_aggregate, - p.proiswindow is_window, - p.proretset is_set_returning -FROM pg_catalog.pg_proc p - INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace - WHERE n.nspname IN ({{schema_names}}) -{% if is_set_returning %} - AND p.proretset -{% endif %} - ORDER BY 1, 2 -{% endif %} \ No newline at end of file +WHERE p.prorettype::regtype != 'trigger'::regtype + AND n.nspname IN ({{schema_names}}) +ORDER BY 1, 2 diff --git a/web/pgadmin/utils/sqlautocomplete/autocomplete.py b/web/pgadmin/utils/sqlautocomplete/autocomplete.py index 8a25c1324..3efac773d 100644 --- a/web/pgadmin/utils/sqlautocomplete/autocomplete.py +++ b/web/pgadmin/utils/sqlautocomplete/autocomplete.py @@ -9,60 +9,71 @@ """A blueprint module implementing the sql auto complete feature.""" -import itertools -import operator import re -import sys -from collections import namedtuple - -import sqlparse +import operator +from itertools import count, repeat, chain +from .completion import Completion +from collections import namedtuple, defaultdict, OrderedDict +from .sqlcompletion import ( + FromClauseItem, suggest_type, Database, Schema, Table, + Function, Column, View, Keyword, Datatype, Alias, JoinCondition, Join) +from .parseutils.meta import FunctionMetadata, ColumnMetadata, ForeignKey +from .parseutils.utils import last_word +from .parseutils.tables import TableReference +from .prioritization import PrevalenceCounter from flask import render_template from pgadmin.utils.driver import get_driver -from sqlparse.sql import Comparison, Identifier, Where - from config import PG_DEFAULT_DRIVER -from .completion import Completion -from .function_metadata import FunctionMetadata -from .parseutils import ( - last_word, extract_tables, find_prev_keyword, parse_partial_identifier) -from .prioritization import PrevalenceCounter from pgadmin.utils.preferences import Preferences -PY2 = sys.version_info[0] == 2 -PY3 = sys.version_info[0] == 3 - -if PY3: - string_types = str -else: - string_types = basestring - -Database = namedtuple('Database', []) -Schema = namedtuple('Schema', []) -Table = namedtuple('Table', ['schema']) - -Function = namedtuple('Function', ['schema', 'filter']) -# For convenience, don't require the `filter` argument in Function constructor -Function.__new__.__defaults__ = (None, None) - -Column = namedtuple('Column', ['tables', 'drop_unique']) -Column.__new__.__defaults__ = (None, None) - -View = namedtuple('View', ['schema']) -Keyword = namedtuple('Keyword', []) -Datatype = namedtuple('Datatype', ['schema']) -Alias = namedtuple('Alias', ['aliases']) Match = namedtuple('Match', ['completion', 'priority']) -try: - from collections import Counter -except ImportError: - # python 2.6 - from .counter import Counter +_SchemaObject = namedtuple('SchemaObject', 'name schema meta') + + +def SchemaObject(name, schema=None, meta=None): + return _SchemaObject(name, schema, meta) + # Regex for finding "words" in documents. _FIND_WORD_RE = re.compile(r'([a-zA-Z0-9_]+|[^a-zA-Z0-9_\s]+)') _FIND_BIG_WORD_RE = re.compile(r'([^\s]+)') +_Candidate = namedtuple( + 'Candidate', 'completion prio meta synonyms prio2 display' +) + + +def Candidate( + completion, prio=None, meta=None, synonyms=None, prio2=None, + display=None +): + return _Candidate( + completion, prio, meta, synonyms or [completion], prio2, + display or completion + ) + + +# Used to strip trailing '::some_type' from default-value expressions +arg_default_type_strip_regex = re.compile(r'::[\w\.]+(\[\])?$') + + +def normalize_ref(ref): + return ref if ref[0] == '"' else '"' + ref.lower() + '"' + + +def generate_alias(tbl): + """ Generate a table alias, consisting of all upper-case letters in + the table name, or, if there are no upper-case letters, the first letter + + all letters preceded by _ + param tbl - unescaped name of the table to alias + """ + return ''.join( + [letter for letter in tbl if letter.isupper()] or + [letter for letter, prev in zip(tbl, '_' + tbl) + if prev == '_' and letter != '_'] + ) + class SQLAutoComplete(object): """ @@ -73,7 +84,6 @@ class SQLAutoComplete(object): the connection and get the tables, schemas, functions etc. based on the query. """ - def __init__(self, **kwargs): """ This method is used to initialize the class. @@ -83,9 +93,15 @@ class SQLAutoComplete(object): """ self.sid = kwargs['sid'] if 'sid' in kwargs else None - self.did = kwargs['did'] if 'did' in kwargs else None self.conn = kwargs['conn'] if 'conn' in kwargs else None self.keywords = [] + self.databases = [] + self.functions = [] + self.datatypes = [] + self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {}, + 'datatypes': {}} + self.text_before_cursor = None + self.name_pattern = re.compile("^[_a-z][_a-z0-9\$]*$") manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(self.sid) @@ -93,6 +109,7 @@ class SQLAutoComplete(object): self.sql_path = 'sqlautocomplete/sql/#{0}#'.format(manager.version) self.search_path = [] + schema_names = [] # Fetch the search path if self.conn.connected(): query = render_template( @@ -100,6 +117,7 @@ class SQLAutoComplete(object): status, res = self.conn.execute_dict(query) if status: for record in res['rows']: + schema_names.append(record['schema']) self.search_path.append(record['schema']) pref = Preferences.module('sqleditor') @@ -118,22 +136,44 @@ class SQLAutoComplete(object): for record in res['rows']: self.keywords.append(record['word']) - self.text_before_cursor = None self.prioritizer = PrevalenceCounter(self.keywords) self.reserved_words = set() for x in self.keywords: self.reserved_words.update(x.split()) - self.name_pattern = re.compile("^[_a-z][_a-z0-9\$]*$") + + self.all_completions = set(self.keywords) + self.extend_schemata(schema_names) + + # Below are the configurable options in pgcli which we don't have + # in pgAdmin4 at the moment. Setting the default value from the pgcli's + # config file. + self.signature_arg_style = '{arg_name} {arg_type}' + self.call_arg_style = '{arg_name: <{max_arg_len}} := {arg_default}' + self.call_arg_display_style = '{arg_name}' + self.call_arg_oneliner_max = 2 + self.search_path_filter = True + self.generate_aliases = False + self.insert_col_skip_patterns = [ + re.compile(r'^now\(\)$'), + re.compile(r'^nextval\(')] + self.qualify_columns = 'if_more_than_one_table' + self.asterisk_column_order = 'table_order' def escape_name(self, name): - if name and ((not self.name_pattern.match(name)) or - (name.upper() in self.reserved_words)): + if name and ( + (not self.name_pattern.match(name)) or + (name.upper() in self.reserved_words) + ): name = '"%s"' % name return name + def escape_schema(self, name): + return "'{}'".format(self.unescape_name(name)) + def unescape_name(self, name): + """ Unquote a string.""" if name and name[0] == '"' and name[-1] == '"': name = name[1:-1] @@ -142,15 +182,176 @@ class SQLAutoComplete(object): def escaped_names(self, names): return [self.escape_name(name) for name in names] - def find_matches(self, text, collection, mode='fuzzy', - meta=None, meta_collection=None): + def extend_database_names(self, databases): + self.databases.extend(databases) + + def extend_keywords(self, additional_keywords): + self.keywords.extend(additional_keywords) + self.all_completions.update(additional_keywords) + + def extend_schemata(self, schemata): + + # schemata is a list of schema names + schemata = self.escaped_names(schemata) + metadata = self.dbmetadata['tables'] + for schema in schemata: + metadata[schema] = {} + + # dbmetadata.values() are the 'tables' and 'functions' dicts + for metadata in self.dbmetadata.values(): + for schema in schemata: + metadata[schema] = {} + + self.all_completions.update(schemata) + + def extend_casing(self, words): + """ extend casing data + + :return: """ - Find completion matches for the given text. + # casing should be a dict {lowercasename:PreferredCasingName} + self.casing = dict((word.lower(), word) for word in words) + + def extend_relations(self, data, kind): + """extend metadata for tables or views. + + :param data: list of (schema_name, rel_name) tuples + :param kind: either 'tables' or 'views' + + :return: + + """ + + data = [self.escaped_names(d) for d in data] + + # dbmetadata['tables']['schema_name']['table_name'] should be an + # OrderedDict {column_name:ColumnMetaData}. + metadata = self.dbmetadata[kind] + for schema, relname in data: + try: + metadata[schema][relname] = OrderedDict() + except KeyError: + print('%r %r listed in unrecognized schema %r', + kind, relname, schema) + + self.all_completions.add(relname) + + def extend_columns(self, column_data, kind): + """extend column metadata. + + :param column_data: list of (schema_name, rel_name, column_name, + column_type, has_default, default) tuples + :param kind: either 'tables' or 'views' + + :return: + + """ + metadata = self.dbmetadata[kind] + for schema, relname, colname, datatype, \ + has_default, default in column_data: + (schema, relname, colname) = self.escaped_names( + [schema, relname, colname]) + column = ColumnMetadata( + name=colname, + datatype=datatype, + has_default=has_default, + default=default + ) + metadata[schema][relname][colname] = column + self.all_completions.add(colname) + + def extend_functions(self, func_data): + + # func_data is a list of function metadata namedtuples + + # dbmetadata['schema_name']['functions']['function_name'] should return + # the function metadata namedtuple for the corresponding function + metadata = self.dbmetadata['functions'] + + for f in func_data: + schema, func = self.escaped_names([f.schema_name, f.func_name]) + + if func in metadata[schema]: + metadata[schema][func].append(f) + else: + metadata[schema][func] = [f] + + self.all_completions.add(func) + + self._refresh_arg_list_cache() + + def _refresh_arg_list_cache(self): + # We keep a cache of + # {function_usage:{function_metadata: function_arg_list_string}} + # This is used when suggesting functions, to avoid the latency that + # would result if we'd recalculate the arg lists each time we suggest + # functions (in large DBs) + self._arg_list_cache = { + usage: { + meta: self._arg_list(meta, usage) + for sch, funcs in self.dbmetadata['functions'].items() + for func, metas in funcs.items() + for meta in metas + } + for usage in ('call', 'call_display', 'signature') + } + + def extend_foreignkeys(self, fk_data): + + # fk_data is a list of ForeignKey namedtuples, with fields + # parentschema, childschema, parenttable, childtable, + # parentcolumns, childcolumns + + # These are added as a list of ForeignKey namedtuples to the + # ColumnMetadata namedtuple for both the child and parent + meta = self.dbmetadata['tables'] + + for fk in fk_data: + e = self.escaped_names + parentschema, childschema = e([fk.parentschema, fk.childschema]) + parenttable, childtable = e([fk.parenttable, fk.childtable]) + childcol, parcol = e([fk.childcolumn, fk.parentcolumn]) + childcolmeta = meta[childschema][childtable][childcol] + parcolmeta = meta[parentschema][parenttable][parcol] + fk = ForeignKey( + parentschema, parenttable, parcol, + childschema, childtable, childcol + ) + childcolmeta.foreignkeys.append((fk)) + parcolmeta.foreignkeys.append((fk)) + + def extend_datatypes(self, type_data): + + # dbmetadata['datatypes'][schema_name][type_name] should store type + # metadata, such as composite type field names. Currently, we're not + # storing any metadata beyond typename, so just store None + meta = self.dbmetadata['datatypes'] + + for t in type_data: + schema, type_name = self.escaped_names(t) + meta[schema][type_name] = None + self.all_completions.add(type_name) + + def set_search_path(self, search_path): + self.search_path = self.escaped_names(search_path) + + def reset_completions(self): + self.databases = [] + self.special_commands = [] + self.search_path = [] + self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {}, + 'datatypes': {}} + self.all_completions = set(self.keywords + self.functions) + + def find_matches(self, text, collection, mode='fuzzy', meta=None): + """Find completion matches for the given text. Given the user's input text and a collection of available completions, find completions matching the last word of the text. + `collection` can be either a list of strings or a list of Candidate + namedtuples. `mode` can be either 'fuzzy', or 'strict' 'fuzzy': fuzzy matching, ties broken by name prevalance `keyword`: start only matching, ties broken by keyword prevalance @@ -165,7 +366,14 @@ class SQLAutoComplete(object): meta: meta_collection: """ - + if not collection: + return [] + prio_order = [ + 'keyword', 'function', 'view', 'table', 'datatype', 'database', + 'schema', 'column', 'table alias', 'join', 'name join', 'fk join', + 'table format' + ] + type_priority = prio_order.index(meta) if meta in prio_order else -1 text = last_word(text, include='most_punctuations').lower() text_len = len(text) @@ -193,6 +401,12 @@ class SQLAutoComplete(object): pat = re.compile('(%s)' % regex) def _match(item): + if item.lower()[:len(text) + 1] in (text, text + ' '): + # Exact match of first word in suggestion + # This is to get exact alias matches to the top + # E.g. for input `e`, 'Entries E' should be on top + # (before e.g. `EndUsers EU`) + return float('Infinity'), -1 r = pat.search(self.unescape_name(item.lower())) if r: return -len(r.group()), -r.start() @@ -206,39 +420,56 @@ class SQLAutoComplete(object): # fuzzy matches return -float('Infinity'), -match_point - if meta_collection: - # Each possible completion in the collection has a corresponding - # meta-display string - collection = zip(collection, meta_collection) - else: - # All completions have an identical meta - collection = zip(collection, itertools.repeat(meta)) - matches = [] + for cand in collection: + if isinstance(cand, _Candidate): + item, prio, display_meta, synonyms, prio2, display = cand + if display_meta is None: + display_meta = meta + syn_matches = (_match(x) for x in synonyms) + # Nones need to be removed to avoid max() crashing in Python 3 + syn_matches = [m for m in syn_matches if m] + sort_key = max(syn_matches) if syn_matches else None + else: + item, display_meta, prio, prio2, display = \ + cand, meta, 0, 0, cand + sort_key = _match(cand) - for item, meta in collection: - sort_key = _match(item) if sort_key: - if meta and len(meta) > 50: + if display_meta and len(display_meta) > 50: # Truncate meta-text to 50 characters, if necessary - meta = meta[:47] + u'...' + display_meta = display_meta[:47] + u'...' # Lexical order of items in the collection, used for # tiebreaking items with the same match group length and start # position. Since we use *higher* priority to mean "more # important," we use -ord(c) to prioritize "aa" > "ab" and end # with 1 to prioritize shorter strings (ie "user" > "users"). + # We first do a case-insensitive sort and then a + # case-sensitive one as a tie breaker. # We also use the unescape_name to make sure quoted names have # the same priority as unquoted names. - lexical_priority = tuple( - -ord(c) for c in self.unescape_name(item)) + (1,) - - priority = sort_key, priority_func(item), lexical_priority - - matches.append(Match( - completion=Completion(item, -text_len, display_meta=meta), - priority=priority)) + lexical_priority = ( + tuple(0 if c in(' _') else -ord(c) + for c in self.unescape_name(item.lower())) + (1,) + + tuple(c for c in item) + ) + priority = ( + sort_key, type_priority, prio, priority_func(item), + prio2, lexical_priority + ) + matches.append( + Match( + completion=Completion( + text=item, + start_position=-text_len, + display_meta=display_meta, + display=display + ), + priority=priority + ) + ) return matches def get_completions(self, text, text_before_cursor): @@ -246,7 +477,7 @@ class SQLAutoComplete(object): word_before_cursor = self.get_word_before_cursor(word=True) matches = [] - suggestions = self.suggest_type(text, text_before_cursor) + suggestions = suggest_type(text, text_before_cursor) for suggestion in suggestions: suggestion_type = type(suggestion) @@ -262,119 +493,383 @@ class SQLAutoComplete(object): result = dict() for m in matches: - # Escape name only if meta type is not a keyword and datatype. - if m.completion.display_meta != 'keyword' and \ - m.completion.display_meta != 'datatype': - name = self.escape_name(m.completion.display) - else: - name = m.completion.display - + name = m.completion.display result[name] = {'object_type': m.completion.display_meta} return result def get_column_matches(self, suggestion, word_before_cursor): - tables = suggestion.tables - scoped_cols = self.populate_scoped_cols(tables) - if suggestion.drop_unique: - # drop_unique is used for 'tb11 JOIN tbl2 USING (...' which should - # suggest only columns that appear in more than one table - scoped_cols = [col for (col, count) - in Counter(scoped_cols).items() - if count > 1 and col != '*'] + # Tables and Views should be populated first. + self.fetch_schema_objects(None, 'tables') + self.fetch_schema_objects(None, 'views') - return self.find_matches( - word_before_cursor, scoped_cols, mode='strict', meta='column' + tables = suggestion.table_refs + do_qualify = suggestion.qualifiable and { + 'always': True, 'never': False, + 'if_more_than_one_table': len(tables) > 1}[self.qualify_columns] + + def qualify(col, tbl): + return (tbl + '.' + col) if do_qualify else col + + scoped_cols = self.populate_scoped_cols( + tables, suggestion.local_tables ) - def get_function_matches(self, suggestion, word_before_cursor): - if suggestion.filter == 'is_set_returning': - # Only suggest set-returning functions - funcs = self.populate_functions(suggestion.schema) - else: - funcs = self.populate_schema_objects( - suggestion.schema, 'functions') + def make_cand(name, ref): + synonyms = (name, generate_alias(name)) + return Candidate(qualify(name, ref), 0, 'column', synonyms) + def flat_cols(): + return [make_cand(c.name, t.ref) for t, cols in scoped_cols.items() + for c in cols] + if suggestion.require_last_table: + # require_last_table is used for 'tb11 JOIN tbl2 USING + # (...' which should + # suggest only columns that appear in the last table and one more + ltbl = tables[-1].ref + other_tbl_cols = set( + c.name for t, cs in scoped_cols.items() if t.ref != ltbl + for c in cs + ) + scoped_cols = { + t: [col for col in cols if col.name in other_tbl_cols] + for t, cols in scoped_cols.items() + if t.ref == ltbl + } + lastword = last_word(word_before_cursor, include='most_punctuations') + if lastword == '*': + if suggestion.context == 'insert': + def filter(col): + if not col.has_default: + return True + return not any( + p.match(col.default) + for p in self.insert_col_skip_patterns + ) + scoped_cols = { + t: [col for col in cols if filter(col)] + for t, cols in scoped_cols.items() + } + if self.asterisk_column_order == 'alphabetic': + for cols in scoped_cols.values(): + cols.sort(key=operator.attrgetter('name')) + if ( + lastword != word_before_cursor and + len(tables) == 1 and + word_before_cursor[-len(lastword) - 1] == '.' + ): + # User typed x.*; replicate "x." for all columns except the + # first, which gets the original (as we only replace the "*"") + sep = ', ' + word_before_cursor[:-1] + collist = sep.join(c.completion for c in flat_cols()) + else: + collist = ', '.join(qualify(c.name, t.ref) + for t, cs in scoped_cols.items() + for c in cs) + + return [Match( + completion=Completion( + collist, + -1, + display_meta='columns', + display='*' + ), + priority=(1, 1, 1) + )] + + return self.find_matches(word_before_cursor, flat_cols(), + mode='strict', meta='column') + + def alias(self, tbl, tbls): + """ Generate a unique table alias + tbl - name of the table to alias, quoted if it needs to be + tbls - TableReference iterable of tables already in query + """ + tbls = set(normalize_ref(t.ref) for t in tbls) + if self.generate_aliases: + tbl = generate_alias(self.unescape_name(tbl)) + if normalize_ref(tbl) not in tbls: + return tbl + elif tbl[0] == '"': + aliases = ('"' + tbl[1:-1] + str(i) + '"' for i in count(2)) + else: + aliases = (tbl + str(i) for i in count(2)) + return next(a for a in aliases if normalize_ref(a) not in tbls) + + def get_join_matches(self, suggestion, word_before_cursor): + tbls = suggestion.table_refs + cols = self.populate_scoped_cols(tbls) + # Set up some data structures for efficient access + qualified = dict((normalize_ref(t.ref), t.schema) for t in tbls) + ref_prio = dict((normalize_ref(t.ref), n) for n, t in enumerate(tbls)) + refs = set(normalize_ref(t.ref) for t in tbls) + other_tbls = set((t.schema, t.name) for t in list(cols)[:-1]) + joins = [] + # Iterate over FKs in existing tables to find potential joins + fks = ( + (fk, rtbl, rcol) for rtbl, rcols in cols.items() + for rcol in rcols for fk in rcol.foreignkeys + ) + col = namedtuple('col', 'schema tbl col') + for fk, rtbl, rcol in fks: + right = col(rtbl.schema, rtbl.name, rcol.name) + child = col(fk.childschema, fk.childtable, fk.childcolumn) + parent = col(fk.parentschema, fk.parenttable, fk.parentcolumn) + left = child if parent == right else parent + if suggestion.schema and left.schema != suggestion.schema: + continue + if self.generate_aliases or normalize_ref(left.tbl) in refs: + lref = self.alias(left.tbl, suggestion.table_refs) + join = '{0} {4} ON {4}.{1} = {2}.{3}'.format( + left.tbl, left.col, rtbl.ref, right.col, lref) + else: + join = '{0} ON {0}.{1} = {2}.{3}'.format( + left.tbl, left.col, rtbl.ref, right.col) + alias = generate_alias(left.tbl) + synonyms = [join, '{0} ON {0}.{1} = {2}.{3}'.format( + alias, left.col, rtbl.ref, right.col)] + # Schema-qualify if (1) new table in same schema as old, and old + # is schema-qualified, or (2) new in other schema, except public + if not suggestion.schema and \ + (qualified[normalize_ref(rtbl.ref)] and + left.schema == right.schema or + left.schema not in(right.schema, 'public')): + join = left.schema + '.' + join + prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + ( + 0 if (left.schema, left.tbl) in other_tbls else 1) + joins.append(Candidate(join, prio, 'join', synonyms=synonyms)) + + return self.find_matches(word_before_cursor, joins, + mode='strict', meta='join') + + def get_join_condition_matches(self, suggestion, word_before_cursor): + col = namedtuple('col', 'schema tbl col') + tbls = self.populate_scoped_cols(suggestion.table_refs).items + cols = [(t, c) for t, cs in tbls() for c in cs] + try: + lref = (suggestion.parent or suggestion.table_refs[-1]).ref + ltbl, lcols = [(t, cs) for (t, cs) in tbls() if t.ref == lref][-1] + except IndexError: # The user typed an incorrect table qualifier + return [] + conds, found_conds = [], set() + + def add_cond(lcol, rcol, rref, prio, meta): + prefix = '' if suggestion.parent else ltbl.ref + '.' + cond = prefix + lcol + ' = ' + rref + '.' + rcol + if cond not in found_conds: + found_conds.add(cond) + conds.append(Candidate(cond, prio + ref_prio[rref], meta)) + + def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]} + d = defaultdict(list) + for pair in pairs: + d[pair[0]].append(pair[1]) + return d + + # Tables that are closer to the cursor get higher prio + ref_prio = dict((tbl.ref, num) + for num, tbl in enumerate(suggestion.table_refs)) + # Map (schema, table, col) to tables + coldict = list_dict( + ((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref + ) + # For each fk from the left table, generate a join condition if + # the other table is also in the scope + fks = ((fk, lcol.name) for lcol in lcols for fk in lcol.foreignkeys) + for fk, lcol in fks: + left = col(ltbl.schema, ltbl.name, lcol) + child = col(fk.childschema, fk.childtable, fk.childcolumn) + par = col(fk.parentschema, fk.parenttable, fk.parentcolumn) + left, right = (child, par) if left == child else (par, child) + for rtbl in coldict[right]: + add_cond(left.col, right.col, rtbl.ref, 2000, 'fk join') + # For name matching, use a {(colname, coltype): TableReference} dict + coltyp = namedtuple('coltyp', 'name datatype') + col_table = list_dict((coltyp(c.name, c.datatype), t) for t, c in cols) + # Find all name-match join conditions + for c in (coltyp(c.name, c.datatype) for c in lcols): + for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref): + prio = 1000 if c.datatype in ( + 'integer', 'bigint', 'smallint') else 0 + add_cond(c.name, c.name, rtbl.ref, prio, 'name join') + + return self.find_matches(word_before_cursor, conds, + mode='strict', meta='join') + + def get_function_matches(self, suggestion, word_before_cursor, + alias=False): + if suggestion.usage == 'from': + # Only suggest functions allowed in FROM clause + def filt(f): + return not f.is_aggregate and not f.is_window + else: + alias = False + + def filt(f): + return True + + arg_mode = { + 'signature': 'signature', + 'special': None, + }.get(suggestion.usage, 'call') # Function overloading means we way have multiple functions of the same # name at this point, so keep unique names only - funcs = set(funcs) + funcs = set( + self._make_cand(f, alias, suggestion, arg_mode) + for f in self.populate_functions(suggestion.schema, filt) + ) - funcs = self.find_matches( - word_before_cursor, funcs, mode='strict', meta='function') + matches = self.find_matches(word_before_cursor, funcs, + mode='strict', meta='function') - return funcs + return matches - def get_schema_matches(self, _, word_before_cursor): - schema_names = [] - - query = render_template("/".join([self.sql_path, 'schema.sql'])) - status, res = self.conn.execute_dict(query) - if status: - for record in res['rows']: - schema_names.append(record['schema']) + def get_schema_matches(self, suggestion, word_before_cursor): + schema_names = self.dbmetadata['tables'].keys() # Unless we're sure the user really wants them, hide schema names # starting with pg_, which are mostly temporary schemas if not word_before_cursor.startswith('pg_'): - schema_names = [s for s in schema_names if not s.startswith('pg_')] + schema_names = [s + for s in schema_names + if not s.startswith('pg_')] - return self.find_matches( - word_before_cursor, schema_names, mode='strict', meta='schema' + if suggestion.quoted: + schema_names = [self.escape_schema(s) for s in schema_names] + + return self.find_matches(word_before_cursor, schema_names, + mode='strict', meta='schema') + + def get_from_clause_item_matches(self, suggestion, word_before_cursor): + alias = self.generate_aliases + s = suggestion + t_sug = Table(s.schema, s.table_refs, s.local_tables) + v_sug = View(s.schema, s.table_refs) + f_sug = Function(s.schema, s.table_refs, usage='from') + return ( + self.get_table_matches(t_sug, word_before_cursor, alias) + + self.get_view_matches(v_sug, word_before_cursor, alias) + + self.get_function_matches(f_sug, word_before_cursor, alias) ) - def get_table_matches(self, suggestion, word_before_cursor): + def _arg_list(self, func, usage): + """Returns a an arg list string, e.g. `(_foo:=23)` for a func. + + :param func is a FunctionMetadata object + :param usage is 'call', 'call_display' or 'signature' + + """ + template = { + 'call': self.call_arg_style, + 'call_display': self.call_arg_display_style, + 'signature': self.signature_arg_style + }[usage] + args = func.args() + if not template: + return '()' + elif usage == 'call' and len(args) < 2: + return '()' + elif usage == 'call' and func.has_variadic(): + return '()' + multiline = usage == 'call' and len(args) > self.call_arg_oneliner_max + max_arg_len = max(len(a.name) for a in args) if multiline else 0 + args = ( + self._format_arg(template, arg, arg_num + 1, max_arg_len) + for arg_num, arg in enumerate(args) + ) + if multiline: + return '(' + ','.join('\n ' + a for a in args if a) + '\n)' + else: + return '(' + ', '.join(a for a in args if a) + ')' + + def _format_arg(self, template, arg, arg_num, max_arg_len): + if not template: + return None + if arg.has_default: + arg_default = 'NULL' if arg.default is None else arg.default + # Remove trailing ::(schema.)type + arg_default = arg_default_type_strip_regex.sub('', arg_default) + else: + arg_default = '' + return template.format( + max_arg_len=max_arg_len, + arg_name=arg.name, + arg_num=arg_num, + arg_type=arg.datatype, + arg_default=arg_default + ) + + def _make_cand(self, tbl, do_alias, suggestion, arg_mode=None): + """Returns a Candidate namedtuple. + + :param tbl is a SchemaObject + :param arg_mode determines what type of arg list to suffix for + functions. + Possible values: call, signature + + """ + cased_tbl = tbl.name + if do_alias: + alias = self.alias(cased_tbl, suggestion.table_refs) + synonyms = (cased_tbl, generate_alias(cased_tbl)) + maybe_alias = (' ' + alias) if do_alias else '' + maybe_schema = (tbl.schema + '.') if tbl.schema else '' + suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else '' + if arg_mode == 'call': + display_suffix = self._arg_list_cache['call_display'][tbl.meta] + elif arg_mode == 'signature': + display_suffix = self._arg_list_cache['signature'][tbl.meta] + else: + display_suffix = '' + item = maybe_schema + cased_tbl + suffix + maybe_alias + display = maybe_schema + cased_tbl + display_suffix + maybe_alias + prio2 = 0 if tbl.schema else 1 + return Candidate(item, synonyms=synonyms, prio2=prio2, display=display) + + def get_table_matches(self, suggestion, word_before_cursor, alias=False): tables = self.populate_schema_objects(suggestion.schema, 'tables') + tables.extend( + SchemaObject(tbl.name) for tbl in suggestion.local_tables) # Unless we're sure the user really wants them, don't suggest the # pg_catalog tables that are implicitly on the search path if not suggestion.schema and ( not word_before_cursor.startswith('pg_')): - tables = [t for t in tables if not t.startswith('pg_')] + tables = [t for t in tables if not t.name.startswith('pg_')] + tables = [self._make_cand(t, alias, suggestion) for t in tables] + return self.find_matches(word_before_cursor, tables, + mode='strict', meta='table') - return self.find_matches( - word_before_cursor, tables, mode='strict', meta='table' - ) - - def get_view_matches(self, suggestion, word_before_cursor): + def get_view_matches(self, suggestion, word_before_cursor, alias=False): views = self.populate_schema_objects(suggestion.schema, 'views') if not suggestion.schema and ( not word_before_cursor.startswith('pg_')): - views = [v for v in views if not v.startswith('pg_')] - - return self.find_matches( - word_before_cursor, views, mode='strict', meta='view' - ) + views = [v for v in views if not v.name.startswith('pg_')] + views = [self._make_cand(v, alias, suggestion) for v in views] + return self.find_matches(word_before_cursor, views, + mode='strict', meta='view') def get_alias_matches(self, suggestion, word_before_cursor): aliases = suggestion.aliases - return self.find_matches(word_before_cursor, aliases, mode='strict', - meta='table alias') + return self.find_matches(word_before_cursor, aliases, + mode='strict', meta='table alias') def get_database_matches(self, _, word_before_cursor): - databases = [] + return self.find_matches(word_before_cursor, self.databases, + mode='strict', meta='database') - query = render_template("/".join([self.sql_path, 'databases.sql'])) - if self.conn.connected(): - status, res = self.conn.execute_dict(query) - if status: - for record in res['rows']: - databases.append(record['datname']) - - return self.find_matches(word_before_cursor, databases, mode='strict', - meta='database') - - def get_keyword_matches(self, _, word_before_cursor): + def get_keyword_matches(self, suggestion, word_before_cursor): return self.find_matches(word_before_cursor, self.keywords, mode='strict', meta='keyword') def get_datatype_matches(self, suggestion, word_before_cursor): # suggest custom datatypes types = self.populate_schema_objects(suggestion.schema, 'datatypes') - matches = self.find_matches( - word_before_cursor, types, mode='strict', meta='datatype') - + types = [self._make_cand(t, False, suggestion) for t in types] + matches = self.find_matches(word_before_cursor, types, + mode='strict', meta='datatype') return matches def get_word_before_cursor(self, word=False): @@ -418,6 +913,9 @@ class SQLAutoComplete(object): pass suggestion_matchers = { + FromClauseItem: get_from_clause_item_matches, + JoinCondition: get_join_condition_matches, + Join: get_join_matches, Column: get_column_matches, Function: get_function_matches, Schema: get_schema_matches, @@ -429,158 +927,126 @@ class SQLAutoComplete(object): Datatype: get_datatype_matches, } - def populate_scoped_cols(self, scoped_tbls): - """ Find all columns in a set of scoped_tables + def populate_scoped_cols(self, scoped_tbls, local_tbls=()): + """Find all columns in a set of scoped_tables. + :param scoped_tbls: list of TableReference namedtuples - :return: list of column names + :param local_tbls: tuple(TableMetadata) + :return: {TableReference:{colname:ColumnMetaData}} + """ + ctes = dict((normalize_ref(t.name), t.columns) for t in local_tbls) + columns = OrderedDict() + meta = self.dbmetadata + + def addcols(schema, rel, alias, reltype, cols): + tbl = TableReference(schema, rel, alias, reltype == 'functions') + if tbl not in columns: + columns[tbl] = [] + columns[tbl].extend(cols) - columns = [] for tbl in scoped_tbls: - if tbl.schema: - # A fully qualified schema.relname reference - schema = self.escape_name(tbl.schema) + # Local tables should shadow database tables + if tbl.schema is None and normalize_ref(tbl.name) in ctes: + cols = ctes[normalize_ref(tbl.name)] + addcols(None, tbl.name, 'CTE', tbl.alias, cols) + continue + schemas = [tbl.schema] if tbl.schema else self.search_path + for schema in schemas: relname = self.escape_name(tbl.name) - + schema = self.escape_name(schema) if tbl.is_function: - query = render_template( - "/".join([self.sql_path, 'functions.sql']), - schema_name=schema, - func_name=relname - ) - - if self.conn.connected(): - status, res = self.conn.execute_dict(query) - func = None - if status: - for row in res['rows']: - func = FunctionMetadata( - row['schema_name'], row['func_name'], - row['arg_list'], row['return_type'], - row['is_aggregate'], row['is_window'], - row['is_set_returning'] - ) - if func: - columns.extend(func.fieldnames()) + # Return column names from a set-returning function + # Get an array of FunctionMetadata objects + functions = meta['functions'].get(schema, {}).get(relname) + for func in (functions or []): + # func is a FunctionMetadata object + cols = func.fields() + addcols(schema, relname, tbl.alias, 'functions', cols) else: - # We don't know if schema.relname is a table or view. Since - # tables and views cannot share the same name, we can check - # one at a time - - query = render_template( - "/".join([self.sql_path, 'columns.sql']), - object_name='table', - schema_name=schema, - rel_name=relname - ) - - if self.conn.connected(): - status, res = self.conn.execute_dict(query) - if status: - if len(res['rows']) > 0: - # Table exists, so don't bother checking for a - # view - for record in res['rows']: - columns.append(record['column_name']) - else: - query = render_template( - "/".join([self.sql_path, 'columns.sql']), - object_name='view', - schema_name=schema, - rel_name=relname - ) - - if self.conn.connected(): - status, res = self.conn.execute_dict(query) - if status: - for record in res['rows']: - columns.append( - record['column_name']) - else: - # Schema not specified, so traverse the search path looking for - # a table or view that matches. Note that in order to get - # proper shadowing behavior, we need to check both views and - # tables for each schema before checking the next schema - for schema in self.search_path: - relname = self.escape_name(tbl.name) - - if tbl.is_function: - query = render_template( - "/".join([self.sql_path, 'functions.sql']), - schema_name=schema, - func_name=relname - ) - - if self.conn.connected(): - status, res = self.conn.execute_dict(query) - func = None - if status: - for row in res['rows']: - func = FunctionMetadata( - row['schema_name'], row['func_name'], - row['arg_list'], row['return_type'], - row['is_aggregate'], row['is_window'], - row['is_set_returning'] - ) - if func: - columns.extend(func.fieldnames()) - else: - query = render_template( - "/".join([self.sql_path, 'columns.sql']), - object_name='table', - schema_name=schema, - rel_name=relname - ) - - if self.conn.connected(): - status, res = self.conn.execute_dict(query) - if status: - if len(res['rows']) > 0: - # Table exists, so don't bother checking - # for a view - for record in res['rows']: - columns.append(record['column_name']) - else: - query = render_template( - "/".join( - [self.sql_path, 'columns.sql'] - ), - object_name='view', - schema_name=schema, - rel_name=relname - ) - - if self.conn.connected(): - status, res = self.conn.execute_dict( - query - ) - if status: - for record in res['rows']: - columns.append( - record['column_name'] - ) + for reltype in ('tables', 'views'): + cols = meta[reltype].get(schema, {}).get(relname) + if cols: + cols = cols.values() + addcols(schema, relname, tbl.alias, reltype, cols) + break return columns + def _get_schemas(self, obj_typ, schema): + """Returns a list of schemas from which to suggest objects. + + :param schema is the schema qualification input by the user (if any) + + """ + metadata = self.dbmetadata[obj_typ] + if schema: + schema = self.escape_name(schema) + return [schema] if schema in metadata else [] + return self.search_path if self.search_path_filter else metadata.keys() + + def _maybe_schema(self, schema, parent): + return None if parent or schema in self.search_path else schema + def populate_schema_objects(self, schema, obj_type): - """ - Returns list of tables or functions for a (optional) schema + """Returns a list of SchemaObjects representing tables or views. + + :param schema is the schema qualification input by the user (if any) + + """ + # Fetch the schema objects first + self.fetch_schema_objects(schema, obj_type) + + return [ + SchemaObject( + name=obj, + schema=(self._maybe_schema(schema=sch, parent=schema)) + ) + for sch in self._get_schemas(obj_type, schema) + for obj in self.dbmetadata[obj_type][sch].keys() + ] + + def populate_functions(self, schema, filter_func): + """Returns a list of function SchemaObjects. + + :param filter_func is a function that accepts a FunctionMetadata + namedtuple and returns a boolean indicating whether that + function should be kept or discarded - Args: - schema: - obj_type: """ + # Fetch the functions list + self.fetch_functions(schema) + + # Because of multiple dispatch, we can have multiple functions + # with the same name, which is why `for meta in metas` is necessary + # in the comprehensions below + return [ + SchemaObject( + name=func, + schema=(self._maybe_schema(schema=sch, parent=schema)), + meta=meta + ) + for sch in self._get_schemas('functions', schema) + for (func, metas) in self.dbmetadata['functions'][sch].items() + for meta in metas + if filter_func(meta) + ] + + def fetch_schema_objects(self, schema, obj_type): + """ + This function is used to fetch schema objects like tables, views, etc.. + :return: + """ in_clause = '' query = '' - objects = [] + data = [] if schema: in_clause = '\'' + schema + '\'' else: for r in self.search_path: in_clause += '\'' + r + '\',' - # Remove extra comma if len(in_clause) > 0: in_clause = in_clause[:-1] @@ -593,9 +1059,6 @@ class SQLAutoComplete(object): query = render_template("/".join([self.sql_path, 'tableview.sql']), schema_names=in_clause, object_name='views') - elif obj_type == 'functions': - query = render_template("/".join([self.sql_path, 'functions.sql']), - schema_names=in_clause) elif obj_type == 'datatypes': query = render_template("/".join([self.sql_path, 'datatypes.sql']), schema_names=in_clause) @@ -604,342 +1067,118 @@ class SQLAutoComplete(object): status, res = self.conn.execute_dict(query) if status: for record in res['rows']: - objects.append(record['object_name']) + data.append( + (record['schema_name'], record['object_name']) + ) - return objects + if (obj_type == 'tables' or obj_type == 'views') and len(data) > 0: + self.extend_relations(data, obj_type) + self.extend_columns( + self.fetch_columns(in_clause, obj_type), obj_type + ) + if obj_type == 'tables': + self.extend_foreignkeys( + self.fetch_foreign_keys(in_clause, obj_type) + ) + elif obj_type == 'datatypes' and len(data) > 0: + self.extend_datatypes(data) - def populate_functions(self, schema): + def fetch_functions(self, schema): """ - Returns a list of function names - - filter_func is a function that accepts a FunctionMetadata namedtuple - and returns a boolean indicating whether that function should be - kept or discarded - - Args: - schema: + This function is used to fecth the list of functions. + :param schema: + :return: """ - in_clause = '' - funcs = [] + data = [] if schema: in_clause = '\'' + schema + '\'' else: for r in self.search_path: in_clause += '\'' + r + '\',' - # Remove extra comma if len(in_clause) > 0: in_clause = in_clause[:-1] query = render_template("/".join([self.sql_path, 'functions.sql']), - schema_names=in_clause, - is_set_returning=True) + schema_names=in_clause) if self.conn.connected(): status, res = self.conn.execute_dict(query) if status: - for record in res['rows']: - funcs.append(record['object_name']) + for row in res['rows']: + data.append(FunctionMetadata( + row['schema_name'], + row['func_name'], + row['arg_names'].strip('{}').split(',') + if row['arg_names'] is not None + else row['arg_names'], + row['arg_types'].strip('{}').split(',') + if row['arg_types'] is not None + else row['arg_types'], + row['arg_modes'].strip('{}').split(',') + if row['arg_modes'] is not None + else row['arg_modes'], + row['return_type'], + row['is_aggregate'], + row['is_window'], + row['is_set_returning'], + row['arg_defaults'].strip('{}').split(',') + if row['arg_defaults'] is not None + else row['arg_defaults'] + )) - return funcs + if len(data) > 0: + self.extend_functions(data) - def suggest_type(self, full_text, text_before_cursor): + def fetch_columns(self, schemas, obj_type): """ - Takes the full_text that is typed so far and also the text before the - cursor to suggest completion type and scope. - - Returns a tuple with a type of entity ('table', 'column' etc) and a - scope. A scope for a column category will be a list of tables. - - Args: - full_text: Contains complete query - text_before_cursor: Contains text before the cursor + This function is used to fetch the columns for the given schema name + :param schemas: + :param obj_type: + :return: """ - word_before_cursor = last_word( - text_before_cursor, include='many_punctuations') + data = [] + query = render_template("/".join([self.sql_path, 'columns.sql']), + schema_names=schemas, + object_name='table') + if obj_type == 'views': + query = render_template("/".join([self.sql_path, 'columns.sql']), + schema_names=schemas, + object_name='view') + if self.conn.connected(): + status, res = self.conn.execute_dict(query) + if status: + for row in res['rows']: + data.append(( + row['schema_name'], row['table_name'], + row['column_name'], row['type_name'], + row['has_default'], row['default'] + )) + return data - identifier = None - - def strip_named_query(txt): - """ - This will strip "save named query" command in the beginning of - the line: - '\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' - ' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' - - Args: - txt: - """ - - pattern = re.compile(r'^\s*\\ns\s+[A-z0-9\-_]+\s+') - if pattern.match(txt): - txt = pattern.sub('', txt) - return txt - - full_text = strip_named_query(full_text) - text_before_cursor = strip_named_query(text_before_cursor) - - # If we've partially typed a word then word_before_cursor won't be an - # empty string. In that case we want to remove the partially typed - # string before sending it to the sqlparser. Otherwise the last token - # will always be the partially typed string which renders the smart - # completion useless because it will always return the list of - # keywords as completion. - if word_before_cursor: - if word_before_cursor[-1] == '(' or word_before_cursor[0] == '\\': - parsed = sqlparse.parse(text_before_cursor) - else: - parsed = sqlparse.parse( - text_before_cursor[:-len(word_before_cursor)]) - - identifier = parse_partial_identifier(word_before_cursor) - else: - parsed = sqlparse.parse(text_before_cursor) - - statement = None - if len(parsed) > 1: - # Multiple statements being edited -- isolate the current one by - # cumulatively summing statement lengths to find the one that - # bounds the current position - current_pos = len(text_before_cursor) - stmt_start, stmt_end = 0, 0 - - for statement in parsed: - stmt_len = len(statement.to_unicode()) - stmt_start, stmt_end = stmt_end, stmt_end + stmt_len - - if stmt_end >= current_pos: - break - - text_before_cursor = full_text[stmt_start:current_pos] - full_text = full_text[stmt_start:] - elif parsed: - # A single statement - statement = parsed[0] - else: - # The empty string - statement = None - - last_token = statement and statement.token_prev( - len(statement.tokens) - ) or '' - - return self.suggest_based_on_last_token( - last_token, text_before_cursor, full_text, identifier - ) - - def suggest_based_on_last_token(self, token, text_before_cursor, full_text, - identifier): - # New version of sqlparse sends tuple, we need to make it - # compatible with our logic - if isinstance(token, tuple) and len(token) > 1: - token = token[1] - - if isinstance(token, string_types): - token_v = token.lower() - elif isinstance(token, Comparison): - # If 'token' is a Comparison type such as - # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling - # token.value on the comparison type will only return the lhs of - # the comparison. In this case a.id. So we need to do token.tokens - # to get both sides of the comparison and pick the last token out - # of that list. - token_v = token.tokens[-1].value.lower() - elif isinstance(token, Where): - # sqlparse groups all tokens from the where clause into a single - # token list. This means that token.value may be something like - # 'where foo > 5 and '. We need to look "inside" token.tokens to - # handle suggestions in complicated where clauses correctly - prev_keyword, text_before_cursor = find_prev_keyword( - text_before_cursor - ) - return self.suggest_based_on_last_token( - prev_keyword, text_before_cursor, full_text, identifier - ) - elif isinstance(token, Identifier): - # If the previous token is an identifier, we can suggest datatypes - # if we're in a parenthesized column/field list, e.g.: - # CREATE TABLE foo (Identifier - # CREATE FUNCTION foo (Identifier - # If we're not in a parenthesized list, the most likely scenario - # is the user is about to specify an alias, e.g.: - # SELECT Identifier - # SELECT foo FROM Identifier - prev_keyword, _ = find_prev_keyword(text_before_cursor) - if prev_keyword and prev_keyword.value == '(': - # Suggest datatypes - return self.suggest_based_on_last_token( - 'type', text_before_cursor, full_text, identifier - ) - else: - return Keyword(), - else: - token_v = token.value.lower() - - if not token: - return Keyword(), - elif token_v.endswith('('): - p = sqlparse.parse(text_before_cursor)[0] - - if p.tokens and isinstance(p.tokens[-1], Where): - # Four possibilities: - # 1 - Parenthesized clause like "WHERE foo AND (" - # Suggest columns/functions - # 2 - Function call like "WHERE foo(" - # Suggest columns/functions - # 3 - Subquery expression like "WHERE EXISTS (" - # Suggest keywords, in order to do a subquery - # 4 - Subquery OR array comparison like "WHERE foo = ANY(" - # Suggest columns/functions AND keywords. (If we wanted - # to be really fancy, we could suggest only array-typed - # columns) - - column_suggestions = self.suggest_based_on_last_token( - 'where', text_before_cursor, full_text, identifier - ) - - # Check for a subquery expression (cases 3 & 4) - where = p.tokens[-1] - prev_tok = where.token_prev(len(where.tokens) - 1) - - if isinstance(prev_tok, Comparison): - # e.g. "SELECT foo FROM bar WHERE foo = ANY(" - prev_tok = prev_tok.tokens[-1] - - prev_tok = prev_tok.value.lower() - if prev_tok == 'exists': - return Keyword(), - else: - return column_suggestions - - # Get the token before the parens - prev_tok = p.token_prev(len(p.tokens) - 1) - if prev_tok and prev_tok.value and \ - prev_tok.value.lower() == 'using': - # tbl1 INNER JOIN tbl2 USING (col1, col2) - tables = extract_tables(full_text) - - # suggest columns that are present in more than one table - return Column(tables=tables, drop_unique=True), - - elif p.token_first().value.lower() == 'select': - # If the lparen is preceeded by a space chances are we're - # about to do a sub-select. - if last_word(text_before_cursor, - 'all_punctuations').startswith('('): - return Keyword(), - # We're probably in a function argument list - return Column(tables=extract_tables(full_text)), - elif token_v in ('set', 'by', 'distinct'): - return Column(tables=extract_tables(full_text)), - elif token_v in ('select', 'where', 'having'): - # Check for a table alias or schema qualification - parent = (identifier and identifier.get_parent_name()) or [] - - if parent: - tables = extract_tables(full_text) - tables = tuple(t for t in tables if self.identifies(parent, t)) - return (Column(tables=tables), - Table(schema=parent), - View(schema=parent), - Function(schema=parent),) - else: - return (Column(tables=extract_tables(full_text)), - Function(schema=None), - Keyword(),) - - elif (token_v.endswith('join') and token.is_keyword) or \ - (token_v in ('copy', 'from', 'update', 'into', - 'describe', 'truncate')): - - schema = (identifier and identifier.get_parent_name()) or None - - # Suggest tables from either the currently-selected schema or the - # public schema if no schema has been specified - suggest = [Table(schema=schema)] - - if not schema: - # Suggest schemas - suggest.insert(0, Schema()) - - # Only tables can be TRUNCATED, otherwise suggest views - if token_v != 'truncate': - suggest.append(View(schema=schema)) - - # Suggest set-returning functions in the FROM clause - if token_v == 'from' or \ - (token_v.endswith('join') and token.is_keyword): - suggest.append( - Function(schema=schema, filter='is_set_returning') - ) - - return tuple(suggest) - - elif token_v in ('table', 'view', 'function'): - # E.g. 'DROP FUNCTION ', 'ALTER TABLE ' - rel_type = {'table': Table, 'view': View, - 'function': Function}[token_v] - schema = (identifier and identifier.get_parent_name()) or None - if schema: - return rel_type(schema=schema), - else: - return Schema(), rel_type(schema=schema) - elif token_v == 'on': - tables = extract_tables(full_text) # [(schema, table, alias), ...] - parent = (identifier and identifier.get_parent_name()) or None - if parent: - # "ON parent." - # parent can be either a schema name or table alias - tables = tuple(t for t in tables if self.identifies(parent, t)) - return (Column(tables=tables), - Table(schema=parent), - View(schema=parent), - Function(schema=parent)) - else: - # ON - # Use table alias if there is one, otherwise the table name - aliases = tuple(t.alias or t.name for t in tables) - return Alias(aliases=aliases), - - elif token_v in ('c', 'use', 'database', 'template'): - # "\c ", "DROP DATABASE ", - # "CREATE DATABASE WITH TEMPLATE " - return Database(), - elif token_v == 'schema': - # DROP SCHEMA schema_name - return Schema(), - elif token_v.endswith(',') or token_v in ('=', 'and', 'or'): - prev_keyword, text_before_cursor = find_prev_keyword( - text_before_cursor) - if prev_keyword: - return self.suggest_based_on_last_token( - prev_keyword, text_before_cursor, full_text, identifier) - else: - return () - elif token_v in ('type', '::'): - # ALTER TABLE foo SET DATA TYPE bar - # SELECT foo::bar - # Note that tables are a form of composite type in postgresql, so - # they're suggested here as well - schema = (identifier and identifier.get_parent_name()) or None - suggestions = [Datatype(schema=schema), - Table(schema=schema)] - if not schema: - suggestions.append(Schema()) - return tuple(suggestions) - else: - return Keyword(), - - def identifies(self, id, ref): + def fetch_foreign_keys(self, schemas, obj_type): """ - Returns true if string `id` matches TableReference `ref` - - Args: - id: - ref: + This function is used to fetch the foreign_keys for the given + schema name + :param schemas: + :param obj_type: + :return: """ - return id == ref.alias or id == ref.name or ( - ref.schema and (id == ref.schema + '.' + ref.name)) + + data = [] + query = render_template("/".join([self.sql_path, 'foreign_keys.sql']), + schema_names=schemas) + + if self.conn.connected(): + status, res = self.conn.execute_dict(query) + if status: + for row in res['rows']: + data.append(ForeignKey( + row['parentschema'], row['parenttable'], + row['parentcolumn'], row['childschema'], + row['childtable'], row['childcolumn'] + )) + return data diff --git a/web/pgadmin/utils/sqlautocomplete/counter.py b/web/pgadmin/utils/sqlautocomplete/counter.py deleted file mode 100644 index 3211527fb..000000000 --- a/web/pgadmin/utils/sqlautocomplete/counter.py +++ /dev/null @@ -1,197 +0,0 @@ -""" -Copied from http://code.activestate.com/recipes/576611-counter-class/ -""" - -from heapq import nlargest -from itertools import repeat -try: - from itertools import ifilter -except ImportError: - # ifilter is in-built function in Python3 as filter - ifilter = filter -from operator import itemgetter - - -class Counter(dict): - '''Dict subclass for counting hashable objects. Sometimes called a bag - or multiset. Elements are stored as dictionary keys and their counts - are stored as dictionary values. - - >>> Counter('zyzygy') - Counter({'y': 3, 'z': 2, 'g': 1}) - - ''' - - def __init__(self, iterable=None, **kwds): - '''Create a new, empty Counter object. And if given, count elements - from an input iterable. Or, initialize the count from another mapping - of elements to their counts. - - >>> c = Counter() # a new, empty counter - >>> c = Counter('gallahad') # a new counter from an iterable - >>> c = Counter({'a': 4, 'b': 2}) # a new counter from a mapping - >>> c = Counter(a=4, b=2) # a new counter from keyword args - - ''' - self.update(iterable, **kwds) - - def __missing__(self, key): - return 0 - - def most_common(self, n=None): - '''List the n most common elements and their counts from the most - common to the least. If n is None, then list all element counts. - - >>> Counter('abracadabra').most_common(3) - [('a', 5), ('r', 2), ('b', 2)] - - ''' - if n is None: - return sorted(self.iteritems(), key=itemgetter(1), reverse=True) - return nlargest(n, self.iteritems(), key=itemgetter(1)) - - def elements(self): - '''Iterator over elements repeating each as many times as its count. - - >>> c = Counter('ABCABC') - >>> sorted(c.elements()) - ['A', 'A', 'B', 'B', 'C', 'C'] - - If an element's count has been set to zero or is a negative number, - elements() will ignore it. - - ''' - for elem, count in self.iteritems(): - for _ in repeat(None, count): - yield elem - - # Override dict methods where the meaning changes for Counter objects. - - @classmethod - def fromkeys(cls, iterable, v=None): - raise NotImplementedError( - 'Counter.fromkeys() is undefined. Use Counter(iterable) instead.' - ) - - def update(self, iterable=None, **kwds): - '''Like dict.update() but add counts instead of replacing them. - - Source can be an iterable, a dictionary, or another Counter instance. - - >>> c = Counter('which') - >>> c.update('witch') # add elements from another iterable - >>> d = Counter('watch') - >>> c.update(d) # add elements from another counter - >>> c['h'] # four 'h' in which, witch, and watch - 4 - - ''' - if iterable is not None: - if hasattr(iterable, 'iteritems'): - if self: - self_get = self.get - for elem, count in iterable.iteritems(): - self[elem] = self_get(elem, 0) + count - else: - # fast path when counter is empty - dict.update(self, iterable) - else: - self_get = self.get - for elem in iterable: - self[elem] = self_get(elem, 0) + 1 - if kwds: - self.update(kwds) - - def copy(self): - 'Like dict.copy() but returns a Counter instance instead of a dict.' - return Counter(self) - - def __delitem__(self, elem): - """Like dict.__delitem__() but does not raise KeyError for missing - values.""" - if elem in self: - dict.__delitem__(self, elem) - - def __repr__(self): - if not self: - return '%s()' % self.__class__.__name__ - items = ', '.join(map('%r: %r'.__mod__, self.most_common())) - return '%s({%s})' % (self.__class__.__name__, items) - - # Multiset-style mathematical operations discussed in: - # Knuth TAOCP Volume II section 4.6.3 exercise 19 - # and at http://en.wikipedia.org/wiki/Multiset - # - # Outputs guaranteed to only include positive counts. - # - # To strip negative and zero counts, add-in an empty counter: - # c += Counter() - - def __add__(self, other): - '''Add counts from two counters. - - >>> Counter('abbb') + Counter('bcc') - Counter({'b': 4, 'c': 2, 'a': 1}) - - - ''' - if not isinstance(other, Counter): - return NotImplemented - result = Counter() - for elem in set(self) | set(other): - newcount = self[elem] + other[elem] - if newcount > 0: - result[elem] = newcount - return result - - def __sub__(self, other): - ''' Subtract count, but keep only results with positive counts. - - >>> Counter('abbbc') - Counter('bccd') - Counter({'b': 2, 'a': 1}) - - ''' - if not isinstance(other, Counter): - return NotImplemented - result = Counter() - for elem in set(self) | set(other): - newcount = self[elem] - other[elem] - if newcount > 0: - result[elem] = newcount - return result - - def __or__(self, other): - '''Union is the maximum of value in either of the input counters. - - >>> Counter('abbb') | Counter('bcc') - Counter({'b': 3, 'c': 2, 'a': 1}) - - ''' - if not isinstance(other, Counter): - return NotImplemented - _max = max - result = Counter() - for elem in set(self) | set(other): - newcount = _max(self[elem], other[elem]) - if newcount > 0: - result[elem] = newcount - return result - - def __and__(self, other): - ''' Intersection is the minimum of corresponding counts. - - >>> Counter('abbb') & Counter('bcc') - Counter({'b': 1}) - - ''' - if not isinstance(other, Counter): - return NotImplemented - _min = min - result = Counter() - if len(self) < len(other): - self, other = other, self - for elem in ifilter(self.__contains__, other): - newcount = _min(self[elem], other[elem]) - if newcount > 0: - result[elem] = newcount - return result diff --git a/web/pgadmin/utils/sqlautocomplete/function_metadata.py b/web/pgadmin/utils/sqlautocomplete/function_metadata.py deleted file mode 100644 index 2c3292bd9..000000000 --- a/web/pgadmin/utils/sqlautocomplete/function_metadata.py +++ /dev/null @@ -1,151 +0,0 @@ -import re - -import sqlparse -from sqlparse.tokens import Whitespace, Comment, Keyword, Name, Punctuation - -table_def_regex = re.compile(r'^TABLE\s*\((.+)\)$', re.IGNORECASE) - - -class FunctionMetadata(object): - def __init__(self, schema_name, func_name, arg_list, return_type, - is_aggregate, is_window, is_set_returning): - """Class for describing a postgresql function""" - - self.schema_name = schema_name - self.func_name = func_name - self.arg_list = arg_list.strip() - self.return_type = return_type.strip() - self.is_aggregate = is_aggregate - self.is_window = is_window - self.is_set_returning = is_set_returning - - def __eq__(self, other): - return (isinstance(other, self.__class__) and - self.__dict__ == other.__dict__) - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - return hash((self.schema_name, self.func_name, self.arg_list, - self.return_type, self.is_aggregate, self.is_window, - self.is_set_returning)) - - def __repr__(self): - return ( - ('%s(schema_name=%r, func_name=%r, arg_list=%r, return_type=%r,' - ' is_aggregate=%r, is_window=%r, is_set_returning=%r)') - % (self.__class__.__name__, self.schema_name, self.func_name, - self.arg_list, self.return_type, self.is_aggregate, - self.is_window, self.is_set_returning) - ) - - def fieldnames(self): - """Returns a list of output field names""" - - if self.return_type.lower() == 'void': - return [] - - match = table_def_regex.match(self.return_type) - if match: - # Function returns a table -- get the column names - return list(field_names(match.group(1), mode_filter=None)) - - # Function may have named output arguments -- find them and return - # their names - return list(field_names(self.arg_list, mode_filter=('OUT', 'INOUT'))) - - -class TypedFieldMetadata(object): - """Describes typed field from a function signature or table definition - - Attributes are: - name The name of the argument/column - mode 'IN', 'OUT', 'INOUT', 'VARIADIC' - type A list of tokens denoting the type - default A list of tokens denoting the default value - unknown A list of tokens not assigned to type or default - """ - - __slots__ = ['name', 'mode', 'type', 'default', 'unknown'] - - def __init__(self): - self.name = None - self.mode = 'IN' - self.type = [] - self.default = [] - self.unknown = [] - - def __getitem__(self, attr): - return getattr(self, attr) - - -def parse_typed_field_list(tokens): - """Parses a argument/column list, yielding TypedFieldMetadata objects - - Field/column lists are used in function signatures and table - definitions. This function parses a flattened list of sqlparse tokens - and yields one metadata argument per argument / column. - """ - - # postgres function argument list syntax: - # " ( [ [ argmode ] [ argname ] argtype - # [ { DEFAULT | = } default_expr ] [, ...] ] )" - - mode_names = set(('IN', 'OUT', 'INOUT', 'VARIADIC')) - parse_state = 'type' - parens = 0 - field = TypedFieldMetadata() - - for tok in tokens: - if tok.ttype in Whitespace or tok.ttype in Comment: - continue - elif tok.ttype in Punctuation: - if parens == 0 and tok.value == ',': - # End of the current field specification - if field.type: - yield field - # Initialize metadata holder for the next field - field, parse_state = TypedFieldMetadata(), 'type' - elif parens == 0 and tok.value == '=': - parse_state = 'default' - else: - field[parse_state].append(tok) - if tok.value == '(': - parens += 1 - elif tok.value == ')': - parens -= 1 - elif parens == 0: - if tok.ttype in Keyword: - if not field.name and tok.value.upper() in mode_names: - # No other keywords allowed before arg name - field.mode = tok.value.upper() - elif tok.value.upper() == 'DEFAULT': - parse_state = 'default' - else: - parse_state = 'unknown' - elif tok.ttype == Name and not field.name: - # note that `ttype in Name` would also match Name.Builtin - field.name = tok.value - else: - field[parse_state].append(tok) - else: - field[parse_state].append(tok) - - # Final argument won't be followed by a comma, so make sure it gets - # yielded - if field.type: - yield field - - -def field_names(sql, mode_filter=('IN', 'OUT', 'INOUT', 'VARIADIC')): - """Yields field names from a table declaration""" - - if not sql: - return - - # sql is something like "x int, y text, ..." - tokens = sqlparse.parse(sql)[0].flatten() - for f in parse_typed_field_list(tokens): - if f.name and (not mode_filter or f.mode in mode_filter): - yield f.name diff --git a/web/pgadmin/utils/sqlautocomplete/parseutils.py b/web/pgadmin/utils/sqlautocomplete/parseutils.py index b08a3927c..85257b0a7 100644 --- a/web/pgadmin/utils/sqlautocomplete/parseutils.py +++ b/web/pgadmin/utils/sqlautocomplete/parseutils.py @@ -73,7 +73,7 @@ TableReference = namedtuple( # This code is borrowed from sqlparse example script. # def is_subselect(parsed): - if not parsed.is_group: + if not parsed.is_group(): return False sql_type = ('SELECT', 'INSERT', 'UPDATE', 'CREATE', 'DELETE') for item in parsed.tokens: @@ -104,7 +104,7 @@ def extract_from_part(parsed, stop_at_punctuation=True): # condition. So we need to ignore the keyword JOIN and its # variants INNER JOIN, FULL OUTER JOIN, etc. elif item.ttype is Keyword and ( - not item.value.upper() == 'FROM') and ( + not item.value.upper() == 'FROM') and ( not item.value.upper().endswith('JOIN')): tbl_prefix_seen = False else: diff --git a/web/pgadmin/utils/sqlautocomplete/parseutils/__init__.py b/web/pgadmin/utils/sqlautocomplete/parseutils/__init__.py new file mode 100644 index 000000000..7dafbdbf5 --- /dev/null +++ b/web/pgadmin/utils/sqlautocomplete/parseutils/__init__.py @@ -0,0 +1,22 @@ +import sqlparse + + +def query_starts_with(query, prefixes): + """Check if the query starts with any item from *prefixes*.""" + prefixes = [prefix.lower() for prefix in prefixes] + formatted_sql = sqlparse.format(query.lower(), strip_comments=True) + return bool(formatted_sql) and formatted_sql.split()[0] in prefixes + + +def queries_start_with(queries, prefixes): + """Check if any queries start with any item from *prefixes*.""" + for query in sqlparse.split(queries): + if query and query_starts_with(query, prefixes) is True: + return True + return False + + +def is_destructive(queries): + """Returns if any of the queries in *queries* is destructive.""" + keywords = ('drop', 'shutdown', 'delete', 'truncate', 'alter') + return queries_start_with(queries, keywords) diff --git a/web/pgadmin/utils/sqlautocomplete/parseutils/ctes.py b/web/pgadmin/utils/sqlautocomplete/parseutils/ctes.py new file mode 100644 index 000000000..ec8838934 --- /dev/null +++ b/web/pgadmin/utils/sqlautocomplete/parseutils/ctes.py @@ -0,0 +1,143 @@ +from sqlparse import parse +from sqlparse.tokens import Keyword, CTE, DML +from sqlparse.sql import Identifier, IdentifierList, Parenthesis +from collections import namedtuple +from .meta import TableMetadata, ColumnMetadata + + +# TableExpression is a namedtuple representing a CTE, used internally +# name: cte alias assigned in the query +# columns: list of column names +# start: index into the original string of the left parens starting the CTE +# stop: index into the original string of the right parens ending the CTE +TableExpression = namedtuple('TableExpression', 'name columns start stop') + + +def isolate_query_ctes(full_text, text_before_cursor): + """Simplify a query by converting CTEs into table metadata objects + """ + + if not full_text: + return full_text, text_before_cursor, tuple() + + ctes, remainder = extract_ctes(full_text) + if not ctes: + return full_text, text_before_cursor, () + + current_position = len(text_before_cursor) + meta = [] + + for cte in ctes: + if cte.start < current_position < cte.stop: + # Currently editing a cte - treat its body as the current full_text + text_before_cursor = full_text[cte.start:current_position] + full_text = full_text[cte.start:cte.stop] + return full_text, text_before_cursor, meta + + # Append this cte to the list of available table metadata + cols = (ColumnMetadata(name, None, ()) for name in cte.columns) + meta.append(TableMetadata(cte.name, cols)) + + # Editing past the last cte (ie the main body of the query) + full_text = full_text[ctes[-1].stop:] + text_before_cursor = text_before_cursor[ctes[-1].stop:current_position] + + return full_text, text_before_cursor, tuple(meta) + + +def extract_ctes(sql): + """ Extract constant table expresseions from a query + + Returns tuple (ctes, remainder_sql) + + ctes is a list of TableExpression namedtuples + remainder_sql is the text from the original query after the CTEs have + been stripped. + """ + + p = parse(sql)[0] + + # Make sure the first meaningful token is "WITH" which is necessary to + # define CTEs + idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True) + if not (tok and tok.ttype == CTE): + return [], sql + + # Get the next (meaningful) token, which should be the first CTE + idx, tok = p.token_next(idx) + if not tok: + return ([], '') + start_pos = token_start_pos(p.tokens, idx) + ctes = [] + + if isinstance(tok, IdentifierList): + # Multiple ctes + for t in tok.get_identifiers(): + cte_start_offset = token_start_pos(tok.tokens, tok.token_index(t)) + cte = get_cte_from_token(t, start_pos + cte_start_offset) + if not cte: + continue + ctes.append(cte) + elif isinstance(tok, Identifier): + # A single CTE + cte = get_cte_from_token(tok, start_pos) + if cte: + ctes.append(cte) + + idx = p.token_index(tok) + 1 + + # Collapse everything after the ctes into a remainder query + remainder = u''.join(str(tok) for tok in p.tokens[idx:]) + + return ctes, remainder + + +def get_cte_from_token(tok, pos0): + cte_name = tok.get_real_name() + if not cte_name: + return None + + # Find the start position of the opening parens enclosing the cte body + idx, parens = tok.token_next_by(Parenthesis) + if not parens: + return None + + start_pos = pos0 + token_start_pos(tok.tokens, idx) + cte_len = len(str(parens)) # includes parens + stop_pos = start_pos + cte_len + + column_names = extract_column_names(parens) + + return TableExpression(cte_name, column_names, start_pos, stop_pos) + + +def extract_column_names(parsed): + # Find the first DML token to check if it's a SELECT or + # INSERT/UPDATE/DELETE + idx, tok = parsed.token_next_by(t=DML) + tok_val = tok and tok.value.lower() + + if tok_val in ('insert', 'update', 'delete'): + # Jump ahead to the RETURNING clause where the list of column names is + idx, tok = parsed.token_next_by(idx, (Keyword, 'returning')) + elif not tok_val == 'select': + # Must be invalid CTE + return () + + # The next token should be either a column name, or a list of column names + idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True) + return tuple(t.get_name() for t in _identifiers(tok)) + + +def token_start_pos(tokens, idx): + return sum(len(str(t)) for t in tokens[:idx]) + + +def _identifiers(tok): + if isinstance(tok, IdentifierList): + for t in tok.get_identifiers(): + # NB: IdentifierList.get_identifiers() can return non-identifiers! + if isinstance(t, Identifier): + yield t + elif isinstance(tok, Identifier): + yield tok diff --git a/web/pgadmin/utils/sqlautocomplete/parseutils/meta.py b/web/pgadmin/utils/sqlautocomplete/parseutils/meta.py new file mode 100644 index 000000000..eab5f2d5c --- /dev/null +++ b/web/pgadmin/utils/sqlautocomplete/parseutils/meta.py @@ -0,0 +1,151 @@ +from collections import namedtuple + +_ColumnMetadata = namedtuple( + 'ColumnMetadata', + ['name', 'datatype', 'foreignkeys', 'default', 'has_default'] +) + + +def ColumnMetadata( + name, datatype, foreignkeys=None, default=None, has_default=False +): + return _ColumnMetadata( + name, datatype, foreignkeys or [], default, has_default + ) + + +ForeignKey = namedtuple('ForeignKey', ['parentschema', 'parenttable', + 'parentcolumn', 'childschema', + 'childtable', 'childcolumn']) +TableMetadata = namedtuple('TableMetadata', 'name columns') + + +def parse_defaults(defaults_string): + """Yields default values for a function, given the string provided by + pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)""" + if not defaults_string: + return + current = '' + in_quote = None + for char in defaults_string: + if current == '' and char == ' ': + # Skip space after comma separating default expressions + continue + if char == '"' or char == '\'': + if in_quote and char == in_quote: + # End quote + in_quote = None + elif not in_quote: + # Begin quote + in_quote = char + elif char == ',' and not in_quote: + # End of expression + yield current + current = '' + continue + current += char + yield current + + +class FunctionMetadata(object): + + def __init__( + self, schema_name, func_name, arg_names, arg_types, arg_modes, + return_type, is_aggregate, is_window, is_set_returning, + arg_defaults + ): + """Class for describing a postgresql function""" + + self.schema_name = schema_name + self.func_name = func_name + + self.arg_modes = tuple(arg_modes) if arg_modes else None + self.arg_names = tuple(arg_names) if arg_names else None + + # Be flexible in not requiring arg_types -- use None as a placeholder + # for each arg. (Used for compatibility with old versions of postgresql + # where such info is hard to get. + if arg_types: + self.arg_types = tuple(arg_types) + elif arg_modes: + self.arg_types = tuple([None] * len(arg_modes)) + elif arg_names: + self.arg_types = tuple([None] * len(arg_names)) + else: + self.arg_types = None + + self.arg_defaults = tuple(parse_defaults(arg_defaults)) + + self.return_type = return_type.strip() + self.is_aggregate = is_aggregate + self.is_window = is_window + self.is_set_returning = is_set_returning + + def __eq__(self, other): + return (isinstance(other, self.__class__) and + self.__dict__ == other.__dict__) + + def __ne__(self, other): + return not self.__eq__(other) + + def _signature(self): + return ( + self.schema_name, self.func_name, self.arg_names, self.arg_types, + self.arg_modes, self.return_type, self.is_aggregate, + self.is_window, self.is_set_returning, self.arg_defaults + ) + + def __hash__(self): + return hash(self._signature()) + + def __repr__(self): + return ( + ( + '%s(schema_name=%r, func_name=%r, arg_names=%r, ' + 'arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, ' + 'is_window=%r, is_set_returning=%r, arg_defaults=%r)' + ) % (self.__class__.__name__,) + self._signature() + ) + + def has_variadic(self): + return self.arg_modes and any( + arg_mode == 'v' for arg_mode in self.arg_modes) + + def args(self): + """Returns a list of input-parameter ColumnMetadata namedtuples.""" + if not self.arg_names: + return [] + modes = self.arg_modes or ['i'] * len(self.arg_names) + args = [ + (name, typ) + for name, typ, mode in zip(self.arg_names, self.arg_types, modes) + if mode in ('i', 'b', 'v') # IN, INOUT, VARIADIC + ] + + def arg(name, typ, num): + num_args = len(args) + num_defaults = len(self.arg_defaults) + has_default = num + num_defaults >= num_args + default = ( + self.arg_defaults[num - num_args + num_defaults] if has_default + else None + ) + return ColumnMetadata(name, typ, [], default, has_default) + + return [arg(name, typ, num) for num, (name, typ) in enumerate(args)] + + def fields(self): + """Returns a list of output-field ColumnMetadata namedtuples""" + + if self.return_type.lower() == 'void': + return [] + elif not self.arg_modes: + # For functions without output parameters, the function name + # is used as the name of the output column. + # E.g. 'SELECT unnest FROM unnest(...);' + return [ColumnMetadata(self.func_name, self.return_type, [])] + + return [ColumnMetadata(name, typ, []) + for name, typ, mode in zip( + self.arg_names, self.arg_types, self.arg_modes) + if mode in ('o', 'b', 't')] # OUT, INOUT, TABLE diff --git a/web/pgadmin/utils/sqlautocomplete/parseutils/tables.py b/web/pgadmin/utils/sqlautocomplete/parseutils/tables.py new file mode 100644 index 000000000..94ffc7f46 --- /dev/null +++ b/web/pgadmin/utils/sqlautocomplete/parseutils/tables.py @@ -0,0 +1,149 @@ +from __future__ import print_function +import sqlparse +from collections import namedtuple +from sqlparse.sql import IdentifierList, Identifier, Function +from sqlparse.tokens import Keyword, DML, Punctuation + +TableReference = namedtuple('TableReference', ['schema', 'name', 'alias', + 'is_function']) +TableReference.ref = property( + lambda self: self.alias or ( + self.name if self.name.islower() or self.name[0] == '"' + else '"' + self.name + '"') +) + + +# This code is borrowed from sqlparse example script. +# +def is_subselect(parsed): + if not parsed.is_group: + return False + for item in parsed.tokens: + if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT', + 'UPDATE', 'CREATE', + 'DELETE'): + return True + return False + + +def _identifier_is_function(identifier): + return any(isinstance(t, Function) for t in identifier.tokens) + + +def extract_from_part(parsed, stop_at_punctuation=True): + tbl_prefix_seen = False + for item in parsed.tokens: + if tbl_prefix_seen: + if is_subselect(item): + for x in extract_from_part(item, stop_at_punctuation): + yield x + elif stop_at_punctuation and item.ttype is Punctuation: + raise StopIteration + # An incomplete nested select won't be recognized correctly as a + # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes + # the second FROM to trigger this elif condition resulting in a + # StopIteration. So we need to ignore the keyword if the keyword + # FROM. + # Also 'SELECT * FROM abc JOIN def' will trigger this elif + # condition. So we need to ignore the keyword JOIN and its variants + # INNER JOIN, FULL OUTER JOIN, etc. + elif item.ttype is Keyword and ( + not item.value.upper() == 'FROM') and ( + not item.value.upper().endswith('JOIN')): + tbl_prefix_seen = False + else: + yield item + elif item.ttype is Keyword or item.ttype is Keyword.DML: + item_val = item.value.upper() + if (item_val in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE') or + item_val.endswith('JOIN')): + tbl_prefix_seen = True + # 'SELECT a, FROM abc' will detect FROM as part of the column list. + # So this check here is necessary. + elif isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + if (identifier.ttype is Keyword and + identifier.value.upper() == 'FROM'): + tbl_prefix_seen = True + break + + +def extract_table_identifiers(token_stream, allow_functions=True): + """yields tuples of TableReference namedtuples""" + + # We need to do some massaging of the names because postgres is case- + # insensitive and '"Foo"' is not the same table as 'Foo' (while 'foo' is) + def parse_identifier(item): + name = item.get_real_name() + schema_name = item.get_parent_name() + alias = item.get_alias() + if not name: + schema_name = None + name = item.get_name() + alias = alias or name + schema_quoted = schema_name and item.value[0] == '"' + if schema_name and not schema_quoted: + schema_name = schema_name.lower() + quote_count = item.value.count('"') + name_quoted = quote_count > 2 or (quote_count and not schema_quoted) + alias_quoted = alias and item.value[-1] == '"' + if alias_quoted or name_quoted and not alias and name.islower(): + alias = '"' + (alias or name) + '"' + if name and not name_quoted and not name.islower(): + if not alias: + alias = name + name = name.lower() + return schema_name, name, alias + + for item in token_stream: + if isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + # Sometimes Keywords (such as FROM ) are classified as + # identifiers which don't have the get_real_name() method. + try: + schema_name = identifier.get_parent_name() + real_name = identifier.get_real_name() + is_function = (allow_functions and + _identifier_is_function(identifier)) + except AttributeError: + continue + if real_name: + yield TableReference(schema_name, real_name, + identifier.get_alias(), is_function) + elif isinstance(item, Identifier): + schema_name, real_name, alias = parse_identifier(item) + is_function = allow_functions and _identifier_is_function(item) + + yield TableReference(schema_name, real_name, alias, is_function) + elif isinstance(item, Function): + schema_name, real_name, alias = parse_identifier(item) + yield TableReference(None, real_name, alias, allow_functions) + + +# extract_tables is inspired from examples in the sqlparse lib. +def extract_tables(sql): + """Extract the table names from an SQL statment. + + Returns a list of TableReference namedtuples + + """ + parsed = sqlparse.parse(sql) + if not parsed: + return () + + # INSERT statements must stop looking for tables at the sign of first + # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2) + # abc is the table name, but if we don't stop at the first lparen, then + # we'll identify abc, col1 and col2 as table names. + insert_stmt = parsed[0].token_first().value.lower() == 'insert' + stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt) + + # Kludge: sqlparse mistakenly identifies insert statements as + # function calls due to the parenthesized column list, e.g. interprets + # "insert into foo (bar, baz)" as a function call to foo with arguments + # (bar, baz). So don't allow any identifiers in insert statements + # to have is_function=True + identifiers = extract_table_identifiers(stream, + allow_functions=not insert_stmt) + # In the case 'sche.', we get an empty TableReference; remove that + return tuple(i for i in identifiers if i.name) diff --git a/web/pgadmin/utils/sqlautocomplete/parseutils/utils.py b/web/pgadmin/utils/sqlautocomplete/parseutils/utils.py new file mode 100644 index 000000000..6bfe7cbf5 --- /dev/null +++ b/web/pgadmin/utils/sqlautocomplete/parseutils/utils.py @@ -0,0 +1,140 @@ +from __future__ import print_function +import re +import sqlparse +from sqlparse.sql import Identifier +from sqlparse.tokens import Token, Error + +cleanup_regex = { + # This matches only alphanumerics and underscores. + 'alphanum_underscore': re.compile(r'(\w+)$'), + # This matches everything except spaces, parens, colon, and comma + 'many_punctuations': re.compile(r'([^():,\s]+)$'), + # This matches everything except spaces, parens, colon, comma, and period + 'most_punctuations': re.compile(r'([^\.():,\s]+)$'), + # This matches everything except a space. + 'all_punctuations': re.compile(r'([^\s]+)$'), +} + + +def last_word(text, include='alphanum_underscore'): + r""" + Find the last word in a sentence. + + >>> last_word('abc') + 'abc' + >>> last_word(' abc') + 'abc' + >>> last_word('') + '' + >>> last_word(' ') + '' + >>> last_word('abc ') + '' + >>> last_word('abc def') + 'def' + >>> last_word('abc def ') + '' + >>> last_word('abc def;') + '' + >>> last_word('bac $def') + 'def' + >>> last_word('bac $def', include='most_punctuations') + '$def' + >>> last_word('bac \def', include='most_punctuations') + '\\\\def' + >>> last_word('bac \def;', include='most_punctuations') + '\\\\def;' + >>> last_word('bac::def', include='most_punctuations') + 'def' + >>> last_word('"foo*bar', include='most_punctuations') + '"foo*bar' + """ + + if not text: # Empty string + return '' + + if text[-1].isspace(): + return '' + else: + regex = cleanup_regex[include] + matches = regex.search(text) + if matches: + return matches.group(0) + else: + return '' + + +def find_prev_keyword(sql, n_skip=0): + """ Find the last sql keyword in an SQL statement + + Returns the value of the last keyword, and the text of the query with + everything after the last keyword stripped + """ + if not sql.strip(): + return None, '' + + parsed = sqlparse.parse(sql)[0] + flattened = list(parsed.flatten()) + flattened = flattened[:len(flattened) - n_skip] + + logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN') + + for t in reversed(flattened): + if t.value == '(' or (t.is_keyword and ( + t.value.upper() not in logical_operators)): + # Find the location of token t in the original parsed statement + # We can't use parsed.token_index(t) because t may be a child token + # inside a TokenList, in which case token_index thows an error + # Minimal example: + # p = sqlparse.parse('select * from foo where bar') + # t = list(p.flatten())[-3] # The "Where" token + # p.token_index(t) # Throws ValueError: not in list + idx = flattened.index(t) + + # Combine the string values of all tokens in the original list + # up to and including the target keyword token t, to produce a + # query string with everything after the keyword token removed + text = ''.join(tok.value for tok in flattened[:idx + 1]) + return t, text + + return None, '' + + +# Postgresql dollar quote signs look like `$$` or `$tag$` +dollar_quote_regex = re.compile(r'^\$[^$]*\$$') + + +def is_open_quote(sql): + """Returns true if the query contains an unclosed quote""" + + # parsed can contain one or more semi-colon separated commands + parsed = sqlparse.parse(sql) + return any(_parsed_is_open_quote(p) for p in parsed) + + +def _parsed_is_open_quote(parsed): + # Look for unmatched single quotes, or unmatched dollar sign quotes + return any(tok.match(Token.Error, ("'", "$")) for tok in parsed.flatten()) + + +def parse_partial_identifier(word): + """Attempt to parse a (partially typed) word as an identifier + + word may include a schema qualification, like `schema_name.partial_name` + or `schema_name.` There may also be unclosed quotation marks, like + `"schema`, or `schema."partial_name` + + :param word: string representing a (partially complete) identifier + :return: sqlparse.sql.Identifier, or None + """ + + p = sqlparse.parse(word)[0] + n_tok = len(p.tokens) + if n_tok == 1 and isinstance(p.tokens[0], Identifier): + return p.tokens[0] + elif p.token_next_by(m=(Error, '"'))[1]: + # An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar' + # Close the double quote, then reparse + return parse_partial_identifier(word + '"') + else: + return None diff --git a/web/pgadmin/utils/sqlautocomplete/prioritization.py b/web/pgadmin/utils/sqlautocomplete/prioritization.py index 84b9bd366..ab38d4947 100644 --- a/web/pgadmin/utils/sqlautocomplete/prioritization.py +++ b/web/pgadmin/utils/sqlautocomplete/prioritization.py @@ -1,8 +1,8 @@ import re -from collections import defaultdict - import sqlparse from sqlparse.tokens import Name +from collections import defaultdict + white_space_regex = re.compile(r'\\s+', re.MULTILINE) @@ -10,7 +10,7 @@ white_space_regex = re.compile(r'\\s+', re.MULTILINE) def _compile_regex(keyword): # Surround the keyword with word boundaries and replace interior whitespace # with whitespace wildcards - pattern = r'\\b' + re.sub(white_space_regex, r'\\s+', keyword) + r'\\b' + pattern = r'\\b' + white_space_regex.sub(r'\\s+', keyword) + r'\\b' return re.compile(pattern, re.MULTILINE | re.IGNORECASE) diff --git a/web/pgadmin/utils/sqlautocomplete/sqlcompletion.py b/web/pgadmin/utils/sqlautocomplete/sqlcompletion.py new file mode 100644 index 000000000..5d18ce928 --- /dev/null +++ b/web/pgadmin/utils/sqlautocomplete/sqlcompletion.py @@ -0,0 +1,521 @@ +from __future__ import print_function +import sys +import re +import sqlparse +from collections import namedtuple +from sqlparse.sql import Comparison, Identifier, Where +from .parseutils.utils import ( + last_word, find_prev_keyword, parse_partial_identifier) +from .parseutils.tables import extract_tables +from .parseutils.ctes import isolate_query_ctes + +PY2 = sys.version_info[0] == 2 +PY3 = sys.version_info[0] == 3 + +if PY3: + string_types = str +else: + string_types = basestring + + +Special = namedtuple('Special', []) +Database = namedtuple('Database', []) +Schema = namedtuple('Schema', ['quoted']) +Schema.__new__.__defaults__ = (False,) +# FromClauseItem is a table/view/function used in the FROM clause +# `table_refs` contains the list of tables/... already in the statement, +# used to ensure that the alias we suggest is unique +FromClauseItem = namedtuple('FromClauseItem', 'schema table_refs local_tables') +Table = namedtuple('Table', ['schema', 'table_refs', 'local_tables']) +TableFormat = namedtuple('TableFormat', []) +View = namedtuple('View', ['schema', 'table_refs']) +# JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid' +JoinCondition = namedtuple('JoinCondition', ['table_refs', 'parent']) +# Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid' +Join = namedtuple('Join', ['table_refs', 'schema']) + +Function = namedtuple('Function', ['schema', 'table_refs', 'usage']) +# For convenience, don't require the `usage` argument in Function constructor +Function.__new__.__defaults__ = (None, tuple(), None) +Table.__new__.__defaults__ = (None, tuple(), tuple()) +View.__new__.__defaults__ = (None, tuple()) +FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple()) + +Column = namedtuple( + 'Column', + ['table_refs', 'require_last_table', 'local_tables', + 'qualifiable', 'context'] +) +Column.__new__.__defaults__ = (None, None, tuple(), False, None) + +Keyword = namedtuple('Keyword', ['last_token']) +Keyword.__new__.__defaults__ = (None,) +NamedQuery = namedtuple('NamedQuery', []) +Datatype = namedtuple('Datatype', ['schema']) +Alias = namedtuple('Alias', ['aliases']) + +Path = namedtuple('Path', []) + + +class SqlStatement(object): + def __init__(self, full_text, text_before_cursor): + self.identifier = None + self.word_before_cursor = word_before_cursor = last_word( + text_before_cursor, include='many_punctuations') + full_text = _strip_named_query(full_text) + text_before_cursor = _strip_named_query(text_before_cursor) + + full_text, text_before_cursor, self.local_tables = \ + isolate_query_ctes(full_text, text_before_cursor) + + self.text_before_cursor_including_last_word = text_before_cursor + + # If we've partially typed a word then word_before_cursor won't be an + # empty string. In that case we want to remove the partially typed + # string before sending it to the sqlparser. Otherwise the last token + # will always be the partially typed string which renders the smart + # completion useless because it will always return the list of + # keywords as completion. + if self.word_before_cursor: + if word_before_cursor[-1] == '(' or word_before_cursor[0] == '\\': + parsed = sqlparse.parse(text_before_cursor) + else: + text_before_cursor = \ + text_before_cursor[:-len(word_before_cursor)] + parsed = sqlparse.parse(text_before_cursor) + self.identifier = parse_partial_identifier(word_before_cursor) + else: + parsed = sqlparse.parse(text_before_cursor) + + full_text, text_before_cursor, parsed = \ + _split_multiple_statements(full_text, text_before_cursor, parsed) + + self.full_text = full_text + self.text_before_cursor = text_before_cursor + self.parsed = parsed + + self.last_token = \ + parsed and parsed.token_prev(len(parsed.tokens))[1] or '' + + def is_insert(self): + return self.parsed.token_first().value.lower() == 'insert' + + def get_tables(self, scope='full'): + """ Gets the tables available in the statement. + param `scope:` possible values: 'full', 'insert', 'before' + If 'insert', only the first table is returned. + If 'before', only tables before the cursor are returned. + If not 'insert' and the stmt is an insert, the first table is skipped. + """ + tables = extract_tables( + self.full_text if scope == 'full' else self.text_before_cursor) + if scope == 'insert': + tables = tables[:1] + elif self.is_insert(): + tables = tables[1:] + return tables + + def get_previous_token(self, token): + return self.parsed.token_prev(self.parsed.token_index(token))[1] + + def get_identifier_schema(self): + schema = \ + (self.identifier and self.identifier.get_parent_name()) or None + # If schema name is unquoted, lower-case it + if schema and self.identifier.value[0] != '"': + schema = schema.lower() + + return schema + + def reduce_to_prev_keyword(self, n_skip=0): + prev_keyword, self.text_before_cursor = \ + find_prev_keyword(self.text_before_cursor, n_skip=n_skip) + return prev_keyword + + +def suggest_type(full_text, text_before_cursor): + """Takes the full_text that is typed so far and also the text before the + cursor to suggest completion type and scope. + + Returns a tuple with a type of entity ('table', 'column' etc) and a scope. + A scope for a column category will be a list of tables. + """ + + if full_text.startswith('\\i '): + return (Path(),) + + # This is a temporary hack; the exception handling + # here should be removed once sqlparse has been fixed + try: + stmt = SqlStatement(full_text, text_before_cursor) + except (TypeError, AttributeError): + return [] + + return suggest_based_on_last_token(stmt.last_token, stmt) + + +named_query_regex = re.compile(r'^\s*\\ns\s+[A-z0-9\-_]+\s+') + + +def _strip_named_query(txt): + """ + This will strip "save named query" command in the beginning of the line: + '\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' + ' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' + """ + + if named_query_regex.match(txt): + txt = named_query_regex.sub('', txt) + return txt + + +function_body_pattern = re.compile(r'(\$.*?\$)([\s\S]*?)\1', re.M) + + +def _find_function_body(text): + split = function_body_pattern.search(text) + return (split.start(2), split.end(2)) if split else (None, None) + + +def _statement_from_function(full_text, text_before_cursor, statement): + current_pos = len(text_before_cursor) + body_start, body_end = _find_function_body(full_text) + if body_start is None: + return full_text, text_before_cursor, statement + if not body_start <= current_pos < body_end: + return full_text, text_before_cursor, statement + full_text = full_text[body_start:body_end] + text_before_cursor = text_before_cursor[body_start:] + parsed = sqlparse.parse(text_before_cursor) + return _split_multiple_statements(full_text, text_before_cursor, parsed) + + +def _split_multiple_statements(full_text, text_before_cursor, parsed): + if len(parsed) > 1: + # Multiple statements being edited -- isolate the current one by + # cumulatively summing statement lengths to find the one that bounds + # the current position + current_pos = len(text_before_cursor) + stmt_start, stmt_end = 0, 0 + + for statement in parsed: + stmt_len = len(str(statement)) + stmt_start, stmt_end = stmt_end, stmt_end + stmt_len + + if stmt_end >= current_pos: + text_before_cursor = full_text[stmt_start:current_pos] + full_text = full_text[stmt_start:] + break + + elif parsed: + # A single statement + statement = parsed[0] + else: + # The empty string + return full_text, text_before_cursor, None + + token2 = None + if statement.get_type() in ('CREATE', 'CREATE OR REPLACE'): + token1 = statement.token_first() + if token1: + token1_idx = statement.token_index(token1) + token2 = statement.token_next(token1_idx)[1] + if token2 and token2.value.upper() == 'FUNCTION': + full_text, text_before_cursor, statement = _statement_from_function( + full_text, text_before_cursor, statement + ) + return full_text, text_before_cursor, statement + + +def suggest_based_on_last_token(token, stmt): + + if isinstance(token, string_types): + token_v = token.lower() + elif isinstance(token, Comparison): + # If 'token' is a Comparison type such as + # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling + # token.value on the comparison type will only return the lhs of the + # comparison. In this case a.id. So we need to do token.tokens to get + # both sides of the comparison and pick the last token out of that + # list. + token_v = token.tokens[-1].value.lower() + elif isinstance(token, Where): + # sqlparse groups all tokens from the where clause into a single token + # list. This means that token.value may be something like + # 'where foo > 5 and '. We need to look "inside" token.tokens to handle + # suggestions in complicated where clauses correctly + prev_keyword = stmt.reduce_to_prev_keyword() + return suggest_based_on_last_token(prev_keyword, stmt) + elif isinstance(token, Identifier): + # If the previous token is an identifier, we can suggest datatypes if + # we're in a parenthesized column/field list, e.g.: + # CREATE TABLE foo (Identifier + # CREATE FUNCTION foo (Identifier + # If we're not in a parenthesized list, the most likely scenario is the + # user is about to specify an alias, e.g.: + # SELECT Identifier + # SELECT foo FROM Identifier + prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor) + if prev_keyword and prev_keyword.value == '(': + # Suggest datatypes + return suggest_based_on_last_token('type', stmt) + else: + return (Keyword(),) + else: + token_v = token.value.lower() + + if not token: + return (Keyword(),) + elif token_v.endswith('('): + p = sqlparse.parse(stmt.text_before_cursor)[0] + + if p.tokens and isinstance(p.tokens[-1], Where): + # Four possibilities: + # 1 - Parenthesized clause like "WHERE foo AND (" + # Suggest columns/functions + # 2 - Function call like "WHERE foo(" + # Suggest columns/functions + # 3 - Subquery expression like "WHERE EXISTS (" + # Suggest keywords, in order to do a subquery + # 4 - Subquery OR array comparison like "WHERE foo = ANY(" + # Suggest columns/functions AND keywords. (If we wanted to be + # really fancy, we could suggest only array-typed columns) + + column_suggestions = suggest_based_on_last_token('where', stmt) + + # Check for a subquery expression (cases 3 & 4) + where = p.tokens[-1] + prev_tok = where.token_prev(len(where.tokens) - 1)[1] + + if isinstance(prev_tok, Comparison): + # e.g. "SELECT foo FROM bar WHERE foo = ANY(" + prev_tok = prev_tok.tokens[-1] + + prev_tok = prev_tok.value.lower() + if prev_tok == 'exists': + return (Keyword(),) + else: + return column_suggestions + + # Get the token before the parens + prev_tok = p.token_prev(len(p.tokens) - 1)[1] + + if ( + prev_tok and prev_tok.value and + prev_tok.value.lower().split(' ')[-1] == 'using' + ): + # tbl1 INNER JOIN tbl2 USING (col1, col2) + tables = stmt.get_tables('before') + + # suggest columns that are present in more than one table + return (Column(table_refs=tables, + require_last_table=True, + local_tables=stmt.local_tables),) + + elif p.token_first().value.lower() == 'select': + # If the lparen is preceeded by a space chances are we're about to + # do a sub-select. + if last_word(stmt.text_before_cursor, + 'all_punctuations').startswith('('): + return (Keyword(),) + prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1] + if prev_prev_tok and prev_prev_tok.normalized == 'INTO': + return ( + Column(table_refs=stmt.get_tables('insert'), context='insert'), + ) + # We're probably in a function argument list + return (Column(table_refs=extract_tables(stmt.full_text), + local_tables=stmt.local_tables, qualifiable=True),) + elif token_v == 'set': + return (Column(table_refs=stmt.get_tables(), + local_tables=stmt.local_tables),) + elif token_v in ('select', 'where', 'having', 'by', 'distinct'): + # Check for a table alias or schema qualification + parent = (stmt.identifier and stmt.identifier.get_parent_name()) or [] + tables = stmt.get_tables() + if parent: + tables = tuple(t for t in tables if identifies(parent, t)) + return (Column(table_refs=tables, local_tables=stmt.local_tables), + Table(schema=parent), + View(schema=parent), + Function(schema=parent),) + else: + return (Column(table_refs=tables, local_tables=stmt.local_tables, + qualifiable=True), + Function(schema=None), + Keyword(token_v.upper()),) + elif token_v == 'as': + # Don't suggest anything for aliases + return () + elif ( + (token_v.endswith('join') and token.is_keyword) or + (token_v in ('copy', 'from', 'update', 'into', 'describe', 'truncate')) + ): + + schema = stmt.get_identifier_schema() + tables = extract_tables(stmt.text_before_cursor) + is_join = token_v.endswith('join') and token.is_keyword + + # Suggest tables from either the currently-selected schema or the + # public schema if no schema has been specified + suggest = [] + + if not schema: + # Suggest schemas + suggest.insert(0, Schema()) + + if token_v == 'from' or is_join: + suggest.append(FromClauseItem(schema=schema, + table_refs=tables, + local_tables=stmt.local_tables)) + elif token_v == 'truncate': + suggest.append(Table(schema)) + else: + suggest.extend((Table(schema), View(schema))) + + if is_join and _allow_join(stmt.parsed): + tables = stmt.get_tables('before') + suggest.append(Join(table_refs=tables, schema=schema)) + + return tuple(suggest) + + elif token_v == 'function': + schema = stmt.get_identifier_schema() + # stmt.get_previous_token will fail for e.g. + # `SELECT 1 FROM functions WHERE function:` + try: + prev = stmt.get_previous_token(token).value.lower() + if prev in('drop', 'alter', 'create', 'create or replace'): + return (Function(schema=schema, usage='signature'),) + except ValueError: + pass + return tuple() + + elif token_v in ('table', 'view'): + # E.g. 'ALTER TABLE ' + rel_type = \ + {'table': Table, 'view': View, 'function': Function}[token_v] + schema = stmt.get_identifier_schema() + if schema: + return (rel_type(schema=schema),) + else: + return (Schema(), rel_type(schema=schema)) + + elif token_v == 'column': + # E.g. 'ALTER TABLE foo ALTER COLUMN bar + return (Column(table_refs=stmt.get_tables()),) + + elif token_v == 'on': + tables = stmt.get_tables('before') + parent = \ + (stmt.identifier and stmt.identifier.get_parent_name()) or None + if parent: + # "ON parent." + # parent can be either a schema name or table alias + filteredtables = tuple(t for t in tables if identifies(parent, t)) + sugs = [Column(table_refs=filteredtables, + local_tables=stmt.local_tables), + Table(schema=parent), + View(schema=parent), + Function(schema=parent)] + if filteredtables and _allow_join_condition(stmt.parsed): + sugs.append(JoinCondition(table_refs=tables, + parent=filteredtables[-1])) + return tuple(sugs) + else: + # ON + # Use table alias if there is one, otherwise the table name + aliases = tuple(t.ref for t in tables) + if _allow_join_condition(stmt.parsed): + return (Alias(aliases=aliases), JoinCondition( + table_refs=tables, parent=None)) + else: + return (Alias(aliases=aliases),) + + elif token_v in ('c', 'use', 'database', 'template'): + # "\c ", "DROP DATABASE ", + # "CREATE DATABASE WITH TEMPLATE " + return (Database(),) + elif token_v == 'schema': + # DROP SCHEMA schema_name, SET SCHEMA schema name + prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2) + quoted = prev_keyword and prev_keyword.value.lower() == 'set' + return (Schema(quoted),) + elif token_v.endswith(',') or token_v in ('=', 'and', 'or'): + prev_keyword = stmt.reduce_to_prev_keyword() + if prev_keyword: + return suggest_based_on_last_token(prev_keyword, stmt) + else: + return () + elif token_v in ('type', '::'): + # ALTER TABLE foo SET DATA TYPE bar + # SELECT foo::bar + # Note that tables are a form of composite type in postgresql, so + # they're suggested here as well + schema = stmt.get_identifier_schema() + suggestions = [Datatype(schema=schema), + Table(schema=schema)] + if not schema: + suggestions.append(Schema()) + return tuple(suggestions) + elif token_v in {'alter', 'create', 'drop'}: + return (Keyword(token_v.upper()),) + elif token.is_keyword: + # token is a keyword we haven't implemented any special handling for + # go backwards in the query until we find one we do recognize + prev_keyword = stmt.reduce_to_prev_keyword(n_skip=1) + if prev_keyword: + return suggest_based_on_last_token(prev_keyword, stmt) + else: + return (Keyword(token_v.upper()),) + else: + return (Keyword(),) + + +def identifies(id, ref): + """Returns true if string `id` matches TableReference `ref`""" + + return id == ref.alias or id == ref.name or ( + ref.schema and (id == ref.schema + '.' + ref.name)) + + +def _allow_join_condition(statement): + """ + Tests if a join condition should be suggested + + We need this to avoid bad suggestions when entering e.g. + select * from tbl1 a join tbl2 b on a.id = + So check that the preceding token is a ON, AND, or OR keyword, instead of + e.g. an equals sign. + + :param statement: an sqlparse.sql.Statement + :return: boolean + """ + + if not statement or not statement.tokens: + return False + + last_tok = statement.token_prev(len(statement.tokens))[1] + return last_tok.value.lower() in ('on', 'and', 'or') + + +def _allow_join(statement): + """ + Tests if a join should be suggested + + We need this to avoid bad suggestions when entering e.g. + select * from tbl1 a join tbl2 b + So check that the preceding token is a JOIN keyword + + :param statement: an sqlparse.sql.Statement + :return: boolean + """ + + if not statement or not statement.tokens: + return False + + last_tok = statement.token_prev(len(statement.tokens))[1] + return ( + last_tok.value.lower().endswith('join') and + last_tok.value.lower() not in('cross join', 'natural join') + )