Be more strict with chain field in Schema.

Update tests accordingly.
pull/3334/head
derekpierre 2023-11-07 10:52:51 -05:00 committed by KPrasch
parent ed5d32203e
commit 099a72e39e
4 changed files with 53 additions and 4 deletions

View File

@ -11,6 +11,7 @@ from typing import (
from eth_typing import ChecksumAddress
from eth_utils import to_checksum_address
from marshmallow import ValidationError, fields, post_load, validate, validates_schema
from marshmallow.validate import OneOf
from web3 import HTTPProvider, Web3
from web3.contract.contract import ContractFunction
from web3.middleware import geth_poa_middleware
@ -137,7 +138,9 @@ class RPCCondition(AccessControlCondition):
condition_type = fields.Str(
validate=validate.Equal(ConditionType.RPC.value), required=True
)
chain = fields.Int(required=True)
chain = fields.Int(
required=True, strict=True, validate=OneOf(_CONDITION_CHAINS)
)
method = fields.Str(required=True)
parameters = fields.List(fields.Field, attribute="parameters", required=False)
return_value_test = fields.Nested(

View File

@ -1,8 +1,9 @@
from typing import Any, List, Optional
from marshmallow import fields, post_load, validate
from marshmallow.validate import Equal, OneOf
from nucypher.policy.conditions.evm import RPCCondition
from nucypher.policy.conditions.evm import _CONDITION_CHAINS, RPCCondition
from nucypher.policy.conditions.exceptions import InvalidCondition
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
from nucypher.policy.conditions.utils import CamelCaseSchema
@ -18,8 +19,12 @@ class TimeCondition(RPCCondition):
validate=validate.Equal(ConditionType.TIME.value), required=True
)
name = fields.Str(required=False)
chain = fields.Int(required=True)
method = fields.Str(dump_default="blocktime", required=True)
chain = fields.Int(
required=True, strict=True, validate=OneOf(_CONDITION_CHAINS)
)
method = fields.Str(
dump_default="blocktime", required=True, validate=Equal("blocktime")
)
return_value_test = fields.Nested(
ReturnValueTest.ReturnValueTestSchema(), required=True
)

View File

@ -96,6 +96,16 @@ def test_rpc_condition_schema_validation(rpc_condition):
del condition_dict["returnValueTest"]
RPCCondition.validate(condition_dict)
with pytest.raises(InvalidCondition):
# chain id not an integer
condition_dict["chain"] = str(TESTERCHAIN_CHAIN_ID)
RPCCondition.validate(condition_dict)
with pytest.raises(InvalidCondition):
# chain id not a permitted chain
condition_dict["chain"] = 90210 # Beverly Hills Chain :)
RPCCondition.validate(condition_dict)
def test_rpc_condition_repr(rpc_condition):
rpc_condition_str = f"{rpc_condition}"

View File

@ -24,6 +24,22 @@ def test_invalid_time_condition():
method="time_after_time",
)
# invalid chain id
with pytest.raises(InvalidCondition):
_ = TimeCondition(
return_value_test=ReturnValueTest(">", 0),
chain="mychain",
method="time_after_time",
)
# chain id not permitted
with pytest.raises(InvalidCondition):
_ = TimeCondition(
return_value_test=ReturnValueTest(">", 0),
chain=90210, # Beverly Hills Chain :)
method="time_after_time",
)
def test_time_condition_schema_validation(time_condition):
condition_dict = time_condition.to_dict()
@ -47,6 +63,21 @@ def test_time_condition_schema_validation(time_condition):
del condition_dict["returnValueTest"]
TimeCondition.validate(condition_dict)
with pytest.raises(InvalidCondition):
# invalid method name
condition_dict["method"] = "my_blocktime"
TimeCondition.validate(condition_dict)
with pytest.raises(InvalidCondition):
# chain id not an integer
condition_dict["chain"] = str(TESTERCHAIN_CHAIN_ID)
TimeCondition.validate(condition_dict)
with pytest.raises(InvalidCondition):
# chain id not a permitted chain
condition_dict["chain"] = 90210 # Beverly Hills Chain :)
TimeCondition.validate(condition_dict)
@pytest.mark.parametrize(
"invalid_value", ["0x123456", 10.15, [1], [1, 2, 3], [True, [1, 2], "0x0"]]