Merge pgcli code with version 1.10.3, which is used for auto complete feature.
parent
25679fd542
commit
7a3f3046df
|
@ -21,6 +21,7 @@ Bug fixes
|
||||||
| `Bug #3325 <https://redmine.postgresql.org/issues/3325>`_ - Fix sort/filter dialog issue where it incorrectly requires ASC/DESC.
|
| `Bug #3325 <https://redmine.postgresql.org/issues/3325>`_ - Fix sort/filter dialog issue where it incorrectly requires ASC/DESC.
|
||||||
| `Bug #3347 <https://redmine.postgresql.org/issues/3347>`_ - Ensure backup should work with '--data-only' and '--schema-only' for any format.
|
| `Bug #3347 <https://redmine.postgresql.org/issues/3347>`_ - Ensure backup should work with '--data-only' and '--schema-only' for any format.
|
||||||
| `Bug #3407 <https://redmine.postgresql.org/issues/3407>`_ - Fix keyboard shortcuts layout in the preferences panel.
|
| `Bug #3407 <https://redmine.postgresql.org/issues/3407>`_ - Fix keyboard shortcuts layout in the preferences panel.
|
||||||
|
| `Bug #3420 <https://redmine.postgresql.org/issues/3420>`_ - Merge pgcli code with version 1.10.3, which is used for auto complete feature.
|
||||||
| `Bug #3461 <https://redmine.postgresql.org/issues/3461>`_ - Ensure that refreshing a node also updates the Property list.
|
| `Bug #3461 <https://redmine.postgresql.org/issues/3461>`_ - Ensure that refreshing a node also updates the Property list.
|
||||||
| `Bug #3528 <https://redmine.postgresql.org/issues/3528>`_ - Handle connection errors properly in the query tool.
|
| `Bug #3528 <https://redmine.postgresql.org/issues/3528>`_ - Handle connection errors properly in the query tool.
|
||||||
| `Bug #3547 <https://redmine.postgresql.org/issues/3547>`_ - Make session implementation thread safe
|
| `Bug #3547 <https://redmine.postgresql.org/issues/3547>`_ - Make session implementation thread safe
|
||||||
|
|
|
@ -0,0 +1,156 @@
|
||||||
|
##########################################################################
|
||||||
|
#
|
||||||
|
# pgAdmin 4 - PostgreSQL Tools
|
||||||
|
#
|
||||||
|
# Copyright (C) 2013 - 2018, The pgAdmin Development Team
|
||||||
|
# This software is released under the PostgreSQL Licence
|
||||||
|
#
|
||||||
|
##########################################################################
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import random
|
||||||
|
|
||||||
|
from selenium.webdriver import ActionChains
|
||||||
|
from selenium.webdriver.common.keys import Keys
|
||||||
|
from selenium.webdriver.support.ui import WebDriverWait
|
||||||
|
from selenium.webdriver.support import expected_conditions as EC
|
||||||
|
from regression.python_test_utils import test_utils
|
||||||
|
from regression.feature_utils.base_feature_test import BaseFeatureTest
|
||||||
|
|
||||||
|
|
||||||
|
class QueryToolAutoCompleteFeatureTest(BaseFeatureTest):
|
||||||
|
"""
|
||||||
|
This feature test will test the query tool auto complete feature.
|
||||||
|
"""
|
||||||
|
|
||||||
|
first_table_name = ""
|
||||||
|
second_table_name = ""
|
||||||
|
|
||||||
|
scenarios = [
|
||||||
|
("Query tool auto complete feature test", dict())
|
||||||
|
]
|
||||||
|
|
||||||
|
def before(self):
|
||||||
|
self.page.wait_for_spinner_to_disappear()
|
||||||
|
|
||||||
|
self.page.add_server(self.server)
|
||||||
|
self.first_table_name = "auto_comp_" + \
|
||||||
|
str(random.randint(1000, 3000))
|
||||||
|
test_utils.create_table(self.server, self.test_db,
|
||||||
|
self.first_table_name)
|
||||||
|
|
||||||
|
self.second_table_name = "auto_comp_" + \
|
||||||
|
str(random.randint(1000, 3000))
|
||||||
|
test_utils.create_table(self.server, self.test_db,
|
||||||
|
self.second_table_name)
|
||||||
|
|
||||||
|
self._locate_database_tree_node()
|
||||||
|
self.page.open_query_tool()
|
||||||
|
self.page.wait_for_spinner_to_disappear()
|
||||||
|
|
||||||
|
def runTest(self):
|
||||||
|
# Test case for keywords
|
||||||
|
print("\nAuto complete ALTER keyword... ", file=sys.stderr, end="")
|
||||||
|
self._auto_complete("A", "ALTER")
|
||||||
|
print("OK.", file=sys.stderr)
|
||||||
|
self._clear_query_tool()
|
||||||
|
|
||||||
|
print("Auto complete BEGIN keyword... ", file=sys.stderr, end="")
|
||||||
|
self._auto_complete("BE", "BEGIN")
|
||||||
|
print("OK.", file=sys.stderr)
|
||||||
|
self._clear_query_tool()
|
||||||
|
|
||||||
|
print("Auto complete CASCADED keyword... ", file=sys.stderr, end="")
|
||||||
|
self._auto_complete("CAS", "CASCADED")
|
||||||
|
print("OK.", file=sys.stderr)
|
||||||
|
self._clear_query_tool()
|
||||||
|
|
||||||
|
print("Auto complete SELECT keyword... ", file=sys.stderr, end="")
|
||||||
|
self._auto_complete("SE", "SELECT")
|
||||||
|
print("OK.", file=sys.stderr)
|
||||||
|
self._clear_query_tool()
|
||||||
|
|
||||||
|
print("Auto complete pg_backend_pid() function ... ",
|
||||||
|
file=sys.stderr, end="")
|
||||||
|
self._auto_complete("SELECT pg_", "pg_backend_pid()")
|
||||||
|
print("OK.", file=sys.stderr)
|
||||||
|
self._clear_query_tool()
|
||||||
|
|
||||||
|
print("Auto complete current_query() function ... ",
|
||||||
|
file=sys.stderr, end="")
|
||||||
|
self._auto_complete("SELECT current_", "current_query()")
|
||||||
|
print("OK.", file=sys.stderr)
|
||||||
|
self._clear_query_tool()
|
||||||
|
|
||||||
|
print("Auto complete function with argument ... ",
|
||||||
|
file=sys.stderr, end="")
|
||||||
|
self._auto_complete("SELECT pg_st", "pg_stat_file(filename)")
|
||||||
|
print("OK.", file=sys.stderr)
|
||||||
|
self._clear_query_tool()
|
||||||
|
|
||||||
|
print("Auto complete first table in public schema ... ",
|
||||||
|
file=sys.stderr, end="")
|
||||||
|
self._auto_complete("SELECT * FROM public.", self.first_table_name)
|
||||||
|
print("OK.", file=sys.stderr)
|
||||||
|
self._clear_query_tool()
|
||||||
|
|
||||||
|
print("Auto complete second table in public schema ... ",
|
||||||
|
file=sys.stderr, end="")
|
||||||
|
self._auto_complete("SELECT * FROM public.", self.second_table_name)
|
||||||
|
print("OK.", file=sys.stderr)
|
||||||
|
self._clear_query_tool()
|
||||||
|
|
||||||
|
print("Auto complete JOIN second table with after schema name ... ",
|
||||||
|
file=sys.stderr, end="")
|
||||||
|
query = "SELECT * FROM public." + self.first_table_name + \
|
||||||
|
" JOIN public."
|
||||||
|
self._auto_complete(query, self.second_table_name)
|
||||||
|
print("OK.", file=sys.stderr)
|
||||||
|
self._clear_query_tool()
|
||||||
|
|
||||||
|
print("Auto complete JOIN ON some columns ... ",
|
||||||
|
file=sys.stderr, end="")
|
||||||
|
query = "SELECT * FROM public." + self.first_table_name + \
|
||||||
|
" JOIN public." + self.second_table_name + " ON " + \
|
||||||
|
self.second_table_name + "."
|
||||||
|
expected_string = "some_column = " + self.first_table_name + \
|
||||||
|
".some_column"
|
||||||
|
self._auto_complete(query, expected_string)
|
||||||
|
print("OK.", file=sys.stderr)
|
||||||
|
self._clear_query_tool()
|
||||||
|
|
||||||
|
print("Auto complete JOIN ON some columns using tabel alias ... ",
|
||||||
|
file=sys.stderr, end="")
|
||||||
|
query = "SELECT * FROM public." + self.first_table_name + \
|
||||||
|
" t1 JOIN public." + self.second_table_name + " t2 ON t2."
|
||||||
|
self._auto_complete(query, "some_column = t1.some_column")
|
||||||
|
print("OK.", file=sys.stderr)
|
||||||
|
self._clear_query_tool()
|
||||||
|
|
||||||
|
def after(self):
|
||||||
|
self.page.remove_server(self.server)
|
||||||
|
|
||||||
|
def _locate_database_tree_node(self):
|
||||||
|
self.page.toggle_open_tree_item(self.server['name'])
|
||||||
|
self.page.toggle_open_tree_item('Databases')
|
||||||
|
self.page.toggle_open_tree_item(self.test_db)
|
||||||
|
|
||||||
|
def _clear_query_tool(self):
|
||||||
|
self.page.click_element(
|
||||||
|
self.page.find_by_xpath("//*[@id='btn-clear-dropdown']")
|
||||||
|
)
|
||||||
|
ActionChains(self.driver) \
|
||||||
|
.move_to_element(self.page.find_by_xpath("//*[@id='btn-clear']")) \
|
||||||
|
.perform()
|
||||||
|
self.page.click_element(
|
||||||
|
self.page.find_by_xpath("//*[@id='btn-clear']")
|
||||||
|
)
|
||||||
|
self.page.click_modal('Yes')
|
||||||
|
|
||||||
|
def _auto_complete(self, word, expected_string):
|
||||||
|
self.page.fill_codemirror_area_with(word)
|
||||||
|
ActionChains(self.page.driver).key_down(
|
||||||
|
Keys.CONTROL).send_keys(Keys.SPACE).key_up(Keys.CONTROL).perform()
|
||||||
|
self.page.find_by_xpath(
|
||||||
|
"//ul[contains(@class, 'CodeMirror-hints') and "
|
||||||
|
"contains(., '" + expected_string + "')]")
|
|
@ -1,30 +1,16 @@
|
||||||
{# ============= Fetch the list of functions based on given schema_names ============= #}
|
{# ============= Fetch the list of functions based on given schema_names ============= #}
|
||||||
{% if func_name %}
|
|
||||||
SELECT n.nspname schema_name,
|
SELECT n.nspname schema_name,
|
||||||
p.proname func_name,
|
p.proname func_name,
|
||||||
pg_catalog.pg_get_function_arguments(p.oid) arg_list,
|
p.proargnames arg_names,
|
||||||
pg_catalog.pg_get_function_result(p.oid) return_type,
|
COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[] arg_types,
|
||||||
|
p.proargmodes arg_modes,
|
||||||
|
prorettype::regtype::text return_type,
|
||||||
CASE WHEN p.prokind = 'a' THEN true ELSE false END is_aggregate,
|
CASE WHEN p.prokind = 'a' THEN true ELSE false END is_aggregate,
|
||||||
CASE WHEN p.prokind = 'w' THEN true ELSE false END is_window,
|
CASE WHEN p.prokind = 'w' THEN true ELSE false END is_window,
|
||||||
p.proretset is_set_returning
|
p.proretset is_set_returning,
|
||||||
|
pg_get_expr(proargdefaults, 0) AS arg_defaults
|
||||||
FROM pg_catalog.pg_proc p
|
FROM pg_catalog.pg_proc p
|
||||||
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
|
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
|
||||||
WHERE n.nspname = '{{schema_name}}' AND p.proname = '{{func_name}}'
|
WHERE p.prorettype::regtype != 'trigger'::regtype
|
||||||
AND p.proretset
|
AND n.nspname IN ({{schema_names}})
|
||||||
ORDER BY 1, 2
|
ORDER BY 1, 2
|
||||||
{% else %}
|
|
||||||
SELECT n.nspname schema_name,
|
|
||||||
p.proname object_name,
|
|
||||||
pg_catalog.pg_get_function_arguments(p.oid) arg_list,
|
|
||||||
pg_catalog.pg_get_function_result(p.oid) return_type,
|
|
||||||
CASE WHEN p.prokind = 'a' THEN true ELSE false END is_aggregate,
|
|
||||||
CASE WHEN p.prokind = 'w' THEN true ELSE false END is_window,
|
|
||||||
p.proretset is_set_returning
|
|
||||||
FROM pg_catalog.pg_proc p
|
|
||||||
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
|
|
||||||
WHERE n.nspname IN ({{schema_names}})
|
|
||||||
{% if is_set_returning %}
|
|
||||||
AND p.proretset
|
|
||||||
{% endif %}
|
|
||||||
ORDER BY 1, 2
|
|
||||||
{% endif %}
|
|
||||||
|
|
|
@ -1,29 +1,43 @@
|
||||||
{# SQL query for getting columns #}
|
{# SQL query for getting columns #}
|
||||||
{% if object_name == 'table' %}
|
{% if object_name == 'table' %}
|
||||||
SELECT
|
SELECT nsp.nspname schema_name,
|
||||||
att.attname column_name
|
cls.relname table_name,
|
||||||
FROM pg_catalog.pg_attribute att
|
att.attname column_name,
|
||||||
INNER JOIN pg_catalog.pg_class cls
|
att.atttypid::regtype::text type_name,
|
||||||
ON att.attrelid = cls.oid
|
att.atthasdef AS has_default,
|
||||||
INNER JOIN pg_catalog.pg_namespace nsp
|
def.adsrc as default
|
||||||
ON cls.relnamespace = nsp.oid
|
FROM pg_catalog.pg_attribute att
|
||||||
WHERE cls.relkind = ANY(array['r'])
|
INNER JOIN pg_catalog.pg_class cls
|
||||||
|
ON att.attrelid = cls.oid
|
||||||
|
INNER JOIN pg_catalog.pg_namespace nsp
|
||||||
|
ON cls.relnamespace = nsp.oid
|
||||||
|
LEFT OUTER JOIN pg_attrdef def
|
||||||
|
ON def.adrelid = att.attrelid
|
||||||
|
AND def.adnum = att.attnum
|
||||||
|
WHERE nsp.nspname IN ({{schema_names}})
|
||||||
|
AND cls.relkind = ANY(array['r'])
|
||||||
AND NOT att.attisdropped
|
AND NOT att.attisdropped
|
||||||
AND att.attnum > 0
|
AND att.attnum > 0
|
||||||
AND (nsp.nspname = '{{schema_name}}' AND cls.relname = '{{rel_name}}')
|
ORDER BY 1, 2, att.attnum
|
||||||
ORDER BY 1
|
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% if object_name == 'view' %}
|
{% if object_name == 'view' %}
|
||||||
SELECT
|
SELECT nsp.nspname schema_name,
|
||||||
att.attname column_name
|
cls.relname table_name,
|
||||||
FROM pg_catalog.pg_attribute att
|
att.attname column_name,
|
||||||
INNER JOIN pg_catalog.pg_class cls
|
att.atttypid::regtype::text type_name,
|
||||||
ON att.attrelid = cls.oid
|
att.atthasdef AS has_default,
|
||||||
INNER JOIN pg_catalog.pg_namespace nsp
|
def.adsrc as default
|
||||||
ON cls.relnamespace = nsp.oid
|
FROM pg_catalog.pg_attribute att
|
||||||
WHERE cls.relkind = ANY(array['v', 'm'])
|
INNER JOIN pg_catalog.pg_class cls
|
||||||
|
ON att.attrelid = cls.oid
|
||||||
|
INNER JOIN pg_catalog.pg_namespace nsp
|
||||||
|
ON cls.relnamespace = nsp.oid
|
||||||
|
LEFT OUTER JOIN pg_attrdef def
|
||||||
|
ON def.adrelid = att.attrelid
|
||||||
|
AND def.adnum = att.attnum
|
||||||
|
WHERE nsp.nspname IN ({{schema_names}})
|
||||||
|
AND cls.relkind = ANY(array['v', 'm'])
|
||||||
AND NOT att.attisdropped
|
AND NOT att.attisdropped
|
||||||
AND att.attnum > 0
|
AND att.attnum > 0
|
||||||
AND (nsp.nspname = '{{schema_name}}' AND cls.relname = '{{rel_name}}')
|
ORDER BY 1, 2, att.attnum
|
||||||
ORDER BY 1
|
|
||||||
{% endif %}
|
{% endif %}
|
|
@ -0,0 +1,27 @@
|
||||||
|
{# SQL query for getting foreign keys #}
|
||||||
|
SELECT s_p.nspname AS parentschema,
|
||||||
|
t_p.relname AS parenttable,
|
||||||
|
unnest((
|
||||||
|
select
|
||||||
|
array_agg(attname ORDER BY i)
|
||||||
|
from
|
||||||
|
(select unnest(confkey) as attnum, generate_subscripts(confkey, 1) as i) x
|
||||||
|
JOIN pg_catalog.pg_attribute c USING(attnum)
|
||||||
|
WHERE c.attrelid = fk.confrelid
|
||||||
|
)) AS parentcolumn,
|
||||||
|
s_c.nspname AS childschema,
|
||||||
|
t_c.relname AS childtable,
|
||||||
|
unnest((
|
||||||
|
select
|
||||||
|
array_agg(attname ORDER BY i)
|
||||||
|
from
|
||||||
|
(select unnest(conkey) as attnum, generate_subscripts(conkey, 1) as i) x
|
||||||
|
JOIN pg_catalog.pg_attribute c USING(attnum)
|
||||||
|
WHERE c.attrelid = fk.conrelid
|
||||||
|
)) AS childcolumn
|
||||||
|
FROM pg_catalog.pg_constraint fk
|
||||||
|
JOIN pg_catalog.pg_class t_p ON t_p.oid = fk.confrelid
|
||||||
|
JOIN pg_catalog.pg_namespace s_p ON s_p.oid = t_p.relnamespace
|
||||||
|
JOIN pg_catalog.pg_class t_c ON t_c.oid = fk.conrelid
|
||||||
|
JOIN pg_catalog.pg_namespace s_c ON s_c.oid = t_c.relnamespace
|
||||||
|
WHERE fk.contype = 'f' AND s_p.nspname IN ({{schema_names}})
|
|
@ -1,30 +1,16 @@
|
||||||
{# ============= Fetch the list of functions based on given schema_names ============= #}
|
{# ============= Fetch the list of functions based on given schema_names ============= #}
|
||||||
{% if func_name %}
|
|
||||||
SELECT n.nspname schema_name,
|
SELECT n.nspname schema_name,
|
||||||
p.proname func_name,
|
p.proname func_name,
|
||||||
pg_catalog.pg_get_function_arguments(p.oid) arg_list,
|
p.proargnames arg_names,
|
||||||
pg_catalog.pg_get_function_result(p.oid) return_type,
|
COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[] arg_types,
|
||||||
|
p.proargmodes arg_modes,
|
||||||
|
prorettype::regtype::text return_type,
|
||||||
p.proisagg is_aggregate,
|
p.proisagg is_aggregate,
|
||||||
p.proiswindow is_window,
|
p.proiswindow is_window,
|
||||||
p.proretset is_set_returning
|
p.proretset is_set_returning,
|
||||||
|
pg_get_expr(proargdefaults, 0) AS arg_defaults
|
||||||
FROM pg_catalog.pg_proc p
|
FROM pg_catalog.pg_proc p
|
||||||
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
|
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
|
||||||
WHERE n.nspname = '{{schema_name}}' AND p.proname = '{{func_name}}'
|
WHERE p.prorettype::regtype != 'trigger'::regtype
|
||||||
AND p.proretset
|
AND n.nspname IN ({{schema_names}})
|
||||||
ORDER BY 1, 2
|
ORDER BY 1, 2
|
||||||
{% else %}
|
|
||||||
SELECT n.nspname schema_name,
|
|
||||||
p.proname object_name,
|
|
||||||
pg_catalog.pg_get_function_arguments(p.oid) arg_list,
|
|
||||||
pg_catalog.pg_get_function_result(p.oid) return_type,
|
|
||||||
p.proisagg is_aggregate,
|
|
||||||
p.proiswindow is_window,
|
|
||||||
p.proretset is_set_returning
|
|
||||||
FROM pg_catalog.pg_proc p
|
|
||||||
INNER JOIN pg_catalog.pg_namespace n ON n.oid = p.pronamespace
|
|
||||||
WHERE n.nspname IN ({{schema_names}})
|
|
||||||
{% if is_set_returning %}
|
|
||||||
AND p.proretset
|
|
||||||
{% endif %}
|
|
||||||
ORDER BY 1, 2
|
|
||||||
{% endif %}
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,197 +0,0 @@
|
||||||
"""
|
|
||||||
Copied from http://code.activestate.com/recipes/576611-counter-class/
|
|
||||||
"""
|
|
||||||
|
|
||||||
from heapq import nlargest
|
|
||||||
from itertools import repeat
|
|
||||||
try:
|
|
||||||
from itertools import ifilter
|
|
||||||
except ImportError:
|
|
||||||
# ifilter is in-built function in Python3 as filter
|
|
||||||
ifilter = filter
|
|
||||||
from operator import itemgetter
|
|
||||||
|
|
||||||
|
|
||||||
class Counter(dict):
|
|
||||||
'''Dict subclass for counting hashable objects. Sometimes called a bag
|
|
||||||
or multiset. Elements are stored as dictionary keys and their counts
|
|
||||||
are stored as dictionary values.
|
|
||||||
|
|
||||||
>>> Counter('zyzygy')
|
|
||||||
Counter({'y': 3, 'z': 2, 'g': 1})
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
def __init__(self, iterable=None, **kwds):
|
|
||||||
'''Create a new, empty Counter object. And if given, count elements
|
|
||||||
from an input iterable. Or, initialize the count from another mapping
|
|
||||||
of elements to their counts.
|
|
||||||
|
|
||||||
>>> c = Counter() # a new, empty counter
|
|
||||||
>>> c = Counter('gallahad') # a new counter from an iterable
|
|
||||||
>>> c = Counter({'a': 4, 'b': 2}) # a new counter from a mapping
|
|
||||||
>>> c = Counter(a=4, b=2) # a new counter from keyword args
|
|
||||||
|
|
||||||
'''
|
|
||||||
self.update(iterable, **kwds)
|
|
||||||
|
|
||||||
def __missing__(self, key):
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def most_common(self, n=None):
|
|
||||||
'''List the n most common elements and their counts from the most
|
|
||||||
common to the least. If n is None, then list all element counts.
|
|
||||||
|
|
||||||
>>> Counter('abracadabra').most_common(3)
|
|
||||||
[('a', 5), ('r', 2), ('b', 2)]
|
|
||||||
|
|
||||||
'''
|
|
||||||
if n is None:
|
|
||||||
return sorted(self.iteritems(), key=itemgetter(1), reverse=True)
|
|
||||||
return nlargest(n, self.iteritems(), key=itemgetter(1))
|
|
||||||
|
|
||||||
def elements(self):
|
|
||||||
'''Iterator over elements repeating each as many times as its count.
|
|
||||||
|
|
||||||
>>> c = Counter('ABCABC')
|
|
||||||
>>> sorted(c.elements())
|
|
||||||
['A', 'A', 'B', 'B', 'C', 'C']
|
|
||||||
|
|
||||||
If an element's count has been set to zero or is a negative number,
|
|
||||||
elements() will ignore it.
|
|
||||||
|
|
||||||
'''
|
|
||||||
for elem, count in self.iteritems():
|
|
||||||
for _ in repeat(None, count):
|
|
||||||
yield elem
|
|
||||||
|
|
||||||
# Override dict methods where the meaning changes for Counter objects.
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def fromkeys(cls, iterable, v=None):
|
|
||||||
raise NotImplementedError(
|
|
||||||
'Counter.fromkeys() is undefined. Use Counter(iterable) instead.'
|
|
||||||
)
|
|
||||||
|
|
||||||
def update(self, iterable=None, **kwds):
|
|
||||||
'''Like dict.update() but add counts instead of replacing them.
|
|
||||||
|
|
||||||
Source can be an iterable, a dictionary, or another Counter instance.
|
|
||||||
|
|
||||||
>>> c = Counter('which')
|
|
||||||
>>> c.update('witch') # add elements from another iterable
|
|
||||||
>>> d = Counter('watch')
|
|
||||||
>>> c.update(d) # add elements from another counter
|
|
||||||
>>> c['h'] # four 'h' in which, witch, and watch
|
|
||||||
4
|
|
||||||
|
|
||||||
'''
|
|
||||||
if iterable is not None:
|
|
||||||
if hasattr(iterable, 'iteritems'):
|
|
||||||
if self:
|
|
||||||
self_get = self.get
|
|
||||||
for elem, count in iterable.iteritems():
|
|
||||||
self[elem] = self_get(elem, 0) + count
|
|
||||||
else:
|
|
||||||
# fast path when counter is empty
|
|
||||||
dict.update(self, iterable)
|
|
||||||
else:
|
|
||||||
self_get = self.get
|
|
||||||
for elem in iterable:
|
|
||||||
self[elem] = self_get(elem, 0) + 1
|
|
||||||
if kwds:
|
|
||||||
self.update(kwds)
|
|
||||||
|
|
||||||
def copy(self):
|
|
||||||
'Like dict.copy() but returns a Counter instance instead of a dict.'
|
|
||||||
return Counter(self)
|
|
||||||
|
|
||||||
def __delitem__(self, elem):
|
|
||||||
"""Like dict.__delitem__() but does not raise KeyError for missing
|
|
||||||
values."""
|
|
||||||
if elem in self:
|
|
||||||
dict.__delitem__(self, elem)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
if not self:
|
|
||||||
return '%s()' % self.__class__.__name__
|
|
||||||
items = ', '.join(map('%r: %r'.__mod__, self.most_common()))
|
|
||||||
return '%s({%s})' % (self.__class__.__name__, items)
|
|
||||||
|
|
||||||
# Multiset-style mathematical operations discussed in:
|
|
||||||
# Knuth TAOCP Volume II section 4.6.3 exercise 19
|
|
||||||
# and at http://en.wikipedia.org/wiki/Multiset
|
|
||||||
#
|
|
||||||
# Outputs guaranteed to only include positive counts.
|
|
||||||
#
|
|
||||||
# To strip negative and zero counts, add-in an empty counter:
|
|
||||||
# c += Counter()
|
|
||||||
|
|
||||||
def __add__(self, other):
|
|
||||||
'''Add counts from two counters.
|
|
||||||
|
|
||||||
>>> Counter('abbb') + Counter('bcc')
|
|
||||||
Counter({'b': 4, 'c': 2, 'a': 1})
|
|
||||||
|
|
||||||
|
|
||||||
'''
|
|
||||||
if not isinstance(other, Counter):
|
|
||||||
return NotImplemented
|
|
||||||
result = Counter()
|
|
||||||
for elem in set(self) | set(other):
|
|
||||||
newcount = self[elem] + other[elem]
|
|
||||||
if newcount > 0:
|
|
||||||
result[elem] = newcount
|
|
||||||
return result
|
|
||||||
|
|
||||||
def __sub__(self, other):
|
|
||||||
''' Subtract count, but keep only results with positive counts.
|
|
||||||
|
|
||||||
>>> Counter('abbbc') - Counter('bccd')
|
|
||||||
Counter({'b': 2, 'a': 1})
|
|
||||||
|
|
||||||
'''
|
|
||||||
if not isinstance(other, Counter):
|
|
||||||
return NotImplemented
|
|
||||||
result = Counter()
|
|
||||||
for elem in set(self) | set(other):
|
|
||||||
newcount = self[elem] - other[elem]
|
|
||||||
if newcount > 0:
|
|
||||||
result[elem] = newcount
|
|
||||||
return result
|
|
||||||
|
|
||||||
def __or__(self, other):
|
|
||||||
'''Union is the maximum of value in either of the input counters.
|
|
||||||
|
|
||||||
>>> Counter('abbb') | Counter('bcc')
|
|
||||||
Counter({'b': 3, 'c': 2, 'a': 1})
|
|
||||||
|
|
||||||
'''
|
|
||||||
if not isinstance(other, Counter):
|
|
||||||
return NotImplemented
|
|
||||||
_max = max
|
|
||||||
result = Counter()
|
|
||||||
for elem in set(self) | set(other):
|
|
||||||
newcount = _max(self[elem], other[elem])
|
|
||||||
if newcount > 0:
|
|
||||||
result[elem] = newcount
|
|
||||||
return result
|
|
||||||
|
|
||||||
def __and__(self, other):
|
|
||||||
''' Intersection is the minimum of corresponding counts.
|
|
||||||
|
|
||||||
>>> Counter('abbb') & Counter('bcc')
|
|
||||||
Counter({'b': 1})
|
|
||||||
|
|
||||||
'''
|
|
||||||
if not isinstance(other, Counter):
|
|
||||||
return NotImplemented
|
|
||||||
_min = min
|
|
||||||
result = Counter()
|
|
||||||
if len(self) < len(other):
|
|
||||||
self, other = other, self
|
|
||||||
for elem in ifilter(self.__contains__, other):
|
|
||||||
newcount = _min(self[elem], other[elem])
|
|
||||||
if newcount > 0:
|
|
||||||
result[elem] = newcount
|
|
||||||
return result
|
|
|
@ -1,151 +0,0 @@
|
||||||
import re
|
|
||||||
|
|
||||||
import sqlparse
|
|
||||||
from sqlparse.tokens import Whitespace, Comment, Keyword, Name, Punctuation
|
|
||||||
|
|
||||||
table_def_regex = re.compile(r'^TABLE\s*\((.+)\)$', re.IGNORECASE)
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionMetadata(object):
|
|
||||||
def __init__(self, schema_name, func_name, arg_list, return_type,
|
|
||||||
is_aggregate, is_window, is_set_returning):
|
|
||||||
"""Class for describing a postgresql function"""
|
|
||||||
|
|
||||||
self.schema_name = schema_name
|
|
||||||
self.func_name = func_name
|
|
||||||
self.arg_list = arg_list.strip()
|
|
||||||
self.return_type = return_type.strip()
|
|
||||||
self.is_aggregate = is_aggregate
|
|
||||||
self.is_window = is_window
|
|
||||||
self.is_set_returning = is_set_returning
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
return (isinstance(other, self.__class__) and
|
|
||||||
self.__dict__ == other.__dict__)
|
|
||||||
|
|
||||||
def __ne__(self, other):
|
|
||||||
return not self.__eq__(other)
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
return hash((self.schema_name, self.func_name, self.arg_list,
|
|
||||||
self.return_type, self.is_aggregate, self.is_window,
|
|
||||||
self.is_set_returning))
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return (
|
|
||||||
('%s(schema_name=%r, func_name=%r, arg_list=%r, return_type=%r,'
|
|
||||||
' is_aggregate=%r, is_window=%r, is_set_returning=%r)')
|
|
||||||
% (self.__class__.__name__, self.schema_name, self.func_name,
|
|
||||||
self.arg_list, self.return_type, self.is_aggregate,
|
|
||||||
self.is_window, self.is_set_returning)
|
|
||||||
)
|
|
||||||
|
|
||||||
def fieldnames(self):
|
|
||||||
"""Returns a list of output field names"""
|
|
||||||
|
|
||||||
if self.return_type.lower() == 'void':
|
|
||||||
return []
|
|
||||||
|
|
||||||
match = table_def_regex.match(self.return_type)
|
|
||||||
if match:
|
|
||||||
# Function returns a table -- get the column names
|
|
||||||
return list(field_names(match.group(1), mode_filter=None))
|
|
||||||
|
|
||||||
# Function may have named output arguments -- find them and return
|
|
||||||
# their names
|
|
||||||
return list(field_names(self.arg_list, mode_filter=('OUT', 'INOUT')))
|
|
||||||
|
|
||||||
|
|
||||||
class TypedFieldMetadata(object):
|
|
||||||
"""Describes typed field from a function signature or table definition
|
|
||||||
|
|
||||||
Attributes are:
|
|
||||||
name The name of the argument/column
|
|
||||||
mode 'IN', 'OUT', 'INOUT', 'VARIADIC'
|
|
||||||
type A list of tokens denoting the type
|
|
||||||
default A list of tokens denoting the default value
|
|
||||||
unknown A list of tokens not assigned to type or default
|
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = ['name', 'mode', 'type', 'default', 'unknown']
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.name = None
|
|
||||||
self.mode = 'IN'
|
|
||||||
self.type = []
|
|
||||||
self.default = []
|
|
||||||
self.unknown = []
|
|
||||||
|
|
||||||
def __getitem__(self, attr):
|
|
||||||
return getattr(self, attr)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_typed_field_list(tokens):
|
|
||||||
"""Parses a argument/column list, yielding TypedFieldMetadata objects
|
|
||||||
|
|
||||||
Field/column lists are used in function signatures and table
|
|
||||||
definitions. This function parses a flattened list of sqlparse tokens
|
|
||||||
and yields one metadata argument per argument / column.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# postgres function argument list syntax:
|
|
||||||
# " ( [ [ argmode ] [ argname ] argtype
|
|
||||||
# [ { DEFAULT | = } default_expr ] [, ...] ] )"
|
|
||||||
|
|
||||||
mode_names = set(('IN', 'OUT', 'INOUT', 'VARIADIC'))
|
|
||||||
parse_state = 'type'
|
|
||||||
parens = 0
|
|
||||||
field = TypedFieldMetadata()
|
|
||||||
|
|
||||||
for tok in tokens:
|
|
||||||
if tok.ttype in Whitespace or tok.ttype in Comment:
|
|
||||||
continue
|
|
||||||
elif tok.ttype in Punctuation:
|
|
||||||
if parens == 0 and tok.value == ',':
|
|
||||||
# End of the current field specification
|
|
||||||
if field.type:
|
|
||||||
yield field
|
|
||||||
# Initialize metadata holder for the next field
|
|
||||||
field, parse_state = TypedFieldMetadata(), 'type'
|
|
||||||
elif parens == 0 and tok.value == '=':
|
|
||||||
parse_state = 'default'
|
|
||||||
else:
|
|
||||||
field[parse_state].append(tok)
|
|
||||||
if tok.value == '(':
|
|
||||||
parens += 1
|
|
||||||
elif tok.value == ')':
|
|
||||||
parens -= 1
|
|
||||||
elif parens == 0:
|
|
||||||
if tok.ttype in Keyword:
|
|
||||||
if not field.name and tok.value.upper() in mode_names:
|
|
||||||
# No other keywords allowed before arg name
|
|
||||||
field.mode = tok.value.upper()
|
|
||||||
elif tok.value.upper() == 'DEFAULT':
|
|
||||||
parse_state = 'default'
|
|
||||||
else:
|
|
||||||
parse_state = 'unknown'
|
|
||||||
elif tok.ttype == Name and not field.name:
|
|
||||||
# note that `ttype in Name` would also match Name.Builtin
|
|
||||||
field.name = tok.value
|
|
||||||
else:
|
|
||||||
field[parse_state].append(tok)
|
|
||||||
else:
|
|
||||||
field[parse_state].append(tok)
|
|
||||||
|
|
||||||
# Final argument won't be followed by a comma, so make sure it gets
|
|
||||||
# yielded
|
|
||||||
if field.type:
|
|
||||||
yield field
|
|
||||||
|
|
||||||
|
|
||||||
def field_names(sql, mode_filter=('IN', 'OUT', 'INOUT', 'VARIADIC')):
|
|
||||||
"""Yields field names from a table declaration"""
|
|
||||||
|
|
||||||
if not sql:
|
|
||||||
return
|
|
||||||
|
|
||||||
# sql is something like "x int, y text, ..."
|
|
||||||
tokens = sqlparse.parse(sql)[0].flatten()
|
|
||||||
for f in parse_typed_field_list(tokens):
|
|
||||||
if f.name and (not mode_filter or f.mode in mode_filter):
|
|
||||||
yield f.name
|
|
|
@ -73,7 +73,7 @@ TableReference = namedtuple(
|
||||||
# This code is borrowed from sqlparse example script.
|
# This code is borrowed from sqlparse example script.
|
||||||
# <url>
|
# <url>
|
||||||
def is_subselect(parsed):
|
def is_subselect(parsed):
|
||||||
if not parsed.is_group:
|
if not parsed.is_group():
|
||||||
return False
|
return False
|
||||||
sql_type = ('SELECT', 'INSERT', 'UPDATE', 'CREATE', 'DELETE')
|
sql_type = ('SELECT', 'INSERT', 'UPDATE', 'CREATE', 'DELETE')
|
||||||
for item in parsed.tokens:
|
for item in parsed.tokens:
|
||||||
|
@ -104,7 +104,7 @@ def extract_from_part(parsed, stop_at_punctuation=True):
|
||||||
# condition. So we need to ignore the keyword JOIN and its
|
# condition. So we need to ignore the keyword JOIN and its
|
||||||
# variants INNER JOIN, FULL OUTER JOIN, etc.
|
# variants INNER JOIN, FULL OUTER JOIN, etc.
|
||||||
elif item.ttype is Keyword and (
|
elif item.ttype is Keyword and (
|
||||||
not item.value.upper() == 'FROM') and (
|
not item.value.upper() == 'FROM') and (
|
||||||
not item.value.upper().endswith('JOIN')):
|
not item.value.upper().endswith('JOIN')):
|
||||||
tbl_prefix_seen = False
|
tbl_prefix_seen = False
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
import sqlparse
|
||||||
|
|
||||||
|
|
||||||
|
def query_starts_with(query, prefixes):
|
||||||
|
"""Check if the query starts with any item from *prefixes*."""
|
||||||
|
prefixes = [prefix.lower() for prefix in prefixes]
|
||||||
|
formatted_sql = sqlparse.format(query.lower(), strip_comments=True)
|
||||||
|
return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
|
||||||
|
|
||||||
|
|
||||||
|
def queries_start_with(queries, prefixes):
|
||||||
|
"""Check if any queries start with any item from *prefixes*."""
|
||||||
|
for query in sqlparse.split(queries):
|
||||||
|
if query and query_starts_with(query, prefixes) is True:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_destructive(queries):
|
||||||
|
"""Returns if any of the queries in *queries* is destructive."""
|
||||||
|
keywords = ('drop', 'shutdown', 'delete', 'truncate', 'alter')
|
||||||
|
return queries_start_with(queries, keywords)
|
|
@ -0,0 +1,143 @@
|
||||||
|
from sqlparse import parse
|
||||||
|
from sqlparse.tokens import Keyword, CTE, DML
|
||||||
|
from sqlparse.sql import Identifier, IdentifierList, Parenthesis
|
||||||
|
from collections import namedtuple
|
||||||
|
from .meta import TableMetadata, ColumnMetadata
|
||||||
|
|
||||||
|
|
||||||
|
# TableExpression is a namedtuple representing a CTE, used internally
|
||||||
|
# name: cte alias assigned in the query
|
||||||
|
# columns: list of column names
|
||||||
|
# start: index into the original string of the left parens starting the CTE
|
||||||
|
# stop: index into the original string of the right parens ending the CTE
|
||||||
|
TableExpression = namedtuple('TableExpression', 'name columns start stop')
|
||||||
|
|
||||||
|
|
||||||
|
def isolate_query_ctes(full_text, text_before_cursor):
|
||||||
|
"""Simplify a query by converting CTEs into table metadata objects
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not full_text:
|
||||||
|
return full_text, text_before_cursor, tuple()
|
||||||
|
|
||||||
|
ctes, remainder = extract_ctes(full_text)
|
||||||
|
if not ctes:
|
||||||
|
return full_text, text_before_cursor, ()
|
||||||
|
|
||||||
|
current_position = len(text_before_cursor)
|
||||||
|
meta = []
|
||||||
|
|
||||||
|
for cte in ctes:
|
||||||
|
if cte.start < current_position < cte.stop:
|
||||||
|
# Currently editing a cte - treat its body as the current full_text
|
||||||
|
text_before_cursor = full_text[cte.start:current_position]
|
||||||
|
full_text = full_text[cte.start:cte.stop]
|
||||||
|
return full_text, text_before_cursor, meta
|
||||||
|
|
||||||
|
# Append this cte to the list of available table metadata
|
||||||
|
cols = (ColumnMetadata(name, None, ()) for name in cte.columns)
|
||||||
|
meta.append(TableMetadata(cte.name, cols))
|
||||||
|
|
||||||
|
# Editing past the last cte (ie the main body of the query)
|
||||||
|
full_text = full_text[ctes[-1].stop:]
|
||||||
|
text_before_cursor = text_before_cursor[ctes[-1].stop:current_position]
|
||||||
|
|
||||||
|
return full_text, text_before_cursor, tuple(meta)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_ctes(sql):
|
||||||
|
""" Extract constant table expresseions from a query
|
||||||
|
|
||||||
|
Returns tuple (ctes, remainder_sql)
|
||||||
|
|
||||||
|
ctes is a list of TableExpression namedtuples
|
||||||
|
remainder_sql is the text from the original query after the CTEs have
|
||||||
|
been stripped.
|
||||||
|
"""
|
||||||
|
|
||||||
|
p = parse(sql)[0]
|
||||||
|
|
||||||
|
# Make sure the first meaningful token is "WITH" which is necessary to
|
||||||
|
# define CTEs
|
||||||
|
idx, tok = p.token_next(-1, skip_ws=True, skip_cm=True)
|
||||||
|
if not (tok and tok.ttype == CTE):
|
||||||
|
return [], sql
|
||||||
|
|
||||||
|
# Get the next (meaningful) token, which should be the first CTE
|
||||||
|
idx, tok = p.token_next(idx)
|
||||||
|
if not tok:
|
||||||
|
return ([], '')
|
||||||
|
start_pos = token_start_pos(p.tokens, idx)
|
||||||
|
ctes = []
|
||||||
|
|
||||||
|
if isinstance(tok, IdentifierList):
|
||||||
|
# Multiple ctes
|
||||||
|
for t in tok.get_identifiers():
|
||||||
|
cte_start_offset = token_start_pos(tok.tokens, tok.token_index(t))
|
||||||
|
cte = get_cte_from_token(t, start_pos + cte_start_offset)
|
||||||
|
if not cte:
|
||||||
|
continue
|
||||||
|
ctes.append(cte)
|
||||||
|
elif isinstance(tok, Identifier):
|
||||||
|
# A single CTE
|
||||||
|
cte = get_cte_from_token(tok, start_pos)
|
||||||
|
if cte:
|
||||||
|
ctes.append(cte)
|
||||||
|
|
||||||
|
idx = p.token_index(tok) + 1
|
||||||
|
|
||||||
|
# Collapse everything after the ctes into a remainder query
|
||||||
|
remainder = u''.join(str(tok) for tok in p.tokens[idx:])
|
||||||
|
|
||||||
|
return ctes, remainder
|
||||||
|
|
||||||
|
|
||||||
|
def get_cte_from_token(tok, pos0):
|
||||||
|
cte_name = tok.get_real_name()
|
||||||
|
if not cte_name:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find the start position of the opening parens enclosing the cte body
|
||||||
|
idx, parens = tok.token_next_by(Parenthesis)
|
||||||
|
if not parens:
|
||||||
|
return None
|
||||||
|
|
||||||
|
start_pos = pos0 + token_start_pos(tok.tokens, idx)
|
||||||
|
cte_len = len(str(parens)) # includes parens
|
||||||
|
stop_pos = start_pos + cte_len
|
||||||
|
|
||||||
|
column_names = extract_column_names(parens)
|
||||||
|
|
||||||
|
return TableExpression(cte_name, column_names, start_pos, stop_pos)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_column_names(parsed):
|
||||||
|
# Find the first DML token to check if it's a SELECT or
|
||||||
|
# INSERT/UPDATE/DELETE
|
||||||
|
idx, tok = parsed.token_next_by(t=DML)
|
||||||
|
tok_val = tok and tok.value.lower()
|
||||||
|
|
||||||
|
if tok_val in ('insert', 'update', 'delete'):
|
||||||
|
# Jump ahead to the RETURNING clause where the list of column names is
|
||||||
|
idx, tok = parsed.token_next_by(idx, (Keyword, 'returning'))
|
||||||
|
elif not tok_val == 'select':
|
||||||
|
# Must be invalid CTE
|
||||||
|
return ()
|
||||||
|
|
||||||
|
# The next token should be either a column name, or a list of column names
|
||||||
|
idx, tok = parsed.token_next(idx, skip_ws=True, skip_cm=True)
|
||||||
|
return tuple(t.get_name() for t in _identifiers(tok))
|
||||||
|
|
||||||
|
|
||||||
|
def token_start_pos(tokens, idx):
|
||||||
|
return sum(len(str(t)) for t in tokens[:idx])
|
||||||
|
|
||||||
|
|
||||||
|
def _identifiers(tok):
|
||||||
|
if isinstance(tok, IdentifierList):
|
||||||
|
for t in tok.get_identifiers():
|
||||||
|
# NB: IdentifierList.get_identifiers() can return non-identifiers!
|
||||||
|
if isinstance(t, Identifier):
|
||||||
|
yield t
|
||||||
|
elif isinstance(tok, Identifier):
|
||||||
|
yield tok
|
|
@ -0,0 +1,151 @@
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
_ColumnMetadata = namedtuple(
|
||||||
|
'ColumnMetadata',
|
||||||
|
['name', 'datatype', 'foreignkeys', 'default', 'has_default']
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ColumnMetadata(
|
||||||
|
name, datatype, foreignkeys=None, default=None, has_default=False
|
||||||
|
):
|
||||||
|
return _ColumnMetadata(
|
||||||
|
name, datatype, foreignkeys or [], default, has_default
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ForeignKey = namedtuple('ForeignKey', ['parentschema', 'parenttable',
|
||||||
|
'parentcolumn', 'childschema',
|
||||||
|
'childtable', 'childcolumn'])
|
||||||
|
TableMetadata = namedtuple('TableMetadata', 'name columns')
|
||||||
|
|
||||||
|
|
||||||
|
def parse_defaults(defaults_string):
|
||||||
|
"""Yields default values for a function, given the string provided by
|
||||||
|
pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)"""
|
||||||
|
if not defaults_string:
|
||||||
|
return
|
||||||
|
current = ''
|
||||||
|
in_quote = None
|
||||||
|
for char in defaults_string:
|
||||||
|
if current == '' and char == ' ':
|
||||||
|
# Skip space after comma separating default expressions
|
||||||
|
continue
|
||||||
|
if char == '"' or char == '\'':
|
||||||
|
if in_quote and char == in_quote:
|
||||||
|
# End quote
|
||||||
|
in_quote = None
|
||||||
|
elif not in_quote:
|
||||||
|
# Begin quote
|
||||||
|
in_quote = char
|
||||||
|
elif char == ',' and not in_quote:
|
||||||
|
# End of expression
|
||||||
|
yield current
|
||||||
|
current = ''
|
||||||
|
continue
|
||||||
|
current += char
|
||||||
|
yield current
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionMetadata(object):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, schema_name, func_name, arg_names, arg_types, arg_modes,
|
||||||
|
return_type, is_aggregate, is_window, is_set_returning,
|
||||||
|
arg_defaults
|
||||||
|
):
|
||||||
|
"""Class for describing a postgresql function"""
|
||||||
|
|
||||||
|
self.schema_name = schema_name
|
||||||
|
self.func_name = func_name
|
||||||
|
|
||||||
|
self.arg_modes = tuple(arg_modes) if arg_modes else None
|
||||||
|
self.arg_names = tuple(arg_names) if arg_names else None
|
||||||
|
|
||||||
|
# Be flexible in not requiring arg_types -- use None as a placeholder
|
||||||
|
# for each arg. (Used for compatibility with old versions of postgresql
|
||||||
|
# where such info is hard to get.
|
||||||
|
if arg_types:
|
||||||
|
self.arg_types = tuple(arg_types)
|
||||||
|
elif arg_modes:
|
||||||
|
self.arg_types = tuple([None] * len(arg_modes))
|
||||||
|
elif arg_names:
|
||||||
|
self.arg_types = tuple([None] * len(arg_names))
|
||||||
|
else:
|
||||||
|
self.arg_types = None
|
||||||
|
|
||||||
|
self.arg_defaults = tuple(parse_defaults(arg_defaults))
|
||||||
|
|
||||||
|
self.return_type = return_type.strip()
|
||||||
|
self.is_aggregate = is_aggregate
|
||||||
|
self.is_window = is_window
|
||||||
|
self.is_set_returning = is_set_returning
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return (isinstance(other, self.__class__) and
|
||||||
|
self.__dict__ == other.__dict__)
|
||||||
|
|
||||||
|
def __ne__(self, other):
|
||||||
|
return not self.__eq__(other)
|
||||||
|
|
||||||
|
def _signature(self):
|
||||||
|
return (
|
||||||
|
self.schema_name, self.func_name, self.arg_names, self.arg_types,
|
||||||
|
self.arg_modes, self.return_type, self.is_aggregate,
|
||||||
|
self.is_window, self.is_set_returning, self.arg_defaults
|
||||||
|
)
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self._signature())
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
(
|
||||||
|
'%s(schema_name=%r, func_name=%r, arg_names=%r, '
|
||||||
|
'arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, '
|
||||||
|
'is_window=%r, is_set_returning=%r, arg_defaults=%r)'
|
||||||
|
) % (self.__class__.__name__,) + self._signature()
|
||||||
|
)
|
||||||
|
|
||||||
|
def has_variadic(self):
|
||||||
|
return self.arg_modes and any(
|
||||||
|
arg_mode == 'v' for arg_mode in self.arg_modes)
|
||||||
|
|
||||||
|
def args(self):
|
||||||
|
"""Returns a list of input-parameter ColumnMetadata namedtuples."""
|
||||||
|
if not self.arg_names:
|
||||||
|
return []
|
||||||
|
modes = self.arg_modes or ['i'] * len(self.arg_names)
|
||||||
|
args = [
|
||||||
|
(name, typ)
|
||||||
|
for name, typ, mode in zip(self.arg_names, self.arg_types, modes)
|
||||||
|
if mode in ('i', 'b', 'v') # IN, INOUT, VARIADIC
|
||||||
|
]
|
||||||
|
|
||||||
|
def arg(name, typ, num):
|
||||||
|
num_args = len(args)
|
||||||
|
num_defaults = len(self.arg_defaults)
|
||||||
|
has_default = num + num_defaults >= num_args
|
||||||
|
default = (
|
||||||
|
self.arg_defaults[num - num_args + num_defaults] if has_default
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return ColumnMetadata(name, typ, [], default, has_default)
|
||||||
|
|
||||||
|
return [arg(name, typ, num) for num, (name, typ) in enumerate(args)]
|
||||||
|
|
||||||
|
def fields(self):
|
||||||
|
"""Returns a list of output-field ColumnMetadata namedtuples"""
|
||||||
|
|
||||||
|
if self.return_type.lower() == 'void':
|
||||||
|
return []
|
||||||
|
elif not self.arg_modes:
|
||||||
|
# For functions without output parameters, the function name
|
||||||
|
# is used as the name of the output column.
|
||||||
|
# E.g. 'SELECT unnest FROM unnest(...);'
|
||||||
|
return [ColumnMetadata(self.func_name, self.return_type, [])]
|
||||||
|
|
||||||
|
return [ColumnMetadata(name, typ, [])
|
||||||
|
for name, typ, mode in zip(
|
||||||
|
self.arg_names, self.arg_types, self.arg_modes)
|
||||||
|
if mode in ('o', 'b', 't')] # OUT, INOUT, TABLE
|
|
@ -0,0 +1,149 @@
|
||||||
|
from __future__ import print_function
|
||||||
|
import sqlparse
|
||||||
|
from collections import namedtuple
|
||||||
|
from sqlparse.sql import IdentifierList, Identifier, Function
|
||||||
|
from sqlparse.tokens import Keyword, DML, Punctuation
|
||||||
|
|
||||||
|
TableReference = namedtuple('TableReference', ['schema', 'name', 'alias',
|
||||||
|
'is_function'])
|
||||||
|
TableReference.ref = property(
|
||||||
|
lambda self: self.alias or (
|
||||||
|
self.name if self.name.islower() or self.name[0] == '"'
|
||||||
|
else '"' + self.name + '"')
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# This code is borrowed from sqlparse example script.
|
||||||
|
# <url>
|
||||||
|
def is_subselect(parsed):
|
||||||
|
if not parsed.is_group:
|
||||||
|
return False
|
||||||
|
for item in parsed.tokens:
|
||||||
|
if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT',
|
||||||
|
'UPDATE', 'CREATE',
|
||||||
|
'DELETE'):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _identifier_is_function(identifier):
|
||||||
|
return any(isinstance(t, Function) for t in identifier.tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_from_part(parsed, stop_at_punctuation=True):
|
||||||
|
tbl_prefix_seen = False
|
||||||
|
for item in parsed.tokens:
|
||||||
|
if tbl_prefix_seen:
|
||||||
|
if is_subselect(item):
|
||||||
|
for x in extract_from_part(item, stop_at_punctuation):
|
||||||
|
yield x
|
||||||
|
elif stop_at_punctuation and item.ttype is Punctuation:
|
||||||
|
raise StopIteration
|
||||||
|
# An incomplete nested select won't be recognized correctly as a
|
||||||
|
# sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
|
||||||
|
# the second FROM to trigger this elif condition resulting in a
|
||||||
|
# StopIteration. So we need to ignore the keyword if the keyword
|
||||||
|
# FROM.
|
||||||
|
# Also 'SELECT * FROM abc JOIN def' will trigger this elif
|
||||||
|
# condition. So we need to ignore the keyword JOIN and its variants
|
||||||
|
# INNER JOIN, FULL OUTER JOIN, etc.
|
||||||
|
elif item.ttype is Keyword and (
|
||||||
|
not item.value.upper() == 'FROM') and (
|
||||||
|
not item.value.upper().endswith('JOIN')):
|
||||||
|
tbl_prefix_seen = False
|
||||||
|
else:
|
||||||
|
yield item
|
||||||
|
elif item.ttype is Keyword or item.ttype is Keyword.DML:
|
||||||
|
item_val = item.value.upper()
|
||||||
|
if (item_val in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE') or
|
||||||
|
item_val.endswith('JOIN')):
|
||||||
|
tbl_prefix_seen = True
|
||||||
|
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
|
||||||
|
# So this check here is necessary.
|
||||||
|
elif isinstance(item, IdentifierList):
|
||||||
|
for identifier in item.get_identifiers():
|
||||||
|
if (identifier.ttype is Keyword and
|
||||||
|
identifier.value.upper() == 'FROM'):
|
||||||
|
tbl_prefix_seen = True
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def extract_table_identifiers(token_stream, allow_functions=True):
|
||||||
|
"""yields tuples of TableReference namedtuples"""
|
||||||
|
|
||||||
|
# We need to do some massaging of the names because postgres is case-
|
||||||
|
# insensitive and '"Foo"' is not the same table as 'Foo' (while 'foo' is)
|
||||||
|
def parse_identifier(item):
|
||||||
|
name = item.get_real_name()
|
||||||
|
schema_name = item.get_parent_name()
|
||||||
|
alias = item.get_alias()
|
||||||
|
if not name:
|
||||||
|
schema_name = None
|
||||||
|
name = item.get_name()
|
||||||
|
alias = alias or name
|
||||||
|
schema_quoted = schema_name and item.value[0] == '"'
|
||||||
|
if schema_name and not schema_quoted:
|
||||||
|
schema_name = schema_name.lower()
|
||||||
|
quote_count = item.value.count('"')
|
||||||
|
name_quoted = quote_count > 2 or (quote_count and not schema_quoted)
|
||||||
|
alias_quoted = alias and item.value[-1] == '"'
|
||||||
|
if alias_quoted or name_quoted and not alias and name.islower():
|
||||||
|
alias = '"' + (alias or name) + '"'
|
||||||
|
if name and not name_quoted and not name.islower():
|
||||||
|
if not alias:
|
||||||
|
alias = name
|
||||||
|
name = name.lower()
|
||||||
|
return schema_name, name, alias
|
||||||
|
|
||||||
|
for item in token_stream:
|
||||||
|
if isinstance(item, IdentifierList):
|
||||||
|
for identifier in item.get_identifiers():
|
||||||
|
# Sometimes Keywords (such as FROM ) are classified as
|
||||||
|
# identifiers which don't have the get_real_name() method.
|
||||||
|
try:
|
||||||
|
schema_name = identifier.get_parent_name()
|
||||||
|
real_name = identifier.get_real_name()
|
||||||
|
is_function = (allow_functions and
|
||||||
|
_identifier_is_function(identifier))
|
||||||
|
except AttributeError:
|
||||||
|
continue
|
||||||
|
if real_name:
|
||||||
|
yield TableReference(schema_name, real_name,
|
||||||
|
identifier.get_alias(), is_function)
|
||||||
|
elif isinstance(item, Identifier):
|
||||||
|
schema_name, real_name, alias = parse_identifier(item)
|
||||||
|
is_function = allow_functions and _identifier_is_function(item)
|
||||||
|
|
||||||
|
yield TableReference(schema_name, real_name, alias, is_function)
|
||||||
|
elif isinstance(item, Function):
|
||||||
|
schema_name, real_name, alias = parse_identifier(item)
|
||||||
|
yield TableReference(None, real_name, alias, allow_functions)
|
||||||
|
|
||||||
|
|
||||||
|
# extract_tables is inspired from examples in the sqlparse lib.
|
||||||
|
def extract_tables(sql):
|
||||||
|
"""Extract the table names from an SQL statment.
|
||||||
|
|
||||||
|
Returns a list of TableReference namedtuples
|
||||||
|
|
||||||
|
"""
|
||||||
|
parsed = sqlparse.parse(sql)
|
||||||
|
if not parsed:
|
||||||
|
return ()
|
||||||
|
|
||||||
|
# INSERT statements must stop looking for tables at the sign of first
|
||||||
|
# Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
|
||||||
|
# abc is the table name, but if we don't stop at the first lparen, then
|
||||||
|
# we'll identify abc, col1 and col2 as table names.
|
||||||
|
insert_stmt = parsed[0].token_first().value.lower() == 'insert'
|
||||||
|
stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
|
||||||
|
|
||||||
|
# Kludge: sqlparse mistakenly identifies insert statements as
|
||||||
|
# function calls due to the parenthesized column list, e.g. interprets
|
||||||
|
# "insert into foo (bar, baz)" as a function call to foo with arguments
|
||||||
|
# (bar, baz). So don't allow any identifiers in insert statements
|
||||||
|
# to have is_function=True
|
||||||
|
identifiers = extract_table_identifiers(stream,
|
||||||
|
allow_functions=not insert_stmt)
|
||||||
|
# In the case 'sche.<cursor>', we get an empty TableReference; remove that
|
||||||
|
return tuple(i for i in identifiers if i.name)
|
|
@ -0,0 +1,140 @@
|
||||||
|
from __future__ import print_function
|
||||||
|
import re
|
||||||
|
import sqlparse
|
||||||
|
from sqlparse.sql import Identifier
|
||||||
|
from sqlparse.tokens import Token, Error
|
||||||
|
|
||||||
|
cleanup_regex = {
|
||||||
|
# This matches only alphanumerics and underscores.
|
||||||
|
'alphanum_underscore': re.compile(r'(\w+)$'),
|
||||||
|
# This matches everything except spaces, parens, colon, and comma
|
||||||
|
'many_punctuations': re.compile(r'([^():,\s]+)$'),
|
||||||
|
# This matches everything except spaces, parens, colon, comma, and period
|
||||||
|
'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
|
||||||
|
# This matches everything except a space.
|
||||||
|
'all_punctuations': re.compile(r'([^\s]+)$'),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def last_word(text, include='alphanum_underscore'):
|
||||||
|
r"""
|
||||||
|
Find the last word in a sentence.
|
||||||
|
|
||||||
|
>>> last_word('abc')
|
||||||
|
'abc'
|
||||||
|
>>> last_word(' abc')
|
||||||
|
'abc'
|
||||||
|
>>> last_word('')
|
||||||
|
''
|
||||||
|
>>> last_word(' ')
|
||||||
|
''
|
||||||
|
>>> last_word('abc ')
|
||||||
|
''
|
||||||
|
>>> last_word('abc def')
|
||||||
|
'def'
|
||||||
|
>>> last_word('abc def ')
|
||||||
|
''
|
||||||
|
>>> last_word('abc def;')
|
||||||
|
''
|
||||||
|
>>> last_word('bac $def')
|
||||||
|
'def'
|
||||||
|
>>> last_word('bac $def', include='most_punctuations')
|
||||||
|
'$def'
|
||||||
|
>>> last_word('bac \def', include='most_punctuations')
|
||||||
|
'\\\\def'
|
||||||
|
>>> last_word('bac \def;', include='most_punctuations')
|
||||||
|
'\\\\def;'
|
||||||
|
>>> last_word('bac::def', include='most_punctuations')
|
||||||
|
'def'
|
||||||
|
>>> last_word('"foo*bar', include='most_punctuations')
|
||||||
|
'"foo*bar'
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not text: # Empty string
|
||||||
|
return ''
|
||||||
|
|
||||||
|
if text[-1].isspace():
|
||||||
|
return ''
|
||||||
|
else:
|
||||||
|
regex = cleanup_regex[include]
|
||||||
|
matches = regex.search(text)
|
||||||
|
if matches:
|
||||||
|
return matches.group(0)
|
||||||
|
else:
|
||||||
|
return ''
|
||||||
|
|
||||||
|
|
||||||
|
def find_prev_keyword(sql, n_skip=0):
|
||||||
|
""" Find the last sql keyword in an SQL statement
|
||||||
|
|
||||||
|
Returns the value of the last keyword, and the text of the query with
|
||||||
|
everything after the last keyword stripped
|
||||||
|
"""
|
||||||
|
if not sql.strip():
|
||||||
|
return None, ''
|
||||||
|
|
||||||
|
parsed = sqlparse.parse(sql)[0]
|
||||||
|
flattened = list(parsed.flatten())
|
||||||
|
flattened = flattened[:len(flattened) - n_skip]
|
||||||
|
|
||||||
|
logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN')
|
||||||
|
|
||||||
|
for t in reversed(flattened):
|
||||||
|
if t.value == '(' or (t.is_keyword and (
|
||||||
|
t.value.upper() not in logical_operators)):
|
||||||
|
# Find the location of token t in the original parsed statement
|
||||||
|
# We can't use parsed.token_index(t) because t may be a child token
|
||||||
|
# inside a TokenList, in which case token_index thows an error
|
||||||
|
# Minimal example:
|
||||||
|
# p = sqlparse.parse('select * from foo where bar')
|
||||||
|
# t = list(p.flatten())[-3] # The "Where" token
|
||||||
|
# p.token_index(t) # Throws ValueError: not in list
|
||||||
|
idx = flattened.index(t)
|
||||||
|
|
||||||
|
# Combine the string values of all tokens in the original list
|
||||||
|
# up to and including the target keyword token t, to produce a
|
||||||
|
# query string with everything after the keyword token removed
|
||||||
|
text = ''.join(tok.value for tok in flattened[:idx + 1])
|
||||||
|
return t, text
|
||||||
|
|
||||||
|
return None, ''
|
||||||
|
|
||||||
|
|
||||||
|
# Postgresql dollar quote signs look like `$$` or `$tag$`
|
||||||
|
dollar_quote_regex = re.compile(r'^\$[^$]*\$$')
|
||||||
|
|
||||||
|
|
||||||
|
def is_open_quote(sql):
|
||||||
|
"""Returns true if the query contains an unclosed quote"""
|
||||||
|
|
||||||
|
# parsed can contain one or more semi-colon separated commands
|
||||||
|
parsed = sqlparse.parse(sql)
|
||||||
|
return any(_parsed_is_open_quote(p) for p in parsed)
|
||||||
|
|
||||||
|
|
||||||
|
def _parsed_is_open_quote(parsed):
|
||||||
|
# Look for unmatched single quotes, or unmatched dollar sign quotes
|
||||||
|
return any(tok.match(Token.Error, ("'", "$")) for tok in parsed.flatten())
|
||||||
|
|
||||||
|
|
||||||
|
def parse_partial_identifier(word):
|
||||||
|
"""Attempt to parse a (partially typed) word as an identifier
|
||||||
|
|
||||||
|
word may include a schema qualification, like `schema_name.partial_name`
|
||||||
|
or `schema_name.` There may also be unclosed quotation marks, like
|
||||||
|
`"schema`, or `schema."partial_name`
|
||||||
|
|
||||||
|
:param word: string representing a (partially complete) identifier
|
||||||
|
:return: sqlparse.sql.Identifier, or None
|
||||||
|
"""
|
||||||
|
|
||||||
|
p = sqlparse.parse(word)[0]
|
||||||
|
n_tok = len(p.tokens)
|
||||||
|
if n_tok == 1 and isinstance(p.tokens[0], Identifier):
|
||||||
|
return p.tokens[0]
|
||||||
|
elif p.token_next_by(m=(Error, '"'))[1]:
|
||||||
|
# An unmatched double quote, e.g. '"foo', 'foo."', or 'foo."bar'
|
||||||
|
# Close the double quote, then reparse
|
||||||
|
return parse_partial_identifier(word + '"')
|
||||||
|
else:
|
||||||
|
return None
|
|
@ -1,8 +1,8 @@
|
||||||
import re
|
import re
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
import sqlparse
|
import sqlparse
|
||||||
from sqlparse.tokens import Name
|
from sqlparse.tokens import Name
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
white_space_regex = re.compile(r'\\s+', re.MULTILINE)
|
white_space_regex = re.compile(r'\\s+', re.MULTILINE)
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ white_space_regex = re.compile(r'\\s+', re.MULTILINE)
|
||||||
def _compile_regex(keyword):
|
def _compile_regex(keyword):
|
||||||
# Surround the keyword with word boundaries and replace interior whitespace
|
# Surround the keyword with word boundaries and replace interior whitespace
|
||||||
# with whitespace wildcards
|
# with whitespace wildcards
|
||||||
pattern = r'\\b' + re.sub(white_space_regex, r'\\s+', keyword) + r'\\b'
|
pattern = r'\\b' + white_space_regex.sub(r'\\s+', keyword) + r'\\b'
|
||||||
return re.compile(pattern, re.MULTILINE | re.IGNORECASE)
|
return re.compile(pattern, re.MULTILINE | re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,521 @@
|
||||||
|
from __future__ import print_function
|
||||||
|
import sys
|
||||||
|
import re
|
||||||
|
import sqlparse
|
||||||
|
from collections import namedtuple
|
||||||
|
from sqlparse.sql import Comparison, Identifier, Where
|
||||||
|
from .parseutils.utils import (
|
||||||
|
last_word, find_prev_keyword, parse_partial_identifier)
|
||||||
|
from .parseutils.tables import extract_tables
|
||||||
|
from .parseutils.ctes import isolate_query_ctes
|
||||||
|
|
||||||
|
PY2 = sys.version_info[0] == 2
|
||||||
|
PY3 = sys.version_info[0] == 3
|
||||||
|
|
||||||
|
if PY3:
|
||||||
|
string_types = str
|
||||||
|
else:
|
||||||
|
string_types = basestring
|
||||||
|
|
||||||
|
|
||||||
|
Special = namedtuple('Special', [])
|
||||||
|
Database = namedtuple('Database', [])
|
||||||
|
Schema = namedtuple('Schema', ['quoted'])
|
||||||
|
Schema.__new__.__defaults__ = (False,)
|
||||||
|
# FromClauseItem is a table/view/function used in the FROM clause
|
||||||
|
# `table_refs` contains the list of tables/... already in the statement,
|
||||||
|
# used to ensure that the alias we suggest is unique
|
||||||
|
FromClauseItem = namedtuple('FromClauseItem', 'schema table_refs local_tables')
|
||||||
|
Table = namedtuple('Table', ['schema', 'table_refs', 'local_tables'])
|
||||||
|
TableFormat = namedtuple('TableFormat', [])
|
||||||
|
View = namedtuple('View', ['schema', 'table_refs'])
|
||||||
|
# JoinConditions are suggested after ON, e.g. 'foo.barid = bar.barid'
|
||||||
|
JoinCondition = namedtuple('JoinCondition', ['table_refs', 'parent'])
|
||||||
|
# Joins are suggested after JOIN, e.g. 'foo ON foo.barid = bar.barid'
|
||||||
|
Join = namedtuple('Join', ['table_refs', 'schema'])
|
||||||
|
|
||||||
|
Function = namedtuple('Function', ['schema', 'table_refs', 'usage'])
|
||||||
|
# For convenience, don't require the `usage` argument in Function constructor
|
||||||
|
Function.__new__.__defaults__ = (None, tuple(), None)
|
||||||
|
Table.__new__.__defaults__ = (None, tuple(), tuple())
|
||||||
|
View.__new__.__defaults__ = (None, tuple())
|
||||||
|
FromClauseItem.__new__.__defaults__ = (None, tuple(), tuple())
|
||||||
|
|
||||||
|
Column = namedtuple(
|
||||||
|
'Column',
|
||||||
|
['table_refs', 'require_last_table', 'local_tables',
|
||||||
|
'qualifiable', 'context']
|
||||||
|
)
|
||||||
|
Column.__new__.__defaults__ = (None, None, tuple(), False, None)
|
||||||
|
|
||||||
|
Keyword = namedtuple('Keyword', ['last_token'])
|
||||||
|
Keyword.__new__.__defaults__ = (None,)
|
||||||
|
NamedQuery = namedtuple('NamedQuery', [])
|
||||||
|
Datatype = namedtuple('Datatype', ['schema'])
|
||||||
|
Alias = namedtuple('Alias', ['aliases'])
|
||||||
|
|
||||||
|
Path = namedtuple('Path', [])
|
||||||
|
|
||||||
|
|
||||||
|
class SqlStatement(object):
|
||||||
|
def __init__(self, full_text, text_before_cursor):
|
||||||
|
self.identifier = None
|
||||||
|
self.word_before_cursor = word_before_cursor = last_word(
|
||||||
|
text_before_cursor, include='many_punctuations')
|
||||||
|
full_text = _strip_named_query(full_text)
|
||||||
|
text_before_cursor = _strip_named_query(text_before_cursor)
|
||||||
|
|
||||||
|
full_text, text_before_cursor, self.local_tables = \
|
||||||
|
isolate_query_ctes(full_text, text_before_cursor)
|
||||||
|
|
||||||
|
self.text_before_cursor_including_last_word = text_before_cursor
|
||||||
|
|
||||||
|
# If we've partially typed a word then word_before_cursor won't be an
|
||||||
|
# empty string. In that case we want to remove the partially typed
|
||||||
|
# string before sending it to the sqlparser. Otherwise the last token
|
||||||
|
# will always be the partially typed string which renders the smart
|
||||||
|
# completion useless because it will always return the list of
|
||||||
|
# keywords as completion.
|
||||||
|
if self.word_before_cursor:
|
||||||
|
if word_before_cursor[-1] == '(' or word_before_cursor[0] == '\\':
|
||||||
|
parsed = sqlparse.parse(text_before_cursor)
|
||||||
|
else:
|
||||||
|
text_before_cursor = \
|
||||||
|
text_before_cursor[:-len(word_before_cursor)]
|
||||||
|
parsed = sqlparse.parse(text_before_cursor)
|
||||||
|
self.identifier = parse_partial_identifier(word_before_cursor)
|
||||||
|
else:
|
||||||
|
parsed = sqlparse.parse(text_before_cursor)
|
||||||
|
|
||||||
|
full_text, text_before_cursor, parsed = \
|
||||||
|
_split_multiple_statements(full_text, text_before_cursor, parsed)
|
||||||
|
|
||||||
|
self.full_text = full_text
|
||||||
|
self.text_before_cursor = text_before_cursor
|
||||||
|
self.parsed = parsed
|
||||||
|
|
||||||
|
self.last_token = \
|
||||||
|
parsed and parsed.token_prev(len(parsed.tokens))[1] or ''
|
||||||
|
|
||||||
|
def is_insert(self):
|
||||||
|
return self.parsed.token_first().value.lower() == 'insert'
|
||||||
|
|
||||||
|
def get_tables(self, scope='full'):
|
||||||
|
""" Gets the tables available in the statement.
|
||||||
|
param `scope:` possible values: 'full', 'insert', 'before'
|
||||||
|
If 'insert', only the first table is returned.
|
||||||
|
If 'before', only tables before the cursor are returned.
|
||||||
|
If not 'insert' and the stmt is an insert, the first table is skipped.
|
||||||
|
"""
|
||||||
|
tables = extract_tables(
|
||||||
|
self.full_text if scope == 'full' else self.text_before_cursor)
|
||||||
|
if scope == 'insert':
|
||||||
|
tables = tables[:1]
|
||||||
|
elif self.is_insert():
|
||||||
|
tables = tables[1:]
|
||||||
|
return tables
|
||||||
|
|
||||||
|
def get_previous_token(self, token):
|
||||||
|
return self.parsed.token_prev(self.parsed.token_index(token))[1]
|
||||||
|
|
||||||
|
def get_identifier_schema(self):
|
||||||
|
schema = \
|
||||||
|
(self.identifier and self.identifier.get_parent_name()) or None
|
||||||
|
# If schema name is unquoted, lower-case it
|
||||||
|
if schema and self.identifier.value[0] != '"':
|
||||||
|
schema = schema.lower()
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
|
def reduce_to_prev_keyword(self, n_skip=0):
|
||||||
|
prev_keyword, self.text_before_cursor = \
|
||||||
|
find_prev_keyword(self.text_before_cursor, n_skip=n_skip)
|
||||||
|
return prev_keyword
|
||||||
|
|
||||||
|
|
||||||
|
def suggest_type(full_text, text_before_cursor):
|
||||||
|
"""Takes the full_text that is typed so far and also the text before the
|
||||||
|
cursor to suggest completion type and scope.
|
||||||
|
|
||||||
|
Returns a tuple with a type of entity ('table', 'column' etc) and a scope.
|
||||||
|
A scope for a column category will be a list of tables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if full_text.startswith('\\i '):
|
||||||
|
return (Path(),)
|
||||||
|
|
||||||
|
# This is a temporary hack; the exception handling
|
||||||
|
# here should be removed once sqlparse has been fixed
|
||||||
|
try:
|
||||||
|
stmt = SqlStatement(full_text, text_before_cursor)
|
||||||
|
except (TypeError, AttributeError):
|
||||||
|
return []
|
||||||
|
|
||||||
|
return suggest_based_on_last_token(stmt.last_token, stmt)
|
||||||
|
|
||||||
|
|
||||||
|
named_query_regex = re.compile(r'^\s*\\ns\s+[A-z0-9\-_]+\s+')
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_named_query(txt):
|
||||||
|
"""
|
||||||
|
This will strip "save named query" command in the beginning of the line:
|
||||||
|
'\ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
|
||||||
|
' \ns zzz SELECT * FROM abc' -> 'SELECT * FROM abc'
|
||||||
|
"""
|
||||||
|
|
||||||
|
if named_query_regex.match(txt):
|
||||||
|
txt = named_query_regex.sub('', txt)
|
||||||
|
return txt
|
||||||
|
|
||||||
|
|
||||||
|
function_body_pattern = re.compile(r'(\$.*?\$)([\s\S]*?)\1', re.M)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_function_body(text):
|
||||||
|
split = function_body_pattern.search(text)
|
||||||
|
return (split.start(2), split.end(2)) if split else (None, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _statement_from_function(full_text, text_before_cursor, statement):
|
||||||
|
current_pos = len(text_before_cursor)
|
||||||
|
body_start, body_end = _find_function_body(full_text)
|
||||||
|
if body_start is None:
|
||||||
|
return full_text, text_before_cursor, statement
|
||||||
|
if not body_start <= current_pos < body_end:
|
||||||
|
return full_text, text_before_cursor, statement
|
||||||
|
full_text = full_text[body_start:body_end]
|
||||||
|
text_before_cursor = text_before_cursor[body_start:]
|
||||||
|
parsed = sqlparse.parse(text_before_cursor)
|
||||||
|
return _split_multiple_statements(full_text, text_before_cursor, parsed)
|
||||||
|
|
||||||
|
|
||||||
|
def _split_multiple_statements(full_text, text_before_cursor, parsed):
|
||||||
|
if len(parsed) > 1:
|
||||||
|
# Multiple statements being edited -- isolate the current one by
|
||||||
|
# cumulatively summing statement lengths to find the one that bounds
|
||||||
|
# the current position
|
||||||
|
current_pos = len(text_before_cursor)
|
||||||
|
stmt_start, stmt_end = 0, 0
|
||||||
|
|
||||||
|
for statement in parsed:
|
||||||
|
stmt_len = len(str(statement))
|
||||||
|
stmt_start, stmt_end = stmt_end, stmt_end + stmt_len
|
||||||
|
|
||||||
|
if stmt_end >= current_pos:
|
||||||
|
text_before_cursor = full_text[stmt_start:current_pos]
|
||||||
|
full_text = full_text[stmt_start:]
|
||||||
|
break
|
||||||
|
|
||||||
|
elif parsed:
|
||||||
|
# A single statement
|
||||||
|
statement = parsed[0]
|
||||||
|
else:
|
||||||
|
# The empty string
|
||||||
|
return full_text, text_before_cursor, None
|
||||||
|
|
||||||
|
token2 = None
|
||||||
|
if statement.get_type() in ('CREATE', 'CREATE OR REPLACE'):
|
||||||
|
token1 = statement.token_first()
|
||||||
|
if token1:
|
||||||
|
token1_idx = statement.token_index(token1)
|
||||||
|
token2 = statement.token_next(token1_idx)[1]
|
||||||
|
if token2 and token2.value.upper() == 'FUNCTION':
|
||||||
|
full_text, text_before_cursor, statement = _statement_from_function(
|
||||||
|
full_text, text_before_cursor, statement
|
||||||
|
)
|
||||||
|
return full_text, text_before_cursor, statement
|
||||||
|
|
||||||
|
|
||||||
|
def suggest_based_on_last_token(token, stmt):
|
||||||
|
|
||||||
|
if isinstance(token, string_types):
|
||||||
|
token_v = token.lower()
|
||||||
|
elif isinstance(token, Comparison):
|
||||||
|
# If 'token' is a Comparison type such as
|
||||||
|
# 'select * FROM abc a JOIN def d ON a.id = d.'. Then calling
|
||||||
|
# token.value on the comparison type will only return the lhs of the
|
||||||
|
# comparison. In this case a.id. So we need to do token.tokens to get
|
||||||
|
# both sides of the comparison and pick the last token out of that
|
||||||
|
# list.
|
||||||
|
token_v = token.tokens[-1].value.lower()
|
||||||
|
elif isinstance(token, Where):
|
||||||
|
# sqlparse groups all tokens from the where clause into a single token
|
||||||
|
# list. This means that token.value may be something like
|
||||||
|
# 'where foo > 5 and '. We need to look "inside" token.tokens to handle
|
||||||
|
# suggestions in complicated where clauses correctly
|
||||||
|
prev_keyword = stmt.reduce_to_prev_keyword()
|
||||||
|
return suggest_based_on_last_token(prev_keyword, stmt)
|
||||||
|
elif isinstance(token, Identifier):
|
||||||
|
# If the previous token is an identifier, we can suggest datatypes if
|
||||||
|
# we're in a parenthesized column/field list, e.g.:
|
||||||
|
# CREATE TABLE foo (Identifier <CURSOR>
|
||||||
|
# CREATE FUNCTION foo (Identifier <CURSOR>
|
||||||
|
# If we're not in a parenthesized list, the most likely scenario is the
|
||||||
|
# user is about to specify an alias, e.g.:
|
||||||
|
# SELECT Identifier <CURSOR>
|
||||||
|
# SELECT foo FROM Identifier <CURSOR>
|
||||||
|
prev_keyword, _ = find_prev_keyword(stmt.text_before_cursor)
|
||||||
|
if prev_keyword and prev_keyword.value == '(':
|
||||||
|
# Suggest datatypes
|
||||||
|
return suggest_based_on_last_token('type', stmt)
|
||||||
|
else:
|
||||||
|
return (Keyword(),)
|
||||||
|
else:
|
||||||
|
token_v = token.value.lower()
|
||||||
|
|
||||||
|
if not token:
|
||||||
|
return (Keyword(),)
|
||||||
|
elif token_v.endswith('('):
|
||||||
|
p = sqlparse.parse(stmt.text_before_cursor)[0]
|
||||||
|
|
||||||
|
if p.tokens and isinstance(p.tokens[-1], Where):
|
||||||
|
# Four possibilities:
|
||||||
|
# 1 - Parenthesized clause like "WHERE foo AND ("
|
||||||
|
# Suggest columns/functions
|
||||||
|
# 2 - Function call like "WHERE foo("
|
||||||
|
# Suggest columns/functions
|
||||||
|
# 3 - Subquery expression like "WHERE EXISTS ("
|
||||||
|
# Suggest keywords, in order to do a subquery
|
||||||
|
# 4 - Subquery OR array comparison like "WHERE foo = ANY("
|
||||||
|
# Suggest columns/functions AND keywords. (If we wanted to be
|
||||||
|
# really fancy, we could suggest only array-typed columns)
|
||||||
|
|
||||||
|
column_suggestions = suggest_based_on_last_token('where', stmt)
|
||||||
|
|
||||||
|
# Check for a subquery expression (cases 3 & 4)
|
||||||
|
where = p.tokens[-1]
|
||||||
|
prev_tok = where.token_prev(len(where.tokens) - 1)[1]
|
||||||
|
|
||||||
|
if isinstance(prev_tok, Comparison):
|
||||||
|
# e.g. "SELECT foo FROM bar WHERE foo = ANY("
|
||||||
|
prev_tok = prev_tok.tokens[-1]
|
||||||
|
|
||||||
|
prev_tok = prev_tok.value.lower()
|
||||||
|
if prev_tok == 'exists':
|
||||||
|
return (Keyword(),)
|
||||||
|
else:
|
||||||
|
return column_suggestions
|
||||||
|
|
||||||
|
# Get the token before the parens
|
||||||
|
prev_tok = p.token_prev(len(p.tokens) - 1)[1]
|
||||||
|
|
||||||
|
if (
|
||||||
|
prev_tok and prev_tok.value and
|
||||||
|
prev_tok.value.lower().split(' ')[-1] == 'using'
|
||||||
|
):
|
||||||
|
# tbl1 INNER JOIN tbl2 USING (col1, col2)
|
||||||
|
tables = stmt.get_tables('before')
|
||||||
|
|
||||||
|
# suggest columns that are present in more than one table
|
||||||
|
return (Column(table_refs=tables,
|
||||||
|
require_last_table=True,
|
||||||
|
local_tables=stmt.local_tables),)
|
||||||
|
|
||||||
|
elif p.token_first().value.lower() == 'select':
|
||||||
|
# If the lparen is preceeded by a space chances are we're about to
|
||||||
|
# do a sub-select.
|
||||||
|
if last_word(stmt.text_before_cursor,
|
||||||
|
'all_punctuations').startswith('('):
|
||||||
|
return (Keyword(),)
|
||||||
|
prev_prev_tok = prev_tok and p.token_prev(p.token_index(prev_tok))[1]
|
||||||
|
if prev_prev_tok and prev_prev_tok.normalized == 'INTO':
|
||||||
|
return (
|
||||||
|
Column(table_refs=stmt.get_tables('insert'), context='insert'),
|
||||||
|
)
|
||||||
|
# We're probably in a function argument list
|
||||||
|
return (Column(table_refs=extract_tables(stmt.full_text),
|
||||||
|
local_tables=stmt.local_tables, qualifiable=True),)
|
||||||
|
elif token_v == 'set':
|
||||||
|
return (Column(table_refs=stmt.get_tables(),
|
||||||
|
local_tables=stmt.local_tables),)
|
||||||
|
elif token_v in ('select', 'where', 'having', 'by', 'distinct'):
|
||||||
|
# Check for a table alias or schema qualification
|
||||||
|
parent = (stmt.identifier and stmt.identifier.get_parent_name()) or []
|
||||||
|
tables = stmt.get_tables()
|
||||||
|
if parent:
|
||||||
|
tables = tuple(t for t in tables if identifies(parent, t))
|
||||||
|
return (Column(table_refs=tables, local_tables=stmt.local_tables),
|
||||||
|
Table(schema=parent),
|
||||||
|
View(schema=parent),
|
||||||
|
Function(schema=parent),)
|
||||||
|
else:
|
||||||
|
return (Column(table_refs=tables, local_tables=stmt.local_tables,
|
||||||
|
qualifiable=True),
|
||||||
|
Function(schema=None),
|
||||||
|
Keyword(token_v.upper()),)
|
||||||
|
elif token_v == 'as':
|
||||||
|
# Don't suggest anything for aliases
|
||||||
|
return ()
|
||||||
|
elif (
|
||||||
|
(token_v.endswith('join') and token.is_keyword) or
|
||||||
|
(token_v in ('copy', 'from', 'update', 'into', 'describe', 'truncate'))
|
||||||
|
):
|
||||||
|
|
||||||
|
schema = stmt.get_identifier_schema()
|
||||||
|
tables = extract_tables(stmt.text_before_cursor)
|
||||||
|
is_join = token_v.endswith('join') and token.is_keyword
|
||||||
|
|
||||||
|
# Suggest tables from either the currently-selected schema or the
|
||||||
|
# public schema if no schema has been specified
|
||||||
|
suggest = []
|
||||||
|
|
||||||
|
if not schema:
|
||||||
|
# Suggest schemas
|
||||||
|
suggest.insert(0, Schema())
|
||||||
|
|
||||||
|
if token_v == 'from' or is_join:
|
||||||
|
suggest.append(FromClauseItem(schema=schema,
|
||||||
|
table_refs=tables,
|
||||||
|
local_tables=stmt.local_tables))
|
||||||
|
elif token_v == 'truncate':
|
||||||
|
suggest.append(Table(schema))
|
||||||
|
else:
|
||||||
|
suggest.extend((Table(schema), View(schema)))
|
||||||
|
|
||||||
|
if is_join and _allow_join(stmt.parsed):
|
||||||
|
tables = stmt.get_tables('before')
|
||||||
|
suggest.append(Join(table_refs=tables, schema=schema))
|
||||||
|
|
||||||
|
return tuple(suggest)
|
||||||
|
|
||||||
|
elif token_v == 'function':
|
||||||
|
schema = stmt.get_identifier_schema()
|
||||||
|
# stmt.get_previous_token will fail for e.g.
|
||||||
|
# `SELECT 1 FROM functions WHERE function:`
|
||||||
|
try:
|
||||||
|
prev = stmt.get_previous_token(token).value.lower()
|
||||||
|
if prev in('drop', 'alter', 'create', 'create or replace'):
|
||||||
|
return (Function(schema=schema, usage='signature'),)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return tuple()
|
||||||
|
|
||||||
|
elif token_v in ('table', 'view'):
|
||||||
|
# E.g. 'ALTER TABLE <tablname>'
|
||||||
|
rel_type = \
|
||||||
|
{'table': Table, 'view': View, 'function': Function}[token_v]
|
||||||
|
schema = stmt.get_identifier_schema()
|
||||||
|
if schema:
|
||||||
|
return (rel_type(schema=schema),)
|
||||||
|
else:
|
||||||
|
return (Schema(), rel_type(schema=schema))
|
||||||
|
|
||||||
|
elif token_v == 'column':
|
||||||
|
# E.g. 'ALTER TABLE foo ALTER COLUMN bar
|
||||||
|
return (Column(table_refs=stmt.get_tables()),)
|
||||||
|
|
||||||
|
elif token_v == 'on':
|
||||||
|
tables = stmt.get_tables('before')
|
||||||
|
parent = \
|
||||||
|
(stmt.identifier and stmt.identifier.get_parent_name()) or None
|
||||||
|
if parent:
|
||||||
|
# "ON parent.<suggestion>"
|
||||||
|
# parent can be either a schema name or table alias
|
||||||
|
filteredtables = tuple(t for t in tables if identifies(parent, t))
|
||||||
|
sugs = [Column(table_refs=filteredtables,
|
||||||
|
local_tables=stmt.local_tables),
|
||||||
|
Table(schema=parent),
|
||||||
|
View(schema=parent),
|
||||||
|
Function(schema=parent)]
|
||||||
|
if filteredtables and _allow_join_condition(stmt.parsed):
|
||||||
|
sugs.append(JoinCondition(table_refs=tables,
|
||||||
|
parent=filteredtables[-1]))
|
||||||
|
return tuple(sugs)
|
||||||
|
else:
|
||||||
|
# ON <suggestion>
|
||||||
|
# Use table alias if there is one, otherwise the table name
|
||||||
|
aliases = tuple(t.ref for t in tables)
|
||||||
|
if _allow_join_condition(stmt.parsed):
|
||||||
|
return (Alias(aliases=aliases), JoinCondition(
|
||||||
|
table_refs=tables, parent=None))
|
||||||
|
else:
|
||||||
|
return (Alias(aliases=aliases),)
|
||||||
|
|
||||||
|
elif token_v in ('c', 'use', 'database', 'template'):
|
||||||
|
# "\c <db", "use <db>", "DROP DATABASE <db>",
|
||||||
|
# "CREATE DATABASE <newdb> WITH TEMPLATE <db>"
|
||||||
|
return (Database(),)
|
||||||
|
elif token_v == 'schema':
|
||||||
|
# DROP SCHEMA schema_name, SET SCHEMA schema name
|
||||||
|
prev_keyword = stmt.reduce_to_prev_keyword(n_skip=2)
|
||||||
|
quoted = prev_keyword and prev_keyword.value.lower() == 'set'
|
||||||
|
return (Schema(quoted),)
|
||||||
|
elif token_v.endswith(',') or token_v in ('=', 'and', 'or'):
|
||||||
|
prev_keyword = stmt.reduce_to_prev_keyword()
|
||||||
|
if prev_keyword:
|
||||||
|
return suggest_based_on_last_token(prev_keyword, stmt)
|
||||||
|
else:
|
||||||
|
return ()
|
||||||
|
elif token_v in ('type', '::'):
|
||||||
|
# ALTER TABLE foo SET DATA TYPE bar
|
||||||
|
# SELECT foo::bar
|
||||||
|
# Note that tables are a form of composite type in postgresql, so
|
||||||
|
# they're suggested here as well
|
||||||
|
schema = stmt.get_identifier_schema()
|
||||||
|
suggestions = [Datatype(schema=schema),
|
||||||
|
Table(schema=schema)]
|
||||||
|
if not schema:
|
||||||
|
suggestions.append(Schema())
|
||||||
|
return tuple(suggestions)
|
||||||
|
elif token_v in {'alter', 'create', 'drop'}:
|
||||||
|
return (Keyword(token_v.upper()),)
|
||||||
|
elif token.is_keyword:
|
||||||
|
# token is a keyword we haven't implemented any special handling for
|
||||||
|
# go backwards in the query until we find one we do recognize
|
||||||
|
prev_keyword = stmt.reduce_to_prev_keyword(n_skip=1)
|
||||||
|
if prev_keyword:
|
||||||
|
return suggest_based_on_last_token(prev_keyword, stmt)
|
||||||
|
else:
|
||||||
|
return (Keyword(token_v.upper()),)
|
||||||
|
else:
|
||||||
|
return (Keyword(),)
|
||||||
|
|
||||||
|
|
||||||
|
def identifies(id, ref):
|
||||||
|
"""Returns true if string `id` matches TableReference `ref`"""
|
||||||
|
|
||||||
|
return id == ref.alias or id == ref.name or (
|
||||||
|
ref.schema and (id == ref.schema + '.' + ref.name))
|
||||||
|
|
||||||
|
|
||||||
|
def _allow_join_condition(statement):
|
||||||
|
"""
|
||||||
|
Tests if a join condition should be suggested
|
||||||
|
|
||||||
|
We need this to avoid bad suggestions when entering e.g.
|
||||||
|
select * from tbl1 a join tbl2 b on a.id = <cursor>
|
||||||
|
So check that the preceding token is a ON, AND, or OR keyword, instead of
|
||||||
|
e.g. an equals sign.
|
||||||
|
|
||||||
|
:param statement: an sqlparse.sql.Statement
|
||||||
|
:return: boolean
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not statement or not statement.tokens:
|
||||||
|
return False
|
||||||
|
|
||||||
|
last_tok = statement.token_prev(len(statement.tokens))[1]
|
||||||
|
return last_tok.value.lower() in ('on', 'and', 'or')
|
||||||
|
|
||||||
|
|
||||||
|
def _allow_join(statement):
|
||||||
|
"""
|
||||||
|
Tests if a join should be suggested
|
||||||
|
|
||||||
|
We need this to avoid bad suggestions when entering e.g.
|
||||||
|
select * from tbl1 a join tbl2 b <cursor>
|
||||||
|
So check that the preceding token is a JOIN keyword
|
||||||
|
|
||||||
|
:param statement: an sqlparse.sql.Statement
|
||||||
|
:return: boolean
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not statement or not statement.tokens:
|
||||||
|
return False
|
||||||
|
|
||||||
|
last_tok = statement.token_prev(len(statement.tokens))[1]
|
||||||
|
return (
|
||||||
|
last_tok.value.lower().endswith('join') and
|
||||||
|
last_tok.value.lower() not in('cross join', 'natural join')
|
||||||
|
)
|
Loading…
Reference in New Issue