230 lines
7.1 KiB
Python
230 lines
7.1 KiB
Python
##########################################################################
|
|
#
|
|
# pgAdmin 4 - PostgreSQL Tools
|
|
#
|
|
# Copyright (C) 2013 - 2021, The pgAdmin Development Team
|
|
# This software is released under the PostgreSQL Licence
|
|
#
|
|
##########################################################################
|
|
|
|
"""A blueprint module implementing the Authentication."""
|
|
|
|
import flask
|
|
import pickle
|
|
from flask import current_app, flash, Response, request, url_for,\
|
|
render_template
|
|
from flask_babelex import gettext
|
|
from flask_security import current_user
|
|
from flask_security.views import _security, _ctx
|
|
from flask_security.utils import config_value, get_post_logout_redirect, \
|
|
get_post_login_redirect, logout_user
|
|
|
|
from flask import session
|
|
|
|
import config
|
|
from pgadmin.utils import PgAdminModule
|
|
from pgadmin.utils.constants import KERBEROS
|
|
from pgadmin.utils.csrf import pgCSRFProtect
|
|
|
|
from .registry import AuthSourceRegistry
|
|
|
|
MODULE_NAME = 'authenticate'
|
|
|
|
|
|
class AuthenticateModule(PgAdminModule):
|
|
def get_exposed_url_endpoints(self):
|
|
return ['authenticate.login',
|
|
'authenticate.kerberos_login',
|
|
'authenticate.kerberos_logout']
|
|
|
|
|
|
blueprint = AuthenticateModule(MODULE_NAME, __name__, static_url_path='')
|
|
|
|
|
|
@blueprint.route("/login/kerberos",
|
|
endpoint="kerberos_login", methods=["GET"])
|
|
@pgCSRFProtect.exempt
|
|
def kerberos_login():
|
|
logout_user()
|
|
return Response(render_template("browser/kerberos_login.html",
|
|
login_url=url_for('security.login'),
|
|
))
|
|
|
|
|
|
@blueprint.route("/logout/kerberos",
|
|
endpoint="kerberos_logout", methods=["GET"])
|
|
@pgCSRFProtect.exempt
|
|
def kerberos_logout():
|
|
logout_user()
|
|
return Response(render_template("browser/kerberos_logout.html",
|
|
login_url=url_for('security.login'),
|
|
))
|
|
|
|
|
|
@blueprint.route('/login', endpoint='login', methods=['GET', 'POST'])
|
|
def login():
|
|
"""
|
|
Entry point for all the authentication sources.
|
|
The user input will be validated and authenticated.
|
|
"""
|
|
form = _security.login_form()
|
|
auth_obj = AuthSourceManager(form, config.AUTHENTICATION_SOURCES)
|
|
session['_auth_source_manager_obj'] = None
|
|
|
|
# Validate the user
|
|
if not auth_obj.validate():
|
|
for field in form.errors:
|
|
for error in form.errors[field]:
|
|
flash(error, 'warning')
|
|
return flask.redirect(get_post_logout_redirect())
|
|
|
|
# Authenticate the user
|
|
status, msg = auth_obj.authenticate()
|
|
if status:
|
|
# Login the user
|
|
status, msg = auth_obj.login()
|
|
current_auth_obj = auth_obj.as_dict()
|
|
if not status:
|
|
if current_auth_obj['current_source'] ==\
|
|
KERBEROS:
|
|
return flask.redirect('{0}?next={1}'.format(url_for(
|
|
'authenticate.kerberos_login'), url_for('browser.index')))
|
|
|
|
flash(gettext(msg), 'danger')
|
|
return flask.redirect(get_post_logout_redirect())
|
|
|
|
session['_auth_source_manager_obj'] = current_auth_obj
|
|
return flask.redirect(get_post_login_redirect())
|
|
|
|
elif isinstance(msg, Response):
|
|
return msg
|
|
flash(gettext(msg), 'danger')
|
|
response = flask.redirect(get_post_logout_redirect())
|
|
return response
|
|
|
|
|
|
class AuthSourceManager():
|
|
"""This class will manage all the authentication sources.
|
|
"""
|
|
def __init__(self, form, sources):
|
|
self.form = form
|
|
self.auth_sources = sources
|
|
self.source = None
|
|
self.source_friendly_name = None
|
|
self.current_source = None
|
|
|
|
def as_dict(self):
|
|
"""
|
|
Returns the dictionary object representing this object.
|
|
"""
|
|
|
|
res = dict()
|
|
res['source_friendly_name'] = self.source_friendly_name
|
|
res['auth_sources'] = self.auth_sources
|
|
res['current_source'] = self.current_source
|
|
|
|
return res
|
|
|
|
def set_current_source(self, source):
|
|
self.current_source = source
|
|
|
|
@property
|
|
def get_current_source(self, source):
|
|
return self.current_source
|
|
|
|
def set_source(self, source):
|
|
self.source = source
|
|
|
|
@property
|
|
def get_source(self):
|
|
return self.source
|
|
|
|
def set_source_friendly_name(self, name):
|
|
self.source_friendly_name = name
|
|
|
|
@property
|
|
def get_source_friendly_name(self):
|
|
return self.source_friendly_name
|
|
|
|
def validate(self):
|
|
"""Validate through all the sources."""
|
|
for src in self.auth_sources:
|
|
source = get_auth_sources(src)
|
|
if source.validate(self.form):
|
|
return True
|
|
return False
|
|
|
|
def authenticate(self):
|
|
"""Authenticate through all the sources."""
|
|
status = False
|
|
msg = None
|
|
for src in self.auth_sources:
|
|
source = get_auth_sources(src)
|
|
current_app.logger.debug(
|
|
"Authentication initiated via source: %s" %
|
|
source.get_source_name())
|
|
|
|
if self.form.data['email'] and self.form.data['password'] and \
|
|
source.get_source_name() == KERBEROS:
|
|
continue
|
|
|
|
status, msg = source.authenticate(self.form)
|
|
|
|
# When server sends Unauthorized header to get the ticket over HTTP
|
|
# OR When kerberos authentication failed while accessing pgadmin,
|
|
# we need to break the loop as no need to authenticate further
|
|
# even if the authentication sources set to multiple
|
|
if not status:
|
|
if (hasattr(msg, 'status') and
|
|
msg.status == '401 UNAUTHORIZED') or\
|
|
(source.get_source_name() ==
|
|
KERBEROS and
|
|
request.method == 'GET'):
|
|
break
|
|
|
|
if status:
|
|
self.set_source(source)
|
|
self.set_current_source(source.get_source_name())
|
|
if msg is not None and 'username' in msg:
|
|
self.form._fields['email'].data = msg['username']
|
|
return status, msg
|
|
return status, msg
|
|
|
|
def login(self):
|
|
status, msg = self.source.login(self.form)
|
|
if status:
|
|
self.set_source_friendly_name(self.source.get_friendly_name())
|
|
current_app.logger.debug(
|
|
"Authentication and Login successfully done via source : %s" %
|
|
self.source.get_source_name())
|
|
return status, msg
|
|
|
|
|
|
def get_auth_sources(type):
|
|
"""Get the authenticated source object from the registry"""
|
|
|
|
auth_sources = getattr(current_app, '_pgadmin_auth_sources', None)
|
|
|
|
if auth_sources is None or not isinstance(auth_sources, dict):
|
|
auth_sources = dict()
|
|
|
|
if type in auth_sources:
|
|
return auth_sources[type]
|
|
|
|
auth_source = AuthSourceRegistry.create(type)
|
|
|
|
if auth_source is not None:
|
|
auth_sources[type] = auth_source
|
|
setattr(current_app, '_pgadmin_auth_sources', auth_sources)
|
|
|
|
return auth_source
|
|
|
|
|
|
def init_app(app):
|
|
auth_sources = dict()
|
|
|
|
setattr(app, '_pgadmin_auth_sources', auth_sources)
|
|
AuthSourceRegistry.load_auth_sources()
|
|
|
|
return auth_sources
|