mirror of https://github.com/nucypher/nucypher.git
parent
ed5d32203e
commit
099a72e39e
|
@ -11,6 +11,7 @@ from typing import (
|
|||
from eth_typing import ChecksumAddress
|
||||
from eth_utils import to_checksum_address
|
||||
from marshmallow import ValidationError, fields, post_load, validate, validates_schema
|
||||
from marshmallow.validate import OneOf
|
||||
from web3 import HTTPProvider, Web3
|
||||
from web3.contract.contract import ContractFunction
|
||||
from web3.middleware import geth_poa_middleware
|
||||
|
@ -137,7 +138,9 @@ class RPCCondition(AccessControlCondition):
|
|||
condition_type = fields.Str(
|
||||
validate=validate.Equal(ConditionType.RPC.value), required=True
|
||||
)
|
||||
chain = fields.Int(required=True)
|
||||
chain = fields.Int(
|
||||
required=True, strict=True, validate=OneOf(_CONDITION_CHAINS)
|
||||
)
|
||||
method = fields.Str(required=True)
|
||||
parameters = fields.List(fields.Field, attribute="parameters", required=False)
|
||||
return_value_test = fields.Nested(
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from typing import Any, List, Optional
|
||||
|
||||
from marshmallow import fields, post_load, validate
|
||||
from marshmallow.validate import Equal, OneOf
|
||||
|
||||
from nucypher.policy.conditions.evm import RPCCondition
|
||||
from nucypher.policy.conditions.evm import _CONDITION_CHAINS, RPCCondition
|
||||
from nucypher.policy.conditions.exceptions import InvalidCondition
|
||||
from nucypher.policy.conditions.lingo import ConditionType, ReturnValueTest
|
||||
from nucypher.policy.conditions.utils import CamelCaseSchema
|
||||
|
@ -18,8 +19,12 @@ class TimeCondition(RPCCondition):
|
|||
validate=validate.Equal(ConditionType.TIME.value), required=True
|
||||
)
|
||||
name = fields.Str(required=False)
|
||||
chain = fields.Int(required=True)
|
||||
method = fields.Str(dump_default="blocktime", required=True)
|
||||
chain = fields.Int(
|
||||
required=True, strict=True, validate=OneOf(_CONDITION_CHAINS)
|
||||
)
|
||||
method = fields.Str(
|
||||
dump_default="blocktime", required=True, validate=Equal("blocktime")
|
||||
)
|
||||
return_value_test = fields.Nested(
|
||||
ReturnValueTest.ReturnValueTestSchema(), required=True
|
||||
)
|
||||
|
|
|
@ -96,6 +96,16 @@ def test_rpc_condition_schema_validation(rpc_condition):
|
|||
del condition_dict["returnValueTest"]
|
||||
RPCCondition.validate(condition_dict)
|
||||
|
||||
with pytest.raises(InvalidCondition):
|
||||
# chain id not an integer
|
||||
condition_dict["chain"] = str(TESTERCHAIN_CHAIN_ID)
|
||||
RPCCondition.validate(condition_dict)
|
||||
|
||||
with pytest.raises(InvalidCondition):
|
||||
# chain id not a permitted chain
|
||||
condition_dict["chain"] = 90210 # Beverly Hills Chain :)
|
||||
RPCCondition.validate(condition_dict)
|
||||
|
||||
|
||||
def test_rpc_condition_repr(rpc_condition):
|
||||
rpc_condition_str = f"{rpc_condition}"
|
||||
|
|
|
@ -24,6 +24,22 @@ def test_invalid_time_condition():
|
|||
method="time_after_time",
|
||||
)
|
||||
|
||||
# invalid chain id
|
||||
with pytest.raises(InvalidCondition):
|
||||
_ = TimeCondition(
|
||||
return_value_test=ReturnValueTest(">", 0),
|
||||
chain="mychain",
|
||||
method="time_after_time",
|
||||
)
|
||||
|
||||
# chain id not permitted
|
||||
with pytest.raises(InvalidCondition):
|
||||
_ = TimeCondition(
|
||||
return_value_test=ReturnValueTest(">", 0),
|
||||
chain=90210, # Beverly Hills Chain :)
|
||||
method="time_after_time",
|
||||
)
|
||||
|
||||
|
||||
def test_time_condition_schema_validation(time_condition):
|
||||
condition_dict = time_condition.to_dict()
|
||||
|
@ -47,6 +63,21 @@ def test_time_condition_schema_validation(time_condition):
|
|||
del condition_dict["returnValueTest"]
|
||||
TimeCondition.validate(condition_dict)
|
||||
|
||||
with pytest.raises(InvalidCondition):
|
||||
# invalid method name
|
||||
condition_dict["method"] = "my_blocktime"
|
||||
TimeCondition.validate(condition_dict)
|
||||
|
||||
with pytest.raises(InvalidCondition):
|
||||
# chain id not an integer
|
||||
condition_dict["chain"] = str(TESTERCHAIN_CHAIN_ID)
|
||||
TimeCondition.validate(condition_dict)
|
||||
|
||||
with pytest.raises(InvalidCondition):
|
||||
# chain id not a permitted chain
|
||||
condition_dict["chain"] = 90210 # Beverly Hills Chain :)
|
||||
TimeCondition.validate(condition_dict)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_value", ["0x123456", 10.15, [1], [1, 2, 3], [True, [1, 2], "0x0"]]
|
||||
|
|
Loading…
Reference in New Issue