Move KEM/DEM logic to ThresholdMessageKit, and add ACP.aad() function - the AAD can be controlled by versioning but the TMK dictates the AAD and so must be linked somehow with the ACP.aad() function. For now this is done via a compatibility check function.

pull/3194/head
derekpierre 2023-08-02 12:46:13 -04:00 committed by Kieran Prasch
parent 51e78a14cf
commit fc9edd9ed4
5 changed files with 98 additions and 54 deletions

View File

@ -7,7 +7,6 @@ import time
from eth_typing import ChecksumAddress
from hexbytes import HexBytes
from nucypher_core import (
Conditions,
EncryptedThresholdDecryptionResponse,
SessionStaticKey,
ThresholdDecryptionResponse,
@ -630,11 +629,7 @@ class Ritualist(BaseActor):
return tx_hash
def derive_decryption_share(
self,
ritual_id: int,
ciphertext: Ciphertext,
conditions: Conditions,
variant: FerveoVariant
self, ritual_id: int, ciphertext: Ciphertext, aad: bytes, variant: FerveoVariant
) -> Union[DecryptionShareSimple, DecryptionSharePrecomputed]:
ritual = self.coordinator_agent.get_ritual(ritual_id)
status = self.coordinator_agent.get_ritual_status(ritual_id=ritual_id)
@ -648,7 +643,6 @@ class Ritualist(BaseActor):
)
threshold = (ritual.shares // 2) + 1
conditions = str(conditions).encode()
# TODO: consider the usage of local DKG artifact storage here #3052
# aggregated_transcript_bytes = self.dkg_storage.get_aggregated_transcript(ritual_id)
aggregated_transcript = AggregatedTranscript.from_bytes(bytes(ritual.aggregated_transcript))
@ -660,7 +654,7 @@ class Ritualist(BaseActor):
ritual_id=ritual_id,
aggregated_transcript=aggregated_transcript,
ciphertext=ciphertext,
conditions=conditions,
aad=aad,
variant=variant
)

View File

