Merge pull request #3560 from derekpierre/json-auth

JSON API Authorized Requests (OAuth and JWT Support)
pull/3563/head
Derek Pierre 2024-10-30 11:12:33 -04:00 committed by GitHub
commit 2d85ccf07b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 303 additions and 62 deletions

View File

@ -0,0 +1 @@
Enable support for Bearer authorization tokens (e.g., OAuth, JWT) within HTTP GET requests for ``JsonApiCondition``.

View File

@ -1,6 +1,6 @@
import re
from functools import partial
from typing import Any, List, Optional, Union
from typing import Any, Dict, List, Union
from eth_typing import ChecksumAddress
from eth_utils import to_checksum_address
@ -90,12 +90,12 @@ _DIRECTIVES = {
def is_context_variable(variable) -> bool:
if isinstance(variable, str) and variable.startswith(CONTEXT_PREFIX):
if CONTEXT_REGEX.fullmatch(variable):
return True
else:
raise ValueError(f"Context variable name '{variable}' is not valid.")
return False
return isinstance(variable, str) and CONTEXT_REGEX.fullmatch(variable)
def string_contains_context_variable(variable: str) -> bool:
matches = re.findall(CONTEXT_REGEX, variable)
return bool(matches)
def get_context_value(context_variable: str, **context) -> Any:
@ -116,20 +116,30 @@ def get_context_value(context_variable: str, **context) -> Any:
return value
def _resolve_context_variable(param: Union[Any, List[Any]], **context):
def resolve_any_context_variables(
param: Union[Any, List[Any], Dict[Any, Any]], **context
):
if isinstance(param, list):
return [_resolve_context_variable(item, **context) for item in param]
elif is_context_variable(param):
return get_context_value(context_variable=param, **context)
return [resolve_any_context_variables(item, **context) for item in param]
elif isinstance(param, dict):
return {
k: resolve_any_context_variables(v, **context) for k, v in param.items()
}
elif isinstance(param, str):
# either it is a context variable OR contains a context variable within it
# TODO separating the two cases for now out of concern of regex searching
# within strings (case 2)
if is_context_variable(param):
return get_context_value(context_variable=param, **context)
else:
matches = re.findall(CONTEXT_REGEX, param)
for context_var in matches:
# checking out of concern for faulty regex search within string
if context_var in context:
resolved_var = get_context_value(
context_variable=context_var, **context
)
param = param.replace(context_var, str(resolved_var))
return param
else:
return param
def resolve_parameter_context_variables(parameters: Optional[List[Any]], **context):
if not parameters:
processed_parameters = [] # produce empty list
else:
processed_parameters = [
_resolve_context_variable(param, **context) for param in parameters
]
return processed_parameters

View File

@ -31,7 +31,7 @@ from nucypher.policy.conditions.base import (
)
from nucypher.policy.conditions.context import (
is_context_variable,
resolve_parameter_context_variables,
resolve_any_context_variables,
)
from nucypher.policy.conditions.exceptions import (
NoConnectionToChain,
@ -169,9 +169,11 @@ class RPCCall(ExecutionCall):
yield provider
def execute(self, providers: Dict[int, Set[HTTPProvider]], **context) -> Any:
resolved_parameters = resolve_parameter_context_variables(
self.parameters, **context
)
resolved_parameters = []
if self.parameters:
resolved_parameters = resolve_any_context_variables(
self.parameters, **context
)
endpoints = self._next_endpoint(providers=providers)
latest_error = ""

View File

@ -28,8 +28,8 @@ from nucypher.policy.conditions.base import (
_Serializable,
)
from nucypher.policy.conditions.context import (
_resolve_context_variable,
is_context_variable,
resolve_any_context_variables,
)
from nucypher.policy.conditions.exceptions import (
InvalidCondition,
@ -611,7 +611,7 @@ class ReturnValueTest:
return result
def with_resolved_context(self, **context):
value = _resolve_context_variable(self.value, **context)
value = resolve_any_context_variables(self.value, **context)
return ReturnValueTest(self.comparator, value=value, index=self.index)

View File

@ -3,10 +3,15 @@ from typing import Any, Optional, Tuple
import requests
from jsonpath_ng.exceptions import JsonPathLexerError, JsonPathParserError
from jsonpath_ng.ext import parse
from marshmallow import fields, post_load, validate
from marshmallow import ValidationError, fields, post_load, validate, validates
from marshmallow.fields import Field, Url
from nucypher.policy.conditions.base import ExecutionCall
from nucypher.policy.conditions.context import (
is_context_variable,
resolve_any_context_variables,
string_contains_context_variable,
)
from nucypher.policy.conditions.exceptions import (
ConditionEvaluationFailed,
InvalidCondition,
@ -29,7 +34,8 @@ class JSONPathField(Field):
if not isinstance(value, str):
raise self.make_error("invalidType", value=type(value))
try:
parse(value)
if not string_contains_context_variable(value):
parse(value)
except (JsonPathLexerError, JsonPathParserError):
raise self.make_error("invalid", value=value)
return value
@ -42,56 +48,80 @@ class JsonApiCall(ExecutionCall):
endpoint = Url(required=True, relative=False, schemes=["https"])
parameters = fields.Dict(required=False, allow_none=True)
query = JSONPathField(required=False, allow_none=True)
authorization_token = fields.Str(required=False, allow_none=True)
@post_load
def make(self, data, **kwargs):
return JsonApiCall(**data)
@validates("authorization_token")
def validate_auth_token(self, value):
if value and not is_context_variable(value):
raise ValidationError(
f"Invalid value for authorization token; expected a context variable, but got '{value}'"
)
def __init__(
self,
endpoint: str,
parameters: Optional[dict] = None,
query: Optional[str] = None,
authorization_token: Optional[str] = None,
):
self.endpoint = endpoint
self.parameters = parameters or {}
self.query = query
self.authorization_token = authorization_token
self.timeout = self.TIMEOUT
self.logger = Logger(__name__)
super().__init__()
def execute(self, *args, **kwargs) -> Any:
response = self._fetch()
def execute(self, **context) -> Any:
response = self._fetch(**context)
data = self._deserialize_response(response)
result = self._query_response(data)
result = self._query_response(data, **context)
return result
def _fetch(self) -> requests.Response:
def _fetch(self, **context) -> requests.Response:
"""Fetches data from the endpoint."""
resolved_endpoint = resolve_any_context_variables(self.endpoint, **context)
resolved_parameters = resolve_any_context_variables(self.parameters, **context)
headers = None
if self.authorization_token:
resolved_authorization_token = resolve_any_context_variables(
self.authorization_token, **context
)
headers = {"Authorization": f"Bearer {resolved_authorization_token}"}
try:
response = requests.get(
self.endpoint, params=self.parameters, timeout=self.timeout
resolved_endpoint,
params=resolved_parameters,
timeout=self.timeout,
headers=headers,
)
response.raise_for_status()
except requests.exceptions.HTTPError as http_error:
self.logger.error(f"HTTP error occurred: {http_error}")
raise ConditionEvaluationFailed(
f"Failed to fetch endpoint {self.endpoint}: {http_error}"
f"Failed to fetch endpoint {resolved_endpoint}: {http_error}"
)
except requests.exceptions.RequestException as request_error:
self.logger.error(f"Request exception occurred: {request_error}")
raise InvalidCondition(
f"Failed to fetch endpoint {self.endpoint}: {request_error}"
f"Failed to fetch endpoint {resolved_endpoint}: {request_error}"
)
if response.status_code != 200:
self.logger.error(
f"Failed to fetch endpoint {self.endpoint}: {response.status_code}"
f"Failed to fetch endpoint {resolved_endpoint}: {response.status_code}"
)
raise ConditionEvaluationFailed(
f"Failed to fetch endpoint {self.endpoint}: {response.status_code}"
f"Failed to fetch endpoint {resolved_endpoint}: {response.status_code}"
)
return response
@ -107,16 +137,18 @@ class JsonApiCall(ExecutionCall):
)
return data
def _query_response(self, data: Any) -> Any:
def _query_response(self, data: Any, **context) -> Any:
if not self.query:
return data # primitive value
resolved_query = resolve_any_context_variables(self.query, **context)
try:
expression = parse(self.query)
expression = parse(resolved_query)
matches = expression.find(data)
if not matches:
message = f"No matches found for the JSONPath query: {self.query}"
message = f"No matches found for the JSONPath query: {resolved_query}"
self.logger.info(message)
raise ConditionEvaluationFailed(message)
except (JsonPathLexerError, JsonPathParserError) as jsonpath_err:
@ -124,9 +156,7 @@ class JsonApiCall(ExecutionCall):
raise ConditionEvaluationFailed(f"JSONPath error: {jsonpath_err}")
if len(matches) > 1:
message = (
f"Ambiguous JSONPath query - Multiple matches found for: {self.query}"
)
message = f"Ambiguous JSONPath query - Multiple matches found for: {resolved_query}"
self.logger.info(message)
raise ConditionEvaluationFailed(message)
result = matches[0].value
@ -159,6 +189,7 @@ class JsonApiCondition(ExecutionCallAccessControlCondition):
return_value_test: ReturnValueTest,
query: Optional[str] = None,
parameters: Optional[dict] = None,
authorization_token: Optional[str] = None,
condition_type: str = ConditionType.JSONAPI.value,
name: Optional[str] = None,
):
@ -167,6 +198,7 @@ class JsonApiCondition(ExecutionCallAccessControlCondition):
return_value_test=return_value_test,
query=query,
parameters=parameters,
authorization_token=authorization_token,
condition_type=condition_type,
name=name,
)
@ -187,6 +219,10 @@ class JsonApiCondition(ExecutionCallAccessControlCondition):
def timeout(self):
return self.execution_call.timeout
@property
def authorization_token(self):
return self.execution_call.authorization_token
@staticmethod
def _process_result_for_eval(result: Any):
# strings that are not already quoted will cause a problem for literal_eval

View File

@ -70,11 +70,12 @@ def lingo_with_all_condition_types(get_random_checksum_address):
# JSON API
"conditionType": ConditionType.JSONAPI.value,
"endpoint": "https://api.example.com/data",
"query": "$.store.book[0].price",
"parameters": {
"ids": "ethereum",
"vs_currencies": "usd",
},
"authorizationToken": ":authToken",
"query": "$.store.book[0].price",
"returnValueTest": {
"comparator": "==",
"value": 2,

View File

@ -1,17 +1,15 @@
import copy
import itertools
import re
import pytest
from nucypher.policy.conditions.context import (
USER_ADDRESS_CONTEXT,
USER_ADDRESS_EIP4361_EXTERNAL_CONTEXT,
_resolve_context_variable,
_resolve_user_address,
get_context_value,
is_context_variable,
resolve_parameter_context_variables,
resolve_any_context_variables,
)
from nucypher.policy.conditions.exceptions import (
ContextVariableVerificationFailed,
@ -49,6 +47,7 @@ DEFINITELY_NOT_CONTEXT_PARAM_NAMES = ["1234", "foo", "", 123]
CONTEXT = {":foo": 1234, ":bar": "'BAR'"}
VALUES_WITH_RESOLUTION = [
([], []),
(42, 42),
(True, True),
("'bar'", "'bar'"),
@ -67,16 +66,12 @@ def test_is_context_variable():
assert not is_context_variable(variable)
for variable in INVALID_CONTEXT_PARAM_NAMES:
expected_message = re.escape(
f"Context variable name '{variable}' is not valid."
)
with pytest.raises(ValueError, match=expected_message):
_ = is_context_variable(variable)
assert not is_context_variable(variable)
def test_resolve_context_variable():
for value, resolution in VALUES_WITH_RESOLUTION:
assert resolution == _resolve_context_variable(value, **CONTEXT)
assert resolution == resolve_any_context_variables(value, **CONTEXT)
def test_resolve_any_context_variables():
@ -86,7 +81,7 @@ def test_resolve_any_context_variables():
params, resolved_params = params_with_resolution
value, resolved_value = value_with_resolution
return_value_test = ReturnValueTest(comparator="==", value=value)
resolved_parameters = resolve_parameter_context_variables([params], **CONTEXT)
resolved_parameters = resolve_any_context_variables([params], **CONTEXT)
resolved_return_value = return_value_test.with_resolved_context(**CONTEXT)
assert resolved_parameters == [resolved_params]
assert resolved_return_value.comparator == return_value_test.comparator
@ -94,6 +89,78 @@ def test_resolve_any_context_variables():
assert resolved_return_value.value == resolved_value
@pytest.mark.parametrize(
"value, expected_resolution",
[
(
"https://api.github.com/user/:foo/:bar",
"https://api.github.com/user/1234/BAR",
),
(
"The cost of :bar is $:foo; $:foo is too expensive for :bar",
"The cost of BAR is $1234; $1234 is too expensive for BAR",
),
# graphql query
(
"""{
organization(login: ":bar") {
teams(first: :foo, userLogins: [":bar"]) {
totalCount
edges {
node {
id
name
description
}
}
}
}
}""",
"""{
organization(login: "BAR") {
teams(first: 1234, userLogins: ["BAR"]) {
totalCount
edges {
node {
id
name
description
}
}
}
}
}""",
),
],
)
def test_resolve_context_variable_within_substring(value, expected_resolution):
context = {":foo": 1234, ":bar": "BAR"}
resolved_value = resolve_any_context_variables(value, **context)
assert expected_resolution == resolved_value
@pytest.mark.parametrize(
"value, expected_resolution",
[
(
{
"book_name": ":bar",
"price": "$:foo",
"description": ":bar is a book about foo and bar.",
},
{
"book_name": "BAR",
"price": "$1234",
"description": "BAR is a book about foo and bar.",
},
)
],
)
def test_resolve_context_variable_within_dictionary(value, expected_resolution):
context = {":foo": 1234, ":bar": "BAR"}
resolved_value = resolve_any_context_variables(value, **context)
assert expected_resolution == resolved_value
@pytest.mark.parametrize("expected_entry", ["address", "signature", "typedData"])
@pytest.mark.parametrize(
"context_variable_name, valid_user_address_fixture",

View File

@ -304,6 +304,12 @@ def test_contract_condition_schema_validation():
del condition_dict["contractAddress"]
ContractCondition.from_dict(condition_dict)
with pytest.raises(InvalidConditionLingo):
# invalid contract address
contract_dict = contract_condition.to_dict()
contract_dict["contractAddress"] = "0xABCD"
ContractCondition.from_dict(condition_dict)
balanceOf_abi = {
"constant": True,
"inputs": [{"name": "_owner", "type": "address"}],

View File

@ -48,10 +48,10 @@ def test_json_api_condition_invalid_type():
InvalidCondition, match="'condition_type' field - Must be equal to json-api"
):
_ = JsonApiCondition(
condition_type="INVALID_TYPE",
endpoint="https://api.example.com/data",
query="$.store.book[0].price",
return_value_test=ReturnValueTest("==", 0),
condition_type="INVALID_TYPE",
)
@ -64,6 +64,16 @@ def test_https_enforcement():
)
def test_invalid_authorization_token():
with pytest.raises(InvalidCondition, match="Invalid value for authorization token"):
_ = JsonApiCondition(
endpoint="https://api.example.com/data",
authorization_token="1234", # doesn't make sense hardcoding the token
query="$.store.book[0].price",
return_value_test=ReturnValueTest("==", 0),
)
def test_json_api_condition_with_primitive_response(mocker):
mock_response = mocker.Mock(status_code=200)
mock_response.json.return_value = 1
@ -109,10 +119,16 @@ def test_json_api_condition_fetch_failure(mocker):
def test_json_api_condition_verify(mocker):
mock_response = mocker.Mock(status_code=200)
mock_response.json.return_value = {"store": {"book": [{"price": 1}]}}
mocker.patch("requests.get", return_value=mock_response)
mocked_method = mocker.patch("requests.get", return_value=mock_response)
parameters = {
"arg1": "val1",
"arg2": "val2",
}
condition = JsonApiCondition(
endpoint="https://api.example.com/data",
parameters=parameters,
query="$.store.book[0].price",
return_value_test=ReturnValueTest("==", 1),
)
@ -120,6 +136,10 @@ def test_json_api_condition_verify(mocker):
assert result is True
assert value == 1
# check that appropriate kwarg used for respective http method
assert mocked_method.call_count == 1
assert mocked_method.call_args.kwargs["params"] == parameters
@pytest.mark.parametrize(
"json_return, value_test",
@ -208,18 +228,90 @@ def test_basic_json_api_condition_evaluation_with_parameters(mocker):
assert mocked_get.call_count == 1
def test_json_api_condition_evaluation_with_auth_token(mocker):
mocked_get = mocker.patch(
"requests.get",
return_value=mocker.Mock(
status_code=200, json=lambda: {"ethereum": {"usd": 0.0}}
),
)
condition = JsonApiCondition(
endpoint="https://api.coingecko.com/api/v3/simple/price",
parameters={
"ids": "ethereum",
"vs_currencies": "usd",
},
authorization_token=":authToken",
query="ethereum.usd",
return_value_test=ReturnValueTest("==", 0.0),
)
assert condition.authorization_token == ":authToken"
auth_token = "1234567890"
context = {":authToken": f"{auth_token}"}
assert condition.verify(**context) == (True, 0.0)
assert mocked_get.call_count == 1
assert (
mocked_get.call_args.kwargs["headers"]["Authorization"]
== f"Bearer {auth_token}"
)
def test_json_api_condition_evaluation_with_various_context_variables(mocker):
mocked_get = mocker.patch(
"requests.get",
return_value=mocker.Mock(
status_code=200, json=lambda: {"ethereum": {"cad": 0.0}}
),
)
condition = JsonApiCondition(
endpoint="https://api.coingecko.com/api/:version/simple/:endpointPath",
parameters={
"ids": "ethereum",
"vs_currencies": ":vsCurrency",
},
authorization_token=":authToken",
query="ethereum.:vsCurrency",
return_value_test=ReturnValueTest("==", ":expectedPrice"),
)
assert condition.authorization_token == ":authToken"
auth_token = "1234567890"
context = {
":endpointPath": "price",
":version": "v3",
":vsCurrency": "cad",
":authToken": f"{auth_token}",
":expectedPrice": 0.0,
}
assert condition.verify(**context) == (True, 0.0)
assert mocked_get.call_count == 1
call_args = mocked_get.call_args
assert call_args.args == (
f"https://api.coingecko.com/api/{context[':version']}/simple/{context[':endpointPath']}",
)
assert call_args.kwargs["headers"]["Authorization"] == f"Bearer {auth_token}"
assert call_args.kwargs["params"] == {
"ids": "ethereum",
"vs_currencies": context[":vsCurrency"],
}
def test_json_api_condition_from_lingo_expression():
lingo_dict = {
"conditionType": "json-api",
"endpoint": "https://api.example.com/data",
"query": "$.store.book[0].price",
"parameters": {
"ids": "ethereum",
"vs_currencies": "usd",
},
"query": "$.store.book[0].price",
"returnValueTest": {
"comparator": "==",
"value": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
"value": 1.0,
},
}
@ -229,6 +321,32 @@ def test_json_api_condition_from_lingo_expression():
lingo_json = json.dumps(lingo_dict)
condition = JsonApiCondition.from_json(lingo_json)
assert isinstance(condition, JsonApiCondition)
assert condition.to_dict() == lingo_dict
def test_json_api_condition_from_lingo_expression_with_authorization():
lingo_dict = {
"conditionType": "json-api",
"endpoint": "https://api.example.com/data",
"parameters": {
"ids": "ethereum",
"vs_currencies": "usd",
},
"authorizationToken": ":authorizationToken",
"query": "$.store.book[0].price",
"returnValueTest": {
"comparator": "==",
"value": 1.0,
},
}
cls = ConditionLingo.resolve_condition_class(lingo_dict, version=1)
assert cls == JsonApiCondition
lingo_json = json.dumps(lingo_dict)
condition = JsonApiCondition.from_json(lingo_json)
assert isinstance(condition, JsonApiCondition)
assert condition.to_dict() == lingo_dict
def test_ambiguous_json_path_multiple_results(mocker):

View File

@ -6,7 +6,7 @@ from typing import NamedTuple
import pytest
from hexbytes import HexBytes
from nucypher.policy.conditions.context import _resolve_context_variable
from nucypher.policy.conditions.context import resolve_any_context_variables
from nucypher.policy.conditions.exceptions import ReturnValueEvaluationError
from nucypher.policy.conditions.lingo import ReturnValueTest
@ -150,14 +150,14 @@ def test_return_value_test_with_resolved_context():
resolved = test.with_resolved_context(**context)
assert resolved.comparator == test.comparator
assert resolved.index == test.index
assert resolved.value == _resolve_context_variable(test.value, **context)
assert resolved.value == resolve_any_context_variables(test.value, **context)
test = ReturnValueTest(comparator="==", value=[42, ":foo"])
resolved = test.with_resolved_context(**context)
assert resolved.comparator == test.comparator
assert resolved.index == test.index
assert resolved.value == _resolve_context_variable(test.value, **context)
assert resolved.value == resolve_any_context_variables(test.value, **context)
def test_return_value_test_integer():