Adding support for autocomplete in the SQL Editor.
In Query editor, we can use the autocomplete feature by using keyword combination - 'Ctrl + Space'.pull/3/head
parent
da28dc8507
commit
0a354055a9
|
@ -43,3 +43,4 @@ traceback2==1.4.0
|
|||
unittest2==1.1.0
|
||||
Werkzeug==0.9.6
|
||||
WTForms==2.0.2
|
||||
sqlparse==0.1.19
|
||||
|
|
|
@ -37,3 +37,4 @@ unittest2==1.1.0
|
|||
Werkzeug==0.9.6
|
||||
wheel==0.24.0
|
||||
WTForms==2.0.2
|
||||
sqlparse==0.1.19
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
{# SQL query for getting columns #}
|
||||
{% if object_name == 'table' %}
|
||||
SELECT
|
||||
att.attname column_name
|
||||
FROM pg_catalog.pg_attribute att
|
||||
INNER JOIN pg_catalog.pg_class cls
|
||||
ON att.attrelid = cls.oid
|
||||
INNER JOIN pg_catalog.pg_namespace nsp
|
||||
ON cls.relnamespace = nsp.oid
|
||||
WHERE cls.relkind = ANY(array['r'])
|
||||
AND NOT att.attisdropped
|
||||
AND att.attnum > 0
|
||||
AND (nsp.nspname = '{{schema_name}}' AND cls.relname = '{{rel_name}}')
|
||||
ORDER BY 1
|
||||
{% endif %}
|
||||
{% if object_name == 'view' %}
|
||||
SELECT
|
||||
att.attname column_name
|
||||
FROM pg_catalog.pg_attribute att
|
||||
INNER JOIN pg_catalog.pg_class cls
|
||||
ON att.attrelid = cls.oid
|
||||
INNER JOIN pg_catalog.pg_namespace nsp
|
||||
ON cls.relnamespace = nsp.oid
|
||||
WHERE cls.relkind = ANY(array['v', 'm'])
|
||||
AND NOT att.attisdropped
|
||||
AND att.attnum > 0
|
||||
AND (nsp.nspname = '{{schema_name}}' AND cls.relname = '{{rel_name}}')
|
||||
ORDER BY 1
|
||||
{% endif %}
|
|
@ -0,0 +1,4 @@
|
|||
{# SQL query for getting databases #}
|
||||
SELECT d.datname
|
||||
FROM pg_catalog.pg_database d
|
||||
ORDER BY 1
|
|
@ -0,0 +1,9 @@
|
|||
{# SQL query for getting datatypes #}
|
||||
SELECT n.nspname schema_name,
|
||||
t.typname object_name
|
||||
FROM pg_catalog.pg_type t
|
||||
INNER JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
|
||||
WHERE (t.typrelid = 0 OR (SELECT c.relkind = 'c' FROM pg_catalog.pg_class c WHERE c.oid = t.typrelid))
|
||||
AND NOT EXISTS(SELECT 1 FROM pg_catalog.pg_type el WHERE el.oid = t.typelem AND el.typarray = t.oid)
|
||||
AND n.nspname IN ({{schema_names}})
|
||||
ORDER BY 1, 2;
|
|
@ -0,0 +1,30 @@
|
|||
{# ============= Fetch the list of functions based on given schema_names ============= #}
|
||||
{% if func_name %}
|
||||
SELECT n.nspname schema_name,
|
||||
p.proname func_name,
|
||||
pg_catalog.pg_get_function_arguments(p.oid) arg_list,
|
||||
pg_catalog.pg_get_function_result(p.oid) return_type,
|
||||
p.proisagg is_aggregate,
|
||||
p.proiswindow is_window,
|
||||
p.proretset is_set_returning
|
||||
FROM pg_catalog.pg_proc p
|
||||
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
|
||||
WHERE n.nspname = '{{schema_name}}' AND p.proname = '{{func_name}}'
|
||||
AND p.proretset
|
||||
ORDER BY 1, 2
|
||||
{% else %}
|
||||
SELECT n.nspname schema_name,
|
||||
p.proname object_name,
|
||||
pg_catalog.pg_get_function_arguments(p.oid) arg_list,
|
||||
pg_catalog.pg_get_function_result(p.oid) return_type,
|
||||
p.proisagg is_aggregate,
|
||||
p.proiswindow is_window,
|
||||
p.proretset is_set_returning
|
||||
FROM pg_catalog.pg_proc p
|
||||
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
|
||||
WHERE n.nspname IN ({{schema_names}})
|
||||
{% if is_set_returning %}
|
||||
AND p.proretset
|
||||
{% endif %}
|
||||
ORDER BY 1, 2
|
||||
{% endif %}
|
|
@ -0,0 +1,2 @@
|
|||
{# SQL query for getting keywords #}
|
||||
SELECT upper(word) as word FROM pg_get_keywords()
|
|
@ -0,0 +1,6 @@
|
|||
{# SQL query for getting current_schemas #}
|
||||
{% if search_path %}
|
||||
SELECT * FROM unnest(current_schemas(true)) AS schema
|
||||
{% else %}
|
||||
SELECT nspname AS schema FROM pg_catalog.pg_namespace ORDER BY 1
|
||||
{% endif %}
|
|
@ -0,0 +1,17 @@
|
|||
{# ============= Fetch the list of tables/view based on given schema_names ============= #}
|
||||
{% if object_name == 'tables' %}
|
||||
SELECT n.nspname schema_name,
|
||||
c.relname object_name
|
||||
FROM pg_catalog.pg_class c
|
||||
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE c.relkind = ANY(array['r']) and n.nspname IN ({{schema_names}})
|
||||
ORDER BY 1,2
|
||||
{% endif %}
|
||||
{% if object_name == 'views' %}
|
||||
SELECT n.nspname schema_name,
|
||||
c.relname object_name
|
||||
FROM pg_catalog.pg_class c
|
||||
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
|
||||
WHERE c.relkind = ANY(array['v', 'm']) and n.nspname IN ({{schema_names}})
|
||||
ORDER BY 1,2
|
||||
{% endif %}
|
|
@ -24,6 +24,7 @@ from pgadmin.utils.driver import get_driver
|
|||
from config import PG_DEFAULT_DRIVER
|
||||
from pgadmin.tools.sqleditor.command import QueryToolCommand
|
||||
from pgadmin.utils import get_storage_directory
|
||||
from pgadmin.utils.sqlautocomplete.autocomplete import SQLAutoComplete
|
||||
|
||||
# import unquote from urlib for python2.x and python3.x
|
||||
try:
|
||||
|
@ -890,6 +891,44 @@ def set_auto_rollback(trans_id):
|
|||
return make_json_response(data={'status': status, 'result': res})
|
||||
|
||||
|
||||
@blueprint.route('/autocomplete/<int:trans_id>', methods=["PUT", "POST"])
|
||||
@login_required
|
||||
def auto_complete(trans_id):
|
||||
"""
|
||||
This method implements the autocomplete feature.
|
||||
|
||||
Args:
|
||||
trans_id: unique transaction id
|
||||
"""
|
||||
full_sql = ''
|
||||
text_before_cursor = ''
|
||||
|
||||
if request.data:
|
||||
data = json.loads(request.data.decode())
|
||||
else:
|
||||
data = request.args or request.form
|
||||
|
||||
if len(data) > 0:
|
||||
full_sql = data[0]
|
||||
text_before_cursor = data[1]
|
||||
|
||||
# Check the transaction and connection status
|
||||
status, error_msg, conn, trans_obj, session_obj = check_transaction_status(trans_id)
|
||||
if status and conn is not None \
|
||||
and trans_obj is not None and session_obj is not None:
|
||||
|
||||
# Create object of SQLAutoComplete class and pass connection object
|
||||
auto_complete_obj = SQLAutoComplete(sid=trans_obj.sid, did=trans_obj.did, conn=conn)
|
||||
|
||||
# Get the auto completion suggestions.
|
||||
res = auto_complete_obj.get_completions(full_sql, text_before_cursor)
|
||||
else:
|
||||
status = False
|
||||
res = error_msg
|
||||
|
||||
return make_json_response(data={'status': status, 'result': res})
|
||||
|
||||
|
||||
@blueprint.route("/sqleditor.js")
|
||||
@login_required
|
||||
def script():
|
||||
|
|
|
@ -243,3 +243,50 @@
|
|||
width: 100%;
|
||||
overflow: auto;
|
||||
}
|
||||
|
||||
.CodeMirror-hints {
|
||||
position: absolute;
|
||||
z-index: 10;
|
||||
overflow: hidden;
|
||||
list-style: none;
|
||||
|
||||
margin: 0;
|
||||
padding: 2px;
|
||||
|
||||
-webkit-box-shadow: 2px 3px 5px rgba(0,0,0,.2);
|
||||
-moz-box-shadow: 2px 3px 5px rgba(0,0,0,.2);
|
||||
box-shadow: 2px 3px 5px rgba(0,0,0,.2);
|
||||
border-radius: 3px;
|
||||
border: 1px solid silver;
|
||||
|
||||
background: white;
|
||||
font-size: 90%;
|
||||
font-family: monospace;
|
||||
|
||||
max-height: 20em;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.CodeMirror-hint {
|
||||
margin: 0;
|
||||
padding: 0 4px;
|
||||
border-radius: 2px;
|
||||
max-width: 19em;
|
||||
overflow: hidden;
|
||||
white-space: pre;
|
||||
color: black;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
li.CodeMirror-hint-active {
|
||||
background: #08f;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.sqleditor-hint {
|
||||
padding-left: 20px;
|
||||
}
|
||||
|
||||
.CodeMirror-hint .fa::before {
|
||||
padding-right: 7px;
|
||||
}
|
|
@ -6,6 +6,7 @@ define(
|
|||
'codemirror/mode/sql/sql', 'codemirror/addon/selection/mark-selection',
|
||||
'codemirror/addon/selection/active-line', 'backbone.paginator',
|
||||
'codemirror/addon/fold/foldgutter', 'codemirror/addon/fold/foldcode',
|
||||
'codemirror/addon/hint/show-hint', 'codemirror/addon/hint/sql-hint',
|
||||
'codemirror/addon/fold/pgadmin-sqlfoldcode', 'backgrid.paginator',
|
||||
'wcdocker', 'pgadmin.file_manager'
|
||||
],
|
||||
|
@ -238,7 +239,8 @@ define(
|
|||
rangeFinder: CodeMirror.fold.combine(CodeMirror.pgadminBeginRangeFinder, CodeMirror.pgadminIfRangeFinder,
|
||||
CodeMirror.pgadminLoopRangeFinder, CodeMirror.pgadminCaseRangeFinder)
|
||||
},
|
||||
gutters: ["CodeMirror-linenumbers", "CodeMirror-foldgutter"]
|
||||
gutters: ["CodeMirror-linenumbers", "CodeMirror-foldgutter"],
|
||||
extraKeys: {"Ctrl-Space": "autocomplete"}
|
||||
});
|
||||
|
||||
// Create panels for 'Data Output', 'Explain', 'Messages' and 'History'
|
||||
|
@ -295,6 +297,107 @@ define(
|
|||
self.history_panel = main_docker.addPanel('history', wcDocker.DOCK.STACKED, self.data_output_panel);
|
||||
|
||||
self.render_history_grid();
|
||||
|
||||
/* We have override/register the hint function of CodeMirror
|
||||
* to provide our own hint logic.
|
||||
*/
|
||||
CodeMirror.registerHelper("hint", "sql", function(editor, options) {
|
||||
var data = [],
|
||||
result = [];
|
||||
var doc = editor.getDoc();
|
||||
var cur = doc.getCursor();
|
||||
var current_line = cur.line; // gets the line number in the cursor position
|
||||
var current_cur = cur.ch; // get the current cursor position
|
||||
|
||||
/* Render function for hint to add our own class
|
||||
* and icon as per the object type.
|
||||
*/
|
||||
var hint_render = function(elt, data, cur) {
|
||||
var el = document.createElement('span');
|
||||
|
||||
switch(cur.type) {
|
||||
case 'database':
|
||||
el.className = 'sqleditor-hint pg-icon-' + cur.type;
|
||||
break;
|
||||
case 'datatype':
|
||||
el.className = 'sqleditor-hint icon-type';
|
||||
break;
|
||||
case 'keyword':
|
||||
el.className = 'fa fa-key';
|
||||
break;
|
||||
case 'table alias':
|
||||
el.className = 'fa fa-at';
|
||||
break;
|
||||
default:
|
||||
el.className = 'sqleditor-hint icon-' + cur.type;
|
||||
}
|
||||
|
||||
el.appendChild(document.createTextNode(cur.text));
|
||||
elt.appendChild(el);
|
||||
};
|
||||
|
||||
var full_text = doc.getValue();
|
||||
// Get the text from start to the current cursor position.
|
||||
var text_before_cursor = doc.getRange({ line: 0, ch: 0 },
|
||||
{ line: current_line, ch: current_cur });
|
||||
|
||||
data.push(full_text);
|
||||
data.push(text_before_cursor);
|
||||
|
||||
// Make ajax call to find the autocomplete data
|
||||
$.ajax({
|
||||
url: "{{ url_for('sqleditor.index') }}" + "autocomplete/" + self.transId,
|
||||
method: 'POST',
|
||||
async: false,
|
||||
contentType: "application/json",
|
||||
data: JSON.stringify(data),
|
||||
success: function(res) {
|
||||
_.each(res.data.result, function(obj, key) {
|
||||
result.push({
|
||||
text: key, type: obj.object_type,
|
||||
render: hint_render
|
||||
});
|
||||
});
|
||||
|
||||
// Sort function to sort the suggestion's alphabetically.
|
||||
result.sort(function(a, b){
|
||||
var textA = a.text.toLowerCase(), textB = b.text.toLowerCase()
|
||||
if (textA < textB) //sort string ascending
|
||||
return -1
|
||||
if (textA > textB)
|
||||
return 1
|
||||
return 0 //default return value (no sorting)
|
||||
})
|
||||
}
|
||||
});
|
||||
|
||||
/* Below logic find the start and end point
|
||||
* to replace the selected auto complete suggestion.
|
||||
*/
|
||||
var token = editor.getTokenAt(cur), start, end, search;
|
||||
if (token.end > cur.ch) {
|
||||
token.end = cur.ch;
|
||||
token.string = token.string.slice(0, cur.ch - token.start);
|
||||
}
|
||||
|
||||
if (token.string.match(/^[.`\w@]\w*$/)) {
|
||||
search = token.string;
|
||||
start = token.start;
|
||||
end = token.end;
|
||||
} else {
|
||||
start = end = cur.ch;
|
||||
search = "";
|
||||
}
|
||||
|
||||
/* Added 1 in the start position if search string
|
||||
* started with "." or "`" else auto complete of code mirror
|
||||
* will remove the "." when user select any suggestion.
|
||||
*/
|
||||
if (search.charAt(0) == "." || search.charAt(0) == "``")
|
||||
start += 1;
|
||||
|
||||
return {list: result, from: {line: current_line, ch: start }, to: { line: current_line, ch: end }};
|
||||
});
|
||||
},
|
||||
|
||||
/* This function is responsible to create and render the
|
||||
|
@ -782,7 +885,7 @@ define(
|
|||
el: self.container,
|
||||
handler: self
|
||||
});
|
||||
self.transId = self.container.data('transId');
|
||||
self.transId = self.gridView.transId = self.container.data('transId');
|
||||
|
||||
self.gridView.editor_title = editor_title;
|
||||
self.gridView.current_file = undefined;
|
||||
|
|
|
@ -0,0 +1,863 @@
|
|||
##########################################################################
|
||||
#
|
||||
# pgAdmin 4 - PostgreSQL Tools
|
||||
#
|
||||
# Copyright (C) 2013 - 2016, The pgAdmin Development Team
|
||||
# This software is released under the PostgreSQL Licence
|
||||
#
|
||||
##########################################################################
|
||||
|
||||
"""A blueprint module implementing the sql auto complete feature."""
|
||||
|
||||
import sys
|
||||
import re
|
||||
import sqlparse
|
||||
import itertools
|
||||
import operator
|
||||
from collections import namedtuple
|
||||
from sqlparse.sql import Comparison, Identifier, Where
|
||||
from .parseutils import (
|
||||
last_word, extract_tables, find_prev_keyword, parse_partial_identifier)
|
||||
from .prioritization import PrevalenceCounter
|
||||
from .completion import Completion
|
||||
from .function_metadata import FunctionMetadata
|
||||
from flask import render_template
|
||||
from pgadmin.utils.driver import get_driver
|
||||
from config import PG_DEFAULT_DRIVER
|
||||
|
||||
PY2 = sys.version_info[0] == 2
|
||||
PY3 = sys.version_info[0] == 3
|
||||
|
||||
if PY3:
|
||||
string_types = str
|
||||
else:
|
||||
string_types = basestring
|
||||
|
||||
Database = namedtuple('Database', [])
|
||||
Schema = namedtuple('Schema', [])
|
||||
Table = namedtuple('Table', ['schema'])
|
||||
|
||||
Function = namedtuple('Function', ['schema', 'filter'])
|
||||
# For convenience, don't require the `filter` argument in Function constructor
|
||||
Function.__new__.__defaults__ = (None, None)
|
||||
|
||||
Column = namedtuple('Column', ['tables', 'drop_unique'])
|
||||
Column.__new__.__defaults__ = (None, None)
|
||||
|
||||
View = namedtuple('View', ['schema'])
|
||||
Keyword = namedtuple('Keyword', [])
|
||||
Datatype = namedtuple('Datatype', ['schema'])
|
||||
Alias = namedtuple('Alias', ['aliases'])
|
||||
Match = namedtuple('Match', ['completion', 'priority'])
|
||||
|
||||
try:
|
||||
from collections import Counter
|
||||
except ImportError:
|
||||
# python 2.6
|
||||
from .counter import Counter
|
||||
|
||||
# Regex for finding "words" in documents.
|
||||
_FIND_WORD_RE = re.compile(r'([a-zA-Z0-9_]+|[^a-zA-Z0-9_\s]+)')
|
||||
_FIND_BIG_WORD_RE = re.compile(r'([^\s]+)')
|
||||
|
||||
|
||||
class SQLAutoComplete(object):
|
||||
"""
|
||||
class SQLAutoComplete
|
||||
|
||||
This class is used to provide the postgresql's autocomplete feature.
|
||||
This class used sqlparse to parse the given sql and psycopg2 to make
|
||||
the connection and get the tables, schemas, functions etc. based on
|
||||
the query.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
This method is used to initialize the class.
|
||||
|
||||
Args:
|
||||
**kwargs : N number of parameters
|
||||
"""
|
||||
|
||||
self.sid = kwargs['sid'] if 'sid' in kwargs else None
|
||||
self.did = kwargs['did'] if 'did' in kwargs else None
|
||||
self.conn = kwargs['conn'] if 'conn' in kwargs else None
|
||||
self.keywords = []
|
||||
|
||||
manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(self.sid)
|
||||
|
||||
ver = manager.version
|
||||
# we will set template path for sql scripts
|
||||
if ver >= 90100:
|
||||
self.sql_path = 'sqlautocomplete/sql/9.1_plus'
|
||||
|
||||
self.search_path = []
|
||||
# Fetch the search path
|
||||
if self.conn.connected():
|
||||
query = render_template("/".join([self.sql_path, 'schema.sql']), search_path=True)
|
||||
status, res = self.conn.execute_dict(query)
|
||||
if status:
|
||||
for record in res['rows']:
|
||||
self.search_path.append(record['schema'])
|
||||
|
||||
# Fetch the keywords
|
||||
query = render_template("/".join([self.sql_path, 'keywords.sql']))
|
||||
status, res = self.conn.execute_dict(query)
|
||||
if status:
|
||||
for record in res['rows']:
|
||||
self.keywords.append(record['word'])
|
||||
|
||||
self.text_before_cursor = None
|
||||
self.prioritizer = PrevalenceCounter(self.keywords)
|
||||
|
||||
self.reserved_words = set()
|
||||
for x in self.keywords:
|
||||
self.reserved_words.update(x.split())
|
||||
self.name_pattern = re.compile("^[_a-z][_a-z0-9\$]*$")
|
||||
|
||||
def escape_name(self, name):
|
||||
if name and ((not self.name_pattern.match(name)) or
|
||||
(name.upper() in self.reserved_words)):
|
||||
name = '"%s"' % name
|
||||
|
||||
return name
|
||||
|
||||
def unescape_name(self, name):
|
||||
if name and name[0] == '"' and name[-1] == '"':
|
||||
name = name[1:-1]
|
||||
|
||||
return name
|
||||
|
||||
def escaped_names(self, names):
|
||||
return [self.escape_name(name) for name in names]
|
||||
|
||||
def find_matches(self, text, collection, mode='fuzzy',
|
||||
meta=None, meta_collection=None):
|
||||
"""
|
||||
Find completion matches for the given text.
|
||||
|
||||
Given the user's input text and a collection of available
|
||||
completions, find completions matching the last word of the
|
||||
text.
|
||||
|
||||
`mode` can be either 'fuzzy', or 'strict'
|
||||
'fuzzy': fuzzy matching, ties broken by name prevalance
|
||||
`keyword`: start only matching, ties broken by keyword prevalance
|
||||
|
||||
yields prompt_toolkit Completion instances for any matches found
|
||||
in the collection of available completions.
|
||||
|
||||
Args:
|
||||
text:
|
||||
collection:
|
||||
mode:
|
||||
meta:
|
||||
meta_collection:
|
||||
"""
|
||||
|
||||
text = last_word(text, include='most_punctuations').lower()
|
||||
text_len = len(text)
|
||||
|
||||
if text and text[0] == '"':
|
||||
# text starts with double quote; user is manually escaping a name
|
||||
# Match on everything that follows the double-quote. Note that
|
||||
# text_len is calculated before removing the quote, so the
|
||||
# Completion.position value is correct
|
||||
text = text[1:]
|
||||
|
||||
if mode == 'fuzzy':
|
||||
fuzzy = True
|
||||
priority_func = self.prioritizer.name_count
|
||||
else:
|
||||
fuzzy = False
|
||||
priority_func = self.prioritizer.keyword_count
|
||||
|
||||
# Construct a `_match` function for either fuzzy or non-fuzzy matching
|
||||
# The match function returns a 2-tuple used for sorting the matches,
|
||||
# or None if the item doesn't match
|
||||
# Note: higher priority values mean more important, so use negative
|
||||
# signs to flip the direction of the tuple
|
||||
if fuzzy:
|
||||
regex = '.*?'.join(map(re.escape, text))
|
||||
pat = re.compile('(%s)' % regex)
|
||||
|
||||
def _match(item):
|
||||
r = pat.search(self.unescape_name(item.lower()))
|
||||
if r:
|
||||
return -len(r.group()), -r.start()
|
||||
else:
|
||||
match_end_limit = len(text)
|
||||
|
||||
def _match(item):
|
||||
match_point = item.lower().find(text, 0, match_end_limit)
|
||||
if match_point >= 0:
|
||||
# Use negative infinity to force keywords to sort after all
|
||||
# fuzzy matches
|
||||
return -float('Infinity'), -match_point
|
||||
|
||||
if meta_collection:
|
||||
# Each possible completion in the collection has a corresponding
|
||||
# meta-display string
|
||||
collection = zip(collection, meta_collection)
|
||||
else:
|
||||
# All completions have an identical meta
|
||||
collection = zip(collection, itertools.repeat(meta))
|
||||
|
||||
matches = []
|
||||
|
||||
for item, meta in collection:
|
||||
sort_key = _match(item)
|
||||
if sort_key:
|
||||
if meta and len(meta) > 50:
|
||||
# Truncate meta-text to 50 characters, if necessary
|
||||
meta = meta[:47] + u'...'
|
||||
|
||||
# Lexical order of items in the collection, used for
|
||||
# tiebreaking items with the same match group length and start
|
||||
# position. Since we use *higher* priority to mean "more
|
||||
# important," we use -ord(c) to prioritize "aa" > "ab" and end
|
||||
# with 1 to prioritize shorter strings (ie "user" > "users").
|
||||
# We also use the unescape_name to make sure quoted names have
|
||||
# the same priority as unquoted names.
|
||||
lexical_priority = tuple(-ord(c) for c in self.unescape_name(item)) + (1,)
|
||||
|
||||
priority = sort_key, priority_func(item), lexical_priority
|
||||
|
||||
matches.append(Match(
|
||||
completion=Completion(item, -text_len, display_meta=meta),
|
||||
priority=priority))
|
||||
|
||||
return matches
|
||||
|
||||
def get_completions(self, text, text_before_cursor):
|
||||
self.text_before_cursor = text_before_cursor
|
||||
|
||||
word_before_cursor = self.get_word_before_cursor(word=True)
|
||||
matches = []
|
||||
suggestions = self.suggest_type(text, text_before_cursor)
|
||||
|
||||
for suggestion in suggestions:
|
||||
suggestion_type = type(suggestion)
|
||||
|
||||
# Map suggestion type to method
|
||||
# e.g. 'table' -> self.get_table_matches
|
||||
matcher = self.suggestion_matchers[suggestion_type]
|
||||
matches.extend(matcher(self, suggestion, word_before_cursor))
|
||||
|
||||
# Sort matches so highest priorities are first
|
||||
matches = sorted(matches, key=operator.attrgetter('priority'),
|
||||
reverse=True)
|
||||
|
||||
result = dict()
|
||||
for m in matches:
|
||||
result[m.completion.display] = {'object_type': m.completion.display_meta}
|
||||
|
||||
return result
|
||||
|
||||
def get_column_matches(self, suggestion, word_before_cursor):
|
||||
tables = suggestion.tables
|
||||
scoped_cols = self.populate_scoped_cols(tables)
|
||||
|
||||
if suggestion.drop_unique:
|
||||
# drop_unique is used for 'tb11 JOIN tbl2 USING (...' which should
|
||||
# suggest only columns that appear in more than one table
|
||||
scoped_cols = [col for (col, count)
|
||||
in Counter(scoped_cols).items()
|
||||
if count > 1 and col != '*']
|
||||
|
||||
return self.find_matches(word_before_cursor, scoped_cols, mode='strict', meta='column')
|
||||
|
||||
def get_function_matches(self, suggestion, word_before_cursor):
|
||||
if suggestion.filter == 'is_set_returning':
|
||||
# Only suggest set-returning functions
|
||||
funcs = self.populate_functions(suggestion.schema)
|
||||
else:
|
||||
funcs = self.populate_schema_objects(suggestion.schema, 'functions')
|
||||
|
||||
# Function overloading means we way have multiple functions of the same
|
||||
# name at this point, so keep unique names only
|
||||
funcs = set(funcs)
|
||||
|
||||
funcs = self.find_matches(word_before_cursor, funcs, mode='strict', meta='function')
|
||||
|
||||
return funcs
|
||||
|
||||
def get_schema_matches(self, _, word_before_cursor):
|
||||
schema_names = []
|
||||
|
||||
query = render_template("/".join([self.sql_path, 'schema.sql']))
|
||||
status, res = self.conn.execute_dict(query)
|
||||
if status:
|
||||
for record in res['rows']:
|
||||
schema_names.append(record['schema'])
|
||||
|
||||
# Unless we're sure the user really wants them, hide schema names
|
||||
# starting with pg_, which are mostly temporary schemas
|
||||
if not word_before_cursor.startswith('pg_'):
|
||||
schema_names = [s for s in schema_names if not s.startswith('pg_')]
|
||||
|
||||
return self.find_matches(word_before_cursor, schema_names, mode='strict', meta='schema')
|
||||
|
||||
def get_table_matches(self, suggestion, word_before_cursor):
|
||||
tables = self.populate_schema_objects(suggestion.schema, 'tables')
|
||||
|
||||
# Unless we're sure the user really wants them, don't suggest the
|
||||
# pg_catalog tables that are implicitly on the search path
|
||||
if not suggestion.schema and (
|
||||
not word_before_cursor.startswith('pg_')):
|
||||
tables = [t for t in tables if not t.startswith('pg_')]
|
||||
|
||||
return self.find_matches(word_before_cursor, tables, mode='strict', meta='table')
|
||||
|
||||
def get_view_matches(self, suggestion, word_before_cursor):
|
||||
views = self.populate_schema_objects(suggestion.schema, 'views')
|
||||
|
||||
if not suggestion.schema and (
|
||||
not word_before_cursor.startswith('pg_')):
|
||||
views = [v for v in views if not v.startswith('pg_')]
|
||||
|
||||
return self.find_matches(word_before_cursor, views, mode='strict', meta='view')
|
||||
|
||||
def get_alias_matches(self, suggestion, word_before_cursor):
|
||||
aliases = suggestion.aliases
|
||||
return self.find_matches(word_before_cursor, aliases, mode='strict',
|
||||
meta='table alias')
|
||||
|
||||
def get_database_matches(self, _, word_before_cursor):
|
||||
databases = []
|
||||
|
||||
query = render_template("/".join([self.sql_path, 'databases.sql']))
|
||||
if self.conn.connected():
|
||||
status, res = self.conn.execute_dict(query)
|
||||
if status:
|
||||
for record in res['rows']:
|
||||
databases.append(record['datname'])
|
||||
|
||||
return self.find_matches(word_before_cursor, databases, mode='strict',
|
||||
meta='database')
|
||||
|
||||
def get_keyword_matches(self, _, word_before_cursor):
|
||||
return self.find_matches(word_before_cursor, self.keywords,
|
||||
mode='strict', meta='keyword')
|
||||
|
||||
def get_datatype_matches(self, suggestion, word_before_cursor):
|
||||
# suggest custom datatypes
|
||||
types = self.populate_schema_objects(suggestion.schema, 'datatypes')
|
||||
matches = self.find_matches(word_before_cursor, types, mode='strict', meta='datatype')
|
||||
|
||||
return matches
|
||||
|
||||
def get_word_before_cursor(self, word=False):
|
||||
"""
|
||||
Give the word before the cursor.
|
||||
If we have whitespace before the cursor this returns an empty string.
|
||||
|
||||
Args:
|
||||
word:
|
||||
"""
|
||||
|
||||
if self.text_before_cursor[-1:].isspace():
|
||||
return ''
|
||||
else:
|
||||
return self.text_before_cursor[self.find_start_of_previous_word(word=word):]
|
||||
|
||||
def find_start_of_previous_word(self, count=1, word=False):
|
||||
"""
|
||||
Return an index relative to the cursor position pointing to the start
|
||||
of the previous word. Return `None` if nothing was found.
|
||||
|
||||
Args:
|
||||
count:
|
||||
word:
|
||||
"""
|
||||
|
||||
# Reverse the text before the cursor, in order to do an efficient
|
||||
# backwards search.
|
||||
text_before_cursor = self.text_before_cursor[::-1]
|
||||
|
||||
regex = _FIND_BIG_WORD_RE if word else _FIND_WORD_RE
|
||||
iterator = regex.finditer(text_before_cursor)
|
||||
|
||||
try:
|
||||
for i, match in enumerate(iterator):
|
||||
if i + 1 == count:
|
||||
return - match.end(1)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
suggestion_matchers = {
|
||||
Column: get_column_matches,
|
||||
Function: get_function_matches,
|
||||
Schema: get_schema_matches,
|
||||
Table: get_table_matches,
|
||||
View: get_view_matches,
|
||||
Alias: get_alias_matches,
|
||||
Database: get_database_matches,
|
||||
Keyword: get_keyword_matches,
|
||||
Datatype: get_datatype_matches,
|
||||
}
|
||||
|
||||
def populate_scoped_cols(self, scoped_tbls):
|
||||
""" Find all columns in a set of scoped_tables
|
||||
:param scoped_tbls: list of TableReference namedtuples
|
||||
:return: list of column names
|
||||
"""
|
||||
|
||||
columns = []
|
||||
for tbl in scoped_tbls:
|
||||
if tbl.schema:
|
||||
# A fully qualified schema.relname reference
|
||||
schema = self.escape_name(tbl.schema)
|
||||
relname = self.escape_name(tbl.name)
|
||||
|
||||
if tbl.is_function:
|
||||
query = render_template("/".join([self.sql_path, 'functions.sql']),
|
||||
schema_name=schema,
|
||||
func_name=relname)
|
||||
|
||||
if self.conn.connected():
|
||||
status, res = self.conn.execute_dict(query)
|
||||
func = None
|
||||
if status:
|
||||
for row in res['rows']:
|
||||
func = FunctionMetadata(row['schema_name'], row['func_name'],
|
||||
row['arg_list'], row['return_type'],
|
||||
row['is_aggregate'], row['is_window'],
|
||||
row['is_set_returning'])
|
||||
if func:
|
||||
columns.extend(func.fieldnames())
|
||||
else:
|
||||
# We don't know if schema.relname is a table or view. Since
|
||||
# tables and views cannot share the same name, we can check
|
||||
# one at a time
|
||||
|
||||
query = render_template("/".join([self.sql_path, 'columns.sql']),
|
||||
object_name='table',
|
||||
schema_name=schema,
|
||||
rel_name=relname)
|
||||
|
||||
if self.conn.connected():
|
||||
status, res = self.conn.execute_dict(query)
|
||||
if status:
|
||||
if len(res['rows']) > 0:
|
||||
# Table exists, so don't bother checking for a view
|
||||
for record in res['rows']:
|
||||
columns.append(record['column_name'])
|
||||
else:
|
||||
query = render_template("/".join([self.sql_path, 'columns.sql']),
|
||||
object_name='view',
|
||||
schema_name=schema,
|
||||
rel_name=relname)
|
||||
|
||||
if self.conn.connected():
|
||||
status, res = self.conn.execute_dict(query)
|
||||
if status:
|
||||
for record in res['rows']:
|
||||
columns.append(record['column_name'])
|
||||
else:
|
||||
# Schema not specified, so traverse the search path looking for
|
||||
# a table or view that matches. Note that in order to get proper
|
||||
# shadowing behavior, we need to check both views and tables for
|
||||
# each schema before checking the next schema
|
||||
for schema in self.search_path:
|
||||
relname = self.escape_name(tbl.name)
|
||||
|
||||
if tbl.is_function:
|
||||
query = render_template("/".join([self.sql_path, 'functions.sql']),
|
||||
schema_name=schema,
|
||||
func_name=relname)
|
||||
|
||||
if self.conn.connected():
|
||||
status, res = self.conn.execute_dict(query)
|
||||
func = None
|
||||
if status:
|
||||
for row in res['rows']:
|
||||
func = FunctionMetadata(row['schema_name'], row['func_name'],
|
||||
row['arg_list'], row['return_type'],
|
||||
row['is_aggregate'], row['is_window'],
|
||||
row['is_set_returning'])
|
||||
if func:
|
||||
columns.extend(func.fieldnames())
|
||||
else:
|
||||
query = render_template("/".join([self.sql_path, 'columns.sql']),
|
||||
object_name='table',
|
||||
schema_name=schema,
|
||||
rel_name=relname)
|
||||
|
||||
if self.conn.connected():
|
||||
status, res = self.conn.execute_dict(query)
|
||||
if status:
|
||||
if len(res['rows']) > 0:
|
||||
# Table exists, so don't bother checking for a view
|
||||
for record in res['rows']:
|
||||
columns.append(record['column_name'])
|
||||
else:
|
||||
query = render_template("/".join([self.sql_path, 'columns.sql']),
|
||||
object_name='view',
|
||||
schema_name=schema,
|
||||
rel_name=relname)
|
||||
|
||||
if self.conn.connected():
|
||||
status, res = self.conn.execute_dict(query)
|
||||
if status:
|
||||
for record in res['rows']:
|
||||
columns.append(record['column_name'])
|
||||
|
||||
return columns
|
||||
|
||||
def populate_schema_objects(self, schema, obj_type):
|
||||
"""
|
||||
Returns list of tables or functions for a (optional) schema
|
||||
|
||||
Args:
|
||||
schema:
|
||||
obj_type:
|
||||
"""
|
||||
|
||||
in_clause = ''
|
||||
query = ''
|
||||
objects = []
|
||||
|
||||
if schema:
|
||||
in_clause = '\'' + schema + '\''
|
||||
else:
|
||||
for r in self.search_path:
|
||||
in_clause += '\'' + r + '\','
|
||||
|
||||
# Remove extra comma
|
||||
if len(in_clause) > 0:
|
||||
in_clause = in_clause[:-1]
|
||||
|
||||
if obj_type == 'tables':
|
||||
query = render_template("/".join([self.sql_path, 'tableview.sql']),
|
||||
schema_names=in_clause,
|
||||
object_name='tables')
|
||||
elif obj_type == 'views':
|
||||
query = render_template("/".join([self.sql_path, 'tableview.sql']),
|
||||
schema_names=in_clause,
|
||||
object_name='views')
|
||||
elif obj_type == 'functions':
|
||||
query = render_template("/".join([self.sql_path, 'functions.sql']),
|
||||
schema_names=in_clause)
|
||||
elif obj_type == 'datatypes':
|
||||
query = render_template("/".join([self.sql_path, 'datatypes.sql']),
|
||||
schema_names=in_clause)
|
||||
|
||||
if self.conn.connected():
|
||||
status, res = self.conn.execute_dict(query)
|
||||
if status:
|
||||
for record in res['rows']:
|
||||
objects.append(record['object_name'])
|
||||
|
||||
return objects
|
||||
|
||||
def populate_functions(self, schema):
|
||||
"""
|
||||
Returns a list of function names
|
||||
|
||||
filter_func is a function that accepts a FunctionMetadata namedtuple
|
||||
and returns a boolean indicating whether that function should be
|
||||
kept or discarded
|
||||
|
||||
Args:
|
||||
schema:
|
||||
"""
|
||||
|
||||
in_clause = ''
|
||||
funcs = []
|
||||
|
||||
if schema:
|
||||
in_clause = '\'' + schema + '\''
|
||||
else:
|
||||
for r in self.search_path:
|
||||
in_clause += '\'' + r + '\','
|
||||
|
||||
# Remove extra comma
|
||||
if len(in_clause) > 0:
|
||||
in_clause = in_clause[:-1]
|
||||
|
||||
query = render_template("/".join([self.sql_path, 'functions.sql']),
|
||||
schema_names=in_clause,
|
||||
is_set_returning=True)
|
||||
|
||||
if self.conn.connected():
|
||||
status, res = self.conn.execute_dict(query)
|
||||
if status:
|
||||
for record in res['rows']:
|
||||
funcs.append(record['object_name'])
|
||||
|
||||
return funcs
|
||||
|
||||
def suggest_type(self, full_text, text_before_cursor):
|
||||
"""
|
||||
Takes the full_text that is typed so far and also the text before the
|
||||
cursor to suggest completion type and scope.
|
||||
|
||||
Returns a tuple with a type of entity ('table', 'column' etc) and a scope.
|
||||
A scope for a column category will be a list of tables.
|
||||
|
||||
Args:
|
||||
full_text: Contains complete query
|
||||
text_before_cursor: Contains text before the cursor
|
||||
"""
|
||||
|
||||
word_before_cursor = last_word(text_before_cursor, include='many_punctuations')
|
||||
|
||||
identifier = None
|
||||
|
||||
def strip_named_query(txt):
|
||||
"""
|
||||
This will strip "save named query" command in the beginning of the line:
|
||||
'\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
|
||||
' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
|
||||
|
||||
Args:
|
||||
txt:
|
||||
"""
|
||||
|
||||
pattern = re.compile(r'^\s*\\ns\s+[A-z0-9\-_]+\s+')
|
||||
if pattern.match(txt):
|
||||
txt = pattern.sub('', txt)
|
||||
return txt
|
||||
|
||||
full_text = strip_named_query(full_text)
|
||||
text_before_cursor = strip_named_query(text_before_cursor)
|
||||
|
||||
# If we've partially typed a word then word_before_cursor won't be an empty
|
||||
# string. In that case we want to remove the partially typed string before
|
||||
# sending it to the sqlparser. Otherwise the last token will always be the
|
||||
# partially typed string which renders the smart completion useless because
|
||||
# it will always return the list of keywords as completion.
|
||||
if word_before_cursor:
|
||||
if word_before_cursor[-1] == '(' or word_before_cursor[0] == '\\':
|
||||
parsed = sqlparse.parse(text_before_cursor)
|
||||
else:
|
||||
parsed = sqlparse.parse(
|
||||
text_before_cursor[:-len(word_before_cursor)])
|
||||
|
||||
identifier = parse_partial_identifier(word_before_cursor)
|
||||
else:
|
||||
parsed = sqlparse.parse(text_before_cursor)
|
||||
|
||||
statement = None
|
||||
if len(parsed) > 1:
|
||||
# Multiple statements being edited -- isolate the current one by
|
||||
# cumulatively summing statement lengths to find the one that bounds the
|
||||
# current position
|
||||
current_pos = len(text_before_cursor)
|
||||
stmt_start, stmt_end = 0, 0
|
||||
|
||||
for statement in parsed:
|
||||
stmt_len = len(statement.to_unicode())
|
||||
stmt_start, stmt_end = stmt_end, stmt_end + stmt_len
|
||||
|
||||
if stmt_end >= current_pos:
|
||||
break
|
||||
|
||||
text_before_cursor = full_text[stmt_start:current_pos]
|
||||
full_text = full_text[stmt_start:]
|
||||
elif parsed:
|
||||
# A single statement
|
||||
statement = parsed[0]
|
||||
else:
|
||||
# The empty string
|
||||
statement = None
|
||||
|
||||
last_token = statement and statement.token_prev(len(statement.tokens)) or ''
|
||||
|
||||
return self.suggest_based_on_last_token(last_token, text_before_cursor,
|
||||
full_text, identifier)
|
||||
|
||||
def suggest_based_on_last_token(self, token, text_before_cursor, full_text, identifier):
|
||||
if isinstance(token, string_types):
|
||||
token_v = token.lower()
|
||||
elif isinstance(token, Comparison):
|
||||
# If 'token' is a Comparison type such as
|
||||
# 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling
|
||||
# token.value on the comparison type will only return the lhs of the
|
||||
# comparison. In this case a.id. So we need to do token.tokens to get
|
||||
# both sides of the comparison and pick the last token out of that
|
||||
# list.
|
||||
token_v = token.tokens[-1].value.lower()
|
||||
elif isinstance(token, Where):
|
||||
# sqlparse groups all tokens from the where clause into a single token
|
||||
# list. This means that token.value may be something like
|
||||
# 'where foo > 5 and '. We need to look "inside" token.tokens to handle
|
||||
# suggestions in complicated where clauses correctly
|
||||
prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
|
||||
return self.suggest_based_on_last_token(
|
||||
prev_keyword, text_before_cursor, full_text, identifier)
|
||||
elif isinstance(token, Identifier):
|
||||
# If the previous token is an identifier, we can suggest datatypes if
|
||||
# we're in a parenthesized column/field list, e.g.:
|
||||
# CREATE TABLE foo (Identifier <CURSOR>
|
||||
# CREATE FUNCTION foo (Identifier <CURSOR>
|
||||
# If we're not in a parenthesized list, the most likely scenario is the
|
||||
# user is about to specify an alias, e.g.:
|
||||
# SELECT Identifier <CURSOR>
|
||||
# SELECT foo FROM Identifier <CURSOR>
|
||||
prev_keyword, _ = find_prev_keyword(text_before_cursor)
|
||||
if prev_keyword and prev_keyword.value == '(':
|
||||
# Suggest datatypes
|
||||
return self.suggest_based_on_last_token(
|
||||
'type', text_before_cursor, full_text, identifier)
|
||||
else:
|
||||
return Keyword(),
|
||||
else:
|
||||
token_v = token.value.lower()
|
||||
|
||||
if not token:
|
||||
return Keyword(),
|
||||
elif token_v.endswith('('):
|
||||
p = sqlparse.parse(text_before_cursor)[0]
|
||||
|
||||
if p.tokens and isinstance(p.tokens[-1], Where):
|
||||
# Four possibilities:
|
||||
# 1 - Parenthesized clause like "WHERE foo AND ("
|
||||
# Suggest columns/functions
|
||||
# 2 - Function call like "WHERE foo("
|
||||
# Suggest columns/functions
|
||||
# 3 - Subquery expression like "WHERE EXISTS ("
|
||||
# Suggest keywords, in order to do a subquery
|
||||
# 4 - Subquery OR array comparison like "WHERE foo = ANY("
|
||||
# Suggest columns/functions AND keywords. (If we wanted to be
|
||||
# really fancy, we could suggest only array-typed columns)
|
||||
|
||||
column_suggestions = self.suggest_based_on_last_token(
|
||||
'where', text_before_cursor, full_text, identifier)
|
||||
|
||||
# Check for a subquery expression (cases 3 & 4)
|
||||
where = p.tokens[-1]
|
||||
prev_tok = where.token_prev(len(where.tokens) - 1)
|
||||
|
||||
if isinstance(prev_tok, Comparison):
|
||||
# e.g. "SELECT foo FROM bar WHERE foo = ANY("
|
||||
prev_tok = prev_tok.tokens[-1]
|
||||
|
||||
prev_tok = prev_tok.value.lower()
|
||||
if prev_tok == 'exists':
|
||||
return Keyword(),
|
||||
else:
|
||||
return column_suggestions
|
||||
|
||||
# Get the token before the parens
|
||||
prev_tok = p.token_prev(len(p.tokens) - 1)
|
||||
if prev_tok and prev_tok.value and prev_tok.value.lower() == 'using':
|
||||
# tbl1 INNER JOIN tbl2 USING (col1, col2)
|
||||
tables = extract_tables(full_text)
|
||||
|
||||
# suggest columns that are present in more than one table
|
||||
return Column(tables=tables, drop_unique=True),
|
||||
|
||||
elif p.token_first().value.lower() == 'select':
|
||||
# If the lparen is preceeded by a space chances are we're about to
|
||||
# do a sub-select.
|
||||
if last_word(text_before_cursor,
|
||||
'all_punctuations').startswith('('):
|
||||
return Keyword(),
|
||||
# We're probably in a function argument list
|
||||
return Column(tables=extract_tables(full_text)),
|
||||
elif token_v in ('set', 'by', 'distinct'):
|
||||
return Column(tables=extract_tables(full_text)),
|
||||
elif token_v in ('select', 'where', 'having'):
|
||||
# Check for a table alias or schema qualification
|
||||
parent = (identifier and identifier.get_parent_name()) or []
|
||||
|
||||
if parent:
|
||||
tables = extract_tables(full_text)
|
||||
tables = tuple(t for t in tables if self.identifies(parent, t))
|
||||
return (Column(tables=tables),
|
||||
Table(schema=parent),
|
||||
View(schema=parent),
|
||||
Function(schema=parent),)
|
||||
else:
|
||||
return (Column(tables=extract_tables(full_text)),
|
||||
Function(schema=None),
|
||||
Keyword(),)
|
||||
|
||||
elif (token_v.endswith('join') and token.is_keyword) or \
|
||||
(token_v in ('copy', 'from', 'update', 'into', 'describe', 'truncate')):
|
||||
|
||||
schema = (identifier and identifier.get_parent_name()) or None
|
||||
|
||||
# Suggest tables from either the currently-selected schema or the
|
||||
# public schema if no schema has been specified
|
||||
suggest = [Table(schema=schema)]
|
||||
|
||||
if not schema:
|
||||
# Suggest schemas
|
||||
suggest.insert(0, Schema())
|
||||
|
||||
# Only tables can be TRUNCATED, otherwise suggest views
|
||||
if token_v != 'truncate':
|
||||
suggest.append(View(schema=schema))
|
||||
|
||||
# Suggest set-returning functions in the FROM clause
|
||||
if token_v == 'from' or (token_v.endswith('join') and token.is_keyword):
|
||||
suggest.append(Function(schema=schema, filter='is_set_returning'))
|
||||
|
||||
return tuple(suggest)
|
||||
|
||||
elif token_v in ('table', 'view', 'function'):
|
||||
# E.g. 'DROP FUNCTION <funcname>', 'ALTER TABLE <tablname>'
|
||||
rel_type = {'table': Table, 'view': View, 'function': Function}[token_v]
|
||||
schema = (identifier and identifier.get_parent_name()) or None
|
||||
if schema:
|
||||
return rel_type(schema=schema),
|
||||
else:
|
||||
return Schema(), rel_type(schema=schema)
|
||||
elif token_v == 'on':
|
||||
tables = extract_tables(full_text) # [(schema, table, alias), ...]
|
||||
parent = (identifier and identifier.get_parent_name()) or None
|
||||
if parent:
|
||||
# "ON parent.<suggestion>"
|
||||
# parent can be either a schema name or table alias
|
||||
tables = tuple(t for t in tables if self.identifies(parent, t))
|
||||
return (Column(tables=tables),
|
||||
Table(schema=parent),
|
||||
View(schema=parent),
|
||||
Function(schema=parent))
|
||||
else:
|
||||
# ON <suggestion>
|
||||
# Use table alias if there is one, otherwise the table name
|
||||
aliases = tuple(t.alias or t.name for t in tables)
|
||||
return Alias(aliases=aliases),
|
||||
|
||||
elif token_v in ('c', 'use', 'database', 'template'):
|
||||
# "\c <db", "use <db>", "DROP DATABASE <db>",
|
||||
# "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
|
||||
return Database(),
|
||||
elif token_v == 'schema':
|
||||
# DROP SCHEMA schema_name
|
||||
return Schema(),
|
||||
elif token_v.endswith(',') or token_v in ('=', 'and', 'or'):
|
||||
prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
|
||||
if prev_keyword:
|
||||
return self.suggest_based_on_last_token(
|
||||
prev_keyword, text_before_cursor, full_text, identifier)
|
||||
else:
|
||||
return ()
|
||||
elif token_v in ('type', '::'):
|
||||
# ALTER TABLE foo SET DATA TYPE bar
|
||||
# SELECT foo::bar
|
||||
# Note that tables are a form of composite type in postgresql, so
|
||||
# they're suggested here as well
|
||||
schema = (identifier and identifier.get_parent_name()) or None
|
||||
suggestions = [Datatype(schema=schema),
|
||||
Table(schema=schema)]
|
||||
if not schema:
|
||||
suggestions.append(Schema())
|
||||
return tuple(suggestions)
|
||||
else:
|
||||
return Keyword(),
|
||||
|
||||
def identifies(self, id, ref):
|
||||
"""
|
||||
Returns true if string `id` matches TableReference `ref`
|
||||
|
||||
Args:
|
||||
id:
|
||||
ref:
|
||||
"""
|
||||
return id == ref.alias or id == ref.name or (
|
||||
ref.schema and (id == ref.schema + '.' + ref.name))
|
|
@ -0,0 +1,67 @@
|
|||
"""
|
||||
Using Completion class from
|
||||
https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/completion.py
|
||||
"""
|
||||
|
||||
from __future__ import unicode_literals
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from six import with_metaclass
|
||||
|
||||
__all__ = (
|
||||
'Completion'
|
||||
)
|
||||
|
||||
|
||||
class Completion(object):
|
||||
"""
|
||||
:param text: The new string that will be inserted into the document.
|
||||
:param start_position: Position relative to the cursor_position where the
|
||||
new text will start. The text will be inserted between the
|
||||
start_position and the original cursor position.
|
||||
:param display: (optional string) If the completion has to be displayed
|
||||
differently in the completion menu.
|
||||
:param display_meta: (Optional string) Meta information about the
|
||||
completion, e.g. the path or source where it's coming from.
|
||||
:param get_display_meta: Lazy `display_meta`. Retrieve meta information
|
||||
only when meta is displayed.
|
||||
"""
|
||||
def __init__(self, text, start_position=0, display=None, display_meta=None,
|
||||
get_display_meta=None):
|
||||
self.text = text
|
||||
self.start_position = start_position
|
||||
self._display_meta = display_meta
|
||||
self._get_display_meta = get_display_meta
|
||||
|
||||
if display is None:
|
||||
self.display = text
|
||||
else:
|
||||
self.display = display
|
||||
|
||||
assert self.start_position <= 0
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(text=%r, start_position=%r)' % (
|
||||
self.__class__.__name__, self.text, self.start_position)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self.text == other.text and
|
||||
self.start_position == other.start_position and
|
||||
self.display == other.display and
|
||||
self.display_meta == other.display_meta)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.text, self.start_position, self.display, self.display_meta))
|
||||
|
||||
@property
|
||||
def display_meta(self):
|
||||
# Return meta-text. (This is lazy when using "get_display_meta".)
|
||||
if self._display_meta is not None:
|
||||
return self._display_meta
|
||||
|
||||
elif self._get_display_meta:
|
||||
self._display_meta = self._get_display_meta()
|
||||
return self._display_meta
|
||||
|
||||
else:
|
||||
return ''
|
|
@ -0,0 +1,189 @@
|
|||
"""
|
||||
Copied from http://code.activestate.com/recipes/576611-counter-class/
|
||||
"""
|
||||
|
||||
from operator import itemgetter
|
||||
from heapq import nlargest
|
||||
from itertools import repeat, ifilter
|
||||
|
||||
|
||||
class Counter(dict):
|
||||
'''Dict subclass for counting hashable objects. Sometimes called a bag
|
||||
or multiset. Elements are stored as dictionary keys and their counts
|
||||
are stored as dictionary values.
|
||||
|
||||
>>> Counter('zyzygy')
|
||||
Counter({'y': 3, 'z': 2, 'g': 1})
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self, iterable=None, **kwds):
|
||||
'''Create a new, empty Counter object. And if given, count elements
|
||||
from an input iterable. Or, initialize the count from another mapping
|
||||
of elements to their counts.
|
||||
|
||||
>>> c = Counter() # a new, empty counter
|
||||
>>> c = Counter('gallahad') # a new counter from an iterable
|
||||
>>> c = Counter({'a': 4, 'b': 2}) # a new counter from a mapping
|
||||
>>> c = Counter(a=4, b=2) # a new counter from keyword args
|
||||
|
||||
'''
|
||||
self.update(iterable, **kwds)
|
||||
|
||||
def __missing__(self, key):
|
||||
return 0
|
||||
|
||||
def most_common(self, n=None):
|
||||
'''List the n most common elements and their counts from the most
|
||||
common to the least. If n is None, then list all element counts.
|
||||
|
||||
>>> Counter('abracadabra').most_common(3)
|
||||
[('a', 5), ('r', 2), ('b', 2)]
|
||||
|
||||
'''
|
||||
if n is None:
|
||||
return sorted(self.iteritems(), key=itemgetter(1), reverse=True)
|
||||
return nlargest(n, self.iteritems(), key=itemgetter(1))
|
||||
|
||||
def elements(self):
|
||||
'''Iterator over elements repeating each as many times as its count.
|
||||
|
||||
>>> c = Counter('ABCABC')
|
||||
>>> sorted(c.elements())
|
||||
['A', 'A', 'B', 'B', 'C', 'C']
|
||||
|
||||
If an element's count has been set to zero or is a negative number,
|
||||
elements() will ignore it.
|
||||
|
||||
'''
|
||||
for elem, count in self.iteritems():
|
||||
for _ in repeat(None, count):
|
||||
yield elem
|
||||
|
||||
# Override dict methods where the meaning changes for Counter objects.
|
||||
|
||||
@classmethod
|
||||
def fromkeys(cls, iterable, v=None):
|
||||
raise NotImplementedError(
|
||||
'Counter.fromkeys() is undefined. Use Counter(iterable) instead.')
|
||||
|
||||
def update(self, iterable=None, **kwds):
|
||||
'''Like dict.update() but add counts instead of replacing them.
|
||||
|
||||
Source can be an iterable, a dictionary, or another Counter instance.
|
||||
|
||||
>>> c = Counter('which')
|
||||
>>> c.update('witch') # add elements from another iterable
|
||||
>>> d = Counter('watch')
|
||||
>>> c.update(d) # add elements from another counter
|
||||
>>> c['h'] # four 'h' in which, witch, and watch
|
||||
4
|
||||
|
||||
'''
|
||||
if iterable is not None:
|
||||
if hasattr(iterable, 'iteritems'):
|
||||
if self:
|
||||
self_get = self.get
|
||||
for elem, count in iterable.iteritems():
|
||||
self[elem] = self_get(elem, 0) + count
|
||||
else:
|
||||
dict.update(self, iterable) # fast path when counter is empty
|
||||
else:
|
||||
self_get = self.get
|
||||
for elem in iterable:
|
||||
self[elem] = self_get(elem, 0) + 1
|
||||
if kwds:
|
||||
self.update(kwds)
|
||||
|
||||
def copy(self):
|
||||
'Like dict.copy() but returns a Counter instance instead of a dict.'
|
||||
return Counter(self)
|
||||
|
||||
def __delitem__(self, elem):
|
||||
'Like dict.__delitem__() but does not raise KeyError for missing values.'
|
||||
if elem in self:
|
||||
dict.__delitem__(self, elem)
|
||||
|
||||
def __repr__(self):
|
||||
if not self:
|
||||
return '%s()' % self.__class__.__name__
|
||||
items = ', '.join(map('%r: %r'.__mod__, self.most_common()))
|
||||
return '%s({%s})' % (self.__class__.__name__, items)
|
||||
|
||||
# Multiset-style mathematical operations discussed in:
|
||||
# Knuth TAOCP Volume II section 4.6.3 exercise 19
|
||||
# and at http://en.wikipedia.org/wiki/Multiset
|
||||
#
|
||||
# Outputs guaranteed to only include positive counts.
|
||||
#
|
||||
# To strip negative and zero counts, add-in an empty counter:
|
||||
# c += Counter()
|
||||
|
||||
def __add__(self, other):
|
||||
'''Add counts from two counters.
|
||||
|
||||
>>> Counter('abbb') + Counter('bcc')
|
||||
Counter({'b': 4, 'c': 2, 'a': 1})
|
||||
|
||||
|
||||
'''
|
||||
if not isinstance(other, Counter):
|
||||
return NotImplemented
|
||||
result = Counter()
|
||||
for elem in set(self) | set(other):
|
||||
newcount = self[elem] + other[elem]
|
||||
if newcount > 0:
|
||||
result[elem] = newcount
|
||||
return result
|
||||
|
||||
def __sub__(self, other):
|
||||
''' Subtract count, but keep only results with positive counts.
|
||||
|
||||
>>> Counter('abbbc') - Counter('bccd')
|
||||
Counter({'b': 2, 'a': 1})
|
||||
|
||||
'''
|
||||
if not isinstance(other, Counter):
|
||||
return NotImplemented
|
||||
result = Counter()
|
||||
for elem in set(self) | set(other):
|
||||
newcount = self[elem] - other[elem]
|
||||
if newcount > 0:
|
||||
result[elem] = newcount
|
||||
return result
|
||||
|
||||
def __or__(self, other):
|
||||
'''Union is the maximum of value in either of the input counters.
|
||||
|
||||
>>> Counter('abbb') | Counter('bcc')
|
||||
Counter({'b': 3, 'c': 2, 'a': 1})
|
||||
|
||||
'''
|
||||
if not isinstance(other, Counter):
|
||||
return NotImplemented
|
||||
_max = max
|
||||
result = Counter()
|
||||
for elem in set(self) | set(other):
|
||||
newcount = _max(self[elem], other[elem])
|
||||
if newcount > 0:
|
||||
result[elem] = newcount
|
||||
return result
|
||||
|
||||
def __and__(self, other):
|
||||
''' Intersection is the minimum of corresponding counts.
|
||||
|
||||
>>> Counter('abbb') & Counter('bcc')
|
||||
Counter({'b': 1})
|
||||
|
||||
'''
|
||||
if not isinstance(other, Counter):
|
||||
return NotImplemented
|
||||
_min = min
|
||||
result = Counter()
|
||||
if len(self) < len(other):
|
||||
self, other = other, self
|
||||
for elem in ifilter(self.__contains__, other):
|
||||
newcount = _min(self[elem], other[elem])
|
||||
if newcount > 0:
|
||||
result[elem] = newcount
|
||||
return result
|
|
@ -0,0 +1,149 @@
|
|||
import re
|
||||
import sqlparse
|
||||
from sqlparse.tokens import Whitespace, Comment, Keyword, Name, Punctuation
|
||||
|
||||
|
||||
table_def_regex = re.compile(r'^TABLE\s*\((.+)\)$', re.IGNORECASE)
|
||||
|
||||
|
||||
class FunctionMetadata(object):
|
||||
|
||||
def __init__(self, schema_name, func_name, arg_list, return_type, is_aggregate,
|
||||
is_window, is_set_returning):
|
||||
"""Class for describing a postgresql function"""
|
||||
|
||||
self.schema_name = schema_name
|
||||
self.func_name = func_name
|
||||
self.arg_list = arg_list.strip()
|
||||
self.return_type = return_type.strip()
|
||||
self.is_aggregate = is_aggregate
|
||||
self.is_window = is_window
|
||||
self.is_set_returning = is_set_returning
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, self.__class__)
|
||||
and self.__dict__ == other.__dict__)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.schema_name, self.func_name, self.arg_list,
|
||||
self.return_type, self.is_aggregate, self.is_window,
|
||||
self.is_set_returning))
|
||||
|
||||
def __repr__(self):
|
||||
return (('%s(schema_name=%r, func_name=%r, arg_list=%r, return_type=%r,'
|
||||
' is_aggregate=%r, is_window=%r, is_set_returning=%r)')
|
||||
% (self.__class__.__name__, self.schema_name, self.func_name,
|
||||
self.arg_list, self.return_type, self.is_aggregate,
|
||||
self.is_window, self.is_set_returning))
|
||||
|
||||
def fieldnames(self):
|
||||
"""Returns a list of output field names"""
|
||||
|
||||
if self.return_type.lower() == 'void':
|
||||
return []
|
||||
|
||||
match = table_def_regex.match(self.return_type)
|
||||
if match:
|
||||
# Function returns a table -- get the column names
|
||||
return list(field_names(match.group(1), mode_filter=None))
|
||||
|
||||
# Function may have named output arguments -- find them and return
|
||||
# their names
|
||||
return list(field_names(self.arg_list, mode_filter=('OUT', 'INOUT')))
|
||||
|
||||
|
||||
class TypedFieldMetadata(object):
|
||||
"""Describes typed field from a function signature or table definition
|
||||
|
||||
Attributes are:
|
||||
name The name of the argument/column
|
||||
mode 'IN', 'OUT', 'INOUT', 'VARIADIC'
|
||||
type A list of tokens denoting the type
|
||||
default A list of tokens denoting the default value
|
||||
unknown A list of tokens not assigned to type or default
|
||||
"""
|
||||
|
||||
__slots__ = ['name', 'mode', 'type', 'default', 'unknown']
|
||||
|
||||
def __init__(self):
|
||||
self.name = None
|
||||
self.mode = 'IN'
|
||||
self.type = []
|
||||
self.default = []
|
||||
self.unknown = []
|
||||
|
||||
def __getitem__(self, attr):
|
||||
return getattr(self, attr)
|
||||
|
||||
|
||||
def parse_typed_field_list(tokens):
|
||||
"""Parses a argument/column list, yielding TypedFieldMetadata objects
|
||||
|
||||
Field/column lists are used in function signatures and table
|
||||
definitions. This function parses a flattened list of sqlparse tokens
|
||||
and yields one metadata argument per argument / column.
|
||||
"""
|
||||
|
||||
# postgres function argument list syntax:
|
||||
# " ( [ [ argmode ] [ argname ] argtype
|
||||
# [ { DEFAULT | = } default_expr ] [, ...] ] )"
|
||||
|
||||
mode_names = set(('IN', 'OUT', 'INOUT', 'VARIADIC'))
|
||||
parse_state = 'type'
|
||||
parens = 0
|
||||
field = TypedFieldMetadata()
|
||||
|
||||
for tok in tokens:
|
||||
if tok.ttype in Whitespace or tok.ttype in Comment:
|
||||
continue
|
||||
elif tok.ttype in Punctuation:
|
||||
if parens == 0 and tok.value == ',':
|
||||
# End of the current field specification
|
||||
if field.type:
|
||||
yield field
|
||||
# Initialize metadata holder for the next field
|
||||
field, parse_state = TypedFieldMetadata(), 'type'
|
||||
elif parens == 0 and tok.value == '=':
|
||||
parse_state = 'default'
|
||||
else:
|
||||
field[parse_state].append(tok)
|
||||
if tok.value == '(':
|
||||
parens += 1
|
||||
elif tok.value == ')':
|
||||
parens -= 1
|
||||
elif parens == 0:
|
||||
if tok.ttype in Keyword:
|
||||
if not field.name and tok.value.upper() in mode_names:
|
||||
# No other keywords allowed before arg name
|
||||
field.mode = tok.value.upper()
|
||||
elif tok.value.upper() == 'DEFAULT':
|
||||
parse_state = 'default'
|
||||
else:
|
||||
parse_state = 'unknown'
|
||||
elif tok.ttype == Name and not field.name:
|
||||
# note that `ttype in Name` would also match Name.Builtin
|
||||
field.name = tok.value
|
||||
else:
|
||||
field[parse_state].append(tok)
|
||||
else:
|
||||
field[parse_state].append(tok)
|
||||
|
||||
# Final argument won't be followed by a comma, so make sure it gets yielded
|
||||
if field.type:
|
||||
yield field
|
||||
|
||||
|
||||
def field_names(sql, mode_filter=('IN', 'OUT', 'INOUT', 'VARIADIC')):
|
||||
"""Yields field names from a table declaration"""
|
||||
|
||||
if not sql:
|
||||
return
|
||||
|
||||
# sql is something like "x int, y text, ..."
|
||||
tokens = sqlparse.parse(sql)[0].flatten()
|
||||
for f in parse_typed_field_list(tokens):
|
||||
if f.name and (not mode_filter or f.mode in mode_filter):
|
||||
yield f.name
|
|
@ -0,0 +1,288 @@
|
|||
|
||||
import re
|
||||
import sqlparse
|
||||
from collections import namedtuple
|
||||
from sqlparse.sql import IdentifierList, Identifier, Function
|
||||
from sqlparse.tokens import Keyword, DML, Punctuation, Token, Error
|
||||
|
||||
cleanup_regex = {
|
||||
# This matches only alphanumerics and underscores.
|
||||
'alphanum_underscore': re.compile(r'(\w+)$'),
|
||||
# This matches everything except spaces, parens, colon, and comma
|
||||
'many_punctuations': re.compile(r'([^():,\s]+)$'),
|
||||
# This matches everything except spaces, parens, colon, comma, and period
|
||||
'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
|
||||
# This matches everything except a space.
|
||||
'all_punctuations': re.compile('([^\s]+)$'),
|
||||
}
|
||||
|
||||
|
||||
def last_word(text, include='alphanum_underscore'):
|
||||
"""
|
||||
Find the last word in a sentence.
|
||||
|
||||
>>> last_word('abc')
|
||||
'abc'
|
||||
>>> last_word(' abc')
|
||||
'abc'
|
||||
>>> last_word('')
|
||||
''
|
||||
>>> last_word(' ')
|
||||
''
|
||||
>>> last_word('abc ')
|
||||
''
|
||||
>>> last_word('abc def')
|
||||
'def'
|
||||
>>> last_word('abc def ')
|
||||
''
|
||||
>>> last_word('abc def;')
|
||||
''
|
||||
>>> last_word('bac $def')
|
||||
'def'
|
||||
>>> last_word('bac $def', include='most_punctuations')
|
||||
'$def'
|
||||
>>> last_word('bac \def', include='most_punctuations')
|
||||
'\\\\def'
|
||||
>>> last_word('bac \def;', include='most_punctuations')
|
||||
'\\\\def;'
|
||||
>>> last_word('bac::def', include='most_punctuations')
|
||||
'def'
|
||||
>>> last_word('"foo*bar', include='most_punctuations')
|
||||
'"foo*bar'
|
||||
"""
|
||||
|
||||
if not text: # Empty string
|
||||
return ''
|
||||
|
||||
if text[-1].isspace():
|
||||
return ''
|
||||
else:
|
||||
regex = cleanup_regex[include]
|
||||
matches = regex.search(text)
|
||||
if matches:
|
||||
return matches.group(0)
|
||||
else:
|
||||
return ''
|
||||
|
||||
|
||||
TableReference = namedtuple('TableReference', ['schema', 'name', 'alias',
|
||||
'is_function'])
|
||||
|
||||
|
||||
# This code is borrowed from sqlparse example script.
|
||||
# <url>
|
||||
def is_subselect(parsed):
|
||||
if not parsed.is_group():
|
||||
return False
|
||||
for item in parsed.tokens:
|
||||
if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT',
|
||||
'UPDATE', 'CREATE', 'DELETE'):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _identifier_is_function(identifier):
|
||||
return any(isinstance(t, Function) for t in identifier.tokens)
|
||||
|
||||
|
||||
def extract_from_part(parsed, stop_at_punctuation=True):
|
||||
tbl_prefix_seen = False
|
||||
for item in parsed.tokens:
|
||||
if tbl_prefix_seen:
|
||||
if is_subselect(item):
|
||||
for x in extract_from_part(item, stop_at_punctuation):
|
||||
yield x
|
||||
elif stop_at_punctuation and item.ttype is Punctuation:
|
||||
raise StopIteration
|
||||
# An incomplete nested select won't be recognized correctly as a
|
||||
# sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
|
||||
# the second FROM to trigger this elif condition resulting in a
|
||||
# StopIteration. So we need to ignore the keyword if the keyword
|
||||
# FROM.
|
||||
# Also 'SELECT * FROM abc JOIN def' will trigger this elif
|
||||
# condition. So we need to ignore the keyword JOIN and its variants
|
||||
# INNER JOIN, FULL OUTER JOIN, etc.
|
||||
elif item.ttype is Keyword and (
|
||||
not item.value.upper() == 'FROM') and (
|
||||
not item.value.upper().endswith('JOIN')):
|
||||
tbl_prefix_seen = False
|
||||
else:
|
||||
yield item
|
||||
elif item.ttype is Keyword or item.ttype is Keyword.DML:
|
||||
item_val = item.value.upper()
|
||||
if (item_val in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE') or
|
||||
item_val.endswith('JOIN')):
|
||||
tbl_prefix_seen = True
|
||||
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
|
||||
# So this check here is necessary.
|
||||
elif isinstance(item, IdentifierList):
|
||||
for identifier in item.get_identifiers():
|
||||
if (identifier.ttype is Keyword and
|
||||
identifier.value.upper() == 'FROM'):
|
||||
tbl_prefix_seen = True
|
||||
break
|
||||
|
||||
|
||||
def extract_table_identifiers(token_stream, allow_functions=True):
|
||||
"""yields tuples of TableReference namedtuples"""
|
||||
|
||||
for item in token_stream:
|
||||
if isinstance(item, IdentifierList):
|
||||
for identifier in item.get_identifiers():
|
||||
# Sometimes Keywords (such as FROM ) are classified as
|
||||
# identifiers which don't have the get_real_name() method.
|
||||
try:
|
||||
schema_name = identifier.get_parent_name()
|
||||
real_name = identifier.get_real_name()
|
||||
is_function = (allow_functions and
|
||||
_identifier_is_function(identifier))
|
||||
except AttributeError:
|
||||
continue
|
||||
if real_name:
|
||||
yield TableReference(schema_name, real_name,
|
||||
identifier.get_alias(), is_function)
|
||||
elif isinstance(item, Identifier):
|
||||
real_name = item.get_real_name()
|
||||
schema_name = item.get_parent_name()
|
||||
is_function = allow_functions and _identifier_is_function(item)
|
||||
|
||||
if real_name:
|
||||
yield TableReference(schema_name, real_name, item.get_alias(),
|
||||
is_function)
|
||||
else:
|
||||
name = item.get_name()
|
||||
yield TableReference(None, name, item.get_alias() or name,
|
||||
is_function)
|
||||
elif isinstance(item, Function):
|
||||
yield TableReference(None, item.get_real_name(), item.get_alias(),
|
||||
allow_functions)
|
||||
|
||||
|
||||
# extract_tables is inspired from examples in the sqlparse lib.
|
||||
def extract_tables(sql):
|
||||
"""Extract the table names from an SQL statment.
|
||||
|
||||
Returns a list of TableReference namedtuples
|
||||
|
||||
"""
|
||||
parsed = sqlparse.parse(sql)
|
||||
if not parsed:
|
||||
return ()
|
||||
|
||||
# INSERT statements must stop looking for tables at the sign of first
|
||||
# Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
|
||||
# abc is the table name, but if we don't stop at the first lparen, then
|
||||
# we'll identify abc, col1 and col2 as table names.
|
||||
insert_stmt = parsed[0].token_first().value.lower() == 'insert'
|
||||
stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
|
||||
|
||||
# Kludge: sqlparse mistakenly identifies insert statements as
|
||||
# function calls due to the parenthesized column list, e.g. interprets
|
||||
# "insert into foo (bar, baz)" as a function call to foo with arguments
|
||||
# (bar, baz). So don't allow any identifiers in insert statements
|
||||
# to have is_function=True
|
||||
identifiers = extract_table_identifiers(stream,
|
||||
allow_functions=not insert_stmt)
|
||||
return tuple(identifiers)
|
||||
|
||||
|
||||
def find_prev_keyword(sql):
|
||||
""" Find the last sql keyword in an SQL statement
|
||||
|
||||
Returns the value of the last keyword, and the text of the query with
|
||||
everything after the last keyword stripped
|
||||
"""
|
||||
if not sql.strip():
|
||||
return None, ''
|
||||
|
||||
parsed = sqlparse.parse(sql)[0]
|
||||
flattened = list(parsed.flatten())
|
||||
|
||||
logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN')
|
||||
|
||||
for t in reversed(flattened):
|
||||
if t.value == '(' or (t.is_keyword and (
|
||||
t.value.upper() not in logical_operators)):
|
||||
# Find the location of token t in the original parsed statement
|
||||
# We can't use parsed.token_index(t) because t may be a child token
|
||||
# inside a TokenList, in which case token_index thows an error
|
||||
# Minimal example:
|
||||
# p = sqlparse.parse('select * from foo where bar')
|
||||
# t = list(p.flatten())[-3] # The "Where" token
|
||||
# p.token_index(t) # Throws ValueError: not in list
|
||||
idx = flattened.index(t)
|
||||
|
||||
# Combine the string values of all tokens in the original list
|
||||
# up to and including the target keyword token t, to produce a
|
||||
# query string with everything after the keyword token removed
|
||||
text = ''.join(tok.value for tok in flattened[:idx+1])
|
||||
return t, text
|
||||
|
||||
return None, ''
|
||||
|
||||
|
||||
# Postgresql dollar quote signs look like `$$` or `$tag$`
|
||||
dollar_quote_regex = re.compile(r'^\$[^$]*\$$')
|
||||
|
||||
|
||||
def is_open_quote(sql):
|
||||
"""Returns true if the query contains an unclosed quote"""
|
||||
|
||||
# parsed can contain one or more semi-colon separated commands
|
||||
parsed = sqlparse.parse(sql)
|
||||
return any(_parsed_is_open_quote(p) for p in parsed)
|
||||
|
||||
|
||||
def _parsed_is_open_quote(parsed):
|
||||
tokens = list(parsed.flatten())
|
||||
|
||||
i = 0
|
||||
while i < len(tokens):
|
||||
tok = tokens[i]
|
||||
if tok.match(Token.Error, "'"):
|
||||
# An unmatched single quote
|
||||
return True
|
||||
elif (tok.ttype in Token.Name.Builtin
|
||||
and dollar_quote_regex.match(tok.value)):
|
||||
# Find the matching closing dollar quote sign
|
||||
for (j, tok2) in enumerate(tokens[i+1:], i+1):
|
||||
if tok2.match(Token.Name.Builtin, tok.value):
|
||||
# Found the matching closing quote - continue our scan for
|
||||
# open quotes thereafter
|
||||
i = j
|
||||
break
|
||||
else:
|
||||
# No matching dollar sign quote
|
||||
return True
|
||||
|
||||
i += 1
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def parse_partial_identifier(word):
|
||||
"""Attempt to parse a (partially typed) word as an identifier
|
||||
|
||||
word may include a schema qualification, like `schema_name.partial_name`
|
||||
or `schema_name.` There may also be unclosed quotation marks, like
|
||||
`"schema`, or `schema."partial_name`
|
||||
|
||||
:param word: string representing a (partially complete) identifier
|
||||
:return: sqlparse.sql.Identifier, or None
|
||||
"""
|
||||
|
||||
p = sqlparse.parse(word)[0]
|
||||
n_tok = len(p.tokens)
|
||||
if n_tok == 1 and isinstance(p.tokens[0], Identifier):
|
||||
return p.tokens[0]
|
||||
elif p.token_next_match(0, Error, '"'):
|
||||
# An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar'
|
||||
# Close the double quote, then reparse
|
||||
return parse_partial_identifier(word + '"')
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sql = 'select * from (select t. from tabl t'
|
||||
print (extract_tables(sql))
|
|
@ -0,0 +1,49 @@
|
|||
import re
|
||||
import sqlparse
|
||||
from sqlparse.tokens import Name
|
||||
from collections import defaultdict
|
||||
|
||||
white_space_regex = re.compile('\\s+', re.MULTILINE)
|
||||
|
||||
|
||||
def _compile_regex(keyword):
|
||||
# Surround the keyword with word boundaries and replace interior whitespace
|
||||
# with whitespace wildcards
|
||||
pattern = '\\b' + re.sub(white_space_regex, '\\s+', keyword) + '\\b'
|
||||
return re.compile(pattern, re.MULTILINE | re.IGNORECASE)
|
||||
|
||||
|
||||
class PrevalenceCounter(object):
|
||||
def __init__(self, keywords):
|
||||
self.keyword_counts = defaultdict(int)
|
||||
self.name_counts = defaultdict(int)
|
||||
self.keyword_regexs = dict((kw, _compile_regex(kw)) for kw in keywords)
|
||||
|
||||
def update(self, text):
|
||||
self.update_keywords(text)
|
||||
self.update_names(text)
|
||||
|
||||
def update_names(self, text):
|
||||
for parsed in sqlparse.parse(text):
|
||||
for token in parsed.flatten():
|
||||
if token.ttype in Name:
|
||||
self.name_counts[token.value] += 1
|
||||
|
||||
def clear_names(self):
|
||||
self.name_counts = defaultdict(int)
|
||||
|
||||
def update_keywords(self, text):
|
||||
# Count keywords. Can't rely for sqlparse for this, because it's
|
||||
# database agnostic
|
||||
for keyword, regex in self.keyword_regexs.items():
|
||||
for _ in regex.finditer(text):
|
||||
self.keyword_counts[keyword] += 1
|
||||
|
||||
def keyword_count(self, keyword):
|
||||
return self.keyword_counts[keyword]
|
||||
|
||||
def name_count(self, name):
|
||||
return self.name_counts[name]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue