Adding support for autocomplete in the SQL Editor.

In Query editor, we can use the autocomplete feature by using keyword
combination - 'Ctrl + Space'.
pull/3/head
Akshay Joshi 2016-05-21 16:04:05 +05:30 committed by Ashesh Vashi
parent da28dc8507
commit 0a354055a9
19 changed files with 1895 additions and 2 deletions

View File

@ -43,3 +43,4 @@ traceback2==1.4.0
unittest2==1.1.0
Werkzeug==0.9.6
WTForms==2.0.2
sqlparse==0.1.19

View File

@ -37,3 +37,4 @@ unittest2==1.1.0
Werkzeug==0.9.6
wheel==0.24.0
WTForms==2.0.2
sqlparse==0.1.19

View File

@ -0,0 +1,29 @@
{# SQL query for getting columns #}
{% if object_name == 'table' %}
SELECT
att.attname column_name
FROM pg_catalog.pg_attribute att
INNER JOIN pg_catalog.pg_class cls
ON att.attrelid = cls.oid
INNER JOIN pg_catalog.pg_namespace nsp
ON cls.relnamespace = nsp.oid
WHERE cls.relkind = ANY(array['r'])
AND NOT att.attisdropped
AND att.attnum > 0
AND (nsp.nspname = '{{schema_name}}' AND cls.relname = '{{rel_name}}')
ORDER BY 1
{% endif %}
{% if object_name == 'view' %}
SELECT
att.attname column_name
FROM pg_catalog.pg_attribute att
INNER JOIN pg_catalog.pg_class cls
ON att.attrelid = cls.oid
INNER JOIN pg_catalog.pg_namespace nsp
ON cls.relnamespace = nsp.oid
WHERE cls.relkind = ANY(array['v', 'm'])
AND NOT att.attisdropped
AND att.attnum > 0
AND (nsp.nspname = '{{schema_name}}' AND cls.relname = '{{rel_name}}')
ORDER BY 1
{% endif %}

View File

@ -0,0 +1,4 @@
{# SQL query for getting databases #}
SELECT d.datname
FROM pg_catalog.pg_database d
ORDER BY 1

View File

@ -0,0 +1,9 @@
{# SQL query for getting datatypes #}
SELECT n.nspname schema_name,
t.typname object_name
FROM pg_catalog.pg_type t
INNER JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
WHERE (t.typrelid = 0 OR (SELECT c.relkind = 'c' FROM pg_catalog.pg_class c WHERE c.oid = t.typrelid))
AND NOT EXISTS(SELECT 1 FROM pg_catalog.pg_type el WHERE el.oid = t.typelem AND el.typarray = t.oid)
AND n.nspname IN ({{schema_names}})
ORDER BY 1, 2;

View File

@ -0,0 +1,30 @@
{# ============= Fetch the list of functions based on given schema_names ============= #}
{% if func_name %}
SELECT n.nspname schema_name,
p.proname func_name,
pg_catalog.pg_get_function_arguments(p.oid) arg_list,
pg_catalog.pg_get_function_result(p.oid) return_type,
p.proisagg is_aggregate,
p.proiswindow is_window,
p.proretset is_set_returning
FROM pg_catalog.pg_proc p
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
WHERE n.nspname = '{{schema_name}}' AND p.proname = '{{func_name}}'
AND p.proretset
ORDER BY 1, 2
{% else %}
SELECT n.nspname schema_name,
p.proname object_name,
pg_catalog.pg_get_function_arguments(p.oid) arg_list,
pg_catalog.pg_get_function_result(p.oid) return_type,
p.proisagg is_aggregate,
p.proiswindow is_window,
p.proretset is_set_returning
FROM pg_catalog.pg_proc p
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
WHERE n.nspname IN ({{schema_names}})
{% if is_set_returning %}
AND p.proretset
{% endif %}
ORDER BY 1, 2
{% endif %}

View File

@ -0,0 +1,2 @@
{# SQL query for getting keywords #}
SELECT upper(word) as word FROM pg_get_keywords()

View File

@ -0,0 +1,6 @@
{# SQL query for getting current_schemas #}
{% if search_path %}
SELECT * FROM unnest(current_schemas(true)) AS schema
{% else %}
SELECT nspname AS schema FROM pg_catalog.pg_namespace ORDER BY 1
{% endif %}

View File

@ -0,0 +1,17 @@
{# ============= Fetch the list of tables/view based on given schema_names ============= #}
{% if object_name == 'tables' %}
SELECT n.nspname schema_name,
c.relname object_name
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE c.relkind = ANY(array['r']) and n.nspname IN ({{schema_names}})
ORDER BY 1,2
{% endif %}
{% if object_name == 'views' %}
SELECT n.nspname schema_name,
c.relname object_name
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
WHERE c.relkind = ANY(array['v', 'm']) and n.nspname IN ({{schema_names}})
ORDER BY 1,2
{% endif %}

View File

@ -24,6 +24,7 @@ from pgadmin.utils.driver import get_driver
from config import PG_DEFAULT_DRIVER
from pgadmin.tools.sqleditor.command import QueryToolCommand
from pgadmin.utils import get_storage_directory
from pgadmin.utils.sqlautocomplete.autocomplete import SQLAutoComplete
# import unquote from urlib for python2.x and python3.x
try:
@ -890,6 +891,44 @@ def set_auto_rollback(trans_id):
return make_json_response(data={'status': status, 'result': res})
@blueprint.route('/autocomplete/<int:trans_id>', methods=["PUT", "POST"])
@login_required
def auto_complete(trans_id):
"""
This method implements the autocomplete feature.
Args:
trans_id: unique transaction id
"""
full_sql = ''
text_before_cursor = ''
if request.data:
data = json.loads(request.data.decode())
else:
data = request.args or request.form
if len(data) > 0:
full_sql = data[0]
text_before_cursor = data[1]
# Check the transaction and connection status
status, error_msg, conn, trans_obj, session_obj = check_transaction_status(trans_id)
if status and conn is not None \
and trans_obj is not None and session_obj is not None:
# Create object of SQLAutoComplete class and pass connection object
auto_complete_obj = SQLAutoComplete(sid=trans_obj.sid, did=trans_obj.did, conn=conn)
# Get the auto completion suggestions.
res = auto_complete_obj.get_completions(full_sql, text_before_cursor)
else:
status = False
res = error_msg
return make_json_response(data={'status': status, 'result': res})
@blueprint.route("/sqleditor.js")
@login_required
def script():

View File

@ -243,3 +243,50 @@
width: 100%;
overflow: auto;
}
.CodeMirror-hints {
position: absolute;
z-index: 10;
overflow: hidden;
list-style: none;
margin: 0;
padding: 2px;
-webkit-box-shadow: 2px 3px 5px rgba(0,0,0,.2);
-moz-box-shadow: 2px 3px 5px rgba(0,0,0,.2);
box-shadow: 2px 3px 5px rgba(0,0,0,.2);
border-radius: 3px;
border: 1px solid silver;
background: white;
font-size: 90%;
font-family: monospace;
max-height: 20em;
overflow-y: auto;
}
.CodeMirror-hint {
margin: 0;
padding: 0 4px;
border-radius: 2px;
max-width: 19em;
overflow: hidden;
white-space: pre;
color: black;
cursor: pointer;
}
li.CodeMirror-hint-active {
background: #08f;
color: white;
}
.sqleditor-hint {
padding-left: 20px;
}
.CodeMirror-hint .fa::before {
padding-right: 7px;
}

View File

@ -6,6 +6,7 @@ define(
'codemirror/mode/sql/sql', 'codemirror/addon/selection/mark-selection',
'codemirror/addon/selection/active-line', 'backbone.paginator',
'codemirror/addon/fold/foldgutter', 'codemirror/addon/fold/foldcode',
'codemirror/addon/hint/show-hint', 'codemirror/addon/hint/sql-hint',
'codemirror/addon/fold/pgadmin-sqlfoldcode', 'backgrid.paginator',
'wcdocker', 'pgadmin.file_manager'
],
@ -238,7 +239,8 @@ define(
rangeFinder: CodeMirror.fold.combine(CodeMirror.pgadminBeginRangeFinder, CodeMirror.pgadminIfRangeFinder,
CodeMirror.pgadminLoopRangeFinder, CodeMirror.pgadminCaseRangeFinder)
},
gutters: ["CodeMirror-linenumbers", "CodeMirror-foldgutter"]
gutters: ["CodeMirror-linenumbers", "CodeMirror-foldgutter"],
extraKeys: {"Ctrl-Space": "autocomplete"}
});
// Create panels for 'Data Output', 'Explain', 'Messages' and 'History'
@ -295,6 +297,107 @@ define(
self.history_panel = main_docker.addPanel('history', wcDocker.DOCK.STACKED, self.data_output_panel);
self.render_history_grid();
/* We have override/register the hint function of CodeMirror
* to provide our own hint logic.
*/
CodeMirror.registerHelper("hint", "sql", function(editor, options) {
var data = [],
result = [];
var doc = editor.getDoc();
var cur = doc.getCursor();
var current_line = cur.line; // gets the line number in the cursor position
var current_cur = cur.ch; // get the current cursor position
/* Render function for hint to add our own class
* and icon as per the object type.
*/
var hint_render = function(elt, data, cur) {
var el = document.createElement('span');
switch(cur.type) {
case 'database':
el.className = 'sqleditor-hint pg-icon-' + cur.type;
break;
case 'datatype':
el.className = 'sqleditor-hint icon-type';
break;
case 'keyword':
el.className = 'fa fa-key';
break;
case 'table alias':
el.className = 'fa fa-at';
break;
default:
el.className = 'sqleditor-hint icon-' + cur.type;
}
el.appendChild(document.createTextNode(cur.text));
elt.appendChild(el);
};
var full_text = doc.getValue();
// Get the text from start to the current cursor position.
var text_before_cursor = doc.getRange({ line: 0, ch: 0 },
{ line: current_line, ch: current_cur });
data.push(full_text);
data.push(text_before_cursor);
// Make ajax call to find the autocomplete data
$.ajax({
url: "{{ url_for('sqleditor.index') }}" + "autocomplete/" + self.transId,
method: 'POST',
async: false,
contentType: "application/json",
data: JSON.stringify(data),
success: function(res) {
_.each(res.data.result, function(obj, key) {
result.push({
text: key, type: obj.object_type,
render: hint_render
});
});
// Sort function to sort the suggestion's alphabetically.
result.sort(function(a, b){
var textA = a.text.toLowerCase(), textB = b.text.toLowerCase()
if (textA < textB) //sort string ascending
return -1
if (textA > textB)
return 1
return 0 //default return value (no sorting)
})
}
});
/* Below logic find the start and end point
* to replace the selected auto complete suggestion.
*/
var token = editor.getTokenAt(cur), start, end, search;
if (token.end > cur.ch) {
token.end = cur.ch;
token.string = token.string.slice(0, cur.ch - token.start);
}
if (token.string.match(/^[.`\w@]\w*$/)) {
search = token.string;
start = token.start;
end = token.end;
} else {
start = end = cur.ch;
search = "";
}
/* Added 1 in the start position if search string
* started with "." or "`" else auto complete of code mirror
* will remove the "." when user select any suggestion.
*/
if (search.charAt(0) == "." || search.charAt(0) == "``")
start += 1;
return {list: result, from: {line: current_line, ch: start }, to: { line: current_line, ch: end }};
});
},
/* This function is responsible to create and render the
@ -782,7 +885,7 @@ define(
el: self.container,
handler: self
});
self.transId = self.container.data('transId');
self.transId = self.gridView.transId = self.container.data('transId');
self.gridView.editor_title = editor_title;
self.gridView.current_file = undefined;

View File

@ -0,0 +1,863 @@
##########################################################################
#
# pgAdmin 4 - PostgreSQL Tools
#
# Copyright (C) 2013 - 2016, The pgAdmin Development Team
# This software is released under the PostgreSQL Licence
#
##########################################################################
"""A blueprint module implementing the sql auto complete feature."""
import sys
import re
import sqlparse
import itertools
import operator
from collections import namedtuple
from sqlparse.sql import Comparison, Identifier, Where
from .parseutils import (
last_word, extract_tables, find_prev_keyword, parse_partial_identifier)
from .prioritization import PrevalenceCounter
from .completion import Completion
from .function_metadata import FunctionMetadata
from flask import render_template
from pgadmin.utils.driver import get_driver
from config import PG_DEFAULT_DRIVER
PY2 = sys.version_info[0] == 2
PY3 = sys.version_info[0] == 3
if PY3:
string_types = str
else:
string_types = basestring
Database = namedtuple('Database', [])
Schema = namedtuple('Schema', [])
Table = namedtuple('Table', ['schema'])
Function = namedtuple('Function', ['schema', 'filter'])
# For convenience, don't require the `filter` argument in Function constructor
Function.__new__.__defaults__ = (None, None)
Column = namedtuple('Column', ['tables', 'drop_unique'])
Column.__new__.__defaults__ = (None, None)
View = namedtuple('View', ['schema'])
Keyword = namedtuple('Keyword', [])
Datatype = namedtuple('Datatype', ['schema'])
Alias = namedtuple('Alias', ['aliases'])
Match = namedtuple('Match', ['completion', 'priority'])
try:
from collections import Counter
except ImportError:
# python 2.6
from .counter import Counter
# Regex for finding "words" in documents.
_FIND_WORD_RE = re.compile(r'([a-zA-Z0-9_]+|[^a-zA-Z0-9_\s]+)')
_FIND_BIG_WORD_RE = re.compile(r'([^\s]+)')
class SQLAutoComplete(object):
"""
class SQLAutoComplete
This class is used to provide the postgresql's autocomplete feature.
This class used sqlparse to parse the given sql and psycopg2 to make
the connection and get the tables, schemas, functions etc. based on
the query.
"""
def __init__(self, **kwargs):
"""
This method is used to initialize the class.
Args:
**kwargs : N number of parameters
"""
self.sid = kwargs['sid'] if 'sid' in kwargs else None
self.did = kwargs['did'] if 'did' in kwargs else None
self.conn = kwargs['conn'] if 'conn' in kwargs else None
self.keywords = []
manager = get_driver(PG_DEFAULT_DRIVER).connection_manager(self.sid)
ver = manager.version
# we will set template path for sql scripts
if ver >= 90100:
self.sql_path = 'sqlautocomplete/sql/9.1_plus'
self.search_path = []
# Fetch the search path
if self.conn.connected():
query = render_template("/".join([self.sql_path, 'schema.sql']), search_path=True)
status, res = self.conn.execute_dict(query)
if status:
for record in res['rows']:
self.search_path.append(record['schema'])
# Fetch the keywords
query = render_template("/".join([self.sql_path, 'keywords.sql']))
status, res = self.conn.execute_dict(query)
if status:
for record in res['rows']:
self.keywords.append(record['word'])
self.text_before_cursor = None
self.prioritizer = PrevalenceCounter(self.keywords)
self.reserved_words = set()
for x in self.keywords:
self.reserved_words.update(x.split())
self.name_pattern = re.compile("^[_a-z][_a-z0-9\$]*$")
def escape_name(self, name):
if name and ((not self.name_pattern.match(name)) or
(name.upper() in self.reserved_words)):
name = '"%s"' % name
return name
def unescape_name(self, name):
if name and name[0] == '"' and name[-1] == '"':
name = name[1:-1]
return name
def escaped_names(self, names):
return [self.escape_name(name) for name in names]
def find_matches(self, text, collection, mode='fuzzy',
meta=None, meta_collection=None):
"""
Find completion matches for the given text.
Given the user's input text and a collection of available
completions, find completions matching the last word of the
text.
`mode` can be either 'fuzzy', or 'strict'
'fuzzy': fuzzy matching, ties broken by name prevalance
`keyword`: start only matching, ties broken by keyword prevalance
yields prompt_toolkit Completion instances for any matches found
in the collection of available completions.
Args:
text:
collection:
mode:
meta:
meta_collection:
"""
text = last_word(text, include='most_punctuations').lower()
text_len = len(text)
if text and text[0] == '"':
# text starts with double quote; user is manually escaping a name
# Match on everything that follows the double-quote. Note that
# text_len is calculated before removing the quote, so the
# Completion.position value is correct
text = text[1:]
if mode == 'fuzzy':
fuzzy = True
priority_func = self.prioritizer.name_count
else:
fuzzy = False
priority_func = self.prioritizer.keyword_count
# Construct a `_match` function for either fuzzy or non-fuzzy matching
# The match function returns a 2-tuple used for sorting the matches,
# or None if the item doesn't match
# Note: higher priority values mean more important, so use negative
# signs to flip the direction of the tuple
if fuzzy:
regex = '.*?'.join(map(re.escape, text))
pat = re.compile('(%s)' % regex)
def _match(item):
r = pat.search(self.unescape_name(item.lower()))
if r:
return -len(r.group()), -r.start()
else:
match_end_limit = len(text)
def _match(item):
match_point = item.lower().find(text, 0, match_end_limit)
if match_point >= 0:
# Use negative infinity to force keywords to sort after all
# fuzzy matches
return -float('Infinity'), -match_point
if meta_collection:
# Each possible completion in the collection has a corresponding
# meta-display string
collection = zip(collection, meta_collection)
else:
# All completions have an identical meta
collection = zip(collection, itertools.repeat(meta))
matches = []
for item, meta in collection:
sort_key = _match(item)
if sort_key:
if meta and len(meta) > 50:
# Truncate meta-text to 50 characters, if necessary
meta = meta[:47] + u'...'
# Lexical order of items in the collection, used for
# tiebreaking items with the same match group length and start
# position. Since we use *higher* priority to mean "more
# important," we use -ord(c) to prioritize "aa" > "ab" and end
# with 1 to prioritize shorter strings (ie "user" > "users").
# We also use the unescape_name to make sure quoted names have
# the same priority as unquoted names.
lexical_priority = tuple(-ord(c) for c in self.unescape_name(item)) + (1,)
priority = sort_key, priority_func(item), lexical_priority
matches.append(Match(
completion=Completion(item, -text_len, display_meta=meta),
priority=priority))
return matches
def get_completions(self, text, text_before_cursor):
self.text_before_cursor = text_before_cursor
word_before_cursor = self.get_word_before_cursor(word=True)
matches = []
suggestions = self.suggest_type(text, text_before_cursor)
for suggestion in suggestions:
suggestion_type = type(suggestion)
# Map suggestion type to method
# e.g. 'table' -> self.get_table_matches
matcher = self.suggestion_matchers[suggestion_type]
matches.extend(matcher(self, suggestion, word_before_cursor))
# Sort matches so highest priorities are first
matches = sorted(matches, key=operator.attrgetter('priority'),
reverse=True)
result = dict()
for m in matches:
result[m.completion.display] = {'object_type': m.completion.display_meta}
return result
def get_column_matches(self, suggestion, word_before_cursor):
tables = suggestion.tables
scoped_cols = self.populate_scoped_cols(tables)
if suggestion.drop_unique:
# drop_unique is used for 'tb11 JOIN tbl2 USING (...' which should
# suggest only columns that appear in more than one table
scoped_cols = [col for (col, count)
in Counter(scoped_cols).items()
if count > 1 and col != '*']
return self.find_matches(word_before_cursor, scoped_cols, mode='strict', meta='column')
def get_function_matches(self, suggestion, word_before_cursor):
if suggestion.filter == 'is_set_returning':
# Only suggest set-returning functions
funcs = self.populate_functions(suggestion.schema)
else:
funcs = self.populate_schema_objects(suggestion.schema, 'functions')
# Function overloading means we way have multiple functions of the same
# name at this point, so keep unique names only
funcs = set(funcs)
funcs = self.find_matches(word_before_cursor, funcs, mode='strict', meta='function')
return funcs
def get_schema_matches(self, _, word_before_cursor):
schema_names = []
query = render_template("/".join([self.sql_path, 'schema.sql']))
status, res = self.conn.execute_dict(query)
if status:
for record in res['rows']:
schema_names.append(record['schema'])
# Unless we're sure the user really wants them, hide schema names
# starting with pg_, which are mostly temporary schemas
if not word_before_cursor.startswith('pg_'):
schema_names = [s for s in schema_names if not s.startswith('pg_')]
return self.find_matches(word_before_cursor, schema_names, mode='strict', meta='schema')
def get_table_matches(self, suggestion, word_before_cursor):
tables = self.populate_schema_objects(suggestion.schema, 'tables')
# Unless we're sure the user really wants them, don't suggest the
# pg_catalog tables that are implicitly on the search path
if not suggestion.schema and (
not word_before_cursor.startswith('pg_')):
tables = [t for t in tables if not t.startswith('pg_')]
return self.find_matches(word_before_cursor, tables, mode='strict', meta='table')
def get_view_matches(self, suggestion, word_before_cursor):
views = self.populate_schema_objects(suggestion.schema, 'views')
if not suggestion.schema and (
not word_before_cursor.startswith('pg_')):
views = [v for v in views if not v.startswith('pg_')]
return self.find_matches(word_before_cursor, views, mode='strict', meta='view')
def get_alias_matches(self, suggestion, word_before_cursor):
aliases = suggestion.aliases
return self.find_matches(word_before_cursor, aliases, mode='strict',
meta='table alias')
def get_database_matches(self, _, word_before_cursor):
databases = []
query = render_template("/".join([self.sql_path, 'databases.sql']))
if self.conn.connected():
status, res = self.conn.execute_dict(query)
if status:
for record in res['rows']:
databases.append(record['datname'])
return self.find_matches(word_before_cursor, databases, mode='strict',
meta='database')
def get_keyword_matches(self, _, word_before_cursor):
return self.find_matches(word_before_cursor, self.keywords,
mode='strict', meta='keyword')
def get_datatype_matches(self, suggestion, word_before_cursor):
# suggest custom datatypes
types = self.populate_schema_objects(suggestion.schema, 'datatypes')
matches = self.find_matches(word_before_cursor, types, mode='strict', meta='datatype')
return matches
def get_word_before_cursor(self, word=False):
"""
Give the word before the cursor.
If we have whitespace before the cursor this returns an empty string.
Args:
word:
"""
if self.text_before_cursor[-1:].isspace():
return ''
else:
return self.text_before_cursor[self.find_start_of_previous_word(word=word):]
def find_start_of_previous_word(self, count=1, word=False):
"""
Return an index relative to the cursor position pointing to the start
of the previous word. Return `None` if nothing was found.
Args:
count:
word:
"""
# Reverse the text before the cursor, in order to do an efficient
# backwards search.
text_before_cursor = self.text_before_cursor[::-1]
regex = _FIND_BIG_WORD_RE if word else _FIND_WORD_RE
iterator = regex.finditer(text_before_cursor)
try:
for i, match in enumerate(iterator):
if i + 1 == count:
return - match.end(1)
except StopIteration:
pass
suggestion_matchers = {
Column: get_column_matches,
Function: get_function_matches,
Schema: get_schema_matches,
Table: get_table_matches,
View: get_view_matches,
Alias: get_alias_matches,
Database: get_database_matches,
Keyword: get_keyword_matches,
Datatype: get_datatype_matches,
}
def populate_scoped_cols(self, scoped_tbls):
""" Find all columns in a set of scoped_tables
:param scoped_tbls: list of TableReference namedtuples
:return: list of column names
"""
columns = []
for tbl in scoped_tbls:
if tbl.schema:
# A fully qualified schema.relname reference
schema = self.escape_name(tbl.schema)
relname = self.escape_name(tbl.name)
if tbl.is_function:
query = render_template("/".join([self.sql_path, 'functions.sql']),
schema_name=schema,
func_name=relname)
if self.conn.connected():
status, res = self.conn.execute_dict(query)
func = None
if status:
for row in res['rows']:
func = FunctionMetadata(row['schema_name'], row['func_name'],
row['arg_list'], row['return_type'],
row['is_aggregate'], row['is_window'],
row['is_set_returning'])
if func:
columns.extend(func.fieldnames())
else:
# We don't know if schema.relname is a table or view. Since
# tables and views cannot share the same name, we can check
# one at a time
query = render_template("/".join([self.sql_path, 'columns.sql']),
object_name='table',
schema_name=schema,
rel_name=relname)
if self.conn.connected():
status, res = self.conn.execute_dict(query)
if status:
if len(res['rows']) > 0:
# Table exists, so don't bother checking for a view
for record in res['rows']:
columns.append(record['column_name'])
else:
query = render_template("/".join([self.sql_path, 'columns.sql']),
object_name='view',
schema_name=schema,
rel_name=relname)
if self.conn.connected():
status, res = self.conn.execute_dict(query)
if status:
for record in res['rows']:
columns.append(record['column_name'])
else:
# Schema not specified, so traverse the search path looking for
# a table or view that matches. Note that in order to get proper
# shadowing behavior, we need to check both views and tables for
# each schema before checking the next schema
for schema in self.search_path:
relname = self.escape_name(tbl.name)
if tbl.is_function:
query = render_template("/".join([self.sql_path, 'functions.sql']),
schema_name=schema,
func_name=relname)
if self.conn.connected():
status, res = self.conn.execute_dict(query)
func = None
if status:
for row in res['rows']:
func = FunctionMetadata(row['schema_name'], row['func_name'],
row['arg_list'], row['return_type'],
row['is_aggregate'], row['is_window'],
row['is_set_returning'])
if func:
columns.extend(func.fieldnames())
else:
query = render_template("/".join([self.sql_path, 'columns.sql']),
object_name='table',
schema_name=schema,
rel_name=relname)
if self.conn.connected():
status, res = self.conn.execute_dict(query)
if status:
if len(res['rows']) > 0:
# Table exists, so don't bother checking for a view
for record in res['rows']:
columns.append(record['column_name'])
else:
query = render_template("/".join([self.sql_path, 'columns.sql']),
object_name='view',
schema_name=schema,
rel_name=relname)
if self.conn.connected():
status, res = self.conn.execute_dict(query)
if status:
for record in res['rows']:
columns.append(record['column_name'])
return columns
def populate_schema_objects(self, schema, obj_type):
"""
Returns list of tables or functions for a (optional) schema
Args:
schema:
obj_type:
"""
in_clause = ''
query = ''
objects = []
if schema:
in_clause = '\'' + schema + '\''
else:
for r in self.search_path:
in_clause += '\'' + r + '\','
# Remove extra comma
if len(in_clause) > 0:
in_clause = in_clause[:-1]
if obj_type == 'tables':
query = render_template("/".join([self.sql_path, 'tableview.sql']),
schema_names=in_clause,
object_name='tables')
elif obj_type == 'views':
query = render_template("/".join([self.sql_path, 'tableview.sql']),
schema_names=in_clause,
object_name='views')
elif obj_type == 'functions':
query = render_template("/".join([self.sql_path, 'functions.sql']),
schema_names=in_clause)
elif obj_type == 'datatypes':
query = render_template("/".join([self.sql_path, 'datatypes.sql']),
schema_names=in_clause)
if self.conn.connected():
status, res = self.conn.execute_dict(query)
if status:
for record in res['rows']:
objects.append(record['object_name'])
return objects
def populate_functions(self, schema):
"""
Returns a list of function names
filter_func is a function that accepts a FunctionMetadata namedtuple
and returns a boolean indicating whether that function should be
kept or discarded
Args:
schema:
"""
in_clause = ''
funcs = []
if schema:
in_clause = '\'' + schema + '\''
else:
for r in self.search_path:
in_clause += '\'' + r + '\','
# Remove extra comma
if len(in_clause) > 0:
in_clause = in_clause[:-1]
query = render_template("/".join([self.sql_path, 'functions.sql']),
schema_names=in_clause,
is_set_returning=True)
if self.conn.connected():
status, res = self.conn.execute_dict(query)
if status:
for record in res['rows']:
funcs.append(record['object_name'])
return funcs
def suggest_type(self, full_text, text_before_cursor):
"""
Takes the full_text that is typed so far and also the text before the
cursor to suggest completion type and scope.
Returns a tuple with a type of entity ('table', 'column' etc) and a scope.
A scope for a column category will be a list of tables.
Args:
full_text: Contains complete query
text_before_cursor: Contains text before the cursor
"""
word_before_cursor = last_word(text_before_cursor, include='many_punctuations')
identifier = None
def strip_named_query(txt):
"""
This will strip "save named query" command in the beginning of the line:
'\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
Args:
txt:
"""
pattern = re.compile(r'^\s*\\ns\s+[A-z0-9\-_]+\s+')
if pattern.match(txt):
txt = pattern.sub('', txt)
return txt
full_text = strip_named_query(full_text)
text_before_cursor = strip_named_query(text_before_cursor)
# If we've partially typed a word then word_before_cursor won't be an empty
# string. In that case we want to remove the partially typed string before
# sending it to the sqlparser. Otherwise the last token will always be the
# partially typed string which renders the smart completion useless because
# it will always return the list of keywords as completion.
if word_before_cursor:
if word_before_cursor[-1] == '(' or word_before_cursor[0] == '\\':
parsed = sqlparse.parse(text_before_cursor)
else:
parsed = sqlparse.parse(
text_before_cursor[:-len(word_before_cursor)])
identifier = parse_partial_identifier(word_before_cursor)
else:
parsed = sqlparse.parse(text_before_cursor)
statement = None
if len(parsed) > 1:
# Multiple statements being edited -- isolate the current one by
# cumulatively summing statement lengths to find the one that bounds the
# current position
current_pos = len(text_before_cursor)
stmt_start, stmt_end = 0, 0
for statement in parsed:
stmt_len = len(statement.to_unicode())
stmt_start, stmt_end = stmt_end, stmt_end + stmt_len
if stmt_end >= current_pos:
break
text_before_cursor = full_text[stmt_start:current_pos]
full_text = full_text[stmt_start:]
elif parsed:
# A single statement
statement = parsed[0]
else:
# The empty string
statement = None
last_token = statement and statement.token_prev(len(statement.tokens)) or ''
return self.suggest_based_on_last_token(last_token, text_before_cursor,
full_text, identifier)
def suggest_based_on_last_token(self, token, text_before_cursor, full_text, identifier):
if isinstance(token, string_types):
token_v = token.lower()
elif isinstance(token, Comparison):
# If 'token' is a Comparison type such as
# 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling
# token.value on the comparison type will only return the lhs of the
# comparison. In this case a.id. So we need to do token.tokens to get
# both sides of the comparison and pick the last token out of that
# list.
token_v = token.tokens[-1].value.lower()
elif isinstance(token, Where):
# sqlparse groups all tokens from the where clause into a single token
# list. This means that token.value may be something like
# 'where foo > 5 and '. We need to look "inside" token.tokens to handle
# suggestions in complicated where clauses correctly
prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
return self.suggest_based_on_last_token(
prev_keyword, text_before_cursor, full_text, identifier)
elif isinstance(token, Identifier):
# If the previous token is an identifier, we can suggest datatypes if
# we're in a parenthesized column/field list, e.g.:
# CREATE TABLE foo (Identifier <CURSOR>
# CREATE FUNCTION foo (Identifier <CURSOR>
# If we're not in a parenthesized list, the most likely scenario is the
# user is about to specify an alias, e.g.:
# SELECT Identifier <CURSOR>
# SELECT foo FROM Identifier <CURSOR>
prev_keyword, _ = find_prev_keyword(text_before_cursor)
if prev_keyword and prev_keyword.value == '(':
# Suggest datatypes
return self.suggest_based_on_last_token(
'type', text_before_cursor, full_text, identifier)
else:
return Keyword(),
else:
token_v = token.value.lower()
if not token:
return Keyword(),
elif token_v.endswith('('):
p = sqlparse.parse(text_before_cursor)[0]
if p.tokens and isinstance(p.tokens[-1], Where):
# Four possibilities:
# 1 - Parenthesized clause like "WHERE foo AND ("
# Suggest columns/functions
# 2 - Function call like "WHERE foo("
# Suggest columns/functions
# 3 - Subquery expression like "WHERE EXISTS ("
# Suggest keywords, in order to do a subquery
# 4 - Subquery OR array comparison like "WHERE foo = ANY("
# Suggest columns/functions AND keywords. (If we wanted to be
# really fancy, we could suggest only array-typed columns)
column_suggestions = self.suggest_based_on_last_token(
'where', text_before_cursor, full_text, identifier)
# Check for a subquery expression (cases 3 & 4)
where = p.tokens[-1]
prev_tok = where.token_prev(len(where.tokens) - 1)
if isinstance(prev_tok, Comparison):
# e.g. "SELECT foo FROM bar WHERE foo = ANY("
prev_tok = prev_tok.tokens[-1]
prev_tok = prev_tok.value.lower()
if prev_tok == 'exists':
return Keyword(),
else:
return column_suggestions
# Get the token before the parens
prev_tok = p.token_prev(len(p.tokens) - 1)
if prev_tok and prev_tok.value and prev_tok.value.lower() == 'using':
# tbl1 INNER JOIN tbl2 USING (col1, col2)
tables = extract_tables(full_text)
# suggest columns that are present in more than one table
return Column(tables=tables, drop_unique=True),
elif p.token_first().value.lower() == 'select':
# If the lparen is preceeded by a space chances are we're about to
# do a sub-select.
if last_word(text_before_cursor,
'all_punctuations').startswith('('):
return Keyword(),
# We're probably in a function argument list
return Column(tables=extract_tables(full_text)),
elif token_v in ('set', 'by', 'distinct'):
return Column(tables=extract_tables(full_text)),
elif token_v in ('select', 'where', 'having'):
# Check for a table alias or schema qualification
parent = (identifier and identifier.get_parent_name()) or []
if parent:
tables = extract_tables(full_text)
tables = tuple(t for t in tables if self.identifies(parent, t))
return (Column(tables=tables),
Table(schema=parent),
View(schema=parent),
Function(schema=parent),)
else:
return (Column(tables=extract_tables(full_text)),
Function(schema=None),
Keyword(),)
elif (token_v.endswith('join') and token.is_keyword) or \
(token_v in ('copy', 'from', 'update', 'into', 'describe', 'truncate')):
schema = (identifier and identifier.get_parent_name()) or None
# Suggest tables from either the currently-selected schema or the
# public schema if no schema has been specified
suggest = [Table(schema=schema)]
if not schema:
# Suggest schemas
suggest.insert(0, Schema())
# Only tables can be TRUNCATED, otherwise suggest views
if token_v != 'truncate':
suggest.append(View(schema=schema))
# Suggest set-returning functions in the FROM clause
if token_v == 'from' or (token_v.endswith('join') and token.is_keyword):
suggest.append(Function(schema=schema, filter='is_set_returning'))
return tuple(suggest)
elif token_v in ('table', 'view', 'function'):
# E.g. 'DROP FUNCTION <funcname>', 'ALTER TABLE <tablname>'
rel_type = {'table': Table, 'view': View, 'function': Function}[token_v]
schema = (identifier and identifier.get_parent_name()) or None
if schema:
return rel_type(schema=schema),
else:
return Schema(), rel_type(schema=schema)
elif token_v == 'on':
tables = extract_tables(full_text) # [(schema, table, alias), ...]
parent = (identifier and identifier.get_parent_name()) or None
if parent:
# "ON parent.<suggestion>"
# parent can be either a schema name or table alias
tables = tuple(t for t in tables if self.identifies(parent, t))
return (Column(tables=tables),
Table(schema=parent),
View(schema=parent),
Function(schema=parent))
else:
# ON <suggestion>
# Use table alias if there is one, otherwise the table name
aliases = tuple(t.alias or t.name for t in tables)
return Alias(aliases=aliases),
elif token_v in ('c', 'use', 'database', 'template'):
# "\c <db", "use <db>", "DROP DATABASE <db>",
# "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
return Database(),
elif token_v == 'schema':
# DROP SCHEMA schema_name
return Schema(),
elif token_v.endswith(',') or token_v in ('=', 'and', 'or'):
prev_keyword, text_before_cursor = find_prev_keyword(text_before_cursor)
if prev_keyword:
return self.suggest_based_on_last_token(
prev_keyword, text_before_cursor, full_text, identifier)
else:
return ()
elif token_v in ('type', '::'):
# ALTER TABLE foo SET DATA TYPE bar
# SELECT foo::bar
# Note that tables are a form of composite type in postgresql, so
# they're suggested here as well
schema = (identifier and identifier.get_parent_name()) or None
suggestions = [Datatype(schema=schema),
Table(schema=schema)]
if not schema:
suggestions.append(Schema())
return tuple(suggestions)
else:
return Keyword(),
def identifies(self, id, ref):
"""
Returns true if string `id` matches TableReference `ref`
Args:
id:
ref:
"""
return id == ref.alias or id == ref.name or (
ref.schema and (id == ref.schema + '.' + ref.name))

View File

@ -0,0 +1,67 @@
"""
Using Completion class from
https://github.com/jonathanslenders/python-prompt-toolkit/blob/master/prompt_toolkit/completion.py
"""
from __future__ import unicode_literals
from abc import ABCMeta, abstractmethod
from six import with_metaclass
__all__ = (
'Completion'
)
class Completion(object):
"""
:param text: The new string that will be inserted into the document.
:param start_position: Position relative to the cursor_position where the
new text will start. The text will be inserted between the
start_position and the original cursor position.
:param display: (optional string) If the completion has to be displayed
differently in the completion menu.
:param display_meta: (Optional string) Meta information about the
completion, e.g. the path or source where it's coming from.
:param get_display_meta: Lazy `display_meta`. Retrieve meta information
only when meta is displayed.
"""
def __init__(self, text, start_position=0, display=None, display_meta=None,
get_display_meta=None):
self.text = text
self.start_position = start_position
self._display_meta = display_meta
self._get_display_meta = get_display_meta
if display is None:
self.display = text
else:
self.display = display
assert self.start_position <= 0
def __repr__(self):
return '%s(text=%r, start_position=%r)' % (
self.__class__.__name__, self.text, self.start_position)
def __eq__(self, other):
return (
self.text == other.text and
self.start_position == other.start_position and
self.display == other.display and
self.display_meta == other.display_meta)
def __hash__(self):
return hash((self.text, self.start_position, self.display, self.display_meta))
@property
def display_meta(self):
# Return meta-text. (This is lazy when using "get_display_meta".)
if self._display_meta is not None:
return self._display_meta
elif self._get_display_meta:
self._display_meta = self._get_display_meta()
return self._display_meta
else:
return ''

View File

@ -0,0 +1,189 @@
"""
Copied from http://code.activestate.com/recipes/576611-counter-class/
"""
from operator import itemgetter
from heapq import nlargest
from itertools import repeat, ifilter
class Counter(dict):
'''Dict subclass for counting hashable objects. Sometimes called a bag
or multiset. Elements are stored as dictionary keys and their counts
are stored as dictionary values.
>>> Counter('zyzygy')
Counter({'y': 3, 'z': 2, 'g': 1})
'''
def __init__(self, iterable=None, **kwds):
'''Create a new, empty Counter object. And if given, count elements
from an input iterable. Or, initialize the count from another mapping
of elements to their counts.
>>> c = Counter() # a new, empty counter
>>> c = Counter('gallahad') # a new counter from an iterable
>>> c = Counter({'a': 4, 'b': 2}) # a new counter from a mapping
>>> c = Counter(a=4, b=2) # a new counter from keyword args
'''
self.update(iterable, **kwds)
def __missing__(self, key):
return 0
def most_common(self, n=None):
'''List the n most common elements and their counts from the most
common to the least. If n is None, then list all element counts.
>>> Counter('abracadabra').most_common(3)
[('a', 5), ('r', 2), ('b', 2)]
'''
if n is None:
return sorted(self.iteritems(), key=itemgetter(1), reverse=True)
return nlargest(n, self.iteritems(), key=itemgetter(1))
def elements(self):
'''Iterator over elements repeating each as many times as its count.
>>> c = Counter('ABCABC')
>>> sorted(c.elements())
['A', 'A', 'B', 'B', 'C', 'C']
If an element's count has been set to zero or is a negative number,
elements() will ignore it.
'''
for elem, count in self.iteritems():
for _ in repeat(None, count):
yield elem
# Override dict methods where the meaning changes for Counter objects.
@classmethod
def fromkeys(cls, iterable, v=None):
raise NotImplementedError(
'Counter.fromkeys() is undefined. Use Counter(iterable) instead.')
def update(self, iterable=None, **kwds):
'''Like dict.update() but add counts instead of replacing them.
Source can be an iterable, a dictionary, or another Counter instance.
>>> c = Counter('which')
>>> c.update('witch') # add elements from another iterable
>>> d = Counter('watch')
>>> c.update(d) # add elements from another counter
>>> c['h'] # four 'h' in which, witch, and watch
4
'''
if iterable is not None:
if hasattr(iterable, 'iteritems'):
if self:
self_get = self.get
for elem, count in iterable.iteritems():
self[elem] = self_get(elem, 0) + count
else:
dict.update(self, iterable) # fast path when counter is empty
else:
self_get = self.get
for elem in iterable:
self[elem] = self_get(elem, 0) + 1
if kwds:
self.update(kwds)
def copy(self):
'Like dict.copy() but returns a Counter instance instead of a dict.'
return Counter(self)
def __delitem__(self, elem):
'Like dict.__delitem__() but does not raise KeyError for missing values.'
if elem in self:
dict.__delitem__(self, elem)
def __repr__(self):
if not self:
return '%s()' % self.__class__.__name__
items = ', '.join(map('%r: %r'.__mod__, self.most_common()))
return '%s({%s})' % (self.__class__.__name__, items)
# Multiset-style mathematical operations discussed in:
# Knuth TAOCP Volume II section 4.6.3 exercise 19
# and at http://en.wikipedia.org/wiki/Multiset
#
# Outputs guaranteed to only include positive counts.
#
# To strip negative and zero counts, add-in an empty counter:
# c += Counter()
def __add__(self, other):
'''Add counts from two counters.
>>> Counter('abbb') + Counter('bcc')
Counter({'b': 4, 'c': 2, 'a': 1})
'''
if not isinstance(other, Counter):
return NotImplemented
result = Counter()
for elem in set(self) | set(other):
newcount = self[elem] + other[elem]
if newcount > 0:
result[elem] = newcount
return result
def __sub__(self, other):
''' Subtract count, but keep only results with positive counts.
>>> Counter('abbbc') - Counter('bccd')
Counter({'b': 2, 'a': 1})
'''
if not isinstance(other, Counter):
return NotImplemented
result = Counter()
for elem in set(self) | set(other):
newcount = self[elem] - other[elem]
if newcount > 0:
result[elem] = newcount
return result
def __or__(self, other):
'''Union is the maximum of value in either of the input counters.
>>> Counter('abbb') | Counter('bcc')
Counter({'b': 3, 'c': 2, 'a': 1})
'''
if not isinstance(other, Counter):
return NotImplemented
_max = max
result = Counter()
for elem in set(self) | set(other):
newcount = _max(self[elem], other[elem])
if newcount > 0:
result[elem] = newcount
return result
def __and__(self, other):
''' Intersection is the minimum of corresponding counts.
>>> Counter('abbb') & Counter('bcc')
Counter({'b': 1})
'''
if not isinstance(other, Counter):
return NotImplemented
_min = min
result = Counter()
if len(self) < len(other):
self, other = other, self
for elem in ifilter(self.__contains__, other):
newcount = _min(self[elem], other[elem])
if newcount > 0:
result[elem] = newcount
return result

View File

@ -0,0 +1,149 @@
import re
import sqlparse
from sqlparse.tokens import Whitespace, Comment, Keyword, Name, Punctuation
table_def_regex = re.compile(r'^TABLE\s*\((.+)\)$', re.IGNORECASE)
class FunctionMetadata(object):
def __init__(self, schema_name, func_name, arg_list, return_type, is_aggregate,
is_window, is_set_returning):
"""Class for describing a postgresql function"""
self.schema_name = schema_name
self.func_name = func_name
self.arg_list = arg_list.strip()
self.return_type = return_type.strip()
self.is_aggregate = is_aggregate
self.is_window = is_window
self.is_set_returning = is_set_returning
def __eq__(self, other):
return (isinstance(other, self.__class__)
and self.__dict__ == other.__dict__)
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash((self.schema_name, self.func_name, self.arg_list,
self.return_type, self.is_aggregate, self.is_window,
self.is_set_returning))
def __repr__(self):
return (('%s(schema_name=%r, func_name=%r, arg_list=%r, return_type=%r,'
' is_aggregate=%r, is_window=%r, is_set_returning=%r)')
% (self.__class__.__name__, self.schema_name, self.func_name,
self.arg_list, self.return_type, self.is_aggregate,
self.is_window, self.is_set_returning))
def fieldnames(self):
"""Returns a list of output field names"""
if self.return_type.lower() == 'void':
return []
match = table_def_regex.match(self.return_type)
if match:
# Function returns a table -- get the column names
return list(field_names(match.group(1), mode_filter=None))
# Function may have named output arguments -- find them and return
# their names
return list(field_names(self.arg_list, mode_filter=('OUT', 'INOUT')))
class TypedFieldMetadata(object):
"""Describes typed field from a function signature or table definition
Attributes are:
name The name of the argument/column
mode 'IN', 'OUT', 'INOUT', 'VARIADIC'
type A list of tokens denoting the type
default A list of tokens denoting the default value
unknown A list of tokens not assigned to type or default
"""
__slots__ = ['name', 'mode', 'type', 'default', 'unknown']
def __init__(self):
self.name = None
self.mode = 'IN'
self.type = []
self.default = []
self.unknown = []
def __getitem__(self, attr):
return getattr(self, attr)
def parse_typed_field_list(tokens):
"""Parses a argument/column list, yielding TypedFieldMetadata objects
Field/column lists are used in function signatures and table
definitions. This function parses a flattened list of sqlparse tokens
and yields one metadata argument per argument / column.
"""
# postgres function argument list syntax:
# " ( [ [ argmode ] [ argname ] argtype
# [ { DEFAULT | = } default_expr ] [, ...] ] )"
mode_names = set(('IN', 'OUT', 'INOUT', 'VARIADIC'))
parse_state = 'type'
parens = 0
field = TypedFieldMetadata()
for tok in tokens:
if tok.ttype in Whitespace or tok.ttype in Comment:
continue
elif tok.ttype in Punctuation:
if parens == 0 and tok.value == ',':
# End of the current field specification
if field.type:
yield field
# Initialize metadata holder for the next field
field, parse_state = TypedFieldMetadata(), 'type'
elif parens == 0 and tok.value == '=':
parse_state = 'default'
else:
field[parse_state].append(tok)
if tok.value == '(':
parens += 1
elif tok.value == ')':
parens -= 1
elif parens == 0:
if tok.ttype in Keyword:
if not field.name and tok.value.upper() in mode_names:
# No other keywords allowed before arg name
field.mode = tok.value.upper()
elif tok.value.upper() == 'DEFAULT':
parse_state = 'default'
else:
parse_state = 'unknown'
elif tok.ttype == Name and not field.name:
# note that `ttype in Name` would also match Name.Builtin
field.name = tok.value
else:
field[parse_state].append(tok)
else:
field[parse_state].append(tok)
# Final argument won't be followed by a comma, so make sure it gets yielded
if field.type:
yield field
def field_names(sql, mode_filter=('IN', 'OUT', 'INOUT', 'VARIADIC')):
"""Yields field names from a table declaration"""
if not sql:
return
# sql is something like "x int, y text, ..."
tokens = sqlparse.parse(sql)[0].flatten()
for f in parse_typed_field_list(tokens):
if f.name and (not mode_filter or f.mode in mode_filter):
yield f.name

View File

@ -0,0 +1,288 @@
import re
import sqlparse
from collections import namedtuple
from sqlparse.sql import IdentifierList, Identifier, Function
from sqlparse.tokens import Keyword, DML, Punctuation, Token, Error
cleanup_regex = {
# This matches only alphanumerics and underscores.
'alphanum_underscore': re.compile(r'(\w+)$'),
# This matches everything except spaces, parens, colon, and comma
'many_punctuations': re.compile(r'([^():,\s]+)$'),
# This matches everything except spaces, parens, colon, comma, and period
'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
# This matches everything except a space.
'all_punctuations': re.compile('([^\s]+)$'),
}
def last_word(text, include='alphanum_underscore'):
"""
Find the last word in a sentence.
>>> last_word('abc')
'abc'
>>> last_word(' abc')
'abc'
>>> last_word('')
''
>>> last_word(' ')
''
>>> last_word('abc ')
''
>>> last_word('abc def')
'def'
>>> last_word('abc def ')
''
>>> last_word('abc def;')
''
>>> last_word('bac $def')
'def'
>>> last_word('bac $def', include='most_punctuations')
'$def'
>>> last_word('bac \def', include='most_punctuations')
'\\\\def'
>>> last_word('bac \def;', include='most_punctuations')
'\\\\def;'
>>> last_word('bac::def', include='most_punctuations')
'def'
>>> last_word('"foo*bar', include='most_punctuations')
'"foo*bar'
"""
if not text: # Empty string
return ''
if text[-1].isspace():
return ''
else:
regex = cleanup_regex[include]
matches = regex.search(text)
if matches:
return matches.group(0)
else:
return ''
TableReference = namedtuple('TableReference', ['schema', 'name', 'alias',
'is_function'])
# This code is borrowed from sqlparse example script.
# <url>
def is_subselect(parsed):
if not parsed.is_group():
return False
for item in parsed.tokens:
if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT',
'UPDATE', 'CREATE', 'DELETE'):
return True
return False
def _identifier_is_function(identifier):
return any(isinstance(t, Function) for t in identifier.tokens)
def extract_from_part(parsed, stop_at_punctuation=True):
tbl_prefix_seen = False
for item in parsed.tokens:
if tbl_prefix_seen:
if is_subselect(item):
for x in extract_from_part(item, stop_at_punctuation):
yield x
elif stop_at_punctuation and item.ttype is Punctuation:
raise StopIteration
# An incomplete nested select won't be recognized correctly as a
# sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
# the second FROM to trigger this elif condition resulting in a
# StopIteration. So we need to ignore the keyword if the keyword
# FROM.
# Also 'SELECT * FROM abc JOIN def' will trigger this elif
# condition. So we need to ignore the keyword JOIN and its variants
# INNER JOIN, FULL OUTER JOIN, etc.
elif item.ttype is Keyword and (
not item.value.upper() == 'FROM') and (
not item.value.upper().endswith('JOIN')):
tbl_prefix_seen = False
else:
yield item
elif item.ttype is Keyword or item.ttype is Keyword.DML:
item_val = item.value.upper()
if (item_val in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE') or
item_val.endswith('JOIN')):
tbl_prefix_seen = True
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
# So this check here is necessary.
elif isinstance(item, IdentifierList):
for identifier in item.get_identifiers():
if (identifier.ttype is Keyword and
identifier.value.upper() == 'FROM'):
tbl_prefix_seen = True
break
def extract_table_identifiers(token_stream, allow_functions=True):
"""yields tuples of TableReference namedtuples"""
for item in token_stream:
if isinstance(item, IdentifierList):
for identifier in item.get_identifiers():
# Sometimes Keywords (such as FROM ) are classified as
# identifiers which don't have the get_real_name() method.
try:
schema_name = identifier.get_parent_name()
real_name = identifier.get_real_name()
is_function = (allow_functions and
_identifier_is_function(identifier))
except AttributeError:
continue
if real_name:
yield TableReference(schema_name, real_name,
identifier.get_alias(), is_function)
elif isinstance(item, Identifier):
real_name = item.get_real_name()
schema_name = item.get_parent_name()
is_function = allow_functions and _identifier_is_function(item)
if real_name:
yield TableReference(schema_name, real_name, item.get_alias(),
is_function)
else:
name = item.get_name()
yield TableReference(None, name, item.get_alias() or name,
is_function)
elif isinstance(item, Function):
yield TableReference(None, item.get_real_name(), item.get_alias(),
allow_functions)
# extract_tables is inspired from examples in the sqlparse lib.
def extract_tables(sql):
"""Extract the table names from an SQL statment.
Returns a list of TableReference namedtuples
"""
parsed = sqlparse.parse(sql)
if not parsed:
return ()
# INSERT statements must stop looking for tables at the sign of first
# Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
# abc is the table name, but if we don't stop at the first lparen, then
# we'll identify abc, col1 and col2 as table names.
insert_stmt = parsed[0].token_first().value.lower() == 'insert'
stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
# Kludge: sqlparse mistakenly identifies insert statements as
# function calls due to the parenthesized column list, e.g. interprets
# "insert into foo (bar, baz)" as a function call to foo with arguments
# (bar, baz). So don't allow any identifiers in insert statements
# to have is_function=True
identifiers = extract_table_identifiers(stream,
allow_functions=not insert_stmt)
return tuple(identifiers)
def find_prev_keyword(sql):
""" Find the last sql keyword in an SQL statement
Returns the value of the last keyword, and the text of the query with
everything after the last keyword stripped
"""
if not sql.strip():
return None, ''
parsed = sqlparse.parse(sql)[0]
flattened = list(parsed.flatten())
logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN')
for t in reversed(flattened):
if t.value == '(' or (t.is_keyword and (
t.value.upper() not in logical_operators)):
# Find the location of token t in the original parsed statement
# We can't use parsed.token_index(t) because t may be a child token
# inside a TokenList, in which case token_index thows an error
# Minimal example:
# p = sqlparse.parse('select * from foo where bar')
# t = list(p.flatten())[-3] # The "Where" token
# p.token_index(t) # Throws ValueError: not in list
idx = flattened.index(t)
# Combine the string values of all tokens in the original list
# up to and including the target keyword token t, to produce a
# query string with everything after the keyword token removed
text = ''.join(tok.value for tok in flattened[:idx+1])
return t, text
return None, ''
# Postgresql dollar quote signs look like `$$` or `$tag$`
dollar_quote_regex = re.compile(r'^\$[^$]*\$$')
def is_open_quote(sql):
"""Returns true if the query contains an unclosed quote"""
# parsed can contain one or more semi-colon separated commands
parsed = sqlparse.parse(sql)
return any(_parsed_is_open_quote(p) for p in parsed)
def _parsed_is_open_quote(parsed):
tokens = list(parsed.flatten())
i = 0
while i < len(tokens):
tok = tokens[i]
if tok.match(Token.Error, "'"):
# An unmatched single quote
return True
elif (tok.ttype in Token.Name.Builtin
and dollar_quote_regex.match(tok.value)):
# Find the matching closing dollar quote sign
for (j, tok2) in enumerate(tokens[i+1:], i+1):
if tok2.match(Token.Name.Builtin, tok.value):
# Found the matching closing quote - continue our scan for
# open quotes thereafter
i = j
break
else:
# No matching dollar sign quote
return True
i += 1
return False
def parse_partial_identifier(word):
"""Attempt to parse a (partially typed) word as an identifier
word may include a schema qualification, like `schema_name.partial_name`
or `schema_name.` There may also be unclosed quotation marks, like
`"schema`, or `schema."partial_name`
:param word: string representing a (partially complete) identifier
:return: sqlparse.sql.Identifier, or None
"""
p = sqlparse.parse(word)[0]
n_tok = len(p.tokens)
if n_tok == 1 and isinstance(p.tokens[0], Identifier):
return p.tokens[0]
elif p.token_next_match(0, Error, '"'):
# An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar'
# Close the double quote, then reparse
return parse_partial_identifier(word + '"')
else:
return None
if __name__ == '__main__':
sql = 'select * from (select t. from tabl t'
print (extract_tables(sql))

View File

@ -0,0 +1,49 @@
import re
import sqlparse
from sqlparse.tokens import Name
from collections import defaultdict
white_space_regex = re.compile('\\s+', re.MULTILINE)
def _compile_regex(keyword):
# Surround the keyword with word boundaries and replace interior whitespace
# with whitespace wildcards
pattern = '\\b' + re.sub(white_space_regex, '\\s+', keyword) + '\\b'
return re.compile(pattern, re.MULTILINE | re.IGNORECASE)
class PrevalenceCounter(object):
def __init__(self, keywords):
self.keyword_counts = defaultdict(int)
self.name_counts = defaultdict(int)
self.keyword_regexs = dict((kw, _compile_regex(kw)) for kw in keywords)
def update(self, text):
self.update_keywords(text)
self.update_names(text)
def update_names(self, text):
for parsed in sqlparse.parse(text):
for token in parsed.flatten():
if token.ttype in Name:
self.name_counts[token.value] += 1
def clear_names(self):
self.name_counts = defaultdict(int)
def update_keywords(self, text):
# Count keywords. Can't rely for sqlparse for this, because it's
# database agnostic
for keyword, regex in self.keyword_regexs.items():
for _ in regex.finditer(text):
self.keyword_counts[keyword] += 1
def keyword_count(self, keyword):
return self.keyword_counts[keyword]
def name_count(self, name):
return self.name_counts[name]