192 lines
6.1 KiB
Python
192 lines
6.1 KiB
Python
"""Script to convert an old-format home-assistant.db to a new format one."""
|
|
|
|
import argparse
|
|
import os.path
|
|
import sqlite3
|
|
import sys
|
|
|
|
from datetime import datetime
|
|
from typing import Optional, List
|
|
|
|
import homeassistant.config as config_util
|
|
import homeassistant.util.dt as dt_util
|
|
# pylint: disable=unused-import
|
|
from homeassistant.components.recorder import REQUIREMENTS # NOQA
|
|
|
|
|
|
def ts_to_dt(timestamp: Optional[float]) -> Optional[datetime]:
|
|
"""Turn a datetime into an integer for in the DB."""
|
|
if timestamp is None:
|
|
return None
|
|
return dt_util.utc_from_timestamp(timestamp)
|
|
|
|
|
|
# Based on code at
|
|
# http://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console
|
|
def print_progress(iteration: int, total: int, prefix: str = '',
|
|
suffix: str = '', decimals: int = 2,
|
|
bar_length: int = 68) -> None:
|
|
"""Print progress bar.
|
|
|
|
Call in a loop to create terminal progress bar
|
|
@params:
|
|
iteration - Required : current iteration (Int)
|
|
total - Required : total iterations (Int)
|
|
prefix - Optional : prefix string (Str)
|
|
suffix - Optional : suffix string (Str)
|
|
decimals - Optional : number of decimals in percent complete (Int)
|
|
barLength - Optional : character length of bar (Int)
|
|
"""
|
|
filled_length = int(round(bar_length * iteration / float(total)))
|
|
percents = round(100.00 * (iteration / float(total)), decimals)
|
|
line = '#' * filled_length + '-' * (bar_length - filled_length)
|
|
sys.stdout.write('%s [%s] %s%s %s\r' % (prefix, line,
|
|
percents, '%', suffix))
|
|
sys.stdout.flush()
|
|
if iteration == total:
|
|
print("\n")
|
|
|
|
|
|
def run(script_args: List) -> int:
|
|
"""Run the actual script."""
|
|
# pylint: disable=invalid-name
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
from homeassistant.components.recorder import models
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="Migrate legacy DB to SQLAlchemy format.")
|
|
parser.add_argument(
|
|
'-c', '--config',
|
|
metavar='path_to_config_dir',
|
|
default=config_util.get_default_config_dir(),
|
|
help="Directory that contains the Home Assistant configuration")
|
|
parser.add_argument(
|
|
'-a', '--append',
|
|
action='store_true',
|
|
default=False,
|
|
help="Append to existing new format SQLite database")
|
|
parser.add_argument(
|
|
'--uri',
|
|
type=str,
|
|
help="Connect to URI and import (implies --append)"
|
|
"eg: mysql://localhost/homeassistant")
|
|
parser.add_argument(
|
|
'--script',
|
|
choices=['db_migrator'])
|
|
|
|
args = parser.parse_args()
|
|
|
|
config_dir = os.path.join(os.getcwd(), args.config) # type: str
|
|
|
|
# Test if configuration directory exists
|
|
if not os.path.isdir(config_dir):
|
|
if config_dir != config_util.get_default_config_dir():
|
|
print(('Fatal Error: Specified configuration directory does '
|
|
'not exist {} ').format(config_dir))
|
|
return 1
|
|
|
|
src_db = '{}/home-assistant.db'.format(config_dir)
|
|
dst_db = '{}/home-assistant_v2.db'.format(config_dir)
|
|
|
|
if not os.path.exists(src_db):
|
|
print("Fatal Error: Old format database '{}' does not exist".format(
|
|
src_db))
|
|
return 1
|
|
if not args.uri and (os.path.exists(dst_db) and not args.append):
|
|
print("Fatal Error: New format database '{}' exists already - "
|
|
"Remove it or use --append".format(dst_db))
|
|
print("Note: --append must maintain an ID mapping and is much slower"
|
|
"and requires sufficient memory to track all event IDs")
|
|
return 1
|
|
|
|
conn = sqlite3.connect(src_db)
|
|
uri = args.uri or "sqlite:///{}".format(dst_db)
|
|
|
|
engine = create_engine(uri, echo=False)
|
|
models.Base.metadata.create_all(engine)
|
|
session_factory = sessionmaker(bind=engine)
|
|
session = session_factory()
|
|
|
|
append = args.append or args.uri
|
|
|
|
c = conn.cursor()
|
|
c.execute("SELECT count(*) FROM recorder_runs")
|
|
num_rows = c.fetchone()[0]
|
|
print("Converting {} recorder_runs".format(num_rows))
|
|
c.close()
|
|
|
|
c = conn.cursor()
|
|
n = 0
|
|
for row in c.execute("SELECT * FROM recorder_runs"): # type: ignore
|
|
n += 1
|
|
session.add(models.RecorderRuns(
|
|
start=ts_to_dt(row[1]),
|
|
end=ts_to_dt(row[2]),
|
|
closed_incorrect=row[3],
|
|
created=ts_to_dt(row[4])
|
|
))
|
|
if n % 1000 == 0:
|
|
session.commit()
|
|
print_progress(n, num_rows)
|
|
print_progress(n, num_rows)
|
|
session.commit()
|
|
c.close()
|
|
|
|
c = conn.cursor()
|
|
c.execute("SELECT count(*) FROM events")
|
|
num_rows = c.fetchone()[0]
|
|
print("Converting {} events".format(num_rows))
|
|
c.close()
|
|
|
|
id_mapping = {}
|
|
|
|
c = conn.cursor()
|
|
n = 0
|
|
for row in c.execute("SELECT * FROM events"): # type: ignore
|
|
n += 1
|
|
o = models.Events(
|
|
event_type=row[1],
|
|
event_data=row[2],
|
|
origin=row[3],
|
|
created=ts_to_dt(row[4]),
|
|
time_fired=ts_to_dt(row[5]),
|
|
)
|
|
session.add(o)
|
|
if append:
|
|
session.flush()
|
|
id_mapping[row[0]] = o.event_id
|
|
if n % 1000 == 0:
|
|
session.commit()
|
|
print_progress(n, num_rows)
|
|
print_progress(n, num_rows)
|
|
session.commit()
|
|
c.close()
|
|
|
|
c = conn.cursor()
|
|
c.execute("SELECT count(*) FROM states")
|
|
num_rows = c.fetchone()[0]
|
|
print("Converting {} states".format(num_rows))
|
|
c.close()
|
|
|
|
c = conn.cursor()
|
|
n = 0
|
|
for row in c.execute("SELECT * FROM states"): # type: ignore
|
|
n += 1
|
|
session.add(models.States(
|
|
entity_id=row[1],
|
|
state=row[2],
|
|
attributes=row[3],
|
|
last_changed=ts_to_dt(row[4]),
|
|
last_updated=ts_to_dt(row[5]),
|
|
event_id=id_mapping.get(row[6], row[6]),
|
|
domain=row[7]
|
|
))
|
|
if n % 1000 == 0:
|
|
session.commit()
|
|
print_progress(n, num_rows)
|
|
print_progress(n, num_rows)
|
|
session.commit()
|
|
c.close()
|
|
return 0
|