mirror of https://github.com/nucypher/nucypher.git
Simplify authentication logic for EIP1271.
parent
75e49a603f
commit
f340e3b153
|
@ -164,6 +164,52 @@ class EIP1271Auth(EvmAuth):
|
|||
MAGIC_VALUE_BYTES = bytes(HexBytes("0x1626ba7e"))
|
||||
LOG = Logger("EIP1271Auth")
|
||||
|
||||
@classmethod
|
||||
def _extract_typed_data(cls, data):
|
||||
try:
|
||||
data_hash = bytes(HexBytes(data["dataHash"]))
|
||||
chain = data["chain"]
|
||||
return data_hash, chain
|
||||
except Exception as e:
|
||||
# data could not be processed
|
||||
raise cls.InvalidData(
|
||||
f"Invalid EIP1271 authentication data: {str(e) or e.__class__.__name__}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _validate_auth_data(
|
||||
cls, data_hash, signature_bytes, expected_address, chain, providers
|
||||
):
|
||||
web3_endpoints = providers.web3_endpoints(chain_id=chain)
|
||||
last_error = None
|
||||
for web3_instance in web3_endpoints:
|
||||
try:
|
||||
# Interact with the EIP1271 contract
|
||||
eip1271_contract = web3_instance.eth.contract(
|
||||
address=expected_address, abi=cls.EIP1271_ABI
|
||||
)
|
||||
result = eip1271_contract.functions.isValidSignature(
|
||||
data_hash,
|
||||
signature_bytes,
|
||||
).call()
|
||||
if result == cls.MAGIC_VALUE_BYTES:
|
||||
return # Successful authentication
|
||||
|
||||
break
|
||||
except Exception as e:
|
||||
last_error = f"EIP1271 contract call failed ({expected_address}): {e}"
|
||||
cls.LOG.warn(f"{last_error}; attempting next provider")
|
||||
else:
|
||||
# If all providers fail
|
||||
if last_error:
|
||||
raise cls.AuthenticationFailed(
|
||||
f"EIP1271 verification failed; {last_error}"
|
||||
)
|
||||
|
||||
raise cls.AuthenticationFailed(
|
||||
f"EIP1271 verification failed; signature not valid for contract address, {expected_address}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def authenticate(
|
||||
cls,
|
||||
|
@ -172,46 +218,27 @@ class EIP1271Auth(EvmAuth):
|
|||
expected_address: ChecksumAddress,
|
||||
providers: Optional[ConditionProviderManager] = None,
|
||||
):
|
||||
try:
|
||||
data_hash = bytes(HexBytes(data["dataHash"]))
|
||||
chain = data["chain"]
|
||||
signature_bytes = bytes(HexBytes(signature))
|
||||
w3_instances = providers.web3_endpoints(chain_id=chain)
|
||||
if not providers:
|
||||
# should never happen
|
||||
raise cls.AuthenticationFailed(
|
||||
"EIP1271 verification failed; no endpoints provided"
|
||||
)
|
||||
|
||||
latest_error = ""
|
||||
for w3 in w3_instances:
|
||||
try:
|
||||
eip1271_contract = w3.eth.contract(
|
||||
address=expected_address, abi=cls.EIP1271_ABI
|
||||
)
|
||||
result = eip1271_contract.functions.isValidSignature(
|
||||
data_hash,
|
||||
signature_bytes,
|
||||
).call()
|
||||
if result == cls.MAGIC_VALUE_BYTES:
|
||||
return # Authentication successful
|
||||
break
|
||||
except Exception as e:
|
||||
latest_error = (
|
||||
f"EIP1271 contract call failed ({expected_address}): {e}"
|
||||
)
|
||||
cls.LOG.warn(f"{latest_error}; attempting next provider")
|
||||
else:
|
||||
raise cls.AuthenticationFailed(
|
||||
f"EIP1271 verification failed; {latest_error}"
|
||||
)
|
||||
# Extract and validate input data
|
||||
data_hash, chain = cls._extract_typed_data(data)
|
||||
|
||||
# Validate the signature
|
||||
signature_bytes = bytes(HexBytes(signature))
|
||||
try:
|
||||
cls._validate_auth_data(
|
||||
data_hash, signature_bytes, expected_address, chain, providers
|
||||
)
|
||||
except NoConnectionToChain:
|
||||
raise cls.AuthenticationFailed(
|
||||
f"EIP1271 verification failed; No connection to chain ID {data['chain']}"
|
||||
f"EIP1271 verification failed; No connection to chain ID {chain}"
|
||||
)
|
||||
except cls.AuthenticationFailed:
|
||||
raise
|
||||
except Exception as e:
|
||||
# data could not be processed
|
||||
raise cls.InvalidData(
|
||||
f"Invalid EIP1271 authentication data: {str(e) or e.__class__.__name__}"
|
||||
)
|
||||
|
||||
raise cls.AuthenticationFailed(
|
||||
f"EIP1271 verification failed; signature not valid for contract address, {expected_address}"
|
||||
)
|
||||
# catch all
|
||||
raise cls.AuthenticationFailed(f"EIP1271 verification failed; {e}")
|
||||
|
|
|
@ -345,6 +345,12 @@ def test_authenticate_eip1271(mocker, get_random_checksum_address):
|
|||
address=eip1271_mock_contract.address, abi=EIP1271Auth.EIP1271_ABI
|
||||
)
|
||||
|
||||
# no providers
|
||||
with pytest.raises(EvmAuth.AuthenticationFailed, match="no endpoints provided"):
|
||||
EIP1271Auth.authenticate(
|
||||
typedData, valid_message_signature, eip1271_mock_contract.address, None
|
||||
)
|
||||
|
||||
# invalid typed data - no chain id
|
||||
with pytest.raises(EvmAuth.InvalidData):
|
||||
EIP1271Auth.authenticate(
|
||||
|
|
Loading…
Reference in New Issue