Move KEM/DEM logic to ThresholdMessageKit, and add ACP.aad() function - the AAD can be controlled by versioning but the TMK dictates the AAD and so must be linked somehow with the ACP.aad() function. For now this is done via a compatibility check function.

pull/3194/head
derekpierre 2023-08-02 12:46:13 -04:00 committed by Kieran Prasch
parent 51e78a14cf
commit fc9edd9ed4
5 changed files with 98 additions and 54 deletions

View File

@ -7,7 +7,6 @@ import time
from eth_typing import ChecksumAddress from eth_typing import ChecksumAddress
from hexbytes import HexBytes from hexbytes import HexBytes
from nucypher_core import ( from nucypher_core import (
Conditions,
EncryptedThresholdDecryptionResponse, EncryptedThresholdDecryptionResponse,
SessionStaticKey, SessionStaticKey,
ThresholdDecryptionResponse, ThresholdDecryptionResponse,
@ -630,11 +629,7 @@ class Ritualist(BaseActor):
return tx_hash return tx_hash
def derive_decryption_share( def derive_decryption_share(
self, self, ritual_id: int, ciphertext: Ciphertext, aad: bytes, variant: FerveoVariant
ritual_id: int,
ciphertext: Ciphertext,
conditions: Conditions,
variant: FerveoVariant
) -> Union[DecryptionShareSimple, DecryptionSharePrecomputed]: ) -> Union[DecryptionShareSimple, DecryptionSharePrecomputed]:
ritual = self.coordinator_agent.get_ritual(ritual_id) ritual = self.coordinator_agent.get_ritual(ritual_id)
status = self.coordinator_agent.get_ritual_status(ritual_id=ritual_id) status = self.coordinator_agent.get_ritual_status(ritual_id=ritual_id)
@ -648,7 +643,6 @@ class Ritualist(BaseActor):
) )
threshold = (ritual.shares // 2) + 1 threshold = (ritual.shares // 2) + 1
conditions = str(conditions).encode()
# TODO: consider the usage of local DKG artifact storage here #3052 # TODO: consider the usage of local DKG artifact storage here #3052
# aggregated_transcript_bytes = self.dkg_storage.get_aggregated_transcript(ritual_id) # aggregated_transcript_bytes = self.dkg_storage.get_aggregated_transcript(ritual_id)
aggregated_transcript = AggregatedTranscript.from_bytes(bytes(ritual.aggregated_transcript)) aggregated_transcript = AggregatedTranscript.from_bytes(bytes(ritual.aggregated_transcript))
@ -660,7 +654,7 @@ class Ritualist(BaseActor):
ritual_id=ritual_id, ritual_id=ritual_id,
aggregated_transcript=aggregated_transcript, aggregated_transcript=aggregated_transcript,
ciphertext=ciphertext, ciphertext=ciphertext,
conditions=conditions, aad=aad,
variant=variant variant=variant
) )

View File

