core/homeassistant/components/backup/util.py

596 lines
20 KiB
Python

"""Local backup support for Core and Container installations."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator, Callable, Coroutine
from concurrent.futures import CancelledError, Future
import copy
from dataclasses import dataclass, replace
from io import BytesIO
import json
import os
from pathlib import Path, PurePath
from queue import SimpleQueue
import tarfile
import threading
from typing import IO, Any, Self, cast
import aiohttp
from securetar import SecureTarError, SecureTarFile, SecureTarReadError
from homeassistant.backup_restore import password_to_key
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import dt as dt_util
from homeassistant.util.json import JsonObjectType, json_loads_object
from .const import BUF_SIZE, LOGGER
from .models import AddonInfo, AgentBackup, Folder
class DecryptError(HomeAssistantError):
"""Error during decryption."""
_message = "Unexpected error during decryption."
class EncryptError(HomeAssistantError):
"""Error during encryption."""
_message = "Unexpected error during encryption."
class UnsupportedSecureTarVersion(DecryptError):
"""Unsupported securetar version."""
_message = "Unsupported securetar version."
class IncorrectPassword(DecryptError):
"""Invalid password or corrupted backup."""
_message = "Invalid password or corrupted backup."
class BackupEmpty(DecryptError):
"""No tar files found in the backup."""
_message = "No tar files found in the backup."
class AbortCipher(HomeAssistantError):
"""Abort the cipher operation."""
_message = "Abort cipher operation."
def make_backup_dir(path: Path) -> None:
"""Create a backup directory if it does not exist."""
path.mkdir(exist_ok=True)
def read_backup(backup_path: Path) -> AgentBackup:
"""Read a backup from disk."""
with tarfile.open(backup_path, "r:", bufsize=BUF_SIZE) as backup_file:
if not (data_file := backup_file.extractfile("./backup.json")):
raise KeyError("backup.json not found in tar file")
data = json_loads_object(data_file.read())
addons = [
AddonInfo(
name=cast(str, addon["name"]),
slug=cast(str, addon["slug"]),
version=cast(str, addon["version"]),
)
for addon in cast(list[JsonObjectType], data.get("addons", []))
]
folders = [
Folder(folder)
for folder in cast(list[str], data.get("folders", []))
if folder != "homeassistant"
]
homeassistant_included = False
homeassistant_version: str | None = None
database_included = False
if (
homeassistant := cast(JsonObjectType, data.get("homeassistant"))
) and "version" in homeassistant:
homeassistant_included = True
homeassistant_version = cast(str, homeassistant["version"])
database_included = not cast(
bool, homeassistant.get("exclude_database", False)
)
return AgentBackup(
addons=addons,
backup_id=cast(str, data["slug"]),
database_included=database_included,
date=cast(str, data["date"]),
extra_metadata=cast(dict[str, bool | str], data.get("extra", {})),
folders=folders,
homeassistant_included=homeassistant_included,
homeassistant_version=homeassistant_version,
name=cast(str, data["name"]),
protected=cast(bool, data.get("protected", False)),
size=backup_path.stat().st_size,
)
def suggested_filename_from_name_date(name: str, date_str: str) -> str:
"""Suggest a filename for the backup."""
date = dt_util.parse_datetime(date_str, raise_on_error=True)
return "_".join(f"{name} {date.strftime('%Y-%m-%d %H.%M %S%f')}.tar".split())
def suggested_filename(backup: AgentBackup) -> str:
"""Suggest a filename for the backup."""
return suggested_filename_from_name_date(backup.name, backup.date)
def validate_password(path: Path, password: str | None) -> bool:
"""Validate the password."""
with tarfile.open(path, "r:", bufsize=BUF_SIZE) as backup_file:
compressed = False
ha_tar_name = "homeassistant.tar"
try:
ha_tar = backup_file.extractfile(ha_tar_name)
except KeyError:
compressed = True
ha_tar_name = "homeassistant.tar.gz"
try:
ha_tar = backup_file.extractfile(ha_tar_name)
except KeyError:
LOGGER.error("No homeassistant.tar or homeassistant.tar.gz found")
return False
try:
with SecureTarFile(
path, # Not used
gzip=compressed,
key=password_to_key(password) if password is not None else None,
mode="r",
fileobj=ha_tar,
):
# If we can read the tar file, the password is correct
return True
except tarfile.ReadError:
LOGGER.debug("Invalid password")
return False
except Exception: # noqa: BLE001
LOGGER.exception("Unexpected error validating password")
return False
class AsyncIteratorReader:
"""Wrap an AsyncIterator."""
def __init__(self, hass: HomeAssistant, stream: AsyncIterator[bytes]) -> None:
"""Initialize the wrapper."""
self._aborted = False
self._hass = hass
self._stream = stream
self._buffer: bytes | None = None
self._next_future: Future[bytes | None] | None = None
self._pos: int = 0
async def _next(self) -> bytes | None:
"""Get the next chunk from the iterator."""
return await anext(self._stream, None)
def abort(self) -> None:
"""Abort the reader."""
self._aborted = True
if self._next_future is not None:
self._next_future.cancel()
def read(self, n: int = -1, /) -> bytes:
"""Read data from the iterator."""
result = bytearray()
while n < 0 or len(result) < n:
if not self._buffer:
self._next_future = asyncio.run_coroutine_threadsafe(
self._next(), self._hass.loop
)
if self._aborted:
self._next_future.cancel()
raise AbortCipher
try:
self._buffer = self._next_future.result()
except CancelledError as err:
raise AbortCipher from err
self._pos = 0
if not self._buffer:
# The stream is exhausted
break
chunk = self._buffer[self._pos : self._pos + n]
result.extend(chunk)
n -= len(chunk)
self._pos += len(chunk)
if self._pos == len(self._buffer):
self._buffer = None
return bytes(result)
def close(self) -> None:
"""Close the iterator."""
class AsyncIteratorWriter:
"""Wrap an AsyncIterator."""
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize the wrapper."""
self._aborted = False
self._hass = hass
self._pos: int = 0
self._queue: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=1)
self._write_future: Future[bytes | None] | None = None
def __aiter__(self) -> Self:
"""Return the iterator."""
return self
async def __anext__(self) -> bytes:
"""Get the next chunk from the iterator."""
if data := await self._queue.get():
return data
raise StopAsyncIteration
def abort(self) -> None:
"""Abort the writer."""
self._aborted = True
if self._write_future is not None:
self._write_future.cancel()
def tell(self) -> int:
"""Return the current position in the iterator."""
return self._pos
def write(self, s: bytes, /) -> int:
"""Write data to the iterator."""
self._write_future = asyncio.run_coroutine_threadsafe(
self._queue.put(s), self._hass.loop
)
if self._aborted:
self._write_future.cancel()
raise AbortCipher
try:
self._write_future.result()
except CancelledError as err:
raise AbortCipher from err
self._pos += len(s)
return len(s)
def validate_password_stream(
input_stream: IO[bytes],
password: str | None,
) -> None:
"""Decrypt a backup."""
with (
tarfile.open(fileobj=input_stream, mode="r|", bufsize=BUF_SIZE) as input_tar,
):
for obj in input_tar:
if not obj.name.endswith((".tar", ".tgz", ".tar.gz")):
continue
istf = SecureTarFile(
None, # Not used
gzip=False,
key=password_to_key(password) if password is not None else None,
mode="r",
fileobj=input_tar.extractfile(obj),
)
with istf.decrypt(obj) as decrypted:
if istf.securetar_header.plaintext_size is None:
raise UnsupportedSecureTarVersion
try:
decrypted.read(1) # Read a single byte to trigger the decryption
except SecureTarReadError as err:
raise IncorrectPassword from err
return
raise BackupEmpty
def decrypt_backup(
input_stream: IO[bytes],
output_stream: IO[bytes],
password: str | None,
on_done: Callable[[Exception | None], None],
minimum_size: int,
nonces: list[bytes],
) -> None:
"""Decrypt a backup."""
error: Exception | None = None
try:
try:
with (
tarfile.open(
fileobj=input_stream, mode="r|", bufsize=BUF_SIZE
) as input_tar,
tarfile.open(
fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
) as output_tar,
):
_decrypt_backup(input_tar, output_tar, password)
except (DecryptError, SecureTarError, tarfile.TarError) as err:
LOGGER.warning("Error decrypting backup: %s", err)
error = err
else:
# Pad the output stream to the requested minimum size
padding = max(minimum_size - output_stream.tell(), 0)
output_stream.write(b"\0" * padding)
finally:
# Write an empty chunk to signal the end of the stream
output_stream.write(b"")
except AbortCipher:
LOGGER.debug("Cipher operation aborted")
finally:
on_done(error)
def _decrypt_backup(
input_tar: tarfile.TarFile,
output_tar: tarfile.TarFile,
password: str | None,
) -> None:
"""Decrypt a backup."""
for obj in input_tar:
# We compare with PurePath to avoid issues with different path separators,
# for example when backup.json is added as "./backup.json"
if PurePath(obj.name) == PurePath("backup.json"):
# Rewrite the backup.json file to indicate that the backup is decrypted
if not (reader := input_tar.extractfile(obj)):
raise DecryptError
metadata = json_loads_object(reader.read())
metadata["protected"] = False
updated_metadata_b = json.dumps(metadata).encode()
metadata_obj = copy.deepcopy(obj)
metadata_obj.size = len(updated_metadata_b)
output_tar.addfile(metadata_obj, BytesIO(updated_metadata_b))
continue
if not obj.name.endswith((".tar", ".tgz", ".tar.gz")):
output_tar.addfile(obj, input_tar.extractfile(obj))
continue
istf = SecureTarFile(
None, # Not used
gzip=False,
key=password_to_key(password) if password is not None else None,
mode="r",
fileobj=input_tar.extractfile(obj),
)
with istf.decrypt(obj) as decrypted:
if (plaintext_size := istf.securetar_header.plaintext_size) is None:
raise UnsupportedSecureTarVersion
decrypted_obj = copy.deepcopy(obj)
decrypted_obj.size = plaintext_size
output_tar.addfile(decrypted_obj, decrypted)
def encrypt_backup(
input_stream: IO[bytes],
output_stream: IO[bytes],
password: str | None,
on_done: Callable[[Exception | None], None],
minimum_size: int,
nonces: list[bytes],
) -> None:
"""Encrypt a backup."""
error: Exception | None = None
try:
try:
with (
tarfile.open(
fileobj=input_stream, mode="r|", bufsize=BUF_SIZE
) as input_tar,
tarfile.open(
fileobj=output_stream, mode="w|", bufsize=BUF_SIZE
) as output_tar,
):
_encrypt_backup(input_tar, output_tar, password, nonces)
except (EncryptError, SecureTarError, tarfile.TarError) as err:
LOGGER.warning("Error encrypting backup: %s", err)
error = err
else:
# Pad the output stream to the requested minimum size
padding = max(minimum_size - output_stream.tell(), 0)
output_stream.write(b"\0" * padding)
finally:
# Write an empty chunk to signal the end of the stream
output_stream.write(b"")
except AbortCipher:
LOGGER.debug("Cipher operation aborted")
finally:
on_done(error)
def _encrypt_backup(
input_tar: tarfile.TarFile,
output_tar: tarfile.TarFile,
password: str | None,
nonces: list[bytes],
) -> None:
"""Encrypt a backup."""
inner_tar_idx = 0
for obj in input_tar:
# We compare with PurePath to avoid issues with different path separators,
# for example when backup.json is added as "./backup.json"
if PurePath(obj.name) == PurePath("backup.json"):
# Rewrite the backup.json file to indicate that the backup is encrypted
if not (reader := input_tar.extractfile(obj)):
raise EncryptError
metadata = json_loads_object(reader.read())
metadata["protected"] = True
updated_metadata_b = json.dumps(metadata).encode()
metadata_obj = copy.deepcopy(obj)
metadata_obj.size = len(updated_metadata_b)
output_tar.addfile(metadata_obj, BytesIO(updated_metadata_b))
continue
if not obj.name.endswith((".tar", ".tgz", ".tar.gz")):
output_tar.addfile(obj, input_tar.extractfile(obj))
continue
istf = SecureTarFile(
None, # Not used
gzip=False,
key=password_to_key(password) if password is not None else None,
mode="r",
fileobj=input_tar.extractfile(obj),
nonce=nonces[inner_tar_idx],
)
inner_tar_idx += 1
with istf.encrypt(obj) as encrypted:
encrypted_obj = copy.deepcopy(obj)
encrypted_obj.size = encrypted.encrypted_size
output_tar.addfile(encrypted_obj, encrypted)
@dataclass(kw_only=True)
class _CipherWorkerStatus:
done: asyncio.Event
error: Exception | None = None
reader: AsyncIteratorReader
thread: threading.Thread
writer: AsyncIteratorWriter
class _CipherBackupStreamer:
"""Encrypt or decrypt a backup."""
_cipher_func: Callable[
[
IO[bytes],
IO[bytes],
str | None,
Callable[[Exception | None], None],
int,
list[bytes],
],
None,
]
def __init__(
self,
hass: HomeAssistant,
backup: AgentBackup,
open_stream: Callable[[], Coroutine[Any, Any, AsyncIterator[bytes]]],
password: str | None,
) -> None:
"""Initialize."""
self._workers: list[_CipherWorkerStatus] = []
self._backup = backup
self._hass = hass
self._open_stream = open_stream
self._password = password
self._nonces: list[bytes] = []
def size(self) -> int:
"""Return the maximum size of the decrypted or encrypted backup."""
return self._backup.size + self._num_tar_files() * tarfile.RECORDSIZE
def _num_tar_files(self) -> int:
"""Return the number of inner tar files."""
b = self._backup
return len(b.addons) + len(b.folders) + b.homeassistant_included + 1
async def open_stream(self) -> AsyncIterator[bytes]:
"""Open a stream."""
def on_done(error: Exception | None) -> None:
"""Call by the worker thread when it's done."""
worker_status.error = error
self._hass.loop.call_soon_threadsafe(worker_status.done.set)
stream = await self._open_stream()
reader = AsyncIteratorReader(self._hass, stream)
writer = AsyncIteratorWriter(self._hass)
worker = threading.Thread(
target=self._cipher_func,
args=[reader, writer, self._password, on_done, self.size(), self._nonces],
)
worker_status = _CipherWorkerStatus(
done=asyncio.Event(), reader=reader, thread=worker, writer=writer
)
self._workers.append(worker_status)
worker.start()
return writer
async def wait(self) -> None:
"""Wait for the worker threads to finish."""
for worker in self._workers:
worker.reader.abort()
worker.writer.abort()
await asyncio.gather(*(worker.done.wait() for worker in self._workers))
class DecryptedBackupStreamer(_CipherBackupStreamer):
"""Decrypt a backup."""
_cipher_func = staticmethod(decrypt_backup)
def backup(self) -> AgentBackup:
"""Return the decrypted backup."""
return replace(self._backup, protected=False, size=self.size())
class EncryptedBackupStreamer(_CipherBackupStreamer):
"""Encrypt a backup."""
def __init__(
self,
hass: HomeAssistant,
backup: AgentBackup,
open_stream: Callable[[], Coroutine[Any, Any, AsyncIterator[bytes]]],
password: str | None,
) -> None:
"""Initialize."""
super().__init__(hass, backup, open_stream, password)
self._nonces = [os.urandom(16) for _ in range(self._num_tar_files())]
_cipher_func = staticmethod(encrypt_backup)
def backup(self) -> AgentBackup:
"""Return the encrypted backup."""
return replace(self._backup, protected=True, size=self.size())
async def receive_file(
hass: HomeAssistant, contents: aiohttp.BodyPartReader, path: Path
) -> None:
"""Receive a file from a stream and write it to a file."""
queue: SimpleQueue[tuple[bytes, asyncio.Future[None] | None] | None] = SimpleQueue()
def _sync_queue_consumer() -> None:
with path.open("wb") as file_handle:
while True:
if (_chunk_future := queue.get()) is None:
break
_chunk, _future = _chunk_future
if _future is not None:
hass.loop.call_soon_threadsafe(_future.set_result, None)
file_handle.write(_chunk)
fut: asyncio.Future[None] | None = None
try:
fut = hass.async_add_executor_job(_sync_queue_consumer)
megabytes_sending = 0
while chunk := await contents.read_chunk(BUF_SIZE):
megabytes_sending += 1
if megabytes_sending % 5 != 0:
queue.put_nowait((chunk, None))
continue
chunk_future = hass.loop.create_future()
queue.put_nowait((chunk, chunk_future))
await asyncio.wait(
(fut, chunk_future),
return_when=asyncio.FIRST_COMPLETED,
)
if fut.done():
# The executor job failed
break
queue.put_nowait(None) # terminate queue consumer
finally:
if fut is not None:
await fut