commit
75fe9eed96
|
@ -0,0 +1,26 @@
|
|||
Feature: Authentication with JWTs
|
||||
Some of the API endpoints contain information that is specific to a user.
|
||||
To ensure that information is seen only by the user that owns it, we will
|
||||
use a login mechanism coupled with authentication tokens to securely identify
|
||||
a user.
|
||||
|
||||
The code executed in these tests is embedded in every view call. These tests
|
||||
apply to any endpoint that requires authentication. These tests are meant to
|
||||
be the only place authentication logic needs to be tested.
|
||||
|
||||
Scenario: Request for user data includes valid access token
|
||||
Given an authenticated user
|
||||
When a user requests their profile
|
||||
Then the request will be successful
|
||||
And the authentication tokens will remain unchanged
|
||||
|
||||
Scenario: Access token expired
|
||||
Given an authenticated user with an expired access token
|
||||
When a user requests their profile
|
||||
Then the request will be successful
|
||||
And the authentication tokens will be refreshed
|
||||
|
||||
Scenario: Both access and refresh tokens expired
|
||||
Given a previously authenticated user with expired tokens
|
||||
When a user requests their profile
|
||||
Then the request will fail with an unauthorized error
|
|
@ -1,5 +1,4 @@
|
|||
from datetime import date, timedelta
|
||||
import os
|
||||
|
||||
from behave import fixture, use_fixture
|
||||
|
||||
|
@ -28,11 +27,9 @@ def acct_api_client(context):
|
|||
|
||||
def before_feature(context, _):
|
||||
use_fixture(acct_api_client, context)
|
||||
os.environ['SALT'] = 'testsalt'
|
||||
|
||||
|
||||
def before_scenario(context, _):
|
||||
|
||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
||||
_add_agreements(context, db)
|
||||
_add_account(context, db)
|
||||
|
@ -61,7 +58,7 @@ def _add_agreements(context, db):
|
|||
def _add_account(context, db):
|
||||
context.account = Account(
|
||||
email_address='foo@mycroft.ai',
|
||||
display_name='foobar',
|
||||
username='foobar',
|
||||
refresh_tokens=[],
|
||||
subscription=AccountSubscription(
|
||||
type='Monthly Supporter',
|
||||
|
|
|
@ -4,5 +4,5 @@ Feature: Manage account profiles
|
|||
|
||||
Scenario: Retrieve authenticated user's account
|
||||
Given an authenticated user
|
||||
When account endpoint is called to get user profile
|
||||
When a user requests their profile
|
||||
Then user profile is returned
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
from behave import given, then
|
||||
from hamcrest import assert_that, equal_to, is_not
|
||||
|
||||
from selene.api.testing import (
|
||||
generate_access_token,
|
||||
generate_refresh_token,
|
||||
validate_token_cookies
|
||||
)
|
||||
|
||||
|
||||
@given('an authenticated user with an expired access token')
|
||||
def generate_refresh_token_only(context):
|
||||
generate_access_token(context, expire=True)
|
||||
generate_refresh_token(context)
|
||||
context.old_refresh_token = context.refresh_token.jwt
|
||||
|
||||
|
||||
@given('a previously authenticated user with expired tokens')
|
||||
def expire_both_tokens(context):
|
||||
generate_access_token(context, expire=True)
|
||||
generate_refresh_token(context, expire=True)
|
||||
|
||||
|
||||
@then('the authentication tokens will remain unchanged')
|
||||
def check_for_no_new_cookie(context):
|
||||
cookies = context.response.headers.getlist('Set-Cookie')
|
||||
assert_that(cookies, equal_to([]))
|
||||
|
||||
|
||||
@then('the authentication tokens will be refreshed')
|
||||
def check_for_new_cookies(context):
|
||||
validate_token_cookies(context)
|
||||
assert_that(
|
||||
context.refresh_token,
|
||||
is_not(equal_to(context.old_refresh_token))
|
||||
)
|
|
@ -0,0 +1,25 @@
|
|||
from http import HTTPStatus
|
||||
|
||||
from behave import then
|
||||
from hamcrest import assert_that, equal_to
|
||||
|
||||
|
||||
@then('the request will be successful')
|
||||
def check_request_success(context):
|
||||
assert_that(context.response.status_code, equal_to(HTTPStatus.OK))
|
||||
|
||||
|
||||
@then('the request will fail with {error_type} error')
|
||||
def check_for_bad_request(context, error_type):
|
||||
if error_type == 'a bad request':
|
||||
assert_that(
|
||||
context.response.status_code,
|
||||
equal_to(HTTPStatus.BAD_REQUEST)
|
||||
)
|
||||
elif error_type == 'an unauthorized':
|
||||
assert_that(
|
||||
context.response.status_code,
|
||||
equal_to(HTTPStatus.UNAUTHORIZED)
|
||||
)
|
||||
else:
|
||||
raise ValueError('unsupported error_type')
|
|
@ -9,7 +9,7 @@ from selene.data.account import AccountRepository, PRIVACY_POLICY, TERMS_OF_USE
|
|||
from selene.util.db import get_db_connection
|
||||
|
||||
new_account_request = dict(
|
||||
displayName='barfoo',
|
||||
username='barfoo',
|
||||
termsOfUse=True,
|
||||
privacyPolicy=True,
|
||||
login=dict(
|
||||
|
@ -19,7 +19,7 @@ new_account_request = dict(
|
|||
),
|
||||
support=dict(
|
||||
openDataset=True,
|
||||
membership='Monthly Supporter',
|
||||
membership='MONTHLY SUPPORTER',
|
||||
stripeCustomerId='barstripe'
|
||||
)
|
||||
)
|
||||
|
@ -50,11 +50,6 @@ def create_account_without_email(context):
|
|||
)
|
||||
|
||||
|
||||
@then('the request will be successful')
|
||||
def check_request_success(context):
|
||||
assert_that(context.response.status_code, equal_to(HTTPStatus.OK))
|
||||
|
||||
|
||||
@then('the account will be added to the system')
|
||||
def check_db_for_account(context):
|
||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
||||
|
@ -64,7 +59,7 @@ def check_db_for_account(context):
|
|||
assert_that(
|
||||
account.email_address, equal_to('bar@mycroft.ai')
|
||||
)
|
||||
assert_that(account.display_name, equal_to('barfoo'))
|
||||
assert_that(account.username, equal_to('barfoo'))
|
||||
assert_that(account.subscription.type, equal_to('Monthly Supporter'))
|
||||
assert_that(
|
||||
account.subscription.stripe_customer_id,
|
||||
|
@ -74,8 +69,3 @@ def check_db_for_account(context):
|
|||
for agreement in account.agreements:
|
||||
assert_that(agreement.type, is_in((PRIVACY_POLICY, TERMS_OF_USE)))
|
||||
assert_that(agreement.accept_date, equal_to(str(date.today())))
|
||||
|
||||
|
||||
@then('the request will fail with a bad request error')
|
||||
def check_for_bad_request(context):
|
||||
assert_that(context.response.status_code, equal_to(HTTPStatus.BAD_REQUEST))
|
||||
|
|
|
@ -5,16 +5,17 @@ import json
|
|||
from behave import given, then, when
|
||||
from hamcrest import assert_that, equal_to, has_item
|
||||
|
||||
from selene.api.testing import generate_auth_tokens
|
||||
from selene.api.testing import generate_access_token, generate_refresh_token
|
||||
from selene.data.account import PRIVACY_POLICY
|
||||
|
||||
|
||||
@given('an authenticated user')
|
||||
def setup_authenticated_user(context):
|
||||
generate_auth_tokens(context)
|
||||
generate_access_token(context)
|
||||
generate_refresh_token(context)
|
||||
|
||||
|
||||
@when('account endpoint is called to get user profile')
|
||||
@when('a user requests their profile')
|
||||
def call_account_endpoint(context):
|
||||
context.response = context.client.get('/api/account')
|
||||
|
||||
|
|
|
@ -5,12 +5,11 @@ name = "pypi"
|
|||
|
||||
[packages]
|
||||
flask = "*"
|
||||
flask-restful = "*"
|
||||
certifi = "*"
|
||||
uwsgi = "*"
|
||||
|
||||
[dev-packages]
|
||||
selene = {path = "./../../shared"}
|
||||
selene = {editable = true,path = "./../../shared"}
|
||||
behave = "*"
|
||||
pyhamcrest = "*"
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"_meta": {
|
||||
"hash": {
|
||||
"sha256": "d82f5a7209d972cbf70f44f620e09c550dca2c32cea6fb419f90e9f898580c03"
|
||||
"sha256": "e27bc9018c42543c8594ffade1899d7d7c9cef2117f4c48462b0971310caeb0f"
|
||||
},
|
||||
"pipfile-spec": 6,
|
||||
"requires": {
|
||||
|
@ -16,13 +16,6 @@
|
|||
]
|
||||
},
|
||||
"default": {
|
||||
"aniso8601": {
|
||||
"hashes": [
|
||||
"sha256:03c0ffeeb04edeca1ed59684cc6836dc377f58e52e315dc7be3af879909889f4",
|
||||
"sha256:ac30cceff24aec920c37b8d74d7d8a5dd37b1f62a90b4f268a6234cabe147080"
|
||||
],
|
||||
"version": "==4.1.0"
|
||||
},
|
||||
"certifi": {
|
||||
"hashes": [
|
||||
"sha256:47f9c83ef4c0c621eaef743f133f09fa8a74a9b75f037e8624f83bd1b6626cb7",
|
||||
|
@ -46,14 +39,6 @@
|
|||
"index": "pypi",
|
||||
"version": "==1.0.2"
|
||||
},
|
||||
"flask-restful": {
|
||||
"hashes": [
|
||||
"sha256:ecd620c5cc29f663627f99e04f17d1f16d095c83dc1d618426e2ad68b03092f8",
|
||||
"sha256:f8240ec12349afe8df1db168ea7c336c4e5b0271a36982bff7394f93275f2ca9"
|
||||
],
|
||||
"index": "pypi",
|
||||
"version": "==0.3.7"
|
||||
},
|
||||
"itsdangerous": {
|
||||
"hashes": [
|
||||
"sha256:321b033d07f2a4136d3ec762eac9f16a10ccd60f53c0c91af90217ace7ba1f19",
|
||||
|
@ -101,26 +86,12 @@
|
|||
],
|
||||
"version": "==1.1.0"
|
||||
},
|
||||
"pytz": {
|
||||
"hashes": [
|
||||
"sha256:32b0891edff07e28efe91284ed9c31e123d84bea3fd98e1f72be2508f43ef8d9",
|
||||
"sha256:d5f05e487007e29e03409f9398d074e158d920d36eb82eaf66fb1136b0c5374c"
|
||||
],
|
||||
"version": "==2018.9"
|
||||
},
|
||||
"six": {
|
||||
"hashes": [
|
||||
"sha256:3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c",
|
||||
"sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73"
|
||||
],
|
||||
"version": "==1.12.0"
|
||||
},
|
||||
"uwsgi": {
|
||||
"hashes": [
|
||||
"sha256:d2318235c74665a60021a4fc7770e9c2756f9fc07de7b8c22805efe85b5ab277"
|
||||
"sha256:4972ac538800fb2d421027f49b4a1869b66048839507ccf0aa2fda792d99f583"
|
||||
],
|
||||
"index": "pypi",
|
||||
"version": "==2.0.17.1"
|
||||
"version": "==2.0.18"
|
||||
},
|
||||
"werkzeug": {
|
||||
"hashes": [
|
||||
|
@ -139,6 +110,97 @@
|
|||
"index": "pypi",
|
||||
"version": "==1.2.6"
|
||||
},
|
||||
"certifi": {
|
||||
"hashes": [
|
||||
"sha256:47f9c83ef4c0c621eaef743f133f09fa8a74a9b75f037e8624f83bd1b6626cb7",
|
||||
"sha256:993f830721089fef441cdfeb4b2c8c9df86f0c63239f06bd025a76a7daddb033"
|
||||
],
|
||||
"index": "pypi",
|
||||
"version": "==2018.11.29"
|
||||
},
|
||||
"chardet": {
|
||||
"hashes": [
|
||||
"sha256:84ab92ed1c4d4f16916e05906b6b75a6c0fb5db821cc65e70cbd64a3e2a5eaae",
|
||||
"sha256:fc323ffcaeaed0e0a02bf4d117757b98aed530d9ed4531e3e15460124c106691"
|
||||
],
|
||||
"version": "==3.0.4"
|
||||
},
|
||||
"click": {
|
||||
"hashes": [
|
||||
"sha256:2335065e6395b9e67ca716de5f7526736bfa6ceead690adf616d925bdc622b13",
|
||||
"sha256:5b94b49521f6456670fdb30cd82a4eca9412788a93fa6dd6df72c94d5a8ff2d7"
|
||||
],
|
||||
"version": "==7.0"
|
||||
},
|
||||
"deprecated": {
|
||||
"hashes": [
|
||||
"sha256:8bfeba6e630abf42b5d111b68a05f7fe3d6de7004391b3cd614947594f87a4ff",
|
||||
"sha256:b784e0ca85a8c1e694d77e545c10827bd99772392e79d5f5442e761515a1246e"
|
||||
],
|
||||
"version": "==1.2.4"
|
||||
},
|
||||
"flask": {
|
||||
"hashes": [
|
||||
"sha256:2271c0070dbcb5275fad4a82e29f23ab92682dc45f9dfbc22c02ba9b9322ce48",
|
||||
"sha256:a080b744b7e345ccfcbc77954861cb05b3c63786e93f2b3875e0913d44b43f05"
|
||||
],
|
||||
"index": "pypi",
|
||||
"version": "==1.0.2"
|
||||
},
|
||||
"idna": {
|
||||
"hashes": [
|
||||
"sha256:c357b3f628cf53ae2c4c05627ecc484553142ca23264e593d327bcde5e9c3407",
|
||||
"sha256:ea8b7f6188e6fa117537c3df7da9fc686d485087abf6ac197f9c46432f7e4a3c"
|
||||
],
|
||||
"version": "==2.8"
|
||||
},
|
||||
"itsdangerous": {
|
||||
"hashes": [
|
||||
"sha256:321b033d07f2a4136d3ec762eac9f16a10ccd60f53c0c91af90217ace7ba1f19",
|
||||
"sha256:b12271b2047cb23eeb98c8b5622e2e5c5e9abd9784a153e9d8ef9cb4dd09d749"
|
||||
],
|
||||
"version": "==1.1.0"
|
||||
},
|
||||
"jinja2": {
|
||||
"hashes": [
|
||||
"sha256:74c935a1b8bb9a3947c50a54766a969d4846290e1e788ea44c1392163723c3bd",
|
||||
"sha256:f84be1bb0040caca4cea721fcbbbbd61f9be9464ca236387158b0feea01914a4"
|
||||
],
|
||||
"version": "==2.10"
|
||||
},
|
||||
"markupsafe": {
|
||||
"hashes": [
|
||||
"sha256:048ef924c1623740e70204aa7143ec592504045ae4429b59c30054cb31e3c432",
|
||||
"sha256:130f844e7f5bdd8e9f3f42e7102ef1d49b2e6fdf0d7526df3f87281a532d8c8b",
|
||||
"sha256:19f637c2ac5ae9da8bfd98cef74d64b7e1bb8a63038a3505cd182c3fac5eb4d9",
|
||||
"sha256:1b8a7a87ad1b92bd887568ce54b23565f3fd7018c4180136e1cf412b405a47af",
|
||||
"sha256:1c25694ca680b6919de53a4bb3bdd0602beafc63ff001fea2f2fc16ec3a11834",
|
||||
"sha256:1f19ef5d3908110e1e891deefb5586aae1b49a7440db952454b4e281b41620cd",
|
||||
"sha256:1fa6058938190ebe8290e5cae6c351e14e7bb44505c4a7624555ce57fbbeba0d",
|
||||
"sha256:31cbb1359e8c25f9f48e156e59e2eaad51cd5242c05ed18a8de6dbe85184e4b7",
|
||||
"sha256:3e835d8841ae7863f64e40e19477f7eb398674da6a47f09871673742531e6f4b",
|
||||
"sha256:4e97332c9ce444b0c2c38dd22ddc61c743eb208d916e4265a2a3b575bdccb1d3",
|
||||
"sha256:525396ee324ee2da82919f2ee9c9e73b012f23e7640131dd1b53a90206a0f09c",
|
||||
"sha256:52b07fbc32032c21ad4ab060fec137b76eb804c4b9a1c7c7dc562549306afad2",
|
||||
"sha256:52ccb45e77a1085ec5461cde794e1aa037df79f473cbc69b974e73940655c8d7",
|
||||
"sha256:5c3fbebd7de20ce93103cb3183b47671f2885307df4a17a0ad56a1dd51273d36",
|
||||
"sha256:5e5851969aea17660e55f6a3be00037a25b96a9b44d2083651812c99d53b14d1",
|
||||
"sha256:5edfa27b2d3eefa2210fb2f5d539fbed81722b49f083b2c6566455eb7422fd7e",
|
||||
"sha256:7d263e5770efddf465a9e31b78362d84d015cc894ca2c131901a4445eaa61ee1",
|
||||
"sha256:83381342bfc22b3c8c06f2dd93a505413888694302de25add756254beee8449c",
|
||||
"sha256:857eebb2c1dc60e4219ec8e98dfa19553dae33608237e107db9c6078b1167856",
|
||||
"sha256:98e439297f78fca3a6169fd330fbe88d78b3bb72f967ad9961bcac0d7fdd1550",
|
||||
"sha256:bf54103892a83c64db58125b3f2a43df6d2cb2d28889f14c78519394feb41492",
|
||||
"sha256:d9ac82be533394d341b41d78aca7ed0e0f4ba5a2231602e2f05aa87f25c51672",
|
||||
"sha256:e982fe07ede9fada6ff6705af70514a52beb1b2c3d25d4e873e82114cf3c5401",
|
||||
"sha256:edce2ea7f3dfc981c4ddc97add8a61381d9642dc3273737e756517cc03e84dd6",
|
||||
"sha256:efdc45ef1afc238db84cb4963aa689c0408912a0239b0721cb172b4016eb31d6",
|
||||
"sha256:f137c02498f8b935892d5c0172560d7ab54bc45039de8805075e19079c639a9c",
|
||||
"sha256:f82e347a72f955b7017a39708a3667f106e6ad4d10b25f237396a7115d8ed5fd",
|
||||
"sha256:fb7c206e01ad85ce57feeaaa0bf784b97fa3cad0d4a5737bc5295785f5c613a1"
|
||||
],
|
||||
"version": "==1.1.0"
|
||||
},
|
||||
"parse": {
|
||||
"hashes": [
|
||||
"sha256:870dd675c1ee8951db3e29b81ebe44fd131e3eb8c03a79483a58ea574f3145c2"
|
||||
|
@ -152,6 +214,54 @@
|
|||
],
|
||||
"version": "==0.4.2"
|
||||
},
|
||||
"passlib": {
|
||||
"hashes": [
|
||||
"sha256:3d948f64138c25633613f303bcc471126eae67c04d5e3f6b7b8ce6242f8653e0",
|
||||
"sha256:43526aea08fa32c6b6dbbbe9963c4c767285b78147b7437597f992812f69d280"
|
||||
],
|
||||
"version": "==1.7.1"
|
||||
},
|
||||
"psycopg2-binary": {
|
||||
"hashes": [
|
||||
"sha256:19a2d1f3567b30f6c2bb3baea23f74f69d51f0c06c2e2082d0d9c28b0733a4c2",
|
||||
"sha256:2b69cf4b0fa2716fd977aa4e1fd39af6110eb47b2bb30b4e5a469d8fbecfc102",
|
||||
"sha256:2e952fa17ba48cbc2dc063ddeec37d7dc4ea0ef7db0ac1eda8906365a8543f31",
|
||||
"sha256:348b49dd737ff74cfb5e663e18cb069b44c64f77ec0523b5794efafbfa7df0b8",
|
||||
"sha256:3d72a5fdc5f00ca85160915eb9a973cf9a0ab8148f6eda40708bf672c55ac1d1",
|
||||
"sha256:4957452f7868f43f32c090dadb4188e9c74a4687323c87a882e943c2bd4780c3",
|
||||
"sha256:5138cec2ee1e53a671e11cc519505eb08aaaaf390c508f25b09605763d48de4b",
|
||||
"sha256:587098ca4fc46c95736459d171102336af12f0d415b3b865972a79c03f06259f",
|
||||
"sha256:5b79368bcdb1da4a05f931b62760bea0955ee2c81531d8e84625df2defd3f709",
|
||||
"sha256:5cf43807392247d9bc99737160da32d3fa619e0bfd85ba24d1c78db205f472a4",
|
||||
"sha256:676d1a80b1eebc0cacae8dd09b2fde24213173bf65650d22b038c5ed4039f392",
|
||||
"sha256:6b0211ecda389101a7d1d3df2eba0cf7ffbdd2480ca6f1d2257c7bd739e84110",
|
||||
"sha256:79cde4660de6f0bb523c229763bd8ad9a93ac6760b72c369cf1213955c430934",
|
||||
"sha256:7aba9786ac32c2a6d5fb446002ed936b47d5e1f10c466ef7e48f66eb9f9ebe3b",
|
||||
"sha256:7c8159352244e11bdd422226aa17651110b600d175220c451a9acf795e7414e0",
|
||||
"sha256:945f2eedf4fc6b2432697eb90bb98cc467de5147869e57405bfc31fa0b824741",
|
||||
"sha256:96b4e902cde37a7fc6ab306b3ac089a3949e6ce3d824eeca5b19dc0bedb9f6e2",
|
||||
"sha256:9a7bccb1212e63f309eb9fab47b6eaef796f59850f169a25695b248ca1bf681b",
|
||||
"sha256:a3bfcac727538ec11af304b5eccadbac952d4cca1a551a29b8fe554e3ad535dc",
|
||||
"sha256:b19e9f1b85c5d6136f5a0549abdc55dcbd63aba18b4f10d0d063eb65ef2c68b4",
|
||||
"sha256:b664011bb14ca1f2287c17185e222f2098f7b4c857961dbcf9badb28786dbbf4",
|
||||
"sha256:bde7959ef012b628868d69c474ec4920252656d0800835ed999ba5e4f57e3e2e",
|
||||
"sha256:cb095a0657d792c8de9f7c9a0452385a309dfb1bbbb3357d6b1e216353ade6ca",
|
||||
"sha256:d16d42a1b9772152c1fe606f679b2316551f7e1a1ce273e7f808e82a136cdb3d",
|
||||
"sha256:d444b1545430ffc1e7a24ce5a9be122ccd3b135a7b7e695c5862c5aff0b11159",
|
||||
"sha256:d93ccc7bf409ec0a23f2ac70977507e0b8a8d8c54e5ee46109af2f0ec9e411f3",
|
||||
"sha256:df6444f952ca849016902662e1a47abf4fa0678d75f92fd9dd27f20525f809cd",
|
||||
"sha256:e63850d8c52ba2b502662bf3c02603175c2397a9acc756090e444ce49508d41e",
|
||||
"sha256:ec43358c105794bc2b6fd34c68d27f92bea7102393c01889e93f4b6a70975728",
|
||||
"sha256:f4c6926d9c03dadce7a3b378b40d2fea912c1344ef9b29869f984fb3d2a2420b"
|
||||
],
|
||||
"version": "==2.7.7"
|
||||
},
|
||||
"pygithub": {
|
||||
"hashes": [
|
||||
"sha256:263102b43a83e2943900c1313109db7a00b3b78aeeae2c9137ba694982864872"
|
||||
],
|
||||
"version": "==1.43.5"
|
||||
},
|
||||
"pyhamcrest": {
|
||||
"hashes": [
|
||||
"sha256:6b672c02fdf7470df9674ab82263841ce8333fb143f32f021f6cb26f0e512420",
|
||||
|
@ -160,7 +270,29 @@
|
|||
"index": "pypi",
|
||||
"version": "==1.9.0"
|
||||
},
|
||||
"pyjwt": {
|
||||
"hashes": [
|
||||
"sha256:5c6eca3c2940464d106b99ba83b00c6add741c9becaec087fb7ccdefea71350e",
|
||||
"sha256:8d59a976fb773f3e6a39c85636357c4f0e242707394cadadd9814f5cbaa20e96"
|
||||
],
|
||||
"version": "==1.7.1"
|
||||
},
|
||||
"requests": {
|
||||
"hashes": [
|
||||
"sha256:502a824f31acdacb3a35b6690b5fbf0bc41d63a24a45c4004352b0242707598e",
|
||||
"sha256:7bf2a778576d825600030a110f3c0e3e8edc51dfaafe1c146e39a2027784957b"
|
||||
],
|
||||
"version": "==2.21.0"
|
||||
},
|
||||
"schematics": {
|
||||
"hashes": [
|
||||
"sha256:8fcc6182606fd0b24410a1dbb066d9bbddbe8da9c9509f47b743495706239283",
|
||||
"sha256:a40b20635c0e43d18d3aff76220f6cd95ea4decb3f37765e49529b17d81b0439"
|
||||
],
|
||||
"version": "==2.1.0"
|
||||
},
|
||||
"selene": {
|
||||
"editable": true,
|
||||
"path": "./../../shared"
|
||||
},
|
||||
"six": {
|
||||
|
@ -169,6 +301,26 @@
|
|||
"sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73"
|
||||
],
|
||||
"version": "==1.12.0"
|
||||
},
|
||||
"urllib3": {
|
||||
"hashes": [
|
||||
"sha256:61bf29cada3fc2fbefad4fdf059ea4bd1b4a86d2b6d15e1c7c0b582b9752fe39",
|
||||
"sha256:de9529817c93f27c8ccbfead6985011db27bd0ddfcdb2d86f3f663385c6a9c22"
|
||||
],
|
||||
"version": "==1.24.1"
|
||||
},
|
||||
"werkzeug": {
|
||||
"hashes": [
|
||||
"sha256:c3fd7a7d41976d9f44db327260e263132466836cef6f91512889ed60ad26557c",
|
||||
"sha256:d5da73735293558eb1651ee2fddc4d0dedcfa06538b8813a2e20011583c9e49b"
|
||||
],
|
||||
"version": "==0.14.1"
|
||||
},
|
||||
"wrapt": {
|
||||
"hashes": [
|
||||
"sha256:4aea003270831cceb8a90ff27c4031da6ead7ec1886023b80ce0dfe0adf61533"
|
||||
],
|
||||
"version": "==1.11.1"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
"""Define the API that will support Mycroft single sign on (SSO)."""
|
||||
|
||||
from logging import getLogger
|
||||
|
||||
from flask import Flask, request
|
||||
from flask_restful import Api
|
||||
|
||||
from selene.api.base_config import get_base_config
|
||||
from selene.api import get_base_config, selene_api, SeleneResponse
|
||||
from selene.util.log import configure_logger
|
||||
|
||||
from .endpoints import (
|
||||
AuthenticateInternalEndpoint,
|
||||
|
@ -13,18 +11,30 @@ from .endpoints import (
|
|||
ValidateFederatedEndpoint
|
||||
)
|
||||
|
||||
_log = getLogger('sso_api')
|
||||
_log = configure_logger('sso_api')
|
||||
|
||||
# Initialize the Flask application and the Flask Restful API
|
||||
# Define the Flask application
|
||||
sso = Flask(__name__)
|
||||
sso.config.from_object(get_base_config())
|
||||
sso.response_class = SeleneResponse
|
||||
sso.register_blueprint(selene_api)
|
||||
|
||||
# Initialize the REST API and define the endpoints
|
||||
sso_api = Api(sso, catch_all_404s=True)
|
||||
sso_api.add_resource(AuthenticateInternalEndpoint, '/api/internal-login')
|
||||
sso_api.add_resource(ValidateFederatedEndpoint, '/api/validate-federated')
|
||||
|
||||
sso_api.add_resource(LogoutEndpoint, '/api/logout')
|
||||
# Define the endpoints
|
||||
sso.add_url_rule(
|
||||
'/api/internal-login',
|
||||
view_func=AuthenticateInternalEndpoint.as_view('internal_login'),
|
||||
methods=['GET']
|
||||
)
|
||||
sso.add_url_rule(
|
||||
'/api/validate-federated',
|
||||
view_func=ValidateFederatedEndpoint.as_view('federated_login'),
|
||||
methods=['POST']
|
||||
)
|
||||
sso.add_url_rule(
|
||||
'/api/logout',
|
||||
view_func=LogoutEndpoint.as_view('logout'),
|
||||
methods=['GET']
|
||||
)
|
||||
|
||||
|
||||
def add_cors_headers(response):
|
||||
|
|
|
@ -18,20 +18,16 @@ class AuthenticateInternalEndpoint(SeleneEndpoint):
|
|||
"""Sign in a user with an email address and password."""
|
||||
def __init__(self):
|
||||
super(AuthenticateInternalEndpoint, self).__init__()
|
||||
self.response_status_code = HTTPStatus.OK
|
||||
self.account: Account = None
|
||||
|
||||
def get(self):
|
||||
"""Process HTTP GET request."""
|
||||
try:
|
||||
self._authenticate_credentials()
|
||||
access_token, refresh_token = self._generate_tokens()
|
||||
self._add_refresh_token_to_db(refresh_token)
|
||||
self._set_token_cookies(access_token, refresh_token)
|
||||
except AuthenticationError as ae:
|
||||
self.response = (str(ae), HTTPStatus.UNAUTHORIZED)
|
||||
else:
|
||||
self.response = ({}, HTTPStatus.OK)
|
||||
self._authenticate_credentials()
|
||||
self._generate_tokens()
|
||||
self._add_refresh_token_to_db()
|
||||
self._set_token_cookies()
|
||||
|
||||
self.response = dict(result='user authenticated'), HTTPStatus.OK
|
||||
|
||||
return self.response
|
||||
|
||||
|
@ -52,15 +48,15 @@ class AuthenticateInternalEndpoint(SeleneEndpoint):
|
|||
)
|
||||
if self.account is None:
|
||||
raise AuthenticationError('provided credentials not found')
|
||||
self.access_token.account_id = self.account.id
|
||||
self.refresh_token.account_id = self.account.id
|
||||
|
||||
def _add_refresh_token_to_db(self, refresh_token: str):
|
||||
def _add_refresh_token_to_db(self):
|
||||
"""Track refresh tokens in the database.
|
||||
|
||||
We need to store the value of the refresh token in the database so
|
||||
that we can validate it when it is used to request new tokens.
|
||||
|
||||
:param refresh_token: the token to install into the database.
|
||||
"""
|
||||
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
|
||||
token_repo = RefreshTokenRepository(db, self.account)
|
||||
token_repo.add_refresh_token(refresh_token)
|
||||
token_repo = RefreshTokenRepository(db, self.account.id)
|
||||
token_repo.add_refresh_token(self.refresh_token.jwt)
|
||||
|
|
|
@ -13,8 +13,7 @@ _log = getLogger(__package__)
|
|||
class LogoutEndpoint(SeleneEndpoint):
|
||||
def get(self):
|
||||
self._authenticate()
|
||||
if self.authenticated or self.refresh_token_expired:
|
||||
self._logout()
|
||||
self._logout()
|
||||
|
||||
return self.response
|
||||
|
||||
|
@ -26,9 +25,9 @@ class LogoutEndpoint(SeleneEndpoint):
|
|||
"""
|
||||
request_refresh_token = self.request.cookies['seleneRefresh']
|
||||
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
|
||||
token_repository = RefreshTokenRepository(db, self.account)
|
||||
token_repository = RefreshTokenRepository(db, self.account.id)
|
||||
token_repository.delete_refresh_token(request_refresh_token)
|
||||
access_token, refresh_token = self._generate_tokens()
|
||||
self._set_token_cookies(access_token, refresh_token, expire=True)
|
||||
self._generate_tokens()
|
||||
self._set_token_cookies(expire=True)
|
||||
|
||||
self.response = ('logged out', HTTPStatus.OK)
|
||||
|
|
|
@ -16,19 +16,15 @@ from selene.util.db import get_db_connection
|
|||
class ValidateFederatedEndpoint(SeleneEndpoint):
|
||||
def post(self):
|
||||
"""Process a HTTP POST request."""
|
||||
try:
|
||||
self._get_account()
|
||||
except AuthenticationError as ae:
|
||||
self.response = str(ae), HTTPStatus.UNAUTHORIZED
|
||||
else:
|
||||
access_token, refresh_token = self._generate_tokens()
|
||||
self._set_token_cookies(access_token, refresh_token)
|
||||
self._add_refresh_token_to_db(refresh_token)
|
||||
self.response = 'account validated', HTTPStatus.OK
|
||||
self._get_account_by_email()
|
||||
self._generate_tokens()
|
||||
self._set_token_cookies()
|
||||
self._add_refresh_token_to_db()
|
||||
self.response = dict(result='account validated'), HTTPStatus.OK
|
||||
|
||||
return self.response
|
||||
|
||||
def _get_account(self):
|
||||
def _get_account_by_email(self):
|
||||
"""Use email returned by the authentication platform for validation"""
|
||||
email_address = self.request.form['email']
|
||||
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
|
||||
|
@ -38,14 +34,13 @@ class ValidateFederatedEndpoint(SeleneEndpoint):
|
|||
if self.account is None:
|
||||
raise AuthenticationError('account not found')
|
||||
|
||||
def _add_refresh_token_to_db(self, refresh_token):
|
||||
def _add_refresh_token_to_db(self):
|
||||
"""Track refresh tokens in the database.
|
||||
|
||||
We need to store the value of the refresh token in the database so
|
||||
that we can validate it when it is used to request new tokens.
|
||||
|
||||
:param refresh_token: the token to install into the database.
|
||||
"""
|
||||
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
|
||||
token_repo = RefreshTokenRepository(db, self.account)
|
||||
token_repo.add_refresh_token(refresh_token)
|
||||
token_repo = RefreshTokenRepository(db, self.account.id)
|
||||
token_repo.add_refresh_token(self.refresh_token.jwt)
|
||||
|
|
|
@ -38,29 +38,28 @@ def before_scenario(context, _):
|
|||
|
||||
|
||||
def _add_agreement(context, db):
|
||||
context.agreement = Agreement(
|
||||
agreement = Agreement(
|
||||
type='Privacy Policy',
|
||||
version='1',
|
||||
content='this is Privacy Policy version 1',
|
||||
version='999',
|
||||
content='this is Privacy Policy version 999',
|
||||
effective_date=date.today() - timedelta(days=5)
|
||||
)
|
||||
agreement_repository = AgreementRepository(db)
|
||||
agreement_repository.add(context.agreement)
|
||||
agreement_repository.add(agreement)
|
||||
context.agreement = agreement_repository.get_active_for_type(PRIVACY_POLICY)
|
||||
|
||||
|
||||
def _add_account(context, db):
|
||||
test_account = Account(
|
||||
id=None,
|
||||
email_address='foo@mycroft.ai',
|
||||
username='foobar',
|
||||
refresh_tokens=None,
|
||||
display_name='foobar',
|
||||
subscription=AccountSubscription(
|
||||
type='monthly supporter',
|
||||
start_date=None,
|
||||
type='Monthly Supporter',
|
||||
start_date=date.today(),
|
||||
stripe_customer_id='foo'
|
||||
),
|
||||
agreements=[
|
||||
AccountAgreement(name=PRIVACY_POLICY, accept_date=None)
|
||||
AccountAgreement(type=PRIVACY_POLICY, accept_date=date.today())
|
||||
]
|
||||
)
|
||||
acct_repository = AccountRepository(db)
|
||||
|
|
|
@ -62,4 +62,5 @@ def check_for_login_fail(context, error_message):
|
|||
equal_to('*')
|
||||
)
|
||||
assert_that(context.response.is_json, equal_to(True))
|
||||
assert_that(context.response.get_json(), equal_to(error_message))
|
||||
response_json = context.response.get_json()
|
||||
assert_that(response_json['error'], equal_to(error_message))
|
||||
|
|
|
@ -3,7 +3,8 @@ from behave import given, then, when
|
|||
from hamcrest import assert_that, equal_to, has_item, is_not
|
||||
|
||||
from selene.api.testing import (
|
||||
generate_auth_tokens,
|
||||
generate_access_token,
|
||||
generate_refresh_token,
|
||||
get_account,
|
||||
validate_token_cookies
|
||||
)
|
||||
|
@ -16,7 +17,8 @@ def save_email(context, email):
|
|||
|
||||
@when('user attempts to logout')
|
||||
def call_logout_endpoint(context):
|
||||
generate_auth_tokens(context)
|
||||
generate_access_token(context)
|
||||
generate_refresh_token(context)
|
||||
context.response = context.client.get('/api/logout')
|
||||
|
||||
|
||||
|
@ -39,7 +41,7 @@ def check_refresh_token_removed(context):
|
|||
account = get_account(context)
|
||||
assert_that(
|
||||
account.refresh_tokens,
|
||||
is_not(has_item(context.request_refresh_token))
|
||||
is_not(has_item(context.refresh_token))
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
CREATE TABLE account.account (
|
||||
id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
email_address text NOT NULL UNIQUE,
|
||||
display_name text NOT NULL UNIQUE,
|
||||
username text NOT NULL UNIQUE,
|
||||
password text,
|
||||
insert_ts TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
|
|
@ -9,7 +9,6 @@ pygithub = "*"
|
|||
psycopg2-binary = "*"
|
||||
passlib = "*"
|
||||
pyhamcrest = "*"
|
||||
validator-collection = "*"
|
||||
schematics = "*"
|
||||
|
||||
[dev-packages]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"_meta": {
|
||||
"hash": {
|
||||
"sha256": "4d6582a80eebdbca3cd6aedd17561dede1e8d01a1315d05248225350afcd82ef"
|
||||
"sha256": "63164ec5172150b56a6a3930e1eb1bfa195837d6dbc019864a1b5475cdfdf590"
|
||||
},
|
||||
"pipfile-spec": 6,
|
||||
"requires": {
|
||||
|
@ -44,13 +44,6 @@
|
|||
],
|
||||
"version": "==2.8"
|
||||
},
|
||||
"jsonschema": {
|
||||
"hashes": [
|
||||
"sha256:000e68abd33c972a5248544925a0cae7d1125f9bf6c58280d37546b946769a08",
|
||||
"sha256:6ff5f3180870836cae40f06fa10419f557208175f13ad7bc26caa77beb1f6e02"
|
||||
],
|
||||
"version": "==2.6.0"
|
||||
},
|
||||
"passlib": {
|
||||
"hashes": [
|
||||
"sha256:3d948f64138c25633613f303bcc471126eae67c04d5e3f6b7b8ce6242f8653e0",
|
||||
|
@ -147,13 +140,6 @@
|
|||
],
|
||||
"version": "==1.24.1"
|
||||
},
|
||||
"validator-collection": {
|
||||
"hashes": [
|
||||
"sha256:e8ddec6d301bd3be40cacb9d4f9f85573bc003e3e17a66ba7267ef46b9a8e3d2"
|
||||
],
|
||||
"index": "pypi",
|
||||
"version": "==1.3.2"
|
||||
},
|
||||
"wrapt": {
|
||||
"hashes": [
|
||||
"sha256:4aea003270831cceb8a90ff27c4031da6ead7ec1886023b80ce0dfe0adf61533"
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
"""Base class for Flask API endpoints"""
|
||||
|
||||
from http import HTTPStatus
|
||||
|
||||
from logging import getLogger
|
||||
from flask import after_this_request, current_app, request
|
||||
from flask.views import MethodView
|
||||
|
||||
|
@ -10,15 +8,16 @@ from selene.data.account import (
|
|||
AccountRepository,
|
||||
RefreshTokenRepository
|
||||
)
|
||||
from selene.util.auth import (
|
||||
AuthenticationError,
|
||||
AuthenticationTokenGenerator,
|
||||
AuthenticationTokenValidator,
|
||||
FIFTEEN_MINUTES,
|
||||
ONE_MONTH
|
||||
)
|
||||
from selene.util.auth import AuthenticationError, AuthenticationToken
|
||||
from selene.util.db import get_db_connection
|
||||
|
||||
ACCESS_TOKEN_COOKIE_NAME = 'seleneAccess'
|
||||
FIFTEEN_MINUTES = 900
|
||||
ONE_MONTH = 2628000
|
||||
REFRESH_TOKEN_COOKIE_NAME = 'seleneRefresh'
|
||||
|
||||
_log = getLogger()
|
||||
|
||||
|
||||
class APIError(Exception):
|
||||
"""Raise this exception whenever a non-successful response is built"""
|
||||
|
@ -34,16 +33,19 @@ class SeleneEndpoint(MethodView):
|
|||
HTTP methods. Each list member must be a HTTPMethod enum
|
||||
- override the _build_response_data method
|
||||
"""
|
||||
authentication_required: bool = True
|
||||
|
||||
def __init__(self):
|
||||
self.config: dict = current_app.config
|
||||
self.authenticated = False
|
||||
self.request = request
|
||||
self.response: tuple = None
|
||||
self.access_token_expired: bool = False
|
||||
self.refresh_token_expired: bool = False
|
||||
self.account: Account = None
|
||||
self.access_token = AuthenticationToken(
|
||||
self.config['ACCESS_SECRET'],
|
||||
FIFTEEN_MINUTES
|
||||
)
|
||||
self.refresh_token = AuthenticationToken(
|
||||
self.config['REFRESH_SECRET'],
|
||||
ONE_MONTH
|
||||
)
|
||||
|
||||
def _authenticate(self):
|
||||
"""
|
||||
|
@ -51,90 +53,92 @@ class SeleneEndpoint(MethodView):
|
|||
|
||||
:raises: APIError()
|
||||
"""
|
||||
try:
|
||||
account_id = self._validate_auth_tokens()
|
||||
self._validate_account(account_id)
|
||||
except AuthenticationError as ae:
|
||||
if self.authentication_required:
|
||||
self.response = (str(ae), HTTPStatus.UNAUTHORIZED)
|
||||
else:
|
||||
self.authenticated = True
|
||||
self._validate_auth_tokens()
|
||||
account_id = self._get_account_id_from_tokens()
|
||||
self._get_account(account_id)
|
||||
self._validate_account(account_id)
|
||||
if self.access_token.is_expired:
|
||||
self._refresh_auth_tokens()
|
||||
|
||||
def _validate_auth_tokens(self) -> str:
|
||||
self.access_token_expired, account_id = self._validate_token(
|
||||
'seleneAccess',
|
||||
self.config['ACCESS_SECRET']
|
||||
def _validate_auth_tokens(self):
|
||||
"""Ensure the tokens are passed in request and are well formed."""
|
||||
self.access_token.jwt = self.request.cookies.get(
|
||||
ACCESS_TOKEN_COOKIE_NAME
|
||||
)
|
||||
if self.access_token_expired:
|
||||
self.refresh_token_expired, account_id = self._validate_token(
|
||||
'seleneRefresh',
|
||||
self.config['REFRESH_SECRET']
|
||||
)
|
||||
self.access_token.validate()
|
||||
self.refresh_token.jwt = self.request.cookies.get(
|
||||
REFRESH_TOKEN_COOKIE_NAME
|
||||
)
|
||||
self.refresh_token.validate()
|
||||
|
||||
if self.access_token.jwt is None and self.refresh_token.jwt is None:
|
||||
raise AuthenticationError('no authentication tokens found')
|
||||
|
||||
if self.access_token.is_expired and self.refresh_token.is_expired:
|
||||
raise AuthenticationError('authentication tokens expired')
|
||||
|
||||
def _get_account_id_from_tokens(self):
|
||||
"""Extract the account ID, which is encoded within the tokens"""
|
||||
if self.access_token.is_expired:
|
||||
account_id = self.refresh_token.account_id
|
||||
else:
|
||||
account_id = self.access_token.account_id
|
||||
|
||||
return account_id
|
||||
|
||||
def _validate_token(self, cookie_key, jwt_secret):
|
||||
"""Validate the access token is well-formed and not expired
|
||||
|
||||
:raises: AuthenticationError
|
||||
"""
|
||||
account_id = None
|
||||
token_expired = False
|
||||
|
||||
try:
|
||||
token = self.request.cookies[cookie_key]
|
||||
except KeyError:
|
||||
error_msg = 'no {} token found in request'
|
||||
raise AuthenticationError(error_msg.format(cookie_key))
|
||||
|
||||
validator = AuthenticationTokenValidator(token, jwt_secret)
|
||||
validator.validate_token()
|
||||
if validator.token_is_expired:
|
||||
token_expired = True
|
||||
elif validator.token_is_invalid:
|
||||
raise AuthenticationError('access token is invalid')
|
||||
else:
|
||||
account_id = validator.account_id
|
||||
|
||||
return token_expired, account_id
|
||||
|
||||
def _validate_account(self, account_id):
|
||||
"""The refresh token in the request must match the database value.
|
||||
|
||||
:raises: AuthenticationError
|
||||
"""
|
||||
def _get_account(self, account_id):
|
||||
"""Use account ID from decoded authentication token to get account."""
|
||||
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
|
||||
account_repository = AccountRepository(db)
|
||||
self.account = account_repository.get_account_by_id(account_id)
|
||||
|
||||
def _validate_account(self, account_id: str):
|
||||
"""Account must exist and contain have a refresh token matching request.
|
||||
|
||||
:raises: AuthenticationError
|
||||
"""
|
||||
if self.account is None:
|
||||
_log.error('account ID {} not on database'.format(account_id))
|
||||
raise AuthenticationError('account not found')
|
||||
|
||||
if self.access_token_expired:
|
||||
if self.refresh_token not in self.account.refresh_tokens:
|
||||
raise AuthenticationError('refresh token not found')
|
||||
if self.refresh_token.jwt not in self.account.refresh_tokens:
|
||||
log_msg = 'account ID {} does not have token {}'
|
||||
_log.error(log_msg.format(account_id, self.refresh_token.jwt))
|
||||
raise AuthenticationError(
|
||||
'refresh token does not exist for this account'
|
||||
)
|
||||
|
||||
def _refresh_auth_tokens(self):
|
||||
"""Steps necessary to refresh the tokens used for authentication."""
|
||||
old_refresh_token = self.refresh_token
|
||||
self._generate_tokens()
|
||||
self._update_refresh_token_on_db(old_refresh_token)
|
||||
self._set_token_cookies()
|
||||
|
||||
def _generate_tokens(self):
|
||||
token_generator = AuthenticationTokenGenerator(
|
||||
self.account.id,
|
||||
self.config['ACCESS_SECRET'],
|
||||
self.config['REFRESH_SECRET']
|
||||
)
|
||||
access_token = token_generator.access_token
|
||||
refresh_token = token_generator.refresh_token
|
||||
"""Generate an access token and refresh token."""
|
||||
self.access_token.generate()
|
||||
self.refresh_token.generate()
|
||||
|
||||
return access_token, refresh_token
|
||||
def _set_token_cookies(self, expire=False):
|
||||
"""Set the cookies that contain the authentication token.
|
||||
|
||||
def _set_token_cookies(self, access_token, refresh_token, expire=False):
|
||||
This method should be called when a user logs in, logs out, or when
|
||||
their access token expires.
|
||||
|
||||
:param expire: generate tokens that immediately expire, effectively
|
||||
logging a user out of the system.
|
||||
:return:
|
||||
"""
|
||||
access_token_cookie = dict(
|
||||
key='seleneAccess',
|
||||
value=str(access_token),
|
||||
value=str(self.access_token.jwt),
|
||||
domain=self.config['DOMAIN'],
|
||||
max_age=FIFTEEN_MINUTES,
|
||||
)
|
||||
refresh_token_cookie = dict(
|
||||
key='seleneRefresh',
|
||||
value=str(refresh_token),
|
||||
value=str(self.refresh_token.jwt),
|
||||
domain=self.config['DOMAIN'],
|
||||
max_age=ONE_MONTH,
|
||||
)
|
||||
|
@ -145,20 +149,21 @@ class SeleneEndpoint(MethodView):
|
|||
|
||||
@after_this_request
|
||||
def set_cookies(response):
|
||||
"""Use Flask after request hook to reset token cookies"""
|
||||
response.set_cookie(**access_token_cookie)
|
||||
response.set_cookie(**refresh_token_cookie)
|
||||
|
||||
return response
|
||||
|
||||
def _update_refresh_token_on_db(self, new_refresh_token):
|
||||
old_refresh_token = self.request.cookies['seleneRefresh']
|
||||
def _update_refresh_token_on_db(self, old_refresh_token):
|
||||
"""Replace the refresh token on the request with the newly minted one"""
|
||||
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
|
||||
token_repository = RefreshTokenRepository(db, self.account)
|
||||
if self.refresh_token_expired:
|
||||
token_repository = RefreshTokenRepository(db, self.account.id)
|
||||
if old_refresh_token.is_expired:
|
||||
token_repository.delete_refresh_token(old_refresh_token)
|
||||
raise AuthenticationError('refresh token expired')
|
||||
else:
|
||||
token_repository.update_refresh_token(
|
||||
new_refresh_token,
|
||||
old_refresh_token
|
||||
self.refresh_token.jwt,
|
||||
old_refresh_token.jwt
|
||||
)
|
||||
|
|
|
@ -3,9 +3,16 @@ from http import HTTPStatus
|
|||
from flask import Blueprint
|
||||
from schematics.exceptions import DataError
|
||||
|
||||
from selene.util.auth import AuthenticationError
|
||||
|
||||
selene_api = Blueprint('selene_api', __name__)
|
||||
|
||||
|
||||
@selene_api.app_errorhandler(DataError)
|
||||
def handle_data_error(error):
|
||||
return str(error.messages), HTTPStatus.BAD_REQUEST
|
||||
|
||||
|
||||
@selene_api.app_errorhandler(AuthenticationError)
|
||||
def handle_data_error(error):
|
||||
return dict(error=str(error)), HTTPStatus.UNAUTHORIZED
|
||||
|
|
|
@ -3,7 +3,7 @@ from dataclasses import asdict
|
|||
from datetime import date
|
||||
from http import HTTPStatus
|
||||
|
||||
from flask import json
|
||||
from flask import json, jsonify
|
||||
from schematics import Model
|
||||
from schematics.exceptions import ValidationError
|
||||
from schematics.types import BooleanType, EmailType, ModelType, StringType
|
||||
|
@ -67,7 +67,7 @@ class Support(Model):
|
|||
|
||||
|
||||
class AddAccountRequest(Model):
|
||||
display_name = StringType(required=True)
|
||||
username = StringType(required=True)
|
||||
privacy_policy = BooleanType(required=True, validators=[agreement_accepted])
|
||||
terms_of_use = BooleanType(required=True, validators=[agreement_accepted])
|
||||
login = ModelType(Login)
|
||||
|
@ -83,10 +83,9 @@ class AccountEndpoint(SeleneEndpoint):
|
|||
def get(self):
|
||||
"""Process HTTP GET request for an account."""
|
||||
self._authenticate()
|
||||
if self.authenticated:
|
||||
response_data = asdict(self.account)
|
||||
del (response_data['refresh_tokens'])
|
||||
self.response = response_data, HTTPStatus.OK
|
||||
response_data = asdict(self.account)
|
||||
del (response_data['refresh_tokens'])
|
||||
self.response = response_data, HTTPStatus.OK
|
||||
|
||||
return self.response
|
||||
|
||||
|
@ -96,11 +95,11 @@ class AccountEndpoint(SeleneEndpoint):
|
|||
email_address, password = self._determine_login_method()
|
||||
self._add_account(email_address, password)
|
||||
|
||||
return 'Account added successfully', HTTPStatus.OK
|
||||
return jsonify('Account added successfully'), HTTPStatus.OK
|
||||
|
||||
def _validate_request(self):
|
||||
add_request = AddAccountRequest(dict(
|
||||
display_name=self.request_data.get('displayName'),
|
||||
username=self.request_data.get('username'),
|
||||
privacy_policy=self.request_data.get('privacyPolicy'),
|
||||
terms_of_use=self.request_data.get('termsOfUse'),
|
||||
login=self._build_login_schematic(),
|
||||
|
@ -149,7 +148,7 @@ class AccountEndpoint(SeleneEndpoint):
|
|||
]
|
||||
account = Account(
|
||||
email_address=email_address,
|
||||
display_name=self.request_data['displayName'],
|
||||
username=self.request_data['username'],
|
||||
agreements=[
|
||||
AccountAgreement(type=PRIVACY_POLICY, accept_date=date.today()),
|
||||
AccountAgreement(type=TERMS_OF_USE, accept_date=date.today())
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from .authentication import (
|
||||
ACCESS_TOKEN_COOKIE_KEY,
|
||||
generate_auth_tokens,
|
||||
generate_access_token,
|
||||
generate_refresh_token,
|
||||
get_account,
|
||||
REFRESH_TOKEN_COOKIE_KEY,
|
||||
validate_token_cookies
|
||||
|
|
|
@ -5,34 +5,54 @@ from selene.data.account import (
|
|||
AccountRepository,
|
||||
RefreshTokenRepository
|
||||
)
|
||||
from selene.util.auth import AuthenticationTokenGenerator
|
||||
from selene.util.auth import AuthenticationToken
|
||||
from selene.util.db import get_db_connection
|
||||
|
||||
ACCESS_TOKEN_COOKIE_KEY = 'seleneAccess'
|
||||
ONE_MINUTE = 60
|
||||
TWO_MINUTES = 120
|
||||
REFRESH_TOKEN_COOKIE_KEY = 'seleneRefresh'
|
||||
|
||||
|
||||
def generate_auth_tokens(context):
|
||||
token_generator = AuthenticationTokenGenerator(
|
||||
context.account.id,
|
||||
def generate_access_token(context, expire=False):
|
||||
access_token = AuthenticationToken(
|
||||
context.client_config['ACCESS_SECRET'],
|
||||
context.client_config['REFRESH_SECRET']
|
||||
ONE_MINUTE
|
||||
)
|
||||
access_token.account_id = context.account.id
|
||||
if not expire:
|
||||
access_token.generate()
|
||||
context.access_token = access_token
|
||||
|
||||
context.client.set_cookie(
|
||||
context.client_config['DOMAIN'],
|
||||
ACCESS_TOKEN_COOKIE_KEY,
|
||||
token_generator.access_token
|
||||
access_token.jwt,
|
||||
max_age=0 if expire else ONE_MINUTE
|
||||
)
|
||||
|
||||
|
||||
def generate_refresh_token(context, expire=False):
|
||||
account_id = context.account.id
|
||||
refresh_token = AuthenticationToken(
|
||||
context.client_config['REFRESH_SECRET'],
|
||||
TWO_MINUTES
|
||||
)
|
||||
refresh_token.account_id = account_id
|
||||
if not expire:
|
||||
refresh_token.generate()
|
||||
context.refresh_token = refresh_token
|
||||
|
||||
context.client.set_cookie(
|
||||
context.client_config['DOMAIN'],
|
||||
REFRESH_TOKEN_COOKIE_KEY,
|
||||
token_generator.refresh_token
|
||||
refresh_token.jwt,
|
||||
max_age=0 if expire else TWO_MINUTES
|
||||
)
|
||||
context.request_refresh_token = token_generator.refresh_token
|
||||
|
||||
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
|
||||
token_repository = RefreshTokenRepository(db, context.account.id)
|
||||
token_repository.add_refresh_token(token_generator.refresh_token)
|
||||
token_repository = RefreshTokenRepository(db, account_id)
|
||||
token_repository.add_refresh_token(refresh_token.jwt)
|
||||
|
||||
|
||||
def validate_token_cookies(context, expired=False):
|
||||
|
|
|
@ -24,7 +24,7 @@ class AccountSubscription(object):
|
|||
class Account(object):
|
||||
"""Representation of a Mycroft user account."""
|
||||
email_address: str
|
||||
display_name: str
|
||||
username: str
|
||||
agreements: List[AccountAgreement]
|
||||
subscription: AccountSubscription
|
||||
id: str = None
|
||||
|
|
|
@ -48,7 +48,7 @@ class AccountRepository(object):
|
|||
args=dict(
|
||||
email_address=account.email_address,
|
||||
password=encrypted_password,
|
||||
display_name=account.display_name
|
||||
username=account.username
|
||||
)
|
||||
)
|
||||
result = self.cursor.insert_returning(request)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
INSERT INTO
|
||||
account.account (email_address, password, display_name)
|
||||
account.account (email_address, password, username)
|
||||
VALUES
|
||||
(%(email_address)s, %(password)s, %(display_name)s)
|
||||
(%(email_address)s, %(password)s, %(username)s)
|
||||
RETURNING
|
||||
id
|
||||
|
|
|
@ -41,7 +41,7 @@ SELECT
|
|||
json_build_object(
|
||||
'id', id,
|
||||
'email_address', email_address,
|
||||
'display_name', display_name,
|
||||
'username', username,
|
||||
'subscription', (SELECT * FROM subscription),
|
||||
'refresh_tokens', (SELECT * FROM refresh_tokens),
|
||||
'agreements', (SELECT * FROM agreements)
|
||||
|
|
|
@ -1,86 +1,50 @@
|
|||
"""Logic for generating and validating JWT authentication tokens."""
|
||||
from datetime import datetime
|
||||
from time import time
|
||||
|
||||
import jwt
|
||||
|
||||
FIFTEEN_MINUTES = 900
|
||||
ONE_MONTH = 2628000
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AuthenticationTokenGenerator(object):
|
||||
_access_token = None
|
||||
_refresh_token = None
|
||||
class AuthenticationToken(object):
|
||||
def __init__(self, secret: str, duration: int):
|
||||
self.secret = secret
|
||||
self.duration = duration
|
||||
self.jwt: str = ''
|
||||
self.is_valid: bool = None
|
||||
self.is_expired: bool = None
|
||||
self.account_id: str = None
|
||||
|
||||
def __init__(self, account_id: str, access_secret, refresh_secret):
|
||||
self.account_id = account_id
|
||||
self.access_secret = access_secret
|
||||
self.refresh_secret = refresh_secret
|
||||
|
||||
def _generate_token(self, token_duration: int):
|
||||
def generate(self):
|
||||
"""
|
||||
Generates a JWT token
|
||||
"""
|
||||
token_expiration = time() + token_duration
|
||||
payload = dict(
|
||||
iat=datetime.utcnow(),
|
||||
exp=token_expiration,
|
||||
exp=time() + self.duration,
|
||||
sub=self.account_id
|
||||
)
|
||||
|
||||
if token_duration == FIFTEEN_MINUTES:
|
||||
secret = self.access_secret
|
||||
else:
|
||||
secret = self.refresh_secret
|
||||
|
||||
token = jwt.encode(
|
||||
payload,
|
||||
secret,
|
||||
algorithm='HS256'
|
||||
)
|
||||
token = jwt.encode(payload, self.secret, algorithm='HS256')
|
||||
|
||||
# convert the token from byte-array to string so that
|
||||
# it can be included in a JSON response object
|
||||
return token.decode()
|
||||
self.jwt = token.decode()
|
||||
|
||||
@property
|
||||
def access_token(self):
|
||||
"""
|
||||
Generates a JWT access token
|
||||
"""
|
||||
if self._access_token is None:
|
||||
self._access_token = self._generate_token(FIFTEEN_MINUTES)
|
||||
def validate(self):
|
||||
"""Decodes the auth token and performs some preliminary validation."""
|
||||
self.is_expired = False
|
||||
self.is_valid = True
|
||||
|
||||
return self._access_token
|
||||
|
||||
@property
|
||||
def refresh_token(self):
|
||||
"""
|
||||
Generates a JWT access token
|
||||
"""
|
||||
if self._refresh_token is None:
|
||||
self._refresh_token = self._generate_token(ONE_MONTH)
|
||||
|
||||
return self._refresh_token
|
||||
|
||||
|
||||
class AuthenticationTokenValidator(object):
|
||||
def __init__(self, token: str, secret: str):
|
||||
self.token = token
|
||||
self.secret = secret
|
||||
self.account_id = None
|
||||
self.token_is_expired = False
|
||||
self.token_is_invalid = False
|
||||
|
||||
def validate_token(self):
|
||||
"""Decodes the auth token"""
|
||||
try:
|
||||
payload = jwt.decode(self.token, self.secret)
|
||||
self.account_id = payload['sub']
|
||||
except jwt.ExpiredSignatureError:
|
||||
self.token_is_expired = True
|
||||
except jwt.InvalidTokenError:
|
||||
self.token_is_invalid = True
|
||||
if self.jwt is None:
|
||||
self.is_expired = True
|
||||
else:
|
||||
try:
|
||||
payload = jwt.decode(self.jwt, self.secret)
|
||||
self.account_id = payload['sub']
|
||||
except jwt.ExpiredSignatureError:
|
||||
self.is_expired = True
|
||||
except jwt.InvalidTokenError:
|
||||
self.is_valid = False
|
||||
|
|
|
@ -16,7 +16,6 @@ setup(
|
|||
'pyhamcrest',
|
||||
'pyjwt',
|
||||
'psycopg2-binary',
|
||||
'schematics',
|
||||
'validator-collection'
|
||||
'schematics'
|
||||
]
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue