pgadmin4/web/pgadmin/misc/cloud/utils/rds.py

172 lines
5.3 KiB
Python

##########################################################################
#
# pgAdmin 4 - PostgreSQL Tools
#
# Copyright (C) 2013 - 2022, The pgAdmin Development Team
# This software is released under the PostgreSQL Licence
#
##########################################################################
# AWS RDS PostgreSQL provider
import boto3
import pickle
from flask import session
from boto3.session import Session
class RDS():
def __init__(self, access_key, secret_key, session_token=None,
default_region='ap-south-1'):
self._clients = {}
self._access_key = access_key
self._secret_key = secret_key
self._session_token = session_token
self._default_region = default_region
##########################################################################
# AWS Helper functions
##########################################################################
def _get_aws_client(self, type):
""" Create/cache/return an AWS client object """
if type in self._clients:
return self._clients[type]
session = boto3.Session(
aws_access_key_id=self._access_key,
aws_secret_access_key=self._secret_key,
aws_session_token=self._session_token
)
self._clients[type] = session.client(
type, region_name=self._default_region)
return self._clients[type]
def get_available_db_version(self, engine='postgres'):
rds = self._get_aws_client('rds')
return rds.describe_db_engine_versions(Engine=engine)
def get_available_db_instance_class(self, engine='postgres',
engine_version='9.6'):
rds = self._get_aws_client('rds')
_instances = rds.describe_orderable_db_instance_options(
Engine=engine,
EngineVersion=engine_version)
_instances_list = _instances['OrderableDBInstanceOptions']
_marker = _instances['Marker'] if 'Marker' in _instances else None
while _marker:
_tmp_instances = rds.describe_orderable_db_instance_options(
Engine=engine,
EngineVersion=engine_version,
Marker=_marker)
_instances_list = [*_instances_list,
*_tmp_instances['OrderableDBInstanceOptions']]
_marker = _tmp_instances['Marker'] if 'Marker'\
in _tmp_instances else None
return _instances_list
def get_db_instance(self, instance_name):
rds = self._get_aws_client('rds')
return rds.describe_db_instances(
DBInstanceIdentifier=instance_name)
def validate_credentials(self):
client = self._get_aws_client('sts')
try:
identity = client.get_caller_identity()
return True, identity
except Exception as e:
return False, str(e)
finally:
self._clients.pop('sts')
def verify_aws_credentials(data):
"""Verify Credentials"""
session_token = data['secret']['aws_session_token'] if\
'aws_session_token' in data['secret'] else None
if 'aws' not in session:
session['aws'] = {}
if 'aws_rds_obj' not in session['aws'] or\
session['aws']['secret'] != data['secret']:
_rds = RDS(
access_key=data['secret']['aws_access_key'],
secret_key=data['secret']['aws_secret_access_key'],
session_token=session_token,
default_region=data['secret']['aws_region'])
status, identity = _rds.validate_credentials()
if status:
session['aws']['secret'] = data['secret']
session['aws']['aws_rds_obj'] = pickle.dumps(_rds, -1)
return status, identity
return True, None
def clear_aws_session():
"""Clear AWS Session"""
if 'aws' in session:
session.pop('aws')
def get_aws_db_instances(eng_version):
"""Get AWS DB Instances"""
if 'aws' not in session:
return False, 'Session has not created yet.'
if not eng_version or eng_version == '' or eng_version == 'undefined':
eng_version = '9.6.1'
rds_obj = pickle.loads(session['aws']['aws_rds_obj'])
res = rds_obj.get_available_db_instance_class(
engine_version=eng_version)
versions_set = set()
versions = []
for value in res:
versions_set.add(value['DBInstanceClass'])
for value in versions_set:
versions.append({
'label': value,
'value': value
})
return True, versions
def get_aws_db_versions():
"""Get AWS DB Versions"""
if 'aws' not in session:
return False, 'Session has not created yet.'
rds_obj = pickle.loads(session['aws']['aws_rds_obj'])
db_versions = rds_obj.get_available_db_version()
res = db_versions['DBEngineVersions']
versions = []
for value in res:
versions.append({
'label': value['DBEngineVersionDescription'],
'value': value['EngineVersion']
})
return True, versions
def get_aws_regions():
"""Get AWS DB Versions"""
clear_aws_session()
_session = Session()
res = _session.get_available_regions('rds')
regions = []
for value in res:
regions.append({
'label': value,
'value': value
})
return True, regions