mirror of https://github.com/nucypher/nucypher.git
restrictive design: unexpose custom request headers and require https
parent
f2c7337483
commit
06af880616
|
@ -32,6 +32,17 @@ class JSONPathField(Field):
|
|||
return value
|
||||
|
||||
|
||||
class HTTPSField(Field):
|
||||
default_error_messages = {
|
||||
"invalid": "'{value}' is not a valid HTTPS endpoint",
|
||||
}
|
||||
|
||||
def _deserialize(self, value, attr, data, **kwargs):
|
||||
if not value.startswith("https://"):
|
||||
raise self.make_error("invalid", value=value)
|
||||
return value
|
||||
|
||||
|
||||
class JsonApiCondition(AccessControlCondition):
|
||||
"""
|
||||
A JSON API condition is a condition that can be evaluated by reading from a JSON
|
||||
|
@ -48,9 +59,8 @@ class JsonApiCondition(AccessControlCondition):
|
|||
condition_type = fields.Str(
|
||||
validate=validate.Equal(ConditionType.JSONAPI.value), required=True
|
||||
)
|
||||
headers = fields.Dict(required=False)
|
||||
parameters = fields.Dict(required=False)
|
||||
endpoint = fields.Str(required=True)
|
||||
endpoint = HTTPSField(required=True)
|
||||
query = JSONPathField(required=True)
|
||||
return_value_test = fields.Nested(
|
||||
ReturnValueTest.ReturnValueTestSchema(), required=True
|
||||
|
@ -65,7 +75,6 @@ class JsonApiCondition(AccessControlCondition):
|
|||
endpoint: str,
|
||||
query: Optional[str],
|
||||
return_value_test: ReturnValueTest,
|
||||
headers: Optional[dict] = None,
|
||||
parameters: Optional[dict] = None,
|
||||
condition_type: str = ConditionType.JSONAPI.value,
|
||||
):
|
||||
|
@ -75,7 +84,6 @@ class JsonApiCondition(AccessControlCondition):
|
|||
)
|
||||
|
||||
self.endpoint = endpoint
|
||||
self.headers = headers
|
||||
self.parameters = parameters
|
||||
self.query = query
|
||||
self.return_value_test = return_value_test
|
||||
|
@ -84,9 +92,7 @@ class JsonApiCondition(AccessControlCondition):
|
|||
def fetch(self) -> requests.Response:
|
||||
"""Fetches data from the endpoint."""
|
||||
try:
|
||||
response = requests.get(
|
||||
self.endpoint, params=self.parameters, headers=self.headers
|
||||
)
|
||||
response = requests.get(self.endpoint, params=self.parameters, timeout=5)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as http_error:
|
||||
self.logger.error(f"HTTP error occurred: {http_error}")
|
||||
|
|
|
@ -9,7 +9,26 @@ from nucypher.policy.conditions.exceptions import (
|
|||
InvalidCondition,
|
||||
)
|
||||
from nucypher.policy.conditions.lingo import ConditionLingo, ReturnValueTest
|
||||
from nucypher.policy.conditions.offchain import JsonApiCondition, JSONPathField
|
||||
from nucypher.policy.conditions.offchain import (
|
||||
HTTPSField,
|
||||
JsonApiCondition,
|
||||
JSONPathField,
|
||||
)
|
||||
|
||||
|
||||
def test_https_field_valid():
|
||||
field = HTTPSField()
|
||||
valid_https = "https://api.example.com/data"
|
||||
result = field.deserialize(valid_https)
|
||||
assert result == valid_https
|
||||
|
||||
|
||||
def test_https_field_invalid():
|
||||
field = HTTPSField()
|
||||
invalid_https = "http://api.example.com/data"
|
||||
with pytest.raises(ValidationError) as excinfo:
|
||||
field.deserialize(invalid_https)
|
||||
assert f"'{invalid_https}' is not a valid HTTPS endpoint" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_jsonpath_field_valid():
|
||||
|
@ -154,30 +173,6 @@ def test_basic_json_api_condition_evaluation_with_parameters(mocker):
|
|||
assert mocked_get.call_count == 1
|
||||
|
||||
|
||||
def test_basic_json_api_condition_evaluation_with_headers(mocker):
|
||||
mocked_get = mocker.patch(
|
||||
"requests.get",
|
||||
return_value=mocker.Mock(
|
||||
status_code=200, json=lambda: {"ethereum": {"usd": 0.0}}
|
||||
),
|
||||
)
|
||||
|
||||
condition = JsonApiCondition(
|
||||
endpoint="https://api.coingecko.com/api/v3/simple/price",
|
||||
parameters={
|
||||
"ids": "ethereum",
|
||||
"vs_currencies": "usd",
|
||||
},
|
||||
headers={"Authorization": "Bearer 1234567890"},
|
||||
query="ethereum.usd",
|
||||
return_value_test=ReturnValueTest("==", 0.0),
|
||||
)
|
||||
|
||||
assert condition.verify() == (True, 0.0)
|
||||
assert mocked_get.call_count == 1
|
||||
assert mocked_get.call_args[1]["headers"]["Authorization"] == "Bearer 1234567890"
|
||||
|
||||
|
||||
def test_json_api_condition_from_lingo_expression():
|
||||
lingo_dict = {
|
||||
"conditionType": "json-api",
|
||||
|
@ -187,9 +182,6 @@ def test_json_api_condition_from_lingo_expression():
|
|||
"ids": "ethereum",
|
||||
"vs_currencies": "usd",
|
||||
},
|
||||
"headers": {
|
||||
"Authorization": "Bearer 1234567890",
|
||||
},
|
||||
"returnValueTest": {
|
||||
"comparator": "==",
|
||||
"value": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118",
|
||||
|
|
Loading…
Reference in New Issue