@ -55,7 +55,6 @@ from nucypher_core.ferveo import (
combine_decryption_shares_precomputed,
combine_decryption_shares_simple,
decrypt_with_shared_secret,
encrypt,
)
from nucypher_core.umbral import (
PublicKey,
@ -90,7 +89,6 @@ from nucypher.characters.banners import (
from nucypher.characters.base import Character, Learner
from nucypher.config.storages import NodeStorage
from nucypher.core import (
AccessControlPolicy,
ThresholdDecryptionRequest,
ThresholdMessageKit,
)
@ -105,7 +103,6 @@ from nucypher.crypto.powers import (
TLSHostingPower,
TransactingPower,
)
from nucypher.crypto.utils import keccak_digest
from nucypher.network import trackers
from nucypher.network.decryption import ThresholdDecryptionClient
from nucypher.network.exceptions import NodeSeemsToBeDown
@ -778,12 +775,12 @@ class Bob(Character):
shared_secret = combine_decryption_shares_simple(shares)
else:
raise ValueError(f"Invalid variant: {variant}.")
conditions = str(threshold_message_kit.acp.conditions).encode() # aad
aad = threshold_message_kit.acp.aad()
# TODO this ferveo call should probably take the kem_ciphertext and the dem_ciphertext
# to actually obtain the cleartext
symmetric_key = decrypt_with_shared_secret(
threshold_message_kit.kem_ciphertext,
conditions, # aad
aad, # aad
shared_secret,
)
@ -1490,29 +1487,17 @@ class Enrico:
) -> ThresholdMessageKit:
validate_condition_lingo(conditions)
conditions_json = json.dumps(conditions)
aad = json.dumps(conditions).encode()
access_conditions = Conditions(conditions_json)
# let's assume we get back dem_ciphertext, kem_ciphertext from ferveo
# TODO use Fernet for now
symmetric_key = Fernet.generate_key()
fernet = Fernet(symmetric_key)
dem_ciphertext = fernet.encrypt(plaintext)
kem_ciphertext = encrypt(symmetric_key, aad, self.policy_pubkey)
# TODO perhaps the `Callable[[bytes]bytes]` for signing should be passed as a param?
def signer(data: bytes) -> bytes:
return self.signing_power.keypair.sign(data).to_be_bytes()
kem_ciphertext_hash = keccak_digest(bytes(kem_ciphertext))
authorization = self.signing_power.keypair.sign(
kem_ciphertext_hash
).to_be_bytes()
acp = AccessControlPolicy(
public_key=self.policy_pubkey,
conditions=Conditions(conditions_json),
authorization=authorization,
)
message_kit = ThresholdMessageKit(
kem_ciphertext=kem_ciphertext,
dem_ciphertext=dem_ciphertext,
acp=acp,
message_kit = ThresholdMessageKit.encrypt_data(
plaintext=plaintext,
conditions=access_conditions,
dkg_public_key=self.policy_pubkey,
signer=signer,
)
return message_kit

View File

@ -1,18 +1,26 @@
import base64
import json
from typing import Dict, NamedTuple, Optional
from typing import Callable, Dict, NamedTuple, Optional
from cryptography.fernet import Fernet
from nucypher_core import Conditions, Context, SessionSharedSecret, SessionStaticKey
from nucypher_core.ferveo import Ciphertext, DkgPublicKey
from nucypher_core.ferveo import Ciphertext, DkgPublicKey, encrypt
from nucypher.crypto.utils import keccak_digest
class AccessControlPolicy(NamedTuple):
public_key: DkgPublicKey
conditions: Conditions # should this be folded into aad?
conditions: Conditions
authorization: bytes
version: int = 1
def aad(self) -> bytes:
return str(self.conditions).encode()
def to_dict(self):
d = {
"version": self.version,
"public_key": base64.b64encode(bytes(self.public_key)).decode(),
"access_conditions": str(self.conditions),
"authorization": {
@ -25,6 +33,7 @@ class AccessControlPolicy(NamedTuple):
@classmethod
def from_dict(cls, acp_dict: Dict) -> "AccessControlPolicy":
return cls(
version=acp_dict["version"],
public_key=DkgPublicKey.from_bytes(
base64.b64decode(acp_dict["public_key"])
),
@ -44,14 +53,69 @@ class AccessControlPolicy(NamedTuple):
return instance
class ThresholdMessageKit(NamedTuple):
# one entry for now: thin ferveo ciphertext + symmetric ciphertext; ferveo#147
kem_ciphertext: Ciphertext
dem_ciphertext: bytes
acp: AccessControlPolicy
class ThresholdMessageKit:
VERSION = 1
def __init__(
self,
kem_ciphertext: Ciphertext,
dem_ciphertext: bytes,
acp: AccessControlPolicy,
version: int = VERSION,
):
self.version = version
self.kem_ciphertext = kem_ciphertext
self.dem_ciphertext = dem_ciphertext
self.acp = acp
@staticmethod
def _validate_aad_compatibility(tmk_aad: bytes, acp_aad: bytes):
if tmk_aad != acp_aad:
raise ValueError("Incompatible ThresholdMessageKit and AccessControlPolicy")
@classmethod
def encrypt_data(
cls,
plaintext: bytes,
conditions: Conditions,
dkg_public_key: DkgPublicKey,
signer: Callable[[bytes], bytes],
):
symmetric_key = Fernet.generate_key()
fernet = Fernet(symmetric_key)
dem_ciphertext = fernet.encrypt(plaintext)
aad = str(conditions).encode()
kem_ciphertext = encrypt(symmetric_key, aad, dkg_public_key)
kem_ciphertext_hash = keccak_digest(bytes(kem_ciphertext))
authorization = signer(kem_ciphertext_hash)
acp = AccessControlPolicy(
public_key=dkg_public_key,
conditions=conditions,
authorization=authorization,
)
# we need to link the ThresholdMessageKit to a specific version of the ACP
# because the ACP.aad() function should return the same value as the aad used
# for encryption. Since the ACP version can change independently of
# ThresholdMessageKit this check is good for code maintenance and ensuring
# compatibility - unless we find a better way to link TMK and ACP.
#
# TODO: perhaps this can be improved. You could have ACP be an inner class of TMK,
# but not sure how that plays out with rust and python bindings... OR ...?
cls._validate_aad_compatibility(aad, acp.aad())
return ThresholdMessageKit(
kem_ciphertext,
dem_ciphertext,
acp,
)
def to_dict(self):
d = {
"version": self.version,
"kem_ciphertext": base64.b64encode(bytes(self.kem_ciphertext)).decode(),
"dem_ciphertext": base64.b64encode(self.dem_ciphertext).decode(),
"acp": self.acp.to_dict(),
@ -62,6 +126,7 @@ class ThresholdMessageKit(NamedTuple):
@classmethod
def from_dict(cls, message_kit: Dict) -> "ThresholdMessageKit":
return cls(
version=message_kit["version"],
kem_ciphertext=Ciphertext.from_bytes(
base64.b64decode(message_kit["kem_ciphertext"])
),

View File

@ -277,8 +277,8 @@ class RitualisticPower(KeyPairBasedPower):
nodes: list,
aggregated_transcript: AggregatedTranscript,
ciphertext: Ciphertext,
conditions: bytes,
variant: FerveoVariant
aad: bytes,
variant: FerveoVariant,
) -> Union[DecryptionShareSimple, DecryptionSharePrecomputed]:
decryption_share = dkg.derive_decryption_share(
ritual_id=ritual_id,
@ -289,7 +289,7 @@ class RitualisticPower(KeyPairBasedPower):
aggregated_transcript=aggregated_transcript,
keypair=self.keypair._privkey,
ciphertext=ciphertext,
aad=conditions,
aad=aad,
variant=variant
)
return decryption_share

View File

@ -222,7 +222,7 @@ def _make_rest_app(this_node, log: Logger) -> Flask:
decryption_share = this_node.derive_decryption_share(
ritual_id=decryption_request.ritual_id,
ciphertext=decryption_request.ciphertext,
conditions=decryption_request.access_control_policy.conditions,
aad=decryption_request.access_control_policy.aad(),
variant=decryption_request.variant,
)