Merge pgcli code with version 1.10.3, which is used for auto complete feature.
parent
25679fd542
commit
7a3f3046df
|
@ -21,6 +21,7 @@ Bug fixes
|
|||
| `Bug #3325 <https://redmine.postgresql.org/issues/3325>`_ - Fix sort/filter dialog issue where it incorrectly requires ASC/DESC.
|
||||
| `Bug #3347 <https://redmine.postgresql.org/issues/3347>`_ - Ensure backup should work with '--data-only' and '--schema-only' for any format.
|
||||
| `Bug #3407 <https://redmine.postgresql.org/issues/3407>`_ - Fix keyboard shortcuts layout in the preferences panel.
|
||||
| `Bug #3420 <https://redmine.postgresql.org/issues/3420>`_ - Merge pgcli code with version 1.10.3, which is used for auto complete feature.
|
||||
| `Bug #3461 <https://redmine.postgresql.org/issues/3461>`_ - Ensure that refreshing a node also updates the Property list.
|
||||
| `Bug #3528 <https://redmine.postgresql.org/issues/3528>`_ - Handle connection errors properly in the query tool.
|
||||
| `Bug #3547 <https://redmine.postgresql.org/issues/3547>`_ - Make session implementation thread safe
|
||||
|
|
|
@ -0,0 +1,156 @@
|
|||
##########################################################################
|
||||
#
|
||||
# pgAdmin 4 - PostgreSQL Tools
|
||||
#
|
||||
# Copyright (C) 2013 - 2018, The pgAdmin Development Team
|
||||
# This software is released under the PostgreSQL Licence
|
||||
#
|
||||
##########################################################################
|
||||
|
||||
import sys
|
||||
import random
|
||||
|
||||
from selenium.webdriver import ActionChains
|
||||
from selenium.webdriver.common.keys import Keys
|
||||
from selenium.webdriver.support.ui import WebDriverWait
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from regression.python_test_utils import test_utils
|
||||
from regression.feature_utils.base_feature_test import BaseFeatureTest
|
||||
|
||||
|
||||
class QueryToolAutoCompleteFeatureTest(BaseFeatureTest):
|
||||
"""
|
||||
This feature test will test the query tool auto complete feature.
|
||||
"""
|
||||
|
||||
first_table_name = ""
|
||||
second_table_name = ""
|
||||
|
||||
scenarios = [
|
||||
("Query tool auto complete feature test", dict())
|
||||
]
|
||||
|
||||
def before(self):
|
||||
self.page.wait_for_spinner_to_disappear()
|
||||
|
||||
self.page.add_server(self.server)
|
||||
self.first_table_name = "auto_comp_" + \
|
||||
str(random.randint(1000, 3000))
|
||||
test_utils.create_table(self.server, self.test_db,
|
||||
self.first_table_name)
|
||||
|
||||
self.second_table_name = "auto_comp_" + \
|
||||
str(random.randint(1000, 3000))
|
||||
test_utils.create_table(self.server, self.test_db,
|
||||
self.second_table_name)
|
||||
|
||||
self._locate_database_tree_node()
|
||||
self.page.open_query_tool()
|
||||
self.page.wait_for_spinner_to_disappear()
|
||||
|
||||
def runTest(self):
|
||||
# Test case for keywords
|
||||
print("\nAuto complete ALTER keyword... ", file=sys.stderr, end="")
|
||||
self._auto_complete("A", "ALTER")
|
||||
print("OK.", file=sys.stderr)
|
||||
self._clear_query_tool()
|
||||
|
||||
print("Auto complete BEGIN keyword... ", file=sys.stderr, end="")
|
||||
self._auto_complete("BE", "BEGIN")
|
||||
print("OK.", file=sys.stderr)
|
||||
self._clear_query_tool()
|
||||
|
||||
print("Auto complete CASCADED keyword... ", file=sys.stderr, end="")
|
||||
self._auto_complete("CAS", "CASCADED")
|
||||
print("OK.", file=sys.stderr)
|
||||
self._clear_query_tool()
|
||||
|
||||
print("Auto complete SELECT keyword... ", file=sys.stderr, end="")
|
||||
self._auto_complete("SE", "SELECT")
|
||||
print("OK.", file=sys.stderr)
|
||||
self._clear_query_tool()
|
||||
|
||||
print("Auto complete pg_backend_pid() function ... ",
|
||||
file=sys.stderr, end="")
|
||||
self._auto_complete("SELECT pg_", "pg_backend_pid()")
|
||||
print("OK.", file=sys.stderr)
|
||||
self._clear_query_tool()
|
||||
|
||||
print("Auto complete current_query() function ... ",
|
||||
file=sys.stderr, end="")
|
||||
self._auto_complete("SELECT current_", "current_query()")
|
||||
print("OK.", file=sys.stderr)
|
||||
self._clear_query_tool()
|
||||
|
||||
print("Auto complete function with argument ... ",
|
||||
file=sys.stderr, end="")
|
||||
self._auto_complete("SELECT pg_st", "pg_stat_file(filename)")
|
||||
print("OK.", file=sys.stderr)
|
||||
self._clear_query_tool()
|
||||
|
||||
print("Auto complete first table in public schema ... ",
|
||||
file=sys.stderr, end="")
|
||||
self._auto_complete("SELECT * FROM public.", self.first_table_name)
|
||||
print("OK.", file=sys.stderr)
|
||||
self._clear_query_tool()
|
||||
|
||||
print("Auto complete second table in public schema ... ",
|
||||
file=sys.stderr, end="")
|
||||
self._auto_complete("SELECT * FROM public.", self.second_table_name)
|
||||
print("OK.", file=sys.stderr)
|
||||
self._clear_query_tool()
|
||||
|
||||
print("Auto complete JOIN second table with after schema name ... ",
|
||||
file=sys.stderr, end="")
|
||||
query = "SELECT * FROM public." + self.first_table_name + \
|
||||
" JOIN public."
|
||||
self._auto_complete(query, self.second_table_name)
|
||||
print("OK.", file=sys.stderr)
|
||||
self._clear_query_tool()
|
||||
|
||||
print("Auto complete JOIN ON some columns ... ",
|
||||
file=sys.stderr, end="")
|
||||
query = "SELECT * FROM public." + self.first_table_name + \
|
||||
" JOIN public." + self.second_table_name + " ON " + \
|
||||
self.second_table_name + "."
|
||||
expected_string = "some_column = " + self.first_table_name + \
|
||||
".some_column"
|
||||
self._auto_complete(query, expected_string)
|
||||
print("OK.", file=sys.stderr)
|
||||
self._clear_query_tool()
|
||||
|
||||
print("Auto complete JOIN ON some columns using tabel alias ... ",
|
||||
file=sys.stderr, end="")
|
||||
query = "SELECT * FROM public." + self.first_table_name + \
|
||||
" t1 JOIN public." + self.second_table_name + " t2 ON t2."
|
||||
self._auto_complete(query, "some_column = t1.some_column")
|
||||
print("OK.", file=sys.stderr)
|
||||
self._clear_query_tool()
|
||||
|
||||
def after(self):
|
||||
self.page.remove_server(self.server)
|
||||
|
||||
def _locate_database_tree_node(self):
|
||||
self.page.toggle_open_tree_item(self.server['name'])
|
||||
self.page.toggle_open_tree_item('Databases')
|
||||
self.page.toggle_open_tree_item(self.test_db)
|
||||
|
||||
def _clear_query_tool(self):
|
||||
self.page.click_element(
|
||||
self.page.find_by_xpath("//*[@id='btn-clear-dropdown']")
|
||||
)
|
||||
ActionChains(self.driver) \
|
||||
.move_to_element(self.page.find_by_xpath("//*[@id='btn-clear']")) \
|
||||
.perform()
|
||||
self.page.click_element(
|
||||
self.page.find_by_xpath("//*[@id='btn-clear']")
|
||||
)
|
||||
self.page.click_modal('Yes')
|
||||
|
||||
def _auto_complete(self, word, expected_string):
|
||||
self.page.fill_codemirror_area_with(word)
|
||||
ActionChains(self.page.driver).key_down(
|
||||
Keys.CONTROL).send_keys(Keys.SPACE).key_up(Keys.CONTROL).perform()
|
||||
self.page.find_by_xpath(
|
||||
"//ul[contains(@class, 'CodeMirror-hints') and "
|
||||
"contains(., '" + expected_string + "')]")
|
|
@ -1,30 +1,16 @@
|
|||
{# ============= Fetch the list of functions based on given schema_names ============= #}
|
||||
{% if func_name %}
|
||||
SELECT n.nspname schema_name,
|
||||
p.proname func_name,
|
||||
pg_catalog.pg_get_function_arguments(p.oid) arg_list,
|
||||
pg_catalog.pg_get_function_result(p.oid) return_type,
|
||||
p.proargnames arg_names,
|
||||
COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[] arg_types,
|
||||
p.proargmodes arg_modes,
|
||||
prorettype::regtype::text return_type,
|
||||
CASE WHEN p.prokind = 'a' THEN true ELSE false END is_aggregate,
|
||||
CASE WHEN p.prokind = 'w' THEN true ELSE false END is_window,
|
||||
p.proretset is_set_returning
|
||||
p.proretset is_set_returning,
|
||||
pg_get_expr(proargdefaults, 0) AS arg_defaults
|
||||
FROM pg_catalog.pg_proc p
|
||||
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
|
||||
WHERE n.nspname = '{{schema_name}}' AND p.proname = '{{func_name}}'
|
||||
AND p.proretset
|
||||
ORDER BY 1, 2
|
||||
{% else %}
|
||||
SELECT n.nspname schema_name,
|
||||
p.proname object_name,
|
||||
pg_catalog.pg_get_function_arguments(p.oid) arg_list,
|
||||
pg_catalog.pg_get_function_result(p.oid) return_type,
|
||||
CASE WHEN p.prokind = 'a' THEN true ELSE false END is_aggregate,
|
||||
CASE WHEN p.prokind = 'w' THEN true ELSE false END is_window,
|
||||
p.proretset is_set_returning
|
||||
FROM pg_catalog.pg_proc p
|
||||
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
|
||||
WHERE n.nspname IN ({{schema_names}})
|
||||
{% if is_set_returning %}
|
||||
AND p.proretset
|
||||
{% endif %}
|
||||
ORDER BY 1, 2
|
||||
{% endif %}
|
||||
WHERE p.prorettype::regtype != 'trigger'::regtype
|
||||
AND n.nspname IN ({{schema_names}})
|
||||
ORDER BY 1, 2
|
||||
|
|
|
@ -1,29 +1,43 @@
|
|||
{# SQL query for getting columns #}
|
||||
{% if object_name == 'table' %}
|
||||
SELECT
|
||||
att.attname column_name
|
||||
FROM pg_catalog.pg_attribute att
|
||||
INNER JOIN pg_catalog.pg_class cls
|
||||
ON att.attrelid = cls.oid
|
||||
INNER JOIN pg_catalog.pg_namespace nsp
|
||||
ON cls.relnamespace = nsp.oid
|
||||
WHERE cls.relkind = ANY(array['r'])
|
||||
SELECT nsp.nspname schema_name,
|
||||
cls.relname table_name,
|
||||
att.attname column_name,
|
||||
att.atttypid::regtype::text type_name,
|
||||
att.atthasdef AS has_default,
|
||||
def.adsrc as default
|
||||
FROM pg_catalog.pg_attribute att
|
||||
INNER JOIN pg_catalog.pg_class cls
|
||||
ON att.attrelid = cls.oid
|
||||
INNER JOIN pg_catalog.pg_namespace nsp
|
||||
ON cls.relnamespace = nsp.oid
|
||||
LEFT OUTER JOIN pg_attrdef def
|
||||
ON def.adrelid = att.attrelid
|
||||
AND def.adnum = att.attnum
|
||||
WHERE nsp.nspname IN ({{schema_names}})
|
||||
AND cls.relkind = ANY(array['r'])
|
||||
AND NOT att.attisdropped
|
||||
AND att.attnum > 0
|
||||
AND (nsp.nspname = '{{schema_name}}' AND cls.relname = '{{rel_name}}')
|
||||
ORDER BY 1
|
||||
ORDER BY 1, 2, att.attnum
|
||||
{% endif %}
|
||||
{% if object_name == 'view' %}
|
||||
SELECT
|
||||
att.attname column_name
|
||||
FROM pg_catalog.pg_attribute att
|
||||
INNER JOIN pg_catalog.pg_class cls
|
||||
ON att.attrelid = cls.oid
|
||||
INNER JOIN pg_catalog.pg_namespace nsp
|
||||
ON cls.relnamespace = nsp.oid
|
||||
WHERE cls.relkind = ANY(array['v', 'm'])
|
||||
SELECT nsp.nspname schema_name,
|
||||
cls.relname table_name,
|
||||
att.attname column_name,
|
||||
att.atttypid::regtype::text type_name,
|
||||
att.atthasdef AS has_default,
|
||||
def.adsrc as default
|
||||
FROM pg_catalog.pg_attribute att
|
||||
INNER JOIN pg_catalog.pg_class cls
|
||||
ON att.attrelid = cls.oid
|
||||
INNER JOIN pg_catalog.pg_namespace nsp
|
||||
ON cls.relnamespace = nsp.oid
|
||||
LEFT OUTER JOIN pg_attrdef def
|
||||
ON def.adrelid = att.attrelid
|
||||
AND def.adnum = att.attnum
|
||||
WHERE nsp.nspname IN ({{schema_names}})
|
||||
AND cls.relkind = ANY(array['v', 'm'])
|
||||
AND NOT att.attisdropped
|
||||
AND att.attnum > 0
|
||||
AND (nsp.nspname = '{{schema_name}}' AND cls.relname = '{{rel_name}}')
|
||||
ORDER BY 1
|
||||
{% endif %}
|
||||
ORDER BY 1, 2, att.attnum
|
||||
{% endif %}
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
{# SQL query for getting foreign keys #}
|
||||
SELECT s_p.nspname AS parentschema,
|
||||
t_p.relname AS parenttable,
|
||||
unnest((
|
||||
select
|
||||
array_agg(attname ORDER BY i)
|
||||
from
|
||||
(select unnest(confkey) as attnum, generate_subscripts(confkey, 1) as i) x
|
||||
JOIN pg_catalog.pg_attribute c USING(attnum)
|
||||
WHERE c.attrelid = fk.confrelid
|
||||
)) AS parentcolumn,
|
||||
s_c.nspname AS childschema,
|
||||
t_c.relname AS childtable,
|
||||
unnest((
|
||||
select
|
||||
array_agg(attname ORDER BY i)
|
||||
from
|
||||
(select unnest(conkey) as attnum, generate_subscripts(conkey, 1) as i) x
|
||||
JOIN pg_catalog.pg_attribute c USING(attnum)
|
||||
WHERE c.attrelid = fk.conrelid
|
||||
)) AS childcolumn
|
||||
FROM pg_catalog.pg_constraint fk
|
||||
JOIN pg_catalog.pg_class t_p ON t_p.oid = fk.confrelid
|
||||
JOIN pg_catalog.pg_namespace s_p ON s_p.oid = t_p.relnamespace
|
||||
JOIN pg_catalog.pg_class t_c ON t_c.oid = fk.conrelid
|
||||
JOIN pg_catalog.pg_namespace s_c ON s_c.oid = t_c.relnamespace
|
||||
WHERE fk.contype = 'f' AND s_p.nspname IN ({{schema_names}})
|
|
@ -1,30 +1,16 @@
|
|||
{# ============= Fetch the list of functions based on given schema_names ============= #}
|
||||
{% if func_name %}
|
||||
SELECT n.nspname schema_name,
|
||||
p.proname func_name,
|
||||
pg_catalog.pg_get_function_arguments(p.oid) arg_list,
|
||||
pg_catalog.pg_get_function_result(p.oid) return_type,
|
||||
p.proargnames arg_names,
|
||||
COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[] arg_types,
|
||||
p.proargmodes arg_modes,
|
||||
prorettype::regtype::text return_type,
|
||||
p.proisagg is_aggregate,
|
||||
p.proiswindow is_window,
|
||||
p.proretset is_set_returning
|
||||
p.proretset is_set_returning,
|
||||
pg_get_expr(proargdefaults, 0) AS arg_defaults
|
||||
FROM pg_catalog.pg_proc p
|
||||
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
|
||||
WHERE n.nspname = '{{schema_name}}' AND p.proname = '{{func_name}}'
|
||||
AND p.proretset
|
||||
ORDER BY 1, 2
|
||||
{% else %}
|
||||
SELECT n.nspname schema_name,
|
||||
p.proname object_name,
|
||||
pg_catalog.pg_get_function_arguments(p.oid) arg_list,
|
||||
pg_catalog.pg_get_function_result(p.oid) return_type,
|
||||
p.proisagg is_aggregate,
|
||||
p.proiswindow is_window,
|
||||
p.proretset is_set_returning
|
||||
FROM pg_catalog.pg_proc p
|
||||
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
|
||||
WHERE n.nspname IN ({{schema_names}})
|
||||
{% if is_set_returning %}
|
||||
AND p.proretset
|
||||
{% endif %}
|
||||
ORDER BY 1, 2
|
||||
{% endif %}
|
||||
WHERE p.prorettype::regtype != 'trigger'::regtype
|
||||
AND n.nspname IN ({{schema_names}})
|
||||
ORDER BY 1, 2
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,197 +0,0 @@
|
|||
"""
|
||||
Copied from http://code.activestate.com/recipes/576611-counter-class/
|
||||
"""
|
||||
|
||||
from heapq import nlargest
|
||||
from itertools import repeat
|
||||
try:
|
||||
from itertools import ifilter
|
||||
except ImportError:
|
||||
# ifilter is in-built function in Python3 as filter
|
||||
ifilter = filter
|
||||
from operator import itemgetter
|
||||
|
||||
|
||||
class Counter(dict):
|
||||
'''Dict subclass for counting hashable objects. Sometimes called a bag
|
||||
or multiset. Elements are stored as dictionary keys and their counts
|
||||
are stored as dictionary values.
|
||||
|
||||
>>> Counter('zyzygy')
|
||||
Counter({'y': 3, 'z': 2, 'g': 1})
|
||||
|
||||
'''
|
||||
|
||||
def __init__(self, iterable=None, **kwds):
|
||||
'''Create a new, empty Counter object. And if given, count elements
|
||||
from an input iterable. Or, initialize the count from another mapping
|
||||
of elements to their counts.
|
||||
|
||||
>>> c = Counter() # a new, empty counter
|
||||
>>> c = Counter('gallahad') # a new counter from an iterable
|
||||
>>> c = Counter({'a': 4, 'b': 2}) # a new counter from a mapping
|
||||
>>> c = Counter(a=4, b=2) # a new counter from keyword args
|
||||
|
||||
'''
|
||||
self.update(iterable, **kwds)
|
||||
|
||||
def __missing__(self, key):
|
||||
return 0
|
||||
|
||||
def most_common(self, n=None):
|
||||
'''List the n most common elements and their counts from the most
|
||||
common to the least. If n is None, then list all element counts.
|
||||
|
||||
>>> Counter('abracadabra').most_common(3)
|
||||
[('a', 5), ('r', 2), ('b', 2)]
|
||||
|
||||
'''
|
||||
if n is None:
|
||||
return sorted(self.iteritems(), key=itemgetter(1), reverse=True)
|
||||
return nlargest(n, self.iteritems(), key=itemgetter(1))
|
||||
|
||||
def elements(self):
|
||||
'''Iterator over elements repeating each as many times as its count.
|
||||
|
||||
>>> c = Counter('ABCABC')
|
||||
>>> sorted(c.elements())
|
||||
['A', 'A', 'B', 'B', 'C', 'C']
|
||||
|
||||
If an element's count has been set to zero or is a negative number,
|
||||
elements() will ignore it.
|
||||
|
||||
'''
|
||||
for elem, count in self.iteritems():
|
||||
for _ in repeat(None, count):
|
||||
yield elem
|
||||
|
||||
# Override dict methods where the meaning changes for Counter objects.
|
||||
|
||||
@classmethod
|
||||
def fromkeys(cls, iterable, v=None):
|
||||
raise NotImplementedError(
|
||||
'Counter.fromkeys() is undefined. Use Counter(iterable) instead.'
|
||||
)
|
||||
|
||||
def update(self, iterable=None, **kwds):
|
||||
'''Like dict.update() but add counts instead of replacing them.
|
||||
|
||||
Source can be an iterable, a dictionary, or another Counter instance.
|
||||
|
||||
>>> c = Counter('which')
|
||||
>>> c.update('witch') # add elements from another iterable
|
||||
>>> d = Counter('watch')
|
||||
>>> c.update(d) # add elements from another counter
|
||||
>>> c['h'] # four 'h' in which, witch, and watch
|
||||
4
|
||||
|
||||
'''
|
||||
if iterable is not None:
|
||||
if hasattr(iterable, 'iteritems'):
|
||||
if self:
|
||||
self_get = self.get
|
||||
for elem, count in iterable.iteritems():
|
||||
self[elem] = self_get(elem, 0) + count
|
||||
else:
|
||||
# fast path when counter is empty
|
||||
dict.update(self, iterable)
|
||||
else:
|
||||
self_get = self.get
|
||||
for elem in iterable:
|
||||
self[elem] = self_get(elem, 0) + 1
|
||||
if kwds:
|
||||
self.update(kwds)
|
||||
|
||||
def copy(self):
|
||||
'Like dict.copy() but returns a Counter instance instead of a dict.'
|
||||
return Counter(self)
|
||||
|
||||
def __delitem__(self, elem):
|
||||
"""Like dict.__delitem__() but does not raise KeyError for missing
|
||||
values."""
|
||||
if elem in self:
|
||||
dict.__delitem__(self, elem)
|
||||
|
||||
def __repr__(self):
|
||||
if not self:
|
||||
return '%s()' % self.__class__.__name__
|
||||
items = ', '.join(map('%r: %r'.__mod__, self.most_common()))
|
||||
return '%s({%s})' % (self.__class__.__name__, items)
|
||||
|
||||
# Multiset-style mathematical operations discussed in:
|
||||
# Knuth TAOCP Volume II section 4.6.3 exercise 19
|
||||
# and at http://en.wikipedia.org/wiki/Multiset
|
||||
#
|
||||
# Outputs guaranteed to only include positive counts.
|
||||
#
|
||||
# To strip negative and zero counts, add-in an empty counter:
|
||||
# c += Counter()
|
||||
|
||||
def __add__(self, other):
|
||||
'''Add counts from two counters.
|
||||
|
||||
>>> Counter('abbb') + Counter('bcc')
|
||||
Counter({'b': 4, 'c': 2, 'a': 1})
|
||||
|
||||
|
||||
'''
|
||||
if not isinstance(other, Counter):
|
||||
return NotImplemented
|
||||
result = Counter()
|
||||
for elem in set(self) | set(other):
|
||||
newcount = self[elem] + other[elem]
|
||||
if newcount > 0:
|
||||
result[elem] = newcount
|
||||
return result
|
||||
|
||||
def __sub__(self, other):
|
||||
''' Subtract count, but keep only results with positive counts.
|
||||
|
||||
>>> Counter('abbbc') - Counter('bccd')
|
||||
Counter({'b': 2, 'a': 1})
|
||||
|
||||
'''
|
||||
if not isinstance(other, Counter):
|
||||
return NotImplemented
|
||||
result = Counter()
|
||||
for elem in set(self) | set(other):
|
||||
newcount = self[elem] - other[elem]
|
||||
if newcount > 0:
|
||||
result[elem] = newcount
|
||||
return result
|
||||
|
||||
def __or__(self, other):
|
||||
'''Union is the maximum of value in either of the input counters.
|
||||
|
||||
>>> Counter('abbb') | Counter('bcc')
|
||||
Counter({'b': 3, 'c': 2, 'a': 1})
|
||||
|
||||
'''
|
||||
if not isinstance(other, Counter):
|
||||
return NotImplemented
|
||||
_max = max
|
||||
result = Counter()
|
||||
for elem in set(self) | set(other):
|
||||
newcount = _max(self[elem], other[elem])
|
||||
if newcount > 0:
|
||||
result[elem] = newcount
|
||||
return result
|
||||
|
||||
def __and__(self, other):
|
||||
''' Intersection is the minimum of corresponding counts.
|
||||
|
||||
>>> Counter('abbb') & Counter('bcc')
|
||||
Counter({'b': 1})
|
||||
|
||||
'''
|
||||
if not isinstance(other, Counter):
|
||||
return NotImplemented
|
||||
_min = min
|
||||
result = Counter()
|
||||
if len(self) < len(other):
|
||||
self, other = other, self
|
||||
for elem in ifilter(self.__contains__, other):
|
||||
newcount = _min(self[elem], other[elem])
|
||||
if newcount > 0:
|
||||
result[elem] = newcount
|
||||
return result
|
|
@ -1,151 +0,0 @@
|
|||
import re
|
||||
|
||||
import sqlparse
|
||||
from sqlparse.tokens import Whitespace, Comment, Keyword, Name, Punctuation
|
||||
|
||||
table_def_regex = re.compile(r'^TABLE\s*\((.+)\)$', re.IGNORECASE)
|
||||
|
||||
|
||||
class FunctionMetadata(object):
|
||||
def __init__(self, schema_name, func_name, arg_list, return_type,
|
||||
is_aggregate, is_window, is_set_returning):
|
||||
"""Class for describing a postgresql function"""
|
||||
|
||||
self.schema_name = schema_name
|
||||
self.func_name = func_name
|
||||
self.arg_list = arg_list.strip()
|
||||
self.return_type = return_type.strip()
|
||||
self.is_aggregate = is_aggregate
|
||||
self.is_window = is_window
|
||||
self.is_set_returning = is_set_returning
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, self.__class__) and
|
||||
self.__dict__ == other.__dict__)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.schema_name, self.func_name, self.arg_list,
|
||||
self.return_type, self.is_aggregate, self.is_window,
|
||||
self.is_set_returning))
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
('%s(schema_name=%r, func_name=%r, arg_list=%r, return_type=%r,'
|
||||
' is_aggregate=%r, is_window=%r, is_set_returning=%r)')
|
||||
% (self.__class__.__name__, self.schema_name, self.func_name,
|
||||
self.arg_list, self.return_type, self.is_aggregate,
|
||||
self.is_window, self.is_set_returning)
|
||||
)
|
||||
|
||||
def fieldnames(self):
|
||||
"""Returns a list of output field names"""
|
||||
|
||||
if self.return_type.lower() == 'void':
|
||||
return []
|
||||
|
||||
match = table_def_regex.match(self.return_type)
|
||||
if match:
|
||||
# Function returns a table -- get the column names
|
||||
return list(field_names(match.group(1), mode_filter=None))
|
||||
|
||||
# Function may have named output arguments -- find them and return
|
||||
# their names
|
||||
return list(field_names(self.arg_list, mode_filter=('OUT', 'INOUT')))
|
||||
|
||||
|
||||
class TypedFieldMetadata(object):
|
||||
"""Describes typed field from a function signature or table definition
|
||||
|
||||
Attributes are:
|
||||
name The name of the argument/column
|
||||
mode 'IN', 'OUT', 'INOUT', 'VARIADIC'
|
||||
type A list of tokens denoting the type
|
||||
default A list of tokens denoting the default value
|
||||
unknown A list of tokens not assigned to type or default
|
||||
"""
|
||||
|
||||
__slots__ = ['name', 'mode', 'type', 'default', 'unknown']
|
||||
|
||||
def __init__(self):
|
||||
self.name = None
|
||||
self.mode = 'IN'
|
||||
self.type = []
|
||||
self.default = []
|
||||
self.unknown = []
|
||||
|
||||
def __getitem__(self, attr):
|
||||
return getattr(self, attr)
|
||||
|
||||
|
||||
def parse_typed_field_list(tokens):
|
||||
"""Parses a argument/column list, yielding TypedFieldMetadata objects
|
||||
|
||||
Field/column lists are used in function signatures and table
|
||||
definitions. This function parses a flattened list of sqlparse tokens
|
||||
and yields one metadata argument per argument / column.
|
||||
"""
|
||||
|
||||
# postgres function argument list syntax:
|
||||
# " ( [ [ argmode ] [ argname ] argtype
|
||||
# [ { DEFAULT | = } default_expr ] [, ...] ] )"
|
||||
|
||||
mode_names = set(('IN', 'OUT', 'INOUT', 'VARIADIC'))
|
||||
parse_state = 'type'
|
||||
parens = 0
|
||||
field = TypedFieldMetadata()
|
||||
|
||||
for tok in tokens:
|
||||
if tok.ttype in Whitespace or tok.ttype in Comment:
|
||||
continue
|
||||
elif tok.ttype in Punctuation:
|
||||
if parens == 0 and tok.value == ',':
|
||||
# End of the current field specification
|
||||
if field.type:
|
||||
yield field
|
||||
# Initialize metadata holder for the next field
|
||||
field, parse_state = TypedFieldMetadata(), 'type'
|
||||
elif parens == 0 and tok.value == '=':
|
||||
parse_state = 'default'
|
||||
else:
|
||||
field[parse_state].append(tok)
|
||||
if tok.value == '(':
|
||||
parens += 1
|
||||
elif tok.value == ')':
|
||||
parens -= 1
|
||||
elif parens == 0:
|
||||
if tok.ttype in Keyword:
|
||||
if not field.name and tok.value.upper() in mode_names:
|
||||
# No other keywords allowed before arg name
|
||||
field.mode = tok.value.upper()
|
||||
elif tok.value.upper() == 'DEFAULT':
|
||||
parse_state = 'default'
|
||||
else:
|
||||
parse_state = 'unknown'
|
||||
elif tok.ttype == Name and not field.name:
|
||||
# note that `ttype in Name` would also match Name.Builtin
|
||||
field.name = tok.value
|
||||
else:
|
||||
field[parse_state].append(tok)
|
||||
else:
|
||||
field[parse_state].append(tok)
|
||||
|
||||
# Final argument won't be followed by a comma, so make sure it gets
|
||||
# yielded
|
||||
if field.type:
|
||||
yield field
|
||||
|
||||
|
||||
def field_names(sql, mode_filter=('IN', 'OUT', 'INOUT', 'VARIADIC')):
|
||||
"""Yields field names from a table declaration"""
|
||||
|
||||
if not sql:
|
||||
return
|
||||
|
||||
# sql is something like "x int, y text, ..."
|
||||
tokens = sqlparse.parse(sql)[0].flatten()
|
||||
for f in parse_typed_field_list(tokens):
|
||||
if f.name and (not mode_filter or f.mode in mode_filter):
|
||||
yield f.name
|
|
@ -73,7 +73,7 @@ TableReference = namedtuple(
|
|||
# This code is borrowed from sqlparse example script.
|
||||
# <url>
|
||||
def is_subselect(parsed):
|
||||
if not parsed.is_group:
|
||||
if not parsed.is_group():
|
||||
return False
|
||||
sql_type = ('SELECT', 'INSERT', 'UPDATE', 'CREATE', 'DELETE')
|
||||
for item in parsed.tokens:
|
||||
|
@ -104,7 +104,7 @@ def extract_from_part(parsed, stop_at_punctuation=True):
|
|||
# condition. So we need to ignore the keyword JOIN and its
|
||||
# variants INNER JOIN, FULL OUTER JOIN, etc.
|
||||
elif item.ttype is Keyword and (
|
||||
not item.value.upper() == 'FROM') and (
|
||||
not item.value.upper() == 'FROM') and (
|
||||
not item.value.upper().endswith('JOIN')):
|
||||
tbl_prefix_seen = False
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
import sqlparse
|
||||
|
||||
|
||||
def query_starts_with(query, prefixes):
|
||||
"""Check if the query starts with any item from *prefixes*."""
|
||||
prefixes = [prefix.lower() for prefix in prefixes]
|
||||
formatted_sql = sqlparse.format(query.lower(), strip_comments=True)
|
||||
return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
|
||||
|
||||
|
||||
def queries_start_with(queries, prefixes):
|
||||
"""Check if any queries start with any item from *prefixes*."""
|
||||
for query in sqlparse.split(queries):
|
||||
if query and query_starts_with(query, prefixes) is True:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_destructive(queries):
|
||||
"""Returns if any of the queries in *queries* is destructive."""
|
||||
keywords = ('drop', 'shutdown', 'delete', 'truncate', 'alter')
|
||||
return queries_start_with(queries, keywords)
|
|
@ -0,0 +1,143 @@
|
|||
from sqlparse import parse
|
||||
from sqlparse.tokens import Keyword, CTE, DML
|
||||
from sqlparse.sql import Identifier, IdentifierList, Parenthesis
|
||||
from collections import namedtuple
|
||||
from .meta import TableMetadata, ColumnMetadata
|
||||
|
||||
|
||||
# TableExpression is a namedtuple representing a CTE, used internally
|
||||
# name: cte alias assigned in the query
|
||||
# columns: list of column names
|
||||
# start: index into the original string of the left parens starting the CTE
|
||||
# stop: index into the original string of the right parens ending the CTE
|
||||
TableExpression = namedtuple('TableExpression', 'name columns start stop')
|
||||
|
||||
|
||||
def isolate_query_ctes(full_text, text_before_cursor):
|
||||
"""Simplify a query by converting CTEs into table metadata objects
|
||||
"""
|
||||
|
||||
if not full_text:
|
||||
return full_text, text_before_cursor, tuple()
|
||||
|
||||
ctes, remainder = extract_ctes(full_text)
|
||||
if not ctes:
|
||||
return full_text, text_before_cursor, ()
|
||||
|
||||
current_position = len(text_before_cursor)
|
||||
meta = []
|
||||
|
||||
for cte in ctes:
|
||||
if cte.start < current_position < cte.stop:
|
||||
# Currently editing a cte - treat its body as the current full_text
|
||||
text_before_cursor = full_text[cte.start:current_position]
|
||||
full_text = full_text[cte.start:cte.stop]
|
||||
return full_text, text_before_cursor, meta
|
||||
|
||||
# Append this cte to the list of available table metadata
|
||||
cols = (ColumnMetadata(name, None, ()) for name in cte.columns)
|
||||
meta.append(TableMetadata(cte.name, cols))
|
||||
|
||||
# Editing past the last cte (ie the main body of the query)
|
||||
full_text = full_text[ctes[-1].stop:]
|
||||
text_before_cursor = text_before_cursor[ctes[-1].stop:current_position]
|
||||
|
||||
return full_text, text_before_cursor, tuple(meta)
|
||||
|
||||
|
||||
def extract_ctes(sql):
|
||||
""" Extract constant table expresseions from a query
|
||||
|
||||
Returns tuple (ctes, remainder_sql)
|
||||
|
||||
ctes is a list of TableExpression namedtuples
|
||||
remainder_sql is the text from the original query after the CTEs have
|
||||
been stripped.
|
||||
"""
|
||||
|
||||
p = parse(sql)[0]
|
||||
|
||||
# Make sure the first meaningful token is "WITH" which is necessary to
|
||||
# define CTEs
|
||||
idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True)
|
||||
if not (tok and tok.ttype == CTE):
|
||||
return [], sql
|
||||
|
||||
# Get the next (meaningful) token, which should be the first CTE
|
||||
idx, tok = p.token_next(idx)
|
||||
if not tok:
|
||||
return ([], '')
|
||||
start_pos = token_start_pos(p.tokens, idx)
|
||||
ctes = []
|
||||
|
||||
if isinstance(tok, IdentifierList):
|
||||
# Multiple ctes
|
||||
for t in tok.get_identifiers():
|
||||
cte_start_offset = token_start_pos(tok.tokens, tok.token_index(t))
|
||||
cte = get_cte_from_token(t, start_pos + cte_start_offset)
|
||||
if not cte:
|
||||
continue
|
||||
ctes.append(cte)
|
||||
elif isinstance(tok, Identifier):
|
||||
# A single CTE
|
||||
cte = get_cte_from_token(tok, start_pos)
|
||||
if cte:
|
||||
ctes.append(cte)
|
||||
|
||||
idx = p.token_index(tok) + 1
|
||||
|
||||
# Collapse everything after the ctes into a remainder query
|
||||
remainder = u''.join(str(tok) for tok in p.tokens[idx:])
|
||||
|
||||
return ctes, remainder
|
||||
|
||||
|
||||
def get_cte_from_token(tok, pos0):
|
||||
cte_name = tok.get_real_name()
|
||||
if not cte_name:
|
||||
return None
|
||||
|
||||
# Find the start position of the opening parens enclosing the cte body
|
||||
idx, parens = tok.token_next_by(Parenthesis)
|
||||
if not parens:
|
||||
return None
|
||||
|
||||
start_pos = pos0 + token_start_pos(tok.tokens, idx)
|
||||
cte_len = len(str(parens)) # includes parens
|
||||
stop_pos = start_pos + cte_len
|
||||
|
||||
column_names = extract_column_names(parens)
|
||||
|
||||
return TableExpression(cte_name, column_names, start_pos, stop_pos)
|
||||
|
||||
|
||||
def extract_column_names(parsed):
|
||||
# Find the first DML token to check if it's a SELECT or
|
||||
# INSERT/UPDATE/DELETE
|
||||
idx, tok = parsed.token_next_by(t=DML)
|
||||
tok_val = tok and tok.value.lower()
|
||||
|
||||
if tok_val in ('insert', 'update', 'delete'):
|
||||
# Jump ahead to the RETURNING clause where the list of column names is
|
||||
idx, tok = parsed.token_next_by(idx, (Keyword, 'returning'))
|
||||
elif not tok_val == 'select':
|
||||
# Must be invalid CTE
|
||||
return ()
|
||||
|
||||
# The next token should be either a column name, or a list of column names
|
||||
idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True)
|
||||
return tuple(t.get_name() for t in _identifiers(tok))
|
||||
|
||||
|
||||
def token_start_pos(tokens, idx):
|
||||
return sum(len(str(t)) for t in tokens[:idx])
|
||||
|
||||
|
||||
def _identifiers(tok):
|
||||
if isinstance(tok, IdentifierList):
|
||||
for t in tok.get_identifiers():
|
||||
# NB: IdentifierList.get_identifiers() can return non-identifiers!
|
||||
if isinstance(t, Identifier):
|
||||
yield t
|
||||
elif isinstance(tok, Identifier):
|
||||
yield tok
|
|
@ -0,0 +1,151 @@
|
|||
from collections import namedtuple
|
||||
|
||||
_ColumnMetadata = namedtuple(
|
||||
'ColumnMetadata',
|
||||
['name', 'datatype', 'foreignkeys', 'default', 'has_default']
|
||||
)
|
||||
|
||||
|
||||
def ColumnMetadata(
|
||||
name, datatype, foreignkeys=None, default=None, has_default=False
|
||||
):
|
||||
return _ColumnMetadata(
|
||||
name, datatype, foreignkeys or [], default, has_default
|
||||
)
|
||||
|
||||
|
||||
ForeignKey = namedtuple('ForeignKey', ['parentschema', 'parenttable',
|
||||
'parentcolumn', 'childschema',
|
||||
'childtable', 'childcolumn'])
|
||||
TableMetadata = namedtuple('TableMetadata', 'name columns')
|
||||
|
||||
|
||||
def parse_defaults(defaults_string):
|
||||
"""Yields default values for a function, given the string provided by
|
||||
pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)"""
|
||||
if not defaults_string:
|
||||
return
|
||||
current = ''
|
||||
in_quote = None
|
||||
for char in defaults_string:
|
||||
if current == '' and char == ' ':
|
||||
# Skip space after comma separating default expressions
|
||||
continue
|
||||
if char == '"' or char == '\'':
|
||||
if in_quote and char == in_quote:
|
||||
# End quote
|
||||
in_quote = None
|
||||
elif not in_quote:
|
||||
# Begin quote
|
||||
in_quote = char
|
||||
elif char == ',' and not in_quote:
|
||||
# End of expression
|
||||
yield current
|
||||
current = ''
|
||||
continue
|
||||
current += char
|
||||
yield current
|
||||
|
||||
|
||||
class FunctionMetadata(object):
|
||||
|
||||
def __init__(
|
||||
self, schema_name, func_name, arg_names, arg_types, arg_modes,
|
||||
return_type, is_aggregate, is_window, is_set_returning,
|
||||
arg_defaults
|
||||
):
|
||||
"""Class for describing a postgresql function"""
|
||||
|
||||
self.schema_name = schema_name
|
||||
self.func_name = func_name
|
||||
|
||||
self.arg_modes = tuple(arg_modes) if arg_modes else None
|
||||
self.arg_names = tuple(arg_names) if arg_names else None
|
||||
|
||||
# Be flexible in not requiring arg_types -- use None as a placeholder
|
||||
# for each arg. (Used for compatibility with old versions of postgresql
|
||||
# where such info is hard to get.
|
||||
if arg_types:
|
||||
self.arg_types = tuple(arg_types)
|
||||
elif arg_modes:
|
||||
self.arg_types = tuple([None] * len(arg_modes))
|
||||
elif arg_names:
|
||||
self.arg_types = tuple([None] * len(arg_names))
|
||||
else:
|
||||
self.arg_types = None
|
||||
|
||||
self.arg_defaults = tuple(parse_defaults(arg_defaults))
|
||||
|
||||
self.return_type = return_type.strip()
|
||||
self.is_aggregate = is_aggregate
|
||||
self.is_window = is_window
|
||||
self.is_set_returning = is_set_returning
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, self.__class__) and
|
||||
self.__dict__ == other.__dict__)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def _signature(self):
|
||||
return (
|
||||
self.schema_name, self.func_name, self.arg_names, self.arg_types,
|
||||
self.arg_modes, self.return_type, self.is_aggregate,
|
||||
self.is_window, self.is_set_returning, self.arg_defaults
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self._signature())
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
(
|
||||
'%s(schema_name=%r, func_name=%r, arg_names=%r, '
|
||||
'arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, '
|
||||
'is_window=%r, is_set_returning=%r, arg_defaults=%r)'
|
||||
) % (self.__class__.__name__,) + self._signature()
|
||||
)
|
||||
|
||||
def has_variadic(self):
|
||||
return self.arg_modes and any(
|
||||
arg_mode == 'v' for arg_mode in self.arg_modes)
|
||||
|
||||
def args(self):
|
||||
"""Returns a list of input-parameter ColumnMetadata namedtuples."""
|
||||
if not self.arg_names:
|
||||
return []
|
||||
modes = self.arg_modes or ['i'] * len(self.arg_names)
|
||||
args = [
|
||||
(name, typ)
|
||||
for name, typ, mode in zip(self.arg_names, self.arg_types, modes)
|
||||
if mode in ('i', 'b', 'v') # IN, INOUT, VARIADIC
|
||||
]
|
||||
|
||||
def arg(name, typ, num):
|
||||
num_args = len(args)
|
||||
num_defaults = len(self.arg_defaults)
|
||||
has_default = num + num_defaults >= num_args
|
||||
default = (
|
||||
self.arg_defaults[num - num_args + num_defaults] if has_default
|
||||
else None
|
||||
)
|
||||
return ColumnMetadata(name, typ, [], default, has_default)
|
||||
|
||||
return [arg(name, typ, num) for num, (name, typ) in enumerate(args)]
|
||||
|
||||
def fields(self):
|
||||
"""Returns a list of output-field ColumnMetadata namedtuples"""
|
||||
|
||||
if self.return_type.lower() == 'void':
|
||||
return []
|
||||
elif not self.arg_modes:
|
||||
# For functions without output parameters, the function name
|
||||
# is used as the name of the output column.
|
||||
# E.g. 'SELECT unnest FROM unnest(...);'
|
||||
return [ColumnMetadata(self.func_name, self.return_type, [])]
|
||||
|
||||
return [ColumnMetadata(name, typ, [])
|
||||
for name, typ, mode in zip(
|
||||
self.arg_names, self.arg_types, self.arg_modes)
|
||||
if mode in ('o', 'b', 't')] # OUT, INOUT, TABLE
|
|
@ -0,0 +1,149 @@
|
|||
from __future__ import print_function
|
||||
import sqlparse
|
||||
from collections import namedtuple
|
||||
from sqlparse.sql import IdentifierList, Identifier, Function
|
||||
from sqlparse.tokens import Keyword, DML, Punctuation
|
||||
|
||||
TableReference = namedtuple('TableReference', ['schema', 'name', 'alias',
|
||||
'is_function'])
|
||||
TableReference.ref = property(
|
||||
lambda self: self.alias or (
|
||||
self.name if self.name.islower() or self.name[0] == '"'
|
||||
else '"' + self.name + '"')
|
||||
)
|
||||
|
||||
|
||||
# This code is borrowed from sqlparse example script.
|
||||
# <url>
|
||||
def is_subselect(parsed):
|
||||
if not parsed.is_group:
|
||||
return False
|
||||
for item in parsed.tokens:
|
||||
if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT',
|
||||
'UPDATE', 'CREATE',
|
||||
'DELETE'):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _identifier_is_function(identifier):
|
||||
return any(isinstance(t, Function) for t in identifier.tokens)
|
||||
|
||||
|
||||
def extract_from_part(parsed, stop_at_punctuation=True):
|
||||
tbl_prefix_seen = False
|
||||
for item in parsed.tokens:
|
||||
if tbl_prefix_seen:
|
||||
if is_subselect(item):
|
||||
for x in extract_from_part(item, stop_at_punctuation):
|
||||
yield x
|
||||
elif stop_at_punctuation and item.ttype is Punctuation:
|
||||
raise StopIteration
|
||||
# An incomplete nested select won't be recognized correctly as a
|
||||
# sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
|
||||
# the second FROM to trigger this elif condition resulting in a
|
||||
# StopIteration. So we need to ignore the keyword if the keyword
|
||||
# FROM.
|
||||
# Also 'SELECT * FROM abc JOIN def' will trigger this elif
|
||||
# condition. So we need to ignore the keyword JOIN and its variants
|
||||
# INNER JOIN, FULL OUTER JOIN, etc.
|
||||
elif item.ttype is Keyword and (
|
||||
not item.value.upper() == 'FROM') and (
|
||||
not item.value.upper().endswith('JOIN')):
|
||||
tbl_prefix_seen = False
|
||||
else:
|
||||
yield item
|
||||
elif item.ttype is Keyword or item.ttype is Keyword.DML:
|
||||
item_val = item.value.upper()
|
||||
if (item_val in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE') or
|
||||
item_val.endswith('JOIN')):
|
||||
tbl_prefix_seen = True
|
||||
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
|
||||
# So this check here is necessary.
|
||||
elif isinstance(item, IdentifierList):
|
||||
for identifier in item.get_identifiers():
|
||||
if (identifier.ttype is Keyword and
|
||||
identifier.value.upper() == 'FROM'):
|
||||
tbl_prefix_seen = True
|
||||
break
|
||||
|
||||
|
||||
def extract_table_identifiers(token_stream, allow_functions=True):
|
||||
"""yields tuples of TableReference namedtuples"""
|
||||
|
||||
# We need to do some massaging of the names because postgres is case-
|
||||
# insensitive and '"Foo"' is not the same table as 'Foo' (while 'foo' is)
|
||||
def parse_identifier(item):
|
||||
name = item.get_real_name()
|
||||
schema_name = item.get_parent_name()
|
||||
alias = item.get_alias()
|
||||
if not name:
|
||||
schema_name = None
|
||||
name = item.get_name()
|
||||
alias = alias or name
|
||||
schema_quoted = schema_name and item.value[0] == '"'
|
||||
if schema_name and not schema_quoted:
|
||||
schema_name = schema_name.lower()
|
||||
quote_count = item.value.count('"')
|
||||
name_quoted = quote_count > 2 or (quote_count and not schema_quoted)
|
||||
alias_quoted = alias and item.value[-1] == '"'
|
||||
if alias_quoted or name_quoted and not alias and name.islower():
|
||||
alias = '"' + (alias or name) + '"'
|
||||
if name and not name_quoted and not name.islower():
|
||||
if not alias:
|
||||
alias = name
|
||||
name = name.lower()
|
||||
return schema_name, name, alias
|
||||
|
||||
for item in token_stream:
|
||||
if isinstance(item, IdentifierList):
|
||||
for identifier in item.get_identifiers():
|
||||
# Sometimes Keywords (such as FROM ) are classified as
|
||||
# identifiers which don't have the get_real_name() method.
|
||||
try:
|
||||
schema_name = identifier.get_parent_name()
|
||||
real_name = identifier.get_real_name()
|
||||
is_function = (allow_functions and
|
||||
_identifier_is_function(identifier))
|
||||
except AttributeError:
|
||||
continue
|
||||
if real_name:
|
||||
yield TableReference(schema_name, real_name,
|
||||
identifier.get_alias(), is_function)
|
||||
elif isinstance(item, Identifier):
|
||||
schema_name, real_name, alias = parse_identifier(item)
|
||||
is_function = allow_functions and _identifier_is_function(item)
|
||||
|
||||
yield TableReference(schema_name, real_name, alias, is_function)
|
||||
elif isinstance(item, Function):
|
||||
schema_name, real_name, alias = parse_identifier(item)
|
||||
yield TableReference(None, real_name, alias, allow_functions)
|
||||
|
||||
|
||||
# extract_tables is inspired from examples in the sqlparse lib.
|
||||
def extract_tables(sql):
|
||||
"""Extract the table names from an SQL statment.
|
||||
|
||||
Returns a list of TableReference namedtuples
|
||||
|
||||
"""
|
||||
parsed = sqlparse.parse(sql)
|
||||
if not parsed:
|
||||
return ()
|
||||
|
||||
# INSERT statements must stop looking for tables at the sign of first
|
||||
# Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
|
||||
# abc is the table name, but if we don't stop at the first lparen, then
|
||||
# we'll identify abc, col1 and col2 as table names.
|
||||
insert_stmt = parsed[0].token_first().value.lower() == 'insert'
|
||||
stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
|
||||
|
||||
# Kludge: sqlparse mistakenly identifies insert statements as
|
||||
# function calls due to the parenthesized column list, e.g. interprets
|
||||
# "insert into foo (bar, baz)" as a function call to foo with arguments
|
||||
# (bar, baz). So don't allow any identifiers in insert statements
|
||||
# to have is_function=True
|
||||
identifiers = extract_table_identifiers(stream,
|
||||
allow_functions=not insert_stmt)
|
||||
# In the case 'sche.<cursor>', we get an empty TableReference; remove that
|
||||
return tuple(i for i in identifiers if i.name)
|
|
@ -0,0 +1,140 @@
|
|||
from __future__ import print_function
|
||||
import re
|
||||
import sqlparse
|
||||
from sqlparse.sql import Identifier
|
||||
from sqlparse.tokens import Token, Error
|
||||
|
||||
cleanup_regex = {
|
||||
# This matches only alphanumerics and underscores.
|
||||
'alphanum_underscore': re.compile(r'(\w+)$'),
|
||||
# This matches everything except spaces, parens, colon, and comma
|
||||
'many_punctuations': re.compile(r'([^():,\s]+)$'),
|
||||
# This matches everything except spaces, parens, colon, comma, and period
|
||||
'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
|
||||
# This matches everything except a space.
|
||||
'all_punctuations': re.compile(r'([^\s]+)$'),
|
||||
}
|
||||
|
||||
|
||||
def last_word(text, include='alphanum_underscore'):
|
||||
r"""
|
||||
Find the last word in a sentence.
|
||||
|
||||
>>> last_word('abc')
|
||||
'abc'
|
||||
>>> last_word(' abc')
|
||||
'abc'
|
||||
>>> last_word('')
|
||||
''
|
||||
>>> last_word(' ')
|
||||
''
|
||||
>>> last_word('abc ')
|
||||
''
|
||||
>>> last_word('abc def')
|
||||
'def'
|
||||
>>> last_word('abc def ')
|
||||
''
|
||||
>>> last_word('abc def;')
|
||||
''
|
||||
>>> last_word('bac $def')
|
||||
'def'
|
||||
>>> last_word('bac $def', include='most_punctuations')
|
||||
'$def'
|
||||
>>> last_word('bac \def', include='most_punctuations')
|
||||
'\\\\def'
|
||||
>>> last_word('bac \def;', include='most_punctuations')
|
||||
'\\\\def;'
|
||||
>>> last_word('bac::def', include='most_punctuations')
|
||||
'def'
|
||||
>>> last_word('"foo*bar', include='most_punctuations')
|
||||
'"foo*bar'
|
||||
"""
|
||||
|
||||
if not text: # Empty string
|
||||
return ''
|
||||
|
||||
if text[-1].isspace():
|
||||
return ''
|
||||
else:
|
||||
regex = cleanup_regex[include]
|
||||
matches = regex.search(text)
|
||||
if matches:
|
||||
return matches.group(0)
|
||||
else:
|
||||
return ''
|
||||
|
||||
|
||||
def find_prev_keyword(sql, n_skip=0):
|
||||
""" Find the last sql keyword in an SQL statement
|
||||
|
||||
Returns the value of the last keyword, and the text of the query with
|
||||
everything after the last keyword stripped
|
||||
"""
|
||||
if not sql.strip():
|
||||
return None, ''
|
||||
|
||||
parsed = sqlparse.parse(sql)[0]
|
||||
flattened = list(parsed.flatten())
|
||||
flattened = flattened[:len(flattened) - n_skip]
|
||||
|
||||
logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN')
|
||||
|
||||
for t in reversed(flattened):
|
||||
if t.value == '(' or (t.is_keyword and (
|
||||
t.value.upper() not in logical_operators)):
|
||||
# Find the location of token t in the original parsed statement
|
||||
# We can't use parsed.token_index(t) because t may be a child token
|
||||
# inside a TokenList, in which case token_index thows an error
|
||||
# Minimal example:
|
||||
# p = sqlparse.parse('select * from foo where bar')
|
||||
# t = list(p.flatten())[-3] # The "Where" token
|
||||
# p.token_index(t) # Throws ValueError: not in list
|
||||
idx = flattened.index(t)
|
||||
|
||||
# Combine the string values of all tokens in the original list
|
||||
# up to and including the target keyword token t, to produce a
|
||||
# query string with everything after the keyword token removed
|
||||
text = ''.join(tok.value for tok in flattened[:idx + 1])
|
||||
return t, text
|
||||
|
||||
return None, ''
|
||||
|
||||
|
||||
# Postgresql dollar quote signs look like `$$` or `$tag$`
|
||||
dollar_quote_regex = re.compile(r'^\$[^$]*\$$')
|
||||
|
||||
|
||||
def is_open_quote(sql):
|
||||
"""Returns true if the query contains an unclosed quote"""
|
||||
|
||||
# parsed can contain one or more semi-colon separated commands
|
||||
parsed = sqlparse.parse(sql)
|
||||
return any(_parsed_is_open_quote(p) for p in parsed)
|
||||
|
||||
|
||||
def _parsed_is_open_quote(parsed):
|
||||
# Look for unmatched single quotes, or unmatched dollar sign quotes
|
||||
return any(tok.match(Token.Error, ("'", "$")) for tok in parsed.flatten())
|
||||
|
||||
|
||||
def parse_partial_identifier(word):
|
||||
"""Attempt to parse a (partially typed) word as an identifier
|
||||
|
||||
word may include a schema qualification, like `schema_name.partial_name`
|
||||
or `schema_name.` There may also be unclosed quotation marks, like
|
||||
`"schema`, or `schema."partial_name`
|
||||
|
||||
:param word: string representing a (partially complete) identifier
|
||||
:return: sqlparse.sql.Identifier, or None
|
||||
"""
|
||||
|
||||
p = sqlparse.parse(word)[0]
|
||||
n_tok = len(p.tokens)
|
||||
if n_tok == 1 and isinstance(p.tokens[0], Identifier):
|
||||
return p.tokens[0]
|
||||
elif p.token_next_by(m=(Error, '"'))[1]:
|
||||
# An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar'
|
||||
# Close the double quote, then reparse
|
||||
return parse_partial_identifier(word + '"')
|
||||
else:
|
||||
return None
|
|
@ -1,8 +1,8 @@
|
|||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
import sqlparse
|
||||
from sqlparse.tokens import Name
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
white_space_regex = re.compile(r'\\s+', re.MULTILINE)
|
||||
|
||||
|
@ -10,7 +10,7 @@ white_space_regex = re.compile(r'\\s+', re.MULTILINE)
|
|||
def _compile_regex(keyword):
|
||||
# Surround the keyword with word boundaries and replace interior whitespace
|
||||
# with whitespace wildcards
|
||||
pattern = r'\\b' + re.sub(white_space_regex, r'\\s+', keyword) + r'\\b'
|
||||
pattern = r'\\b' + white_space_regex.sub(r'\\s+', keyword) + r'\\b'
|
||||
return re.compile(pattern, re.MULTILINE | re.IGNORECASE)
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,521 @@
|
|||
from __future__ import print_function
|
||||
import sys
|
||||
import re
|
||||
import sqlparse
|
||||
from collections import namedtuple
|
||||
from sqlparse.sql import Comparison, Identifier, Where
|
||||
from .parseutils.utils import (
|
||||
last_word, find_prev_keyword, parse_partial_identifier)
|
||||
from .parseutils.tables import extract_tables
|
||||
from .parseutils.ctes import isolate_query_ctes
|
||||
|
||||
PY2 = sys.version_info[0] == 2
|
||||
PY3 = sys.version_info[0] == 3
|
||||
|
||||
if PY3:
|
||||
string_types = str
|
||||
else:
|
||||
string_types = basestring
|
||||
|
||||
|
||||
Special = namedtuple('Special', [])
|
||||
Database = namedtuple('Database', [])
|
||||
Schema = namedtuple('Schema', ['quoted'])
|
||||
Schema.__new__.__defaults__ = (False,)
|
||||
# FromClauseItem is a table/view/function used in the FROM clause
|
||||
# `table_refs` contains the list of tables/... already in the statement,
|
||||
# used to ensure that the alias we suggest is unique
|
||||
FromClauseItem = namedtuple('FromClauseItem', 'schema table_refs local_tables')
|
||||
Table = namedtuple('Table', ['schema', 'table_refs', 'local_tables'])
|
||||
TableFormat = namedtuple('TableFormat', [])
|
||||
View = namedtuple('View', ['schema', 'table_refs'])
|
||||
# JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid'
|
||||
JoinCondition = namedtuple('JoinCondition', ['table_refs', 'parent'])
|
||||
# Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid'
|
||||
Join = namedtuple('Join', ['table_refs', 'schema'])
|
||||
|
||||
Function = namedtuple('Function', ['schema', 'table_refs', 'usage'])
|
||||
# For convenience, don't require the `usage` argument in Function constructor
|
||||
Function.__new__.__defaults__ = (None, tuple(), None)
|
||||
Table.__new__.__defaults__ = (None, tuple(), tuple())
|
||||
View.__new__.__defaults__ = (None, tuple())
|
||||
FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple())
|
||||
|
||||
Column = namedtuple(
|
||||
'Column',
|
||||
['table_refs', 'require_last_table', 'local_tables',
|
||||
'qualifiable', 'context']
|
||||
)
|
||||
Column.__new__.__defaults__ = (None, None, tuple(), False, None)
|
||||
|
||||
Keyword = namedtuple('Keyword', ['last_token'])
|
||||
Keyword.__new__.__defaults__ = (None,)
|
||||
NamedQuery = namedtuple('NamedQuery', [])
|
||||
Datatype = namedtuple('Datatype', ['schema'])
|
||||
Alias = namedtuple('Alias', ['aliases'])
|
||||
|
||||
Path = namedtuple('Path', [])
|
||||
|
||||
|
||||
class SqlStatement(object):
|
||||
def __init__(self, full_text, text_before_cursor):
|
||||
self.identifier = None
|
||||
self.word_before_cursor = word_before_cursor = last_word(
|
||||
text_before_cursor, include='many_punctuations')
|
||||
full_text = _strip_named_query(full_text)
|
||||
text_before_cursor = _strip_named_query(text_before_cursor)
|
||||
|
||||
full_text, text_before_cursor, self.local_tables = \
|
||||
isolate_query_ctes(full_text, text_before_cursor)
|
||||
|
||||
self.text_before_cursor_including_last_word = text_before_cursor
|
||||
|
||||
# If we've partially typed a word then word_before_cursor won't be an
|
||||
# empty string. In that case we want to remove the partially typed
|
||||
# string before sending it to the sqlparser. Otherwise the last token
|
||||
# will always be the partially typed string which renders the smart
|
||||
# completion useless because it will always return the list of
|
||||
# keywords as completion.
|
||||
if self.word_before_cursor:
|
||||
if word_before_cursor[-1] == '(' or word_before_cursor[0] == '\\':
|
||||
parsed = sqlparse.parse(text_before_cursor)
|
||||
else:
|
||||
text_before_cursor = \
|
||||
text_before_cursor[:-len(word_before_cursor)]
|
||||
parsed = sqlparse.parse(text_before_cursor)
|
||||
self.identifier = parse_partial_identifier(word_before_cursor)
|
||||
else:
|
||||
parsed = sqlparse.parse(text_before_cursor)
|
||||
|
||||
full_text, text_before_cursor, parsed = \
|
||||
_split_multiple_statements(full_text, text_before_cursor, parsed)
|
||||
|
||||
self.full_text = full_text
|
||||
self.text_before_cursor = text_before_cursor
|
||||
self.parsed = parsed
|
||||
|
||||
self.last_token = \
|
||||
parsed and parsed.token_prev(len(parsed.tokens))[1] or ''
|
||||
|
||||
def is_insert(self):
|
||||
return self.parsed.token_first().value.lower() == 'insert'
|
||||
|
||||
def get_tables(self, scope='full'):
|
||||
""" Gets the tables available in the statement.
|
||||
param `scope:` possible values: 'full', 'insert', 'before'
|
||||
If 'insert', only the first table is returned.
|
||||
If 'before', only tables before the cursor are returned.
|
||||
If not 'insert' and the stmt is an insert, the first table is skipped.
|
||||
"""
|
||||
tables = extract_tables(
|
||||
self.full_text if scope == 'full' else self.text_before_cursor)
|
||||
if scope == 'insert':
|
||||
tables = tables[:1]
|
||||
elif self.is_insert():
|
||||
tables = tables[1:]
|
||||
return tables
|
||||
|
||||
def get_previous_token(self, token):
|
||||
return self.parsed.token_prev(self.parsed.token_index(token))[1]
|
||||
|
||||
def get_identifier_schema(self):
|
||||
schema = \
|
||||
(self.identifier and self.identifier.get_parent_name()) or None
|
||||
# If schema name is unquoted, lower-case it
|
||||
if schema and self.identifier.value[0] != '"':
|
||||
schema = schema.lower()
|
||||
|
||||
return schema
|
||||
|
||||
def reduce_to_prev_keyword(self, n_skip=0):
|
||||
prev_keyword, self.text_before_cursor = \
|
||||
find_prev_keyword(self.text_before_cursor, n_skip=n_skip)
|
||||
return prev_keyword
|
||||
|
||||
|
||||
def suggest_type(full_text, text_before_cursor):
|
||||
"""Takes the full_text that is typed so far and also the text before the
|
||||
cursor to suggest completion type and scope.
|
||||
|
||||
Returns a tuple with a type of entity ('table', 'column' etc) and a scope.
|
||||
A scope for a column category will be a list of tables.
|
||||
"""
|
||||
|
||||
if full_text.startswith('\\i '):
|
||||
return (Path(),)
|
||||
|
||||
# This is a temporary hack; the exception handling
|
||||
# here should be removed once sqlparse has been fixed
|
||||
try:
|
||||
stmt = SqlStatement(full_text, text_before_cursor)
|
||||
except (TypeError, AttributeError):
|
||||
return []
|
||||
|
||||
return suggest_based_on_last_token(stmt.last_token, stmt)
|
||||
|
||||
|
||||
named_query_regex = re.compile(r'^\s*\\ns\s+[A-z0-9\-_]+\s+')
|
||||
|
||||
|
||||
def _strip_named_query(txt):
|
||||
"""
|
||||
This will strip "save named query" command in the beginning of the line:
|
||||
'\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
|
||||
' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
|
||||
"""
|
||||
|
||||
if named_query_regex.match(txt):
|
||||
txt = named_query_regex.sub('', txt)
|
||||
return txt
|
||||
|
||||
|
||||
function_body_pattern = re.compile(r'(\$.*?\$)([\s\S]*?)\1', re.M)
|
||||
|
||||
|
||||
def _find_function_body(text):
|
||||
split = function_body_pattern.search(text)
|
||||
return (split.start(2), split.end(2)) if split else (None, None)
|
||||
|
||||
|
||||
def _statement_from_function(full_text, text_before_cursor, statement):
|
||||
current_pos = len(text_before_cursor)
|
||||
body_start, body_end = _find_function_body(full_text)
|
||||
if body_start is None:
|
||||
return full_text, text_before_cursor, statement
|
||||
if not body_start <= current_pos < body_end:
|
||||
return full_text, text_before_cursor, statement
|
||||
full_text = full_text[body_start:body_end]
|
||||
text_before_cursor = text_before_cursor[body_start:]
|
||||
parsed = sqlparse.parse(text_before_cursor)
|
||||
return _split_multiple_statements(full_text, text_before_cursor, parsed)
|
||||
|
||||
|
||||
def _split_multiple_statements(full_text, text_before_cursor, parsed):
|
||||
if len(parsed) > 1:
|
||||
# Multiple statements being edited -- isolate the current one by
|
||||
# cumulatively summing statement lengths to find the one that bounds
|
||||
# the current position
|
||||
current_pos = len(text_before_cursor)
|
||||
stmt_start, stmt_end = 0, 0
|
||||
|
||||
for statement in parsed:
|
||||
stmt_len = len(str(statement))
|
||||
stmt_start, stmt_end = stmt_end, stmt_end + stmt_len
|
||||
|
||||
if stmt_end >= current_pos:
|
||||
text_before_cursor = full_text[stmt_start:current_pos]
|
||||
full_text = full_text[stmt_start:]
|
||||
break
|
||||
|
||||
elif parsed:
|
||||
# A single statement
|
||||
statement = parsed[0]
|
||||
else:
|
||||
# The empty string
|
||||
return full_text, text_before_cursor, None
|
||||
|
||||
token2 = None
|
||||
if statement.get_type() in ('CREATE', 'CREATE OR REPLACE'):
|
||||
token1 = statement.token_first()
|
||||
if token1:
|
||||
token1_idx = statement.token_index(token1)
|
||||
token2 = statement.token_next(token1_idx)[1]
|
||||
if token2 and token2.value.upper() == 'FUNCTION':
|
||||
full_text, text_before_cursor, statement = _statement_from_function(
|
||||
full_text, text_before_cursor, statement
|
||||
)
|
||||
return full_text, text_before_cursor, statement
|
||||
|
||||
|
||||
def suggest_based_on_last_token(token, stmt):
|
||||
|
||||
if isinstance(token, string_types):
|
||||
token_v = token.lower()
|
||||
elif isinstance(token, Comparison):
|
||||
# If 'token' is a Comparison type such as
|
||||
# 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling
|
||||
# token.value on the comparison type will only return the lhs of the
|
||||
# comparison. In this case a.id. So we need to do token.tokens to get
|
||||
# both sides of the comparison and pick the last token out of that
|
||||
# list.
|
||||
token_v = token.tokens[-1].value.lower()
|
||||
elif isinstance(token, Where):
|
||||
# sqlparse groups all tokens from the where clause into a single token
|
||||
# list. This means that token.value may be something like
|
||||
# 'where foo > 5 and '. We need to look "inside" token.tokens to handle
|
||||
# suggestions in complicated where clauses correctly
|
||||
prev_keyword = stmt.reduce_to_prev_keyword()
|
||||
return suggest_based_on_last_token(prev_keyword, stmt)
|
||||
elif isinstance(token, Identifier):
|
||||
# If the previous token is an identifier, we can suggest datatypes if
|
||||
# we're in a parenthesized column/field list, e.g.:
|
||||
# CREATE TABLE foo (Identifier <CURSOR>
|
||||
# CREATE FUNCTION foo (Identifier <CURSOR>
|
||||
# If we're not in a parenthesized list, the most likely scenario is the
|
||||
# user is about to specify an alias, e.g.:
|
||||
# SELECT Identifier <CURSOR>
|
||||
# SELECT foo FROM Identifier <CURSOR>
|
||||
prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor)
|
||||
if prev_keyword and prev_keyword.value == '(':
|
||||
# Suggest datatypes
|
||||
return suggest_based_on_last_token('type', stmt)
|
||||
else:
|
||||
return (Keyword(),)
|
||||
else:
|
||||
token_v = token.value.lower()
|
||||
|
||||
if not token:
|
||||
return (Keyword(),)
|
||||
elif token_v.endswith('('):
|
||||
p = sqlparse.parse(stmt.text_before_cursor)[0]
|
||||
|
||||
if p.tokens and isinstance(p.tokens[-1], Where):
|
||||
# Four possibilities:
|
||||
# 1 - Parenthesized clause like "WHERE foo AND ("
|
||||
# Suggest columns/functions
|
||||
# 2 - Function call like "WHERE foo("
|
||||
# Suggest columns/functions
|
||||
# 3 - Subquery expression like "WHERE EXISTS ("
|
||||
# Suggest keywords, in order to do a subquery
|
||||
# 4 - Subquery OR array comparison like "WHERE foo = ANY("
|
||||
# Suggest columns/functions AND keywords. (If we wanted to be
|
||||
# really fancy, we could suggest only array-typed columns)
|
||||
|
||||
column_suggestions = suggest_based_on_last_token('where', stmt)
|
||||
|
||||
# Check for a subquery expression (cases 3 & 4)
|
||||
where = p.tokens[-1]
|
||||
prev_tok = where.token_prev(len(where.tokens) - 1)[1]
|
||||
|
||||
if isinstance(prev_tok, Comparison):
|
||||
# e.g. "SELECT foo FROM bar WHERE foo = ANY("
|
||||
prev_tok = prev_tok.tokens[-1]
|
||||
|
||||
prev_tok = prev_tok.value.lower()
|
||||
if prev_tok == 'exists':
|
||||
return (Keyword(),)
|
||||
else:
|
||||
return column_suggestions
|
||||
|
||||
# Get the token before the parens
|
||||
prev_tok = p.token_prev(len(p.tokens) - 1)[1]
|
||||
|
||||
if (
|
||||
prev_tok and prev_tok.value and
|
||||
prev_tok.value.lower().split(' ')[-1] == 'using'
|
||||
):
|
||||
# tbl1 INNER JOIN tbl2 USING (col1, col2)
|
||||
tables = stmt.get_tables('before')
|
||||
|
||||
# suggest columns that are present in more than one table
|
||||
return (Column(table_refs=tables,
|
||||
require_last_table=True,
|
||||
local_tables=stmt.local_tables),)
|
||||
|
||||
elif p.token_first().value.lower() == 'select':
|
||||
# If the lparen is preceeded by a space chances are we're about to
|
||||
# do a sub-select.
|
||||
if last_word(stmt.text_before_cursor,
|
||||
'all_punctuations').startswith('('):
|
||||
return (Keyword(),)
|
||||
prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1]
|
||||
if prev_prev_tok and prev_prev_tok.normalized == 'INTO':
|
||||
return (
|
||||
Column(table_refs=stmt.get_tables('insert'), context='insert'),
|
||||
)
|
||||
# We're probably in a function argument list
|
||||
return (Column(table_refs=extract_tables(stmt.full_text),
|
||||
local_tables=stmt.local_tables, qualifiable=True),)
|
||||
elif token_v == 'set':
|
||||
return (Column(table_refs=stmt.get_tables(),
|
||||
local_tables=stmt.local_tables),)
|
||||
elif token_v in ('select', 'where', 'having', 'by', 'distinct'):
|
||||
# Check for a table alias or schema qualification
|
||||
parent = (stmt.identifier and stmt.identifier.get_parent_name()) or []
|
||||
tables = stmt.get_tables()
|
||||
if parent:
|
||||
tables = tuple(t for t in tables if identifies(parent, t))
|
||||
return (Column(table_refs=tables, local_tables=stmt.local_tables),
|
||||
Table(schema=parent),
|
||||
View(schema=parent),
|
||||
Function(schema=parent),)
|
||||
else:
|
||||
return (Column(table_refs=tables, local_tables=stmt.local_tables,
|
||||
qualifiable=True),
|
||||
Function(schema=None),
|
||||
Keyword(token_v.upper()),)
|
||||
elif token_v == 'as':
|
||||
# Don't suggest anything for aliases
|
||||
return ()
|
||||
elif (
|
||||
(token_v.endswith('join') and token.is_keyword) or
|
||||
(token_v in ('copy', 'from', 'update', 'into', 'describe', 'truncate'))
|
||||
):
|
||||
|
||||
schema = stmt.get_identifier_schema()
|
||||
tables = extract_tables(stmt.text_before_cursor)
|
||||
is_join = token_v.endswith('join') and token.is_keyword
|
||||
|
||||
# Suggest tables from either the currently-selected schema or the
|
||||
# public schema if no schema has been specified
|
||||
suggest = []
|
||||
|
||||
if not schema:
|
||||
# Suggest schemas
|
||||
suggest.insert(0, Schema())
|
||||
|
||||
if token_v == 'from' or is_join:
|
||||
suggest.append(FromClauseItem(schema=schema,
|
||||
table_refs=tables,
|
||||
local_tables=stmt.local_tables))
|
||||
elif token_v == 'truncate':
|
||||
suggest.append(Table(schema))
|
||||
else:
|
||||
suggest.extend((Table(schema), View(schema)))
|
||||
|
||||
if is_join and _allow_join(stmt.parsed):
|
||||
tables = stmt.get_tables('before')
|
||||
suggest.append(Join(table_refs=tables, schema=schema))
|
||||
|
||||
return tuple(suggest)
|
||||
|
||||
elif token_v == 'function':
|
||||
schema = stmt.get_identifier_schema()
|
||||
# stmt.get_previous_token will fail for e.g.
|
||||
# `SELECT 1 FROM functions WHERE function:`
|
||||
try:
|
||||
prev = stmt.get_previous_token(token).value.lower()
|
||||
if prev in('drop', 'alter', 'create', 'create or replace'):
|
||||
return (Function(schema=schema, usage='signature'),)
|
||||
except ValueError:
|
||||
pass
|
||||
return tuple()
|
||||
|
||||
elif token_v in ('table', 'view'):
|
||||
# E.g. 'ALTER TABLE <tablname>'
|
||||
rel_type = \
|
||||
{'table': Table, 'view': View, 'function': Function}[token_v]
|
||||
schema = stmt.get_identifier_schema()
|
||||
if schema:
|
||||
return (rel_type(schema=schema),)
|
||||
else:
|
||||
return (Schema(), rel_type(schema=schema))
|
||||
|
||||
elif token_v == 'column':
|
||||
# E.g. 'ALTER TABLE foo ALTER COLUMN bar
|
||||
return (Column(table_refs=stmt.get_tables()),)
|
||||
|
||||
elif token_v == 'on':
|
||||
tables = stmt.get_tables('before')
|
||||
parent = \
|
||||
(stmt.identifier and stmt.identifier.get_parent_name()) or None
|
||||
if parent:
|
||||
# "ON parent.<suggestion>"
|
||||
# parent can be either a schema name or table alias
|
||||
filteredtables = tuple(t for t in tables if identifies(parent, t))
|
||||
sugs = [Column(table_refs=filteredtables,
|
||||
local_tables=stmt.local_tables),
|
||||
Table(schema=parent),
|
||||
View(schema=parent),
|
||||
Function(schema=parent)]
|
||||
if filteredtables and _allow_join_condition(stmt.parsed):
|
||||
sugs.append(JoinCondition(table_refs=tables,
|
||||
parent=filteredtables[-1]))
|
||||
return tuple(sugs)
|
||||
else:
|
||||
# ON <suggestion>
|
||||
# Use table alias if there is one, otherwise the table name
|
||||
aliases = tuple(t.ref for t in tables)
|
||||
if _allow_join_condition(stmt.parsed):
|
||||
return (Alias(aliases=aliases), JoinCondition(
|
||||
table_refs=tables, parent=None))
|
||||
else:
|
||||
return (Alias(aliases=aliases),)
|
||||
|
||||
elif token_v in ('c', 'use', 'database', 'template'):
|
||||
# "\c <db", "use <db>", "DROP DATABASE <db>",
|
||||
# "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
|
||||
return (Database(),)
|
||||
elif token_v == 'schema':
|
||||
# DROP SCHEMA schema_name, SET SCHEMA schema name
|
||||
prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2)
|
||||
quoted = prev_keyword and prev_keyword.value.lower() == 'set'
|
||||
return (Schema(quoted),)
|
||||
elif token_v.endswith(',') or token_v in ('=', 'and', 'or'):
|
||||
prev_keyword = stmt.reduce_to_prev_keyword()
|
||||
if prev_keyword:
|
||||
return suggest_based_on_last_token(prev_keyword, stmt)
|
||||
else:
|
||||
return ()
|
||||
elif token_v in ('type', '::'):
|
||||
# ALTER TABLE foo SET DATA TYPE bar
|
||||
# SELECT foo::bar
|
||||
# Note that tables are a form of composite type in postgresql, so
|
||||
# they're suggested here as well
|
||||
schema = stmt.get_identifier_schema()
|
||||
suggestions = [Datatype(schema=schema),
|
||||
Table(schema=schema)]
|
||||
if not schema:
|
||||
suggestions.append(Schema())
|
||||
return tuple(suggestions)
|
||||
elif token_v in {'alter', 'create', 'drop'}:
|
||||
return (Keyword(token_v.upper()),)
|
||||
elif token.is_keyword:
|
||||
# token is a keyword we haven't implemented any special handling for
|
||||
# go backwards in the query until we find one we do recognize
|
||||
prev_keyword = stmt.reduce_to_prev_keyword(n_skip=1)
|
||||
if prev_keyword:
|
||||
return suggest_based_on_last_token(prev_keyword, stmt)
|
||||
else:
|
||||
return (Keyword(token_v.upper()),)
|
||||
else:
|
||||
return (Keyword(),)
|
||||
|
||||
|
||||
def identifies(id, ref):
|
||||
"""Returns true if string `id` matches TableReference `ref`"""
|
||||
|
||||
return id == ref.alias or id == ref.name or (
|
||||
ref.schema and (id == ref.schema + '.' + ref.name))
|
||||
|
||||
|
||||
def _allow_join_condition(statement):
|
||||
"""
|
||||
Tests if a join condition should be suggested
|
||||
|
||||
We need this to avoid bad suggestions when entering e.g.
|
||||
select * from tbl1 a join tbl2 b on a.id = <cursor>
|
||||
So check that the preceding token is a ON, AND, or OR keyword, instead of
|
||||
e.g. an equals sign.
|
||||
|
||||
:param statement: an sqlparse.sql.Statement
|
||||
:return: boolean
|
||||
"""
|
||||
|
||||
if not statement or not statement.tokens:
|
||||
return False
|
||||
|
||||
last_tok = statement.token_prev(len(statement.tokens))[1]
|
||||
return last_tok.value.lower() in ('on', 'and', 'or')
|
||||
|
||||
|
||||
def _allow_join(statement):
|
||||
"""
|
||||
Tests if a join should be suggested
|
||||
|
||||
We need this to avoid bad suggestions when entering e.g.
|
||||
select * from tbl1 a join tbl2 b <cursor>
|
||||
So check that the preceding token is a JOIN keyword
|
||||
|
||||
:param statement: an sqlparse.sql.Statement
|
||||
:return: boolean
|
||||
"""
|
||||
|
||||
if not statement or not statement.tokens:
|
||||
return False
|
||||
|
||||
last_tok = statement.token_prev(len(statement.tokens))[1]
|
||||
return (
|
||||
last_tok.value.lower().endswith('join') and
|
||||
last_tok.value.lower() not in('cross join', 'natural join')
|
||||
)
|
Loading…
Reference in New Issue