From 0a354055a960bb9070bf271cc23ba8556605e5c7 Mon Sep 17 00:00:00 2001 From: Akshay Joshi Date: Sat, 21 May 2016 16:04:05 +0530 Subject: [PATCH] Adding support for autocomplete in the SQL Editor. In Query editor, we can use the autocomplete feature by using keyword combination - 'Ctrl + Space'. --- requirements_py2.txt | 1 + requirements_py3.txt | 1 + .../sqlautocomplete/sql/9.1_plus/columns.sql | 29 + .../sql/9.1_plus/databases.sql | 4 + .../sql/9.1_plus/datatypes.sql | 9 + .../sql/9.1_plus/functions.sql | 30 + .../sqlautocomplete/sql/9.1_plus/keywords.sql | 2 + .../sqlautocomplete/sql/9.1_plus/schema.sql | 6 + .../sql/9.1_plus/tableview.sql | 17 + web/pgadmin/tools/sqleditor/__init__.py | 39 + .../tools/sqleditor/static/css/sqleditor.css | 47 + .../templates/sqleditor/js/sqleditor.js | 107 ++- web/pgadmin/utils/sqlautocomplete/__init__.py | 0 .../utils/sqlautocomplete/autocomplete.py | 863 ++++++++++++++++++ .../utils/sqlautocomplete/completion.py | 67 ++ web/pgadmin/utils/sqlautocomplete/counter.py | 189 ++++ .../sqlautocomplete/function_metadata.py | 149 +++ .../utils/sqlautocomplete/parseutils.py | 288 ++++++ .../utils/sqlautocomplete/prioritization.py | 49 + 19 files changed, 1895 insertions(+), 2 deletions(-) create mode 100644 web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/columns.sql create mode 100644 web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/databases.sql create mode 100644 web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/datatypes.sql create mode 100644 web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/functions.sql create mode 100644 web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/keywords.sql create mode 100644 web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/schema.sql create mode 100644 web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/tableview.sql create mode 100644 web/pgadmin/utils/sqlautocomplete/__init__.py create mode 100644 web/pgadmin/utils/sqlautocomplete/autocomplete.py create mode 100644 web/pgadmin/utils/sqlautocomplete/completion.py create mode 100644 web/pgadmin/utils/sqlautocomplete/counter.py create mode 100644 web/pgadmin/utils/sqlautocomplete/function_metadata.py create mode 100644 web/pgadmin/utils/sqlautocomplete/parseutils.py create mode 100644 web/pgadmin/utils/sqlautocomplete/prioritization.py diff --git a/requirements_py2.txt b/requirements_py2.txt index 8fcafcd9e..a442e3665 100644 --- a/requirements_py2.txt +++ b/requirements_py2.txt @@ -43,3 +43,4 @@ traceback2==1.4.0 unittest2==1.1.0 Werkzeug==0.9.6 WTForms==2.0.2 +sqlparse==0.1.19 diff --git a/requirements_py3.txt b/requirements_py3.txt index 0f813e646..233b14f85 100644 --- a/requirements_py3.txt +++ b/requirements_py3.txt @@ -37,3 +37,4 @@ unittest2==1.1.0 Werkzeug==0.9.6 wheel==0.24.0 WTForms==2.0.2 +sqlparse==0.1.19 diff --git a/web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/columns.sql b/web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/columns.sql new file mode 100644 index 000000000..d4a109484 --- /dev/null +++ b/web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/columns.sql @@ -0,0 +1,29 @@ +{# SQL query for getting columns #} +{% if object_name == 'table' %} +SELECT + att.attname column_name +FROM pg_catalog.pg_attribute att + INNER JOIN pg_catalog.pg_class cls + ON att.attrelid = cls.oid + INNER JOIN pg_catalog.pg_namespace nsp + ON cls.relnamespace = nsp.oid + WHERE cls.relkind = ANY(array['r']) + AND NOT att.attisdropped + AND att.attnum > 0 + AND (nsp.nspname = '{{schema_name}}' AND cls.relname = '{{rel_name}}') + ORDER BY 1 +{% endif %} +{% if object_name == 'view' %} +SELECT + att.attname column_name +FROM pg_catalog.pg_attribute att + INNER JOIN pg_catalog.pg_class cls + ON att.attrelid = cls.oid + INNER JOIN pg_catalog.pg_namespace nsp + ON cls.relnamespace = nsp.oid + WHERE cls.relkind = ANY(array['v', 'm']) + AND NOT att.attisdropped + AND att.attnum > 0 + AND (nsp.nspname = '{{schema_name}}' AND cls.relname = '{{rel_name}}') + ORDER BY 1 +{% endif %} \ No newline at end of file diff --git a/web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/databases.sql b/web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/databases.sql new file mode 100644 index 000000000..90a84e022 --- /dev/null +++ b/web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/databases.sql @@ -0,0 +1,4 @@ +{# SQL query for getting databases #} +SELECT d.datname + FROM pg_catalog.pg_database d + ORDER BY 1 diff --git a/web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/datatypes.sql b/web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/datatypes.sql new file mode 100644 index 000000000..7a2fc9a4e --- /dev/null +++ b/web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/datatypes.sql @@ -0,0 +1,9 @@ +{# SQL query for getting datatypes #} +SELECT n.nspname schema_name, + t.typname object_name +FROM pg_catalog.pg_type t + INNER JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace +WHERE (t.typrelid = 0 OR (SELECT c.relkind = 'c' FROM pg_catalog.pg_class c WHERE c.oid = t.typrelid)) + AND NOT EXISTS(SELECT 1 FROM pg_catalog.pg_type el WHERE el.oid = t.typelem AND el.typarray = t.oid) + AND n.nspname IN ({{schema_names}}) +ORDER BY 1, 2; diff --git a/web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/functions.sql b/web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/functions.sql new file mode 100644 index 000000000..826da2d81 --- /dev/null +++ b/web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/functions.sql @@ -0,0 +1,30 @@ +{# ============= Fetch the list of functions based on given schema_names ============= #} +{% if func_name %} +SELECT n.nspname schema_name, + p.proname func_name, + pg_catalog.pg_get_function_arguments(p.oid) arg_list, + pg_catalog.pg_get_function_result(p.oid) return_type, + p.proisagg is_aggregate, + p.proiswindow is_window, + p.proretset is_set_returning +FROM pg_catalog.pg_proc p + INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace + WHERE n.nspname = '{{schema_name}}' AND p.proname = '{{func_name}}' + AND p.proretset + ORDER BY 1, 2 +{% else %} +SELECT n.nspname schema_name, + p.proname object_name, + pg_catalog.pg_get_function_arguments(p.oid) arg_list, + pg_catalog.pg_get_function_result(p.oid) return_type, + p.proisagg is_aggregate, + p.proiswindow is_window, + p.proretset is_set_returning +FROM pg_catalog.pg_proc p + INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace + WHERE n.nspname IN ({{schema_names}}) +{% if is_set_returning %} + AND p.proretset +{% endif %} + ORDER BY 1, 2 +{% endif %} \ No newline at end of file diff --git a/web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/keywords.sql b/web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/keywords.sql new file mode 100644 index 000000000..cd571e170 --- /dev/null +++ b/web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/keywords.sql @@ -0,0 +1,2 @@ +{# SQL query for getting keywords #} +SELECT upper(word) as word FROM pg_get_keywords() diff --git a/web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/schema.sql b/web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/schema.sql new file mode 100644 index 000000000..48ac88e09 --- /dev/null +++ b/web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/schema.sql @@ -0,0 +1,6 @@ +{# SQL query for getting current_schemas #} +{% if search_path %} +SELECT * FROM unnest(current_schemas(true)) AS schema +{% else %} +SELECT nspname AS schema FROM pg_catalog.pg_namespace ORDER BY 1 +{% endif %} diff --git a/web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/tableview.sql b/web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/tableview.sql new file mode 100644 index 000000000..395af2e24 --- /dev/null +++ b/web/pgadmin/misc/templates/sqlautocomplete/sql/9.1_plus/tableview.sql @@ -0,0 +1,17 @@ +{# ============= Fetch the list of tables/view based on given schema_names ============= #} +{% if object_name == 'tables' %} +SELECT n.nspname schema_name, + c.relname object_name +FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE c.relkind = ANY(array['r']) and n.nspname IN ({{schema_names}}) + ORDER BY 1,2 +{% endif %} +{% if object_name == 'views' %} +SELECT n.nspname schema_name, + c.relname object_name +FROM pg_catalog.pg_class c + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE c.relkind = ANY(array['v', 'm']) and n.nspname IN ({{schema_names}}) + ORDER BY 1,2 +{% endif %} diff --git a/web/pgadmin/tools/sqleditor/__init__.py b/web/pgadmin/tools/sqleditor/__init__.py index 6afdef58e..48fc8eca0 100644 --- a/web/pgadmin/tools/sqleditor/__init__.py +++ b/web/pgadmin/tools/sqleditor/__init__.py @@ -24,6 +24,7 @@ from pgadmin.utils.driver import get_driver from config import PG_DEFAULT_DRIVER from pgadmin.tools.sqleditor.command import QueryToolCommand from pgadmin.utils import get_storage_directory +from pgadmin.utils.sqlautocomplete.autocomplete import SQLAutoComplete # import unquote from urlib for python2.x and python3.x try: @@ -890,6 +891,44 @@ def set_auto_rollback(trans_id): return make_json_response(data={'status': status, 'result': res}) +@blueprint.route('/autocomplete/', methods=["PUT", "POST"]) +@login_required +def auto_complete(trans_id): + """ + This method implements the autocomplete feature. + + Args: + trans_id: unique transaction id + """ + full_sql = '' + text_before_cursor = '' + + if request.data: + data = json.loads(request.data.decode()) + else: + data = request.args or request.form + + if len(data) > 0: + full_sql = data[0] + text_before_cursor = data[1] + + # Check the transaction and connection status + status, error_msg, conn, trans_obj, session_obj = check_transaction_status(trans_id) + if status and conn is not None \ + and trans_obj is not None and session_obj is not None: + + # Create object of SQLAutoComplete class and pass connection object + auto_complete_obj = SQLAutoComplete(sid=trans_obj.sid, did=trans_obj.did, conn=conn) + + # Get the auto completion suggestions. + res = auto_complete_obj.get_completions(full_sql, text_before_cursor) + else: + status = False + res = error_msg + + return make_json_response(data={'status': status, 'result': res}) + + @blueprint.route("/sqleditor.js") @login_required def script(): diff --git a/web/pgadmin/tools/sqleditor/static/css/sqleditor.css b/web/pgadmin/tools/sqleditor/static/css/sqleditor.css index 6db04cb0a..99d9b4f3b 100644 --- a/web/pgadmin/tools/sqleditor/static/css/sqleditor.css +++ b/web/pgadmin/tools/sqleditor/static/css/sqleditor.css @@ -243,3 +243,50 @@ width: 100%; overflow: auto; } + +.CodeMirror-hints { + position: absolute; + z-index: 10; + overflow: hidden; + list-style: none; + + margin: 0; + padding: 2px; + + -webkit-box-shadow: 2px 3px 5px rgba(0,0,0,.2); + -moz-box-shadow: 2px 3px 5px rgba(0,0,0,.2); + box-shadow: 2px 3px 5px rgba(0,0,0,.2); + border-radius: 3px; + border: 1px solid silver; + + background: white; + font-size: 90%; + font-family: monospace; + + max-height: 20em; + overflow-y: auto; +} + +.CodeMirror-hint { + margin: 0; + padding: 0 4px; + border-radius: 2px; + max-width: 19em; + overflow: hidden; + white-space: pre; + color: black; + cursor: pointer; +} + +li.CodeMirror-hint-active { + background: #08f; + color: white; +} + +.sqleditor-hint { + padding-left: 20px; +} + +.CodeMirror-hint .fa::before { + padding-right: 7px; +} \ No newline at end of file diff --git a/web/pgadmin/tools/sqleditor/templates/sqleditor/js/sqleditor.js b/web/pgadmin/tools/sqleditor/templates/sqleditor/js/sqleditor.js index 7bb193213..b6ae5fe79 100644 --- a/web/pgadmin/tools/sqleditor/templates/sqleditor/js/sqleditor.js +++ b/web/pgadmin/tools/sqleditor/templates/sqleditor/js/sqleditor.js @@ -6,6 +6,7 @@ define( 'codemirror/mode/sql/sql', 'codemirror/addon/selection/mark-selection', 'codemirror/addon/selection/active-line', 'backbone.paginator', 'codemirror/addon/fold/foldgutter', 'codemirror/addon/fold/foldcode', + 'codemirror/addon/hint/show-hint', 'codemirror/addon/hint/sql-hint', 'codemirror/addon/fold/pgadmin-sqlfoldcode', 'backgrid.paginator', 'wcdocker', 'pgadmin.file_manager' ], @@ -238,7 +239,8 @@ define( rangeFinder: CodeMirror.fold.combine(CodeMirror.pgadminBeginRangeFinder, CodeMirror.pgadminIfRangeFinder, CodeMirror.pgadminLoopRangeFinder, CodeMirror.pgadminCaseRangeFinder) }, - gutters: ["CodeMirror-linenumbers", "CodeMirror-foldgutter"] + gutters: ["CodeMirror-linenumbers", "CodeMirror-foldgutter"], + extraKeys: {"Ctrl-Space": "autocomplete"} }); // Create panels for 'Data Output', 'Explain', 'Messages' and 'History' @@ -295,6 +297,107 @@ define( self.history_panel = main_docker.addPanel('history', wcDocker.DOCK.STACKED, self.data_output_panel); self.render_history_grid(); + + /* We have override/register the hint function of CodeMirror + * to provide our own hint logic. + */ + CodeMirror.registerHelper("hint", "sql", function(editor, options) { + var data = [], + result = []; + var doc = editor.getDoc(); + var cur = doc.getCursor(); + var current_line = cur.line; // gets the line number in the cursor position + var current_cur = cur.ch; // get the current cursor position + + /* Render function for hint to add our own class + * and icon as per the object type. + */ + var hint_render = function(elt, data, cur) { + var el = document.createElement('span'); + + switch(cur.type) { + case 'database': + el.className = 'sqleditor-hint pg-icon-' + cur.type; + break; + case 'datatype': + el.className = 'sqleditor-hint icon-type'; + break; + case 'keyword': + el.className = 'fa fa-key'; + break; + case 'table alias': + el.className = 'fa fa-at'; + break; + default: + el.className = 'sqleditor-hint icon-' + cur.type; + } + + el.appendChild(document.createTextNode(cur.text)); + elt.appendChild(el); + }; + + var full_text = doc.getValue(); + // Get the text from start to the current cursor position. + var text_before_cursor = doc.getRange({ line: 0, ch: 0 }, + { line: current_line, ch: current_cur }); + + data.push(full_text); + data.push(text_before_cursor); + + // Make ajax call to find the autocomplete data + $.ajax({ + url: "{{ url_for('sqleditor.index') }}" + "autocomplete/" + self.transId, + method: 'POST', + async: false, + contentType: "application/json", + data: JSON.stringify(data), + success: function(res) { + _.each(res.data.result, function(obj, key) { + result.push({ + text: key, type: obj.object_type, + render: hint_render + }); + }); + + // Sort function to sort the suggestion's alphabetically. + result.sort(function(a, b){ + var textA = a.text.toLowerCase(), textB = b.text.toLowerCase() + if (textA < textB) //sort string ascending + return -1 + if (textA > textB) + return 1 + return 0 //default return value (no sorting) + }) + } + }); + + /* Below logic find the start and end point + * to replace the selected auto complete suggestion. + */ + var token = editor.getTokenAt(cur), start, end, search; + if (token.end > cur.ch) { + token.end = cur.ch; + token.string = token.string.slice(0, cur.ch - token.start); + } + + if (token.string.match(/^[.`\w@]\w*$/)) { + search = token.string; + start = token.start; + end = token.end; + } else { + start = end = cur.ch; + search = ""; + } + + /* Added 1 in the start position if search string + * started with "." or "`" else auto complete of code mirror + * will remove the "." when user select any suggestion. + */ + if (search.charAt(0) == "." || search.charAt(0) == "``") + start += 1; + + return {list: result, from: {line: current_line, ch: start }, to: { line: current_line, ch: end }}; + }); }, /* This function is responsible to create and render the @@ -782,7 +885,7 @@ define( el: self.container, handler: self }); - self.transId = self.container.data('transId'); + self.transId = self.gridView.transId = self.container.data('transId'); self.gridView.editor_title = editor_title; self.gridView.current_file = undefined; diff --git a/web/pgadmin/utils/sqlautocomplete/__init__.py b/web/pgadmin/utils/sqlautocomplete/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/web/pgadmin/utils/sqlautocomplete/autocomplete.py b/web/pgadmin/utils/sqlautocomplete/autocomplete.py new file mode 100644 index 000000000..9ac002160 --- /dev/null +++ b/web/pgadmin/utils/sqlautocomplete/autocomplete.py @@ -0,0 +1,863 @@ +########################################################################## +# +# pgAdmin 4 - PostgreSQL Tools +# +# Copyright (C) 2013 - 2016, The pgAdmin Development Team +# This software is released under the PostgreSQL Licence +# +########################################################################## + +"""A blueprint module implementing the sql auto complete feature.""" + +import sys +import re +import sqlparse +import itertools +import operator +from collections import namedtuple +from sqlparse.sql import Comparison, Identifier, Where +from .parseutils import ( + last_word, extract_tables, find_prev_keyword, parse_partial_identifier) +from .prioritization import PrevalenceCounter +from .completion import Completion +from .function_metadata import FunctionMetadata +from flask import render_template +from pgadmin.utils.driver import get_driver +from config import PG_DEFAULT_DRIVER + +PY2 = sys.version_info[0] == 2 +PY3 = sys.version_info[0] == 3 + +if PY3: + string_types = str +else: + string_types = basestring + +Database = namedtuple('Database', []) +Schema = namedtuple('Schema', []) +Table = namedtuple('Table', ['schema']) + +Function = namedtuple('Function', ['schema', 'filter']) +# For convenience, don't require the `filter` argument in Function constructor +Function.__new__.__defaults__ = (None, None) + +Column = namedtuple('Column', ['tables', 'drop_unique']) +Column.__new__.__defaults__ = (None, None) + +View = namedtuple('View', ['schema']) +Keyword = namedtuple('Keyword', []) +Datatype = namedtuple('Datatype', ['schema']) +Alias = namedtuple('Alias', ['aliases']) +Match = namedtuple('Match', ['completion', 'priority']) + +try: + from collections import Counter +except ImportError: + # python 2.6 + from .counter import Counter + +# Regex for finding "words" in documents. +_FIND_WORD_RE = re.compile(r'([a-zA-Z0-9_]+|[^a-zA-Z0-9_\s]+)') +_FIND_BIG_WORD_RE = re.compile(r'([^\s]+)') + + +class SQLAutoComplete(object): + """ + class SQLAutoComplete + + This class is used to provide the postgresql's autocomplete feature. + This class used sqlparse to parse the given sql and psycopg2 to make + the connection and get the tables, schemas, functions etc. based on + the query. + """ + + def __init__(self, **kwargs): + """ + This method is used to initialize the class. + + Args: + **kwargs : N number of parameters + """ + + self.sid = kwargs['sid'] if 'sid' in kwargs else None + self.did = kwargs['did'] if 'did' in kwargs else None + self.conn = kwargs['conn'] if 'conn' in kwargs else None + self.keywords = [] + + manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(self.sid) + + ver = manager.version + # we will set template path for sql scripts + if ver >= 90100: + self.sql_path = 'sqlautocomplete/sql/9.1_plus' + + self.search_path = [] + # Fetch the search path + if self.conn.connected(): + query = render_template("/".join([self.sql_path, 'schema.sql']), search_path=True) + status, res = self.conn.execute_dict(query) + if status: + for record in res['rows']: + self.search_path.append(record['schema']) + + # Fetch the keywords + query = render_template("/".join([self.sql_path, 'keywords.sql'])) + status, res = self.conn.execute_dict(query) + if status: + for record in res['rows']: + self.keywords.append(record['word']) + + self.text_before_cursor = None + self.prioritizer = PrevalenceCounter(self.keywords) + + self.reserved_words = set() + for x in self.keywords: + self.reserved_words.update(x.split()) + self.name_pattern = re.compile("^[_a-z][_a-z0-9\$]*$") + + def escape_name(self, name): + if name and ((not self.name_pattern.match(name)) or + (name.upper() in self.reserved_words)): + name = '"%s"' % name + + return name + + def unescape_name(self, name): + if name and name[0] == '"' and name[-1] == '"': + name = name[1:-1] + + return name + + def escaped_names(self, names): + return [self.escape_name(name) for name in names] + + def find_matches(self, text, collection, mode='fuzzy', + meta=None, meta_collection=None): + """ + Find completion matches for the given text. + + Given the user's input text and a collection of available + completions, find completions matching the last word of the + text. + + `mode` can be either 'fuzzy', or 'strict' + 'fuzzy': fuzzy matching, ties broken by name prevalance + `keyword`: start only matching, ties broken by keyword prevalance + + yields prompt_toolkit Completion instances for any matches found + in the collection of available completions. + + Args: + text: + collection: + mode: + meta: + meta_collection: + """ + + text = last_word(text, include='most_punctuations').lower() + text_len = len(text) + + if text and text[0] == '"': + # text starts with double quote; user is manually escaping a name + # Match on everything that follows the double-quote. Note that + # text_len is calculated before removing the quote, so the + # Completion.position value is correct + text = text[1:] + + if mode == 'fuzzy': + fuzzy = True + priority_func = self.prioritizer.name_count + else: + fuzzy = False + priority_func = self.prioritizer.keyword_count + + # Construct a `_match` function for either fuzzy or non-fuzzy matching + # The match function returns a 2-tuple used for sorting the matches, + # or None if the item doesn't match + # Note: higher priority values mean more important, so use negative + # signs to flip the direction of the tuple + if fuzzy: + regex = '.*?'.join(map(re.escape, text)) + pat = re.compile('(%s)' % regex) + + def _match(item): + r = pat.search(self.unescape_name(item.lower())) + if r: + return -len(r.group()), -r.start() + else: + match_end_limit = len(text) + + def _match(item): + match_point = item.lower().find(text, 0, match_end_limit) + if match_point >= 0: + # Use negative infinity to force keywords to sort after all + # fuzzy matches + return -float('Infinity'), -match_point + + if meta_collection: + # Each possible completion in the collection has a corresponding + # meta-display string + collection = zip(collection, meta_collection) + else: + # All completions have an identical meta + collection = zip(collection, itertools.repeat(meta)) + + matches = [] + + for item, meta in collection: + sort_key = _match(item) + if sort_key: + if meta and len(meta) > 50: + # Truncate meta-text to 50 characters, if necessary + meta = meta[:47] + u'...' + + # Lexical order of items in the collection, used for + # tiebreaking items with the same match group length and start + # position. Since we use *higher* priority to mean "more + # important," we use -ord(c) to prioritize "aa" > "ab" and end + # with 1 to prioritize shorter strings (ie "user" > "users"). + # We also use the unescape_name to make sure quoted names have + # the same priority as unquoted names. + lexical_priority = tuple(-ord(c) for c in self.unescape_name(item)) + (1,) + + priority = sort_key, priority_func(item), lexical_priority + + matches.append(Match( + completion=Completion(item, -text_len, display_meta=meta), + priority=priority)) + + return matches + + def get_completions(self, text, text_before_cursor): + self.text_before_cursor = text_before_cursor + + word_before_cursor = self.get_word_before_cursor(word=True) + matches = [] + suggestions = self.suggest_type(text, text_before_cursor) + + for suggestion in suggestions: + suggestion_type = type(suggestion) + + # Map suggestion type to method + # e.g. 'table' -> self.get_table_matches + matcher = self.suggestion_matchers[suggestion_type] + matches.extend(matcher(self, suggestion, word_before_cursor)) + + # Sort matches so highest priorities are first + matches = sorted(matches, key=operator.attrgetter('priority'), + reverse=True) + + result = dict() + for m in matches: + result[m.completion.display] = {'object_type': m.completion.display_meta} + + return result + + def get_column_matches(self, suggestion, word_before_cursor): + tables = suggestion.tables + scoped_cols = self.populate_scoped_cols(tables) + + if suggestion.drop_unique: + # drop_unique is used for 'tb11 JOIN tbl2 USING (...' which should + # suggest only columns that appear in more than one table + scoped_cols = [col for (col, count) + in Counter(scoped_cols).items() + if count > 1 and col != '*'] + + return self.find_matches(word_before_cursor, scoped_cols, mode='strict', meta='column') + + def get_function_matches(self, suggestion, word_before_cursor): + if suggestion.filter == 'is_set_returning': + # Only suggest set-returning functions + funcs = self.populate_functions(suggestion.schema) + else: + funcs = self.populate_schema_objects(suggestion.schema, 'functions') + + # Function overloading means we way have multiple functions of the same + # name at this point, so keep unique names only + funcs = set(funcs) + + funcs = self.find_matches(word_before_cursor, funcs, mode='strict', meta='function') + + return funcs + + def get_schema_matches(self, _, word_before_cursor): + schema_names = [] + + query = render_template("/".join([self.sql_path, 'schema.sql'])) + status, res = self.conn.execute_dict(query) + if status: + for record in res['rows']: + schema_names.append(record['schema']) + + # Unless we're sure the user really wants them, hide schema names + # starting with pg_, which are mostly temporary schemas + if not word_before_cursor.startswith('pg_'): + schema_names = [s for s in schema_names if not s.startswith('pg_')] + + return self.find_matches(word_before_cursor, schema_names, mode='strict', meta='schema') + + def get_table_matches(self, suggestion, word_before_cursor): + tables = self.populate_schema_objects(suggestion.schema, 'tables') + + # Unless we're sure the user really wants them, don't suggest the + # pg_catalog tables that are implicitly on the search path + if not suggestion.schema and ( + not word_before_cursor.startswith('pg_')): + tables = [t for t in tables if not t.startswith('pg_')] + + return self.find_matches(word_before_cursor, tables, mode='strict', meta='table') + + def get_view_matches(self, suggestion, word_before_cursor): + views = self.populate_schema_objects(suggestion.schema, 'views') + + if not suggestion.schema and ( + not word_before_cursor.startswith('pg_')): + views = [v for v in views if not v.startswith('pg_')] + + return self.find_matches(word_before_cursor, views, mode='strict', meta='view') + + def get_alias_matches(self, suggestion, word_before_cursor): + aliases = suggestion.aliases + return self.find_matches(word_before_cursor, aliases, mode='strict', + meta='table alias') + + def get_database_matches(self, _, word_before_cursor): + databases = [] + + query = render_template("/".join([self.sql_path, 'databases.sql'])) + if self.conn.connected(): + status, res = self.conn.execute_dict(query) + if status: + for record in res['rows']: + databases.append(record['datname']) + + return self.find_matches(word_before_cursor, databases, mode='strict', + meta='database') + + def get_keyword_matches(self, _, word_before_cursor): + return self.find_matches(word_before_cursor, self.keywords, + mode='strict', meta='keyword') + + def get_datatype_matches(self, suggestion, word_before_cursor): + # suggest custom datatypes + types = self.populate_schema_objects(suggestion.schema, 'datatypes') + matches = self.find_matches(word_before_cursor, types, mode='strict', meta='datatype') + + return matches + + def get_word_before_cursor(self, word=False): + """ + Give the word before the cursor. + If we have whitespace before the cursor this returns an empty string. + + Args: + word: + """ + + if self.text_before_cursor[-1:].isspace(): + return '' + else: + return self.text_before_cursor[self.find_start_of_previous_word(word=word):] + + def find_start_of_previous_word(self, count=1, word=False): + """ + Return an index relative to the cursor position pointing to the start + of the previous word. Return `None` if nothing was found. + + Args: + count: + word: + """ + + # Reverse the text before the cursor, in order to do an efficient + # backwards search. + text_before_cursor = self.text_before_cursor[::-1] + + regex = _FIND_BIG_WORD_RE if word else _FIND_WORD_RE + iterator = regex.finditer(text_before_cursor) + + try: + for i, match in enumerate(iterator): + if i + 1 == count: + return - match.end(1) + except StopIteration: + pass + + suggestion_matchers = { + Column: get_column_matches, + Function: get_function_matches, + Schema: get_schema_matches, + Table: get_table_matches, + View: get_view_matches, + Alias: get_alias_matches, + Database: get_database_matches, + Keyword: get_keyword_matches, + Datatype: get_datatype_matches, + } + + def populate_scoped_cols(self, scoped_tbls): + """ Find all columns in a set of scoped_tables + :param scoped_tbls: list of TableReference namedtuples + :return: list of column names + """ + + columns = [] + for tbl in scoped_tbls: + if tbl.schema: + # A fully qualified schema.relname reference + schema = self.escape_name(tbl.schema) + relname = self.escape_name(tbl.name) + + if tbl.is_function: + query = render_template("/".join([self.sql_path, 'functions.sql']), + schema_name=schema, + func_name=relname) + + if self.conn.connected(): + status, res = self.conn.execute_dict(query) + func = None + if status: + for row in res['rows']: + func = FunctionMetadata(row['schema_name'], row['func_name'], + row['arg_list'], row['return_type'], + row['is_aggregate'], row['is_window'], + row['is_set_returning']) + if func: + columns.extend(func.fieldnames()) + else: + # We don't know if schema.relname is a table or view. Since + # tables and views cannot share the same name, we can check + # one at a time + + query = render_template("/".join([self.sql_path, 'columns.sql']), + object_name='table', + schema_name=schema, + rel_name=relname) + + if self.conn.connected(): + status, res = self.conn.execute_dict(query) + if status: + if len(res['rows']) > 0: + # Table exists, so don't bother checking for a view + for record in res['rows']: + columns.append(record['column_name']) + else: + query = render_template("/".join([self.sql_path, 'columns.sql']), + object_name='view', + schema_name=schema, + rel_name=relname) + + if self.conn.connected(): + status, res = self.conn.execute_dict(query) + if status: + for record in res['rows']: + columns.append(record['column_name']) + else: + # Schema not specified, so traverse the search path looking for + # a table or view that matches. Note that in order to get proper + # shadowing behavior, we need to check both views and tables for + # each schema before checking the next schema + for schema in self.search_path: + relname = self.escape_name(tbl.name) + + if tbl.is_function: + query = render_template("/".join([self.sql_path, 'functions.sql']), + schema_name=schema, + func_name=relname) + + if self.conn.connected(): + status, res = self.conn.execute_dict(query) + func = None + if status: + for row in res['rows']: + func = FunctionMetadata(row['schema_name'], row['func_name'], + row['arg_list'], row['return_type'], + row['is_aggregate'], row['is_window'], + row['is_set_returning']) + if func: + columns.extend(func.fieldnames()) + else: + query = render_template("/".join([self.sql_path, 'columns.sql']), + object_name='table', + schema_name=schema, + rel_name=relname) + + if self.conn.connected(): + status, res = self.conn.execute_dict(query) + if status: + if len(res['rows']) > 0: + # Table exists, so don't bother checking for a view + for record in res['rows']: + columns.append(record['column_name']) + else: + query = render_template("/".join([self.sql_path, 'columns.sql']), + object_name='view', + schema_name=schema, + rel_name=relname) + + if self.conn.connected(): + status, res = self.conn.execute_dict(query) + if status: + for record in res['rows']: + columns.append(record['column_name']) + + return columns + + def populate_schema_objects(self, schema, obj_type): + """ + Returns list of tables or functions for a (optional) schema + + Args: + schema: + obj_type: + """ + + in_clause = '' + query = '' + objects = [] + + if schema: + in_clause = '\'' + schema + '\'' + else: + for r in self.search_path: + in_clause += '\'' + r + '\',' + + # Remove extra comma + if len(in_clause) > 0: + in_clause = in_clause[:-1] + + if obj_type == 'tables': + query = render_template("/".join([self.sql_path, 'tableview.sql']), + schema_names=in_clause, + object_name='tables') + elif obj_type == 'views': + query = render_template("/".join([self.sql_path, 'tableview.sql']), + schema_names=in_clause, + object_name='views') + elif obj_type == 'functions': + query = render_template("/".join([self.sql_path, 'functions.sql']), + schema_names=in_clause) + elif obj_type == 'datatypes': + query = render_template("/".join([self.sql_path, 'datatypes.sql']), + schema_names=in_clause) + + if self.conn.connected(): + status, res = self.conn.execute_dict(query) + if status: + for record in res['rows']: + objects.append(record['object_name']) + + return objects + + def populate_functions(self, schema): + """ + Returns a list of function names + + filter_func is a function that accepts a FunctionMetadata namedtuple + and returns a boolean indicating whether that function should be + kept or discarded + + Args: + schema: + """ + + in_clause = '' + funcs = [] + + if schema: + in_clause = '\'' + schema + '\'' + else: + for r in self.search_path: + in_clause += '\'' + r + '\',' + + # Remove extra comma + if len(in_clause) > 0: + in_clause = in_clause[:-1] + + query = render_template("/".join([self.sql_path, 'functions.sql']), + schema_names=in_clause, + is_set_returning=True) + + if self.conn.connected(): + status, res = self.conn.execute_dict(query) + if status: + for record in res['rows']: + funcs.append(record['object_name']) + + return funcs + + def suggest_type(self, full_text, text_before_cursor): + """ + Takes the full_text that is typed so far and also the text before the + cursor to suggest completion type and scope. + + Returns a tuple with a type of entity ('table', 'column' etc) and a scope. + A scope for a column category will be a list of tables. + + Args: + full_text: Contains complete query + text_before_cursor: Contains text before the cursor + """ + + word_before_cursor = last_word(text_before_cursor, include='many_punctuations') + + identifier = None + + def strip_named_query(txt): + """ + This will strip "save named query" command in the beginning of the line: + '\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' + ' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc' + + Args: + txt: + """ + + pattern = re.compile(r'^\s*\\ns\s+[A-z0-9\-_]+\s+') + if pattern.match(txt): + txt = pattern.sub('', txt) + return txt + + full_text = strip_named_query(full_text) + text_before_cursor = strip_named_query(text_before_cursor) + + # If we've partially typed a word then word_before_cursor won't be an empty + # string. In that case we want to remove the partially typed string before + # sending it to the sqlparser. Otherwise the last token will always be the + # partially typed string which renders the smart completion useless because + # it will always return the list of keywords as completion. + if word_before_cursor: + if word_before_cursor[-1] == '(' or word_before_cursor[0] == '\\': + parsed = sqlparse.parse(text_before_cursor) + else: + parsed = sqlparse.parse( + text_before_cursor[:-len(word_before_cursor)]) + + identifier = parse_partial_identifier(word_before_cursor) + else: + parsed = sqlparse.parse(text_before_cursor) + + statement = None + if len(parsed) > 1: + # Multiple statements being edited -- isolate the current one by + # cumulatively summing statement lengths to find the one that bounds the + # current position + current_pos = len(text_before_cursor) + stmt_start, stmt_end = 0, 0 + + for statement in parsed: + stmt_len = len(statement.to_unicode()) + stmt_start, stmt_end = stmt_end, stmt_end + stmt_len + + if stmt_end >= current_pos: + break + + text_before_cursor = full_text[stmt_start:current_pos] + full_text = full_text[stmt_start:] + elif parsed: + # A single statement + statement = parsed[0] + else: + # The empty string + statement = None + + last_token = statement and statement.token_prev(len(statement.tokens)) or '' + + return self.suggest_based_on_last_token(last_token, text_before_cursor, + full_text, identifier) + + def suggest_based_on_last_token(self, token, text_before_cursor, full_text, identifier): + if isinstance(token, string_types): + token_v = token.lower() + elif isinstance(token, Comparison): + # If 'token' is a Comparison type such as + # 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling + # token.value on the comparison type will only return the lhs of the + # comparison. In this case a.id. So we need to do token.tokens to get + # both sides of the comparison and pick the last token out of that + # list. + token_v = token.tokens[-1].value.lower() + elif isinstance(token, Where): + # sqlparse groups all tokens from the where clause into a single token + # list. This means that token.value may be something like + # 'where foo > 5 and '. We need to look "inside" token.tokens to handle + # suggestions in complicated where clauses correctly + prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor) + return self.suggest_based_on_last_token( + prev_keyword, text_before_cursor, full_text, identifier) + elif isinstance(token, Identifier): + # If the previous token is an identifier, we can suggest datatypes if + # we're in a parenthesized column/field list, e.g.: + # CREATE TABLE foo (Identifier + # CREATE FUNCTION foo (Identifier + # If we're not in a parenthesized list, the most likely scenario is the + # user is about to specify an alias, e.g.: + # SELECT Identifier + # SELECT foo FROM Identifier + prev_keyword, _ = find_prev_keyword(text_before_cursor) + if prev_keyword and prev_keyword.value == '(': + # Suggest datatypes + return self.suggest_based_on_last_token( + 'type', text_before_cursor, full_text, identifier) + else: + return Keyword(), + else: + token_v = token.value.lower() + + if not token: + return Keyword(), + elif token_v.endswith('('): + p = sqlparse.parse(text_before_cursor)[0] + + if p.tokens and isinstance(p.tokens[-1], Where): + # Four possibilities: + # 1 - Parenthesized clause like "WHERE foo AND (" + # Suggest columns/functions + # 2 - Function call like "WHERE foo(" + # Suggest columns/functions + # 3 - Subquery expression like "WHERE EXISTS (" + # Suggest keywords, in order to do a subquery + # 4 - Subquery OR array comparison like "WHERE foo = ANY(" + # Suggest columns/functions AND keywords. (If we wanted to be + # really fancy, we could suggest only array-typed columns) + + column_suggestions = self.suggest_based_on_last_token( + 'where', text_before_cursor, full_text, identifier) + + # Check for a subquery expression (cases 3 & 4) + where = p.tokens[-1] + prev_tok = where.token_prev(len(where.tokens) - 1) + + if isinstance(prev_tok, Comparison): + # e.g. "SELECT foo FROM bar WHERE foo = ANY(" + prev_tok = prev_tok.tokens[-1] + + prev_tok = prev_tok.value.lower() + if prev_tok == 'exists': + return Keyword(), + else: + return column_suggestions + + # Get the token before the parens + prev_tok = p.token_prev(len(p.tokens) - 1) + if prev_tok and prev_tok.value and prev_tok.value.lower() == 'using': + # tbl1 INNER JOIN tbl2 USING (col1, col2) + tables = extract_tables(full_text) + + # suggest columns that are present in more than one table + return Column(tables=tables, drop_unique=True), + + elif p.token_first().value.lower() == 'select': + # If the lparen is preceeded by a space chances are we're about to + # do a sub-select. + if last_word(text_before_cursor, + 'all_punctuations').startswith('('): + return Keyword(), + # We're probably in a function argument list + return Column(tables=extract_tables(full_text)), + elif token_v in ('set', 'by', 'distinct'): + return Column(tables=extract_tables(full_text)), + elif token_v in ('select', 'where', 'having'): + # Check for a table alias or schema qualification + parent = (identifier and identifier.get_parent_name()) or [] + + if parent: + tables = extract_tables(full_text) + tables = tuple(t for t in tables if self.identifies(parent, t)) + return (Column(tables=tables), + Table(schema=parent), + View(schema=parent), + Function(schema=parent),) + else: + return (Column(tables=extract_tables(full_text)), + Function(schema=None), + Keyword(),) + + elif (token_v.endswith('join') and token.is_keyword) or \ + (token_v in ('copy', 'from', 'update', 'into', 'describe', 'truncate')): + + schema = (identifier and identifier.get_parent_name()) or None + + # Suggest tables from either the currently-selected schema or the + # public schema if no schema has been specified + suggest = [Table(schema=schema)] + + if not schema: + # Suggest schemas + suggest.insert(0, Schema()) + + # Only tables can be TRUNCATED, otherwise suggest views + if token_v != 'truncate': + suggest.append(View(schema=schema)) + + # Suggest set-returning functions in the FROM clause + if token_v == 'from' or (token_v.endswith('join') and token.is_keyword): + suggest.append(Function(schema=schema, filter='is_set_returning')) + + return tuple(suggest) + + elif token_v in ('table', 'view', 'function'): + # E.g. 'DROP FUNCTION ', 'ALTER TABLE ' + rel_type = {'table': Table, 'view': View, 'function': Function}[token_v] + schema = (identifier and identifier.get_parent_name()) or None + if schema: + return rel_type(schema=schema), + else: + return Schema(), rel_type(schema=schema) + elif token_v == 'on': + tables = extract_tables(full_text) # [(schema, table, alias), ...] + parent = (identifier and identifier.get_parent_name()) or None + if parent: + # "ON parent." + # parent can be either a schema name or table alias + tables = tuple(t for t in tables if self.identifies(parent, t)) + return (Column(tables=tables), + Table(schema=parent), + View(schema=parent), + Function(schema=parent)) + else: + # ON + # Use table alias if there is one, otherwise the table name + aliases = tuple(t.alias or t.name for t in tables) + return Alias(aliases=aliases), + + elif token_v in ('c', 'use', 'database', 'template'): + # "\c ", "DROP DATABASE ", + # "CREATE DATABASE WITH TEMPLATE " + return Database(), + elif token_v == 'schema': + # DROP SCHEMA schema_name + return Schema(), + elif token_v.endswith(',') or token_v in ('=', 'and', 'or'): + prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor) + if prev_keyword: + return self.suggest_based_on_last_token( + prev_keyword, text_before_cursor, full_text, identifier) + else: + return () + elif token_v in ('type', '::'): + # ALTER TABLE foo SET DATA TYPE bar + # SELECT foo::bar + # Note that tables are a form of composite type in postgresql, so + # they're suggested here as well + schema = (identifier and identifier.get_parent_name()) or None + suggestions = [Datatype(schema=schema), + Table(schema=schema)] + if not schema: + suggestions.append(Schema()) + return tuple(suggestions) + else: + return Keyword(), + + def identifies(self, id, ref): + """ + Returns true if string `id` matches TableReference `ref` + + Args: + id: + ref: + """ + return id == ref.alias or id == ref.name or ( + ref.schema and (id == ref.schema + '.' + ref.name)) diff --git a/web/pgadmin/utils/sqlautocomplete/completion.py b/web/pgadmin/utils/sqlautocomplete/completion.py new file mode 100644 index 000000000..cea14b5da --- /dev/null +++ b/web/pgadmin/utils/sqlautocomplete/completion.py @@ -0,0 +1,67 @@ +""" +Using Completion class from + https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/completion.py +""" + +from __future__ import unicode_literals +from abc import ABCMeta, abstractmethod +from six import with_metaclass + +__all__ = ( + 'Completion' +) + + +class Completion(object): + """ + :param text: The new string that will be inserted into the document. + :param start_position: Position relative to the cursor_position where the + new text will start. The text will be inserted between the + start_position and the original cursor position. + :param display: (optional string) If the completion has to be displayed + differently in the completion menu. + :param display_meta: (Optional string) Meta information about the + completion, e.g. the path or source where it's coming from. + :param get_display_meta: Lazy `display_meta`. Retrieve meta information + only when meta is displayed. + """ + def __init__(self, text, start_position=0, display=None, display_meta=None, + get_display_meta=None): + self.text = text + self.start_position = start_position + self._display_meta = display_meta + self._get_display_meta = get_display_meta + + if display is None: + self.display = text + else: + self.display = display + + assert self.start_position <= 0 + + def __repr__(self): + return '%s(text=%r, start_position=%r)' % ( + self.__class__.__name__, self.text, self.start_position) + + def __eq__(self, other): + return ( + self.text == other.text and + self.start_position == other.start_position and + self.display == other.display and + self.display_meta == other.display_meta) + + def __hash__(self): + return hash((self.text, self.start_position, self.display, self.display_meta)) + + @property + def display_meta(self): + # Return meta-text. (This is lazy when using "get_display_meta".) + if self._display_meta is not None: + return self._display_meta + + elif self._get_display_meta: + self._display_meta = self._get_display_meta() + return self._display_meta + + else: + return '' diff --git a/web/pgadmin/utils/sqlautocomplete/counter.py b/web/pgadmin/utils/sqlautocomplete/counter.py new file mode 100644 index 000000000..4dc896a14 --- /dev/null +++ b/web/pgadmin/utils/sqlautocomplete/counter.py @@ -0,0 +1,189 @@ +""" +Copied from http://code.activestate.com/recipes/576611-counter-class/ +""" + +from operator import itemgetter +from heapq import nlargest +from itertools import repeat, ifilter + + +class Counter(dict): + '''Dict subclass for counting hashable objects. Sometimes called a bag + or multiset. Elements are stored as dictionary keys and their counts + are stored as dictionary values. + + >>> Counter('zyzygy') + Counter({'y': 3, 'z': 2, 'g': 1}) + + ''' + + def __init__(self, iterable=None, **kwds): + '''Create a new, empty Counter object. And if given, count elements + from an input iterable. Or, initialize the count from another mapping + of elements to their counts. + + >>> c = Counter() # a new, empty counter + >>> c = Counter('gallahad') # a new counter from an iterable + >>> c = Counter({'a': 4, 'b': 2}) # a new counter from a mapping + >>> c = Counter(a=4, b=2) # a new counter from keyword args + + ''' + self.update(iterable, **kwds) + + def __missing__(self, key): + return 0 + + def most_common(self, n=None): + '''List the n most common elements and their counts from the most + common to the least. If n is None, then list all element counts. + + >>> Counter('abracadabra').most_common(3) + [('a', 5), ('r', 2), ('b', 2)] + + ''' + if n is None: + return sorted(self.iteritems(), key=itemgetter(1), reverse=True) + return nlargest(n, self.iteritems(), key=itemgetter(1)) + + def elements(self): + '''Iterator over elements repeating each as many times as its count. + + >>> c = Counter('ABCABC') + >>> sorted(c.elements()) + ['A', 'A', 'B', 'B', 'C', 'C'] + + If an element's count has been set to zero or is a negative number, + elements() will ignore it. + + ''' + for elem, count in self.iteritems(): + for _ in repeat(None, count): + yield elem + + # Override dict methods where the meaning changes for Counter objects. + + @classmethod + def fromkeys(cls, iterable, v=None): + raise NotImplementedError( + 'Counter.fromkeys() is undefined. Use Counter(iterable) instead.') + + def update(self, iterable=None, **kwds): + '''Like dict.update() but add counts instead of replacing them. + + Source can be an iterable, a dictionary, or another Counter instance. + + >>> c = Counter('which') + >>> c.update('witch') # add elements from another iterable + >>> d = Counter('watch') + >>> c.update(d) # add elements from another counter + >>> c['h'] # four 'h' in which, witch, and watch + 4 + + ''' + if iterable is not None: + if hasattr(iterable, 'iteritems'): + if self: + self_get = self.get + for elem, count in iterable.iteritems(): + self[elem] = self_get(elem, 0) + count + else: + dict.update(self, iterable) # fast path when counter is empty + else: + self_get = self.get + for elem in iterable: + self[elem] = self_get(elem, 0) + 1 + if kwds: + self.update(kwds) + + def copy(self): + 'Like dict.copy() but returns a Counter instance instead of a dict.' + return Counter(self) + + def __delitem__(self, elem): + 'Like dict.__delitem__() but does not raise KeyError for missing values.' + if elem in self: + dict.__delitem__(self, elem) + + def __repr__(self): + if not self: + return '%s()' % self.__class__.__name__ + items = ', '.join(map('%r: %r'.__mod__, self.most_common())) + return '%s({%s})' % (self.__class__.__name__, items) + + # Multiset-style mathematical operations discussed in: + # Knuth TAOCP Volume II section 4.6.3 exercise 19 + # and at http://en.wikipedia.org/wiki/Multiset + # + # Outputs guaranteed to only include positive counts. + # + # To strip negative and zero counts, add-in an empty counter: + # c += Counter() + + def __add__(self, other): + '''Add counts from two counters. + + >>> Counter('abbb') + Counter('bcc') + Counter({'b': 4, 'c': 2, 'a': 1}) + + + ''' + if not isinstance(other, Counter): + return NotImplemented + result = Counter() + for elem in set(self) | set(other): + newcount = self[elem] + other[elem] + if newcount > 0: + result[elem] = newcount + return result + + def __sub__(self, other): + ''' Subtract count, but keep only results with positive counts. + + >>> Counter('abbbc') - Counter('bccd') + Counter({'b': 2, 'a': 1}) + + ''' + if not isinstance(other, Counter): + return NotImplemented + result = Counter() + for elem in set(self) | set(other): + newcount = self[elem] - other[elem] + if newcount > 0: + result[elem] = newcount + return result + + def __or__(self, other): + '''Union is the maximum of value in either of the input counters. + + >>> Counter('abbb') | Counter('bcc') + Counter({'b': 3, 'c': 2, 'a': 1}) + + ''' + if not isinstance(other, Counter): + return NotImplemented + _max = max + result = Counter() + for elem in set(self) | set(other): + newcount = _max(self[elem], other[elem]) + if newcount > 0: + result[elem] = newcount + return result + + def __and__(self, other): + ''' Intersection is the minimum of corresponding counts. + + >>> Counter('abbb') & Counter('bcc') + Counter({'b': 1}) + + ''' + if not isinstance(other, Counter): + return NotImplemented + _min = min + result = Counter() + if len(self) < len(other): + self, other = other, self + for elem in ifilter(self.__contains__, other): + newcount = _min(self[elem], other[elem]) + if newcount > 0: + result[elem] = newcount + return result diff --git a/web/pgadmin/utils/sqlautocomplete/function_metadata.py b/web/pgadmin/utils/sqlautocomplete/function_metadata.py new file mode 100644 index 000000000..d11f33b58 --- /dev/null +++ b/web/pgadmin/utils/sqlautocomplete/function_metadata.py @@ -0,0 +1,149 @@ +import re +import sqlparse +from sqlparse.tokens import Whitespace, Comment, Keyword, Name, Punctuation + + +table_def_regex = re.compile(r'^TABLE\s*\((.+)\)$', re.IGNORECASE) + + +class FunctionMetadata(object): + + def __init__(self, schema_name, func_name, arg_list, return_type, is_aggregate, + is_window, is_set_returning): + """Class for describing a postgresql function""" + + self.schema_name = schema_name + self.func_name = func_name + self.arg_list = arg_list.strip() + self.return_type = return_type.strip() + self.is_aggregate = is_aggregate + self.is_window = is_window + self.is_set_returning = is_set_returning + + def __eq__(self, other): + return (isinstance(other, self.__class__) + and self.__dict__ == other.__dict__) + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash((self.schema_name, self.func_name, self.arg_list, + self.return_type, self.is_aggregate, self.is_window, + self.is_set_returning)) + + def __repr__(self): + return (('%s(schema_name=%r, func_name=%r, arg_list=%r, return_type=%r,' + ' is_aggregate=%r, is_window=%r, is_set_returning=%r)') + % (self.__class__.__name__, self.schema_name, self.func_name, + self.arg_list, self.return_type, self.is_aggregate, + self.is_window, self.is_set_returning)) + + def fieldnames(self): + """Returns a list of output field names""" + + if self.return_type.lower() == 'void': + return [] + + match = table_def_regex.match(self.return_type) + if match: + # Function returns a table -- get the column names + return list(field_names(match.group(1), mode_filter=None)) + + # Function may have named output arguments -- find them and return + # their names + return list(field_names(self.arg_list, mode_filter=('OUT', 'INOUT'))) + + +class TypedFieldMetadata(object): + """Describes typed field from a function signature or table definition + + Attributes are: + name The name of the argument/column + mode 'IN', 'OUT', 'INOUT', 'VARIADIC' + type A list of tokens denoting the type + default A list of tokens denoting the default value + unknown A list of tokens not assigned to type or default + """ + + __slots__ = ['name', 'mode', 'type', 'default', 'unknown'] + + def __init__(self): + self.name = None + self.mode = 'IN' + self.type = [] + self.default = [] + self.unknown = [] + + def __getitem__(self, attr): + return getattr(self, attr) + + +def parse_typed_field_list(tokens): + """Parses a argument/column list, yielding TypedFieldMetadata objects + + Field/column lists are used in function signatures and table + definitions. This function parses a flattened list of sqlparse tokens + and yields one metadata argument per argument / column. + """ + + # postgres function argument list syntax: + # " ( [ [ argmode ] [ argname ] argtype + # [ { DEFAULT | = } default_expr ] [, ...] ] )" + + mode_names = set(('IN', 'OUT', 'INOUT', 'VARIADIC')) + parse_state = 'type' + parens = 0 + field = TypedFieldMetadata() + + for tok in tokens: + if tok.ttype in Whitespace or tok.ttype in Comment: + continue + elif tok.ttype in Punctuation: + if parens == 0 and tok.value == ',': + # End of the current field specification + if field.type: + yield field + # Initialize metadata holder for the next field + field, parse_state = TypedFieldMetadata(), 'type' + elif parens == 0 and tok.value == '=': + parse_state = 'default' + else: + field[parse_state].append(tok) + if tok.value == '(': + parens += 1 + elif tok.value == ')': + parens -= 1 + elif parens == 0: + if tok.ttype in Keyword: + if not field.name and tok.value.upper() in mode_names: + # No other keywords allowed before arg name + field.mode = tok.value.upper() + elif tok.value.upper() == 'DEFAULT': + parse_state = 'default' + else: + parse_state = 'unknown' + elif tok.ttype == Name and not field.name: + # note that `ttype in Name` would also match Name.Builtin + field.name = tok.value + else: + field[parse_state].append(tok) + else: + field[parse_state].append(tok) + + # Final argument won't be followed by a comma, so make sure it gets yielded + if field.type: + yield field + + +def field_names(sql, mode_filter=('IN', 'OUT', 'INOUT', 'VARIADIC')): + """Yields field names from a table declaration""" + + if not sql: + return + + # sql is something like "x int, y text, ..." + tokens = sqlparse.parse(sql)[0].flatten() + for f in parse_typed_field_list(tokens): + if f.name and (not mode_filter or f.mode in mode_filter): + yield f.name diff --git a/web/pgadmin/utils/sqlautocomplete/parseutils.py b/web/pgadmin/utils/sqlautocomplete/parseutils.py new file mode 100644 index 000000000..61f3cdc14 --- /dev/null +++ b/web/pgadmin/utils/sqlautocomplete/parseutils.py @@ -0,0 +1,288 @@ + +import re +import sqlparse +from collections import namedtuple +from sqlparse.sql import IdentifierList, Identifier, Function +from sqlparse.tokens import Keyword, DML, Punctuation, Token, Error + +cleanup_regex = { + # This matches only alphanumerics and underscores. + 'alphanum_underscore': re.compile(r'(\w+)$'), + # This matches everything except spaces, parens, colon, and comma + 'many_punctuations': re.compile(r'([^():,\s]+)$'), + # This matches everything except spaces, parens, colon, comma, and period + 'most_punctuations': re.compile(r'([^\.():,\s]+)$'), + # This matches everything except a space. + 'all_punctuations': re.compile('([^\s]+)$'), + } + + +def last_word(text, include='alphanum_underscore'): + """ + Find the last word in a sentence. + + >>> last_word('abc') + 'abc' + >>> last_word(' abc') + 'abc' + >>> last_word('') + '' + >>> last_word(' ') + '' + >>> last_word('abc ') + '' + >>> last_word('abc def') + 'def' + >>> last_word('abc def ') + '' + >>> last_word('abc def;') + '' + >>> last_word('bac $def') + 'def' + >>> last_word('bac $def', include='most_punctuations') + '$def' + >>> last_word('bac \def', include='most_punctuations') + '\\\\def' + >>> last_word('bac \def;', include='most_punctuations') + '\\\\def;' + >>> last_word('bac::def', include='most_punctuations') + 'def' + >>> last_word('"foo*bar', include='most_punctuations') + '"foo*bar' + """ + + if not text: # Empty string + return '' + + if text[-1].isspace(): + return '' + else: + regex = cleanup_regex[include] + matches = regex.search(text) + if matches: + return matches.group(0) + else: + return '' + + +TableReference = namedtuple('TableReference', ['schema', 'name', 'alias', + 'is_function']) + + +# This code is borrowed from sqlparse example script. +# +def is_subselect(parsed): + if not parsed.is_group(): + return False + for item in parsed.tokens: + if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT', + 'UPDATE', 'CREATE', 'DELETE'): + return True + return False + + +def _identifier_is_function(identifier): + return any(isinstance(t, Function) for t in identifier.tokens) + + +def extract_from_part(parsed, stop_at_punctuation=True): + tbl_prefix_seen = False + for item in parsed.tokens: + if tbl_prefix_seen: + if is_subselect(item): + for x in extract_from_part(item, stop_at_punctuation): + yield x + elif stop_at_punctuation and item.ttype is Punctuation: + raise StopIteration + # An incomplete nested select won't be recognized correctly as a + # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes + # the second FROM to trigger this elif condition resulting in a + # StopIteration. So we need to ignore the keyword if the keyword + # FROM. + # Also 'SELECT * FROM abc JOIN def' will trigger this elif + # condition. So we need to ignore the keyword JOIN and its variants + # INNER JOIN, FULL OUTER JOIN, etc. + elif item.ttype is Keyword and ( + not item.value.upper() == 'FROM') and ( + not item.value.upper().endswith('JOIN')): + tbl_prefix_seen = False + else: + yield item + elif item.ttype is Keyword or item.ttype is Keyword.DML: + item_val = item.value.upper() + if (item_val in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE') or + item_val.endswith('JOIN')): + tbl_prefix_seen = True + # 'SELECT a, FROM abc' will detect FROM as part of the column list. + # So this check here is necessary. + elif isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + if (identifier.ttype is Keyword and + identifier.value.upper() == 'FROM'): + tbl_prefix_seen = True + break + + +def extract_table_identifiers(token_stream, allow_functions=True): + """yields tuples of TableReference namedtuples""" + + for item in token_stream: + if isinstance(item, IdentifierList): + for identifier in item.get_identifiers(): + # Sometimes Keywords (such as FROM ) are classified as + # identifiers which don't have the get_real_name() method. + try: + schema_name = identifier.get_parent_name() + real_name = identifier.get_real_name() + is_function = (allow_functions and + _identifier_is_function(identifier)) + except AttributeError: + continue + if real_name: + yield TableReference(schema_name, real_name, + identifier.get_alias(), is_function) + elif isinstance(item, Identifier): + real_name = item.get_real_name() + schema_name = item.get_parent_name() + is_function = allow_functions and _identifier_is_function(item) + + if real_name: + yield TableReference(schema_name, real_name, item.get_alias(), + is_function) + else: + name = item.get_name() + yield TableReference(None, name, item.get_alias() or name, + is_function) + elif isinstance(item, Function): + yield TableReference(None, item.get_real_name(), item.get_alias(), + allow_functions) + + +# extract_tables is inspired from examples in the sqlparse lib. +def extract_tables(sql): + """Extract the table names from an SQL statment. + + Returns a list of TableReference namedtuples + + """ + parsed = sqlparse.parse(sql) + if not parsed: + return () + + # INSERT statements must stop looking for tables at the sign of first + # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2) + # abc is the table name, but if we don't stop at the first lparen, then + # we'll identify abc, col1 and col2 as table names. + insert_stmt = parsed[0].token_first().value.lower() == 'insert' + stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt) + + # Kludge: sqlparse mistakenly identifies insert statements as + # function calls due to the parenthesized column list, e.g. interprets + # "insert into foo (bar, baz)" as a function call to foo with arguments + # (bar, baz). So don't allow any identifiers in insert statements + # to have is_function=True + identifiers = extract_table_identifiers(stream, + allow_functions=not insert_stmt) + return tuple(identifiers) + + +def find_prev_keyword(sql): + """ Find the last sql keyword in an SQL statement + + Returns the value of the last keyword, and the text of the query with + everything after the last keyword stripped + """ + if not sql.strip(): + return None, '' + + parsed = sqlparse.parse(sql)[0] + flattened = list(parsed.flatten()) + + logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN') + + for t in reversed(flattened): + if t.value == '(' or (t.is_keyword and ( + t.value.upper() not in logical_operators)): + # Find the location of token t in the original parsed statement + # We can't use parsed.token_index(t) because t may be a child token + # inside a TokenList, in which case token_index thows an error + # Minimal example: + # p = sqlparse.parse('select * from foo where bar') + # t = list(p.flatten())[-3] # The "Where" token + # p.token_index(t) # Throws ValueError: not in list + idx = flattened.index(t) + + # Combine the string values of all tokens in the original list + # up to and including the target keyword token t, to produce a + # query string with everything after the keyword token removed + text = ''.join(tok.value for tok in flattened[:idx+1]) + return t, text + + return None, '' + + +# Postgresql dollar quote signs look like `$$` or `$tag$` +dollar_quote_regex = re.compile(r'^\$[^$]*\$$') + + +def is_open_quote(sql): + """Returns true if the query contains an unclosed quote""" + + # parsed can contain one or more semi-colon separated commands + parsed = sqlparse.parse(sql) + return any(_parsed_is_open_quote(p) for p in parsed) + + +def _parsed_is_open_quote(parsed): + tokens = list(parsed.flatten()) + + i = 0 + while i < len(tokens): + tok = tokens[i] + if tok.match(Token.Error, "'"): + # An unmatched single quote + return True + elif (tok.ttype in Token.Name.Builtin + and dollar_quote_regex.match(tok.value)): + # Find the matching closing dollar quote sign + for (j, tok2) in enumerate(tokens[i+1:], i+1): + if tok2.match(Token.Name.Builtin, tok.value): + # Found the matching closing quote - continue our scan for + # open quotes thereafter + i = j + break + else: + # No matching dollar sign quote + return True + + i += 1 + + return False + + +def parse_partial_identifier(word): + """Attempt to parse a (partially typed) word as an identifier + + word may include a schema qualification, like `schema_name.partial_name` + or `schema_name.` There may also be unclosed quotation marks, like + `"schema`, or `schema."partial_name` + + :param word: string representing a (partially complete) identifier + :return: sqlparse.sql.Identifier, or None + """ + + p = sqlparse.parse(word)[0] + n_tok = len(p.tokens) + if n_tok == 1 and isinstance(p.tokens[0], Identifier): + return p.tokens[0] + elif p.token_next_match(0, Error, '"'): + # An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar' + # Close the double quote, then reparse + return parse_partial_identifier(word + '"') + else: + return None + + +if __name__ == '__main__': + sql = 'select * from (select t. from tabl t' + print (extract_tables(sql)) diff --git a/web/pgadmin/utils/sqlautocomplete/prioritization.py b/web/pgadmin/utils/sqlautocomplete/prioritization.py new file mode 100644 index 000000000..c9c99e0d9 --- /dev/null +++ b/web/pgadmin/utils/sqlautocomplete/prioritization.py @@ -0,0 +1,49 @@ +import re +import sqlparse +from sqlparse.tokens import Name +from collections import defaultdict + +white_space_regex = re.compile('\\s+', re.MULTILINE) + + +def _compile_regex(keyword): + # Surround the keyword with word boundaries and replace interior whitespace + # with whitespace wildcards + pattern = '\\b' + re.sub(white_space_regex, '\\s+', keyword) + '\\b' + return re.compile(pattern, re.MULTILINE | re.IGNORECASE) + + +class PrevalenceCounter(object): + def __init__(self, keywords): + self.keyword_counts = defaultdict(int) + self.name_counts = defaultdict(int) + self.keyword_regexs = dict((kw, _compile_regex(kw)) for kw in keywords) + + def update(self, text): + self.update_keywords(text) + self.update_names(text) + + def update_names(self, text): + for parsed in sqlparse.parse(text): + for token in parsed.flatten(): + if token.ttype in Name: + self.name_counts[token.value] += 1 + + def clear_names(self): + self.name_counts = defaultdict(int) + + def update_keywords(self, text): + # Count keywords. Can't rely for sqlparse for this, because it's + # database agnostic + for keyword, regex in self.keyword_regexs.items(): + for _ in regex.finditer(text): + self.keyword_counts[keyword] += 1 + + def keyword_count(self, keyword): + return self.keyword_counts[keyword] + + def name_count(self, name): + return self.name_counts[name] + + +