mirror of https://github.com/nucypher/nucypher.git
Add use of ConditionProviderManager across lingo evaluation / condition verification.
parent
27a02837c4
commit
24d0669940
|
@ -4,7 +4,7 @@ import time
|
|||
import traceback
|
||||
from collections import defaultdict
|
||||
from decimal import Decimal
|
||||
from typing import DefaultDict, Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import maya
|
||||
from atxm.exceptions import InsufficientFunds
|
||||
|
@ -65,7 +65,10 @@ from nucypher.crypto.powers import (
|
|||
TransactingPower,
|
||||
)
|
||||
from nucypher.datastore.dkg import DKGStorage
|
||||
from nucypher.policy.conditions.utils import evaluate_condition_lingo
|
||||
from nucypher.policy.conditions.utils import (
|
||||
ConditionProviderManager,
|
||||
evaluate_condition_lingo,
|
||||
)
|
||||
from nucypher.policy.payment import ContractPayment
|
||||
from nucypher.types import PhaseId
|
||||
from nucypher.utilities.emitters import StdoutEmitter
|
||||
|
@ -247,7 +250,7 @@ class Operator(BaseActor):
|
|||
ThresholdRequestDecryptingPower
|
||||
) # used for secure decryption request channel
|
||||
|
||||
self.condition_providers = self.connect_condition_providers(
|
||||
self.condition_provider_manager = self.get_condition_provider_manager(
|
||||
condition_blockchain_endpoints
|
||||
)
|
||||
|
||||
|
@ -269,9 +272,9 @@ class Operator(BaseActor):
|
|||
provider = HTTPProvider(endpoint_uri=uri)
|
||||
return provider
|
||||
|
||||
def connect_condition_providers(
|
||||
def get_condition_provider_manager(
|
||||
self, operator_configured_endpoints: Dict[int, List[str]]
|
||||
) -> DefaultDict[int, List[HTTPProvider]]:
|
||||
) -> ConditionProviderManager:
|
||||
|
||||
# check that we have mandatory user configured endpoints
|
||||
mandatory_configured_chains = {
|
||||
|
@ -336,7 +339,7 @@ class Operator(BaseActor):
|
|||
f"checking on chain IDs {providers.keys()}"
|
||||
)
|
||||
|
||||
return providers
|
||||
return ConditionProviderManager(providers=providers)
|
||||
|
||||
def _resolve_ritual(self, ritual_id: int) -> Coordinator.Ritual:
|
||||
if not self.coordinator_agent.is_ritual_active(ritual_id=ritual_id):
|
||||
|
@ -845,7 +848,7 @@ class Operator(BaseActor):
|
|||
evaluate_condition_lingo(
|
||||
condition_lingo=condition_lingo,
|
||||
context=context,
|
||||
providers=self.condition_providers,
|
||||
providers=self.condition_provider_manager,
|
||||
)
|
||||
|
||||
def _verify_decryption_request_authorization(
|
||||
|
|
|
@ -260,7 +260,7 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
|
|||
try:
|
||||
evaluate_condition_lingo(
|
||||
condition_lingo=condition_lingo,
|
||||
providers=this_node.condition_providers,
|
||||
providers=this_node.condition_provider_manager,
|
||||
context=context,
|
||||
)
|
||||
except ConditionEvalError as error:
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
|
@ -20,9 +17,7 @@ from marshmallow import (
|
|||
)
|
||||
from marshmallow.validate import OneOf
|
||||
from typing_extensions import override
|
||||
from web3 import HTTPProvider, Web3
|
||||
from web3.middleware import geth_poa_middleware
|
||||
from web3.providers import BaseProvider
|
||||
from web3 import Web3
|
||||
from web3.types import ABIFunction
|
||||
|
||||
from nucypher.policy.conditions import STANDARD_ABI_CONTRACT_TYPES
|
||||
|
@ -34,8 +29,6 @@ from nucypher.policy.conditions.context import (
|
|||
resolve_any_context_variables,
|
||||
)
|
||||
from nucypher.policy.conditions.exceptions import (
|
||||
InvalidConnectionToChain,
|
||||
NoConnectionToChain,
|
||||
RequiredContextVariable,
|
||||
RPCExecutionFailed,
|
||||
)
|
||||
|
@ -44,7 +37,10 @@ from nucypher.policy.conditions.lingo import (
|
|||
ExecutionCallAccessControlCondition,
|
||||
ReturnValueTest,
|
||||
)
|
||||
from nucypher.policy.conditions.utils import camel_case_to_snake
|
||||
from nucypher.policy.conditions.utils import (
|
||||
ConditionProviderManager,
|
||||
camel_case_to_snake,
|
||||
)
|
||||
from nucypher.policy.conditions.validation import (
|
||||
align_comparator_value_with_abi,
|
||||
get_unbound_contract_function,
|
||||
|
@ -106,58 +102,18 @@ class RPCCall(ExecutionCall):
|
|||
) # bind contract function (only exposes the eth API)
|
||||
return rpc_function
|
||||
|
||||
def _configure_w3(self, provider: BaseProvider) -> Web3:
|
||||
# Instantiate a local web3 instance
|
||||
w3 = Web3(provider)
|
||||
# inject web3 middleware to handle POA chain extra_data field.
|
||||
w3.middleware_onion.inject(geth_poa_middleware, layer=0, name="poa")
|
||||
return w3
|
||||
|
||||
def _check_chain_id(self, w3: Web3) -> None:
|
||||
"""
|
||||
Validates that the actual web3 provider is *actually*
|
||||
connected to the condition's chain ID by reading its RPC endpoint.
|
||||
"""
|
||||
provider_chain = w3.eth.chain_id
|
||||
if provider_chain != self.chain:
|
||||
raise InvalidConnectionToChain(
|
||||
expected_chain=self.chain,
|
||||
actual_chain=provider_chain,
|
||||
message=f"This rpc call can only be evaluated on chain ID {self.chain} but the provider's "
|
||||
f"connection is to chain ID {provider_chain}",
|
||||
)
|
||||
|
||||
def _configure_provider(self, provider: BaseProvider):
|
||||
"""Binds the condition's contract function to a blockchain provider for evaluation"""
|
||||
w3 = self._configure_w3(provider=provider)
|
||||
self._check_chain_id(w3)
|
||||
return w3
|
||||
|
||||
def _next_endpoint(
|
||||
self, providers: Dict[int, Set[HTTPProvider]]
|
||||
) -> Iterator[HTTPProvider]:
|
||||
"""Yields the next web3 provider to try for a given chain ID"""
|
||||
rpc_providers = providers.get(self.chain, None)
|
||||
if not rpc_providers:
|
||||
raise NoConnectionToChain(chain=self.chain)
|
||||
|
||||
for provider in rpc_providers:
|
||||
# Someday, we might make this whole function async, and then we can knock on
|
||||
# each endpoint here to see if it's alive and only yield it if it is.
|
||||
yield provider
|
||||
|
||||
def execute(self, providers: Dict[int, Set[HTTPProvider]], **context) -> Any:
|
||||
def execute(self, providers: ConditionProviderManager, **context) -> Any:
|
||||
resolved_parameters = []
|
||||
if self.parameters:
|
||||
resolved_parameters = resolve_any_context_variables(
|
||||
self.parameters, **context
|
||||
)
|
||||
|
||||
endpoints = self._next_endpoint(providers=providers)
|
||||
endpoints = providers.web3_endpoints(self.chain)
|
||||
|
||||
latest_error = ""
|
||||
for provider in endpoints:
|
||||
for w3 in endpoints:
|
||||
try:
|
||||
w3 = self._configure_provider(provider)
|
||||
result = self._execute(w3, resolved_parameters)
|
||||
break
|
||||
except RequiredContextVariable:
|
||||
|
@ -257,7 +213,7 @@ class RPCCondition(ExecutionCallAccessControlCondition):
|
|||
return return_value_test
|
||||
|
||||
def verify(
|
||||
self, providers: Dict[int, Set[HTTPProvider]], **context
|
||||
self, providers: ConditionProviderManager, **context
|
||||
) -> Tuple[bool, Any]:
|
||||
resolved_return_value_test = self.return_value_test.with_resolved_context(
|
||||
**context
|
||||
|
|
|
@ -4,7 +4,7 @@ import json
|
|||
import operator as pyoperator
|
||||
from enum import Enum
|
||||
from hashlib import md5
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Any, List, Optional, Tuple, Type, Union
|
||||
|
||||
from hexbytes import HexBytes
|
||||
from marshmallow import (
|
||||
|
@ -19,7 +19,6 @@ from marshmallow import (
|
|||
)
|
||||
from marshmallow.validate import OneOf, Range
|
||||
from packaging.version import parse as parse_version
|
||||
from web3 import HTTPProvider
|
||||
|
||||
from nucypher.policy.conditions.base import (
|
||||
AccessControlCondition,
|
||||
|
@ -37,7 +36,7 @@ from nucypher.policy.conditions.exceptions import (
|
|||
ReturnValueEvaluationError,
|
||||
)
|
||||
from nucypher.policy.conditions.types import ConditionDict, Lingo
|
||||
from nucypher.policy.conditions.utils import CamelCaseSchema
|
||||
from nucypher.policy.conditions.utils import CamelCaseSchema, ConditionProviderManager
|
||||
|
||||
|
||||
class _ConditionField(fields.Dict):
|
||||
|
@ -339,7 +338,7 @@ class SequentialAccessControlCondition(MultiConditionAccessControl):
|
|||
# TODO - think about not dereferencing context but using a dict;
|
||||
# may allows more freedom for params
|
||||
def verify(
|
||||
self, providers: Dict[int, Set[HTTPProvider]], **context
|
||||
self, providers: ConditionProviderManager, **context
|
||||
) -> Tuple[bool, Any]:
|
||||
values = []
|
||||
latest_success = False
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import re
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, Iterator, List, Optional, Set, Tuple
|
||||
from typing import Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
from marshmallow import Schema, post_dump
|
||||
from marshmallow.exceptions import SCHEMA
|
||||
|
@ -28,16 +28,33 @@ __LOGGER = Logger("condition-eval")
|
|||
class ConditionProviderManager:
|
||||
def __init__(self, providers: Dict[int, List[HTTPProvider]]):
|
||||
self.providers = providers
|
||||
self.logger = Logger(__name__)
|
||||
|
||||
def web3_endpoints(self, chain_id: int) -> Iterator[Web3]:
|
||||
rpc_providers = self.providers.get(chain_id, None)
|
||||
if not rpc_providers:
|
||||
raise NoConnectionToChain(chain=chain_id)
|
||||
|
||||
iterator_returned_at_least_one = False
|
||||
for provider in rpc_providers:
|
||||
w3 = self._configure_w3(provider=provider)
|
||||
self._check_chain_id(chain_id, w3)
|
||||
yield w3
|
||||
try:
|
||||
w3 = self._configure_w3(provider=provider)
|
||||
self._check_chain_id(chain_id, w3)
|
||||
yield w3
|
||||
iterator_returned_at_least_one = True
|
||||
except InvalidConnectionToChain as e:
|
||||
# don't expect to happen but must account
|
||||
# for any misconfigurations of public endpoints
|
||||
self.logger.warn(
|
||||
f"Invalid blockchain connection; expected chain ID {e.expected_chain}, but detected {e.actual_chain}"
|
||||
)
|
||||
|
||||
# if we get here, it is because there were endpoints, but issue with configuring them
|
||||
if not iterator_returned_at_least_one:
|
||||
raise NoConnectionToChain(
|
||||
chain=chain_id,
|
||||
message=f"Problematic provider connections for chain ID {chain_id}",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_w3(provider: BaseProvider) -> Web3:
|
||||
|
@ -97,7 +114,7 @@ class CamelCaseSchema(Schema):
|
|||
|
||||
def evaluate_condition_lingo(
|
||||
condition_lingo: Lingo,
|
||||
providers: Optional[Dict[int, Set[BaseProvider]]] = None,
|
||||
providers: Optional[ConditionProviderManager] = None,
|
||||
context: Optional[ContextDict] = None,
|
||||
log: Logger = __LOGGER,
|
||||
):
|
||||
|
@ -113,7 +130,7 @@ def evaluate_condition_lingo(
|
|||
|
||||
# Setup (don't use mutable defaults)
|
||||
context = context or dict()
|
||||
providers = providers or dict()
|
||||
providers = providers or ConditionProviderManager(providers=dict())
|
||||
error = None
|
||||
|
||||
# Evaluate
|
||||
|
|
Loading…
Reference in New Issue