Merge pgcli code with version 1.10.3, which is used for auto complete feature.

pull/14/head
Akshay Joshi 2018-08-27 15:00:56 +05:30
parent 25679fd542
commit 7a3f3046df
17 changed files with 2191 additions and 1004 deletions

View File

@ -21,6 +21,7 @@ Bug fixes
| `Bug #3325 <https://redmine.postgresql.org/issues/3325>`_ - Fix sort/filter dialog issue where it incorrectly requires ASC/DESC.
| `Bug #3347 <https://redmine.postgresql.org/issues/3347>`_ - Ensure backup should work with '--data-only' and '--schema-only' for any format.
| `Bug #3407 <https://redmine.postgresql.org/issues/3407>`_ - Fix keyboard shortcuts layout in the preferences panel.
| `Bug #3420 <https://redmine.postgresql.org/issues/3420>`_ - Merge pgcli code with version 1.10.3, which is used for auto complete feature.
| `Bug #3461 <https://redmine.postgresql.org/issues/3461>`_ - Ensure that refreshing a node also updates the Property list.
| `Bug #3528 <https://redmine.postgresql.org/issues/3528>`_ - Handle connection errors properly in the query tool.
| `Bug #3547 <https://redmine.postgresql.org/issues/3547>`_ - Make session implementation thread safe

View File

@ -0,0 +1,156 @@
##########################################################################
#
# pgAdmin 4 - PostgreSQL Tools
#
# Copyright (C) 2013 - 2018, The pgAdmin Development Team
# This software is released under the PostgreSQL Licence
#
##########################################################################
import sys
import random
from selenium.webdriver import ActionChains
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from regression.python_test_utils import test_utils
from regression.feature_utils.base_feature_test import BaseFeatureTest
class QueryToolAutoCompleteFeatureTest(BaseFeatureTest):
"""
This feature test will test the query tool auto complete feature.
"""
first_table_name = ""
second_table_name = ""
scenarios = [
("Query tool auto complete feature test", dict())
]
def before(self):
self.page.wait_for_spinner_to_disappear()
self.page.add_server(self.server)
self.first_table_name = "auto_comp_" + \
str(random.randint(1000, 3000))
test_utils.create_table(self.server, self.test_db,
self.first_table_name)
self.second_table_name = "auto_comp_" + \
str(random.randint(1000, 3000))
test_utils.create_table(self.server, self.test_db,
self.second_table_name)
self._locate_database_tree_node()
self.page.open_query_tool()
self.page.wait_for_spinner_to_disappear()
def runTest(self):
# Test case for keywords
print("\nAuto complete ALTER keyword... ", file=sys.stderr, end="")
self._auto_complete("A", "ALTER")
print("OK.", file=sys.stderr)
self._clear_query_tool()
print("Auto complete BEGIN keyword... ", file=sys.stderr, end="")
self._auto_complete("BE", "BEGIN")
print("OK.", file=sys.stderr)
self._clear_query_tool()
print("Auto complete CASCADED keyword... ", file=sys.stderr, end="")
self._auto_complete("CAS", "CASCADED")
print("OK.", file=sys.stderr)
self._clear_query_tool()
print("Auto complete SELECT keyword... ", file=sys.stderr, end="")
self._auto_complete("SE", "SELECT")
print("OK.", file=sys.stderr)
self._clear_query_tool()
print("Auto complete pg_backend_pid() function ... ",
file=sys.stderr, end="")
self._auto_complete("SELECT pg_", "pg_backend_pid()")
print("OK.", file=sys.stderr)
self._clear_query_tool()
print("Auto complete current_query() function ... ",
file=sys.stderr, end="")
self._auto_complete("SELECT current_", "current_query()")
print("OK.", file=sys.stderr)
self._clear_query_tool()
print("Auto complete function with argument ... ",
file=sys.stderr, end="")
self._auto_complete("SELECT pg_st", "pg_stat_file(filename)")
print("OK.", file=sys.stderr)
self._clear_query_tool()
print("Auto complete first table in public schema ... ",
file=sys.stderr, end="")
self._auto_complete("SELECT * FROM public.", self.first_table_name)
print("OK.", file=sys.stderr)
self._clear_query_tool()
print("Auto complete second table in public schema ... ",
file=sys.stderr, end="")
self._auto_complete("SELECT * FROM public.", self.second_table_name)
print("OK.", file=sys.stderr)
self._clear_query_tool()
print("Auto complete JOIN second table with after schema name ... ",
file=sys.stderr, end="")
query = "SELECT * FROM public." + self.first_table_name + \
" JOIN public."
self._auto_complete(query, self.second_table_name)
print("OK.", file=sys.stderr)
self._clear_query_tool()
print("Auto complete JOIN ON some columns ... ",
file=sys.stderr, end="")
query = "SELECT * FROM public." + self.first_table_name + \
" JOIN public." + self.second_table_name + " ON " + \
self.second_table_name + "."
expected_string = "some_column = " + self.first_table_name + \
".some_column"
self._auto_complete(query, expected_string)
print("OK.", file=sys.stderr)
self._clear_query_tool()
print("Auto complete JOIN ON some columns using tabel alias ... ",
file=sys.stderr, end="")
query = "SELECT * FROM public." + self.first_table_name + \
" t1 JOIN public." + self.second_table_name + " t2 ON t2."
self._auto_complete(query, "some_column = t1.some_column")
print("OK.", file=sys.stderr)
self._clear_query_tool()
def after(self):
self.page.remove_server(self.server)
def _locate_database_tree_node(self):
self.page.toggle_open_tree_item(self.server['name'])
self.page.toggle_open_tree_item('Databases')
self.page.toggle_open_tree_item(self.test_db)
def _clear_query_tool(self):
self.page.click_element(
self.page.find_by_xpath("//*[@id='btn-clear-dropdown']")
)
ActionChains(self.driver) \
.move_to_element(self.page.find_by_xpath("//*[@id='btn-clear']")) \
.perform()
self.page.click_element(
self.page.find_by_xpath("//*[@id='btn-clear']")
)
self.page.click_modal('Yes')
def _auto_complete(self, word, expected_string):
self.page.fill_codemirror_area_with(word)
ActionChains(self.page.driver).key_down(
Keys.CONTROL).send_keys(Keys.SPACE).key_up(Keys.CONTROL).perform()
self.page.find_by_xpath(
"//ul[contains(@class, 'CodeMirror-hints') and "
"contains(., '" + expected_string + "')]")

View File

