Add tests for rpc hierarchy and unsupported chains

pull/3566/head
James Campbell 2024-12-05 15:21:59 +01:00
parent 17042eb847
commit afb7f0deb8
No known key found for this signature in database
2 changed files with 72 additions and 7 deletions

View File

@ -87,13 +87,17 @@ class RPCCall(ExecutionCall):
parameters = fields.List(
fields.Field, attribute="parameters", required=False, allow_none=True
)
rpc_endpoint = fields.Url(required=False, relative=False, allow_none=True)
rpc_endpoint = fields.Url(
attribute="rpc_endpoint", required=False, relative=False, allow_none=True
)
@validates("chain")
def validate_chain(self, value):
if value not in _CONDITION_CHAINS:
@validates_schema
def validate_chain(self, data, **kwargs):
chain = data.get("chain")
rpc_endpoint = data.get("rpc_endpoint")
if not rpc_endpoint and chain not in _CONDITION_CHAINS:
raise ValidationError(
f"chain ID {value} is not a permitted blockchain for condition evaluation"
f"chain ID {chain} is not a permitted blockchain for condition evaluation"
)
@validates("method")
@ -295,6 +299,10 @@ class RPCCondition(ExecutionCallAccessControlCondition):
def parameters(self):
return self.execution_call.parameters
@property
def rpc_endpoint(self):
return self.execution_call.rpc_endpoint
def _align_comparator_value_with_abi(
self, return_value_test: ReturnValueTest
) -> ReturnValueTest:

View File

@ -1,4 +1,5 @@
import pytest
from web3 import HTTPProvider
from nucypher.policy.conditions.evm import RPCCall, RPCCondition
from nucypher.policy.conditions.exceptions import (
@ -145,7 +146,7 @@ def test_rpc_condition_uses_provided_endpoint(mocker):
# Mock eth module
mock_eth = mocker.Mock()
mock_eth.get_balance.return_value = 0
mock_eth.chain_id = TESTERCHAIN_CHAIN_ID
mock_eth.chain_id = 8453
# Create Web3 mock with required attributes
mock_w3 = mocker.Mock()
@ -158,7 +159,7 @@ def test_rpc_condition_uses_provided_endpoint(mocker):
# Mock _next_endpoint method
_ = mocker.patch.object(RPCCall, "_next_endpoint")
rpc_endpoint = "https://eth-mainnet.example.com"
rpc_endpoint = "https://base.example.com"
condition = RPCCondition(
method="eth_getBalance",
chain=TESTERCHAIN_CHAIN_ID,
@ -173,3 +174,59 @@ def test_rpc_condition_uses_provided_endpoint(mocker):
# Verify the endpoint was used
mock_http_provider.assert_called_once_with(rpc_endpoint)
assert not condition.execution_call._next_endpoint.called
def test_rpc_condition_execution_priority(mocker):
# Mock HTTPProvider
mock_provider = mocker.Mock()
mock_http_provider = mocker.patch(
"nucypher.policy.conditions.evm.HTTPProvider", return_value=mock_provider
)
# Mock eth module with successful response
mock_eth = mocker.Mock()
mock_eth.get_balance.return_value = 100 # Set a non-zero balance
mock_eth.chain_id = TESTERCHAIN_CHAIN_ID
mock_w3 = mocker.Mock()
mock_w3.eth = mock_eth
mock_w3.middleware_onion = mocker.Mock()
mocker.patch("nucypher.policy.conditions.evm.Web3", return_value=mock_w3)
# Test Case 1: Chain in providers - should use local provider only
local_provider = HTTPProvider("https://local-provider.example.com")
providers = {TESTERCHAIN_CHAIN_ID: {local_provider}}
condition = RPCCondition(
method="eth_getBalance",
chain=TESTERCHAIN_CHAIN_ID,
return_value_test=ReturnValueTest("==", 100), # Match the mock response
parameters=["0xaDD9D957170dF6F33982001E4c22eCCdd5539118"],
rpc_endpoint="https://fallback.example.com",
)
condition.verify(providers=providers)
mock_http_provider.assert_not_called() # Fallback endpoint not used
# Test Case 2: Unsupported chain - should use rpc_endpoint
unsupported_chain = 99999 # Chain not in _CONDITION_CHAINS
condition = RPCCondition(
method="eth_getBalance",
chain=unsupported_chain,
return_value_test=ReturnValueTest("==", 0),
parameters=["0xaDD9D957170dF6F33982001E4c22eCCdd5539118"],
rpc_endpoint="https://fallback.example.com",
)
condition.verify(providers={})
mock_http_provider.assert_called_once_with("https://fallback.example.com")
# Test Case 3: Unsupported chain with no rpc_endpoint - should raise errorq
with pytest.raises(InvalidCondition):
condition = RPCCondition(
method="eth_getBalance",
chain=unsupported_chain,
return_value_test=ReturnValueTest("==", 0),
parameters=["0xaDD9D957170dF6F33982001E4c22eCCdd5539118"],
)