Update code to handle E2EE changes to Coordinator contract, and use ThresholdRequestDecryptingPower to perform E2EE decryption requests.

pull/3123/head
derekpierre 2023-05-16 11:18:33 -04:00
parent 549252033c
commit 787d7d3e56
9 changed files with 155 additions and 57 deletions

View File

@ -4,8 +4,15 @@ from typing import List, Optional, Tuple, Union
import maya
from eth_typing import ChecksumAddress
from ferveo_py import AggregatedTranscript, Ciphertext, PublicKey, Validator
from ferveo_py import AggregatedTranscript, Ciphertext
from ferveo_py import PublicKey as FerveoPublicKey
from ferveo_py import Validator
from hexbytes import HexBytes
from nucypher_core import (
EncryptedThresholdDecryptionRequest,
ThresholdDecryptionRequest,
)
from nucypher_core.umbral import PublicKey
from web3 import Web3
from web3.types import TxReceipt
@ -27,7 +34,12 @@ from nucypher.blockchain.eth.token import NU
from nucypher.blockchain.eth.trackers.dkg import ActiveRitualTracker
from nucypher.blockchain.eth.trackers.pre import WorkTracker
from nucypher.crypto.ferveo.dkg import DecryptionShareSimple, FerveoVariant, Transcript
from nucypher.crypto.powers import CryptoPower, RitualisticPower, TransactingPower
from nucypher.crypto.powers import (
CryptoPower,
RitualisticPower,
ThresholdRequestDecryptingPower,
TransactingPower,
)
from nucypher.datastore.dkg import DKGStorage
from nucypher.network.trackers import OperatorBondedTracker
from nucypher.policy.conditions.lingo import ConditionLingo
@ -293,9 +305,18 @@ class Ritualist(BaseActor):
contract=self.coordinator_agent.contract
)
self.publish_finalization = publish_finalization # publish the DKG final key if True
self.dkg_storage = DKGStorage() # TODO: #3052 stores locally generated public DKG artifacts
self.ritual_power = crypto_power.power_ups(RitualisticPower) # ferveo material contained within
self.publish_finalization = (
publish_finalization # publish the DKG final key if True
)
self.dkg_storage = (
DKGStorage()
) # TODO: #3052 stores locally generated public DKG artifacts
self.ritual_power = crypto_power.power_ups(
RitualisticPower
) # ferveo material contained within
self.threshold_request_power = crypto_power.power_ups(
ThresholdRequestDecryptingPower
) # used for secure decryption request channel
def get_ritual(self, ritual_id: int) -> CoordinatorAgent.Ritual:
try:
@ -362,14 +383,18 @@ class Ritualist(BaseActor):
self,
ritual_id: int,
aggregated_transcript: AggregatedTranscript,
public_key: PublicKey,
public_key: FerveoPublicKey,
) -> TxReceipt:
"""Publish an aggregated transcript to publicly available storage."""
# look up the node index for this node on the blockchain
request_encrypting_key = self.threshold_request_power.get_pubkey_from_ritual_id(
ritual_id
).to_compressed_bytes()
receipt = self.coordinator_agent.post_aggregation(
ritual_id=ritual_id,
aggregated_transcript=bytes(aggregated_transcript),
public_key=public_key,
request_encrypting_key=request_encrypting_key,
transacting_power=self.transacting_power
)
return receipt
@ -553,6 +578,13 @@ class Ritualist(BaseActor):
return decryption_share
def decrypt_threshold_decryption_request(
self, encrypted_request: EncryptedThresholdDecryptionRequest
) -> Tuple[ThresholdDecryptionRequest, PublicKey]:
return self.threshold_request_power.decrypt_encrypted_request(
encrypted_request=encrypted_request
)
class PolicyAuthor(NucypherTokenActor):
"""Alice base class for blockchain operations, mocking up new policies!"""

View File

