diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index d7f7832b..5e9823f2 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -111,6 +111,9 @@ jobs: run: | gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py + - name: Run test for AWS + run: hatch run test-dialect-aws + if: matrix.cloud-provider == 'aws' - name: Run tests run: hatch run test-dialect - uses: actions/upload-artifact@v4 @@ -203,6 +206,9 @@ jobs: python -m pip install -U uv python -m uv pip install -U hatch python -m hatch env create default + - name: Run test for AWS + run: hatch run sa14:test-dialect-aws + if: matrix.cloud-provider == 'aws' - name: Run tests run: hatch run sa14:test-dialect - uses: actions/upload-artifact@v4 diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index a1eb1a0c..52f43106 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -57,14 +57,14 @@ jobs: --signature "${dist_base}.sig" \ --cert "${dist_base}.crt" \ --cert-oidc-issuer https://token.actions.githubusercontent.com \ - --cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/build_and_sign_demand.yml@${GITHUB_REF} + --cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/python-publish.yml@${GITHUB_REF} # Verify using `.sigstore` bundle; python -m \ sigstore verify identity "${dist}" \ --bundle "${dist_base}.sigstore" \ --cert-oidc-issuer https://token.actions.githubusercontent.com \ - --cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/build_and_sign_demand.yml@${GITHUB_REF} + --cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/python-publish.yml@${GITHUB_REF} done - name: List artifacts after sign run: ls ./dist diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 38cd70f7..58c2dfe2 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,6 +9,12 @@ Source code is also available at: # Release Notes +- (Unreleased) + + - Add support for dynamic tables and required options + - Add support for hybrid tables + - Fixed SAWarning when registering functions with existing name in default namespace + - v1.6.1(July 9, 2024) - Update internal project workflow with pypi publishing @@ -24,7 +30,7 @@ Source code is also available at: - v1.5.3(April 16, 2024) - - Limit SQLAlchemy to < 2.0.0 before releasing version compatible with 2.0 + - Limit SQLAlchemy to < 2.0.0 before releasing version compatible with 2.0 - v1.5.2(April 11, 2024) @@ -33,7 +39,7 @@ Source code is also available at: - v1.5.1(November 03, 2023) - - Fixed a compatibility issue with Snowflake Behavioral Change 1057 on outer lateral join, for more details check https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057. + - Fixed a compatibility issue with Snowflake Behavioral Change 1057 on outer lateral join, for more details check . - Fixed credentials with `externalbrowser` authentication not caching due to incorrect parsing of boolean query parameters. - This fixes other boolean parameter passing to driver as well. diff --git a/pyproject.toml b/pyproject.toml index 9cdd9fb4..6c72f683 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ development = [ "pytz", "numpy", "mock", + "syrupy==4.6.1", ] pandas = ["snowflake-connector-python[pandas]"] @@ -91,6 +92,7 @@ SQLACHEMY_WARN_20 = "1" check = "pre-commit run --all-files" test-dialect = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/" test-dialect-compatibility = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml tests/sqlalchemy_test_suite" +test-dialect-aws = "pytest -m \"aws\" -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/" gh-cache-sum = "python -VV | sha256sum | cut -d' ' -f1" check-import = "python -c 'import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)'" @@ -109,6 +111,7 @@ line-length = 88 line-length = 88 [tool.pytest.ini_options] +addopts = "-m 'not feature_max_lob_size and not aws'" markers = [ # Optional dependency groups markers "lambda: AWS lambda tests", @@ -126,4 +129,5 @@ markers = [ "timeout: tests that need a timeout time", "internal: tests that could but should only run on our internal CI", "external: tests that could but should only run on our external CI", + "feature_max_lob_size: tests that could but should only run on our external CI", ] diff --git a/src/snowflake/sqlalchemy/__init__.py b/src/snowflake/sqlalchemy/__init__.py index 9df6aaa2..0afd44a5 100644 --- a/src/snowflake/sqlalchemy/__init__.py +++ b/src/snowflake/sqlalchemy/__init__.py @@ -61,6 +61,8 @@ VARBINARY, VARIANT, ) +from .sql.custom_schema import DynamicTable, HybridTable +from .sql.custom_schema.options import AsQuery, TargetLag, TimeUnit, Warehouse from .util import _url as URL base.dialect = dialect = snowdialect.dialect @@ -113,4 +115,10 @@ "ExternalStage", "CreateStage", "CreateFileFormat", + "DynamicTable", + "AsQuery", + "TargetLag", + "TimeUnit", + "Warehouse", + "HybridTable", ) diff --git a/src/snowflake/sqlalchemy/_constants.py b/src/snowflake/sqlalchemy/_constants.py index 46af4454..839745ee 100644 --- a/src/snowflake/sqlalchemy/_constants.py +++ b/src/snowflake/sqlalchemy/_constants.py @@ -10,3 +10,4 @@ APPLICATION_NAME = "SnowflakeSQLAlchemy" SNOWFLAKE_SQLALCHEMY_VERSION = VERSION +DIALECT_NAME = "snowflake" diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 1aaa881e..56631728 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -18,9 +18,16 @@ from sqlalchemy.sql.elements import quoted_name from sqlalchemy.sql.selectable import Lateral, SelectState -from .compat import IS_VERSION_20, args_reducer, string_types -from .custom_commands import AWSBucket, AzureContainer, ExternalStage +from snowflake.sqlalchemy._constants import DIALECT_NAME +from snowflake.sqlalchemy.compat import IS_VERSION_20, args_reducer, string_types +from snowflake.sqlalchemy.custom_commands import ( + AWSBucket, + AzureContainer, + ExternalStage, +) + from .functions import flatten +from .sql.custom_schema.options.table_option_base import TableOptionBase from .util import ( _find_left_clause_to_join_from, _set_connection_interpolate_empty_sequences, @@ -184,7 +191,6 @@ def _join_determine_implicit_left_side(self, raw_columns, left, right, onclause) [element._from_objects for element in statement._where_criteria] ), ): - potential[from_clause] = () all_clauses = list(potential.keys()) @@ -879,7 +885,7 @@ def get_column_specification(self, column, **kwargs): return " ".join(colspec) - def post_create_table(self, table): + def handle_cluster_by(self, table): """ Handles snowflake-specific ``CREATE TABLE ... CLUSTER BY`` syntax. @@ -909,7 +915,7 @@ def post_create_table(self, table): """ text = "" - info = table.dialect_options["snowflake"] + info = table.dialect_options[DIALECT_NAME] cluster = info.get("clusterby") if cluster: text += " CLUSTER BY ({})".format( @@ -917,6 +923,21 @@ def post_create_table(self, table): ) return text + def post_create_table(self, table): + text = self.handle_cluster_by(table) + options = [ + option + for _, option in table.dialect_options[DIALECT_NAME].items() + if isinstance(option, TableOptionBase) + ] + options.sort( + key=lambda x: (x.__priority__.value, x.__option_name__), reverse=True + ) + for option in options: + text += "\t" + option.render_option(self) + + return text + def visit_create_stage(self, create_stage, **kw): """ This visitor will create the SQL representation for a CREATE STAGE command. @@ -1065,4 +1086,4 @@ def visit_GEOMETRY(self, type_, **kw): construct_arguments = [(Table, {"clusterby": None})] -functions.register_function("flatten", flatten) +functions.register_function("flatten", flatten, "snowflake") diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index 04305a00..f2fb9b18 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -42,6 +42,7 @@ from snowflake.connector.constants import UTF8 from snowflake.sqlalchemy.compat import returns_unicode +from ._constants import DIALECT_NAME from .base import ( SnowflakeCompiler, SnowflakeDDLCompiler, @@ -64,6 +65,7 @@ _CUSTOM_Float, _CUSTOM_Time, ) +from .sql.custom_schema.custom_table_prefix import CustomTablePrefix from .util import ( _update_connection_application_name, parse_url_boolean, @@ -119,7 +121,7 @@ class SnowflakeDialect(default.DefaultDialect): - name = "snowflake" + name = DIALECT_NAME driver = "snowflake" max_identifier_length = 255 cte_follows_insert = True @@ -351,14 +353,6 @@ def _map_name_to_idx(result): name_to_idx[col[0]] = idx return name_to_idx - @reflection.cache - def get_indexes(self, connection, table_name, schema=None, **kw): - """ - Gets all indexes - """ - # no index is supported by Snowflake - return [] - @reflection.cache def get_check_constraints(self, connection, table_name, schema, **kw): # check constraints are not supported by Snowflake @@ -894,6 +888,129 @@ def get_table_comment(self, connection, table_name, schema=None, **kw): ) } + def get_multi_indexes( + self, + connection, + *, + schema, + filter_names, + **kw, + ): + """ + Gets the indexes definition + """ + + table_prefixes = self.get_multi_prefixes( + connection, schema, filter_prefix=CustomTablePrefix.HYBRID.name + ) + if len(table_prefixes) == 0: + return [] + schema = schema or self.default_schema_name + if not schema: + result = connection.execute( + text("SHOW /* sqlalchemy:get_multi_indexes */ INDEXES") + ) + else: + result = connection.execute( + text( + f"SHOW /* sqlalchemy:get_multi_indexes */ INDEXES IN SCHEMA {self._denormalize_quote_join(schema)}" + ) + ) + + n2i = self.__class__._map_name_to_idx(result) + indexes = {} + + for row in result.cursor.fetchall(): + table = self.normalize_name(str(row[n2i["table"]])) + if ( + row[n2i["name"]] == f'SYS_INDEX_{row[n2i["table"]]}_PRIMARY' + or table not in filter_names + or (schema, table) not in table_prefixes + or ( + (schema, table) in table_prefixes + and CustomTablePrefix.HYBRID.name + not in table_prefixes[(schema, table)] + ) + ): + continue + index = { + "name": row[n2i["name"]], + "unique": row[n2i["is_unique"]] == "Y", + "column_names": row[n2i["columns"]], + "include_columns": row[n2i["included_columns"]], + "dialect_options": {}, + } + if (schema, table) in indexes: + indexes[(schema, table)] = indexes[(schema, table)].append(index) + else: + indexes[(schema, table)] = [index] + + return list(indexes.items()) + + def _value_or_default(self, data, table, schema): + table = self.normalize_name(str(table)) + dic_data = dict(data) + if (schema, table) in dic_data: + return dic_data[(schema, table)] + else: + return [] + + def get_prefixes_from_data(self, n2i, row, **kw): + prefixes_found = [] + for valid_prefix in CustomTablePrefix: + key = f"is_{valid_prefix.name.lower()}" + if key in n2i and row[n2i[key]] == "Y": + prefixes_found.append(valid_prefix.name) + return prefixes_found + + @reflection.cache + def get_multi_prefixes( + self, connection, schema, table_name=None, filter_prefix=None, **kw + ): + """ + Gets all table prefixes + """ + schema = schema or self.default_schema_name + filter = f"LIKE '{table_name}'" if table_name else "" + if schema: + result = connection.execute( + text( + f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES IN SCHEMA {schema}" + ) + ) + else: + result = connection.execute( + text( + f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES LIKE '{table_name}'" + ) + ) + + n2i = self.__class__._map_name_to_idx(result) + tables_prefixes = {} + for row in result.cursor.fetchall(): + table = self.normalize_name(str(row[n2i["name"]])) + table_prefixes = self.get_prefixes_from_data(n2i, row) + if filter_prefix and filter_prefix not in table_prefixes: + continue + if (schema, table) in tables_prefixes: + tables_prefixes[(schema, table)].append(table_prefixes) + else: + tables_prefixes[(schema, table)] = table_prefixes + + return tables_prefixes + + @reflection.cache + def get_indexes(self, connection, tablename, schema, **kw): + """ + Gets the indexes definition + """ + table_name = self.normalize_name(str(tablename)) + data = self.get_multi_indexes( + connection=connection, schema=schema, filter_names=[table_name], **kw + ) + + return self._value_or_default(data, table_name, schema) + def connect(self, *cargs, **cparams): return ( super().connect( @@ -911,8 +1028,12 @@ def connect(self, *cargs, **cparams): @sa_vnt.listens_for(Table, "before_create") def check_table(table, connection, _ddl_runner, **kw): + from .sql.custom_schema.hybrid_table import HybridTable + + if HybridTable.is_equal_type(table): # noqa + return True if isinstance(_ddl_runner.dialect, SnowflakeDialect) and table.indexes: - raise NotImplementedError("Snowflake does not support indexes") + raise NotImplementedError("Only Snowflake Hybrid Tables supports indexes") dialect = SnowflakeDialect diff --git a/src/snowflake/sqlalchemy/sql/__init__.py b/src/snowflake/sqlalchemy/sql/__init__.py new file mode 100644 index 00000000..ef416f64 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py b/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py new file mode 100644 index 00000000..66b9270f --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py @@ -0,0 +1,7 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from .dynamic_table import DynamicTable +from .hybrid_table import HybridTable + +__all__ = ["DynamicTable", "HybridTable"] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py new file mode 100644 index 00000000..b61c270d --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py @@ -0,0 +1,64 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import typing +from typing import Any + +from sqlalchemy.exc import ArgumentError +from sqlalchemy.sql.schema import MetaData, SchemaItem, Table + +from ..._constants import DIALECT_NAME +from ...compat import IS_VERSION_20 +from ...custom_commands import NoneType +from .custom_table_prefix import CustomTablePrefix +from .options.table_option import TableOption + + +class CustomTableBase(Table): + __table_prefixes__: typing.List[CustomTablePrefix] = [] + _support_primary_and_foreign_keys: bool = True + + @property + def table_prefixes(self) -> typing.List[str]: + return [prefix.name for prefix in self.__table_prefixes__] + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + if len(self.__table_prefixes__) > 0: + prefixes = kw.get("prefixes", []) + self.table_prefixes + kw.update(prefixes=prefixes) + if not IS_VERSION_20 and hasattr(super(), "_init"): + super()._init(name, metadata, *args, **kw) + else: + super().__init__(name, metadata, *args, **kw) + + if not kw.get("autoload_with", False): + self._validate_table() + + def _validate_table(self): + if not self._support_primary_and_foreign_keys and ( + self.primary_key or self.foreign_keys + ): + raise ArgumentError( + f"Primary key and foreign keys are not supported in {' '.join(self.table_prefixes)} TABLE." + ) + + return True + + def _get_dialect_option(self, option_name: str) -> typing.Optional[TableOption]: + if option_name in self.dialect_options[DIALECT_NAME]: + return self.dialect_options[DIALECT_NAME][option_name] + return NoneType + + @classmethod + def is_equal_type(cls, table: Table) -> bool: + for prefix in cls.__table_prefixes__: + if prefix.name not in table._prefixes: + return False + + return True diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py new file mode 100644 index 00000000..de7835d1 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +from enum import Enum + + +class CustomTablePrefix(Enum): + DEFAULT = 0 + EXTERNAL = 1 + EVENT = 2 + HYBRID = 3 + ICEBERG = 4 + DYNAMIC = 5 diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py new file mode 100644 index 00000000..1a2248fc --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py @@ -0,0 +1,87 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import typing +from typing import Any + +from sqlalchemy.exc import ArgumentError +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .custom_table_prefix import CustomTablePrefix +from .options.target_lag import TargetLag +from .options.warehouse import Warehouse +from .table_from_query import TableFromQueryBase + + +class DynamicTable(TableFromQueryBase): + """ + A class representing a dynamic table with configurable options and settings. + + The `DynamicTable` class allows for the creation and querying of tables with + specific options, such as `Warehouse` and `TargetLag`. + + While it does not support reflection at this time, it provides a flexible + interface for creating dynamic tables and management. + + """ + + __table_prefixes__ = [CustomTablePrefix.DYNAMIC] + + _support_primary_and_foreign_keys = False + + @property + def warehouse(self) -> typing.Optional[Warehouse]: + return self._get_dialect_option(Warehouse.__option_name__) + + @property + def target_lag(self) -> typing.Optional[TargetLag]: + return self._get_dialect_option(TargetLag.__option_name__) + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + super().__init__(name, metadata, *args, **kw) + + def _init( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + super().__init__(name, metadata, *args, **kw) + + def _validate_table(self): + missing_attributes = [] + if self.target_lag is NoneType: + missing_attributes.append("TargetLag") + if self.warehouse is NoneType: + missing_attributes.append("Warehouse") + if self.as_query is NoneType: + missing_attributes.append("AsQuery") + if missing_attributes: + raise ArgumentError( + "DYNAMIC TABLE must have the following arguments: %s" + % ", ".join(missing_attributes) + ) + super()._validate_table() + + def __repr__(self) -> str: + return "DynamicTable(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + [repr(self.target_lag)] + + [repr(self.warehouse)] + + [repr(self.as_query)] + + [f"{k}={repr(getattr(self, k))}" for k in ["schema"]] + ) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py new file mode 100644 index 00000000..bd49a420 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from typing import Any + +from sqlalchemy.exc import ArgumentError +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .custom_table_base import CustomTableBase +from .custom_table_prefix import CustomTablePrefix + + +class HybridTable(CustomTableBase): + """ + A class representing a hybrid table with configurable options and settings. + + The `HybridTable` class allows for the creation and querying of OLTP Snowflake Tables . + + While it does not support reflection at this time, it provides a flexible + interface for creating dynamic tables and management. + """ + + __table_prefixes__ = [CustomTablePrefix.HYBRID] + + _support_primary_and_foreign_keys = True + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + super().__init__(name, metadata, *args, **kw) + + def _init( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + super().__init__(name, metadata, *args, **kw) + + def _validate_table(self): + missing_attributes = [] + if self.key is NoneType: + missing_attributes.append("Primary Key") + if missing_attributes: + raise ArgumentError( + "HYBRID TABLE must have the following arguments: %s" + % ", ".join(missing_attributes) + ) + super()._validate_table() + + def __repr__(self) -> str: + return "HybridTable(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + [f"{k}={repr(getattr(self, k))}" for k in ["schema"]] + ) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py new file mode 100644 index 00000000..052e2d96 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py @@ -0,0 +1,9 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from .as_query import AsQuery +from .target_lag import TargetLag, TimeUnit +from .warehouse import Warehouse + +__all__ = ["Warehouse", "AsQuery", "TargetLag", "TimeUnit"] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py new file mode 100644 index 00000000..68076af9 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py @@ -0,0 +1,62 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Union + +from sqlalchemy.sql import Selectable + +from .table_option import TableOption +from .table_option_base import Priority + + +class AsQuery(TableOption): + """Class to represent an AS clause in tables. + This configuration option is used to specify the query from which the table is created. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-table#create-table-as-select-also-referred-to-as-ctas + + + AsQuery example usage using an input string: + DynamicTable( + "sometable", metadata, + Column("name", String(50)), + Column("address", String(100)), + AsQuery('select name, address from existing_table where name = "test"') + ) + + AsQuery example usage using a selectable statement: + DynamicTable( + "sometable", + Base.metadata, + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery(select(test_table_1).where(test_table_1.c.id == 23)) + ) + + """ + + __option_name__ = "as_query" + __priority__ = Priority.LOWEST + + def __init__(self, query: Union[str, Selectable]) -> None: + r"""Construct an as_query object. + + :param \*expressions: + AS + + """ + self.query = query + + @staticmethod + def template() -> str: + return "AS %s" + + def get_expression(self): + if isinstance(self.query, Selectable): + return self.query.compile(compile_kwargs={"literal_binds": True}) + return self.query + + def render_option(self, compiler) -> str: + return AsQuery.template() % (self.get_expression()) + + def __repr__(self) -> str: + return "Query(%s)" % self.get_expression() diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py new file mode 100644 index 00000000..7ac27575 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py @@ -0,0 +1,26 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Any + +from sqlalchemy import exc +from sqlalchemy.sql.base import SchemaEventTarget +from sqlalchemy.sql.schema import SchemaItem, Table + +from snowflake.sqlalchemy._constants import DIALECT_NAME + +from .table_option_base import TableOptionBase + + +class TableOption(TableOptionBase, SchemaItem): + def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: + if self.__option_name__ == "default": + raise exc.SQLAlchemyError(f"{self.__class__.__name__} does not has a name") + if not isinstance(parent, Table): + raise exc.SQLAlchemyError( + f"{self.__class__.__name__} option can only be applied to Table" + ) + parent.dialect_options[DIALECT_NAME][self.__option_name__] = self + + def _set_table_option_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: + pass diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py new file mode 100644 index 00000000..54008ec8 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py @@ -0,0 +1,30 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from enum import Enum + + +class Priority(Enum): + LOWEST = 0 + VERY_LOW = 1 + LOW = 2 + MEDIUM = 4 + HIGH = 6 + VERY_HIGH = 7 + HIGHEST = 8 + + +class TableOptionBase: + __option_name__ = "default" + __visit_name__ = __option_name__ + __priority__ = Priority.MEDIUM + + @staticmethod + def template() -> str: + raise NotImplementedError + + def get_expression(self): + raise NotImplementedError + + def render_option(self, compiler) -> str: + raise NotImplementedError diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py new file mode 100644 index 00000000..4331a4cb --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py @@ -0,0 +1,60 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# from enum import Enum +from enum import Enum +from typing import Optional + +from .table_option import TableOption +from .table_option_base import Priority + + +class TimeUnit(Enum): + SECONDS = "seconds" + MINUTES = "minutes" + HOURS = "hour" + DAYS = "days" + + +class TargetLag(TableOption): + """Class to represent the target lag clause. + This configuration option is used to specify the target lag time for the dynamic table. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table + + + Target lag example usage: + DynamicTable("sometable", metadata, + Column("name", String(50)), + Column("address", String(100)), + TargetLag(20, TimeUnit.MINUTES), + ) + """ + + __option_name__ = "target_lag" + __priority__ = Priority.HIGH + + def __init__( + self, + time: Optional[int] = 0, + unit: Optional[TimeUnit] = TimeUnit.MINUTES, + down_stream: Optional[bool] = False, + ) -> None: + self.time = time + self.unit = unit + self.down_stream = down_stream + + @staticmethod + def template() -> str: + return "TARGET_LAG = %s" + + def get_expression(self): + return ( + f"'{str(self.time)} {str(self.unit.value)}'" + if not self.down_stream + else "DOWNSTREAM" + ) + + def render_option(self, compiler) -> str: + return TargetLag.template() % (self.get_expression()) + + def __repr__(self) -> str: + return "TargetLag(%s)" % self.get_expression() diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py new file mode 100644 index 00000000..a5b8cce0 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional + +from .table_option import TableOption +from .table_option_base import Priority + + +class Warehouse(TableOption): + """Class to represent the warehouse clause. + This configuration option is used to specify the warehouse for the dynamic table. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table + + + Warehouse example usage: + DynamicTable("sometable", metadata, + Column("name", String(50)), + Column("address", String(100)), + Warehouse('my_warehouse_name') + ) + """ + + __option_name__ = "warehouse" + __priority__ = Priority.HIGH + + def __init__( + self, + name: Optional[str], + ) -> None: + r"""Construct a Warehouse object. + + :param \*expressions: + Dynamic table warehouse option. + WAREHOUSE = + + """ + self.name = name + + @staticmethod + def template() -> str: + return "WAREHOUSE = %s" + + def get_expression(self): + return self.name + + def render_option(self, compiler) -> str: + return Warehouse.template() % (self.get_expression()) + + def __repr__(self) -> str: + return "Warehouse(%s)" % self.get_expression() diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py new file mode 100644 index 00000000..60e8995f --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py @@ -0,0 +1,60 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import typing +from typing import Any, Optional + +from sqlalchemy.sql import Selectable +from sqlalchemy.sql.schema import Column, MetaData, SchemaItem +from sqlalchemy.util import NoneType + +from .custom_table_base import CustomTableBase +from .options.as_query import AsQuery + + +class TableFromQueryBase(CustomTableBase): + + @property + def as_query(self): + return self._get_dialect_option(AsQuery.__option_name__) + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + items = [item for item in args] + as_query: AsQuery = self.__get_as_query_from_items(items) + if ( + as_query is not NoneType + and isinstance(as_query.query, Selectable) + and not self.__has_defined_columns(items) + ): + columns = self.__create_columns_from_selectable(as_query.query) + args = items + columns + super().__init__(name, metadata, *args, **kw) + + def __get_as_query_from_items( + self, items: typing.List[SchemaItem] + ) -> Optional[AsQuery]: + for item in items: + if isinstance(item, AsQuery): + return item + return NoneType + + def __has_defined_columns(self, items: typing.List[SchemaItem]) -> bool: + for item in items: + if isinstance(item, Column): + return True + + def __create_columns_from_selectable( + self, selectable: Selectable + ) -> Optional[typing.List[Column]]: + if not isinstance(selectable, Selectable): + return + columns: typing.List[Column] = [] + for _, c in selectable.exported_columns.items(): + columns += [Column(c.name, c.type)] + return columns diff --git a/tests/__snapshots__/test_compile_dynamic_table.ambr b/tests/__snapshots__/test_compile_dynamic_table.ambr new file mode 100644 index 00000000..81c7f90f --- /dev/null +++ b/tests/__snapshots__/test_compile_dynamic_table.ambr @@ -0,0 +1,13 @@ +# serializer version: 1 +# name: test_compile_dynamic_table + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm + "CREATE DYNAMIC TABLE test_dynamic_table_orm (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm_with_str_keys + "CREATE DYNAMIC TABLE test_dynamic_table_orm_2 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_selectable + "CREATE DYNAMIC TABLE dynamic_test_table_1 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT test_table_1.id, test_table_1.name FROM test_table_1 WHERE test_table_1.id = 23" +# --- diff --git a/tests/__snapshots__/test_orm.ambr b/tests/__snapshots__/test_orm.ambr new file mode 100644 index 00000000..2116e9e9 --- /dev/null +++ b/tests/__snapshots__/test_orm.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_orm_one_to_many_relationship_with_hybrid_table + ProgrammingError('(snowflake.connector.errors.ProgrammingError) 200009 (22000): Foreign key constraint "SYS_INDEX_HB_TBL_ADDRESS_FOREIGN_KEY_USER_ID_HB_TBL_USER_ID" was violated.') +# --- diff --git a/tests/__snapshots__/test_reflect_dynamic_table.ambr b/tests/__snapshots__/test_reflect_dynamic_table.ambr new file mode 100644 index 00000000..d4cc22b5 --- /dev/null +++ b/tests/__snapshots__/test_reflect_dynamic_table.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_compile_dynamic_table + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- diff --git a/tests/custom_tables/__init__.py b/tests/custom_tables/__init__.py new file mode 100644 index 00000000..d43f066c --- /dev/null +++ b/tests/custom_tables/__init__.py @@ -0,0 +1,2 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. diff --git a/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr b/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr new file mode 100644 index 00000000..81c7f90f --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr @@ -0,0 +1,13 @@ +# serializer version: 1 +# name: test_compile_dynamic_table + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm + "CREATE DYNAMIC TABLE test_dynamic_table_orm (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm_with_str_keys + "CREATE DYNAMIC TABLE test_dynamic_table_orm_2 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_selectable + "CREATE DYNAMIC TABLE dynamic_test_table_1 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT test_table_1.id, test_table_1.name FROM test_table_1 WHERE test_table_1.id = 23" +# --- diff --git a/tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr b/tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr new file mode 100644 index 00000000..9412fb45 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr @@ -0,0 +1,7 @@ +# serializer version: 1 +# name: test_compile_hybrid_table + 'CREATE HYBRID TABLE test_hybrid_table (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tgeom GEOMETRY, \tPRIMARY KEY (id))' +# --- +# name: test_compile_hybrid_table_orm + 'CREATE HYBRID TABLE test_hybrid_table_orm (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tPRIMARY KEY (id))' +# --- diff --git a/tests/custom_tables/__snapshots__/test_create_hybrid_table.ambr b/tests/custom_tables/__snapshots__/test_create_hybrid_table.ambr new file mode 100644 index 00000000..696ff9c8 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_create_hybrid_table.ambr @@ -0,0 +1,7 @@ +# serializer version: 1 +# name: test_create_hybrid_table + "[(1, 'test')]" +# --- +# name: test_create_hybrid_table_with_multiple_index + ProgrammingError("(snowflake.connector.errors.ProgrammingError) 391480 (0A000): Another index is being built on table 'TEST_HYBRID_TABLE_WITH_MULTIPLE_INDEX'. Only one index can be built at a time. Either cancel the other index creation or wait until it is complete.") +# --- diff --git a/tests/custom_tables/__snapshots__/test_reflect_hybrid_table.ambr b/tests/custom_tables/__snapshots__/test_reflect_hybrid_table.ambr new file mode 100644 index 00000000..6f6cd395 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_reflect_hybrid_table.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_simple_reflection_hybrid_table_as_table + 'CREATE TABLE test_hybrid_table_reflection (\tid DECIMAL(38, 0) NOT NULL, \tname VARCHAR(16777216), \tCONSTRAINT demo_name PRIMARY KEY (id))' +# --- diff --git a/tests/custom_tables/test_compile_dynamic_table.py b/tests/custom_tables/test_compile_dynamic_table.py new file mode 100644 index 00000000..16a039e7 --- /dev/null +++ b/tests/custom_tables/test_compile_dynamic_table.py @@ -0,0 +1,179 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import ( + Column, + ForeignKeyConstraint, + Integer, + MetaData, + String, + Table, + exc, + select, +) +from sqlalchemy.orm import declarative_base +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import GEOMETRY, DynamicTable +from snowflake.sqlalchemy.sql.custom_schema.options.as_query import AsQuery +from snowflake.sqlalchemy.sql.custom_schema.options.target_lag import ( + TargetLag, + TimeUnit, +) +from snowflake.sqlalchemy.sql.custom_schema.options.warehouse import Warehouse + + +def test_compile_dynamic_table(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + test_geometry = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery("SELECT * FROM table"), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_without_required_args(sql_compiler): + with pytest.raises( + exc.ArgumentError, + match="DYNAMIC TABLE must have the following arguments: TargetLag, " + "Warehouse, AsQuery", + ): + DynamicTable( + "test_dynamic_table", + MetaData(), + Column("id", Integer, primary_key=True), + Column("geom", GEOMETRY), + ) + + +def test_compile_dynamic_table_with_primary_key(sql_compiler): + with pytest.raises( + exc.ArgumentError, + match="Primary key and foreign keys are not supported in DYNAMIC TABLE.", + ): + DynamicTable( + "test_dynamic_table", + MetaData(), + Column("id", Integer, primary_key=True), + Column("geom", GEOMETRY), + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery("SELECT * FROM table"), + ) + + +def test_compile_dynamic_table_with_foreign_key(sql_compiler): + with pytest.raises( + exc.ArgumentError, + match="Primary key and foreign keys are not supported in DYNAMIC TABLE.", + ): + DynamicTable( + "test_dynamic_table", + MetaData(), + Column("id", Integer), + Column("geom", GEOMETRY), + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery("SELECT * FROM table"), + ForeignKeyConstraint(["id"], ["table.id"]), + ) + + +def test_compile_dynamic_table_orm(sql_compiler, snapshot): + Base = declarative_base() + metadata = MetaData() + table_name = "test_dynamic_table_orm" + test_dynamic_table_orm = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("name", String), + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery("SELECT * FROM table"), + ) + + class TestDynamicTableOrm(Base): + __table__ = test_dynamic_table_orm + __mapper_args__ = { + "primary_key": [test_dynamic_table_orm.c.id, test_dynamic_table_orm.c.name] + } + + def __repr__(self): + return f"" + + value = CreateTable(TestDynamicTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_orm_with_str_keys(sql_compiler, db_parameters, snapshot): + Base = declarative_base() + schema = db_parameters["schema"] + + class TestDynamicTableOrm(Base): + __tablename__ = "test_dynamic_table_orm_2" + __table_args__ = {"schema": schema} + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return DynamicTable(name, metadata, *arg, **kw) + + __table_args__ = ( + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery("SELECT * FROM table"), + ) + + id = Column(Integer) + name = Column(String) + + __mapper_args__ = {"primary_key": [id, name]} + + def __repr__(self): + return f"" + + value = CreateTable(TestDynamicTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_with_selectable(sql_compiler, snapshot): + Base = declarative_base() + + test_table_1 = Table( + "test_table_1", + Base.metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + ) + + dynamic_test_table = DynamicTable( + "dynamic_test_table_1", + Base.metadata, + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery(select(test_table_1).where(test_table_1.c.id == 23)), + ) + + value = CreateTable(dynamic_test_table) + + actual = sql_compiler(value) + + assert actual == snapshot diff --git a/tests/custom_tables/test_compile_hybrid_table.py b/tests/custom_tables/test_compile_hybrid_table.py new file mode 100644 index 00000000..f1af6dc2 --- /dev/null +++ b/tests/custom_tables/test_compile_hybrid_table.py @@ -0,0 +1,52 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import Column, Integer, MetaData, String +from sqlalchemy.orm import declarative_base +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import GEOMETRY, HybridTable + + +@pytest.mark.aws +def test_compile_hybrid_table(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_hybrid_table" + test_geometry = HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + Column("geom", GEOMETRY), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +@pytest.mark.aws +def test_compile_hybrid_table_orm(sql_compiler, snapshot): + Base = declarative_base() + + class TestHybridTableOrm(Base): + __tablename__ = "test_hybrid_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"" + + value = CreateTable(TestHybridTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot diff --git a/tests/custom_tables/test_create_dynamic_table.py b/tests/custom_tables/test_create_dynamic_table.py new file mode 100644 index 00000000..4e6c48ca --- /dev/null +++ b/tests/custom_tables/test_create_dynamic_table.py @@ -0,0 +1,93 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from sqlalchemy import Column, Integer, MetaData, String, Table, select + +from snowflake.sqlalchemy import DynamicTable +from snowflake.sqlalchemy.sql.custom_schema.options.as_query import AsQuery +from snowflake.sqlalchemy.sql.custom_schema.options.target_lag import ( + TargetLag, + TimeUnit, +) +from snowflake.sqlalchemy.sql.custom_schema.options.warehouse import Warehouse + + +def test_create_dynamic_table(engine_testaccount, db_parameters): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + + dynamic_test_table_1 = DynamicTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + TargetLag(1, TimeUnit.HOURS), + Warehouse(warehouse), + AsQuery("SELECT id, name from test_table_1;"), + ) + + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table_1) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) + + +def test_create_dynamic_table_without_dynamictable_class( + engine_testaccount, db_parameters +): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + + dynamic_test_table_1 = Table( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + TargetLag(1, TimeUnit.HOURS), + Warehouse(warehouse), + AsQuery("SELECT id, name from test_table_1;"), + prefixes=["DYNAMIC"], + ) + + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table_1) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/custom_tables/test_create_hybrid_table.py b/tests/custom_tables/test_create_hybrid_table.py new file mode 100644 index 00000000..43ae3ab6 --- /dev/null +++ b/tests/custom_tables/test_create_hybrid_table.py @@ -0,0 +1,95 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +import sqlalchemy.exc +from sqlalchemy import Column, Index, Integer, MetaData, String, select +from sqlalchemy.orm import Session, declarative_base + +from snowflake.sqlalchemy import HybridTable + + +@pytest.mark.aws +def test_create_hybrid_table(engine_testaccount, db_parameters, snapshot): + metadata = MetaData() + table_name = "test_create_hybrid_table" + + dynamic_test_table_1 = HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = dynamic_test_table_1.insert().values(id=1, name="test") + conn.execute(ins) + conn.commit() + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table_1) + results_hybrid_table = conn.execute(s).fetchall() + assert str(results_hybrid_table) == snapshot + finally: + metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_create_hybrid_table_with_multiple_index( + engine_testaccount, db_parameters, snapshot, sql_compiler +): + metadata = MetaData() + table_name = "test_hybrid_table_with_multiple_index" + + hybrid_test_table_1 = HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String, index=True), + Column("name2", String), + Column("name3", String), + ) + + metadata.create_all(engine_testaccount) + + index = Index("idx_col34", hybrid_test_table_1.c.name2, hybrid_test_table_1.c.name3) + + with pytest.raises(sqlalchemy.exc.ProgrammingError) as exc_info: + index.create(engine_testaccount) + try: + assert exc_info.value == snapshot + finally: + metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_create_hybrid_table_with_orm(sql_compiler, engine_testaccount): + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestHybridTableOrm(Base): + __tablename__ = "test_hybrid_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"({self.id!r}, {self.name!r})" + + Base.metadata.create_all(engine_testaccount) + + try: + instance = TestHybridTableOrm(id=0, name="name_example") + session.add(instance) + session.commit() + data = session.query(TestHybridTableOrm).all() + assert str(data) == "[(0, 'name_example')]" + finally: + Base.metadata.drop_all(engine_testaccount) diff --git a/tests/custom_tables/test_reflect_dynamic_table.py b/tests/custom_tables/test_reflect_dynamic_table.py new file mode 100644 index 00000000..8a4a8445 --- /dev/null +++ b/tests/custom_tables/test_reflect_dynamic_table.py @@ -0,0 +1,88 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from sqlalchemy import Column, Integer, MetaData, String, Table, select + +from snowflake.sqlalchemy import DynamicTable +from snowflake.sqlalchemy.custom_commands import NoneType + + +def test_simple_reflection_dynamic_table_as_table(engine_testaccount, db_parameters): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + create_table_sql = f""" + CREATE DYNAMIC TABLE dynamic_test_table (id INT, name VARCHAR) + TARGET_LAG = '20 minutes' + WAREHOUSE = {warehouse} + AS SELECT id, name from test_table_1; + """ + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + dynamic_test_table = Table( + "dynamic_test_table", metadata, autoload_with=engine_testaccount + ) + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) + + +def test_simple_reflection_without_options_loading(engine_testaccount, db_parameters): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + create_table_sql = f""" + CREATE DYNAMIC TABLE dynamic_test_table (id INT, name VARCHAR) + TARGET_LAG = '20 minutes' + WAREHOUSE = {warehouse} + AS SELECT id, name from test_table_1; + """ + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + dynamic_test_table = DynamicTable( + "dynamic_test_table", metadata, autoload_with=engine_testaccount + ) + + # TODO: Add support for loading options when table is reflected + assert dynamic_test_table.warehouse is NoneType + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/custom_tables/test_reflect_hybrid_table.py b/tests/custom_tables/test_reflect_hybrid_table.py new file mode 100644 index 00000000..4a777bf0 --- /dev/null +++ b/tests/custom_tables/test_reflect_hybrid_table.py @@ -0,0 +1,65 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import MetaData, Table +from sqlalchemy.sql.ddl import CreateTable + + +@pytest.mark.aws +def test_simple_reflection_hybrid_table_as_table( + engine_testaccount, db_parameters, sql_compiler, snapshot +): + metadata = MetaData() + table_name = "test_hybrid_table_reflection" + + create_table_sql = f""" + CREATE HYBRID TABLE {table_name} (id INT primary key, name VARCHAR, INDEX index_name (name)); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + hybrid_test_table = Table(table_name, metadata, autoload_with=engine_testaccount) + + constraint = hybrid_test_table.constraints.pop() + constraint.name = "demo_name" + hybrid_test_table.constraints.add(constraint) + + try: + with engine_testaccount.connect(): + value = CreateTable(hybrid_test_table) + + actual = sql_compiler(value) + + # Prefixes reflection not supported, example: "HYBRID, DYNAMIC" + assert actual == snapshot + + finally: + metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_reflect_hybrid_table_with_index( + engine_testaccount, db_parameters, sql_compiler +): + metadata = MetaData() + schema = db_parameters["schema"] + + table_name = "test_hybrid_table_2" + index_name = "INDEX_NAME_2" + + create_table_sql = f""" + CREATE HYBRID TABLE {table_name} (id INT primary key, name VARCHAR, INDEX {index_name} (name)); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + table = Table(table_name, metadata, schema=schema, autoload_with=engine_testaccount) + + try: + assert len(table.indexes) == 1 and table.indexes.pop().name == index_name + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/test_core.py b/tests/test_core.py index 179133c8..15840838 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -502,19 +502,20 @@ def test_inspect_column(engine_testaccount): users.drop(engine_testaccount) -def test_get_indexes(engine_testaccount): +def test_get_indexes(engine_testaccount, db_parameters): """ Tests get indexes - NOTE: Snowflake doesn't support indexes + NOTE: Only Snowflake Hybrid Tables support indexes """ + schema = db_parameters["schema"] metadata = MetaData() users, addresses = _create_users_addresses_tables_without_sequence( engine_testaccount, metadata ) try: inspector = inspect(engine_testaccount) - assert inspector.get_indexes("users") == [] + assert inspector.get_indexes("users", schema) == [] finally: addresses.drop(engine_testaccount) diff --git a/tests/test_custom_types.py b/tests/test_custom_types.py index a997ffe8..3961a5d3 100644 --- a/tests/test_custom_types.py +++ b/tests/test_custom_types.py @@ -2,7 +2,10 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from snowflake.sqlalchemy import custom_types +import pytest +from sqlalchemy import Column, Integer, MetaData, Table, text + +from snowflake.sqlalchemy import TEXT, custom_types def test_string_conversions(): @@ -34,3 +37,31 @@ def test_string_conversions(): sample = getattr(custom_types, type_)() if type_ in sf_custom_types: assert type_ == str(sample) + + +@pytest.mark.feature_max_lob_size +def test_create_table_with_text_type(engine_testaccount): + metadata = MetaData() + table_name = "test_max_lob_size_0" + test_max_lob_size = Table( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("full_name", TEXT(), server_default=text("id::varchar")), + ) + + metadata.create_all(engine_testaccount) + try: + assert test_max_lob_size is not None + + with engine_testaccount.connect() as conn: + with conn.begin(): + query = text(f"SELECT GET_DDL('TABLE', '{table_name}')") + result = conn.execute(query) + row = str(result.mappings().fetchone()) + assert ( + "VARCHAR(134217728)" in row + ), f"Expected VARCHAR(134217728) in {row}" + + finally: + test_max_lob_size.drop(engine_testaccount) diff --git a/tests/test_index_reflection.py b/tests/test_index_reflection.py new file mode 100644 index 00000000..09f5cfe7 --- /dev/null +++ b/tests/test_index_reflection.py @@ -0,0 +1,34 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import MetaData +from sqlalchemy.engine import reflection + + +@pytest.mark.aws +def test_indexes_reflection(engine_testaccount, db_parameters, sql_compiler): + metadata = MetaData() + + table_name = "test_hybrid_table_2" + index_name = "INDEX_NAME_2" + schema = db_parameters["schema"] + + create_table_sql = f""" + CREATE HYBRID TABLE {table_name} (id INT primary key, name VARCHAR, INDEX {index_name} (name)); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + insp = reflection.Inspector.from_engine(engine_testaccount) + + try: + with engine_testaccount.connect(): + # Prefixes reflection not supported, example: "HYBRID, DYNAMIC" + indexes = insp.get_indexes(table_name, schema) + assert len(indexes) == 1 + assert indexes[0].get("name") == index_name + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/test_orm.py b/tests/test_orm.py index f53cd708..cb3a7768 100644 --- a/tests/test_orm.py +++ b/tests/test_orm.py @@ -7,18 +7,22 @@ import pytest from sqlalchemy import ( + TEXT, Column, Enum, ForeignKey, Integer, Sequence, String, + exc, func, select, text, ) from sqlalchemy.orm import Session, declarative_base, relationship +from snowflake.sqlalchemy import HybridTable + def test_basic_orm(engine_testaccount): """ @@ -55,14 +59,15 @@ def __repr__(self): Base.metadata.drop_all(engine_testaccount) -def test_orm_one_to_many_relationship(engine_testaccount): +def test_orm_one_to_many_relationship(engine_testaccount, db_parameters): """ Tests One to Many relationship """ Base = declarative_base() + prefix = "tbl_" class User(Base): - __tablename__ = "user" + __tablename__ = prefix + "user" id = Column(Integer, Sequence("user_id_seq"), primary_key=True) name = Column(String) @@ -72,13 +77,13 @@ def __repr__(self): return f"" class Address(Base): - __tablename__ = "address" + __tablename__ = prefix + "address" id = Column(Integer, Sequence("address_id_seq"), primary_key=True) email_address = Column(String, nullable=False) - user_id = Column(Integer, ForeignKey("user.id")) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) - user = relationship("User", backref="addresses") + user = relationship(User, backref="addresses") def __repr__(self): return f"" @@ -122,14 +127,143 @@ def __repr__(self): Base.metadata.drop_all(engine_testaccount) +@pytest.mark.aws +def test_orm_one_to_many_relationship_with_hybrid_table(engine_testaccount, snapshot): + """ + Tests One to Many relationship + """ + Base = declarative_base() + + class User(Base): + __tablename__ = "hb_tbl_user" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + name = Column(String) + fullname = Column(String) + + def __repr__(self): + return f"" + + class Address(Base): + __tablename__ = "hb_tbl_address" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, Sequence("address_id_seq"), primary_key=True) + email_address = Column(String, nullable=False) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) + + user = relationship(User, backref="addresses") + + def __repr__(self): + return f"" + + Base.metadata.create_all(engine_testaccount) + + try: + jack = User(name="jack", fullname="Jack Bean") + assert jack.addresses == [], "one to many record is empty list" + + jack.addresses = [ + Address(email_address="jack@gmail.com"), + Address(email_address="j25@yahoo.com"), + Address(email_address="jack@hotmail.com"), + ] + + session = Session(bind=engine_testaccount) + session.add(jack) # cascade each Address into the Session as well + session.commit() + + session.delete(jack) + + with pytest.raises(exc.ProgrammingError) as exc_info: + session.query(Address).all() + + assert exc_info.value == snapshot, "Iceberg Table enforce FK constraint" + + finally: + Base.metadata.drop_all(engine_testaccount) + + def test_delete_cascade(engine_testaccount): """ Test delete cascade """ Base = declarative_base() + prefix = "tbl_" class User(Base): - __tablename__ = "user" + __tablename__ = prefix + "user" + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + name = Column(String) + fullname = Column(String) + + addresses = relationship( + "Address", back_populates="user", cascade="all, delete, delete-orphan" + ) + + def __repr__(self): + return f"" + + class Address(Base): + __tablename__ = prefix + "address" + + id = Column(Integer, Sequence("address_id_seq"), primary_key=True) + email_address = Column(String, nullable=False) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) + + user = relationship(User, back_populates="addresses") + + def __repr__(self): + return f"" + + Base.metadata.create_all(engine_testaccount) + + try: + jack = User(name="jack", fullname="Jack Bean") + assert jack.addresses == [], "one to many record is empty list" + + jack.addresses = [ + Address(email_address="jack@gmail.com"), + Address(email_address="j25@yahoo.com"), + Address(email_address="jack@hotmail.com"), + ] + + session = Session(bind=engine_testaccount) + session.add(jack) # cascade each Address into the Session as well + session.commit() + + got_jack = session.query(User).first() + assert got_jack == jack + + session.delete(jack) + got_addresses = session.query(Address).all() + assert len(got_addresses) == 0, "no address record" + finally: + Base.metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_delete_cascade_hybrid_table(engine_testaccount): + """ + Test delete cascade + """ + Base = declarative_base() + prefix = "hb_tbl_" + + class User(Base): + __tablename__ = prefix + "user" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) id = Column(Integer, Sequence("user_id_seq"), primary_key=True) name = Column(String) @@ -143,13 +277,17 @@ def __repr__(self): return f"" class Address(Base): - __tablename__ = "address" + __tablename__ = prefix + "address" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) id = Column(Integer, Sequence("address_id_seq"), primary_key=True) email_address = Column(String, nullable=False) - user_id = Column(Integer, ForeignKey("user.id")) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) - user = relationship("User", back_populates="addresses") + user = relationship(User, back_populates="addresses") def __repr__(self): return f"" @@ -413,3 +551,34 @@ class Employee(Base): '[SELECT "Employee".uid FROM "Employee" JOIN LATERAL flatten(PARSE_JSON("Employee"' in caplog.text ) + + +@pytest.mark.feature_max_lob_size +def test_basic_table_with_large_lob_size_in_memory(engine_testaccount, sql_compiler): + Base = declarative_base() + + class User(Base): + __tablename__ = "user" + + id = Column(Integer, primary_key=True) + full_name = Column(TEXT(), server_default=text("id::varchar")) + + def __repr__(self): + return f"" + + Base.metadata.create_all(engine_testaccount) + + try: + assert User.__table__ is not None + + with engine_testaccount.connect() as conn: + with conn.begin(): + query = text(f"SELECT GET_DDL('TABLE', '{User.__tablename__}')") + result = conn.execute(query) + row = str(result.mappings().fetchone()) + assert ( + "VARCHAR(134217728)" in row + ), f"Expected VARCHAR(134217728) in {row}" + + finally: + Base.metadata.drop_all(engine_testaccount) diff --git a/tests/test_pandas.py b/tests/test_pandas.py index 63cd6d0e..2a6b9f1b 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -169,7 +169,7 @@ def test_no_indexes(engine_testaccount, db_parameters): con=conn, if_exists="replace", ) - assert str(exc.value) == "Snowflake does not support indexes" + assert str(exc.value) == "Only Snowflake Hybrid Tables supports indexes" def test_timezone(db_parameters, engine_testaccount, engine_testaccount_with_numpy):