@ -55,7 +55,6 @@ from nucypher_core.ferveo import (
combine_decryption_shares_precomputed, combine_decryption_shares_precomputed,
combine_decryption_shares_simple, combine_decryption_shares_simple,
decrypt_with_shared_secret, decrypt_with_shared_secret,
encrypt,
) )
from nucypher_core.umbral import ( from nucypher_core.umbral import (
PublicKey, PublicKey,
@ -90,7 +89,6 @@ from nucypher.characters.banners import (
from nucypher.characters.base import Character, Learner from nucypher.characters.base import Character, Learner
from nucypher.config.storages import NodeStorage from nucypher.config.storages import NodeStorage
from nucypher.core import ( from nucypher.core import (
AccessControlPolicy,
ThresholdDecryptionRequest, ThresholdDecryptionRequest,
ThresholdMessageKit, ThresholdMessageKit,
) )
@ -105,7 +103,6 @@ from nucypher.crypto.powers import (
TLSHostingPower, TLSHostingPower,
TransactingPower, TransactingPower,
) )
from nucypher.crypto.utils import keccak_digest
from nucypher.network import trackers from nucypher.network import trackers
from nucypher.network.decryption import ThresholdDecryptionClient from nucypher.network.decryption import ThresholdDecryptionClient
from nucypher.network.exceptions import NodeSeemsToBeDown from nucypher.network.exceptions import NodeSeemsToBeDown
@ -778,12 +775,12 @@ class Bob(Character):
shared_secret = combine_decryption_shares_simple(shares) shared_secret = combine_decryption_shares_simple(shares)
else: else:
raise ValueError(f"Invalid variant: {variant}.") raise ValueError(f"Invalid variant: {variant}.")
conditions = str(threshold_message_kit.acp.conditions).encode() # aad aad = threshold_message_kit.acp.aad()
# TODO this ferveo call should probably take the kem_ciphertext and the dem_ciphertext # TODO this ferveo call should probably take the kem_ciphertext and the dem_ciphertext
# to actually obtain the cleartext # to actually obtain the cleartext
symmetric_key = decrypt_with_shared_secret( symmetric_key = decrypt_with_shared_secret(
threshold_message_kit.kem_ciphertext, threshold_message_kit.kem_ciphertext,
conditions, # aad aad, # aad
shared_secret, shared_secret,
) )
@ -1490,29 +1487,17 @@ class Enrico:
) -> ThresholdMessageKit: ) -> ThresholdMessageKit:
validate_condition_lingo(conditions) validate_condition_lingo(conditions)
conditions_json = json.dumps(conditions) conditions_json = json.dumps(conditions)
aad = json.dumps(conditions).encode() access_conditions = Conditions(conditions_json)
# let's assume we get back dem_ciphertext, kem_ciphertext from ferveo # TODO perhaps the `Callable[[bytes]bytes]` for signing should be passed as a param?
# TODO use Fernet for now def signer(data: bytes) -> bytes:
symmetric_key = Fernet.generate_key() return self.signing_power.keypair.sign(data).to_be_bytes()
fernet = Fernet(symmetric_key)
dem_ciphertext = fernet.encrypt(plaintext)
kem_ciphertext = encrypt(symmetric_key, aad, self.policy_pubkey)
kem_ciphertext_hash = keccak_digest(bytes(kem_ciphertext)) message_kit = ThresholdMessageKit.encrypt_data(
authorization = self.signing_power.keypair.sign( plaintext=plaintext,
kem_ciphertext_hash conditions=access_conditions,
).to_be_bytes() dkg_public_key=self.policy_pubkey,
signer=signer,
acp = AccessControlPolicy(
public_key=self.policy_pubkey,
conditions=Conditions(conditions_json),
authorization=authorization,
)
message_kit = ThresholdMessageKit(
kem_ciphertext=kem_ciphertext,
dem_ciphertext=dem_ciphertext,
acp=acp,
) )
return message_kit return message_kit

View File

