226 lines
8.7 KiB
Python
226 lines
8.7 KiB
Python
##########################################################################
|
|
#
|
|
# pgAdmin 4 - PostgreSQL Tools
|
|
#
|
|
# Copyright (C) 2013 - 2019, The pgAdmin Development Team
|
|
# This software is released under the PostgreSQL Licence
|
|
#
|
|
##########################################################################
|
|
|
|
import json
|
|
import os
|
|
|
|
from flask import url_for
|
|
from pgadmin.utils.route import BaseTestGenerator
|
|
from regression.python_test_utils import test_utils as utils
|
|
from pgadmin.browser.server_groups.servers.databases.tests import \
|
|
utils as database_utils
|
|
from pgadmin.utils.versioned_template_loader import \
|
|
get_version_mapping_directories
|
|
|
|
|
|
def create_resql_module_list(all_modules, exclude_pkgs):
|
|
"""
|
|
This function is used to create the module list for reverse engineering
|
|
SQL by iterating all the modules.
|
|
|
|
:param all_modules: List of all the modules
|
|
:param exclude_pkgs: List of exclude packages
|
|
:return:
|
|
"""
|
|
resql_module_list = dict()
|
|
|
|
for module in all_modules:
|
|
if "tests." in str(module) and not any(str(module).startswith(
|
|
'pgadmin.' + str(exclude_pkg)) for exclude_pkg in exclude_pkgs
|
|
):
|
|
complete_module_name = module.split(".test")
|
|
module_name_list = complete_module_name[0].split(".")
|
|
module_name = module_name_list[len(module_name_list) - 1]
|
|
|
|
resql_module_list[module_name] = os.path.join(*module_name_list)
|
|
|
|
return resql_module_list
|
|
|
|
|
|
class ReverseEngineeredSQLTestCases(BaseTestGenerator):
|
|
""" This class will test the reverse engineering SQL"""
|
|
|
|
scenarios = [
|
|
('Reverse Engineered SQL Test Cases', dict())
|
|
]
|
|
|
|
def setUp(self):
|
|
# Get the database connection
|
|
self.db_con = database_utils.connect_database(
|
|
self, utils.SERVER_GROUP, self.server_information['server_id'],
|
|
self.server_information['db_id'])
|
|
if not self.db_con['info'] == "Database connected.":
|
|
raise Exception("Could not connect to database.")
|
|
|
|
# Get the application path
|
|
self.apppath = os.getcwd()
|
|
|
|
def runTest(self):
|
|
# Create the module list on which reverse engineering sql test
|
|
# cases will be executed.
|
|
resql_module_list = create_resql_module_list(
|
|
BaseTestGenerator.re_sql_module_list,
|
|
BaseTestGenerator.exclude_pkgs)
|
|
|
|
for module in resql_module_list:
|
|
module_path = resql_module_list[module]
|
|
# Get the folder name based on server version number and
|
|
# their existence.
|
|
status, self.test_folder = self.get_test_folder(module_path)
|
|
if not status:
|
|
continue
|
|
|
|
# Iterate all the files in the test folder and check for
|
|
# the JSON files.
|
|
for filename in os.listdir(self.test_folder):
|
|
if filename.endswith(".json"):
|
|
complete_file_name = os.path.join(self.test_folder,
|
|
filename)
|
|
with open(complete_file_name) as jsonfp:
|
|
data = json.load(jsonfp)
|
|
for key, scenarios in data.items():
|
|
self.execute_test_case(scenarios)
|
|
|
|
def tearDown(self):
|
|
database_utils.disconnect_database(
|
|
self, self.server_information['server_id'],
|
|
self.server_information['db_id'])
|
|
|
|
def get_url(self, endpoint, object_id=None):
|
|
"""
|
|
This function is used to get the url.
|
|
|
|
:param endpoint:
|
|
:param object_id:
|
|
:return:
|
|
"""
|
|
object_url = None
|
|
for rule in self.app.url_map.iter_rules(endpoint):
|
|
options = {}
|
|
for arg in rule.arguments:
|
|
if arg == 'gid':
|
|
options['gid'] = int(utils.SERVER_GROUP)
|
|
elif arg == 'sid':
|
|
options['sid'] = int(self.server_information['server_id'])
|
|
elif arg == 'did':
|
|
options['did'] = int(self.server_information['db_id'])
|
|
elif arg == 'scid':
|
|
options['scid'] = int(self.server_information['schema_id'])
|
|
else:
|
|
if object_id is not None:
|
|
options[arg] = int(object_id)
|
|
|
|
with self.app.test_request_context():
|
|
object_url = url_for(rule.endpoint, **options)
|
|
|
|
return object_url
|
|
|
|
def execute_test_case(self, scenarios):
|
|
"""
|
|
This function will run the test cases for specific module.
|
|
|
|
:param module_name: Name of the module
|
|
:param scenarios: List of scenarios
|
|
:return:
|
|
"""
|
|
object_id = None
|
|
# Added line break after scenario name
|
|
print("\n")
|
|
|
|
for scenario in scenarios:
|
|
print(scenario['name'])
|
|
|
|
if 'type' in scenario and scenario['type'] == 'create':
|
|
# Get the url and create the specific node.
|
|
create_url = self.get_url(scenario['endpoint'])
|
|
response = self.tester.post(create_url,
|
|
data=json.dumps(scenario['data']),
|
|
content_type='html/json')
|
|
self.assertEquals(response.status_code, 200)
|
|
resp_data = json.loads(response.data.decode('utf8'))
|
|
object_id = resp_data['node']['_id']
|
|
|
|
# Compare the reverse engineering SQL
|
|
self.check_re_sql(scenario, object_id)
|
|
elif 'type' in scenario and scenario['type'] == 'alter':
|
|
# Get the url and create the specific node.
|
|
alter_url = self.get_url(scenario['endpoint'], object_id)
|
|
response = self.tester.put(alter_url,
|
|
data=json.dumps(scenario['data']),
|
|
follow_redirects=True)
|
|
self.assertEquals(response.status_code, 200)
|
|
resp_data = json.loads(response.data.decode('utf8'))
|
|
object_id = resp_data['node']['_id']
|
|
|
|
# Compare the reverse engineering SQL
|
|
self.check_re_sql(scenario, object_id)
|
|
elif 'type' in scenario and scenario['type'] == 'delete':
|
|
# Get the delete url and delete the object created above.
|
|
delete_url = self.get_url(scenario['endpoint'], object_id)
|
|
delete_response = self.tester.delete(delete_url,
|
|
follow_redirects=True)
|
|
self.assertEquals(delete_response.status_code, 200)
|
|
|
|
def get_test_folder(self, module_path):
|
|
"""
|
|
This function will get the appropriate test folder based on
|
|
server version and their existence.
|
|
|
|
:param module_path: Path of the module to be tested.
|
|
:return:
|
|
"""
|
|
# Join the application path and the module path
|
|
absolute_path = os.path.join(self.apppath, module_path)
|
|
# Iterate the version mapping directories.
|
|
for version_mapping in get_version_mapping_directories(
|
|
self.server['type']):
|
|
if version_mapping['number'] > \
|
|
self.server_information['server_version']:
|
|
continue
|
|
|
|
complete_path = os.path.join(absolute_path, 'tests',
|
|
version_mapping['name'])
|
|
|
|
if os.path.exists(complete_path):
|
|
return True, complete_path
|
|
|
|
return False, None
|
|
|
|
def check_re_sql(self, scenario, object_id):
|
|
"""
|
|
This function is used to get the reverse engineering SQL.
|
|
:param scenario:
|
|
:param object_id:
|
|
:return:
|
|
"""
|
|
sql_url = self.get_url(scenario['sql_endpoint'], object_id)
|
|
response = self.tester.get(sql_url)
|
|
self.assertEquals(response.status_code, 200)
|
|
resp_sql = response.data.decode('unicode_escape')
|
|
|
|
# Remove first and last double quotes
|
|
if resp_sql.startswith('"') and resp_sql.endswith('"'):
|
|
resp_sql = resp_sql[1:-1]
|
|
|
|
# Check if expected sql is given in JSON file or path of the output
|
|
# file is given
|
|
if 'expected_sql_file' in scenario:
|
|
output_file = os.path.join(self.test_folder,
|
|
scenario['expected_sql_file'])
|
|
|
|
if os.path.exists(output_file):
|
|
fp = open(output_file, "r")
|
|
# Used rstrip to remove trailing \n
|
|
sql = fp.read().rstrip()
|
|
self.assertEquals(sql, resp_sql)
|
|
else:
|
|
self.assertFalse("Expected SQL File not found")
|
|
elif 'expected_sql' in scenario:
|
|
self.assertEquals(scenario['expected_sql'], resp_sql)
|