Add WS command backup/can_decrypt_on_download (#135662)

* Add WS command backup/can_decrypt_on_download

* Wrap errors

* Add default messages to exceptions

* Improve test coverage
pull/134148/head
Erik Montnemery 2025-01-15 19:40:29 +01:00 committed by GitHub
parent 3622e8331b
commit f36a10126c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 278 additions and 7 deletions

View File

@ -14,7 +14,7 @@ from pathlib import Path, PurePath
import shutil
import tarfile
import time
from typing import TYPE_CHECKING, Any, Protocol, TypedDict
from typing import IO, TYPE_CHECKING, Any, Protocol, TypedDict, cast
import aiohttp
from securetar import SecureTarFile, atomic_contents_add
@ -31,6 +31,7 @@ from homeassistant.helpers import (
from homeassistant.helpers.json import json_bytes
from homeassistant.util import dt as dt_util
from . import util as backup_util
from .agent import (
BackupAgent,
BackupAgentError,
@ -48,7 +49,13 @@ from .const import (
)
from .models import AgentBackup, BackupManagerError, Folder
from .store import BackupStore
from .util import make_backup_dir, read_backup, validate_password
from .util import (
AsyncIteratorReader,
make_backup_dir,
read_backup,
validate_password,
validate_password_stream,
)
@dataclass(frozen=True, kw_only=True, slots=True)
@ -248,6 +255,14 @@ class BackupReaderWriterError(HomeAssistantError):
class IncorrectPasswordError(BackupReaderWriterError):
"""Raised when the password is incorrect."""
_message = "The password provided is incorrect."
class DecryptOnDowloadNotSupported(BackupManagerError):
"""Raised when on-the-fly decryption is not supported."""
_message = "On-the-fly decryption is not supported for this backup."
class BackupManager:
"""Define the format that backup managers can have."""
@ -990,6 +1005,39 @@ class BackupManager:
translation_placeholders={"failed_agents": ", ".join(agent_errors)},
)
async def async_can_decrypt_on_download(
self,
backup_id: str,
*,
agent_id: str,
password: str | None,
) -> None:
"""Check if we are able to decrypt the backup on download."""
try:
agent = self.backup_agents[agent_id]
except KeyError as err:
raise BackupManagerError(f"Invalid agent selected: {agent_id}") from err
if not await agent.async_get_backup(backup_id):
raise BackupManagerError(
f"Backup {backup_id} not found in agent {agent_id}"
)
reader: IO[bytes]
if agent_id in self.local_backup_agents:
local_agent = self.local_backup_agents[agent_id]
path = local_agent.get_backup_path(backup_id)
reader = await self.hass.async_add_executor_job(open, path.as_posix(), "rb")
else:
backup_stream = await agent.async_download_backup(backup_id)
reader = cast(IO[bytes], AsyncIteratorReader(self.hass, backup_stream))
try:
validate_password_stream(reader, password)
except backup_util.IncorrectPassword as err:
raise IncorrectPasswordError from err
except backup_util.UnsuppertedSecureTarVersion as err:
raise DecryptOnDowloadNotSupported from err
except backup_util.DecryptError as err:
raise BackupManagerError(str(err)) from err
class KnownBackups:
"""Track known backups."""
@ -1372,7 +1420,7 @@ class CoreBackupReaderWriter(BackupReaderWriter):
validate_password, path, password
)
if not password_valid:
raise IncorrectPasswordError("The password provided is incorrect.")
raise IncorrectPasswordError
def _write_restore_file() -> None:
"""Write the restore file."""

View File

@ -3,13 +3,14 @@
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator
from pathlib import Path
from queue import SimpleQueue
import tarfile
from typing import cast
from typing import IO, cast
import aiohttp
from securetar import SecureTarFile
from securetar import VERSION_HEADER, SecureTarFile, SecureTarReadError
from homeassistant.backup_restore import password_to_key
from homeassistant.core import HomeAssistant
@ -19,6 +20,22 @@ from .const import BUF_SIZE, LOGGER
from .models import AddonInfo, AgentBackup, Folder
class DecryptError(Exception):
"""Error during decryption."""
class UnsuppertedSecureTarVersion(DecryptError):
"""Unsupported securetar version."""
class IncorrectPassword(DecryptError):
"""Invalid password or corrupted backup."""
class BackupEmpty(DecryptError):
"""No tar files found in the backup."""
def make_backup_dir(path: Path) -> None:
"""Create a backup directory if it does not exist."""
path.mkdir(exist_ok=True)
@ -106,6 +123,70 @@ def validate_password(path: Path, password: str | None) -> bool:
return False
class AsyncIteratorReader:
"""Wrap an AsyncIterator."""
def __init__(self, hass: HomeAssistant, stream: AsyncIterator[bytes]) -> None:
"""Initialize the wrapper."""
self._hass = hass
self._stream = stream
self._buffer: bytes | None = None
self._pos: int = 0
async def _next(self) -> bytes | None:
"""Get the next chunk from the iterator."""
return await anext(self._stream, None)
def read(self, n: int = -1, /) -> bytes:
"""Read data from the iterator."""
result = bytearray()
while n < 0 or len(result) < n:
if not self._buffer:
self._buffer = asyncio.run_coroutine_threadsafe(
self._next(), self._hass.loop
).result()
self._pos = 0
if not self._buffer:
# The stream is exhausted
break
chunk = self._buffer[self._pos : self._pos + n]
result.extend(chunk)
n -= len(chunk)
self._pos += len(chunk)
if self._pos == len(self._buffer):
self._buffer = None
return bytes(result)
def validate_password_stream(
input_stream: IO[bytes],
password: str | None,
) -> None:
"""Decrypt a backup."""
with (
tarfile.open(fileobj=input_stream, mode="r|", bufsize=BUF_SIZE) as input_tar,
):
for obj in input_tar:
if not obj.name.endswith((".tar", ".tgz", ".tar.gz")):
continue
if obj.pax_headers.get(VERSION_HEADER) != "2.0":
raise UnsuppertedSecureTarVersion
istf = SecureTarFile(
None, # Not used
gzip=False,
key=password_to_key(password) if password is not None else None,
mode="r",
fileobj=input_tar.extractfile(obj),
)
with istf.decrypt(obj) as decrypted:
try:
decrypted.read(1) # Read a single byte to trigger the decryption
except SecureTarReadError as err:
raise IncorrectPassword from err
return
raise BackupEmpty
async def receive_file(
hass: HomeAssistant, contents: aiohttp.BodyPartReader, path: Path
) -> None:

View File

@ -9,7 +9,11 @@ from homeassistant.core import HomeAssistant, callback
from .config import ScheduleState
from .const import DATA_MANAGER, LOGGER
from .manager import IncorrectPasswordError, ManagerStateEvent
from .manager import (
DecryptOnDowloadNotSupported,
IncorrectPasswordError,
ManagerStateEvent,
)
from .models import Folder
@ -24,6 +28,7 @@ def async_register_websocket_handlers(hass: HomeAssistant, with_hassio: bool) ->
websocket_api.async_register_command(hass, handle_details)
websocket_api.async_register_command(hass, handle_info)
websocket_api.async_register_command(hass, handle_can_decrypt_on_download)
websocket_api.async_register_command(hass, handle_create)
websocket_api.async_register_command(hass, handle_create_with_automatic_settings)
websocket_api.async_register_command(hass, handle_delete)
@ -147,6 +152,38 @@ async def handle_restore(
connection.send_result(msg["id"])
@websocket_api.require_admin
@websocket_api.websocket_command(
{
vol.Required("type"): "backup/can_decrypt_on_download",
vol.Required("backup_id"): str,
vol.Required("agent_id"): str,
vol.Required("password"): str,
}
)
@websocket_api.async_response
async def handle_can_decrypt_on_download(
hass: HomeAssistant,
connection: websocket_api.ActiveConnection,
msg: dict[str, Any],
) -> None:
"""Check if the supplied password is correct."""
try:
await hass.data[DATA_MANAGER].async_can_decrypt_on_download(
msg["backup_id"],
agent_id=msg["agent_id"],
password=msg.get("password"),
)
except IncorrectPasswordError:
connection.send_error(msg["id"], "password_incorrect", "Incorrect password")
except DecryptOnDowloadNotSupported:
connection.send_error(
msg["id"], "decrypt_not_supported", "Decrypt on download not supported"
)
else:
connection.send_result(msg["id"])
@websocket_api.require_admin
@websocket_api.websocket_command(
{

View File

@ -175,6 +175,58 @@
'type': 'result',
})
# ---
# name: test_can_decrypt_on_download[backup.local-2bcb3113-hunter2]
dict({
'error': dict({
'code': 'decrypt_not_supported',
'message': 'Decrypt on download not supported',
}),
'id': 1,
'success': False,
'type': 'result',
})
# ---
# name: test_can_decrypt_on_download[backup.local-ed1608a9-hunter2]
dict({
'id': 1,
'result': None,
'success': True,
'type': 'result',
})
# ---
# name: test_can_decrypt_on_download[backup.local-ed1608a9-wrong_password]
dict({
'error': dict({
'code': 'password_incorrect',
'message': 'Incorrect password',
}),
'id': 1,
'success': False,
'type': 'result',
})
# ---
# name: test_can_decrypt_on_download[backup.local-no_such_backup-hunter2]
dict({
'error': dict({
'code': 'home_assistant_error',
'message': 'Backup no_such_backup not found in agent backup.local',
}),
'id': 1,
'success': False,
'type': 'result',
})
# ---
# name: test_can_decrypt_on_download[no_such_agent-ed1608a9-hunter2]
dict({
'error': dict({
'code': 'home_assistant_error',
'message': 'Invalid agent selected: no_such_agent',
}),
'id': 1,
'success': False,
'type': 'result',
})
# ---
# name: test_config_info[None]
dict({
'id': 1,

View File

@ -36,7 +36,7 @@ from .common import (
setup_backup_platform,
)
from tests.common import async_fire_time_changed, async_mock_service
from tests.common import async_fire_time_changed, async_mock_service, get_fixture_path
from tests.typing import WebSocketGenerator
BACKUP_CALL = call(
@ -2554,3 +2554,56 @@ async def test_subscribe_event(
CreateBackupEvent(stage=None, state=CreateBackupState.IN_PROGRESS)
)
assert await client.receive_json() == snapshot
@pytest.fixture
def mock_backups() -> Generator[None]:
"""Fixture to setup test backups."""
# pylint: disable-next=import-outside-toplevel
from homeassistant.components.backup import backup as core_backup
class CoreLocalBackupAgent(core_backup.CoreLocalBackupAgent):
def __init__(self, hass: HomeAssistant) -> None:
super().__init__(hass)
self._backup_dir = get_fixture_path("test_backups", DOMAIN)
with patch.object(core_backup, "CoreLocalBackupAgent", CoreLocalBackupAgent):
yield
@pytest.mark.parametrize(
("agent_id", "backup_id", "password"),
[
# Invalid agent or backup
("no_such_agent", "ed1608a9", "hunter2"),
("backup.local", "no_such_backup", "hunter2"),
# Legacy backup, which can't be streamed
("backup.local", "2bcb3113", "hunter2"),
# New backup, which can be streamed, try with correct and wrong password
("backup.local", "ed1608a9", "hunter2"),
("backup.local", "ed1608a9", "wrong_password"),
],
)
@pytest.mark.usefixtures("mock_backups")
async def test_can_decrypt_on_download(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
snapshot: SnapshotAssertion,
agent_id: str,
backup_id: str,
password: str,
) -> None:
"""Test can decrypt on download."""
await setup_backup_integration(hass, with_hassio=False)
client = await hass_ws_client(hass)
await client.send_json_auto_id(
{
"type": "backup/can_decrypt_on_download",
"backup_id": backup_id,
"agent_id": agent_id,
"password": password,
}
)
assert await client.receive_json() == snapshot