@ -1,30 +1,16 @@
{# ============= Fetch the list of functions based on given schema_names ============= #}
{% if func_name %}
SELECT n.nspname schema_name,
p.proname func_name,
pg_catalog.pg_get_function_arguments(p.oid) arg_list,
pg_catalog.pg_get_function_result(p.oid) return_type,
p.proargnames arg_names,
COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[] arg_types,
p.proargmodes arg_modes,
prorettype::regtype::text return_type,
CASE WHEN p.prokind = 'a' THEN true ELSE false END is_aggregate,
CASE WHEN p.prokind = 'w' THEN true ELSE false END is_window,
p.proretset is_set_returning
p.proretset is_set_returning,
pg_get_expr(proargdefaults, 0) AS arg_defaults
FROM pg_catalog.pg_proc p
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
WHERE n.nspname = '{{schema_name}}' AND p.proname = '{{func_name}}'
AND p.proretset
ORDER BY 1, 2
{% else %}
SELECT n.nspname schema_name,
p.proname object_name,
pg_catalog.pg_get_function_arguments(p.oid) arg_list,
pg_catalog.pg_get_function_result(p.oid) return_type,
CASE WHEN p.prokind = 'a' THEN true ELSE false END is_aggregate,
CASE WHEN p.prokind = 'w' THEN true ELSE false END is_window,
p.proretset is_set_returning
FROM pg_catalog.pg_proc p
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
WHERE n.nspname IN ({{schema_names}})
{% if is_set_returning %}
AND p.proretset
{% endif %}
ORDER BY 1, 2
{% endif %}
WHERE p.prorettype::regtype != 'trigger'::regtype
AND n.nspname IN ({{schema_names}})
ORDER BY 1, 2

View File

@ -1,29 +1,43 @@
{# SQL query for getting columns #}
{% if object_name == 'table' %}
SELECT
att.attname column_name
FROM pg_catalog.pg_attribute att
INNER JOIN pg_catalog.pg_class cls
ON att.attrelid = cls.oid
INNER JOIN pg_catalog.pg_namespace nsp
ON cls.relnamespace = nsp.oid
WHERE cls.relkind = ANY(array['r'])
SELECT nsp.nspname schema_name,
cls.relname table_name,
att.attname column_name,
att.atttypid::regtype::text type_name,
att.atthasdef AS has_default,
def.adsrc as default
FROM pg_catalog.pg_attribute att
INNER JOIN pg_catalog.pg_class cls
ON att.attrelid = cls.oid
INNER JOIN pg_catalog.pg_namespace nsp
ON cls.relnamespace = nsp.oid
LEFT OUTER JOIN pg_attrdef def
ON def.adrelid = att.attrelid
AND def.adnum = att.attnum
WHERE nsp.nspname IN ({{schema_names}})
AND cls.relkind = ANY(array['r'])
AND NOT att.attisdropped
AND att.attnum > 0
AND (nsp.nspname = '{{schema_name}}' AND cls.relname = '{{rel_name}}')
ORDER BY 1
ORDER BY 1, 2, att.attnum
{% endif %}
{% if object_name == 'view' %}
SELECT
att.attname column_name
FROM pg_catalog.pg_attribute att
INNER JOIN pg_catalog.pg_class cls
ON att.attrelid = cls.oid
INNER JOIN pg_catalog.pg_namespace nsp
ON cls.relnamespace = nsp.oid
WHERE cls.relkind = ANY(array['v', 'm'])
SELECT nsp.nspname schema_name,
cls.relname table_name,
att.attname column_name,
att.atttypid::regtype::text type_name,
att.atthasdef AS has_default,
def.adsrc as default
FROM pg_catalog.pg_attribute att
INNER JOIN pg_catalog.pg_class cls
ON att.attrelid = cls.oid
INNER JOIN pg_catalog.pg_namespace nsp
ON cls.relnamespace = nsp.oid
LEFT OUTER JOIN pg_attrdef def
ON def.adrelid = att.attrelid
AND def.adnum = att.attnum
WHERE nsp.nspname IN ({{schema_names}})
AND cls.relkind = ANY(array['v', 'm'])
AND NOT att.attisdropped
AND att.attnum > 0
AND (nsp.nspname = '{{schema_name}}' AND cls.relname = '{{rel_name}}')
ORDER BY 1
{% endif %}
ORDER BY 1, 2, att.attnum
{% endif %}

View File

@ -0,0 +1,27 @@
{# SQL query for getting foreign keys #}
SELECT s_p.nspname AS parentschema,
t_p.relname AS parenttable,
unnest((
select
array_agg(attname ORDER BY i)
from
(select unnest(confkey) as attnum, generate_subscripts(confkey, 1) as i) x
JOIN pg_catalog.pg_attribute c USING(attnum)
WHERE c.attrelid = fk.confrelid
)) AS parentcolumn,
s_c.nspname AS childschema,
t_c.relname AS childtable,
unnest((
select
array_agg(attname ORDER BY i)
from
(select unnest(conkey) as attnum, generate_subscripts(conkey, 1) as i) x
JOIN pg_catalog.pg_attribute c USING(attnum)
WHERE c.attrelid = fk.conrelid
)) AS childcolumn
FROM pg_catalog.pg_constraint fk
JOIN pg_catalog.pg_class t_p ON t_p.oid = fk.confrelid
JOIN pg_catalog.pg_namespace s_p ON s_p.oid = t_p.relnamespace
JOIN pg_catalog.pg_class t_c ON t_c.oid = fk.conrelid
JOIN pg_catalog.pg_namespace s_c ON s_c.oid = t_c.relnamespace
WHERE fk.contype = 'f' AND s_p.nspname IN ({{schema_names}})

View File

@ -1,30 +1,16 @@
{# ============= Fetch the list of functions based on given schema_names ============= #}
{% if func_name %}
SELECT n.nspname schema_name,
p.proname func_name,
pg_catalog.pg_get_function_arguments(p.oid) arg_list,
pg_catalog.pg_get_function_result(p.oid) return_type,
p.proargnames arg_names,
COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[] arg_types,
p.proargmodes arg_modes,
prorettype::regtype::text return_type,
p.proisagg is_aggregate,
p.proiswindow is_window,
p.proretset is_set_returning
p.proretset is_set_returning,
pg_get_expr(proargdefaults, 0) AS arg_defaults
FROM pg_catalog.pg_proc p
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
WHERE n.nspname = '{{schema_name}}' AND p.proname = '{{func_name}}'
AND p.proretset
ORDER BY 1, 2
{% else %}
SELECT n.nspname schema_name,
p.proname object_name,
pg_catalog.pg_get_function_arguments(p.oid) arg_list,
pg_catalog.pg_get_function_result(p.oid) return_type,
p.proisagg is_aggregate,
p.proiswindow is_window,
p.proretset is_set_returning
FROM pg_catalog.pg_proc p
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
WHERE n.nspname IN ({{schema_names}})
{% if is_set_returning %}
AND p.proretset
{% endif %}
ORDER BY 1, 2
{% endif %}
WHERE p.prorettype::regtype != 'trigger'::regtype
AND n.nspname IN ({{schema_names}})
ORDER BY 1, 2

File diff suppressed because it is too large Load Diff

View File

@ -1,197 +0,0 @@
"""
Copied from http://code.activestate.com/recipes/576611-counter-class/
"""
from heapq import nlargest
from itertools import repeat
try:
from itertools import ifilter
except ImportError:
# ifilter is in-built function in Python3 as filter
ifilter = filter
from operator import itemgetter
class Counter(dict):
'''Dict subclass for counting hashable objects. Sometimes called a bag
or multiset. Elements are stored as dictionary keys and their counts
are stored as dictionary values.
>>> Counter('zyzygy')
Counter({'y': 3, 'z': 2, 'g': 1})
'''
def __init__(self, iterable=None, **kwds):
'''Create a new, empty Counter object. And if given, count elements
from an input iterable. Or, initialize the count from another mapping
of elements to their counts.
>>> c = Counter() # a new, empty counter
>>> c = Counter('gallahad') # a new counter from an iterable
>>> c = Counter({'a': 4, 'b': 2}) # a new counter from a mapping
>>> c = Counter(a=4, b=2) # a new counter from keyword args
'''
self.update(iterable, **kwds)
def __missing__(self, key):
return 0
def most_common(self, n=None):
'''List the n most common elements and their counts from the most
common to the least. If n is None, then list all element counts.
>>> Counter('abracadabra').most_common(3)
[('a', 5), ('r', 2), ('b', 2)]
'''
if n is None:
return sorted(self.iteritems(), key=itemgetter(1), reverse=True)
return nlargest(n, self.iteritems(), key=itemgetter(1))
def elements(self):
'''Iterator over elements repeating each as many times as its count.
>>> c = Counter('ABCABC')
>>> sorted(c.elements())
['A', 'A', 'B', 'B', 'C', 'C']
If an element's count has been set to zero or is a negative number,
elements() will ignore it.
'''
for elem, count in self.iteritems():
for _ in repeat(None, count):
yield elem
# Override dict methods where the meaning changes for Counter objects.
@classmethod
def fromkeys(cls, iterable, v=None):
raise NotImplementedError(
'Counter.fromkeys() is undefined. Use Counter(iterable) instead.'
)
def update(self, iterable=None, **kwds):
'''Like dict.update() but add counts instead of replacing them.
Source can be an iterable, a dictionary, or another Counter instance.
>>> c = Counter('which')
>>> c.update('witch') # add elements from another iterable
>>> d = Counter('watch')
>>> c.update(d) # add elements from another counter
>>> c['h'] # four 'h' in which, witch, and watch
4
'''
if iterable is not None:
if hasattr(iterable, 'iteritems'):
if self:
self_get = self.get
for elem, count in iterable.iteritems():
self[elem] = self_get(elem, 0) + count
else:
# fast path when counter is empty
dict.update(self, iterable)
else:
self_get = self.get
for elem in iterable:
self[elem] = self_get(elem, 0) + 1
if kwds:
self.update(kwds)
def copy(self):
'Like dict.copy() but returns a Counter instance instead of a dict.'
return Counter(self)
def __delitem__(self, elem):
"""Like dict.__delitem__() but does not raise KeyError for missing
values."""
if elem in self:
dict.__delitem__(self, elem)
def __repr__(self):
if not self:
return '%s()' % self.__class__.__name__
items = ', '.join(map('%r: %r'.__mod__, self.most_common()))
return '%s({%s})' % (self.__class__.__name__, items)
# Multiset-style mathematical operations discussed in:
# Knuth TAOCP Volume II section 4.6.3 exercise 19
# and at http://en.wikipedia.org/wiki/Multiset
#
# Outputs guaranteed to only include positive counts.
#
# To strip negative and zero counts, add-in an empty counter:
# c += Counter()
def __add__(self, other):
'''Add counts from two counters.
>>> Counter('abbb') + Counter('bcc')
Counter({'b': 4, 'c': 2, 'a': 1})
'''
if not isinstance(other, Counter):
return NotImplemented
result = Counter()
for elem in set(self) | set(other):
newcount = self[elem] + other[elem]
if newcount > 0:
result[elem] = newcount
return result
def __sub__(self, other):
''' Subtract count, but keep only results with positive counts.
>>> Counter('abbbc') - Counter('bccd')
Counter({'b': 2, 'a': 1})
'''
if not isinstance(other, Counter):
return NotImplemented
result = Counter()
for elem in set(self) | set(other):
newcount = self[elem] - other[elem]
if newcount > 0:
result[elem] = newcount
return result
def __or__(self, other):
'''Union is the maximum of value in either of the input counters.
>>> Counter('abbb') | Counter('bcc')
Counter({'b': 3, 'c': 2, 'a': 1})
'''
if not isinstance(other, Counter):
return NotImplemented
_max = max
result = Counter()
for elem in set(self) | set(other):
newcount = _max(self[elem], other[elem])
if newcount > 0:
result[elem] = newcount
return result
def __and__(self, other):
''' Intersection is the minimum of corresponding counts.
>>> Counter('abbb') & Counter('bcc')
Counter({'b': 1})
'''
if not isinstance(other, Counter):
return NotImplemented
_min = min
result = Counter()
if len(self) < len(other):
self, other = other, self
for elem in ifilter(self.__contains__, other):
newcount = _min(self[elem], other[elem])
if newcount > 0:
result[elem] = newcount
return result

View File

@ -1,151 +0,0 @@
import re
import sqlparse
from sqlparse.tokens import Whitespace, Comment, Keyword, Name, Punctuation
table_def_regex = re.compile(r'^TABLE\s*\((.+)\)$', re.IGNORECASE)
class FunctionMetadata(object):
def __init__(self, schema_name, func_name, arg_list, return_type,
is_aggregate, is_window, is_set_returning):
"""Class for describing a postgresql function"""
self.schema_name = schema_name
self.func_name = func_name
self.arg_list = arg_list.strip()
self.return_type = return_type.strip()
self.is_aggregate = is_aggregate
self.is_window = is_window
self.is_set_returning = is_set_returning
def __eq__(self, other):
return (isinstance(other, self.__class__) and
self.__dict__ == other.__dict__)
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash((self.schema_name, self.func_name, self.arg_list,
self.return_type, self.is_aggregate, self.is_window,
self.is_set_returning))
def __repr__(self):
return (
('%s(schema_name=%r, func_name=%r, arg_list=%r, return_type=%r,'
' is_aggregate=%r, is_window=%r, is_set_returning=%r)')
% (self.__class__.__name__, self.schema_name, self.func_name,
self.arg_list, self.return_type, self.is_aggregate,
self.is_window, self.is_set_returning)
)
def fieldnames(self):
"""Returns a list of output field names"""
if self.return_type.lower() == 'void':
return []
match = table_def_regex.match(self.return_type)
if match:
# Function returns a table -- get the column names
return list(field_names(match.group(1), mode_filter=None))
# Function may have named output arguments -- find them and return
# their names
return list(field_names(self.arg_list, mode_filter=('OUT', 'INOUT')))
class TypedFieldMetadata(object):
"""Describes typed field from a function signature or table definition
Attributes are:
name The name of the argument/column
mode 'IN', 'OUT', 'INOUT', 'VARIADIC'
type A list of tokens denoting the type
default A list of tokens denoting the default value
unknown A list of tokens not assigned to type or default
"""
__slots__ = ['name', 'mode', 'type', 'default', 'unknown']
def __init__(self):
self.name = None
self.mode = 'IN'
self.type = []
self.default = []
self.unknown = []
def __getitem__(self, attr):
return getattr(self, attr)
def parse_typed_field_list(tokens):
"""Parses a argument/column list, yielding TypedFieldMetadata objects
Field/column lists are used in function signatures and table
definitions. This function parses a flattened list of sqlparse tokens
and yields one metadata argument per argument / column.
"""
# postgres function argument list syntax:
# " ( [ [ argmode ] [ argname ] argtype
# [ { DEFAULT | = } default_expr ] [, ...] ] )"
mode_names = set(('IN', 'OUT', 'INOUT', 'VARIADIC'))
parse_state = 'type'
parens = 0
field = TypedFieldMetadata()
for tok in tokens:
if tok.ttype in Whitespace or tok.ttype in Comment:
continue
elif tok.ttype in Punctuation:
if parens == 0 and tok.value == ',':
# End of the current field specification
if field.type:
yield field
# Initialize metadata holder for the next field
field, parse_state = TypedFieldMetadata(), 'type'
elif parens == 0 and tok.value == '=':
parse_state = 'default'
else:
field[parse_state].append(tok)
if tok.value == '(':
parens += 1
elif tok.value == ')':
parens -= 1
elif parens == 0:
if tok.ttype in Keyword:
if not field.name and tok.value.upper() in mode_names:
# No other keywords allowed before arg name
field.mode = tok.value.upper()
elif tok.value.upper() == 'DEFAULT':
parse_state = 'default'
else:
parse_state = 'unknown'
elif tok.ttype == Name and not field.name:
# note that `ttype in Name` would also match Name.Builtin
field.name = tok.value
else:
field[parse_state].append(tok)
else:
field[parse_state].append(tok)
# Final argument won't be followed by a comma, so make sure it gets
# yielded
if field.type:
yield field
def field_names(sql, mode_filter=('IN', 'OUT', 'INOUT', 'VARIADIC')):
"""Yields field names from a table declaration"""
if not sql:
return
# sql is something like "x int, y text, ..."
tokens = sqlparse.parse(sql)[0].flatten()
for f in parse_typed_field_list(tokens):
if f.name and (not mode_filter or f.mode in mode_filter):
yield f.name

View File

@ -73,7 +73,7 @@ TableReference = namedtuple(
# This code is borrowed from sqlparse example script.
# <url>
def is_subselect(parsed):
if not parsed.is_group:
if not parsed.is_group():
return False
sql_type = ('SELECT', 'INSERT', 'UPDATE', 'CREATE', 'DELETE')
for item in parsed.tokens:
@ -104,7 +104,7 @@ def extract_from_part(parsed, stop_at_punctuation=True):
# condition. So we need to ignore the keyword JOIN and its
# variants INNER JOIN, FULL OUTER JOIN, etc.
elif item.ttype is Keyword and (
not item.value.upper() == 'FROM') and (
not item.value.upper() == 'FROM') and (
not item.value.upper().endswith('JOIN')):
tbl_prefix_seen = False
else:

View File

@ -0,0 +1,22 @@
import sqlparse
def query_starts_with(query, prefixes):
"""Check if the query starts with any item from *prefixes*."""
prefixes = [prefix.lower() for prefix in prefixes]
formatted_sql = sqlparse.format(query.lower(), strip_comments=True)
return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
def queries_start_with(queries, prefixes):
"""Check if any queries start with any item from *prefixes*."""
for query in sqlparse.split(queries):
if query and query_starts_with(query, prefixes) is True:
return True
return False
def is_destructive(queries):
"""Returns if any of the queries in *queries* is destructive."""
keywords = ('drop', 'shutdown', 'delete', 'truncate', 'alter')
return queries_start_with(queries, keywords)

View File

@ -0,0 +1,143 @@
from sqlparse import parse
from sqlparse.tokens import Keyword, CTE, DML
from sqlparse.sql import Identifier, IdentifierList, Parenthesis
from collections import namedtuple
from .meta import TableMetadata, ColumnMetadata
# TableExpression is a namedtuple representing a CTE, used internally
# name: cte alias assigned in the query
# columns: list of column names
# start: index into the original string of the left parens starting the CTE
# stop: index into the original string of the right parens ending the CTE
TableExpression = namedtuple('TableExpression', 'name columns start stop')
def isolate_query_ctes(full_text, text_before_cursor):
"""Simplify a query by converting CTEs into table metadata objects
"""
if not full_text:
return full_text, text_before_cursor, tuple()
ctes, remainder = extract_ctes(full_text)
if not ctes:
return full_text, text_before_cursor, ()
current_position = len(text_before_cursor)
meta = []
for cte in ctes:
if cte.start < current_position < cte.stop:
# Currently editing a cte - treat its body as the current full_text
text_before_cursor = full_text[cte.start:current_position]
full_text = full_text[cte.start:cte.stop]
return full_text, text_before_cursor, meta
# Append this cte to the list of available table metadata
cols = (ColumnMetadata(name, None, ()) for name in cte.columns)
meta.append(TableMetadata(cte.name, cols))
# Editing past the last cte (ie the main body of the query)
full_text = full_text[ctes[-1].stop:]
text_before_cursor = text_before_cursor[ctes[-1].stop:current_position]
return full_text, text_before_cursor, tuple(meta)
def extract_ctes(sql):
""" Extract constant table expresseions from a query
Returns tuple (ctes, remainder_sql)
ctes is a list of TableExpression namedtuples
remainder_sql is the text from the original query after the CTEs have
been stripped.
"""
p = parse(sql)[0]
# Make sure the first meaningful token is "WITH" which is necessary to
# define CTEs
idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True)
if not (tok and tok.ttype == CTE):
return [], sql
# Get the next (meaningful) token, which should be the first CTE
idx, tok = p.token_next(idx)
if not tok:
return ([], '')
start_pos = token_start_pos(p.tokens, idx)
ctes = []
if isinstance(tok, IdentifierList):
# Multiple ctes
for t in tok.get_identifiers():
cte_start_offset = token_start_pos(tok.tokens, tok.token_index(t))
cte = get_cte_from_token(t, start_pos + cte_start_offset)
if not cte:
continue
ctes.append(cte)
elif isinstance(tok, Identifier):
# A single CTE
cte = get_cte_from_token(tok, start_pos)
if cte:
ctes.append(cte)
idx = p.token_index(tok) + 1
# Collapse everything after the ctes into a remainder query
remainder = u''.join(str(tok) for tok in p.tokens[idx:])
return ctes, remainder
def get_cte_from_token(tok, pos0):
cte_name = tok.get_real_name()
if not cte_name:
return None
# Find the start position of the opening parens enclosing the cte body
idx, parens = tok.token_next_by(Parenthesis)
if not parens:
return None
start_pos = pos0 + token_start_pos(tok.tokens, idx)
cte_len = len(str(parens)) # includes parens
stop_pos = start_pos + cte_len
column_names = extract_column_names(parens)
return TableExpression(cte_name, column_names, start_pos, stop_pos)
def extract_column_names(parsed):
# Find the first DML token to check if it's a SELECT or
# INSERT/UPDATE/DELETE
idx, tok = parsed.token_next_by(t=DML)
tok_val = tok and tok.value.lower()
if tok_val in ('insert', 'update', 'delete'):
# Jump ahead to the RETURNING clause where the list of column names is
idx, tok = parsed.token_next_by(idx, (Keyword, 'returning'))
elif not tok_val == 'select':
# Must be invalid CTE
return ()
# The next token should be either a column name, or a list of column names
idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True)
return tuple(t.get_name() for t in _identifiers(tok))
def token_start_pos(tokens, idx):
return sum(len(str(t)) for t in tokens[:idx])
def _identifiers(tok):
if isinstance(tok, IdentifierList):
for t in tok.get_identifiers():
# NB: IdentifierList.get_identifiers() can return non-identifiers!
if isinstance(t, Identifier):
yield t
elif isinstance(tok, Identifier):
yield tok

View File

@ -0,0 +1,151 @@
from collections import namedtuple
_ColumnMetadata = namedtuple(
'ColumnMetadata',
['name', 'datatype', 'foreignkeys', 'default', 'has_default']
)
def ColumnMetadata(
name, datatype, foreignkeys=None, default=None, has_default=False
):
return _ColumnMetadata(
name, datatype, foreignkeys or [], default, has_default
)
ForeignKey = namedtuple('ForeignKey', ['parentschema', 'parenttable',
'parentcolumn', 'childschema',
'childtable', 'childcolumn'])
TableMetadata = namedtuple('TableMetadata', 'name columns')
def parse_defaults(defaults_string):
"""Yields default values for a function, given the string provided by
pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)"""
if not defaults_string:
return
current = ''
in_quote = None
for char in defaults_string:
if current == '' and char == ' ':
# Skip space after comma separating default expressions
continue
if char == '"' or char == '\'':
if in_quote and char == in_quote:
# End quote
in_quote = None
elif not in_quote:
# Begin quote
in_quote = char
elif char == ',' and not in_quote:
# End of expression
yield current
current = ''
continue
current += char
yield current
class FunctionMetadata(object):
def __init__(
self, schema_name, func_name, arg_names, arg_types, arg_modes,
return_type, is_aggregate, is_window, is_set_returning,
arg_defaults
):
"""Class for describing a postgresql function"""
self.schema_name = schema_name
self.func_name = func_name
self.arg_modes = tuple(arg_modes) if arg_modes else None
self.arg_names = tuple(arg_names) if arg_names else None
# Be flexible in not requiring arg_types -- use None as a placeholder
# for each arg. (Used for compatibility with old versions of postgresql
# where such info is hard to get.
if arg_types:
self.arg_types = tuple(arg_types)
elif arg_modes:
self.arg_types = tuple([None] * len(arg_modes))
elif arg_names:
self.arg_types = tuple([None] * len(arg_names))
else:
self.arg_types = None
self.arg_defaults = tuple(parse_defaults(arg_defaults))
self.return_type = return_type.strip()
self.is_aggregate = is_aggregate
self.is_window = is_window
self.is_set_returning = is_set_returning
def __eq__(self, other):
return (isinstance(other, self.__class__) and
self.__dict__ == other.__dict__)
def __ne__(self, other):
return not self.__eq__(other)
def _signature(self):
return (
self.schema_name, self.func_name, self.arg_names, self.arg_types,
self.arg_modes, self.return_type, self.is_aggregate,
self.is_window, self.is_set_returning, self.arg_defaults
)
def __hash__(self):
return hash(self._signature())
def __repr__(self):
return (
(
'%s(schema_name=%r, func_name=%r, arg_names=%r, '
'arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, '
'is_window=%r, is_set_returning=%r, arg_defaults=%r)'
) % (self.__class__.__name__,) + self._signature()
)
def has_variadic(self):
return self.arg_modes and any(
arg_mode == 'v' for arg_mode in self.arg_modes)
def args(self):
"""Returns a list of input-parameter ColumnMetadata namedtuples."""
if not self.arg_names:
return []
modes = self.arg_modes or ['i'] * len(self.arg_names)
args = [
(name, typ)
for name, typ, mode in zip(self.arg_names, self.arg_types, modes)
if mode in ('i', 'b', 'v') # IN, INOUT, VARIADIC
]
def arg(name, typ, num):
num_args = len(args)
num_defaults = len(self.arg_defaults)
has_default = num + num_defaults >= num_args
default = (
self.arg_defaults[num - num_args + num_defaults] if has_default
else None
)
return ColumnMetadata(name, typ, [], default, has_default)
return [arg(name, typ, num) for num, (name, typ) in enumerate(args)]
def fields(self):
"""Returns a list of output-field ColumnMetadata namedtuples"""
if self.return_type.lower() == 'void':
return []
elif not self.arg_modes:
# For functions without output parameters, the function name
# is used as the name of the output column.
# E.g. 'SELECT unnest FROM unnest(...);'
return [ColumnMetadata(self.func_name, self.return_type, [])]
return [ColumnMetadata(name, typ, [])
for name, typ, mode in zip(
self.arg_names, self.arg_types, self.arg_modes)
if mode in ('o', 'b', 't')] # OUT, INOUT, TABLE

View File

@ -0,0 +1,149 @@
from __future__ import print_function
import sqlparse
from collections import namedtuple
from sqlparse.sql import IdentifierList, Identifier, Function
from sqlparse.tokens import Keyword, DML, Punctuation
TableReference = namedtuple('TableReference', ['schema', 'name', 'alias',
'is_function'])
TableReference.ref = property(
lambda self: self.alias or (
self.name if self.name.islower() or self.name[0] == '"'
else '"' + self.name + '"')
)
# This code is borrowed from sqlparse example script.
# <url>
def is_subselect(parsed):
if not parsed.is_group:
return False
for item in parsed.tokens:
if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT',
'UPDATE', 'CREATE',
'DELETE'):
return True
return False
def _identifier_is_function(identifier):
return any(isinstance(t, Function) for t in identifier.tokens)
def extract_from_part(parsed, stop_at_punctuation=True):
tbl_prefix_seen = False
for item in parsed.tokens:
if tbl_prefix_seen:
if is_subselect(item):
for x in extract_from_part(item, stop_at_punctuation):
yield x
elif stop_at_punctuation and item.ttype is Punctuation:
raise StopIteration
# An incomplete nested select won't be recognized correctly as a
# sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
# the second FROM to trigger this elif condition resulting in a
# StopIteration. So we need to ignore the keyword if the keyword
# FROM.
# Also 'SELECT * FROM abc JOIN def' will trigger this elif
# condition. So we need to ignore the keyword JOIN and its variants
# INNER JOIN, FULL OUTER JOIN, etc.
elif item.ttype is Keyword and (
not item.value.upper() == 'FROM') and (
not item.value.upper().endswith('JOIN')):
tbl_prefix_seen = False
else:
yield item
elif item.ttype is Keyword or item.ttype is Keyword.DML:
item_val = item.value.upper()
if (item_val in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE') or
item_val.endswith('JOIN')):
tbl_prefix_seen = True
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
# So this check here is necessary.
elif isinstance(item, IdentifierList):
for identifier in item.get_identifiers():
if (identifier.ttype is Keyword and
identifier.value.upper() == 'FROM'):
tbl_prefix_seen = True
break
def extract_table_identifiers(token_stream, allow_functions=True):
"""yields tuples of TableReference namedtuples"""
# We need to do some massaging of the names because postgres is case-
# insensitive and '"Foo"' is not the same table as 'Foo' (while 'foo' is)
def parse_identifier(item):
name = item.get_real_name()
schema_name = item.get_parent_name()
alias = item.get_alias()
if not name:
schema_name = None
name = item.get_name()
alias = alias or name
schema_quoted = schema_name and item.value[0] == '"'
if schema_name and not schema_quoted:
schema_name = schema_name.lower()
quote_count = item.value.count('"')
name_quoted = quote_count > 2 or (quote_count and not schema_quoted)
alias_quoted = alias and item.value[-1] == '"'
if alias_quoted or name_quoted and not alias and name.islower():
alias = '"' + (alias or name) + '"'
if name and not name_quoted and not name.islower():
if not alias:
alias = name
name = name.lower()
return schema_name, name, alias
for item in token_stream:
if isinstance(item, IdentifierList):
for identifier in item.get_identifiers():
# Sometimes Keywords (such as FROM ) are classified as
# identifiers which don't have the get_real_name() method.
try:
schema_name = identifier.get_parent_name()
real_name = identifier.get_real_name()
is_function = (allow_functions and
_identifier_is_function(identifier))
except AttributeError:
continue
if real_name:
yield TableReference(schema_name, real_name,
identifier.get_alias(), is_function)
elif isinstance(item, Identifier):
schema_name, real_name, alias = parse_identifier(item)
is_function = allow_functions and _identifier_is_function(item)
yield TableReference(schema_name, real_name, alias, is_function)
elif isinstance(item, Function):
schema_name, real_name, alias = parse_identifier(item)
yield TableReference(None, real_name, alias, allow_functions)
# extract_tables is inspired from examples in the sqlparse lib.
def extract_tables(sql):
"""Extract the table names from an SQL statment.
Returns a list of TableReference namedtuples
"""
parsed = sqlparse.parse(sql)
if not parsed:
return ()
# INSERT statements must stop looking for tables at the sign of first
# Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
# abc is the table name, but if we don't stop at the first lparen, then
# we'll identify abc, col1 and col2 as table names.
insert_stmt = parsed[0].token_first().value.lower() == 'insert'
stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
# Kludge: sqlparse mistakenly identifies insert statements as
# function calls due to the parenthesized column list, e.g. interprets
# "insert into foo (bar, baz)" as a function call to foo with arguments
# (bar, baz). So don't allow any identifiers in insert statements
# to have is_function=True
identifiers = extract_table_identifiers(stream,
allow_functions=not insert_stmt)
# In the case 'sche.<cursor>', we get an empty TableReference; remove that
return tuple(i for i in identifiers if i.name)

View File

@ -0,0 +1,140 @@
from __future__ import print_function
import re
import sqlparse
from sqlparse.sql import Identifier
from sqlparse.tokens import Token, Error
cleanup_regex = {
# This matches only alphanumerics and underscores.
'alphanum_underscore': re.compile(r'(\w+)$'),
# This matches everything except spaces, parens, colon, and comma
'many_punctuations': re.compile(r'([^():,\s]+)$'),
# This matches everything except spaces, parens, colon, comma, and period
'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
# This matches everything except a space.
'all_punctuations': re.compile(r'([^\s]+)$'),
}
def last_word(text, include='alphanum_underscore'):
r"""
Find the last word in a sentence.
>>> last_word('abc')
'abc'
>>> last_word(' abc')
'abc'
>>> last_word('')
''
>>> last_word(' ')
''
>>> last_word('abc ')
''
>>> last_word('abc def')
'def'
>>> last_word('abc def ')
''
>>> last_word('abc def;')
''
>>> last_word('bac $def')
'def'
>>> last_word('bac $def', include='most_punctuations')
'$def'
>>> last_word('bac \def', include='most_punctuations')
'\\\\def'
>>> last_word('bac \def;', include='most_punctuations')
'\\\\def;'
>>> last_word('bac::def', include='most_punctuations')
'def'
>>> last_word('"foo*bar', include='most_punctuations')
'"foo*bar'
"""
if not text: # Empty string
return ''
if text[-1].isspace():
return ''
else:
regex = cleanup_regex[include]
matches = regex.search(text)
if matches:
return matches.group(0)
else:
return ''
def find_prev_keyword(sql, n_skip=0):
""" Find the last sql keyword in an SQL statement
Returns the value of the last keyword, and the text of the query with
everything after the last keyword stripped
"""
if not sql.strip():
return None, ''
parsed = sqlparse.parse(sql)[0]
flattened = list(parsed.flatten())
flattened = flattened[:len(flattened) - n_skip]
logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN')
for t in reversed(flattened):
if t.value == '(' or (t.is_keyword and (
t.value.upper() not in logical_operators)):
# Find the location of token t in the original parsed statement
# We can't use parsed.token_index(t) because t may be a child token
# inside a TokenList, in which case token_index thows an error
# Minimal example:
# p = sqlparse.parse('select * from foo where bar')
# t = list(p.flatten())[-3] # The "Where" token
# p.token_index(t) # Throws ValueError: not in list
idx = flattened.index(t)
# Combine the string values of all tokens in the original list
# up to and including the target keyword token t, to produce a
# query string with everything after the keyword token removed
text = ''.join(tok.value for tok in flattened[:idx + 1])
return t, text
return None, ''
# Postgresql dollar quote signs look like `$$` or `$tag$`
dollar_quote_regex = re.compile(r'^\$[^$]*\$$')
def is_open_quote(sql):
"""Returns true if the query contains an unclosed quote"""
# parsed can contain one or more semi-colon separated commands
parsed = sqlparse.parse(sql)
return any(_parsed_is_open_quote(p) for p in parsed)
def _parsed_is_open_quote(parsed):
# Look for unmatched single quotes, or unmatched dollar sign quotes
return any(tok.match(Token.Error, ("'", "$")) for tok in parsed.flatten())
def parse_partial_identifier(word):
"""Attempt to parse a (partially typed) word as an identifier
word may include a schema qualification, like `schema_name.partial_name`
or `schema_name.` There may also be unclosed quotation marks, like
`"schema`, or `schema."partial_name`
:param word: string representing a (partially complete) identifier
:return: sqlparse.sql.Identifier, or None
"""
p = sqlparse.parse(word)[0]
n_tok = len(p.tokens)
if n_tok == 1 and isinstance(p.tokens[0], Identifier):
return p.tokens[0]
elif p.token_next_by(m=(Error, '"'))[1]:
# An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar'
# Close the double quote, then reparse
return parse_partial_identifier(word + '"')
else:
return None

View File

@ -1,8 +1,8 @@
import re
from collections import defaultdict
import sqlparse
from sqlparse.tokens import Name
from collections import defaultdict
white_space_regex = re.compile(r'\\s+', re.MULTILINE)
@ -10,7 +10,7 @@ white_space_regex = re.compile(r'\\s+', re.MULTILINE)
def _compile_regex(keyword):
# Surround the keyword with word boundaries and replace interior whitespace
# with whitespace wildcards
pattern = r'\\b' + re.sub(white_space_regex, r'\\s+', keyword) + r'\\b'
pattern = r'\\b' + white_space_regex.sub(r'\\s+', keyword) + r'\\b'
return re.compile(pattern, re.MULTILINE | re.IGNORECASE)

View File

@ -0,0 +1,521 @@
from __future__ import print_function
import sys
import re
import sqlparse
from collections import namedtuple
from sqlparse.sql import Comparison, Identifier, Where
from .parseutils.utils import (
last_word, find_prev_keyword, parse_partial_identifier)
from .parseutils.tables import extract_tables
from .parseutils.ctes import isolate_query_ctes
PY2 = sys.version_info[0] == 2
PY3 = sys.version_info[0] == 3
if PY3:
string_types = str
else:
string_types = basestring
Special = namedtuple('Special', [])
Database = namedtuple('Database', [])
Schema = namedtuple('Schema', ['quoted'])
Schema.__new__.__defaults__ = (False,)
# FromClauseItem is a table/view/function used in the FROM clause
# `table_refs` contains the list of tables/... already in the statement,
# used to ensure that the alias we suggest is unique
FromClauseItem = namedtuple('FromClauseItem', 'schema table_refs local_tables')
Table = namedtuple('Table', ['schema', 'table_refs', 'local_tables'])
TableFormat = namedtuple('TableFormat', [])
View = namedtuple('View', ['schema', 'table_refs'])
# JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid'
JoinCondition = namedtuple('JoinCondition', ['table_refs', 'parent'])
# Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid'
Join = namedtuple('Join', ['table_refs', 'schema'])
Function = namedtuple('Function', ['schema', 'table_refs', 'usage'])
# For convenience, don't require the `usage` argument in Function constructor
Function.__new__.__defaults__ = (None, tuple(), None)
Table.__new__.__defaults__ = (None, tuple(), tuple())
View.__new__.__defaults__ = (None, tuple())
FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple())
Column = namedtuple(
'Column',
['table_refs', 'require_last_table', 'local_tables',
'qualifiable', 'context']
)
Column.__new__.__defaults__ = (None, None, tuple(), False, None)
Keyword = namedtuple('Keyword', ['last_token'])
Keyword.__new__.__defaults__ = (None,)
NamedQuery = namedtuple('NamedQuery', [])
Datatype = namedtuple('Datatype', ['schema'])
Alias = namedtuple('Alias', ['aliases'])
Path = namedtuple('Path', [])
class SqlStatement(object):
def __init__(self, full_text, text_before_cursor):
self.identifier = None
self.word_before_cursor = word_before_cursor = last_word(
text_before_cursor, include='many_punctuations')
full_text = _strip_named_query(full_text)
text_before_cursor = _strip_named_query(text_before_cursor)
full_text, text_before_cursor, self.local_tables = \
isolate_query_ctes(full_text, text_before_cursor)
self.text_before_cursor_including_last_word = text_before_cursor
# If we've partially typed a word then word_before_cursor won't be an
# empty string. In that case we want to remove the partially typed
# string before sending it to the sqlparser. Otherwise the last token
# will always be the partially typed string which renders the smart
# completion useless because it will always return the list of
# keywords as completion.
if self.word_before_cursor:
if word_before_cursor[-1] == '(' or word_before_cursor[0] == '\\':
parsed = sqlparse.parse(text_before_cursor)
else:
text_before_cursor = \
text_before_cursor[:-len(word_before_cursor)]
parsed = sqlparse.parse(text_before_cursor)
self.identifier = parse_partial_identifier(word_before_cursor)
else:
parsed = sqlparse.parse(text_before_cursor)
full_text, text_before_cursor, parsed = \
_split_multiple_statements(full_text, text_before_cursor, parsed)
self.full_text = full_text
self.text_before_cursor = text_before_cursor
self.parsed = parsed
self.last_token = \
parsed and parsed.token_prev(len(parsed.tokens))[1] or ''
def is_insert(self):
return self.parsed.token_first().value.lower() == 'insert'
def get_tables(self, scope='full'):
""" Gets the tables available in the statement.
param `scope:` possible values: 'full', 'insert', 'before'
If 'insert', only the first table is returned.
If 'before', only tables before the cursor are returned.
If not 'insert' and the stmt is an insert, the first table is skipped.
"""
tables = extract_tables(
self.full_text if scope == 'full' else self.text_before_cursor)
if scope == 'insert':
tables = tables[:1]
elif self.is_insert():
tables = tables[1:]
return tables
def get_previous_token(self, token):
return self.parsed.token_prev(self.parsed.token_index(token))[1]
def get_identifier_schema(self):
schema = \
(self.identifier and self.identifier.get_parent_name()) or None
# If schema name is unquoted, lower-case it
if schema and self.identifier.value[0] != '"':
schema = schema.lower()
return schema
def reduce_to_prev_keyword(self, n_skip=0):
prev_keyword, self.text_before_cursor = \
find_prev_keyword(self.text_before_cursor, n_skip=n_skip)
return prev_keyword
def suggest_type(full_text, text_before_cursor):
"""Takes the full_text that is typed so far and also the text before the
cursor to suggest completion type and scope.
Returns a tuple with a type of entity ('table', 'column' etc) and a scope.
A scope for a column category will be a list of tables.
"""
if full_text.startswith('\\i '):
return (Path(),)
# This is a temporary hack; the exception handling
# here should be removed once sqlparse has been fixed
try:
stmt = SqlStatement(full_text, text_before_cursor)
except (TypeError, AttributeError):
return []
return suggest_based_on_last_token(stmt.last_token, stmt)
named_query_regex = re.compile(r'^\s*\\ns\s+[A-z0-9\-_]+\s+')
def _strip_named_query(txt):
"""
This will strip "save named query" command in the beginning of the line:
'\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
"""
if named_query_regex.match(txt):
txt = named_query_regex.sub('', txt)
return txt
function_body_pattern = re.compile(r'(\$.*?\$)([\s\S]*?)\1', re.M)
def _find_function_body(text):
split = function_body_pattern.search(text)
return (split.start(2), split.end(2)) if split else (None, None)
def _statement_from_function(full_text, text_before_cursor, statement):
current_pos = len(text_before_cursor)
body_start, body_end = _find_function_body(full_text)
if body_start is None:
return full_text, text_before_cursor, statement
if not body_start <= current_pos < body_end:
return full_text, text_before_cursor, statement
full_text = full_text[body_start:body_end]
text_before_cursor = text_before_cursor[body_start:]
parsed = sqlparse.parse(text_before_cursor)
return _split_multiple_statements(full_text, text_before_cursor, parsed)
def _split_multiple_statements(full_text, text_before_cursor, parsed):
if len(parsed) > 1:
# Multiple statements being edited -- isolate the current one by
# cumulatively summing statement lengths to find the one that bounds
# the current position
current_pos = len(text_before_cursor)
stmt_start, stmt_end = 0, 0
for statement in parsed:
stmt_len = len(str(statement))
stmt_start, stmt_end = stmt_end, stmt_end + stmt_len
if stmt_end >= current_pos:
text_before_cursor = full_text[stmt_start:current_pos]
full_text = full_text[stmt_start:]
break
elif parsed:
# A single statement
statement = parsed[0]
else:
# The empty string
return full_text, text_before_cursor, None
token2 = None
if statement.get_type() in ('CREATE', 'CREATE OR REPLACE'):
token1 = statement.token_first()
if token1:
token1_idx = statement.token_index(token1)
token2 = statement.token_next(token1_idx)[1]
if token2 and token2.value.upper() == 'FUNCTION':
full_text, text_before_cursor, statement = _statement_from_function(
full_text, text_before_cursor, statement
)
return full_text, text_before_cursor, statement
def suggest_based_on_last_token(token, stmt):
if isinstance(token, string_types):
token_v = token.lower()
elif isinstance(token, Comparison):
# If 'token' is a Comparison type such as
# 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling
# token.value on the comparison type will only return the lhs of the
# comparison. In this case a.id. So we need to do token.tokens to get
# both sides of the comparison and pick the last token out of that
# list.
token_v = token.tokens[-1].value.lower()
elif isinstance(token, Where):
# sqlparse groups all tokens from the where clause into a single token
# list. This means that token.value may be something like
# 'where foo > 5 and '. We need to look "inside" token.tokens to handle
# suggestions in complicated where clauses correctly
prev_keyword = stmt.reduce_to_prev_keyword()
return suggest_based_on_last_token(prev_keyword, stmt)
elif isinstance(token, Identifier):
# If the previous token is an identifier, we can suggest datatypes if
# we're in a parenthesized column/field list, e.g.:
# CREATE TABLE foo (Identifier <CURSOR>
# CREATE FUNCTION foo (Identifier <CURSOR>
# If we're not in a parenthesized list, the most likely scenario is the
# user is about to specify an alias, e.g.:
# SELECT Identifier <CURSOR>
# SELECT foo FROM Identifier <CURSOR>
prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor)
if prev_keyword and prev_keyword.value == '(':
# Suggest datatypes
return suggest_based_on_last_token('type', stmt)
else:
return (Keyword(),)
else:
token_v = token.value.lower()
if not token:
return (Keyword(),)
elif token_v.endswith('('):
p = sqlparse.parse(stmt.text_before_cursor)[0]
if p.tokens and isinstance(p.tokens[-1], Where):
# Four possibilities:
# 1 - Parenthesized clause like "WHERE foo AND ("
# Suggest columns/functions
# 2 - Function call like "WHERE foo("
# Suggest columns/functions
# 3 - Subquery expression like "WHERE EXISTS ("
# Suggest keywords, in order to do a subquery
# 4 - Subquery OR array comparison like "WHERE foo = ANY("
# Suggest columns/functions AND keywords. (If we wanted to be
# really fancy, we could suggest only array-typed columns)
column_suggestions = suggest_based_on_last_token('where', stmt)
# Check for a subquery expression (cases 3 & 4)
where = p.tokens[-1]
prev_tok = where.token_prev(len(where.tokens) - 1)[1]
if isinstance(prev_tok, Comparison):
# e.g. "SELECT foo FROM bar WHERE foo = ANY("
prev_tok = prev_tok.tokens[-1]
prev_tok = prev_tok.value.lower()
if prev_tok == 'exists':
return (Keyword(),)
else:
return column_suggestions
# Get the token before the parens
prev_tok = p.token_prev(len(p.tokens) - 1)[1]
if (
prev_tok and prev_tok.value and
prev_tok.value.lower().split(' ')[-1] == 'using'
):
# tbl1 INNER JOIN tbl2 USING (col1, col2)
tables = stmt.get_tables('before')
# suggest columns that are present in more than one table
return (Column(table_refs=tables,
require_last_table=True,
local_tables=stmt.local_tables),)
elif p.token_first().value.lower() == 'select':
# If the lparen is preceeded by a space chances are we're about to
# do a sub-select.
if last_word(stmt.text_before_cursor,
'all_punctuations').startswith('('):
return (Keyword(),)
prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1]
if prev_prev_tok and prev_prev_tok.normalized == 'INTO':
return (
Column(table_refs=stmt.get_tables('insert'), context='insert'),
)
# We're probably in a function argument list
return (Column(table_refs=extract_tables(stmt.full_text),
local_tables=stmt.local_tables, qualifiable=True),)
elif token_v == 'set':
return (Column(table_refs=stmt.get_tables(),
local_tables=stmt.local_tables),)
elif token_v in ('select', 'where', 'having', 'by', 'distinct'):
# Check for a table alias or schema qualification
parent = (stmt.identifier and stmt.identifier.get_parent_name()) or []
tables = stmt.get_tables()
if parent:
tables = tuple(t for t in tables if identifies(parent, t))
return (Column(table_refs=tables, local_tables=stmt.local_tables),
Table(schema=parent),
View(schema=parent),
Function(schema=parent),)
else:
return (Column(table_refs=tables, local_tables=stmt.local_tables,
qualifiable=True),
Function(schema=None),
Keyword(token_v.upper()),)
elif token_v == 'as':
# Don't suggest anything for aliases
return ()
elif (
(token_v.endswith('join') and token.is_keyword) or
(token_v in ('copy', 'from', 'update', 'into', 'describe', 'truncate'))
):
schema = stmt.get_identifier_schema()
tables = extract_tables(stmt.text_before_cursor)
is_join = token_v.endswith('join') and token.is_keyword
# Suggest tables from either the currently-selected schema or the
# public schema if no schema has been specified
suggest = []
if not schema:
# Suggest schemas
suggest.insert(0, Schema())
if token_v == 'from' or is_join:
suggest.append(FromClauseItem(schema=schema,
table_refs=tables,
local_tables=stmt.local_tables))
elif token_v == 'truncate':
suggest.append(Table(schema))
else:
suggest.extend((Table(schema), View(schema)))
if is_join and _allow_join(stmt.parsed):
tables = stmt.get_tables('before')
suggest.append(Join(table_refs=tables, schema=schema))
return tuple(suggest)
elif token_v == 'function':
schema = stmt.get_identifier_schema()
# stmt.get_previous_token will fail for e.g.
# `SELECT 1 FROM functions WHERE function:`
try:
prev = stmt.get_previous_token(token).value.lower()
if prev in('drop', 'alter', 'create', 'create or replace'):
return (Function(schema=schema, usage='signature'),)
except ValueError:
pass
return tuple()
elif token_v in ('table', 'view'):
# E.g. 'ALTER TABLE <tablname>'
rel_type = \
{'table': Table, 'view': View, 'function': Function}[token_v]
schema = stmt.get_identifier_schema()
if schema:
return (rel_type(schema=schema),)
else:
return (Schema(), rel_type(schema=schema))
elif token_v == 'column':
# E.g. 'ALTER TABLE foo ALTER COLUMN bar
return (Column(table_refs=stmt.get_tables()),)
elif token_v == 'on':
tables = stmt.get_tables('before')
parent = \
(stmt.identifier and stmt.identifier.get_parent_name()) or None
if parent:
# "ON parent.<suggestion>"
# parent can be either a schema name or table alias
filteredtables = tuple(t for t in tables if identifies(parent, t))
sugs = [Column(table_refs=filteredtables,
local_tables=stmt.local_tables),
Table(schema=parent),
View(schema=parent),
Function(schema=parent)]
if filteredtables and _allow_join_condition(stmt.parsed):
sugs.append(JoinCondition(table_refs=tables,
parent=filteredtables[-1]))
return tuple(sugs)
else:
# ON <suggestion>
# Use table alias if there is one, otherwise the table name
aliases = tuple(t.ref for t in tables)
if _allow_join_condition(stmt.parsed):
return (Alias(aliases=aliases), JoinCondition(
table_refs=tables, parent=None))
else:
return (Alias(aliases=aliases),)
elif token_v in ('c', 'use', 'database', 'template'):
# "\c <db", "use <db>", "DROP DATABASE <db>",
# "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
return (Database(),)
elif token_v == 'schema':
# DROP SCHEMA schema_name, SET SCHEMA schema name
prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2)
quoted = prev_keyword and prev_keyword.value.lower() == 'set'
return (Schema(quoted),)
elif token_v.endswith(',') or token_v in ('=', 'and', 'or'):
prev_keyword = stmt.reduce_to_prev_keyword()
if prev_keyword:
return suggest_based_on_last_token(prev_keyword, stmt)
else:
return ()
elif token_v in ('type', '::'):
# ALTER TABLE foo SET DATA TYPE bar
# SELECT foo::bar
# Note that tables are a form of composite type in postgresql, so
# they're suggested here as well
schema = stmt.get_identifier_schema()
suggestions = [Datatype(schema=schema),
Table(schema=schema)]
if not schema:
suggestions.append(Schema())
return tuple(suggestions)
elif token_v in {'alter', 'create', 'drop'}:
return (Keyword(token_v.upper()),)
elif token.is_keyword:
# token is a keyword we haven't implemented any special handling for
# go backwards in the query until we find one we do recognize
prev_keyword = stmt.reduce_to_prev_keyword(n_skip=1)
if prev_keyword:
return suggest_based_on_last_token(prev_keyword, stmt)
else:
return (Keyword(token_v.upper()),)
else:
return (Keyword(),)
def identifies(id, ref):
"""Returns true if string `id` matches TableReference `ref`"""
return id == ref.alias or id == ref.name or (
ref.schema and (id == ref.schema + '.' + ref.name))
def _allow_join_condition(statement):
"""
Tests if a join condition should be suggested
We need this to avoid bad suggestions when entering e.g.
select * from tbl1 a join tbl2 b on a.id = <cursor>
So check that the preceding token is a ON, AND, or OR keyword, instead of
e.g. an equals sign.
:param statement: an sqlparse.sql.Statement
:return: boolean
"""
if not statement or not statement.tokens:
return False
last_tok = statement.token_prev(len(statement.tokens))[1]
return last_tok.value.lower() in ('on', 'and', 'or')
def _allow_join(statement):
"""
Tests if a join should be suggested
We need this to avoid bad suggestions when entering e.g.
select * from tbl1 a join tbl2 b <cursor>
So check that the preceding token is a JOIN keyword
:param statement: an sqlparse.sql.Statement
:return: boolean
"""
if not statement or not statement.tokens:
return False
last_tok = statement.token_prev(len(statement.tokens))[1]
return (
last_tok.value.lower().endswith('join') and
last_tok.value.lower() not in('cross join', 'natural join')
)