Simplify authentication logic for EIP1271.

pull/3576/head
derekpierre 2025-02-03 09:27:25 -05:00
parent 75e49a603f
commit f340e3b153
No known key found for this signature in database
2 changed files with 69 additions and 36 deletions

View File

@ -164,6 +164,52 @@ class EIP1271Auth(EvmAuth):
MAGIC_VALUE_BYTES = bytes(HexBytes("0x1626ba7e"))
LOG = Logger("EIP1271Auth")
@classmethod
def _extract_typed_data(cls, data):
try:
data_hash = bytes(HexBytes(data["dataHash"]))
chain = data["chain"]
return data_hash, chain
except Exception as e:
# data could not be processed
raise cls.InvalidData(
f"Invalid EIP1271 authentication data: {str(e) or e.__class__.__name__}"
)
@classmethod
def _validate_auth_data(
cls, data_hash, signature_bytes, expected_address, chain, providers
):
web3_endpoints = providers.web3_endpoints(chain_id=chain)
last_error = None
for web3_instance in web3_endpoints:
try:
# Interact with the EIP1271 contract
eip1271_contract = web3_instance.eth.contract(
address=expected_address, abi=cls.EIP1271_ABI
)
result = eip1271_contract.functions.isValidSignature(
data_hash,
signature_bytes,
).call()
if result == cls.MAGIC_VALUE_BYTES:
return # Successful authentication
break
except Exception as e:
last_error = f"EIP1271 contract call failed ({expected_address}): {e}"
cls.LOG.warn(f"{last_error}; attempting next provider")
else:
# If all providers fail
if last_error:
raise cls.AuthenticationFailed(
f"EIP1271 verification failed; {last_error}"
)
raise cls.AuthenticationFailed(
f"EIP1271 verification failed; signature not valid for contract address, {expected_address}"
)
@classmethod
def authenticate(
cls,
@ -172,46 +218,27 @@ class EIP1271Auth(EvmAuth):
expected_address: ChecksumAddress,
providers: Optional[ConditionProviderManager] = None,
):
try:
data_hash = bytes(HexBytes(data["dataHash"]))
chain = data["chain"]
signature_bytes = bytes(HexBytes(signature))
w3_instances = providers.web3_endpoints(chain_id=chain)
if not providers:
# should never happen
raise cls.AuthenticationFailed(
"EIP1271 verification failed; no endpoints provided"
)
latest_error = ""
for w3 in w3_instances:
try:
eip1271_contract = w3.eth.contract(
address=expected_address, abi=cls.EIP1271_ABI
)
result = eip1271_contract.functions.isValidSignature(
data_hash,
signature_bytes,
).call()
if result == cls.MAGIC_VALUE_BYTES:
return # Authentication successful
break
except Exception as e:
latest_error = (
f"EIP1271 contract call failed ({expected_address}): {e}"
)
cls.LOG.warn(f"{latest_error}; attempting next provider")
else:
raise cls.AuthenticationFailed(
f"EIP1271 verification failed; {latest_error}"
)
# Extract and validate input data
data_hash, chain = cls._extract_typed_data(data)
# Validate the signature
signature_bytes = bytes(HexBytes(signature))
try:
cls._validate_auth_data(
data_hash, signature_bytes, expected_address, chain, providers
)
except NoConnectionToChain:
raise cls.AuthenticationFailed(
f"EIP1271 verification failed; No connection to chain ID {data['chain']}"
f"EIP1271 verification failed; No connection to chain ID {chain}"
)
except cls.AuthenticationFailed:
raise
except Exception as e:
# data could not be processed
raise cls.InvalidData(
f"Invalid EIP1271 authentication data: {str(e) or e.__class__.__name__}"
)
raise cls.AuthenticationFailed(
f"EIP1271 verification failed; signature not valid for contract address, {expected_address}"
)
# catch all
raise cls.AuthenticationFailed(f"EIP1271 verification failed; {e}")

View File

@ -345,6 +345,12 @@ def test_authenticate_eip1271(mocker, get_random_checksum_address):
address=eip1271_mock_contract.address, abi=EIP1271Auth.EIP1271_ABI
)
# no providers
with pytest.raises(EvmAuth.AuthenticationFailed, match="no endpoints provided"):
EIP1271Auth.authenticate(
typedData, valid_message_signature, eip1271_mock_contract.address, None
)
# invalid typed data - no chain id
with pytest.raises(EvmAuth.InvalidData):
EIP1271Auth.authenticate(