From 06af880616e538df9b6a37f207f5fcffadc966e5 Mon Sep 17 00:00:00 2001 From: KPrasch Date: Fri, 28 Jun 2024 17:39:14 +0800 Subject: [PATCH] restrictive design: unexpose custom request headers and require https --- nucypher/policy/conditions/offchain.py | 20 +++++--- .../conditions/test_json_api_condition.py | 48 ++++++++----------- 2 files changed, 33 insertions(+), 35 deletions(-) diff --git a/nucypher/policy/conditions/offchain.py b/nucypher/policy/conditions/offchain.py index 1d04ec2bb..dcf8915c2 100644 --- a/nucypher/policy/conditions/offchain.py +++ b/nucypher/policy/conditions/offchain.py @@ -32,6 +32,17 @@ class JSONPathField(Field): return value +class HTTPSField(Field): + default_error_messages = { + "invalid": "'{value}' is not a valid HTTPS endpoint", + } + + def _deserialize(self, value, attr, data, **kwargs): + if not value.startswith("https://"): + raise self.make_error("invalid", value=value) + return value + + class JsonApiCondition(AccessControlCondition): """ A JSON API condition is a condition that can be evaluated by reading from a JSON @@ -48,9 +59,8 @@ class JsonApiCondition(AccessControlCondition): condition_type = fields.Str( validate=validate.Equal(ConditionType.JSONAPI.value), required=True ) - headers = fields.Dict(required=False) parameters = fields.Dict(required=False) - endpoint = fields.Str(required=True) + endpoint = HTTPSField(required=True) query = JSONPathField(required=True) return_value_test = fields.Nested( ReturnValueTest.ReturnValueTestSchema(), required=True @@ -65,7 +75,6 @@ class JsonApiCondition(AccessControlCondition): endpoint: str, query: Optional[str], return_value_test: ReturnValueTest, - headers: Optional[dict] = None, parameters: Optional[dict] = None, condition_type: str = ConditionType.JSONAPI.value, ): @@ -75,7 +84,6 @@ class JsonApiCondition(AccessControlCondition): ) self.endpoint = endpoint - self.headers = headers self.parameters = parameters self.query = query self.return_value_test = return_value_test @@ -84,9 +92,7 @@ class JsonApiCondition(AccessControlCondition): def fetch(self) -> requests.Response: """Fetches data from the endpoint.""" try: - response = requests.get( - self.endpoint, params=self.parameters, headers=self.headers - ) + response = requests.get(self.endpoint, params=self.parameters, timeout=5) response.raise_for_status() except requests.exceptions.HTTPError as http_error: self.logger.error(f"HTTP error occurred: {http_error}") diff --git a/tests/unit/conditions/test_json_api_condition.py b/tests/unit/conditions/test_json_api_condition.py index f15e4ec32..7e0866f35 100644 --- a/tests/unit/conditions/test_json_api_condition.py +++ b/tests/unit/conditions/test_json_api_condition.py @@ -9,7 +9,26 @@ from nucypher.policy.conditions.exceptions import ( InvalidCondition, ) from nucypher.policy.conditions.lingo import ConditionLingo, ReturnValueTest -from nucypher.policy.conditions.offchain import JsonApiCondition, JSONPathField +from nucypher.policy.conditions.offchain import ( + HTTPSField, + JsonApiCondition, + JSONPathField, +) + + +def test_https_field_valid(): + field = HTTPSField() + valid_https = "https://api.example.com/data" + result = field.deserialize(valid_https) + assert result == valid_https + + +def test_https_field_invalid(): + field = HTTPSField() + invalid_https = "http://api.example.com/data" + with pytest.raises(ValidationError) as excinfo: + field.deserialize(invalid_https) + assert f"'{invalid_https}' is not a valid HTTPS endpoint" in str(excinfo.value) def test_jsonpath_field_valid(): @@ -154,30 +173,6 @@ def test_basic_json_api_condition_evaluation_with_parameters(mocker): assert mocked_get.call_count == 1 -def test_basic_json_api_condition_evaluation_with_headers(mocker): - mocked_get = mocker.patch( - "requests.get", - return_value=mocker.Mock( - status_code=200, json=lambda: {"ethereum": {"usd": 0.0}} - ), - ) - - condition = JsonApiCondition( - endpoint="https://api.coingecko.com/api/v3/simple/price", - parameters={ - "ids": "ethereum", - "vs_currencies": "usd", - }, - headers={"Authorization": "Bearer 1234567890"}, - query="ethereum.usd", - return_value_test=ReturnValueTest("==", 0.0), - ) - - assert condition.verify() == (True, 0.0) - assert mocked_get.call_count == 1 - assert mocked_get.call_args[1]["headers"]["Authorization"] == "Bearer 1234567890" - - def test_json_api_condition_from_lingo_expression(): lingo_dict = { "conditionType": "json-api", @@ -187,9 +182,6 @@ def test_json_api_condition_from_lingo_expression(): "ids": "ethereum", "vs_currencies": "usd", }, - "headers": { - "Authorization": "Bearer 1234567890", - }, "returnValueTest": { "comparator": "==", "value": "0xaDD9D957170dF6F33982001E4c22eCCdd5539118",