@ -555,6 +555,7 @@ class CoordinatorAgent(EthereumContractAgent):
provider: ChecksumAddress
aggregated: bool = False
transcript: bytes = bytes()
requestEncryptingKey: bytes = bytes()
class G1Point(NamedTuple):
"""Coordinator contract representation of DkgPublicKey."""
@ -648,7 +649,10 @@ class CoordinatorAgent(EthereumContractAgent):
participants = list()
for r in result:
participant = self.Ritual.Participant(
provider=ChecksumAddress(r[0]), aggregated=r[1], transcript=bytes(r[2])
provider=ChecksumAddress(r[0]),
aggregated=r[1],
transcript=bytes(r[2]),
requestEncryptingKey=bytes(r[3]),
)
participants.append(participant)
return participants
@ -669,6 +673,7 @@ class CoordinatorAgent(EthereumContractAgent):
provider=ChecksumAddress(result[0]),
aggregated=result[1],
transcript=bytes(result[2]),
requestEncryptingKey=bytes(result[3]),
)
return participant
@ -705,12 +710,14 @@ class CoordinatorAgent(EthereumContractAgent):
ritual_id: int,
aggregated_transcript: bytes,
public_key: DkgPublicKey,
request_encrypting_key: bytes,
transacting_power: TransactingPower,
) -> TxReceipt:
contract_function: ContractFunction = self.contract.functions.postAggregation(
ritualId=ritual_id,
aggregatedTranscript=aggregated_transcript,
publicKey=self.Ritual.G1Point.from_dkg_public_key(public_key),
requestEncryptingKey=request_encrypting_key,
)
receipt = self.blockchain.send_transaction(
contract_function=contract_function,

View File

@ -47,6 +47,7 @@ from nucypher_core import (
Conditions,
Context,
EncryptedKeyFrag,
EncryptedThresholdDecryptionResponse,
EncryptedTreasureMap,
MessageKit,
NodeMetadata,
@ -59,6 +60,7 @@ from nucypher_core import (
from nucypher_core.umbral import (
PublicKey,
RecoverableSignature,
SecretKey,
VerifiedKeyFrag,
reencrypt,
)
@ -94,6 +96,7 @@ from nucypher.crypto.powers import (
PowerUpError,
RitualisticPower,
SigningPower,
ThresholdRequestDecryptingPower,
TLSHostingPower,
TransactingPower,
)
@ -582,22 +585,35 @@ class Bob(Character):
)
return decryption_request
def get_decryption_shares_using_existing_decryption_request(self,
decryption_request: ThresholdDecryptionRequest,
variant: FerveoVariant,
cohort: List["Ursula"],
threshold: int,
):
def get_decryption_shares_using_existing_decryption_request(
self,
decryption_request: ThresholdDecryptionRequest,
request_encrypting_keys: Dict[ChecksumAddress, PublicKey],
variant: FerveoVariant,
cohort: List["Ursula"],
threshold: int,
):
if variant == FerveoVariant.PRECOMPUTED:
share_type = DecryptionSharePrecomputed
elif variant == FerveoVariant.SIMPLE:
share_type = DecryptionShareSimple
# use ephemeral key for request
# TODO don't use Umbral in the long-run
response_sk = SecretKey.random()
response_encrypting_key = response_sk.public_key()
decryption_request_mapping = {}
for ursula in cohort:
decryption_request_mapping[
to_checksum_address(ursula.checksum_address)
] = bytes(decryption_request)
ursula_checksum_address = to_checksum_address(ursula.checksum_address)
request_encrypting_key = request_encrypting_keys[ursula_checksum_address]
encrypted_decryption_request = decryption_request.encrypt(
request_encrypting_key=request_encrypting_key,
response_encrypting_key=response_encrypting_key,
)
decryption_request_mapping[ursula_checksum_address] = bytes(
encrypted_decryption_request
)
decryption_client = ThresholdDecryptionClient(learner=self)
successes, failures = decryption_client.gather_encrypted_decryption_shares(
@ -605,12 +621,15 @@ class Bob(Character):
)
if len(successes) < threshold:
raise Ursula.NotEnoughUrsulas(f"Not enough Ursulas to decrypt")
raise Ursula.NotEnoughUrsulas(f"Not enough Ursulas to decrypt: {failures}")
self.log.debug(f"Got enough shares to decrypt.")
gathered_shares = {}
for provider_address, response_bytes in successes.items():
decryption_response = ThresholdDecryptionResponse.from_bytes(response_bytes)
encrypted_decryption_response = (
EncryptedThresholdDecryptionResponse.from_bytes(response_bytes)
)
decryption_response = encrypted_decryption_response.decrypt(sk=response_sk)
decryption_share = share_type.from_bytes(
decryption_response.decryption_share
)
@ -618,37 +637,40 @@ class Bob(Character):
return gathered_shares
def gather_decryption_shares(
self,
ritual_id: int,
cohort: List["Ursula"],
ciphertext: Ciphertext,
lingo: LingoList,
threshold: int,
variant: FerveoVariant,
context: Optional[dict] = None,
self,
ritual_id: int,
cohort: List["Ursula"],
ciphertext: Ciphertext,
lingo: LingoList,
threshold: int,
variant: FerveoVariant,
request_encrypting_keys: Dict[ChecksumAddress, PublicKey],
context: Optional[dict] = None,
) -> Dict[
ChecksumAddress, Union[DecryptionShareSimple, DecryptionSharePrecomputed]
]:
decryption_request = self.make_decryption_request(
ritual_id=ritual_id,
ciphertext=ciphertext,
lingo=lingo,
variant=variant,
context=context,
)
return self.get_decryption_shares_using_existing_decryption_request(
decryption_request, request_encrypting_keys, variant, cohort, threshold
)
decryption_request = self.make_decryption_request(ritual_id=ritual_id,
ciphertext=ciphertext,
lingo=lingo,
variant=variant,
context=context)
return self.get_decryption_shares_using_existing_decryption_request(decryption_request, variant, cohort,
threshold)
def threshold_decrypt(self,
ritual_id: int,
ciphertext: Ciphertext,
conditions: LingoList,
context: Optional[dict] = None,
params: Optional[DkgPublicParameters] = None,
ursulas: Optional[List['Ursula']] = None,
variant: str = 'simple',
peering_timeout: int = 60,
) -> bytes:
def threshold_decrypt(
self,
ritual_id: int,
ciphertext: Ciphertext,
conditions: LingoList,
context: Optional[dict] = None,
params: Optional[DkgPublicParameters] = None,
ursulas: Optional[List["Ursula"]] = None,
variant: str = "simple",
peering_timeout: int = 60,
) -> bytes:
# blockchain reads: get the DKG parameters and the cohort.
coordinator_agent = ContractAgency.get_agent(CoordinatorAgent, registry=self.registry)
ritual = coordinator_agent.get_ritual(ritual_id, with_participants=True)
@ -660,19 +682,32 @@ class Bob(Character):
ursulas = self.resolve_cohort(ritual=ritual, timeout=peering_timeout)
else:
for ursula in ursulas:
if ursula.staking_provider_address not in ritual.participants:
raise ValueError(f"{ursula} is not part of the cohort")
if ursula.staking_provider_address not in ritual.providers:
raise ValueError(
f"{ursula} ({ursula.staking_provider_address}) is not part of the cohort"
)
self.remember_node(ursula)
try:
variant = FerveoVariant(getattr(FerveoVariant, variant.upper()).value)
except AttributeError:
raise ValueError(f"Invalid variant: {variant}; Options are: {list(v.name.lower() for v in list(FerveoVariant))}")
raise ValueError(
f"Invalid variant: {variant}; Options are: {list(v.name.lower() for v in list(FerveoVariant))}"
)
threshold = (
(ritual.shares // 2) + 1
if variant == FerveoVariant.SIMPLE
else ritual.shares
) # TODO: #3095 get this from the ritual / put it on-chain?
request_encrypting_keys = {}
participants = ritual.participants
for p in participants:
# TODO don't use Umbral in the long-run
request_encrypting_keys[p.provider] = PublicKey.from_compressed_bytes(
p.requestEncryptingKey
)
decryption_shares = self.gather_decryption_shares(
ritual_id=ritual_id,
cohort=ursulas,
@ -681,6 +716,7 @@ class Bob(Character):
lingo=conditions,
threshold=threshold,
variant=variant,
request_encrypting_keys=request_encrypting_keys,
)
if not params:
@ -746,6 +782,7 @@ class Ursula(Teacher, Character, Operator, Ritualist):
SigningPower,
DecryptingPower,
RitualisticPower,
ThresholdRequestDecryptingPower,
# TLSHostingPower # Still considered a default for Ursula, but needs the host context
]

View File

@ -1,12 +1,11 @@
import ferveo_py
import time
from collections import defaultdict, deque
from contextlib import suppress
from pathlib import Path
from queue import Queue
from typing import Callable, List, Optional, Set, Tuple, Union
import ferveo_py
import maya
import requests
from constant_sorrow.constants import (
@ -37,7 +36,8 @@ from nucypher.crypto.powers import (
CryptoPower,
DecryptingPower,
NoSigningPower,
SigningPower, RitualisticPower,
RitualisticPower,
SigningPower,
)
from nucypher.crypto.signing import InvalidSignature, SignatureStamp
from nucypher.network.exceptions import NodeSeemsToBeDown

View File

@ -10,6 +10,7 @@ from flask import Flask, Response, jsonify, request
from mako import exceptions as mako_exceptions
from mako.template import Template
from nucypher_core import (
EncryptedThresholdDecryptionRequest,
MetadataRequest,
MetadataResponse,
MetadataResponsePayload,
@ -145,7 +146,13 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
def threshold_decrypt():
# Deserialize and instantiate ThresholdDecryptionRequest from the request data
decryption_request = ThresholdDecryptionRequest.from_bytes(request.data)
encrypted_decryption_request = EncryptedThresholdDecryptionRequest.from_bytes(
request.data
)
(
decryption_request,
response_encrypting_key,
) = this_node.decrypt_threshold_decryption_request(encrypted_decryption_request)
log.info(f"Threshold decryption request for ritual ID #{decryption_request.id}")
@ -189,7 +196,11 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
# TODO: #3079 #3081 encrypt the response with the requester's public key
# TODO: #3098 nucypher-core#49 Use DecryptionShare type
response = ThresholdDecryptionResponse(decryption_share=bytes(decryption_share))
return Response(bytes(response), headers={'Content-Type': 'application/octet-stream'})
encrypted_response = response.encrypt(encrypting_key=response_encrypting_key)
return Response(
bytes(encrypted_response),
headers={"Content-Type": "application/octet-stream"},
)
@rest_app.route('/reencrypt', methods=["POST"])
def reencrypt():

View File

@ -129,11 +129,13 @@ def test_post_aggregation(
agent, aggregated_transcript, dkg_public_key, transacting_powers
):
ritual_id = agent.number_of_rituals() - 1
request_encrypting_keys = [os.urandom(32) for t in transacting_powers]
for i, transacting_power in enumerate(transacting_powers):
receipt = agent.post_aggregation(
ritual_id=ritual_id,
aggregated_transcript=aggregated_transcript,
public_key=dkg_public_key,
request_encrypting_key=request_encrypting_keys[i],
transacting_power=transacting_power,
)
assert receipt["status"] == 1
@ -149,7 +151,9 @@ def test_post_aggregation(
)
participants = agent.get_participants(ritual_id)
assert all([p.aggregated for p in participants])
for i, p in enumerate(participants):
assert p.aggregated
assert p.requestEncryptingKey == request_encrypting_keys[i]
assert agent.get_ritual_status(ritual_id=ritual_id) == agent.Ritual.Status.FINALIZED

View File

@ -5,8 +5,8 @@ plugins:
dependencies:
- name: nucypher-contracts
github: nucypher/nucypher-contracts
ref: main
github: derekpierre/nucypher-contracts
ref: e2e-dkg
- name: openzeppelin
github: OpenZeppelin/openzeppelin-contracts
version: 4.8.1

View File

@ -111,6 +111,7 @@ class MockCoordinatorAgent(MockContractAgent):
ritual_id: int,
aggregated_transcript: bytes,
public_key: DkgPublicKey,
request_encrypting_key: bytes,
transacting_power: TransactingPower,
) -> TxReceipt:
ritual = self.rituals[ritual_id]
@ -122,6 +123,7 @@ class MockCoordinatorAgent(MockContractAgent):
)
participant = self.get_participant_from_provider(ritual_id, provider)
participant.aggregated = True
participant.requestEncryptingKey = request_encrypting_key
g1_point = self.Ritual.G1Point.from_dkg_public_key(public_key)
if len(ritual.aggregated_transcript) == 0:

View File

@ -103,23 +103,28 @@ def test_mock_coordinator_round_2(
assert p.transcript == FAKE_TRANSCRIPT
aggregated_transcript = os.urandom(len(FAKE_TRANSCRIPT))
request_encrypting_keys = []
for index, node_address in enumerate(nodes_transacting_powers):
request_encrypting_key = os.urandom(32)
coordinator.post_aggregation(
ritual_id=0,
aggregated_transcript=aggregated_transcript,
public_key=dkg_public_key,
request_encrypting_key=request_encrypting_key,
transacting_power=nodes_transacting_powers[node_address]
)
request_encrypting_keys.append(request_encrypting_key)
if index == len(nodes_transacting_powers) - 1:
assert len(coordinator.EVENTS) == 2
assert ritual.aggregated_transcript == aggregated_transcript
assert bytes(ritual.public_key) == bytes(dkg_public_key)
for p in ritual.participants:
for index, p in enumerate(ritual.participants):
# unchanged
assert p.transcript == FAKE_TRANSCRIPT
assert p.transcript != aggregated_transcript
assert p.requestEncryptingKey == request_encrypting_keys[index]
assert len(coordinator.EVENTS) == 2 # no additional event emitted here?
assert (