Merge pull request #3126 from derekpierre/bob-porter-conjoined

Ensure that Bob and Porter can use common code for making Threshold Decryption Requests
pull/3133/head
Derek Pierre 2023-05-23 10:07:54 -04:00 committed by GitHub
commit 8b205ad443
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 430 additions and 96 deletions

View File

View File

@ -97,11 +97,12 @@ from nucypher.crypto.powers import (
TLSHostingPower, TLSHostingPower,
TransactingPower, TransactingPower,
) )
from nucypher.network.decryption import ThresholdDecryptionClient
from nucypher.network.exceptions import NodeSeemsToBeDown from nucypher.network.exceptions import NodeSeemsToBeDown
from nucypher.network.middleware import RestMiddleware from nucypher.network.middleware import RestMiddleware
from nucypher.network.nodes import TEACHER_NODES, NodeSprout, Teacher from nucypher.network.nodes import TEACHER_NODES, NodeSprout, Teacher
from nucypher.network.protocols import parse_node_uri from nucypher.network.protocols import parse_node_uri
from nucypher.network.retrieval import RetrievalClient from nucypher.network.retrieval import PRERetrievalClient
from nucypher.network.server import ProxyRESTServer, make_rest_app from nucypher.network.server import ProxyRESTServer, make_rest_app
from nucypher.network.trackers import AvailabilityTracker from nucypher.network.trackers import AvailabilityTracker
from nucypher.policy.conditions.types import LingoList from nucypher.policy.conditions.types import LingoList
@ -502,7 +503,7 @@ class Bob(Character):
retrieval_kits = [message_kit.as_retrieval_kit() for message_kit in message_kits] retrieval_kits = [message_kit.as_retrieval_kit() for message_kit in message_kits]
# Retrieve capsule frags # Retrieve capsule frags
client = RetrievalClient(learner=self) client = PRERetrievalClient(learner=self)
retrieval_results, _ = client.retrieve_cfrags( retrieval_results, _ = client.retrieve_cfrags(
treasure_map=treasure_map, treasure_map=treasure_map,
retrieval_kits=retrieval_kits, retrieval_kits=retrieval_kits,
@ -570,51 +571,46 @@ class Bob(Character):
threshold: int, threshold: int,
variant: FerveoVariant, variant: FerveoVariant,
context: Optional[dict] = None, context: Optional[dict] = None,
) -> List[DecryptionShareSimple]: ) -> Dict[
ChecksumAddress, Union[DecryptionShareSimple, DecryptionSharePrecomputed]
]:
if variant == FerveoVariant.PRECOMPUTED: if variant == FerveoVariant.PRECOMPUTED:
share_type = DecryptionSharePrecomputed share_type = DecryptionSharePrecomputed
elif variant == FerveoVariant.SIMPLE: elif variant == FerveoVariant.SIMPLE:
share_type = DecryptionShareSimple share_type = DecryptionShareSimple
gathered_shares = list() decryption_request_mapping = {}
for ursula in cohort: for ursula in cohort:
conditions = Conditions(json.dumps(lingo)) conditions = Conditions(json.dumps(lingo))
if context: if context:
context = Context(json.dumps(context)) context = Context(json.dumps(context))
decryption_request = ThresholdDecryptionRequest( decryption_request = ThresholdDecryptionRequest(
id=ritual_id, id=ritual_id,
variant=int(variant.value),
ciphertext=bytes(ciphertext), ciphertext=bytes(ciphertext),
conditions=conditions, conditions=conditions,
context=context, context=context,
variant=int(variant.value),
) )
decryption_request_mapping[
to_checksum_address(ursula.checksum_address)
] = bytes(decryption_request)
try: decryption_client = ThresholdDecryptionClient(learner=self)
response = self.network_middleware.get_decryption_share(ursula, bytes(decryption_request)) successes, failures = decryption_client.gather_encrypted_decryption_shares(
except NodeSeemsToBeDown as e: encrypted_requests=decryption_request_mapping, threshold=threshold
self.log.warn(f"Node {ursula} is unreachable. {e}") )
continue
if response.status_code != 200:
self.log.warn(f"Node {ursula} returned {response.status_code}.")
continue
decryption_response = ThresholdDecryptionResponse.from_bytes( if len(successes) < threshold:
response.content
)
decryption_share = share_type.from_bytes(
decryption_response.decryption_share
)
gathered_shares.append(decryption_share)
self.log.debug(f"Got {len(gathered_shares)}/{threshold} shares so far...")
if variant == FerveoVariant.SIMPLE and (len(gathered_shares) == threshold):
# security threshold reached
break
if len(gathered_shares) < threshold:
raise Ursula.NotEnoughUrsulas(f"Not enough Ursulas to decrypt") raise Ursula.NotEnoughUrsulas(f"Not enough Ursulas to decrypt")
self.log.debug(f"Got enough shares to decrypt.") self.log.debug(f"Got enough shares to decrypt.")
gathered_shares = {}
for provider_address, response_bytes in successes.items():
decryption_response = ThresholdDecryptionResponse.from_bytes(response_bytes)
decryption_share = share_type.from_bytes(
decryption_response.decryption_share
)
gathered_shares[provider_address] = decryption_share
return gathered_shares return gathered_shares
def threshold_decrypt(self, def threshold_decrypt(self,
@ -647,8 +643,12 @@ class Bob(Character):
except AttributeError: except AttributeError:
raise ValueError(f"Invalid variant: {variant}; Options are: {list(v.name.lower() for v in list(FerveoVariant))}") raise ValueError(f"Invalid variant: {variant}; Options are: {list(v.name.lower() for v in list(FerveoVariant))}")
threshold = (ritual.shares // 2) + 1 # TODO: #3095 get this from the ritual / put it on-chain? threshold = (
shares = self.gather_decryption_shares( (ritual.shares // 2) + 1
if variant == FerveoVariant.SIMPLE
else ritual.shares
) # TODO: #3095 get this from the ritual / put it on-chain?
decryption_shares = self.gather_decryption_shares(
ritual_id=ritual_id, ritual_id=ritual_id,
cohort=ursulas, cohort=ursulas,
ciphertext=ciphertext, ciphertext=ciphertext,
@ -662,10 +662,15 @@ class Bob(Character):
# TODO: Bob can call.verify here instead of aggregating the shares. # TODO: Bob can call.verify here instead of aggregating the shares.
# if the DKG parameters are not provided, we need to # if the DKG parameters are not provided, we need to
# aggregate the transcripts and derive them. # aggregate the transcripts and derive them.
# TODO we don't need all ursulas, only threshold of them
# ursulas = [u for u in ursulas if u.checksum_address in decryption_shares]
params = self.__derive_dkg_parameters(ritual_id, ursulas, ritual, threshold) params = self.__derive_dkg_parameters(ritual_id, ursulas, ritual, threshold)
# TODO: compare the results with the on-chain records (Coordinator). # TODO: compare the results with the on-chain records (Coordinator).
return self.__decrypt(shares, ciphertext, conditions, params, variant) return self.__decrypt(
list(decryption_shares.values()), ciphertext, conditions, params, variant
)
@staticmethod @staticmethod
def __decrypt( def __decrypt(

View File

@ -1 +0,0 @@

View File

@ -0,0 +1,50 @@
from typing import List
from eth_typing import ChecksumAddress
from nucypher.network.nodes import Learner
from nucypher.utilities.logging import Logger
class ThresholdAccessControlClient:
"""
Client for communicating with access control nodes on the Threshold Network.
"""
def __init__(self, learner: Learner):
self._learner = learner
self.log = Logger(self.__class__.__name__)
def _ensure_ursula_availability(
self, ursulas: List[ChecksumAddress], threshold: int, timeout=10
):
"""
Make sure we know enough nodes from the treasure map to decrypt;
otherwise block and wait for them to come online.
"""
# OK, so we're going to need to do some network activity for this retrieval.
# Let's make sure we've seeded.
if not self._learner.done_seeding:
self._learner.learn_from_teacher_node()
all_known_ursulas = self._learner.known_nodes.addresses()
# Push all unknown Ursulas from the map in the queue for learning
unknown_ursulas = ursulas - all_known_ursulas
# If we know enough to decrypt, we can proceed.
known_ursulas = ursulas & all_known_ursulas
if len(known_ursulas) >= threshold:
return
# | <--- shares ---> |
# | <--- threshold ---> | <--- allow_missing ---> |
# | <--- known_ursulas ---> | <--- unknown_ursulas ---> |
allow_missing = len(ursulas) - threshold
self._learner.block_until_specific_nodes_are_known(
unknown_ursulas,
timeout=timeout,
allow_missing=allow_missing,
learn_on_this_thread=True,
)

View File

@ -0,0 +1,74 @@
from typing import Dict, List, Tuple
from eth_typing import ChecksumAddress
from nucypher.network.client import ThresholdAccessControlClient
from nucypher.utilities.concurrency import BatchValueFactory, WorkerPool
class ThresholdDecryptionClient(ThresholdAccessControlClient):
class DecryptionRequestFailed(Exception):
"""Raised when a decryption request returns a non-zero status code."""
class DecryptionRequestFactory(BatchValueFactory):
def __init__(self, ursula_to_contact: List[ChecksumAddress], threshold: int):
# TODO should we batch the ursulas to contact i.e. pass `batch_size` parameter
super().__init__(values=ursula_to_contact, required_successes=threshold)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def gather_encrypted_decryption_shares(
self,
encrypted_requests: Dict[ChecksumAddress, bytes],
threshold: int,
timeout: float = 10,
) -> Tuple[Dict[ChecksumAddress, bytes], Dict[ChecksumAddress, str]]:
self._ensure_ursula_availability(
ursulas=list(encrypted_requests.keys()),
threshold=threshold,
timeout=timeout,
)
def worker(ursula_address: ChecksumAddress) -> bytes:
encrypted_request = encrypted_requests[ursula_address]
try:
node_or_sprout = self._learner.known_nodes[ursula_address]
node_or_sprout.mature()
response = (
self._learner.network_middleware.get_encrypted_decryption_share(
node_or_sprout, encrypted_request
)
)
except Exception as e:
self.log.warn(f"Node {ursula_address} raised {e}")
raise
else:
if response.status_code != 200:
message = f"Node {ursula_address} returned {response.status_code} - {response.content}."
self.log.warn(message)
raise self.DecryptionRequestFailed(message)
return response.content
worker_pool = WorkerPool(
worker=worker,
value_factory=self.DecryptionRequestFactory(
ursula_to_contact=list(encrypted_requests.keys()), threshold=threshold
),
target_successes=threshold,
timeout=timeout,
)
worker_pool.start()
try:
successes = worker_pool.block_until_target_successes()
except (WorkerPool.OutOfValues, WorkerPool.TimedOut):
# It's possible to raise some other exceptions here but we will use the logic below.
successes = worker_pool.get_successes()
finally:
worker_pool.cancel()
worker_pool.join()
failures = worker_pool.get_failures()
return successes, failures

View File

@ -262,7 +262,9 @@ class RestMiddleware:
) )
return response return response
def get_decryption_share(self, ursula: 'Ursula', decryption_request_bytes: bytes): def get_encrypted_decryption_share(
self, ursula: "Ursula", decryption_request_bytes: bytes
):
response = self.client.post( response = self.client.post(
node_or_sprout=ursula, node_or_sprout=ursula,
path=f"decrypt", path=f"decrypt",

View File

@ -2,7 +2,6 @@
import json import json
import random import random
from collections import defaultdict from collections import defaultdict
from json import JSONDecodeError
from typing import Dict, List, Sequence, Tuple from typing import Dict, List, Sequence, Tuple
from eth_typing.evm import ChecksumAddress from eth_typing.evm import ChecksumAddress
@ -22,11 +21,10 @@ from nucypher_core.umbral import (
VerificationError, VerificationError,
VerifiedCapsuleFrag, VerifiedCapsuleFrag,
) )
from twisted.logger import Logger
from nucypher.crypto.signing import InvalidSignature from nucypher.crypto.signing import InvalidSignature
from nucypher.network.client import ThresholdAccessControlClient
from nucypher.network.exceptions import NodeSeemsToBeDown from nucypher.network.exceptions import NodeSeemsToBeDown
from nucypher.network.nodes import Learner
from nucypher.policy.conditions.exceptions import InvalidConditionContext from nucypher.policy.conditions.exceptions import InvalidConditionContext
from nucypher.policy.conditions.rust_shims import _serialize_rust_lingos from nucypher.policy.conditions.rust_shims import _serialize_rust_lingos
from nucypher.policy.kits import RetrievalResult from nucypher.policy.kits import RetrievalResult
@ -39,7 +37,7 @@ class RetrievalError:
class RetrievalPlan: class RetrievalPlan:
""" """
An emphemeral object providing a service of selecting Ursulas for reencryption requests An ephemeral object providing a service of selecting Ursulas for re-encryption requests
during retrieval. during retrieval.
""" """
@ -166,49 +164,13 @@ class RetrievalWorkOrder:
return rust_lingos return rust_lingos
class RetrievalClient: class PRERetrievalClient(ThresholdAccessControlClient):
""" """
Capsule frag retrieval machinery shared between Bob and Porter. Capsule frag retrieval machinery shared between Bob and Porter.
""" """
def __init__(self, learner: Learner): def __init__(self, *args, **kwargs):
self._learner = learner super().__init__(*args, **kwargs)
self.log = Logger(self.__class__.__name__)
def _ensure_ursula_availability(self, treasure_map: TreasureMap, timeout=10):
"""
Make sure we know enough nodes from the treasure map to decrypt;
otherwise block and wait for them to come online.
"""
# OK, so we're going to need to do some network activity for this retrieval.
# Let's make sure we've seeded.
if not self._learner.done_seeding:
self._learner.learn_from_teacher_node()
ursulas_in_map = treasure_map.destinations.keys()
# TODO (#1995): when that issue is fixed, conversion is no longer needed
ursulas_in_map = [to_checksum_address(bytes(address)) for address in ursulas_in_map]
all_known_ursulas = self._learner.known_nodes.addresses()
# Push all unknown Ursulas from the map in the queue for learning
unknown_ursulas = ursulas_in_map - all_known_ursulas
# If we know enough to decrypt, we can proceed.
known_ursulas = ursulas_in_map & all_known_ursulas
if len(known_ursulas) >= treasure_map.threshold:
return
# | <--- shares ---> |
# | <--- threshold ---> | <--- allow_missing ---> |
# | <--- known_ursulas ---> | <--- unknown_ursulas ---> |
allow_missing = len(treasure_map.destinations) - treasure_map.threshold
self._learner.block_until_specific_nodes_are_known(unknown_ursulas,
timeout=timeout,
allow_missing=allow_missing,
learn_on_this_thread=True)
def _request_reencryption(self, def _request_reencryption(self,
ursula: 'Ursula', ursula: 'Ursula',
@ -284,7 +246,16 @@ class RetrievalClient:
bob_verifying_key: PublicKey, bob_verifying_key: PublicKey,
**context) -> Tuple[List[RetrievalResult], List[RetrievalError]]: **context) -> Tuple[List[RetrievalResult], List[RetrievalError]]:
self._ensure_ursula_availability(treasure_map) ursulas_in_map = treasure_map.destinations.keys()
# TODO (#1995): when that issue is fixed, conversion is no longer needed
ursulas_in_map = [
to_checksum_address(bytes(address)) for address in ursulas_in_map
]
self._ensure_ursula_availability(
ursulas=ursulas_in_map, threshold=treasure_map.threshold
)
retrieval_plan = RetrievalPlan(treasure_map=treasure_map, retrieval_kits=retrieval_kits) retrieval_plan = RetrievalPlan(treasure_map=treasure_map, retrieval_kits=retrieval_kits)

View File

@ -330,3 +330,48 @@ class WorkerPool:
break break
self._result_queue.put(PRODUCER_STOPPED) self._result_queue.put(PRODUCER_STOPPED)
class BatchValueFactory:
def __init__(
self, values: List[Any], required_successes: int, batch_size: int = None
):
if not values:
raise ValueError(f"No available values provided")
if required_successes <= 0:
raise ValueError(
f"Invalid number of successes required ({required_successes})"
)
self.values = values
self.required_successes = required_successes
if len(self.values) < self.required_successes:
raise ValueError(
f"Available values ({len(self.values)} less than required successes {self.required_successes}"
)
self._batch_start_index = 0
if batch_size is not None and batch_size <= 0:
raise ValueError(f"Invalid batch size specified ({batch_size})")
self.batch_size = batch_size if batch_size else required_successes
def __call__(self, successes) -> Optional[List[Any]]:
if successes >= self.required_successes:
# no more work needed to be done
return None
if self._batch_start_index == len(self.values):
# no more values to process
return None
batch_end_index = self._batch_start_index + self.batch_size
if batch_end_index <= len(self.values):
batch = self.values[self._batch_start_index : batch_end_index]
self._batch_start_index = batch_end_index
return batch
else:
# return all remaining values
batch = self.values[self._batch_start_index :]
self._batch_start_index = len(self.values)
return batch

View File

@ -1,9 +1,10 @@
import pytest
import random import random
import time import time
from typing import Iterable, Tuple from typing import Iterable, Tuple
from nucypher.utilities.concurrency import WorkerPool import pytest
from nucypher.utilities.concurrency import BatchValueFactory, WorkerPool
class AllAtOnceFactory: class AllAtOnceFactory:
@ -213,27 +214,18 @@ def test_join(join_worker_pool):
assert t_end - t_start < 3 assert t_end - t_start < 3
class BatchFactory: class TestBatchValueFactory(BatchValueFactory):
def __init__(self, values): def __init__(self, *args, **kwargs):
self.values = values super().__init__(*args, **kwargs)
self.batch_sizes = [] self.batch_sizes = []
def __call__(self, successes): def __call__(self, successes):
if successes == 10: result = super().__call__(successes)
return None if result:
batch_size = 10 - successes self.batch_sizes.append(len(result))
if len(self.values) >= batch_size:
batch = self.values[:batch_size] return result
self.batch_sizes.append(len(batch))
self.values = self.values[batch_size:]
return batch
elif len(self.values) > 0:
self.batch_sizes.append(len(self.values))
return self.values
self.values = None
else:
return None
def test_batched_value_generation(join_worker_pool): def test_batched_value_generation(join_worker_pool):
@ -248,7 +240,7 @@ def test_batched_value_generation(join_worker_pool):
], ],
seed=123) seed=123)
factory = BatchFactory(list(outcomes)) factory = TestBatchValueFactory(values=list(outcomes), required_successes=10)
pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10, stagger_timeout=0.5) pool = WorkerPool(worker, factory, target_successes=10, timeout=10, threadpool_size=10, stagger_timeout=0.5)
join_worker_pool(pool) join_worker_pool(pool)

View File

@ -0,0 +1,196 @@
import pytest
from nucypher.utilities.concurrency import BatchValueFactory
NUM_VALUES = 20
@pytest.fixture(scope="module")
def values():
values = []
for i in range(0, NUM_VALUES):
values.append(i)
return values
def test_batch_value_factory_invalid_values(values):
with pytest.raises(ValueError):
BatchValueFactory(values=[], required_successes=0)
with pytest.raises(ValueError):
BatchValueFactory(values=[], required_successes=1)
with pytest.raises(ValueError):
BatchValueFactory(values=[1, 2, 3, 4], required_successes=5)
with pytest.raises(ValueError):
BatchValueFactory(values=[1, 2, 3, 4], required_successes=2, batch_size=0)
def test_batch_value_factory_all_successes_no_specified_batching(values):
target_successes = NUM_VALUES
value_factory = BatchValueFactory(
values=values, required_successes=target_successes
)
# number of successes returned since no batching provided
value_list = value_factory(successes=0)
assert len(value_list) == target_successes, "list returned is based on successes"
assert len(values) == NUM_VALUES, "values remained unchanged"
# get list again
value_list = value_factory(successes=NUM_VALUES) # successes achieved
assert not value_list, "successes achieved and no more values available"
# get list again
value_list = value_factory(successes=0) # successes not achieved
assert not value_list, "no successes achieved but no more values available"
def test_batch_value_factory_no_specified_batching_no_more_values_after_target_successes(
values,
):
target_successes = 1
value_factory = BatchValueFactory(
values=values, required_successes=target_successes
)
for i in range(0, NUM_VALUES // 3):
value_list = value_factory(successes=0)
assert (
len(value_list) == target_successes
), "list returned is based on successes"
assert len(values) == NUM_VALUES, "values remained unchanged"
for i in range(NUM_VALUES // 3, NUM_VALUES):
value_list = value_factory(successes=target_successes)
assert (
not value_list
), "there are more values but no more is needed since target successes attained"
def test_batch_value_factory_no_batching_no_success_multiple_calls(values):
target_successes = 4
value_factory = BatchValueFactory(
values=values, required_successes=target_successes
)
for i in range(0, NUM_VALUES // target_successes):
value_list = value_factory(successes=0)
assert (
len(value_list) == target_successes
), "list returned is based on successes"
assert len(values) == NUM_VALUES, "values remained unchanged"
# list all done but get list again
value_list = value_factory(successes=target_successes) # successes achieved
assert not value_list, "successes achieved"
# list all done but get list again
value_list = value_factory(
successes=1
) # not enough successes but list is now empty
assert not value_list, "successes not achieved, but no more values available"
def test_batch_value_factory_no_batching_no_success_multiple_calls_non_divisible_successes(
values,
):
target_successes = 6
value_factory = BatchValueFactory(
values=values, required_successes=target_successes
)
# should be able to get 4 lists
for i in range(0, NUM_VALUES // target_successes):
value_list = value_factory(successes=0)
assert (
len(value_list) == target_successes
), "list returned is based on successes"
assert len(values) == NUM_VALUES, "values remained unchanged"
# last request
value_list = value_factory(successes=0)
assert len(value_list) == NUM_VALUES % target_successes, "remaining list returned"
# get list again
value_list = value_factory(successes=target_successes) # successes achieved
assert not value_list, "successes achieved"
# get list again
value_list = value_factory(
successes=target_successes - 1
) # not enough successes but list is now empty
assert not value_list, "successes not achieved, but no more values available"
def test_batch_value_factory_batching_individual(values):
target_successes = NUM_VALUES
batch_size = 1
value_factory = BatchValueFactory(
values=values, required_successes=target_successes, batch_size=batch_size
)
# number of successes returned since no batching provided
for i in range(0, NUM_VALUES // batch_size):
value_list = value_factory(successes=0)
assert len(value_list) == batch_size, "list returned is based on batch size"
assert len(values) == NUM_VALUES, "values remained unchanged"
# get list again
value_list = value_factory(successes=NUM_VALUES) # successes achieved
assert not value_list, "successes achieved and no more values available"
# get list again
value_list = value_factory(successes=0) # successes not achieved
assert not value_list, "no successes achieved but no more values available"
def test_batch_value_factory_batching_divisible(values):
target_successes = NUM_VALUES
batch_size = 5
value_factory = BatchValueFactory(
values=values, required_successes=target_successes, batch_size=batch_size
)
# number of successes returned since no batching provided (3x here)
for i in range(0, NUM_VALUES // batch_size):
value_list = value_factory(successes=target_successes - 1)
assert len(value_list) == batch_size, "list returned is based on batch size"
assert len(values) == NUM_VALUES, "values remained unchanged"
# get list again
value_list = value_factory(successes=NUM_VALUES) # successes achieved
assert not value_list, "successes achieved and no more values available"
# get list again
value_list = value_factory(successes=0) # successes not achieved
assert not value_list, "no successes achieved but no more values available"
def test_batch_value_factory_batching_non_divisible(values):
target_successes = NUM_VALUES
batch_size = 7
value_factory = BatchValueFactory(
values=values, required_successes=target_successes, batch_size=batch_size
)
# number of successes returned since no batching provided
for i in range(0, NUM_VALUES // batch_size):
value_list = value_factory(successes=0)
assert len(value_list) == batch_size, "list returned is based on batch size"
assert len(values) == NUM_VALUES, "values remained unchanged"
# one more
value_list = value_factory(successes=0)
assert len(value_list) == NUM_VALUES % batch_size, "remainder of list returned"
assert len(values) == NUM_VALUES, "values remained unchanged"
# get list again
value_list = value_factory(successes=target_successes) # successes achieved
assert not value_list, "successes achieved and no more values available"
# get list again
value_list = value_factory(successes=0) # successes not achieved
assert not value_list, "no successes achieved but no more values available"