Add SequentialContractCondition condition type, that allows contract calls to be made sequentially, and results of one contract call can be passed into a subsequent call.

pull/3500/head
derekpierre 2024-05-06 15:34:34 -04:00
parent 9fa873630c
commit 4fdffe975f
No known key found for this signature in database
2 changed files with 127 additions and 1 deletions

View File

@ -6,6 +6,8 @@ from typing import (
Optional,
Set,
Tuple,
Type,
Union,
)
from eth_typing import ChecksumAddress
@ -475,3 +477,125 @@ class ContractCondition(RPCCondition):
abi=self.contract_function.contract_abi[0],
return_value_test=return_value_test,
)
class SequentialContractCondition:
CONDITION_TYPE = ConditionType.SEQUENTIAL_CONTRACTS.value
MAX_NUM_CALLS = 5
@classmethod
def _validate_contract_calls(
cls,
contract_calls: List,
exception_class: Union[Type[ValidationError], Type[InvalidCondition]],
):
num_contract_calls = len(contract_calls)
if num_contract_calls == 0:
raise exception_class("Must be at least one contract call")
if num_contract_calls > cls.MAX_NUM_CALLS:
raise exception_class(
f"Maximum of {cls.MAX_NUM_CALLS} contract calls are allowed"
)
class Schema(CamelCaseSchema):
SKIP_VALUES = (None,)
condition_type = fields.Str(
validate=validate.Equal(ConditionType.SEQUENTIAL_CONTRACTS.value),
required=True,
)
contract_calls = fields.List(
fields.Nested(ContractCall.Schema()), required=True
)
return_value_test = fields.Nested(
ReturnValueTest.ReturnValueTestSchema(), required=True
)
# maintain field declaration ordering
class Meta:
ordered = True
@validates_schema
def validate_contract_calls(self, data, **kwargs):
contract_calls = data["contract_calls"]
SequentialContractCondition._validate_contract_calls(
contract_calls, ValidationError
)
@post_load
def make(self, data, **kwargs):
return SequentialContractCondition(**data)
def __init__(
self,
contract_calls: List[ContractCall],
return_value_test: ReturnValueTest,
condition_type: str = CONDITION_TYPE,
):
# internal
if condition_type != self.CONDITION_TYPE:
raise InvalidCondition(
f"{self.__class__.__name__} must be instantiated with the {self.CONDITION_TYPE} type."
)
self._validate_contract_calls(
contract_calls=contract_calls, exception_class=InvalidCondition
)
self.contract_calls = contract_calls
self.condition_type = condition_type
self.return_value_test = return_value_test # output
final_contract_function = self.contract_calls[-1].contract_function
_validate_contract_function_expected_return_type(
contract_function=final_contract_function,
return_value_test=self.return_value_test,
)
def __repr__(self):
final_call = self.contract_calls[-1]
r = f"{self.__class__.__name__}({len(self.contract_calls)} contract calls, final_function={final_call.method} on chain={final_call.chain})"
return r
def verify(
self, providers: Dict[int, Set[HTTPProvider]], **context
) -> Tuple[bool, Any]:
resolved_return_value_test = self.return_value_test.with_resolved_context(
**context
)
return_value_test = _align_comparator_value_with_abi(
abi=self.contract_calls[-1].contract_function.contract_abi[0],
return_value_test=resolved_return_value_test,
)
current_value = None
for index, contract_call in enumerate(self.contract_calls):
current_value = contract_call.execute(providers=providers, **context)
context[f":multi_contract_condition_{index+1}_result"] = current_value
final_result = current_value
eval_result = return_value_test.eval(final_result) # test
return eval_result, final_result
def _validate_contract_function_expected_return_type(
contract_function: ContractFunction, return_value_test: ReturnValueTest
) -> None:
output_abi_types = _get_abi_types(contract_function.contract_abi[0])
comparator_value = return_value_test.value
comparator_index = return_value_test.index
index_string = f"@index={comparator_index}" if comparator_index is not None else ""
failure_message = (
f"Invalid return value comparison type '{type(comparator_value)}' for "
f"'{contract_function.fn_name}'{index_string} based on ABI types {output_abi_types}"
)
if len(output_abi_types) == 1:
_validate_single_output_type(
output_abi_types[0], comparator_value, comparator_index, failure_message
)
else:
_validate_multiple_output_types(
output_abi_types, comparator_value, comparator_index, failure_message
)

View File

@ -75,6 +75,7 @@ class ConditionType(Enum):
RPC = "rpc"
JSONAPI = "json-api"
COMPOUND = "compound"
SEQUENTIAL_CONTRACTS = "sequentialContracts"
@classmethod
def values(cls) -> List[str]:
@ -425,7 +426,7 @@ class ConditionLingo(_Serializable):
Inspects a given bloc of JSON and attempts to resolve it's intended datatype within the
conditions expression framework.
"""
from nucypher.policy.conditions.evm import ContractCondition, RPCCondition
from nucypher.policy.conditions.evm import ContractCondition, RPCCondition, SequentialContractCondition
from nucypher.policy.conditions.offchain import JsonApiCondition
from nucypher.policy.conditions.time import TimeCondition
@ -438,6 +439,7 @@ class ConditionLingo(_Serializable):
RPCCondition,
CompoundAccessControlCondition,
JsonApiCondition,
SequentialContractCondition,
):
if condition.CONDITION_TYPE == condition_type:
return condition