Add WS command backup/can_decrypt_on_download (#135662)
* Add WS command backup/can_decrypt_on_download * Wrap errors * Add default messages to exceptions * Improve test coveragepull/134148/head
parent
3622e8331b
commit
f36a10126c
|
@ -14,7 +14,7 @@ from pathlib import Path, PurePath
|
|||
import shutil
|
||||
import tarfile
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypedDict
|
||||
from typing import IO, TYPE_CHECKING, Any, Protocol, TypedDict, cast
|
||||
|
||||
import aiohttp
|
||||
from securetar import SecureTarFile, atomic_contents_add
|
||||
|
@ -31,6 +31,7 @@ from homeassistant.helpers import (
|
|||
from homeassistant.helpers.json import json_bytes
|
||||
from homeassistant.util import dt as dt_util
|
||||
|
||||
from . import util as backup_util
|
||||
from .agent import (
|
||||
BackupAgent,
|
||||
BackupAgentError,
|
||||
|
@ -48,7 +49,13 @@ from .const import (
|
|||
)
|
||||
from .models import AgentBackup, BackupManagerError, Folder
|
||||
from .store import BackupStore
|
||||
from .util import make_backup_dir, read_backup, validate_password
|
||||
from .util import (
|
||||
AsyncIteratorReader,
|
||||
make_backup_dir,
|
||||
read_backup,
|
||||
validate_password,
|
||||
validate_password_stream,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, kw_only=True, slots=True)
|
||||
|
@ -248,6 +255,14 @@ class BackupReaderWriterError(HomeAssistantError):
|
|||
class IncorrectPasswordError(BackupReaderWriterError):
|
||||
"""Raised when the password is incorrect."""
|
||||
|
||||
_message = "The password provided is incorrect."
|
||||
|
||||
|
||||
class DecryptOnDowloadNotSupported(BackupManagerError):
|
||||
"""Raised when on-the-fly decryption is not supported."""
|
||||
|
||||
_message = "On-the-fly decryption is not supported for this backup."
|
||||
|
||||
|
||||
class BackupManager:
|
||||
"""Define the format that backup managers can have."""
|
||||
|
@ -990,6 +1005,39 @@ class BackupManager:
|
|||
translation_placeholders={"failed_agents": ", ".join(agent_errors)},
|
||||
)
|
||||
|
||||
async def async_can_decrypt_on_download(
|
||||
self,
|
||||
backup_id: str,
|
||||
*,
|
||||
agent_id: str,
|
||||
password: str | None,
|
||||
) -> None:
|
||||
"""Check if we are able to decrypt the backup on download."""
|
||||
try:
|
||||
agent = self.backup_agents[agent_id]
|
||||
except KeyError as err:
|
||||
raise BackupManagerError(f"Invalid agent selected: {agent_id}") from err
|
||||
if not await agent.async_get_backup(backup_id):
|
||||
raise BackupManagerError(
|
||||
f"Backup {backup_id} not found in agent {agent_id}"
|
||||
)
|
||||
reader: IO[bytes]
|
||||
if agent_id in self.local_backup_agents:
|
||||
local_agent = self.local_backup_agents[agent_id]
|
||||
path = local_agent.get_backup_path(backup_id)
|
||||
reader = await self.hass.async_add_executor_job(open, path.as_posix(), "rb")
|
||||
else:
|
||||
backup_stream = await agent.async_download_backup(backup_id)
|
||||
reader = cast(IO[bytes], AsyncIteratorReader(self.hass, backup_stream))
|
||||
try:
|
||||
validate_password_stream(reader, password)
|
||||
except backup_util.IncorrectPassword as err:
|
||||
raise IncorrectPasswordError from err
|
||||
except backup_util.UnsuppertedSecureTarVersion as err:
|
||||
raise DecryptOnDowloadNotSupported from err
|
||||
except backup_util.DecryptError as err:
|
||||
raise BackupManagerError(str(err)) from err
|
||||
|
||||
|
||||
class KnownBackups:
|
||||
"""Track known backups."""
|
||||
|
@ -1372,7 +1420,7 @@ class CoreBackupReaderWriter(BackupReaderWriter):
|
|||
validate_password, path, password
|
||||
)
|
||||
if not password_valid:
|
||||
raise IncorrectPasswordError("The password provided is incorrect.")
|
||||
raise IncorrectPasswordError
|
||||
|
||||
def _write_restore_file() -> None:
|
||||
"""Write the restore file."""
|
||||
|
|
|
@ -3,13 +3,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
from pathlib import Path
|
||||
from queue import SimpleQueue
|
||||
import tarfile
|
||||
from typing import cast
|
||||
from typing import IO, cast
|
||||
|
||||
import aiohttp
|
||||
from securetar import SecureTarFile
|
||||
from securetar import VERSION_HEADER, SecureTarFile, SecureTarReadError
|
||||
|
||||
from homeassistant.backup_restore import password_to_key
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
@ -19,6 +20,22 @@ from .const import BUF_SIZE, LOGGER
|
|||
from .models import AddonInfo, AgentBackup, Folder
|
||||
|
||||
|
||||
class DecryptError(Exception):
|
||||
"""Error during decryption."""
|
||||
|
||||
|
||||
class UnsuppertedSecureTarVersion(DecryptError):
|
||||
"""Unsupported securetar version."""
|
||||
|
||||
|
||||
class IncorrectPassword(DecryptError):
|
||||
"""Invalid password or corrupted backup."""
|
||||
|
||||
|
||||
class BackupEmpty(DecryptError):
|
||||
"""No tar files found in the backup."""
|
||||
|
||||
|
||||
def make_backup_dir(path: Path) -> None:
|
||||
"""Create a backup directory if it does not exist."""
|
||||
path.mkdir(exist_ok=True)
|
||||
|
@ -106,6 +123,70 @@ def validate_password(path: Path, password: str | None) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
class AsyncIteratorReader:
|
||||
"""Wrap an AsyncIterator."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant, stream: AsyncIterator[bytes]) -> None:
|
||||
"""Initialize the wrapper."""
|
||||
self._hass = hass
|
||||
self._stream = stream
|
||||
self._buffer: bytes | None = None
|
||||
self._pos: int = 0
|
||||
|
||||
async def _next(self) -> bytes | None:
|
||||
"""Get the next chunk from the iterator."""
|
||||
return await anext(self._stream, None)
|
||||
|
||||
def read(self, n: int = -1, /) -> bytes:
|
||||
"""Read data from the iterator."""
|
||||
result = bytearray()
|
||||
while n < 0 or len(result) < n:
|
||||
if not self._buffer:
|
||||
self._buffer = asyncio.run_coroutine_threadsafe(
|
||||
self._next(), self._hass.loop
|
||||
).result()
|
||||
self._pos = 0
|
||||
if not self._buffer:
|
||||
# The stream is exhausted
|
||||
break
|
||||
chunk = self._buffer[self._pos : self._pos + n]
|
||||
result.extend(chunk)
|
||||
n -= len(chunk)
|
||||
self._pos += len(chunk)
|
||||
if self._pos == len(self._buffer):
|
||||
self._buffer = None
|
||||
return bytes(result)
|
||||
|
||||
|
||||
def validate_password_stream(
|
||||
input_stream: IO[bytes],
|
||||
password: str | None,
|
||||
) -> None:
|
||||
"""Decrypt a backup."""
|
||||
with (
|
||||
tarfile.open(fileobj=input_stream, mode="r|", bufsize=BUF_SIZE) as input_tar,
|
||||
):
|
||||
for obj in input_tar:
|
||||
if not obj.name.endswith((".tar", ".tgz", ".tar.gz")):
|
||||
continue
|
||||
if obj.pax_headers.get(VERSION_HEADER) != "2.0":
|
||||
raise UnsuppertedSecureTarVersion
|
||||
istf = SecureTarFile(
|
||||
None, # Not used
|
||||
gzip=False,
|
||||
key=password_to_key(password) if password is not None else None,
|
||||
mode="r",
|
||||
fileobj=input_tar.extractfile(obj),
|
||||
)
|
||||
with istf.decrypt(obj) as decrypted:
|
||||
try:
|
||||
decrypted.read(1) # Read a single byte to trigger the decryption
|
||||
except SecureTarReadError as err:
|
||||
raise IncorrectPassword from err
|
||||
return
|
||||
raise BackupEmpty
|
||||
|
||||
|
||||
async def receive_file(
|
||||
hass: HomeAssistant, contents: aiohttp.BodyPartReader, path: Path
|
||||
) -> None:
|
||||
|
|
|
@ -9,7 +9,11 @@ from homeassistant.core import HomeAssistant, callback
|
|||
|
||||
from .config import ScheduleState
|
||||
from .const import DATA_MANAGER, LOGGER
|
||||
from .manager import IncorrectPasswordError, ManagerStateEvent
|
||||
from .manager import (
|
||||
DecryptOnDowloadNotSupported,
|
||||
IncorrectPasswordError,
|
||||
ManagerStateEvent,
|
||||
)
|
||||
from .models import Folder
|
||||
|
||||
|
||||
|
@ -24,6 +28,7 @@ def async_register_websocket_handlers(hass: HomeAssistant, with_hassio: bool) ->
|
|||
|
||||
websocket_api.async_register_command(hass, handle_details)
|
||||
websocket_api.async_register_command(hass, handle_info)
|
||||
websocket_api.async_register_command(hass, handle_can_decrypt_on_download)
|
||||
websocket_api.async_register_command(hass, handle_create)
|
||||
websocket_api.async_register_command(hass, handle_create_with_automatic_settings)
|
||||
websocket_api.async_register_command(hass, handle_delete)
|
||||
|
@ -147,6 +152,38 @@ async def handle_restore(
|
|||
connection.send_result(msg["id"])
|
||||
|
||||
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "backup/can_decrypt_on_download",
|
||||
vol.Required("backup_id"): str,
|
||||
vol.Required("agent_id"): str,
|
||||
vol.Required("password"): str,
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def handle_can_decrypt_on_download(
|
||||
hass: HomeAssistant,
|
||||
connection: websocket_api.ActiveConnection,
|
||||
msg: dict[str, Any],
|
||||
) -> None:
|
||||
"""Check if the supplied password is correct."""
|
||||
try:
|
||||
await hass.data[DATA_MANAGER].async_can_decrypt_on_download(
|
||||
msg["backup_id"],
|
||||
agent_id=msg["agent_id"],
|
||||
password=msg.get("password"),
|
||||
)
|
||||
except IncorrectPasswordError:
|
||||
connection.send_error(msg["id"], "password_incorrect", "Incorrect password")
|
||||
except DecryptOnDowloadNotSupported:
|
||||
connection.send_error(
|
||||
msg["id"], "decrypt_not_supported", "Decrypt on download not supported"
|
||||
)
|
||||
else:
|
||||
connection.send_result(msg["id"])
|
||||
|
||||
|
||||
@websocket_api.require_admin
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
|
|
Binary file not shown.
Binary file not shown.
|
@ -175,6 +175,58 @@
|
|||
'type': 'result',
|
||||
})
|
||||
# ---
|
||||
# name: test_can_decrypt_on_download[backup.local-2bcb3113-hunter2]
|
||||
dict({
|
||||
'error': dict({
|
||||
'code': 'decrypt_not_supported',
|
||||
'message': 'Decrypt on download not supported',
|
||||
}),
|
||||
'id': 1,
|
||||
'success': False,
|
||||
'type': 'result',
|
||||
})
|
||||
# ---
|
||||
# name: test_can_decrypt_on_download[backup.local-ed1608a9-hunter2]
|
||||
dict({
|
||||
'id': 1,
|
||||
'result': None,
|
||||
'success': True,
|
||||
'type': 'result',
|
||||
})
|
||||
# ---
|
||||
# name: test_can_decrypt_on_download[backup.local-ed1608a9-wrong_password]
|
||||
dict({
|
||||
'error': dict({
|
||||
'code': 'password_incorrect',
|
||||
'message': 'Incorrect password',
|
||||
}),
|
||||
'id': 1,
|
||||
'success': False,
|
||||
'type': 'result',
|
||||
})
|
||||
# ---
|
||||
# name: test_can_decrypt_on_download[backup.local-no_such_backup-hunter2]
|
||||
dict({
|
||||
'error': dict({
|
||||
'code': 'home_assistant_error',
|
||||
'message': 'Backup no_such_backup not found in agent backup.local',
|
||||
}),
|
||||
'id': 1,
|
||||
'success': False,
|
||||
'type': 'result',
|
||||
})
|
||||
# ---
|
||||
# name: test_can_decrypt_on_download[no_such_agent-ed1608a9-hunter2]
|
||||
dict({
|
||||
'error': dict({
|
||||
'code': 'home_assistant_error',
|
||||
'message': 'Invalid agent selected: no_such_agent',
|
||||
}),
|
||||
'id': 1,
|
||||
'success': False,
|
||||
'type': 'result',
|
||||
})
|
||||
# ---
|
||||
# name: test_config_info[None]
|
||||
dict({
|
||||
'id': 1,
|
||||
|
|
|
@ -36,7 +36,7 @@ from .common import (
|
|||
setup_backup_platform,
|
||||
)
|
||||
|
||||
from tests.common import async_fire_time_changed, async_mock_service
|
||||
from tests.common import async_fire_time_changed, async_mock_service, get_fixture_path
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
||||
BACKUP_CALL = call(
|
||||
|
@ -2554,3 +2554,56 @@ async def test_subscribe_event(
|
|||
CreateBackupEvent(stage=None, state=CreateBackupState.IN_PROGRESS)
|
||||
)
|
||||
assert await client.receive_json() == snapshot
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_backups() -> Generator[None]:
|
||||
"""Fixture to setup test backups."""
|
||||
# pylint: disable-next=import-outside-toplevel
|
||||
from homeassistant.components.backup import backup as core_backup
|
||||
|
||||
class CoreLocalBackupAgent(core_backup.CoreLocalBackupAgent):
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
super().__init__(hass)
|
||||
self._backup_dir = get_fixture_path("test_backups", DOMAIN)
|
||||
|
||||
with patch.object(core_backup, "CoreLocalBackupAgent", CoreLocalBackupAgent):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("agent_id", "backup_id", "password"),
|
||||
[
|
||||
# Invalid agent or backup
|
||||
("no_such_agent", "ed1608a9", "hunter2"),
|
||||
("backup.local", "no_such_backup", "hunter2"),
|
||||
# Legacy backup, which can't be streamed
|
||||
("backup.local", "2bcb3113", "hunter2"),
|
||||
# New backup, which can be streamed, try with correct and wrong password
|
||||
("backup.local", "ed1608a9", "hunter2"),
|
||||
("backup.local", "ed1608a9", "wrong_password"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.usefixtures("mock_backups")
|
||||
async def test_can_decrypt_on_download(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
snapshot: SnapshotAssertion,
|
||||
agent_id: str,
|
||||
backup_id: str,
|
||||
password: str,
|
||||
) -> None:
|
||||
"""Test can decrypt on download."""
|
||||
await setup_backup_integration(hass, with_hassio=False)
|
||||
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "backup/can_decrypt_on_download",
|
||||
"backup_id": backup_id,
|
||||
"agent_id": agent_id,
|
||||
"password": password,
|
||||
}
|
||||
)
|
||||
assert await client.receive_json() == snapshot
|
||||
|
|
Loading…
Reference in New Issue