Add use of ConditionProviderManager across lingo evaluation / condition verification.

pull/3576/head
derekpierre 2025-01-24 14:35:47 -05:00
parent 27a02837c4
commit 24d0669940
No known key found for this signature in database
5 changed files with 47 additions and 72 deletions

View File

@ -4,7 +4,7 @@ import time
import traceback
from collections import defaultdict
from decimal import Decimal
from typing import DefaultDict, Dict, List, Optional, Union
from typing import Dict, List, Optional, Union
import maya
from atxm.exceptions import InsufficientFunds
@ -65,7 +65,10 @@ from nucypher.crypto.powers import (
TransactingPower,
)
from nucypher.datastore.dkg import DKGStorage
from nucypher.policy.conditions.utils import evaluate_condition_lingo
from nucypher.policy.conditions.utils import (
ConditionProviderManager,
evaluate_condition_lingo,
)
from nucypher.policy.payment import ContractPayment
from nucypher.types import PhaseId
from nucypher.utilities.emitters import StdoutEmitter
@ -247,7 +250,7 @@ class Operator(BaseActor):
ThresholdRequestDecryptingPower
) # used for secure decryption request channel
self.condition_providers = self.connect_condition_providers(
self.condition_provider_manager = self.get_condition_provider_manager(
condition_blockchain_endpoints
)
@ -269,9 +272,9 @@ class Operator(BaseActor):
provider = HTTPProvider(endpoint_uri=uri)
return provider
def connect_condition_providers(
def get_condition_provider_manager(
self, operator_configured_endpoints: Dict[int, List[str]]
) -> DefaultDict[int, List[HTTPProvider]]:
) -> ConditionProviderManager:
# check that we have mandatory user configured endpoints
mandatory_configured_chains = {
@ -336,7 +339,7 @@ class Operator(BaseActor):
f"checking on chain IDs {providers.keys()}"
)
return providers
return ConditionProviderManager(providers=providers)
def _resolve_ritual(self, ritual_id: int) -> Coordinator.Ritual:
if not self.coordinator_agent.is_ritual_active(ritual_id=ritual_id):
@ -845,7 +848,7 @@ class Operator(BaseActor):
evaluate_condition_lingo(
condition_lingo=condition_lingo,
context=context,
providers=self.condition_providers,
providers=self.condition_provider_manager,
)
def _verify_decryption_request_authorization(

View File

@ -260,7 +260,7 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
try:
evaluate_condition_lingo(
condition_lingo=condition_lingo,
providers=this_node.condition_providers,
providers=this_node.condition_provider_manager,
context=context,
)
except ConditionEvalError as error:

View File

@ -1,10 +1,7 @@
from typing import (
Any,
Dict,
Iterator,
List,
Optional,
Set,
Tuple,
)
@ -20,9 +17,7 @@ from marshmallow import (
)
from marshmallow.validate import OneOf
from typing_extensions import override
from web3 import HTTPProvider, Web3
from web3.middleware import geth_poa_middleware
from web3.providers import BaseProvider
from web3 import Web3
from web3.types import ABIFunction
from nucypher.policy.conditions import STANDARD_ABI_CONTRACT_TYPES
@ -34,8 +29,6 @@ from nucypher.policy.conditions.context import (
resolve_any_context_variables,
)
from nucypher.policy.conditions.exceptions import (
InvalidConnectionToChain,
NoConnectionToChain,
RequiredContextVariable,
RPCExecutionFailed,
)
@ -44,7 +37,10 @@ from nucypher.policy.conditions.lingo import (
ExecutionCallAccessControlCondition,
ReturnValueTest,
)
from nucypher.policy.conditions.utils import camel_case_to_snake
from nucypher.policy.conditions.utils import (
ConditionProviderManager,
camel_case_to_snake,
)
from nucypher.policy.conditions.validation import (
align_comparator_value_with_abi,
get_unbound_contract_function,
@ -106,58 +102,18 @@ class RPCCall(ExecutionCall):
) # bind contract function (only exposes the eth API)
return rpc_function
def _configure_w3(self, provider: BaseProvider) -> Web3:
# Instantiate a local web3 instance
w3 = Web3(provider)
# inject web3 middleware to handle POA chain extra_data field.
w3.middleware_onion.inject(geth_poa_middleware, layer=0, name="poa")
return w3
def _check_chain_id(self, w3: Web3) -> None:
"""
Validates that the actual web3 provider is *actually*
connected to the condition's chain ID by reading its RPC endpoint.
"""
provider_chain = w3.eth.chain_id
if provider_chain != self.chain:
raise InvalidConnectionToChain(
expected_chain=self.chain,
actual_chain=provider_chain,
message=f"This rpc call can only be evaluated on chain ID {self.chain} but the provider's "
f"connection is to chain ID {provider_chain}",
)
def _configure_provider(self, provider: BaseProvider):
"""Binds the condition's contract function to a blockchain provider for evaluation"""
w3 = self._configure_w3(provider=provider)
self._check_chain_id(w3)
return w3
def _next_endpoint(
self, providers: Dict[int, Set[HTTPProvider]]
) -> Iterator[HTTPProvider]:
"""Yields the next web3 provider to try for a given chain ID"""
rpc_providers = providers.get(self.chain, None)
if not rpc_providers:
raise NoConnectionToChain(chain=self.chain)
for provider in rpc_providers:
# Someday, we might make this whole function async, and then we can knock on
# each endpoint here to see if it's alive and only yield it if it is.
yield provider
def execute(self, providers: Dict[int, Set[HTTPProvider]], **context) -> Any:
def execute(self, providers: ConditionProviderManager, **context) -> Any:
resolved_parameters = []
if self.parameters:
resolved_parameters = resolve_any_context_variables(
self.parameters, **context
)
endpoints = self._next_endpoint(providers=providers)
endpoints = providers.web3_endpoints(self.chain)
latest_error = ""
for provider in endpoints:
for w3 in endpoints:
try:
w3 = self._configure_provider(provider)
result = self._execute(w3, resolved_parameters)
break
except RequiredContextVariable:
@ -257,7 +213,7 @@ class RPCCondition(ExecutionCallAccessControlCondition):
return return_value_test
def verify(
self, providers: Dict[int, Set[HTTPProvider]], **context
self, providers: ConditionProviderManager, **context
) -> Tuple[bool, Any]:
resolved_return_value_test = self.return_value_test.with_resolved_context(
**context

View File

@ -4,7 +4,7 @@ import json
import operator as pyoperator
from enum import Enum
from hashlib import md5
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
from typing import Any, List, Optional, Tuple, Type, Union
from hexbytes import HexBytes
from marshmallow import (
@ -19,7 +19,6 @@ from marshmallow import (
)
from marshmallow.validate import OneOf, Range
from packaging.version import parse as parse_version
from web3 import HTTPProvider
from nucypher.policy.conditions.base import (
AccessControlCondition,
@ -37,7 +36,7 @@ from nucypher.policy.conditions.exceptions import (
ReturnValueEvaluationError,
)
from nucypher.policy.conditions.types import ConditionDict, Lingo
from nucypher.policy.conditions.utils import CamelCaseSchema
from nucypher.policy.conditions.utils import CamelCaseSchema, ConditionProviderManager
class _ConditionField(fields.Dict):
@ -339,7 +338,7 @@ class SequentialAccessControlCondition(MultiConditionAccessControl):
# TODO - think about not dereferencing context but using a dict;
# may allows more freedom for params
def verify(
self, providers: Dict[int, Set[HTTPProvider]], **context
self, providers: ConditionProviderManager, **context
) -> Tuple[bool, Any]:
values = []
latest_success = False

View File

@ -1,6 +1,6 @@
import re
from http import HTTPStatus
from typing import Dict, Iterator, List, Optional, Set, Tuple
from typing import Dict, Iterator, List, Optional, Tuple
from marshmallow import Schema, post_dump
from marshmallow.exceptions import SCHEMA
@ -28,16 +28,33 @@ __LOGGER = Logger("condition-eval")
class ConditionProviderManager:
def __init__(self, providers: Dict[int, List[HTTPProvider]]):
self.providers = providers
self.logger = Logger(__name__)
def web3_endpoints(self, chain_id: int) -> Iterator[Web3]:
rpc_providers = self.providers.get(chain_id, None)
if not rpc_providers:
raise NoConnectionToChain(chain=chain_id)
iterator_returned_at_least_one = False
for provider in rpc_providers:
w3 = self._configure_w3(provider=provider)
self._check_chain_id(chain_id, w3)
yield w3
try:
w3 = self._configure_w3(provider=provider)
self._check_chain_id(chain_id, w3)
yield w3
iterator_returned_at_least_one = True
except InvalidConnectionToChain as e:
# don't expect to happen but must account
# for any misconfigurations of public endpoints
self.logger.warn(
f"Invalid blockchain connection; expected chain ID {e.expected_chain}, but detected {e.actual_chain}"
)
# if we get here, it is because there were endpoints, but issue with configuring them
if not iterator_returned_at_least_one:
raise NoConnectionToChain(
chain=chain_id,
message=f"Problematic provider connections for chain ID {chain_id}",
)
@staticmethod
def _configure_w3(provider: BaseProvider) -> Web3:
@ -97,7 +114,7 @@ class CamelCaseSchema(Schema):
def evaluate_condition_lingo(
condition_lingo: Lingo,
providers: Optional[Dict[int, Set[BaseProvider]]] = None,
providers: Optional[ConditionProviderManager] = None,
context: Optional[ContextDict] = None,
log: Logger = __LOGGER,
):
@ -113,7 +130,7 @@ def evaluate_condition_lingo(
# Setup (don't use mutable defaults)
context = context or dict()
providers = providers or dict()
providers = providers or ConditionProviderManager(providers=dict())
error = None
# Evaluate