diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index 98a5aea4a57b8e..01fda0f02faa8c 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -1459,6 +1459,7 @@ def _setup_connection(self) -> None: self.__dict__.pop("dialect_name", None) sqlalchemy_event.listen(self.engine, "connect", self._setup_recorder_connection) + migration.pre_migrate_schema(self.engine) Base.metadata.create_all(self.engine) self._get_session = scoped_session(sessionmaker(bind=self.engine, future=True)) _LOGGER.debug("Connected to recorder database") diff --git a/homeassistant/components/recorder/db_schema.py b/homeassistant/components/recorder/db_schema.py index 915fd4a4bb816f..57f113b779e98a 100644 --- a/homeassistant/components/recorder/db_schema.py +++ b/homeassistant/components/recorder/db_schema.py @@ -73,7 +73,11 @@ class Base(DeclarativeBase): """Base class for tables.""" -SCHEMA_VERSION = 43 +class LegacyBase(DeclarativeBase): + """Base class for tables, used for schema migration.""" + + +SCHEMA_VERSION = 44 _LOGGER = logging.getLogger(__name__) @@ -187,6 +191,9 @@ def result_processor(self, dialect, coltype): # type: ignore[no-untyped-def] return None +# Although all integers are same in SQLite, it does not allow an identity column to be BIGINT +# https://sqlite.org/forum/info/2dfa968a702e1506e885cb06d92157d492108b22bf39459506ab9f7125bca7fd +ID_TYPE = BigInteger().with_variant(sqlite.INTEGER, "sqlite") # For MariaDB and MySQL we can use an unsigned integer type since it will fit 2**32 # for sqlite and postgresql we use a bigint UINT_32_TYPE = BigInteger().with_variant( @@ -217,6 +224,7 @@ def result_processor(self, dialect, coltype): # type: ignore[no-untyped-def] UNUSED_LEGACY_DATETIME_COLUMN = UnusedDateTime(timezone=True) UNUSED_LEGACY_INTEGER_COLUMN = SmallInteger() DOUBLE_PRECISION_TYPE_SQL = "DOUBLE PRECISION" +BIG_INTEGER_SQL = "BIGINT" CONTEXT_BINARY_TYPE = LargeBinary(CONTEXT_ID_BIN_MAX_LENGTH).with_variant( NativeLargeBinary(CONTEXT_ID_BIN_MAX_LENGTH), "mysql", "mariadb", "sqlite" ) @@ -258,7 +266,7 @@ class Events(Base): _DEFAULT_TABLE_ARGS, ) __tablename__ = TABLE_EVENTS - event_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + event_id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) event_type: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) event_data: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) origin: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) @@ -269,13 +277,13 @@ class Events(Base): context_user_id: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) context_parent_id: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) data_id: Mapped[int | None] = mapped_column( - Integer, ForeignKey("event_data.data_id"), index=True + ID_TYPE, ForeignKey("event_data.data_id"), index=True ) context_id_bin: Mapped[bytes | None] = mapped_column(CONTEXT_BINARY_TYPE) context_user_id_bin: Mapped[bytes | None] = mapped_column(CONTEXT_BINARY_TYPE) context_parent_id_bin: Mapped[bytes | None] = mapped_column(CONTEXT_BINARY_TYPE) event_type_id: Mapped[int | None] = mapped_column( - Integer, ForeignKey("event_types.event_type_id") + ID_TYPE, ForeignKey("event_types.event_type_id") ) event_data_rel: Mapped[EventData | None] = relationship("EventData") event_type_rel: Mapped[EventTypes | None] = relationship("EventTypes") @@ -347,7 +355,7 @@ class EventData(Base): __table_args__ = (_DEFAULT_TABLE_ARGS,) __tablename__ = TABLE_EVENT_DATA - data_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + data_id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) hash: Mapped[int | None] = mapped_column(UINT_32_TYPE, index=True) # Note that this is not named attributes to avoid confusion with the states table shared_data: Mapped[str | None] = mapped_column( @@ -403,7 +411,7 @@ class EventTypes(Base): __table_args__ = (_DEFAULT_TABLE_ARGS,) __tablename__ = TABLE_EVENT_TYPES - event_type_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + event_type_id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) event_type: Mapped[str | None] = mapped_column( String(MAX_LENGTH_EVENT_EVENT_TYPE), index=True, unique=True ) @@ -433,7 +441,7 @@ class States(Base): _DEFAULT_TABLE_ARGS, ) __tablename__ = TABLE_STATES - state_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + state_id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) entity_id: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) state: Mapped[str | None] = mapped_column(String(MAX_LENGTH_STATE_STATE)) attributes: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) @@ -446,10 +454,10 @@ class States(Base): TIMESTAMP_TYPE, default=time.time, index=True ) old_state_id: Mapped[int | None] = mapped_column( - Integer, ForeignKey("states.state_id"), index=True + ID_TYPE, ForeignKey("states.state_id"), index=True ) attributes_id: Mapped[int | None] = mapped_column( - Integer, ForeignKey("state_attributes.attributes_id"), index=True + ID_TYPE, ForeignKey("state_attributes.attributes_id"), index=True ) context_id: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) context_user_id: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) @@ -463,7 +471,7 @@ class States(Base): context_user_id_bin: Mapped[bytes | None] = mapped_column(CONTEXT_BINARY_TYPE) context_parent_id_bin: Mapped[bytes | None] = mapped_column(CONTEXT_BINARY_TYPE) metadata_id: Mapped[int | None] = mapped_column( - Integer, ForeignKey("states_meta.metadata_id") + ID_TYPE, ForeignKey("states_meta.metadata_id") ) states_meta_rel: Mapped[StatesMeta | None] = relationship("StatesMeta") @@ -573,7 +581,7 @@ class StateAttributes(Base): __table_args__ = (_DEFAULT_TABLE_ARGS,) __tablename__ = TABLE_STATE_ATTRIBUTES - attributes_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + attributes_id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) hash: Mapped[int | None] = mapped_column(UINT_32_TYPE, index=True) # Note that this is not named attributes to avoid confusion with the states table shared_attrs: Mapped[str | None] = mapped_column( @@ -647,7 +655,7 @@ class StatesMeta(Base): __table_args__ = (_DEFAULT_TABLE_ARGS,) __tablename__ = TABLE_STATES_META - metadata_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + metadata_id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) entity_id: Mapped[str | None] = mapped_column( String(MAX_LENGTH_STATE_ENTITY_ID), index=True, unique=True ) @@ -664,11 +672,11 @@ def __repr__(self) -> str: class StatisticsBase: """Statistics base class.""" - id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) created: Mapped[datetime | None] = mapped_column(UNUSED_LEGACY_DATETIME_COLUMN) created_ts: Mapped[float | None] = mapped_column(TIMESTAMP_TYPE, default=time.time) metadata_id: Mapped[int | None] = mapped_column( - Integer, + ID_TYPE, ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"), ) start: Mapped[datetime | None] = mapped_column(UNUSED_LEGACY_DATETIME_COLUMN) @@ -738,11 +746,17 @@ class Statistics(Base, StatisticsBase): __tablename__ = TABLE_STATISTICS -class StatisticsShortTerm(Base, StatisticsBase): +class _StatisticsShortTerm(StatisticsBase): """Short term statistics.""" duration = timedelta(minutes=5) + __tablename__ = TABLE_STATISTICS_SHORT_TERM + + +class StatisticsShortTerm(Base, _StatisticsShortTerm): + """Short term statistics.""" + __table_args__ = ( # Used for fetching statistics for a certain entity at a specific time Index( @@ -753,15 +767,35 @@ class StatisticsShortTerm(Base, StatisticsBase): ), _DEFAULT_TABLE_ARGS, ) - __tablename__ = TABLE_STATISTICS_SHORT_TERM -class StatisticsMeta(Base): +class LegacyStatisticsShortTerm(LegacyBase, _StatisticsShortTerm): + """Short term statistics with 32-bit index, used for schema migration.""" + + __table_args__ = ( + # Used for fetching statistics for a certain entity at a specific time + Index( + "ix_statistics_short_term_statistic_id_start_ts", + "metadata_id", + "start_ts", + unique=True, + ), + _DEFAULT_TABLE_ARGS, + ) + + metadata_id: Mapped[int | None] = mapped_column( + Integer, + ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"), + use_existing_column=True, + ) + + +class _StatisticsMeta: """Statistics meta data.""" __table_args__ = (_DEFAULT_TABLE_ARGS,) __tablename__ = TABLE_STATISTICS_META - id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) statistic_id: Mapped[str | None] = mapped_column( String(255), index=True, unique=True ) @@ -777,6 +811,21 @@ def from_meta(meta: StatisticMetaData) -> StatisticsMeta: return StatisticsMeta(**meta) +class StatisticsMeta(Base, _StatisticsMeta): + """Statistics meta data.""" + + +class LegacyStatisticsMeta(LegacyBase, _StatisticsMeta): + """Statistics meta data with 32-bit index, used for schema migration.""" + + id: Mapped[int] = mapped_column( + Integer, + Identity(), + primary_key=True, + use_existing_column=True, + ) + + class RecorderRuns(Base): """Representation of recorder run.""" @@ -785,7 +834,7 @@ class RecorderRuns(Base): _DEFAULT_TABLE_ARGS, ) __tablename__ = TABLE_RECORDER_RUNS - run_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + run_id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) start: Mapped[datetime] = mapped_column(DATETIME_TYPE, default=dt_util.utcnow) end: Mapped[datetime | None] = mapped_column(DATETIME_TYPE) closed_incorrect: Mapped[bool] = mapped_column(Boolean, default=False) @@ -824,7 +873,7 @@ class SchemaChanges(Base): __tablename__ = TABLE_SCHEMA_CHANGES __table_args__ = (_DEFAULT_TABLE_ARGS,) - change_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + change_id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) schema_version: Mapped[int | None] = mapped_column(Integer) changed: Mapped[datetime] = mapped_column(DATETIME_TYPE, default=dt_util.utcnow) @@ -844,7 +893,7 @@ class StatisticsRuns(Base): __tablename__ = TABLE_STATISTICS_RUNS __table_args__ = (_DEFAULT_TABLE_ARGS,) - run_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + run_id: Mapped[int] = mapped_column(ID_TYPE, Identity(), primary_key=True) start: Mapped[datetime] = mapped_column(DATETIME_TYPE, index=True) def __repr__(self) -> str: diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py index 83f89fa899540d..517ea4ca5cbe2f 100644 --- a/homeassistant/components/recorder/migration.py +++ b/homeassistant/components/recorder/migration.py @@ -15,6 +15,7 @@ import sqlalchemy from sqlalchemy import ForeignKeyConstraint, MetaData, Table, func, text, update from sqlalchemy.engine import CursorResult, Engine +from sqlalchemy.engine.interfaces import ReflectedForeignKeyConstraint from sqlalchemy.exc import ( DatabaseError, IntegrityError, @@ -55,6 +56,7 @@ SupportedDialect, ) from .db_schema import ( + BIG_INTEGER_SQL, CONTEXT_ID_BIN_MAX_LENGTH, DOUBLE_PRECISION_TYPE_SQL, LEGACY_STATES_ENTITY_ID_LAST_UPDATED_INDEX, @@ -67,6 +69,7 @@ Base, Events, EventTypes, + LegacyBase, MigrationChanges, SchemaChanges, States, @@ -243,6 +246,23 @@ def live_migration(schema_status: SchemaValidationStatus) -> bool: return schema_status.current_version >= LIVE_MIGRATION_MIN_SCHEMA_VERSION +def pre_migrate_schema(engine: Engine) -> None: + """Prepare for migration. + + This function is called before calling Base.metadata.create_all. + """ + inspector = sqlalchemy.inspect(engine) + + if inspector.has_table("statistics_meta") and not inspector.has_table( + "statistics_short_term" + ): + # Prepare for migration from schema with statistics_meta table but no + # statistics_short_term table + LegacyBase.metadata.create_all( + engine, (LegacyBase.metadata.tables["statistics_short_term"],) + ) + + def migrate_schema( instance: Recorder, hass: HomeAssistant, @@ -572,14 +592,19 @@ def _update_states_table_with_foreign_key_options( def _drop_foreign_key_constraints( - session_maker: Callable[[], Session], engine: Engine, table: str, columns: list[str] -) -> None: + session_maker: Callable[[], Session], engine: Engine, table: str, column: str +) -> list[tuple[str, str, ReflectedForeignKeyConstraint]]: """Drop foreign key constraints for a table on specific columns.""" inspector = sqlalchemy.inspect(engine) + dropped_constraints = [ + (table, column, foreign_key) + for foreign_key in inspector.get_foreign_keys(table) + if foreign_key["name"] and foreign_key["constrained_columns"] == [column] + ] drops = [ ForeignKeyConstraint((), (), name=foreign_key["name"]) for foreign_key in inspector.get_foreign_keys(table) - if foreign_key["name"] and foreign_key["constrained_columns"] == columns + if foreign_key["name"] and foreign_key["constrained_columns"] == [column] ] # Bind the ForeignKeyConstraints to the table @@ -594,9 +619,36 @@ def _drop_foreign_key_constraints( _LOGGER.exception( "Could not drop foreign constraints in %s table on %s", TABLE_STATES, - columns, + column, ) + return dropped_constraints + + +def _restore_foreign_key_constraints( + session_maker: Callable[[], Session], + engine: Engine, + dropped_constraints: list[tuple[str, str, ReflectedForeignKeyConstraint]], +) -> None: + """Restore foreign key constraints.""" + for table, column, dropped_constraint in dropped_constraints: + constraints = Base.metadata.tables[table].foreign_key_constraints + for constraint in constraints: + if constraint.column_keys == [column]: + break + else: + _LOGGER.info( + "Did not find a matching constraint for %s", dropped_constraint + ) + continue + + with session_scope(session=session_maker()) as session: + try: + connection = session.connection() + connection.execute(AddConstraint(constraint)) # type: ignore[no-untyped-call] + except (InternalError, OperationalError): + _LOGGER.exception("Could not update foreign options in %s table", table) + @database_job_retry_wrapper("Apply migration update", 10) def _apply_update( # noqa: C901 @@ -722,7 +774,7 @@ def _apply_update( # noqa: C901 pass elif new_version == 16: _drop_foreign_key_constraints( - session_maker, engine, TABLE_STATES, ["old_state_id"] + session_maker, engine, TABLE_STATES, "old_state_id" ) elif new_version == 17: # This dropped the statistics table, done again in version 18. @@ -1089,6 +1141,66 @@ def _apply_update( # noqa: C901 "states", [f"last_reported_ts {_column_types.timestamp_type}"], ) + elif new_version == 44: + # We skip this step for SQLITE, it doesn't have differently sized integers + if engine.dialect.name == SupportedDialect.SQLITE: + return + identity_sql = ( + "NOT NULL AUTO_INCREMENT" + if engine.dialect.name == SupportedDialect.MYSQL + else "" + ) + # First drop foreign key constraints + foreign_columns = ( + ("events", ("data_id", "event_type_id")), + ("states", ("event_id", "old_state_id", "attributes_id", "metadata_id")), + ("statistics", ("metadata_id",)), + ("statistics_short_term", ("metadata_id",)), + ) + dropped_constraints = [ + dropped_constraint + for table, columns in foreign_columns + for column in columns + for dropped_constraint in _drop_foreign_key_constraints( + session_maker, engine, table, column + ) + ] + _LOGGER.debug("Dropped foreign key constraints: %s", dropped_constraints) + + # Then modify the constrained columns + for table, columns in foreign_columns: + _modify_columns( + session_maker, + engine, + table, + [f"{column} {BIG_INTEGER_SQL}" for column in columns], + ) + + # Then modify the ID columns + id_columns = ( + ("events", "event_id"), + ("event_data", "data_id"), + ("event_types", "event_type_id"), + ("states", "state_id"), + ("state_attributes", "attributes_id"), + ("states_meta", "metadata_id"), + ("statistics", "id"), + ("statistics_short_term", "id"), + ("statistics_meta", "id"), + ("recorder_runs", "run_id"), + ("schema_changes", "change_id"), + ("statistics_runs", "run_id"), + ) + for table, column in id_columns: + _modify_columns( + session_maker, + engine, + table, + [f"{column} {BIG_INTEGER_SQL} {identity_sql}"], + ) + # Finally restore dropped constraints + _restore_foreign_key_constraints(session_maker, engine, dropped_constraints) + else: raise ValueError(f"No schema migration defined for version {new_version}") @@ -1744,7 +1856,7 @@ def cleanup_legacy_states_event_ids(instance: Recorder) -> bool: rebuild_sqlite_table(session_maker, instance.engine, States) else: _drop_foreign_key_constraints( - session_maker, instance.engine, TABLE_STATES, ["event_id"] + session_maker, instance.engine, TABLE_STATES, "event_id" ) _drop_index(session_maker, "states", LEGACY_STATES_EVENT_ID_INDEX) instance.use_legacy_events_index = False diff --git a/tests/components/recorder/db_schema_43.py b/tests/components/recorder/db_schema_43.py new file mode 100644 index 00000000000000..896c995f3641ed --- /dev/null +++ b/tests/components/recorder/db_schema_43.py @@ -0,0 +1,889 @@ +"""Models for SQLAlchemy. + +This file contains the model definitions for schema version 43. +It is used to test the schema migration logic. +""" + +from __future__ import annotations + +from collections.abc import Callable +from datetime import datetime, timedelta +import logging +import time +from typing import Any, Self, cast + +import ciso8601 +from fnv_hash_fast import fnv1a_32 +from sqlalchemy import ( + CHAR, + JSON, + BigInteger, + Boolean, + ColumnElement, + DateTime, + Float, + ForeignKey, + Identity, + Index, + Integer, + LargeBinary, + SmallInteger, + String, + Text, + case, + type_coerce, +) +from sqlalchemy.dialects import mysql, oracle, postgresql, sqlite +from sqlalchemy.engine.interfaces import Dialect +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.orm import DeclarativeBase, Mapped, aliased, mapped_column, relationship +from sqlalchemy.types import TypeDecorator + +from homeassistant.components.recorder.const import ( + ALL_DOMAIN_EXCLUDE_ATTRS, + SupportedDialect, +) +from homeassistant.components.recorder.models import ( + StatisticData, + StatisticDataTimestamp, + StatisticMetaData, + bytes_to_ulid_or_none, + bytes_to_uuid_hex_or_none, + datetime_to_timestamp_or_none, + process_timestamp, + ulid_to_bytes_or_none, + uuid_hex_to_bytes_or_none, +) +from homeassistant.components.sensor import ATTR_STATE_CLASS +from homeassistant.const import ( + ATTR_DEVICE_CLASS, + ATTR_FRIENDLY_NAME, + ATTR_UNIT_OF_MEASUREMENT, + MATCH_ALL, + MAX_LENGTH_EVENT_EVENT_TYPE, + MAX_LENGTH_STATE_ENTITY_ID, + MAX_LENGTH_STATE_STATE, +) +from homeassistant.core import Context, Event, EventOrigin, EventStateChangedData, State +from homeassistant.helpers.json import JSON_DUMP, json_bytes, json_bytes_strip_null +import homeassistant.util.dt as dt_util +from homeassistant.util.json import ( + JSON_DECODE_EXCEPTIONS, + json_loads, + json_loads_object, +) + + +# SQLAlchemy Schema +class Base(DeclarativeBase): + """Base class for tables.""" + + +SCHEMA_VERSION = 43 + +_LOGGER = logging.getLogger(__name__) + +TABLE_EVENTS = "events" +TABLE_EVENT_DATA = "event_data" +TABLE_EVENT_TYPES = "event_types" +TABLE_STATES = "states" +TABLE_STATE_ATTRIBUTES = "state_attributes" +TABLE_STATES_META = "states_meta" +TABLE_RECORDER_RUNS = "recorder_runs" +TABLE_SCHEMA_CHANGES = "schema_changes" +TABLE_STATISTICS = "statistics" +TABLE_STATISTICS_META = "statistics_meta" +TABLE_STATISTICS_RUNS = "statistics_runs" +TABLE_STATISTICS_SHORT_TERM = "statistics_short_term" +TABLE_MIGRATION_CHANGES = "migration_changes" + +STATISTICS_TABLES = ("statistics", "statistics_short_term") + +MAX_STATE_ATTRS_BYTES = 16384 +MAX_EVENT_DATA_BYTES = 32768 + +PSQL_DIALECT = SupportedDialect.POSTGRESQL + +ALL_TABLES = [ + TABLE_STATES, + TABLE_STATE_ATTRIBUTES, + TABLE_EVENTS, + TABLE_EVENT_DATA, + TABLE_EVENT_TYPES, + TABLE_RECORDER_RUNS, + TABLE_SCHEMA_CHANGES, + TABLE_MIGRATION_CHANGES, + TABLE_STATES_META, + TABLE_STATISTICS, + TABLE_STATISTICS_META, + TABLE_STATISTICS_RUNS, + TABLE_STATISTICS_SHORT_TERM, +] + +TABLES_TO_CHECK = [ + TABLE_STATES, + TABLE_EVENTS, + TABLE_RECORDER_RUNS, + TABLE_SCHEMA_CHANGES, +] + +LAST_UPDATED_INDEX_TS = "ix_states_last_updated_ts" +METADATA_ID_LAST_UPDATED_INDEX_TS = "ix_states_metadata_id_last_updated_ts" +EVENTS_CONTEXT_ID_BIN_INDEX = "ix_events_context_id_bin" +STATES_CONTEXT_ID_BIN_INDEX = "ix_states_context_id_bin" +LEGACY_STATES_EVENT_ID_INDEX = "ix_states_event_id" +LEGACY_STATES_ENTITY_ID_LAST_UPDATED_INDEX = "ix_states_entity_id_last_updated_ts" +CONTEXT_ID_BIN_MAX_LENGTH = 16 + +MYSQL_COLLATE = "utf8mb4_unicode_ci" +MYSQL_DEFAULT_CHARSET = "utf8mb4" +MYSQL_ENGINE = "InnoDB" + +_DEFAULT_TABLE_ARGS = { + "mysql_default_charset": MYSQL_DEFAULT_CHARSET, + "mysql_collate": MYSQL_COLLATE, + "mysql_engine": MYSQL_ENGINE, + "mariadb_default_charset": MYSQL_DEFAULT_CHARSET, + "mariadb_collate": MYSQL_COLLATE, + "mariadb_engine": MYSQL_ENGINE, +} + +_MATCH_ALL_KEEP = { + ATTR_DEVICE_CLASS, + ATTR_STATE_CLASS, + ATTR_UNIT_OF_MEASUREMENT, + ATTR_FRIENDLY_NAME, +} + + +class UnusedDateTime(DateTime): + """An unused column type that behaves like a datetime.""" + + +class Unused(CHAR): + """An unused column type that behaves like a string.""" + + +@compiles(UnusedDateTime, "mysql", "mariadb", "sqlite") # type: ignore[misc,no-untyped-call] +@compiles(Unused, "mysql", "mariadb", "sqlite") # type: ignore[misc,no-untyped-call] +def compile_char_zero(type_: TypeDecorator, compiler: Any, **kw: Any) -> str: + """Compile UnusedDateTime and Unused as CHAR(0) on mysql, mariadb, and sqlite.""" + return "CHAR(0)" # Uses 1 byte on MySQL (no change on sqlite) + + +@compiles(Unused, "postgresql") # type: ignore[misc,no-untyped-call] +def compile_char_one(type_: TypeDecorator, compiler: Any, **kw: Any) -> str: + """Compile Unused as CHAR(1) on postgresql.""" + return "CHAR(1)" # Uses 1 byte + + +class FAST_PYSQLITE_DATETIME(sqlite.DATETIME): + """Use ciso8601 to parse datetimes instead of sqlalchemy built-in regex.""" + + def result_processor(self, dialect, coltype): # type: ignore[no-untyped-def] + """Offload the datetime parsing to ciso8601.""" + return lambda value: None if value is None else ciso8601.parse_datetime(value) + + +class NativeLargeBinary(LargeBinary): + """A faster version of LargeBinary for engines that support python bytes natively.""" + + def result_processor(self, dialect, coltype): # type: ignore[no-untyped-def] + """No conversion needed for engines that support native bytes.""" + return None + + +# For MariaDB and MySQL we can use an unsigned integer type since it will fit 2**32 +# for sqlite and postgresql we use a bigint +UINT_32_TYPE = BigInteger().with_variant( + mysql.INTEGER(unsigned=True), # type: ignore[no-untyped-call] + "mysql", + "mariadb", +) +JSON_VARIANT_CAST = Text().with_variant( + postgresql.JSON(none_as_null=True), # type: ignore[no-untyped-call] + "postgresql", +) +JSONB_VARIANT_CAST = Text().with_variant( + postgresql.JSONB(none_as_null=True), # type: ignore[no-untyped-call] + "postgresql", +) +DATETIME_TYPE = ( + DateTime(timezone=True) + .with_variant(mysql.DATETIME(timezone=True, fsp=6), "mysql", "mariadb") # type: ignore[no-untyped-call] + .with_variant(FAST_PYSQLITE_DATETIME(), "sqlite") # type: ignore[no-untyped-call] +) +DOUBLE_TYPE = ( + Float() + .with_variant(mysql.DOUBLE(asdecimal=False), "mysql", "mariadb") # type: ignore[no-untyped-call] + .with_variant(oracle.DOUBLE_PRECISION(), "oracle") + .with_variant(postgresql.DOUBLE_PRECISION(), "postgresql") +) +UNUSED_LEGACY_COLUMN = Unused(0) +UNUSED_LEGACY_DATETIME_COLUMN = UnusedDateTime(timezone=True) +UNUSED_LEGACY_INTEGER_COLUMN = SmallInteger() +DOUBLE_PRECISION_TYPE_SQL = "DOUBLE PRECISION" +CONTEXT_BINARY_TYPE = LargeBinary(CONTEXT_ID_BIN_MAX_LENGTH).with_variant( + NativeLargeBinary(CONTEXT_ID_BIN_MAX_LENGTH), "mysql", "mariadb", "sqlite" +) + +TIMESTAMP_TYPE = DOUBLE_TYPE + + +class JSONLiteral(JSON): + """Teach SA how to literalize json.""" + + def literal_processor(self, dialect: Dialect) -> Callable[[Any], str]: + """Processor to convert a value to JSON.""" + + def process(value: Any) -> str: + """Dump json.""" + return JSON_DUMP(value) + + return process + + +EVENT_ORIGIN_ORDER = [EventOrigin.local, EventOrigin.remote] + + +class Events(Base): + """Event history data.""" + + __table_args__ = ( + # Used for fetching events at a specific time + # see logbook + Index( + "ix_events_event_type_id_time_fired_ts", "event_type_id", "time_fired_ts" + ), + Index( + EVENTS_CONTEXT_ID_BIN_INDEX, + "context_id_bin", + mysql_length=CONTEXT_ID_BIN_MAX_LENGTH, + mariadb_length=CONTEXT_ID_BIN_MAX_LENGTH, + ), + _DEFAULT_TABLE_ARGS, + ) + __tablename__ = TABLE_EVENTS + event_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + event_type: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) + event_data: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) + origin: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) + origin_idx: Mapped[int | None] = mapped_column(SmallInteger) + time_fired: Mapped[datetime | None] = mapped_column(UNUSED_LEGACY_DATETIME_COLUMN) + time_fired_ts: Mapped[float | None] = mapped_column(TIMESTAMP_TYPE, index=True) + context_id: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) + context_user_id: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) + context_parent_id: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) + data_id: Mapped[int | None] = mapped_column( + Integer, ForeignKey("event_data.data_id"), index=True + ) + context_id_bin: Mapped[bytes | None] = mapped_column(CONTEXT_BINARY_TYPE) + context_user_id_bin: Mapped[bytes | None] = mapped_column(CONTEXT_BINARY_TYPE) + context_parent_id_bin: Mapped[bytes | None] = mapped_column(CONTEXT_BINARY_TYPE) + event_type_id: Mapped[int | None] = mapped_column( + Integer, ForeignKey("event_types.event_type_id") + ) + event_data_rel: Mapped[EventData | None] = relationship("EventData") + event_type_rel: Mapped[EventTypes | None] = relationship("EventTypes") + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + "" + ) + + @property + def _time_fired_isotime(self) -> str | None: + """Return time_fired as an isotime string.""" + date_time: datetime | None + if self.time_fired_ts is not None: + date_time = dt_util.utc_from_timestamp(self.time_fired_ts) + else: + date_time = process_timestamp(self.time_fired) + if date_time is None: + return None + return date_time.isoformat(sep=" ", timespec="seconds") + + @staticmethod + def from_event(event: Event) -> Events: + """Create an event database object from a native event.""" + context = event.context + return Events( + event_type=None, + event_data=None, + origin_idx=event.origin.idx, + time_fired=None, + time_fired_ts=event.time_fired_timestamp, + context_id=None, + context_id_bin=ulid_to_bytes_or_none(context.id), + context_user_id=None, + context_user_id_bin=uuid_hex_to_bytes_or_none(context.user_id), + context_parent_id=None, + context_parent_id_bin=ulid_to_bytes_or_none(context.parent_id), + ) + + def to_native(self, validate_entity_id: bool = True) -> Event | None: + """Convert to a native HA Event.""" + context = Context( + id=bytes_to_ulid_or_none(self.context_id_bin), + user_id=bytes_to_uuid_hex_or_none(self.context_user_id_bin), + parent_id=bytes_to_ulid_or_none(self.context_parent_id_bin), + ) + try: + return Event( + self.event_type or "", + json_loads_object(self.event_data) if self.event_data else {}, + EventOrigin(self.origin) + if self.origin + else EVENT_ORIGIN_ORDER[self.origin_idx or 0], + self.time_fired_ts or 0, + context=context, + ) + except JSON_DECODE_EXCEPTIONS: + # When json_loads fails + _LOGGER.exception("Error converting to event: %s", self) + return None + + +class EventData(Base): + """Event data history.""" + + __table_args__ = (_DEFAULT_TABLE_ARGS,) + __tablename__ = TABLE_EVENT_DATA + data_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + hash: Mapped[int | None] = mapped_column(UINT_32_TYPE, index=True) + # Note that this is not named attributes to avoid confusion with the states table + shared_data: Mapped[str | None] = mapped_column( + Text().with_variant(mysql.LONGTEXT, "mysql", "mariadb") + ) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + "" + ) + + @staticmethod + def shared_data_bytes_from_event( + event: Event, dialect: SupportedDialect | None + ) -> bytes: + """Create shared_data from an event.""" + if dialect == SupportedDialect.POSTGRESQL: + bytes_result = json_bytes_strip_null(event.data) + bytes_result = json_bytes(event.data) + if len(bytes_result) > MAX_EVENT_DATA_BYTES: + _LOGGER.warning( + "Event data for %s exceed maximum size of %s bytes. " + "This can cause database performance issues; Event data " + "will not be stored", + event.event_type, + MAX_EVENT_DATA_BYTES, + ) + return b"{}" + return bytes_result + + @staticmethod + def hash_shared_data_bytes(shared_data_bytes: bytes) -> int: + """Return the hash of json encoded shared data.""" + return fnv1a_32(shared_data_bytes) + + def to_native(self) -> dict[str, Any]: + """Convert to an event data dictionary.""" + shared_data = self.shared_data + if shared_data is None: + return {} + try: + return cast(dict[str, Any], json_loads(shared_data)) + except JSON_DECODE_EXCEPTIONS: + _LOGGER.exception("Error converting row to event data: %s", self) + return {} + + +class EventTypes(Base): + """Event type history.""" + + __table_args__ = (_DEFAULT_TABLE_ARGS,) + __tablename__ = TABLE_EVENT_TYPES + event_type_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + event_type: Mapped[str | None] = mapped_column( + String(MAX_LENGTH_EVENT_EVENT_TYPE), index=True, unique=True + ) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + "" + ) + + +class States(Base): + """State change history.""" + + __table_args__ = ( + # Used for fetching the state of entities at a specific time + # (get_states in history.py) + Index(METADATA_ID_LAST_UPDATED_INDEX_TS, "metadata_id", "last_updated_ts"), + Index( + STATES_CONTEXT_ID_BIN_INDEX, + "context_id_bin", + mysql_length=CONTEXT_ID_BIN_MAX_LENGTH, + mariadb_length=CONTEXT_ID_BIN_MAX_LENGTH, + ), + _DEFAULT_TABLE_ARGS, + ) + __tablename__ = TABLE_STATES + state_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + entity_id: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) + state: Mapped[str | None] = mapped_column(String(MAX_LENGTH_STATE_STATE)) + attributes: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) + event_id: Mapped[int | None] = mapped_column(UNUSED_LEGACY_INTEGER_COLUMN) + last_changed: Mapped[datetime | None] = mapped_column(UNUSED_LEGACY_DATETIME_COLUMN) + last_changed_ts: Mapped[float | None] = mapped_column(TIMESTAMP_TYPE) + last_reported_ts: Mapped[float | None] = mapped_column(TIMESTAMP_TYPE) + last_updated: Mapped[datetime | None] = mapped_column(UNUSED_LEGACY_DATETIME_COLUMN) + last_updated_ts: Mapped[float | None] = mapped_column( + TIMESTAMP_TYPE, default=time.time, index=True + ) + old_state_id: Mapped[int | None] = mapped_column( + Integer, ForeignKey("states.state_id"), index=True + ) + attributes_id: Mapped[int | None] = mapped_column( + Integer, ForeignKey("state_attributes.attributes_id"), index=True + ) + context_id: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) + context_user_id: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) + context_parent_id: Mapped[str | None] = mapped_column(UNUSED_LEGACY_COLUMN) + origin_idx: Mapped[int | None] = mapped_column( + SmallInteger + ) # 0 is local, 1 is remote + old_state: Mapped[States | None] = relationship("States", remote_side=[state_id]) + state_attributes: Mapped[StateAttributes | None] = relationship("StateAttributes") + context_id_bin: Mapped[bytes | None] = mapped_column(CONTEXT_BINARY_TYPE) + context_user_id_bin: Mapped[bytes | None] = mapped_column(CONTEXT_BINARY_TYPE) + context_parent_id_bin: Mapped[bytes | None] = mapped_column(CONTEXT_BINARY_TYPE) + metadata_id: Mapped[int | None] = mapped_column( + Integer, ForeignKey("states_meta.metadata_id") + ) + states_meta_rel: Mapped[StatesMeta | None] = relationship("StatesMeta") + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + @property + def _last_updated_isotime(self) -> str | None: + """Return last_updated as an isotime string.""" + date_time: datetime | None + if self.last_updated_ts is not None: + date_time = dt_util.utc_from_timestamp(self.last_updated_ts) + else: + date_time = process_timestamp(self.last_updated) + if date_time is None: + return None + return date_time.isoformat(sep=" ", timespec="seconds") + + @staticmethod + def from_event(event: Event[EventStateChangedData]) -> States: + """Create object from a state_changed event.""" + state = event.data["new_state"] + # None state means the state was removed from the state machine + if state is None: + state_value = "" + last_updated_ts = event.time_fired_timestamp + last_changed_ts = None + last_reported_ts = None + else: + state_value = state.state + last_updated_ts = state.last_updated_timestamp + if state.last_updated == state.last_changed: + last_changed_ts = None + else: + last_changed_ts = state.last_changed_timestamp + if state.last_updated == state.last_reported: + last_reported_ts = None + else: + last_reported_ts = state.last_reported_timestamp + context = event.context + return States( + state=state_value, + entity_id=event.data["entity_id"], + attributes=None, + context_id=None, + context_id_bin=ulid_to_bytes_or_none(context.id), + context_user_id=None, + context_user_id_bin=uuid_hex_to_bytes_or_none(context.user_id), + context_parent_id=None, + context_parent_id_bin=ulid_to_bytes_or_none(context.parent_id), + origin_idx=event.origin.idx, + last_updated=None, + last_changed=None, + last_updated_ts=last_updated_ts, + last_changed_ts=last_changed_ts, + last_reported_ts=last_reported_ts, + ) + + def to_native(self, validate_entity_id: bool = True) -> State | None: + """Convert to an HA state object.""" + context = Context( + id=bytes_to_ulid_or_none(self.context_id_bin), + user_id=bytes_to_uuid_hex_or_none(self.context_user_id_bin), + parent_id=bytes_to_ulid_or_none(self.context_parent_id_bin), + ) + try: + attrs = json_loads_object(self.attributes) if self.attributes else {} + except JSON_DECODE_EXCEPTIONS: + # When json_loads fails + _LOGGER.exception("Error converting row to state: %s", self) + return None + last_updated = dt_util.utc_from_timestamp(self.last_updated_ts or 0) + if self.last_changed_ts is None or self.last_changed_ts == self.last_updated_ts: + last_changed = dt_util.utc_from_timestamp(self.last_updated_ts or 0) + else: + last_changed = dt_util.utc_from_timestamp(self.last_changed_ts or 0) + if ( + self.last_reported_ts is None + or self.last_reported_ts == self.last_updated_ts + ): + last_reported = dt_util.utc_from_timestamp(self.last_updated_ts or 0) + else: + last_reported = dt_util.utc_from_timestamp(self.last_reported_ts or 0) + return State( + self.entity_id or "", + self.state, # type: ignore[arg-type] + # Join the state_attributes table on attributes_id to get the attributes + # for newer states + attrs, + last_changed=last_changed, + last_reported=last_reported, + last_updated=last_updated, + context=context, + validate_entity_id=validate_entity_id, + ) + + +class StateAttributes(Base): + """State attribute change history.""" + + __table_args__ = (_DEFAULT_TABLE_ARGS,) + __tablename__ = TABLE_STATE_ATTRIBUTES + attributes_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + hash: Mapped[int | None] = mapped_column(UINT_32_TYPE, index=True) + # Note that this is not named attributes to avoid confusion with the states table + shared_attrs: Mapped[str | None] = mapped_column( + Text().with_variant(mysql.LONGTEXT, "mysql", "mariadb") + ) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + @staticmethod + def shared_attrs_bytes_from_event( + event: Event[EventStateChangedData], + dialect: SupportedDialect | None, + ) -> bytes: + """Create shared_attrs from a state_changed event.""" + # None state means the state was removed from the state machine + if (state := event.data["new_state"]) is None: + return b"{}" + if state_info := state.state_info: + unrecorded_attributes = state_info["unrecorded_attributes"] + exclude_attrs = { + *ALL_DOMAIN_EXCLUDE_ATTRS, + *unrecorded_attributes, + } + if MATCH_ALL in unrecorded_attributes: + # Don't exclude device class, state class, unit of measurement + # or friendly name when using the MATCH_ALL exclude constant + exclude_attrs.update(state.attributes) + exclude_attrs -= _MATCH_ALL_KEEP + else: + exclude_attrs = ALL_DOMAIN_EXCLUDE_ATTRS + encoder = json_bytes_strip_null if dialect == PSQL_DIALECT else json_bytes + bytes_result = encoder( + {k: v for k, v in state.attributes.items() if k not in exclude_attrs} + ) + if len(bytes_result) > MAX_STATE_ATTRS_BYTES: + _LOGGER.warning( + "State attributes for %s exceed maximum size of %s bytes. " + "This can cause database performance issues; Attributes " + "will not be stored", + state.entity_id, + MAX_STATE_ATTRS_BYTES, + ) + return b"{}" + return bytes_result + + @staticmethod + def hash_shared_attrs_bytes(shared_attrs_bytes: bytes) -> int: + """Return the hash of json encoded shared attributes.""" + return fnv1a_32(shared_attrs_bytes) + + def to_native(self) -> dict[str, Any]: + """Convert to a state attributes dictionary.""" + shared_attrs = self.shared_attrs + if shared_attrs is None: + return {} + try: + return cast(dict[str, Any], json_loads(shared_attrs)) + except JSON_DECODE_EXCEPTIONS: + # When json_loads fails + _LOGGER.exception("Error converting row to state attributes: %s", self) + return {} + + +class StatesMeta(Base): + """Metadata for states.""" + + __table_args__ = (_DEFAULT_TABLE_ARGS,) + __tablename__ = TABLE_STATES_META + metadata_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + entity_id: Mapped[str | None] = mapped_column( + String(MAX_LENGTH_STATE_ENTITY_ID), index=True, unique=True + ) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + "" + ) + + +class StatisticsBase: + """Statistics base class.""" + + id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + created: Mapped[datetime | None] = mapped_column(UNUSED_LEGACY_DATETIME_COLUMN) + created_ts: Mapped[float | None] = mapped_column(TIMESTAMP_TYPE, default=time.time) + metadata_id: Mapped[int | None] = mapped_column( + Integer, + ForeignKey(f"{TABLE_STATISTICS_META}.id", ondelete="CASCADE"), + ) + start: Mapped[datetime | None] = mapped_column(UNUSED_LEGACY_DATETIME_COLUMN) + start_ts: Mapped[float | None] = mapped_column(TIMESTAMP_TYPE, index=True) + mean: Mapped[float | None] = mapped_column(DOUBLE_TYPE) + min: Mapped[float | None] = mapped_column(DOUBLE_TYPE) + max: Mapped[float | None] = mapped_column(DOUBLE_TYPE) + last_reset: Mapped[datetime | None] = mapped_column(UNUSED_LEGACY_DATETIME_COLUMN) + last_reset_ts: Mapped[float | None] = mapped_column(TIMESTAMP_TYPE) + state: Mapped[float | None] = mapped_column(DOUBLE_TYPE) + sum: Mapped[float | None] = mapped_column(DOUBLE_TYPE) + + duration: timedelta + + @classmethod + def from_stats(cls, metadata_id: int, stats: StatisticData) -> Self: + """Create object from a statistics with datatime objects.""" + return cls( # type: ignore[call-arg] + metadata_id=metadata_id, + created=None, + created_ts=time.time(), + start=None, + start_ts=dt_util.utc_to_timestamp(stats["start"]), + mean=stats.get("mean"), + min=stats.get("min"), + max=stats.get("max"), + last_reset=None, + last_reset_ts=datetime_to_timestamp_or_none(stats.get("last_reset")), + state=stats.get("state"), + sum=stats.get("sum"), + ) + + @classmethod + def from_stats_ts(cls, metadata_id: int, stats: StatisticDataTimestamp) -> Self: + """Create object from a statistics with timestamps.""" + return cls( # type: ignore[call-arg] + metadata_id=metadata_id, + created=None, + created_ts=time.time(), + start=None, + start_ts=stats["start_ts"], + mean=stats.get("mean"), + min=stats.get("min"), + max=stats.get("max"), + last_reset=None, + last_reset_ts=stats.get("last_reset_ts"), + state=stats.get("state"), + sum=stats.get("sum"), + ) + + +class Statistics(Base, StatisticsBase): + """Long term statistics.""" + + duration = timedelta(hours=1) + + __table_args__ = ( + # Used for fetching statistics for a certain entity at a specific time + Index( + "ix_statistics_statistic_id_start_ts", + "metadata_id", + "start_ts", + unique=True, + ), + _DEFAULT_TABLE_ARGS, + ) + __tablename__ = TABLE_STATISTICS + + +class StatisticsShortTerm(Base, StatisticsBase): + """Short term statistics.""" + + duration = timedelta(minutes=5) + + __table_args__ = ( + # Used for fetching statistics for a certain entity at a specific time + Index( + "ix_statistics_short_term_statistic_id_start_ts", + "metadata_id", + "start_ts", + unique=True, + ), + _DEFAULT_TABLE_ARGS, + ) + __tablename__ = TABLE_STATISTICS_SHORT_TERM + + +class StatisticsMeta(Base): + """Statistics meta data.""" + + __table_args__ = (_DEFAULT_TABLE_ARGS,) + __tablename__ = TABLE_STATISTICS_META + id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + statistic_id: Mapped[str | None] = mapped_column( + String(255), index=True, unique=True + ) + source: Mapped[str | None] = mapped_column(String(32)) + unit_of_measurement: Mapped[str | None] = mapped_column(String(255)) + has_mean: Mapped[bool | None] = mapped_column(Boolean) + has_sum: Mapped[bool | None] = mapped_column(Boolean) + name: Mapped[str | None] = mapped_column(String(255)) + + @staticmethod + def from_meta(meta: StatisticMetaData) -> StatisticsMeta: + """Create object from meta data.""" + return StatisticsMeta(**meta) + + +class RecorderRuns(Base): + """Representation of recorder run.""" + + __table_args__ = ( + Index("ix_recorder_runs_start_end", "start", "end"), + _DEFAULT_TABLE_ARGS, + ) + __tablename__ = TABLE_RECORDER_RUNS + run_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + start: Mapped[datetime] = mapped_column(DATETIME_TYPE, default=dt_util.utcnow) + end: Mapped[datetime | None] = mapped_column(DATETIME_TYPE) + closed_incorrect: Mapped[bool] = mapped_column(Boolean, default=False) + created: Mapped[datetime] = mapped_column(DATETIME_TYPE, default=dt_util.utcnow) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + end = ( + f"'{self.end.isoformat(sep=' ', timespec='seconds')}'" if self.end else None + ) + return ( + f"" + ) + + def to_native(self, validate_entity_id: bool = True) -> Self: + """Return self, native format is this model.""" + return self + + +class MigrationChanges(Base): + """Representation of migration changes.""" + + __tablename__ = TABLE_MIGRATION_CHANGES + __table_args__ = (_DEFAULT_TABLE_ARGS,) + + migration_id: Mapped[str] = mapped_column(String(255), primary_key=True) + version: Mapped[int] = mapped_column(SmallInteger) + + +class SchemaChanges(Base): + """Representation of schema version changes.""" + + __tablename__ = TABLE_SCHEMA_CHANGES + __table_args__ = (_DEFAULT_TABLE_ARGS,) + + change_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + schema_version: Mapped[int | None] = mapped_column(Integer) + changed: Mapped[datetime] = mapped_column(DATETIME_TYPE, default=dt_util.utcnow) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + "" + ) + + +class StatisticsRuns(Base): + """Representation of statistics run.""" + + __tablename__ = TABLE_STATISTICS_RUNS + __table_args__ = (_DEFAULT_TABLE_ARGS,) + + run_id: Mapped[int] = mapped_column(Integer, Identity(), primary_key=True) + start: Mapped[datetime] = mapped_column(DATETIME_TYPE, index=True) + + def __repr__(self) -> str: + """Return string representation of instance for debugging.""" + return ( + f"" + ) + + +EVENT_DATA_JSON = type_coerce( + EventData.shared_data.cast(JSONB_VARIANT_CAST), JSONLiteral(none_as_null=True) +) +OLD_FORMAT_EVENT_DATA_JSON = type_coerce( + Events.event_data.cast(JSONB_VARIANT_CAST), JSONLiteral(none_as_null=True) +) + +SHARED_ATTRS_JSON = type_coerce( + StateAttributes.shared_attrs.cast(JSON_VARIANT_CAST), JSON(none_as_null=True) +) +OLD_FORMAT_ATTRS_JSON = type_coerce( + States.attributes.cast(JSON_VARIANT_CAST), JSON(none_as_null=True) +) + +ENTITY_ID_IN_EVENT: ColumnElement = EVENT_DATA_JSON["entity_id"] +OLD_ENTITY_ID_IN_EVENT: ColumnElement = OLD_FORMAT_EVENT_DATA_JSON["entity_id"] +DEVICE_ID_IN_EVENT: ColumnElement = EVENT_DATA_JSON["device_id"] +OLD_STATE = aliased(States, name="old_state") + +SHARED_ATTR_OR_LEGACY_ATTRIBUTES = case( + (StateAttributes.shared_attrs.is_(None), States.attributes), + else_=StateAttributes.shared_attrs, +).label("attributes") +SHARED_DATA_OR_LEGACY_EVENT_DATA = case( + (EventData.shared_data.is_(None), Events.event_data), else_=EventData.shared_data +).label("event_data") diff --git a/tests/components/recorder/test_migrate.py b/tests/components/recorder/test_migrate.py index cd9650779b53d4..25fe8993cfb6f5 100644 --- a/tests/components/recorder/test_migrate.py +++ b/tests/components/recorder/test_migrate.py @@ -327,7 +327,7 @@ async def test_events_during_migration_queue_exhausted( @pytest.mark.parametrize( ("start_version", "live"), - [(0, True), (16, True), (18, True), (22, True), (25, True)], + [(0, True), (16, True), (18, True), (22, True), (25, True), (43, True)], ) async def test_schema_migrate( hass: HomeAssistant, @@ -682,3 +682,43 @@ def test_rebuild_sqlite_states_table_extra_columns( assert session.query(States).first().state == "on" engine.dispose() + + +def test_restore_foreign_key_constraints_with_error( + caplog: pytest.LogCaptureFixture, +) -> None: + """Test we can drop and then restore foreign keys. + + This is not supported on SQLite + """ + + constraints_to_restore = [ + ( + "events", + "data_id", + { + "comment": None, + "constrained_columns": ["data_id"], + "name": "events_data_id_fkey", + "options": {}, + "referred_columns": ["data_id"], + "referred_schema": None, + "referred_table": "event_data", + }, + ), + ] + + connection = Mock() + connection.execute = Mock(side_effect=InternalError(None, None, None)) + session = Mock() + session.connection = Mock(return_value=connection) + instance = Mock() + instance.get_session = Mock(return_value=session) + engine = Mock() + + session_maker = Mock(return_value=session) + migration._restore_foreign_key_constraints( + session_maker, engine, constraints_to_restore + ) + + assert "Could not update foreign options in events table" in caplog.text