Merge pull request #53 from MycroftAI/account-api

Cleanup and testing
pull/54/head
Chris Veilleux 2019-02-20 13:03:54 -06:00 committed by GitHub
commit 75fe9eed96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 512 additions and 304 deletions

View File

@ -0,0 +1,26 @@
Feature: Authentication with JWTs
Some of the API endpoints contain information that is specific to a user.
To ensure that information is seen only by the user that owns it, we will
use a login mechanism coupled with authentication tokens to securely identify
a user.
The code executed in these tests is embedded in every view call. These tests
apply to any endpoint that requires authentication. These tests are meant to
be the only place authentication logic needs to be tested.
Scenario: Request for user data includes valid access token
Given an authenticated user
When a user requests their profile
Then the request will be successful
And the authentication tokens will remain unchanged
Scenario: Access token expired
Given an authenticated user with an expired access token
When a user requests their profile
Then the request will be successful
And the authentication tokens will be refreshed
Scenario: Both access and refresh tokens expired
Given a previously authenticated user with expired tokens
When a user requests their profile
Then the request will fail with an unauthorized error

View File

@ -1,5 +1,4 @@
from datetime import date, timedelta from datetime import date, timedelta
import os
from behave import fixture, use_fixture from behave import fixture, use_fixture
@ -28,11 +27,9 @@ def acct_api_client(context):
def before_feature(context, _): def before_feature(context, _):
use_fixture(acct_api_client, context) use_fixture(acct_api_client, context)
os.environ['SALT'] = 'testsalt'
def before_scenario(context, _): def before_scenario(context, _):
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
_add_agreements(context, db) _add_agreements(context, db)
_add_account(context, db) _add_account(context, db)
@ -61,7 +58,7 @@ def _add_agreements(context, db):
def _add_account(context, db): def _add_account(context, db):
context.account = Account( context.account = Account(
email_address='foo@mycroft.ai', email_address='foo@mycroft.ai',
display_name='foobar', username='foobar',
refresh_tokens=[], refresh_tokens=[],
subscription=AccountSubscription( subscription=AccountSubscription(
type='Monthly Supporter', type='Monthly Supporter',

View File

@ -4,5 +4,5 @@ Feature: Manage account profiles
Scenario: Retrieve authenticated user's account Scenario: Retrieve authenticated user's account
Given an authenticated user Given an authenticated user
When account endpoint is called to get user profile When a user requests their profile
Then user profile is returned Then user profile is returned

View File

@ -0,0 +1,36 @@
from behave import given, then
from hamcrest import assert_that, equal_to, is_not
from selene.api.testing import (
generate_access_token,
generate_refresh_token,
validate_token_cookies
)
@given('an authenticated user with an expired access token')
def generate_refresh_token_only(context):
generate_access_token(context, expire=True)
generate_refresh_token(context)
context.old_refresh_token = context.refresh_token.jwt
@given('a previously authenticated user with expired tokens')
def expire_both_tokens(context):
generate_access_token(context, expire=True)
generate_refresh_token(context, expire=True)
@then('the authentication tokens will remain unchanged')
def check_for_no_new_cookie(context):
cookies = context.response.headers.getlist('Set-Cookie')
assert_that(cookies, equal_to([]))
@then('the authentication tokens will be refreshed')
def check_for_new_cookies(context):
validate_token_cookies(context)
assert_that(
context.refresh_token,
is_not(equal_to(context.old_refresh_token))
)

View File

@ -0,0 +1,25 @@
from http import HTTPStatus
from behave import then
from hamcrest import assert_that, equal_to
@then('the request will be successful')
def check_request_success(context):
assert_that(context.response.status_code, equal_to(HTTPStatus.OK))
@then('the request will fail with {error_type} error')
def check_for_bad_request(context, error_type):
if error_type == 'a bad request':
assert_that(
context.response.status_code,
equal_to(HTTPStatus.BAD_REQUEST)
)
elif error_type == 'an unauthorized':
assert_that(
context.response.status_code,
equal_to(HTTPStatus.UNAUTHORIZED)
)
else:
raise ValueError('unsupported error_type')

View File

@ -9,7 +9,7 @@ from selene.data.account import AccountRepository, PRIVACY_POLICY, TERMS_OF_USE
from selene.util.db import get_db_connection from selene.util.db import get_db_connection
new_account_request = dict( new_account_request = dict(
displayName='barfoo', username='barfoo',
termsOfUse=True, termsOfUse=True,
privacyPolicy=True, privacyPolicy=True,
login=dict( login=dict(
@ -19,7 +19,7 @@ new_account_request = dict(
), ),
support=dict( support=dict(
openDataset=True, openDataset=True,
membership='Monthly Supporter', membership='MONTHLY SUPPORTER',
stripeCustomerId='barstripe' stripeCustomerId='barstripe'
) )
) )
@ -50,11 +50,6 @@ def create_account_without_email(context):
) )
@then('the request will be successful')
def check_request_success(context):
assert_that(context.response.status_code, equal_to(HTTPStatus.OK))
@then('the account will be added to the system') @then('the account will be added to the system')
def check_db_for_account(context): def check_db_for_account(context):
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
@ -64,7 +59,7 @@ def check_db_for_account(context):
assert_that( assert_that(
account.email_address, equal_to('bar@mycroft.ai') account.email_address, equal_to('bar@mycroft.ai')
) )
assert_that(account.display_name, equal_to('barfoo')) assert_that(account.username, equal_to('barfoo'))
assert_that(account.subscription.type, equal_to('Monthly Supporter')) assert_that(account.subscription.type, equal_to('Monthly Supporter'))
assert_that( assert_that(
account.subscription.stripe_customer_id, account.subscription.stripe_customer_id,
@ -74,8 +69,3 @@ def check_db_for_account(context):
for agreement in account.agreements: for agreement in account.agreements:
assert_that(agreement.type, is_in((PRIVACY_POLICY, TERMS_OF_USE))) assert_that(agreement.type, is_in((PRIVACY_POLICY, TERMS_OF_USE)))
assert_that(agreement.accept_date, equal_to(str(date.today()))) assert_that(agreement.accept_date, equal_to(str(date.today())))
@then('the request will fail with a bad request error')
def check_for_bad_request(context):
assert_that(context.response.status_code, equal_to(HTTPStatus.BAD_REQUEST))

View File

@ -5,16 +5,17 @@ import json
from behave import given, then, when from behave import given, then, when
from hamcrest import assert_that, equal_to, has_item from hamcrest import assert_that, equal_to, has_item
from selene.api.testing import generate_auth_tokens from selene.api.testing import generate_access_token, generate_refresh_token
from selene.data.account import PRIVACY_POLICY from selene.data.account import PRIVACY_POLICY
@given('an authenticated user') @given('an authenticated user')
def setup_authenticated_user(context): def setup_authenticated_user(context):
generate_auth_tokens(context) generate_access_token(context)
generate_refresh_token(context)
@when('account endpoint is called to get user profile') @when('a user requests their profile')
def call_account_endpoint(context): def call_account_endpoint(context):
context.response = context.client.get('/api/account') context.response = context.client.get('/api/account')

View File

@ -5,12 +5,11 @@ name = "pypi"
[packages] [packages]
flask = "*" flask = "*"
flask-restful = "*"
certifi = "*" certifi = "*"
uwsgi = "*" uwsgi = "*"
[dev-packages] [dev-packages]
selene = {path = "./../../shared"} selene = {editable = true,path = "./../../shared"}
behave = "*" behave = "*"
pyhamcrest = "*" pyhamcrest = "*"

216
api/sso/Pipfile.lock generated
View File

@ -1,7 +1,7 @@
{ {
"_meta": { "_meta": {
"hash": { "hash": {
"sha256": "d82f5a7209d972cbf70f44f620e09c550dca2c32cea6fb419f90e9f898580c03" "sha256": "e27bc9018c42543c8594ffade1899d7d7c9cef2117f4c48462b0971310caeb0f"
}, },
"pipfile-spec": 6, "pipfile-spec": 6,
"requires": { "requires": {
@ -16,13 +16,6 @@
] ]
}, },
"default": { "default": {
"aniso8601": {
"hashes": [
"sha256:03c0ffeeb04edeca1ed59684cc6836dc377f58e52e315dc7be3af879909889f4",
"sha256:ac30cceff24aec920c37b8d74d7d8a5dd37b1f62a90b4f268a6234cabe147080"
],
"version": "==4.1.0"
},
"certifi": { "certifi": {
"hashes": [ "hashes": [
"sha256:47f9c83ef4c0c621eaef743f133f09fa8a74a9b75f037e8624f83bd1b6626cb7", "sha256:47f9c83ef4c0c621eaef743f133f09fa8a74a9b75f037e8624f83bd1b6626cb7",
@ -46,14 +39,6 @@
"index": "pypi", "index": "pypi",
"version": "==1.0.2" "version": "==1.0.2"
}, },
"flask-restful": {
"hashes": [
"sha256:ecd620c5cc29f663627f99e04f17d1f16d095c83dc1d618426e2ad68b03092f8",
"sha256:f8240ec12349afe8df1db168ea7c336c4e5b0271a36982bff7394f93275f2ca9"
],
"index": "pypi",
"version": "==0.3.7"
},
"itsdangerous": { "itsdangerous": {
"hashes": [ "hashes": [
"sha256:321b033d07f2a4136d3ec762eac9f16a10ccd60f53c0c91af90217ace7ba1f19", "sha256:321b033d07f2a4136d3ec762eac9f16a10ccd60f53c0c91af90217ace7ba1f19",
@ -101,26 +86,12 @@
], ],
"version": "==1.1.0" "version": "==1.1.0"
}, },
"pytz": {
"hashes": [
"sha256:32b0891edff07e28efe91284ed9c31e123d84bea3fd98e1f72be2508f43ef8d9",
"sha256:d5f05e487007e29e03409f9398d074e158d920d36eb82eaf66fb1136b0c5374c"
],
"version": "==2018.9"
},
"six": {
"hashes": [
"sha256:3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c",
"sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73"
],
"version": "==1.12.0"
},
"uwsgi": { "uwsgi": {
"hashes": [ "hashes": [
"sha256:d2318235c74665a60021a4fc7770e9c2756f9fc07de7b8c22805efe85b5ab277" "sha256:4972ac538800fb2d421027f49b4a1869b66048839507ccf0aa2fda792d99f583"
], ],
"index": "pypi", "index": "pypi",
"version": "==2.0.17.1" "version": "==2.0.18"
}, },
"werkzeug": { "werkzeug": {
"hashes": [ "hashes": [
@ -139,6 +110,97 @@
"index": "pypi", "index": "pypi",
"version": "==1.2.6" "version": "==1.2.6"
}, },
"certifi": {
"hashes": [
"sha256:47f9c83ef4c0c621eaef743f133f09fa8a74a9b75f037e8624f83bd1b6626cb7",
"sha256:993f830721089fef441cdfeb4b2c8c9df86f0c63239f06bd025a76a7daddb033"
],
"index": "pypi",
"version": "==2018.11.29"
},
"chardet": {
"hashes": [
"sha256:84ab92ed1c4d4f16916e05906b6b75a6c0fb5db821cc65e70cbd64a3e2a5eaae",
"sha256:fc323ffcaeaed0e0a02bf4d117757b98aed530d9ed4531e3e15460124c106691"
],
"version": "==3.0.4"
},
"click": {
"hashes": [
"sha256:2335065e6395b9e67ca716de5f7526736bfa6ceead690adf616d925bdc622b13",
"sha256:5b94b49521f6456670fdb30cd82a4eca9412788a93fa6dd6df72c94d5a8ff2d7"
],
"version": "==7.0"
},
"deprecated": {
"hashes": [
"sha256:8bfeba6e630abf42b5d111b68a05f7fe3d6de7004391b3cd614947594f87a4ff",
"sha256:b784e0ca85a8c1e694d77e545c10827bd99772392e79d5f5442e761515a1246e"
],
"version": "==1.2.4"
},
"flask": {
"hashes": [
"sha256:2271c0070dbcb5275fad4a82e29f23ab92682dc45f9dfbc22c02ba9b9322ce48",
"sha256:a080b744b7e345ccfcbc77954861cb05b3c63786e93f2b3875e0913d44b43f05"
],
"index": "pypi",
"version": "==1.0.2"
},
"idna": {
"hashes": [
"sha256:c357b3f628cf53ae2c4c05627ecc484553142ca23264e593d327bcde5e9c3407",
"sha256:ea8b7f6188e6fa117537c3df7da9fc686d485087abf6ac197f9c46432f7e4a3c"
],
"version": "==2.8"
},
"itsdangerous": {
"hashes": [
"sha256:321b033d07f2a4136d3ec762eac9f16a10ccd60f53c0c91af90217ace7ba1f19",
"sha256:b12271b2047cb23eeb98c8b5622e2e5c5e9abd9784a153e9d8ef9cb4dd09d749"
],
"version": "==1.1.0"
},
"jinja2": {
"hashes": [
"sha256:74c935a1b8bb9a3947c50a54766a969d4846290e1e788ea44c1392163723c3bd",
"sha256:f84be1bb0040caca4cea721fcbbbbd61f9be9464ca236387158b0feea01914a4"
],
"version": "==2.10"
},
"markupsafe": {
"hashes": [
"sha256:048ef924c1623740e70204aa7143ec592504045ae4429b59c30054cb31e3c432",
"sha256:130f844e7f5bdd8e9f3f42e7102ef1d49b2e6fdf0d7526df3f87281a532d8c8b",
"sha256:19f637c2ac5ae9da8bfd98cef74d64b7e1bb8a63038a3505cd182c3fac5eb4d9",
"sha256:1b8a7a87ad1b92bd887568ce54b23565f3fd7018c4180136e1cf412b405a47af",
"sha256:1c25694ca680b6919de53a4bb3bdd0602beafc63ff001fea2f2fc16ec3a11834",
"sha256:1f19ef5d3908110e1e891deefb5586aae1b49a7440db952454b4e281b41620cd",
"sha256:1fa6058938190ebe8290e5cae6c351e14e7bb44505c4a7624555ce57fbbeba0d",
"sha256:31cbb1359e8c25f9f48e156e59e2eaad51cd5242c05ed18a8de6dbe85184e4b7",
"sha256:3e835d8841ae7863f64e40e19477f7eb398674da6a47f09871673742531e6f4b",
"sha256:4e97332c9ce444b0c2c38dd22ddc61c743eb208d916e4265a2a3b575bdccb1d3",
"sha256:525396ee324ee2da82919f2ee9c9e73b012f23e7640131dd1b53a90206a0f09c",
"sha256:52b07fbc32032c21ad4ab060fec137b76eb804c4b9a1c7c7dc562549306afad2",
"sha256:52ccb45e77a1085ec5461cde794e1aa037df79f473cbc69b974e73940655c8d7",
"sha256:5c3fbebd7de20ce93103cb3183b47671f2885307df4a17a0ad56a1dd51273d36",
"sha256:5e5851969aea17660e55f6a3be00037a25b96a9b44d2083651812c99d53b14d1",
"sha256:5edfa27b2d3eefa2210fb2f5d539fbed81722b49f083b2c6566455eb7422fd7e",
"sha256:7d263e5770efddf465a9e31b78362d84d015cc894ca2c131901a4445eaa61ee1",
"sha256:83381342bfc22b3c8c06f2dd93a505413888694302de25add756254beee8449c",
"sha256:857eebb2c1dc60e4219ec8e98dfa19553dae33608237e107db9c6078b1167856",
"sha256:98e439297f78fca3a6169fd330fbe88d78b3bb72f967ad9961bcac0d7fdd1550",
"sha256:bf54103892a83c64db58125b3f2a43df6d2cb2d28889f14c78519394feb41492",
"sha256:d9ac82be533394d341b41d78aca7ed0e0f4ba5a2231602e2f05aa87f25c51672",
"sha256:e982fe07ede9fada6ff6705af70514a52beb1b2c3d25d4e873e82114cf3c5401",
"sha256:edce2ea7f3dfc981c4ddc97add8a61381d9642dc3273737e756517cc03e84dd6",
"sha256:efdc45ef1afc238db84cb4963aa689c0408912a0239b0721cb172b4016eb31d6",
"sha256:f137c02498f8b935892d5c0172560d7ab54bc45039de8805075e19079c639a9c",
"sha256:f82e347a72f955b7017a39708a3667f106e6ad4d10b25f237396a7115d8ed5fd",
"sha256:fb7c206e01ad85ce57feeaaa0bf784b97fa3cad0d4a5737bc5295785f5c613a1"
],
"version": "==1.1.0"
},
"parse": { "parse": {
"hashes": [ "hashes": [
"sha256:870dd675c1ee8951db3e29b81ebe44fd131e3eb8c03a79483a58ea574f3145c2" "sha256:870dd675c1ee8951db3e29b81ebe44fd131e3eb8c03a79483a58ea574f3145c2"
@ -152,6 +214,54 @@
], ],
"version": "==0.4.2" "version": "==0.4.2"
}, },
"passlib": {
"hashes": [
"sha256:3d948f64138c25633613f303bcc471126eae67c04d5e3f6b7b8ce6242f8653e0",
"sha256:43526aea08fa32c6b6dbbbe9963c4c767285b78147b7437597f992812f69d280"
],
"version": "==1.7.1"
},
"psycopg2-binary": {
"hashes": [
"sha256:19a2d1f3567b30f6c2bb3baea23f74f69d51f0c06c2e2082d0d9c28b0733a4c2",
"sha256:2b69cf4b0fa2716fd977aa4e1fd39af6110eb47b2bb30b4e5a469d8fbecfc102",
"sha256:2e952fa17ba48cbc2dc063ddeec37d7dc4ea0ef7db0ac1eda8906365a8543f31",
"sha256:348b49dd737ff74cfb5e663e18cb069b44c64f77ec0523b5794efafbfa7df0b8",
"sha256:3d72a5fdc5f00ca85160915eb9a973cf9a0ab8148f6eda40708bf672c55ac1d1",
"sha256:4957452f7868f43f32c090dadb4188e9c74a4687323c87a882e943c2bd4780c3",
"sha256:5138cec2ee1e53a671e11cc519505eb08aaaaf390c508f25b09605763d48de4b",
"sha256:587098ca4fc46c95736459d171102336af12f0d415b3b865972a79c03f06259f",
"sha256:5b79368bcdb1da4a05f931b62760bea0955ee2c81531d8e84625df2defd3f709",
"sha256:5cf43807392247d9bc99737160da32d3fa619e0bfd85ba24d1c78db205f472a4",
"sha256:676d1a80b1eebc0cacae8dd09b2fde24213173bf65650d22b038c5ed4039f392",
"sha256:6b0211ecda389101a7d1d3df2eba0cf7ffbdd2480ca6f1d2257c7bd739e84110",
"sha256:79cde4660de6f0bb523c229763bd8ad9a93ac6760b72c369cf1213955c430934",
"sha256:7aba9786ac32c2a6d5fb446002ed936b47d5e1f10c466ef7e48f66eb9f9ebe3b",
"sha256:7c8159352244e11bdd422226aa17651110b600d175220c451a9acf795e7414e0",
"sha256:945f2eedf4fc6b2432697eb90bb98cc467de5147869e57405bfc31fa0b824741",
"sha256:96b4e902cde37a7fc6ab306b3ac089a3949e6ce3d824eeca5b19dc0bedb9f6e2",
"sha256:9a7bccb1212e63f309eb9fab47b6eaef796f59850f169a25695b248ca1bf681b",
"sha256:a3bfcac727538ec11af304b5eccadbac952d4cca1a551a29b8fe554e3ad535dc",
"sha256:b19e9f1b85c5d6136f5a0549abdc55dcbd63aba18b4f10d0d063eb65ef2c68b4",
"sha256:b664011bb14ca1f2287c17185e222f2098f7b4c857961dbcf9badb28786dbbf4",
"sha256:bde7959ef012b628868d69c474ec4920252656d0800835ed999ba5e4f57e3e2e",
"sha256:cb095a0657d792c8de9f7c9a0452385a309dfb1bbbb3357d6b1e216353ade6ca",
"sha256:d16d42a1b9772152c1fe606f679b2316551f7e1a1ce273e7f808e82a136cdb3d",
"sha256:d444b1545430ffc1e7a24ce5a9be122ccd3b135a7b7e695c5862c5aff0b11159",
"sha256:d93ccc7bf409ec0a23f2ac70977507e0b8a8d8c54e5ee46109af2f0ec9e411f3",
"sha256:df6444f952ca849016902662e1a47abf4fa0678d75f92fd9dd27f20525f809cd",
"sha256:e63850d8c52ba2b502662bf3c02603175c2397a9acc756090e444ce49508d41e",
"sha256:ec43358c105794bc2b6fd34c68d27f92bea7102393c01889e93f4b6a70975728",
"sha256:f4c6926d9c03dadce7a3b378b40d2fea912c1344ef9b29869f984fb3d2a2420b"
],
"version": "==2.7.7"
},
"pygithub": {
"hashes": [
"sha256:263102b43a83e2943900c1313109db7a00b3b78aeeae2c9137ba694982864872"
],
"version": "==1.43.5"
},
"pyhamcrest": { "pyhamcrest": {
"hashes": [ "hashes": [
"sha256:6b672c02fdf7470df9674ab82263841ce8333fb143f32f021f6cb26f0e512420", "sha256:6b672c02fdf7470df9674ab82263841ce8333fb143f32f021f6cb26f0e512420",
@ -160,7 +270,29 @@
"index": "pypi", "index": "pypi",
"version": "==1.9.0" "version": "==1.9.0"
}, },
"pyjwt": {
"hashes": [
"sha256:5c6eca3c2940464d106b99ba83b00c6add741c9becaec087fb7ccdefea71350e",
"sha256:8d59a976fb773f3e6a39c85636357c4f0e242707394cadadd9814f5cbaa20e96"
],
"version": "==1.7.1"
},
"requests": {
"hashes": [
"sha256:502a824f31acdacb3a35b6690b5fbf0bc41d63a24a45c4004352b0242707598e",
"sha256:7bf2a778576d825600030a110f3c0e3e8edc51dfaafe1c146e39a2027784957b"
],
"version": "==2.21.0"
},
"schematics": {
"hashes": [
"sha256:8fcc6182606fd0b24410a1dbb066d9bbddbe8da9c9509f47b743495706239283",
"sha256:a40b20635c0e43d18d3aff76220f6cd95ea4decb3f37765e49529b17d81b0439"
],
"version": "==2.1.0"
},
"selene": { "selene": {
"editable": true,
"path": "./../../shared" "path": "./../../shared"
}, },
"six": { "six": {
@ -169,6 +301,26 @@
"sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73" "sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73"
], ],
"version": "==1.12.0" "version": "==1.12.0"
},
"urllib3": {
"hashes": [
"sha256:61bf29cada3fc2fbefad4fdf059ea4bd1b4a86d2b6d15e1c7c0b582b9752fe39",
"sha256:de9529817c93f27c8ccbfead6985011db27bd0ddfcdb2d86f3f663385c6a9c22"
],
"version": "==1.24.1"
},
"werkzeug": {
"hashes": [
"sha256:c3fd7a7d41976d9f44db327260e263132466836cef6f91512889ed60ad26557c",
"sha256:d5da73735293558eb1651ee2fddc4d0dedcfa06538b8813a2e20011583c9e49b"
],
"version": "==0.14.1"
},
"wrapt": {
"hashes": [
"sha256:4aea003270831cceb8a90ff27c4031da6ead7ec1886023b80ce0dfe0adf61533"
],
"version": "==1.11.1"
} }
} }
} }

View File

@ -1,11 +1,9 @@
"""Define the API that will support Mycroft single sign on (SSO).""" """Define the API that will support Mycroft single sign on (SSO)."""
from logging import getLogger
from flask import Flask, request from flask import Flask, request
from flask_restful import Api
from selene.api.base_config import get_base_config from selene.api import get_base_config, selene_api, SeleneResponse
from selene.util.log import configure_logger
from .endpoints import ( from .endpoints import (
AuthenticateInternalEndpoint, AuthenticateInternalEndpoint,
@ -13,18 +11,30 @@ from .endpoints import (
ValidateFederatedEndpoint ValidateFederatedEndpoint
) )
_log = getLogger('sso_api') _log = configure_logger('sso_api')
# Initialize the Flask application and the Flask Restful API # Define the Flask application
sso = Flask(__name__) sso = Flask(__name__)
sso.config.from_object(get_base_config()) sso.config.from_object(get_base_config())
sso.response_class = SeleneResponse
sso.register_blueprint(selene_api)
# Initialize the REST API and define the endpoints # Define the endpoints
sso_api = Api(sso, catch_all_404s=True) sso.add_url_rule(
sso_api.add_resource(AuthenticateInternalEndpoint, '/api/internal-login') '/api/internal-login',
sso_api.add_resource(ValidateFederatedEndpoint, '/api/validate-federated') view_func=AuthenticateInternalEndpoint.as_view('internal_login'),
methods=['GET']
sso_api.add_resource(LogoutEndpoint, '/api/logout') )
sso.add_url_rule(
'/api/validate-federated',
view_func=ValidateFederatedEndpoint.as_view('federated_login'),
methods=['POST']
)
sso.add_url_rule(
'/api/logout',
view_func=LogoutEndpoint.as_view('logout'),
methods=['GET']
)
def add_cors_headers(response): def add_cors_headers(response):

View File

@ -18,20 +18,16 @@ class AuthenticateInternalEndpoint(SeleneEndpoint):
"""Sign in a user with an email address and password.""" """Sign in a user with an email address and password."""
def __init__(self): def __init__(self):
super(AuthenticateInternalEndpoint, self).__init__() super(AuthenticateInternalEndpoint, self).__init__()
self.response_status_code = HTTPStatus.OK
self.account: Account = None self.account: Account = None
def get(self): def get(self):
"""Process HTTP GET request.""" """Process HTTP GET request."""
try:
self._authenticate_credentials() self._authenticate_credentials()
access_token, refresh_token = self._generate_tokens() self._generate_tokens()
self._add_refresh_token_to_db(refresh_token) self._add_refresh_token_to_db()
self._set_token_cookies(access_token, refresh_token) self._set_token_cookies()
except AuthenticationError as ae:
self.response = (str(ae), HTTPStatus.UNAUTHORIZED) self.response = dict(result='user authenticated'), HTTPStatus.OK
else:
self.response = ({}, HTTPStatus.OK)
return self.response return self.response
@ -52,15 +48,15 @@ class AuthenticateInternalEndpoint(SeleneEndpoint):
) )
if self.account is None: if self.account is None:
raise AuthenticationError('provided credentials not found') raise AuthenticationError('provided credentials not found')
self.access_token.account_id = self.account.id
self.refresh_token.account_id = self.account.id
def _add_refresh_token_to_db(self, refresh_token: str): def _add_refresh_token_to_db(self):
"""Track refresh tokens in the database. """Track refresh tokens in the database.
We need to store the value of the refresh token in the database so We need to store the value of the refresh token in the database so
that we can validate it when it is used to request new tokens. that we can validate it when it is used to request new tokens.
:param refresh_token: the token to install into the database.
""" """
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db: with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
token_repo = RefreshTokenRepository(db, self.account) token_repo = RefreshTokenRepository(db, self.account.id)
token_repo.add_refresh_token(refresh_token) token_repo.add_refresh_token(self.refresh_token.jwt)

View File

@ -13,7 +13,6 @@ _log = getLogger(__package__)
class LogoutEndpoint(SeleneEndpoint): class LogoutEndpoint(SeleneEndpoint):
def get(self): def get(self):
self._authenticate() self._authenticate()
if self.authenticated or self.refresh_token_expired:
self._logout() self._logout()
return self.response return self.response
@ -26,9 +25,9 @@ class LogoutEndpoint(SeleneEndpoint):
""" """
request_refresh_token = self.request.cookies['seleneRefresh'] request_refresh_token = self.request.cookies['seleneRefresh']
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db: with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
token_repository = RefreshTokenRepository(db, self.account) token_repository = RefreshTokenRepository(db, self.account.id)
token_repository.delete_refresh_token(request_refresh_token) token_repository.delete_refresh_token(request_refresh_token)
access_token, refresh_token = self._generate_tokens() self._generate_tokens()
self._set_token_cookies(access_token, refresh_token, expire=True) self._set_token_cookies(expire=True)
self.response = ('logged out', HTTPStatus.OK) self.response = ('logged out', HTTPStatus.OK)

View File

@ -16,19 +16,15 @@ from selene.util.db import get_db_connection
class ValidateFederatedEndpoint(SeleneEndpoint): class ValidateFederatedEndpoint(SeleneEndpoint):
def post(self): def post(self):
"""Process a HTTP POST request.""" """Process a HTTP POST request."""
try: self._get_account_by_email()
self._get_account() self._generate_tokens()
except AuthenticationError as ae: self._set_token_cookies()
self.response = str(ae), HTTPStatus.UNAUTHORIZED self._add_refresh_token_to_db()
else: self.response = dict(result='account validated'), HTTPStatus.OK
access_token, refresh_token = self._generate_tokens()
self._set_token_cookies(access_token, refresh_token)
self._add_refresh_token_to_db(refresh_token)
self.response = 'account validated', HTTPStatus.OK
return self.response return self.response
def _get_account(self): def _get_account_by_email(self):
"""Use email returned by the authentication platform for validation""" """Use email returned by the authentication platform for validation"""
email_address = self.request.form['email'] email_address = self.request.form['email']
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db: with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
@ -38,14 +34,13 @@ class ValidateFederatedEndpoint(SeleneEndpoint):
if self.account is None: if self.account is None:
raise AuthenticationError('account not found') raise AuthenticationError('account not found')
def _add_refresh_token_to_db(self, refresh_token): def _add_refresh_token_to_db(self):
"""Track refresh tokens in the database. """Track refresh tokens in the database.
We need to store the value of the refresh token in the database so We need to store the value of the refresh token in the database so
that we can validate it when it is used to request new tokens. that we can validate it when it is used to request new tokens.
:param refresh_token: the token to install into the database.
""" """
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db: with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
token_repo = RefreshTokenRepository(db, self.account) token_repo = RefreshTokenRepository(db, self.account.id)
token_repo.add_refresh_token(refresh_token) token_repo.add_refresh_token(self.refresh_token.jwt)

View File

@ -38,29 +38,28 @@ def before_scenario(context, _):
def _add_agreement(context, db): def _add_agreement(context, db):
context.agreement = Agreement( agreement = Agreement(
type='Privacy Policy', type='Privacy Policy',
version='1', version='999',
content='this is Privacy Policy version 1', content='this is Privacy Policy version 999',
effective_date=date.today() - timedelta(days=5) effective_date=date.today() - timedelta(days=5)
) )
agreement_repository = AgreementRepository(db) agreement_repository = AgreementRepository(db)
agreement_repository.add(context.agreement) agreement_repository.add(agreement)
context.agreement = agreement_repository.get_active_for_type(PRIVACY_POLICY)
def _add_account(context, db): def _add_account(context, db):
test_account = Account( test_account = Account(
id=None,
email_address='foo@mycroft.ai', email_address='foo@mycroft.ai',
username='foobar', display_name='foobar',
refresh_tokens=None,
subscription=AccountSubscription( subscription=AccountSubscription(
type='monthly supporter', type='Monthly Supporter',
start_date=None, start_date=date.today(),
stripe_customer_id='foo' stripe_customer_id='foo'
), ),
agreements=[ agreements=[
AccountAgreement(name=PRIVACY_POLICY, accept_date=None) AccountAgreement(type=PRIVACY_POLICY, accept_date=date.today())
] ]
) )
acct_repository = AccountRepository(db) acct_repository = AccountRepository(db)

View File

@ -62,4 +62,5 @@ def check_for_login_fail(context, error_message):
equal_to('*') equal_to('*')
) )
assert_that(context.response.is_json, equal_to(True)) assert_that(context.response.is_json, equal_to(True))
assert_that(context.response.get_json(), equal_to(error_message)) response_json = context.response.get_json()
assert_that(response_json['error'], equal_to(error_message))

View File

@ -3,7 +3,8 @@ from behave import given, then, when
from hamcrest import assert_that, equal_to, has_item, is_not from hamcrest import assert_that, equal_to, has_item, is_not
from selene.api.testing import ( from selene.api.testing import (
generate_auth_tokens, generate_access_token,
generate_refresh_token,
get_account, get_account,
validate_token_cookies validate_token_cookies
) )
@ -16,7 +17,8 @@ def save_email(context, email):
@when('user attempts to logout') @when('user attempts to logout')
def call_logout_endpoint(context): def call_logout_endpoint(context):
generate_auth_tokens(context) generate_access_token(context)
generate_refresh_token(context)
context.response = context.client.get('/api/logout') context.response = context.client.get('/api/logout')
@ -39,7 +41,7 @@ def check_refresh_token_removed(context):
account = get_account(context) account = get_account(context)
assert_that( assert_that(
account.refresh_tokens, account.refresh_tokens,
is_not(has_item(context.request_refresh_token)) is_not(has_item(context.refresh_token))
) )

View File

@ -1,7 +1,7 @@
CREATE TABLE account.account ( CREATE TABLE account.account (
id uuid PRIMARY KEY DEFAULT gen_random_uuid(), id uuid PRIMARY KEY DEFAULT gen_random_uuid(),
email_address text NOT NULL UNIQUE, email_address text NOT NULL UNIQUE,
display_name text NOT NULL UNIQUE, username text NOT NULL UNIQUE,
password text, password text,
insert_ts TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP insert_ts TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
); );

View File

@ -9,7 +9,6 @@ pygithub = "*"
psycopg2-binary = "*" psycopg2-binary = "*"
passlib = "*" passlib = "*"
pyhamcrest = "*" pyhamcrest = "*"
validator-collection = "*"
schematics = "*" schematics = "*"
[dev-packages] [dev-packages]

16
shared/Pipfile.lock generated
View File

@ -1,7 +1,7 @@
{ {
"_meta": { "_meta": {
"hash": { "hash": {
"sha256": "4d6582a80eebdbca3cd6aedd17561dede1e8d01a1315d05248225350afcd82ef" "sha256": "63164ec5172150b56a6a3930e1eb1bfa195837d6dbc019864a1b5475cdfdf590"
}, },
"pipfile-spec": 6, "pipfile-spec": 6,
"requires": { "requires": {
@ -44,13 +44,6 @@
], ],
"version": "==2.8" "version": "==2.8"
}, },
"jsonschema": {
"hashes": [
"sha256:000e68abd33c972a5248544925a0cae7d1125f9bf6c58280d37546b946769a08",
"sha256:6ff5f3180870836cae40f06fa10419f557208175f13ad7bc26caa77beb1f6e02"
],
"version": "==2.6.0"
},
"passlib": { "passlib": {
"hashes": [ "hashes": [
"sha256:3d948f64138c25633613f303bcc471126eae67c04d5e3f6b7b8ce6242f8653e0", "sha256:3d948f64138c25633613f303bcc471126eae67c04d5e3f6b7b8ce6242f8653e0",
@ -147,13 +140,6 @@
], ],
"version": "==1.24.1" "version": "==1.24.1"
}, },
"validator-collection": {
"hashes": [
"sha256:e8ddec6d301bd3be40cacb9d4f9f85573bc003e3e17a66ba7267ef46b9a8e3d2"
],
"index": "pypi",
"version": "==1.3.2"
},
"wrapt": { "wrapt": {
"hashes": [ "hashes": [
"sha256:4aea003270831cceb8a90ff27c4031da6ead7ec1886023b80ce0dfe0adf61533" "sha256:4aea003270831cceb8a90ff27c4031da6ead7ec1886023b80ce0dfe0adf61533"

View File

@ -1,7 +1,5 @@
"""Base class for Flask API endpoints""" """Base class for Flask API endpoints"""
from logging import getLogger
from http import HTTPStatus
from flask import after_this_request, current_app, request from flask import after_this_request, current_app, request
from flask.views import MethodView from flask.views import MethodView
@ -10,15 +8,16 @@ from selene.data.account import (
AccountRepository, AccountRepository,
RefreshTokenRepository RefreshTokenRepository
) )
from selene.util.auth import ( from selene.util.auth import AuthenticationError, AuthenticationToken
AuthenticationError,
AuthenticationTokenGenerator,
AuthenticationTokenValidator,
FIFTEEN_MINUTES,
ONE_MONTH
)
from selene.util.db import get_db_connection from selene.util.db import get_db_connection
ACCESS_TOKEN_COOKIE_NAME = 'seleneAccess'
FIFTEEN_MINUTES = 900
ONE_MONTH = 2628000
REFRESH_TOKEN_COOKIE_NAME = 'seleneRefresh'
_log = getLogger()
class APIError(Exception): class APIError(Exception):
"""Raise this exception whenever a non-successful response is built""" """Raise this exception whenever a non-successful response is built"""
@ -34,16 +33,19 @@ class SeleneEndpoint(MethodView):
HTTP methods. Each list member must be a HTTPMethod enum HTTP methods. Each list member must be a HTTPMethod enum
- override the _build_response_data method - override the _build_response_data method
""" """
authentication_required: bool = True
def __init__(self): def __init__(self):
self.config: dict = current_app.config self.config: dict = current_app.config
self.authenticated = False
self.request = request self.request = request
self.response: tuple = None self.response: tuple = None
self.access_token_expired: bool = False
self.refresh_token_expired: bool = False
self.account: Account = None self.account: Account = None
self.access_token = AuthenticationToken(
self.config['ACCESS_SECRET'],
FIFTEEN_MINUTES
)
self.refresh_token = AuthenticationToken(
self.config['REFRESH_SECRET'],
ONE_MONTH
)
def _authenticate(self): def _authenticate(self):
""" """
@ -51,90 +53,92 @@ class SeleneEndpoint(MethodView):
:raises: APIError() :raises: APIError()
""" """
try: self._validate_auth_tokens()
account_id = self._validate_auth_tokens() account_id = self._get_account_id_from_tokens()
self._get_account(account_id)
self._validate_account(account_id) self._validate_account(account_id)
except AuthenticationError as ae: if self.access_token.is_expired:
if self.authentication_required: self._refresh_auth_tokens()
self.response = (str(ae), HTTPStatus.UNAUTHORIZED)
else:
self.authenticated = True
def _validate_auth_tokens(self) -> str: def _validate_auth_tokens(self):
self.access_token_expired, account_id = self._validate_token( """Ensure the tokens are passed in request and are well formed."""
'seleneAccess', self.access_token.jwt = self.request.cookies.get(
self.config['ACCESS_SECRET'] ACCESS_TOKEN_COOKIE_NAME
) )
if self.access_token_expired: self.access_token.validate()
self.refresh_token_expired, account_id = self._validate_token( self.refresh_token.jwt = self.request.cookies.get(
'seleneRefresh', REFRESH_TOKEN_COOKIE_NAME
self.config['REFRESH_SECRET']
) )
self.refresh_token.validate()
if self.access_token.jwt is None and self.refresh_token.jwt is None:
raise AuthenticationError('no authentication tokens found')
if self.access_token.is_expired and self.refresh_token.is_expired:
raise AuthenticationError('authentication tokens expired')
def _get_account_id_from_tokens(self):
"""Extract the account ID, which is encoded within the tokens"""
if self.access_token.is_expired:
account_id = self.refresh_token.account_id
else:
account_id = self.access_token.account_id
return account_id return account_id
def _validate_token(self, cookie_key, jwt_secret): def _get_account(self, account_id):
"""Validate the access token is well-formed and not expired """Use account ID from decoded authentication token to get account."""
:raises: AuthenticationError
"""
account_id = None
token_expired = False
try:
token = self.request.cookies[cookie_key]
except KeyError:
error_msg = 'no {} token found in request'
raise AuthenticationError(error_msg.format(cookie_key))
validator = AuthenticationTokenValidator(token, jwt_secret)
validator.validate_token()
if validator.token_is_expired:
token_expired = True
elif validator.token_is_invalid:
raise AuthenticationError('access token is invalid')
else:
account_id = validator.account_id
return token_expired, account_id
def _validate_account(self, account_id):
"""The refresh token in the request must match the database value.
:raises: AuthenticationError
"""
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db: with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
account_repository = AccountRepository(db) account_repository = AccountRepository(db)
self.account = account_repository.get_account_by_id(account_id) self.account = account_repository.get_account_by_id(account_id)
def _validate_account(self, account_id: str):
"""Account must exist and contain have a refresh token matching request.
:raises: AuthenticationError
"""
if self.account is None: if self.account is None:
_log.error('account ID {} not on database'.format(account_id))
raise AuthenticationError('account not found') raise AuthenticationError('account not found')
if self.access_token_expired: if self.refresh_token.jwt not in self.account.refresh_tokens:
if self.refresh_token not in self.account.refresh_tokens: log_msg = 'account ID {} does not have token {}'
raise AuthenticationError('refresh token not found') _log.error(log_msg.format(account_id, self.refresh_token.jwt))
raise AuthenticationError(
'refresh token does not exist for this account'
)
def _refresh_auth_tokens(self):
"""Steps necessary to refresh the tokens used for authentication."""
old_refresh_token = self.refresh_token
self._generate_tokens()
self._update_refresh_token_on_db(old_refresh_token)
self._set_token_cookies()
def _generate_tokens(self): def _generate_tokens(self):
token_generator = AuthenticationTokenGenerator( """Generate an access token and refresh token."""
self.account.id, self.access_token.generate()
self.config['ACCESS_SECRET'], self.refresh_token.generate()
self.config['REFRESH_SECRET']
)
access_token = token_generator.access_token
refresh_token = token_generator.refresh_token
return access_token, refresh_token def _set_token_cookies(self, expire=False):
"""Set the cookies that contain the authentication token.
def _set_token_cookies(self, access_token, refresh_token, expire=False): This method should be called when a user logs in, logs out, or when
their access token expires.
:param expire: generate tokens that immediately expire, effectively
logging a user out of the system.
:return:
"""
access_token_cookie = dict( access_token_cookie = dict(
key='seleneAccess', key='seleneAccess',
value=str(access_token), value=str(self.access_token.jwt),
domain=self.config['DOMAIN'], domain=self.config['DOMAIN'],
max_age=FIFTEEN_MINUTES, max_age=FIFTEEN_MINUTES,
) )
refresh_token_cookie = dict( refresh_token_cookie = dict(
key='seleneRefresh', key='seleneRefresh',
value=str(refresh_token), value=str(self.refresh_token.jwt),
domain=self.config['DOMAIN'], domain=self.config['DOMAIN'],
max_age=ONE_MONTH, max_age=ONE_MONTH,
) )
@ -145,20 +149,21 @@ class SeleneEndpoint(MethodView):
@after_this_request @after_this_request
def set_cookies(response): def set_cookies(response):
"""Use Flask after request hook to reset token cookies"""
response.set_cookie(**access_token_cookie) response.set_cookie(**access_token_cookie)
response.set_cookie(**refresh_token_cookie) response.set_cookie(**refresh_token_cookie)
return response return response
def _update_refresh_token_on_db(self, new_refresh_token): def _update_refresh_token_on_db(self, old_refresh_token):
old_refresh_token = self.request.cookies['seleneRefresh'] """Replace the refresh token on the request with the newly minted one"""
with get_db_connection(self.config['DB_CONNECTION_POOL']) as db: with get_db_connection(self.config['DB_CONNECTION_POOL']) as db:
token_repository = RefreshTokenRepository(db, self.account) token_repository = RefreshTokenRepository(db, self.account.id)
if self.refresh_token_expired: if old_refresh_token.is_expired:
token_repository.delete_refresh_token(old_refresh_token) token_repository.delete_refresh_token(old_refresh_token)
raise AuthenticationError('refresh token expired') raise AuthenticationError('refresh token expired')
else: else:
token_repository.update_refresh_token( token_repository.update_refresh_token(
new_refresh_token, self.refresh_token.jwt,
old_refresh_token old_refresh_token.jwt
) )

View File

@ -3,9 +3,16 @@ from http import HTTPStatus
from flask import Blueprint from flask import Blueprint
from schematics.exceptions import DataError from schematics.exceptions import DataError
from selene.util.auth import AuthenticationError
selene_api = Blueprint('selene_api', __name__) selene_api = Blueprint('selene_api', __name__)
@selene_api.app_errorhandler(DataError) @selene_api.app_errorhandler(DataError)
def handle_data_error(error): def handle_data_error(error):
return str(error.messages), HTTPStatus.BAD_REQUEST return str(error.messages), HTTPStatus.BAD_REQUEST
@selene_api.app_errorhandler(AuthenticationError)
def handle_data_error(error):
return dict(error=str(error)), HTTPStatus.UNAUTHORIZED

View File

@ -3,7 +3,7 @@ from dataclasses import asdict
from datetime import date from datetime import date
from http import HTTPStatus from http import HTTPStatus
from flask import json from flask import json, jsonify
from schematics import Model from schematics import Model
from schematics.exceptions import ValidationError from schematics.exceptions import ValidationError
from schematics.types import BooleanType, EmailType, ModelType, StringType from schematics.types import BooleanType, EmailType, ModelType, StringType
@ -67,7 +67,7 @@ class Support(Model):
class AddAccountRequest(Model): class AddAccountRequest(Model):
display_name = StringType(required=True) username = StringType(required=True)
privacy_policy = BooleanType(required=True, validators=[agreement_accepted]) privacy_policy = BooleanType(required=True, validators=[agreement_accepted])
terms_of_use = BooleanType(required=True, validators=[agreement_accepted]) terms_of_use = BooleanType(required=True, validators=[agreement_accepted])
login = ModelType(Login) login = ModelType(Login)
@ -83,7 +83,6 @@ class AccountEndpoint(SeleneEndpoint):
def get(self): def get(self):
"""Process HTTP GET request for an account.""" """Process HTTP GET request for an account."""
self._authenticate() self._authenticate()
if self.authenticated:
response_data = asdict(self.account) response_data = asdict(self.account)
del (response_data['refresh_tokens']) del (response_data['refresh_tokens'])
self.response = response_data, HTTPStatus.OK self.response = response_data, HTTPStatus.OK
@ -96,11 +95,11 @@ class AccountEndpoint(SeleneEndpoint):
email_address, password = self._determine_login_method() email_address, password = self._determine_login_method()
self._add_account(email_address, password) self._add_account(email_address, password)
return 'Account added successfully', HTTPStatus.OK return jsonify('Account added successfully'), HTTPStatus.OK
def _validate_request(self): def _validate_request(self):
add_request = AddAccountRequest(dict( add_request = AddAccountRequest(dict(
display_name=self.request_data.get('displayName'), username=self.request_data.get('username'),
privacy_policy=self.request_data.get('privacyPolicy'), privacy_policy=self.request_data.get('privacyPolicy'),
terms_of_use=self.request_data.get('termsOfUse'), terms_of_use=self.request_data.get('termsOfUse'),
login=self._build_login_schematic(), login=self._build_login_schematic(),
@ -149,7 +148,7 @@ class AccountEndpoint(SeleneEndpoint):
] ]
account = Account( account = Account(
email_address=email_address, email_address=email_address,
display_name=self.request_data['displayName'], username=self.request_data['username'],
agreements=[ agreements=[
AccountAgreement(type=PRIVACY_POLICY, accept_date=date.today()), AccountAgreement(type=PRIVACY_POLICY, accept_date=date.today()),
AccountAgreement(type=TERMS_OF_USE, accept_date=date.today()) AccountAgreement(type=TERMS_OF_USE, accept_date=date.today())

View File

@ -1,6 +1,7 @@
from .authentication import ( from .authentication import (
ACCESS_TOKEN_COOKIE_KEY, ACCESS_TOKEN_COOKIE_KEY,
generate_auth_tokens, generate_access_token,
generate_refresh_token,
get_account, get_account,
REFRESH_TOKEN_COOKIE_KEY, REFRESH_TOKEN_COOKIE_KEY,
validate_token_cookies validate_token_cookies

View File

@ -5,34 +5,54 @@ from selene.data.account import (
AccountRepository, AccountRepository,
RefreshTokenRepository RefreshTokenRepository
) )
from selene.util.auth import AuthenticationTokenGenerator from selene.util.auth import AuthenticationToken
from selene.util.db import get_db_connection from selene.util.db import get_db_connection
ACCESS_TOKEN_COOKIE_KEY = 'seleneAccess' ACCESS_TOKEN_COOKIE_KEY = 'seleneAccess'
ONE_MINUTE = 60
TWO_MINUTES = 120
REFRESH_TOKEN_COOKIE_KEY = 'seleneRefresh' REFRESH_TOKEN_COOKIE_KEY = 'seleneRefresh'
def generate_auth_tokens(context): def generate_access_token(context, expire=False):
token_generator = AuthenticationTokenGenerator( access_token = AuthenticationToken(
context.account.id,
context.client_config['ACCESS_SECRET'], context.client_config['ACCESS_SECRET'],
context.client_config['REFRESH_SECRET'] ONE_MINUTE
) )
access_token.account_id = context.account.id
if not expire:
access_token.generate()
context.access_token = access_token
context.client.set_cookie( context.client.set_cookie(
context.client_config['DOMAIN'], context.client_config['DOMAIN'],
ACCESS_TOKEN_COOKIE_KEY, ACCESS_TOKEN_COOKIE_KEY,
token_generator.access_token access_token.jwt,
max_age=0 if expire else ONE_MINUTE
) )
def generate_refresh_token(context, expire=False):
account_id = context.account.id
refresh_token = AuthenticationToken(
context.client_config['REFRESH_SECRET'],
TWO_MINUTES
)
refresh_token.account_id = account_id
if not expire:
refresh_token.generate()
context.refresh_token = refresh_token
context.client.set_cookie( context.client.set_cookie(
context.client_config['DOMAIN'], context.client_config['DOMAIN'],
REFRESH_TOKEN_COOKIE_KEY, REFRESH_TOKEN_COOKIE_KEY,
token_generator.refresh_token refresh_token.jwt,
max_age=0 if expire else TWO_MINUTES
) )
context.request_refresh_token = token_generator.refresh_token
with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db: with get_db_connection(context.client_config['DB_CONNECTION_POOL']) as db:
token_repository = RefreshTokenRepository(db, context.account.id) token_repository = RefreshTokenRepository(db, account_id)
token_repository.add_refresh_token(token_generator.refresh_token) token_repository.add_refresh_token(refresh_token.jwt)
def validate_token_cookies(context, expired=False): def validate_token_cookies(context, expired=False):

View File

@ -24,7 +24,7 @@ class AccountSubscription(object):
class Account(object): class Account(object):
"""Representation of a Mycroft user account.""" """Representation of a Mycroft user account."""
email_address: str email_address: str
display_name: str username: str
agreements: List[AccountAgreement] agreements: List[AccountAgreement]
subscription: AccountSubscription subscription: AccountSubscription
id: str = None id: str = None

View File

@ -48,7 +48,7 @@ class AccountRepository(object):
args=dict( args=dict(
email_address=account.email_address, email_address=account.email_address,
password=encrypted_password, password=encrypted_password,
display_name=account.display_name username=account.username
) )
) )
result = self.cursor.insert_returning(request) result = self.cursor.insert_returning(request)

View File

@ -1,6 +1,6 @@
INSERT INTO INSERT INTO
account.account (email_address, password, display_name) account.account (email_address, password, username)
VALUES VALUES
(%(email_address)s, %(password)s, %(display_name)s) (%(email_address)s, %(password)s, %(username)s)
RETURNING RETURNING
id id

View File

@ -41,7 +41,7 @@ SELECT
json_build_object( json_build_object(
'id', id, 'id', id,
'email_address', email_address, 'email_address', email_address,
'display_name', display_name, 'username', username,
'subscription', (SELECT * FROM subscription), 'subscription', (SELECT * FROM subscription),
'refresh_tokens', (SELECT * FROM refresh_tokens), 'refresh_tokens', (SELECT * FROM refresh_tokens),
'agreements', (SELECT * FROM agreements) 'agreements', (SELECT * FROM agreements)

View File

@ -1,86 +1,50 @@
"""Logic for generating and validating JWT authentication tokens."""
from datetime import datetime from datetime import datetime
from time import time from time import time
import jwt import jwt
FIFTEEN_MINUTES = 900
ONE_MONTH = 2628000
class AuthenticationError(Exception): class AuthenticationError(Exception):
pass pass
class AuthenticationTokenGenerator(object): class AuthenticationToken(object):
_access_token = None def __init__(self, secret: str, duration: int):
_refresh_token = None self.secret = secret
self.duration = duration
self.jwt: str = ''
self.is_valid: bool = None
self.is_expired: bool = None
self.account_id: str = None
def __init__(self, account_id: str, access_secret, refresh_secret): def generate(self):
self.account_id = account_id
self.access_secret = access_secret
self.refresh_secret = refresh_secret
def _generate_token(self, token_duration: int):
""" """
Generates a JWT token Generates a JWT token
""" """
token_expiration = time() + token_duration
payload = dict( payload = dict(
iat=datetime.utcnow(), iat=datetime.utcnow(),
exp=token_expiration, exp=time() + self.duration,
sub=self.account_id sub=self.account_id
) )
token = jwt.encode(payload, self.secret, algorithm='HS256')
if token_duration == FIFTEEN_MINUTES:
secret = self.access_secret
else:
secret = self.refresh_secret
token = jwt.encode(
payload,
secret,
algorithm='HS256'
)
# convert the token from byte-array to string so that # convert the token from byte-array to string so that
# it can be included in a JSON response object # it can be included in a JSON response object
return token.decode() self.jwt = token.decode()
@property def validate(self):
def access_token(self): """Decodes the auth token and performs some preliminary validation."""
""" self.is_expired = False
Generates a JWT access token self.is_valid = True
"""
if self._access_token is None:
self._access_token = self._generate_token(FIFTEEN_MINUTES)
return self._access_token if self.jwt is None:
self.is_expired = True
@property else:
def refresh_token(self):
"""
Generates a JWT access token
"""
if self._refresh_token is None:
self._refresh_token = self._generate_token(ONE_MONTH)
return self._refresh_token
class AuthenticationTokenValidator(object):
def __init__(self, token: str, secret: str):
self.token = token
self.secret = secret
self.account_id = None
self.token_is_expired = False
self.token_is_invalid = False
def validate_token(self):
"""Decodes the auth token"""
try: try:
payload = jwt.decode(self.token, self.secret) payload = jwt.decode(self.jwt, self.secret)
self.account_id = payload['sub'] self.account_id = payload['sub']
except jwt.ExpiredSignatureError: except jwt.ExpiredSignatureError:
self.token_is_expired = True self.is_expired = True
except jwt.InvalidTokenError: except jwt.InvalidTokenError:
self.token_is_invalid = True self.is_valid = False

View File

@ -16,7 +16,6 @@ setup(
'pyhamcrest', 'pyhamcrest',
'pyjwt', 'pyjwt',
'psycopg2-binary', 'psycopg2-binary',
'schematics', 'schematics'
'validator-collection'
] ]
) )