@ -1,18 +1,26 @@
import base64 import base64
import json import json
from typing import Dict, NamedTuple, Optional from typing import Callable, Dict, NamedTuple, Optional
from cryptography.fernet import Fernet
from nucypher_core import Conditions, Context, SessionSharedSecret, SessionStaticKey from nucypher_core import Conditions, Context, SessionSharedSecret, SessionStaticKey
from nucypher_core.ferveo import Ciphertext, DkgPublicKey from nucypher_core.ferveo import Ciphertext, DkgPublicKey, encrypt
from nucypher.crypto.utils import keccak_digest
class AccessControlPolicy(NamedTuple): class AccessControlPolicy(NamedTuple):
public_key: DkgPublicKey public_key: DkgPublicKey
conditions: Conditions # should this be folded into aad? conditions: Conditions
authorization: bytes authorization: bytes
version: int = 1
def aad(self) -> bytes:
return str(self.conditions).encode()
def to_dict(self): def to_dict(self):
d = { d = {
"version": self.version,
"public_key": base64.b64encode(bytes(self.public_key)).decode(), "public_key": base64.b64encode(bytes(self.public_key)).decode(),
"access_conditions": str(self.conditions), "access_conditions": str(self.conditions),
"authorization": { "authorization": {
@ -25,6 +33,7 @@ class AccessControlPolicy(NamedTuple):
@classmethod @classmethod
def from_dict(cls, acp_dict: Dict) -> "AccessControlPolicy": def from_dict(cls, acp_dict: Dict) -> "AccessControlPolicy":
return cls( return cls(
version=acp_dict["version"],
public_key=DkgPublicKey.from_bytes( public_key=DkgPublicKey.from_bytes(
base64.b64decode(acp_dict["public_key"]) base64.b64decode(acp_dict["public_key"])
), ),
@ -44,14 +53,69 @@ class AccessControlPolicy(NamedTuple):
return instance return instance
class ThresholdMessageKit(NamedTuple): class ThresholdMessageKit:
# one entry for now: thin ferveo ciphertext + symmetric ciphertext; ferveo#147 VERSION = 1
kem_ciphertext: Ciphertext
dem_ciphertext: bytes def __init__(
acp: AccessControlPolicy self,
kem_ciphertext: Ciphertext,
dem_ciphertext: bytes,
acp: AccessControlPolicy,
version: int = VERSION,
):
self.version = version
self.kem_ciphertext = kem_ciphertext
self.dem_ciphertext = dem_ciphertext
self.acp = acp
@staticmethod
def _validate_aad_compatibility(tmk_aad: bytes, acp_aad: bytes):
if tmk_aad != acp_aad:
raise ValueError("Incompatible ThresholdMessageKit and AccessControlPolicy")
@classmethod
def encrypt_data(
cls,
plaintext: bytes,
conditions: Conditions,
dkg_public_key: DkgPublicKey,
signer: Callable[[bytes], bytes],
):
symmetric_key = Fernet.generate_key()
fernet = Fernet(symmetric_key)
dem_ciphertext = fernet.encrypt(plaintext)
aad = str(conditions).encode()
kem_ciphertext = encrypt(symmetric_key, aad, dkg_public_key)
kem_ciphertext_hash = keccak_digest(bytes(kem_ciphertext))
authorization = signer(kem_ciphertext_hash)
acp = AccessControlPolicy(
public_key=dkg_public_key,
conditions=conditions,
authorization=authorization,
)
# we need to link the ThresholdMessageKit to a specific version of the ACP
# because the ACP.aad() function should return the same value as the aad used
# for encryption. Since the ACP version can change independently of
# ThresholdMessageKit this check is good for code maintenance and ensuring
# compatibility - unless we find a better way to link TMK and ACP.
#
# TODO: perhaps this can be improved. You could have ACP be an inner class of TMK,
# but not sure how that plays out with rust and python bindings... OR ...?
cls._validate_aad_compatibility(aad, acp.aad())
return ThresholdMessageKit(
kem_ciphertext,
dem_ciphertext,
acp,
)
def to_dict(self): def to_dict(self):
d = { d = {
"version": self.version,
"kem_ciphertext": base64.b64encode(bytes(self.kem_ciphertext)).decode(), "kem_ciphertext": base64.b64encode(bytes(self.kem_ciphertext)).decode(),
"dem_ciphertext": base64.b64encode(self.dem_ciphertext).decode(), "dem_ciphertext": base64.b64encode(self.dem_ciphertext).decode(),
"acp": self.acp.to_dict(), "acp": self.acp.to_dict(),
@ -62,6 +126,7 @@ class ThresholdMessageKit(NamedTuple):
@classmethod @classmethod
def from_dict(cls, message_kit: Dict) -> "ThresholdMessageKit": def from_dict(cls, message_kit: Dict) -> "ThresholdMessageKit":
return cls( return cls(
version=message_kit["version"],
kem_ciphertext=Ciphertext.from_bytes( kem_ciphertext=Ciphertext.from_bytes(
base64.b64decode(message_kit["kem_ciphertext"]) base64.b64decode(message_kit["kem_ciphertext"])
), ),

View File

@ -277,8 +277,8 @@ class RitualisticPower(KeyPairBasedPower):
nodes: list, nodes: list,
aggregated_transcript: AggregatedTranscript, aggregated_transcript: AggregatedTranscript,
ciphertext: Ciphertext, ciphertext: Ciphertext,
conditions: bytes, aad: bytes,
variant: FerveoVariant variant: FerveoVariant,
) -> Union[DecryptionShareSimple, DecryptionSharePrecomputed]: ) -> Union[DecryptionShareSimple, DecryptionSharePrecomputed]:
decryption_share = dkg.derive_decryption_share( decryption_share = dkg.derive_decryption_share(
ritual_id=ritual_id, ritual_id=ritual_id,
@ -289,7 +289,7 @@ class RitualisticPower(KeyPairBasedPower):
aggregated_transcript=aggregated_transcript, aggregated_transcript=aggregated_transcript,
keypair=self.keypair._privkey, keypair=self.keypair._privkey,
ciphertext=ciphertext, ciphertext=ciphertext,
aad=conditions, aad=aad,
variant=variant variant=variant
) )
return decryption_share return decryption_share

View File

@ -222,7 +222,7 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
decryption_share = this_node.derive_decryption_share( decryption_share = this_node.derive_decryption_share(
ritual_id=decryption_request.ritual_id, ritual_id=decryption_request.ritual_id,
ciphertext=decryption_request.ciphertext, ciphertext=decryption_request.ciphertext,
conditions=decryption_request.access_control_policy.conditions, aad=decryption_request.access_control_policy.aad(),
variant=decryption_request.variant, variant=decryption_request.variant,
) )