From d78f0c07c1701fa9889350b9cee31ae188b7fd71 Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Tue, 2 Jul 2024 15:24:48 +0200 Subject: [PATCH 01/21] Snow 1058245 SqlAlchemy 2.0 support (#469) SNOW-1058245-sqlalchemy-20-support: Add support for installation SQLAlchemy 2.0 --- .github/workflows/build_test.yml | 108 +++++++--- .github/workflows/create_req_files.yml | 6 +- .github/workflows/jira_close.yml | 2 +- .github/workflows/jira_comment.yml | 4 +- .github/workflows/jira_issue.yml | 4 +- .github/workflows/python-publish.yml | 2 +- .github/workflows/stale_issue_bot.yml | 2 +- DESCRIPTION.md | 6 +- pyproject.toml | 14 +- snyk/requirements.txt | 2 +- snyk/requiremtnts.txt | 2 + src/snowflake/sqlalchemy/base.py | 44 ++-- src/snowflake/sqlalchemy/compat.py | 36 ++++ src/snowflake/sqlalchemy/custom_commands.py | 3 +- src/snowflake/sqlalchemy/functions.py | 16 ++ src/snowflake/sqlalchemy/requirements.py | 16 ++ src/snowflake/sqlalchemy/snowdialect.py | 67 +++--- src/snowflake/sqlalchemy/util.py | 12 +- src/snowflake/sqlalchemy/version.py | 2 +- tests/conftest.py | 32 +-- tests/sqlalchemy_test_suite/conftest.py | 7 + tests/sqlalchemy_test_suite/test_suite.py | 4 + tests/sqlalchemy_test_suite/test_suite_20.py | 205 +++++++++++++++++++ tests/test_compiler.py | 2 +- tests/test_core.py | 85 +++----- tests/test_custom_functions.py | 25 +++ tests/test_orm.py | 42 ++-- tests/test_pandas.py | 11 +- tests/test_qmark.py | 4 +- tox.ini | 10 +- 30 files changed, 558 insertions(+), 217 deletions(-) create mode 100644 snyk/requiremtnts.txt create mode 100644 src/snowflake/sqlalchemy/compat.py create mode 100644 src/snowflake/sqlalchemy/functions.py create mode 100644 tests/sqlalchemy_test_suite/test_suite_20.py create mode 100644 tests/test_custom_functions.py diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index be19f1f1..3baa6a0d 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -33,8 +33,8 @@ jobs: python-version: '3.8' - name: Upgrade and install tools run: | - python -m pip install -U pip - python -m pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch python -m hatch env create default - name: Set PY run: echo "PY=$(hatch run gh-cache-sum)" >> $GITHUB_ENV @@ -49,6 +49,10 @@ jobs: name: Test package build and installation runs-on: ubuntu-latest needs: lint + strategy: + fail-fast: true + matrix: + hatch-env: [default, sa20] steps: - uses: actions/checkout@v4 with: @@ -59,15 +63,14 @@ jobs: python-version: '3.8' - name: Upgrade and install tools run: | - python -m pip install -U pip - python -m pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch - name: Build package run: | - python -m hatch clean - python -m hatch build + python -m hatch -e ${{ matrix.hatch-env }} build --clean - name: Install and check import run: | - python -m pip install dist/snowflake_sqlalchemy-*.whl + python -m uv pip install dist/snowflake_sqlalchemy-*.whl python -c "import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)" test-dialect: @@ -79,7 +82,7 @@ jobs: matrix: os: [ ubuntu-latest, - macos-latest, + macos-13, windows-latest, ] python-version: ["3.8"] @@ -98,8 +101,8 @@ jobs: python-version: ${{ matrix.python-version }} - name: Upgrade pip and prepare environment run: | - python -m pip install -U pip - python -m pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch python -m hatch env create default - name: Setup parameters file shell: bash @@ -125,7 +128,7 @@ jobs: matrix: os: [ ubuntu-latest, - macos-latest, + macos-13, windows-latest, ] python-version: ["3.8"] @@ -144,8 +147,8 @@ jobs: python-version: ${{ matrix.python-version }} - name: Upgrade pip and install hatch run: | - python -m pip install -U pip - python -m pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch python -m hatch env create default - name: Setup parameters file shell: bash @@ -162,8 +165,8 @@ jobs: path: | ./coverage.xml - test-dialect-run-v20: - name: Test dialect run v20 ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + test-dialect-v20: + name: Test dialect v20 ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} needs: [ lint, build-install ] runs-on: ${{ matrix.os }} strategy: @@ -171,7 +174,7 @@ jobs: matrix: os: [ ubuntu-latest, - macos-latest, + macos-13, windows-latest, ] python-version: ["3.8"] @@ -197,21 +200,67 @@ jobs: .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py - name: Upgrade pip and install hatch run: | - python -m pip install -U pip - python -m pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch python -m hatch env create default - name: Run tests - run: hatch run test-run_v20 + run: hatch run sa20:test-dialect - uses: actions/upload-artifact@v4 with: - name: coverage.xml_dialect-run-20-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + name: coverage.xml_dialect-v20-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + path: | + ./coverage.xml + + test-dialect-compatibility-v20: + name: Test dialect v20 compatibility ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + needs: lint + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ + ubuntu-latest, + macos-13, + windows-latest, + ] + python-version: ["3.8"] + cloud-provider: [ + aws, + azure, + gcp, + ] + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Upgrade pip and install hatch + run: | + python -m pip install -U uv + python -m uv pip install -U hatch + python -m hatch env create default + - name: Setup parameters file + shell: bash + env: + PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} + run: | + gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ + .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py + - name: Run tests + run: hatch run sa20:test-dialect-compatibility + - uses: actions/upload-artifact@v4 + with: + name: coverage.xml_dialect-v20-compatibility-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} path: | ./coverage.xml combine-coverage: name: Combine coverage if: ${{ success() || failure() }} - needs: [test-dialect, test-dialect-compatibility, test-dialect-run-v20] + needs: [test-dialect, test-dialect-compatibility, test-dialect-v20, test-dialect-compatibility-v20] runs-on: ubuntu-latest steps: - name: Set up Python @@ -220,8 +269,8 @@ jobs: python-version: "3.8" - name: Prepare environment run: | - pip install -U pip - pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch hatch env create default - uses: actions/checkout@v4 with: @@ -233,22 +282,15 @@ jobs: run: | hatch run coverage combine -a artifacts/coverage.xml_*/coverage.xml hatch run coverage report -m - hatch run coverage xml -o combined_coverage.xml - hatch run coverage html -d htmlcov - name: Store coverage reports uses: actions/upload-artifact@v4 with: - name: combined_coverage.xml - path: combined_coverage.xml - - name: Store htmlcov report - uses: actions/upload-artifact@v4 - with: - name: combined_htmlcov - path: htmlcov + name: coverage.xml + path: coverage.xml - name: Uplaod to codecov uses: codecov/codecov-action@v4 with: - file: combined_coverage.xml + file: coverage.xml env_vars: OS,PYTHON fail_ci_if_error: false flags: unittests diff --git a/.github/workflows/create_req_files.yml b/.github/workflows/create_req_files.yml index 618b3024..2cb7a371 100644 --- a/.github/workflows/create_req_files.yml +++ b/.github/workflows/create_req_files.yml @@ -21,10 +21,10 @@ jobs: - name: Display Python version run: python -c "import sys; print(sys.version)" - name: Upgrade setuptools, pip and wheel - run: python -m pip install -U setuptools pip wheel + run: python -m pip install -U setuptools pip wheel uv - name: Install Snowflake SQLAlchemy shell: bash - run: python -m pip install . + run: python -m uv pip install . - name: Generate reqs file name shell: bash run: echo "requirements_file=temp_requirement/requirements_$(python -c 'from sys import version_info;print(str(version_info.major)+str(version_info.minor))').reqs" >> $GITHUB_ENV @@ -34,7 +34,7 @@ jobs: mkdir temp_requirement echo "# Generated on: $(python --version)" >${{ env.requirements_file }} python -m pip freeze | grep -v snowflake-sqlalchemy 1>>${{ env.requirements_file }} 2>/dev/null - echo "snowflake-sqlalchemy==$(python -m pip show snowflake-sqlalchemy | grep ^Version | cut -d' ' -f2-)" >>${{ env.requirements_file }} + echo "snowflake-sqlalchemy==$(python -m uv pip show snowflake-sqlalchemy | grep ^Version | cut -d' ' -f2-)" >>${{ env.requirements_file }} id: create-reqs-file - name: Show created req file shell: bash diff --git a/.github/workflows/jira_close.yml b/.github/workflows/jira_close.yml index 5b170d75..7862f483 100644 --- a/.github/workflows/jira_close.yml +++ b/.github/workflows/jira_close.yml @@ -17,7 +17,7 @@ jobs: token: ${{ secrets.SNOWFLAKE_GITHUB_TOKEN }} # stored in GitHub secrets path: . - name: Jira login - uses: atlassian/gajira-login@master + uses: atlassian/gajira-login@v3 env: JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} diff --git a/.github/workflows/jira_comment.yml b/.github/workflows/jira_comment.yml index 954929fa..8533c14c 100644 --- a/.github/workflows/jira_comment.yml +++ b/.github/workflows/jira_comment.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Jira login - uses: atlassian/gajira-login@master + uses: atlassian/gajira-login@v3 env: JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} @@ -22,7 +22,7 @@ jobs: jira=$(echo -n $TITLE | awk '{print $1}' | sed -e 's/://') echo ::set-output name=jira::$jira - name: Comment on issue - uses: atlassian/gajira-comment@master + uses: atlassian/gajira-comment@v3 if: startsWith(steps.extract.outputs.jira, 'SNOW-') with: issue: "${{ steps.extract.outputs.jira }}" diff --git a/.github/workflows/jira_issue.yml b/.github/workflows/jira_issue.yml index 31b93aae..85c774ca 100644 --- a/.github/workflows/jira_issue.yml +++ b/.github/workflows/jira_issue.yml @@ -23,7 +23,7 @@ jobs: path: . - name: Login - uses: atlassian/gajira-login@v2.0.0 + uses: atlassian/gajira-login@v3 env: JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} @@ -31,7 +31,7 @@ jobs: - name: Create JIRA Ticket id: create - uses: atlassian/gajira-create@v2.0.1 + uses: atlassian/gajira-create@v3 with: project: SNOW issuetype: Bug diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index ab4be45b..23116e7a 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -35,7 +35,7 @@ jobs: - name: Build package run: python -m build - name: Publish package - uses: pypa/gh-action-pypi-publish@release/v1 + uses: pypa/gh-action-pypi-publish@e53eb8b103ffcb59469888563dc324e3c8ba6f06 with: user: __token__ password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/stale_issue_bot.yml b/.github/workflows/stale_issue_bot.yml index 6d76e9f4..4ee56ff8 100644 --- a/.github/workflows/stale_issue_bot.yml +++ b/.github/workflows/stale_issue_bot.yml @@ -10,7 +10,7 @@ jobs: stale: runs-on: ubuntu-latest steps: - - uses: actions/stale@v7 + - uses: actions/stale@v9 with: close-issue-message: 'To clean up and re-prioritize bugs and feature requests we are closing all issues older than 6 months as of Apr 1, 2023. If there are any issues or feature requests that you would like us to address, please re-create them. For urgent issues, opening a support case with this link [Snowflake Community](https://community.snowflake.com/s/article/How-To-Submit-a-Support-Case-in-Snowflake-Lodge) is the fastest way to get a response' days-before-issue-stale: ${{ inputs.staleDays }} diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 2f228781..8b4dcd37 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,13 +9,17 @@ Source code is also available at: # Release Notes +- v1.6.0(Not released) + + - support for installing with SQLAlchemy 2.0.x + - v1.5.4 - Add ability to set ORDER / NOORDER sequence on columns with IDENTITY - v1.5.3(April 16, 2024) - - Limit SQLAlchemy to < 2.0.0 before releasing version compatible with 2.0 + - Limit SQLAlchemy to < 2.0.0 before releasing version compatible with 2.0 - v1.5.2(April 11, 2024) diff --git a/pyproject.toml b/pyproject.toml index 3f95df46..d2316a44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Application Frameworks", "Topic :: Software Development :: Libraries :: Python Modules", ] -dependencies = ["SQLAlchemy>=1.4.19,<2.0.0", "snowflake-connector-python<4.0.0"] +dependencies = ["SQLAlchemy>=1.4.19", "snowflake-connector-python<4.0.0"] [tool.hatch.version] path = "src/snowflake/sqlalchemy/version.py" @@ -73,8 +73,14 @@ exclude = ["/.github"] packages = ["src/snowflake"] [tool.hatch.envs.default] +extra-dependencies = ["SQLAlchemy>=1.4.19,<2.0.0"] features = ["development", "pandas"] python = "3.8" +installer = "uv" + +[tool.hatch.envs.sa20] +extra-dependencies = ["SQLAlchemy>=1.4.19,<=2.1.0"] +python = "3.8" [tool.hatch.envs.default.env-vars] COVERAGE_FILE = "coverage.xml" @@ -82,10 +88,10 @@ SQLACHEMY_WARN_20 = "1" [tool.hatch.envs.default.scripts] check = "pre-commit run --all-files" -test-dialect = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite" -test-dialect-compatibility = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml tests/sqlalchemy_test_suite" -test-run_v20 = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite --run_v20_sqlalchemy" +test-dialect = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/" +test-dialect-compatibility = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml tests/sqlalchemy_test_suite tests/" gh-cache-sum = "python -VV | sha256sum | cut -d' ' -f1" +check-import = "python -c 'import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)'" [tool.ruff] line-length = 88 diff --git a/snyk/requirements.txt b/snyk/requirements.txt index 3a77e0f9..0166d751 100644 --- a/snyk/requirements.txt +++ b/snyk/requirements.txt @@ -1,2 +1,2 @@ -SQLAlchemy>=1.4.19,<2.0.0 +SQLAlchemy>=1.4.19 snowflake-connector-python<4.0.0 diff --git a/snyk/requiremtnts.txt b/snyk/requiremtnts.txt new file mode 100644 index 00000000..a92c527e --- /dev/null +++ b/snyk/requiremtnts.txt @@ -0,0 +1,2 @@ +snowflake-connector-python<4.0.0 +SQLAlchemy>=1.4.19,<2.1.0 diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index e008c92f..1aaa881e 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -13,13 +13,14 @@ from sqlalchemy.orm import context from sqlalchemy.orm.context import _MapperEntity from sqlalchemy.schema import Sequence, Table -from sqlalchemy.sql import compiler, expression +from sqlalchemy.sql import compiler, expression, functions from sqlalchemy.sql.base import CompileState from sqlalchemy.sql.elements import quoted_name from sqlalchemy.sql.selectable import Lateral, SelectState -from sqlalchemy.util.compat import string_types +from .compat import IS_VERSION_20, args_reducer, string_types from .custom_commands import AWSBucket, AzureContainer, ExternalStage +from .functions import flatten from .util import ( _find_left_clause_to_join_from, _set_connection_interpolate_empty_sequences, @@ -324,17 +325,9 @@ def _join_determine_implicit_left_side( return left, replace_from_obj_index, use_entity_index + @args_reducer(positions_to_drop=(6, 7)) def _join_left_to_right( - self, - entities_collection, - left, - right, - onclause, - prop, - create_aliases, - aliased_generation, - outerjoin, - full, + self, entities_collection, left, right, onclause, prop, outerjoin, full ): """given raw "left", "right", "onclause" parameters consumed from a particular key within _join(), add a real ORMJoin object to @@ -364,7 +357,7 @@ def _join_left_to_right( use_entity_index, ) = self._join_place_explicit_left_side(entities_collection, left) - if left is right and not create_aliases: + if left is right: raise sa_exc.InvalidRequestError( "Can't construct a join from %s to %s, they " "are the same entity" % (left, right) @@ -373,9 +366,15 @@ def _join_left_to_right( # the right side as given often needs to be adapted. additionally # a lot of things can be wrong with it. handle all that and # get back the new effective "right" side - r_info, right, onclause = self._join_check_and_adapt_right_side( - left, right, onclause, prop, create_aliases, aliased_generation - ) + + if IS_VERSION_20: + r_info, right, onclause = self._join_check_and_adapt_right_side( + left, right, onclause, prop + ) + else: + r_info, right, onclause = self._join_check_and_adapt_right_side( + left, right, onclause, prop, False, False + ) if not r_info.is_selectable: extra_criteria = self._get_extra_criteria(r_info) @@ -979,24 +978,23 @@ def visit_identity_column(self, identity, **kw): def get_identity_options(self, identity_options): text = [] if identity_options.increment is not None: - text.append(f"INCREMENT BY {identity_options.increment:d}") + text.append("INCREMENT BY %d" % identity_options.increment) if identity_options.start is not None: - text.append(f"START WITH {identity_options.start:d}") + text.append("START WITH %d" % identity_options.start) if identity_options.minvalue is not None: - text.append(f"MINVALUE {identity_options.minvalue:d}") + text.append("MINVALUE %d" % identity_options.minvalue) if identity_options.maxvalue is not None: - text.append(f"MAXVALUE {identity_options.maxvalue:d}") + text.append("MAXVALUE %d" % identity_options.maxvalue) if identity_options.nominvalue is not None: text.append("NO MINVALUE") if identity_options.nomaxvalue is not None: text.append("NO MAXVALUE") if identity_options.cache is not None: - text.append(f"CACHE {identity_options.cache:d}") + text.append("CACHE %d" % identity_options.cache) if identity_options.cycle is not None: text.append("CYCLE" if identity_options.cycle else "NO CYCLE") if identity_options.order is not None: text.append("ORDER" if identity_options.order else "NOORDER") - return " ".join(text) @@ -1066,3 +1064,5 @@ def visit_GEOMETRY(self, type_, **kw): construct_arguments = [(Table, {"clusterby": None})] + +functions.register_function("flatten", flatten) diff --git a/src/snowflake/sqlalchemy/compat.py b/src/snowflake/sqlalchemy/compat.py new file mode 100644 index 00000000..9e97e574 --- /dev/null +++ b/src/snowflake/sqlalchemy/compat.py @@ -0,0 +1,36 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +from __future__ import annotations + +import functools +from typing import Callable + +from sqlalchemy import __version__ as SA_VERSION +from sqlalchemy import util + +string_types = (str,) +returns_unicode = util.symbol("RETURNS_UNICODE") + +IS_VERSION_20 = tuple(int(v) for v in SA_VERSION.split(".")) >= (2, 0, 0) + + +def args_reducer(positions_to_drop: tuple): + """Removes args at positions provided in tuple positions_to_drop. + + For example tuple (3, 5) will remove items at third and fifth position. + Keep in mind that on class methods first postion is cls or self. + """ + + def fn_wrapper(fn: Callable): + @functools.wraps(fn) + def wrapper(*args): + reduced_args = args + if not IS_VERSION_20: + reduced_args = tuple( + arg for idx, arg in enumerate(args) if idx not in positions_to_drop + ) + fn(*reduced_args) + + return wrapper + + return fn_wrapper diff --git a/src/snowflake/sqlalchemy/custom_commands.py b/src/snowflake/sqlalchemy/custom_commands.py index cec16673..15585bd5 100644 --- a/src/snowflake/sqlalchemy/custom_commands.py +++ b/src/snowflake/sqlalchemy/custom_commands.py @@ -10,7 +10,8 @@ from sqlalchemy.sql.dml import UpdateBase from sqlalchemy.sql.elements import ClauseElement from sqlalchemy.sql.roles import FromClauseRole -from sqlalchemy.util.compat import string_types + +from .compat import string_types NoneType = type(None) diff --git a/src/snowflake/sqlalchemy/functions.py b/src/snowflake/sqlalchemy/functions.py new file mode 100644 index 00000000..c08aa734 --- /dev/null +++ b/src/snowflake/sqlalchemy/functions.py @@ -0,0 +1,16 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +import warnings + +from sqlalchemy.sql import functions as sqlfunc + +FLATTEN_WARNING = "For backward compatibility params are not rendered." + + +class flatten(sqlfunc.GenericFunction): + name = "flatten" + + def __init__(self, *args, **kwargs): + warnings.warn(FLATTEN_WARNING, DeprecationWarning, stacklevel=2) + super().__init__(*args, **kwargs) diff --git a/src/snowflake/sqlalchemy/requirements.py b/src/snowflake/sqlalchemy/requirements.py index ea30a823..f2844804 100644 --- a/src/snowflake/sqlalchemy/requirements.py +++ b/src/snowflake/sqlalchemy/requirements.py @@ -289,9 +289,25 @@ def datetime_implicit_bound(self): # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. return exclusions.closed() + @property + def date_implicit_bound(self): + # Supporting this would require behavior breaking change to implicitly convert str to timestamp when binding + # parameters in string forms of timestamp values. + return exclusions.closed() + + @property + def time_implicit_bound(self): + # Supporting this would require behavior breaking change to implicitly convert str to timestamp when binding + # parameters in string forms of timestamp values. + return exclusions.closed() + @property def timestamp_microseconds_implicit_bound(self): # Supporting this would require behavior breaking change to implicitly convert str to timestamp when binding # parameters in string forms of timestamp values. # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. return exclusions.closed() + + @property + def array_type(self): + return exclusions.closed() diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index 2e40d03c..04305a00 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -5,6 +5,7 @@ import operator from collections import defaultdict from functools import reduce +from typing import Any from urllib.parse import unquote_plus import sqlalchemy.types as sqltypes @@ -15,7 +16,6 @@ from sqlalchemy.schema import Table from sqlalchemy.sql import text from sqlalchemy.sql.elements import quoted_name -from sqlalchemy.sql.sqltypes import String from sqlalchemy.types import ( BIGINT, BINARY, @@ -40,6 +40,7 @@ from snowflake.connector import errors as sf_errors from snowflake.connector.connection import DEFAULT_CONFIGURATION from snowflake.connector.constants import UTF8 +from snowflake.sqlalchemy.compat import returns_unicode from .base import ( SnowflakeCompiler, @@ -63,7 +64,11 @@ _CUSTOM_Float, _CUSTOM_Time, ) -from .util import _update_connection_application_name, parse_url_boolean +from .util import ( + _update_connection_application_name, + parse_url_boolean, + parse_url_integer, +) colspecs = { Date: _CUSTOM_Date, @@ -134,7 +139,7 @@ class SnowflakeDialect(default.DefaultDialect): # unicode strings supports_unicode_statements = True supports_unicode_binds = True - returns_unicode_strings = String.RETURNS_UNICODE + returns_unicode_strings = returns_unicode description_encoding = None # No lastrowid support. See SNOW-11155 @@ -195,10 +200,34 @@ class SnowflakeDialect(default.DefaultDialect): @classmethod def dbapi(cls): + return cls.import_dbapi() + + @classmethod + def import_dbapi(cls): from snowflake import connector return connector + @staticmethod + def parse_query_param_type(name: str, value: Any) -> Any: + """Cast param value if possible to type defined in connector-python.""" + if not (maybe_type_configuration := DEFAULT_CONFIGURATION.get(name)): + return value + + _, expected_type = maybe_type_configuration + if not isinstance(expected_type, tuple): + expected_type = (expected_type,) + + if isinstance(value, expected_type): + return value + + elif bool in expected_type: + return parse_url_boolean(value) + elif int in expected_type: + return parse_url_integer(value) + else: + return value + def create_connect_args(self, url: URL): opts = url.translate_connect_args(username="user") if "database" in opts: @@ -235,47 +264,25 @@ def create_connect_args(self, url: URL): # URL sets the query parameter values as strings, we need to cast to expected types when necessary for name, value in query.items(): - maybe_type_configuration = DEFAULT_CONFIGURATION.get(name) - if ( - not maybe_type_configuration - ): # if the parameter is not found in the type mapping, pass it through as a string - opts[name] = value - continue - - (_, expected_type) = maybe_type_configuration - if not isinstance(expected_type, tuple): - expected_type = (expected_type,) - - if isinstance( - value, expected_type - ): # if the expected type is str, pass it through as a string - opts[name] = value - - elif ( - bool in expected_type - ): # if the expected type is bool, parse it and pass as a boolean - opts[name] = parse_url_boolean(value) - else: - # TODO: other types like int are stil passed through as string - # https://github.com/snowflakedb/snowflake-sqlalchemy/issues/447 - opts[name] = value + opts[name] = self.parse_query_param_type(name, value) return ([], opts) - def has_table(self, connection, table_name, schema=None): + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): """ Checks if the table exists """ return self._has_object(connection, "TABLE", table_name, schema) - def has_sequence(self, connection, sequence_name, schema=None): + @reflection.cache + def has_sequence(self, connection, sequence_name, schema=None, **kw): """ Checks if the sequence exists """ return self._has_object(connection, "SEQUENCE", sequence_name, schema) def _has_object(self, connection, object_type, object_name, schema=None): - full_name = self._denormalize_quote_join(schema, object_name) try: results = connection.execute( diff --git a/src/snowflake/sqlalchemy/util.py b/src/snowflake/sqlalchemy/util.py index 32e07373..a1aefff9 100644 --- a/src/snowflake/sqlalchemy/util.py +++ b/src/snowflake/sqlalchemy/util.py @@ -7,7 +7,7 @@ from typing import Any from urllib.parse import quote_plus -from sqlalchemy import exc, inspection, sql, util +from sqlalchemy import exc, inspection, sql from sqlalchemy.exc import NoForeignKeysError from sqlalchemy.orm.interfaces import MapperProperty from sqlalchemy.orm.util import _ORMJoin as sa_orm_util_ORMJoin @@ -19,6 +19,7 @@ from snowflake.connector.compat import IS_STR from snowflake.connector.connection import SnowflakeConnection +from snowflake.sqlalchemy import compat from ._constants import ( APPLICATION_NAME, @@ -124,6 +125,13 @@ def parse_url_boolean(value: str) -> bool: raise ValueError(f"Invalid boolean value detected: '{value}'") +def parse_url_integer(value: str) -> int: + try: + return int(value) + except ValueError as e: + raise ValueError(f"Invalid int value detected: '{value}") from e + + # handle Snowflake BCR bcr-1057 # the BCR impacts sqlalchemy.orm.context.ORMSelectCompileState and sqlalchemy.sql.selectable.SelectState # which used the 'sqlalchemy.util.preloaded.sql_util.find_left_clause_to_join_from' method that @@ -212,7 +220,7 @@ def __init__( # then the "_joined_from_info" concept can go left_orm_info = getattr(left, "_joined_from_info", left_info) self._joined_from_info = right_info - if isinstance(onclause, util.string_types): + if isinstance(onclause, compat.string_types): onclause = getattr(left_orm_info.entity, onclause) # #### diff --git a/src/snowflake/sqlalchemy/version.py b/src/snowflake/sqlalchemy/version.py index 61c9fc41..56509b7d 100644 --- a/src/snowflake/sqlalchemy/version.py +++ b/src/snowflake/sqlalchemy/version.py @@ -3,4 +3,4 @@ # # Update this for the versions # Don't change the forth version number from None -VERSION = "1.5.3" +VERSION = "1.6.0" diff --git a/tests/conftest.py b/tests/conftest.py index a9c2560a..d4dab3d1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,21 +46,6 @@ TEST_SCHEMA = f"sqlalchemy_tests_{str(uuid.uuid4()).replace('-', '_')}" -def pytest_addoption(parser): - parser.addoption( - "--run_v20_sqlalchemy", - help="Use only 2.0 SQLAlchemy APIs, any legacy features (< 2.0) will not be supported." - "Turning on this option will set future flag to True on Engine and Session objects according to" - "the migration guide: https://docs.sqlalchemy.org/en/14/changelog/migration_20.html", - action="store_true", - ) - - -@pytest.fixture(scope="session") -def run_v20_sqlalchemy(pytestconfig): - return pytestconfig.option.run_v20_sqlalchemy - - @pytest.fixture(scope="session") def on_travis(): return os.getenv("TRAVIS", "").lower() == "true" @@ -160,20 +145,21 @@ def url_factory(**kwargs) -> URL: return URL(**url_params) -def get_engine(url: URL, run_v20_sqlalchemy=False, **engine_kwargs): +def get_engine(url: URL, **engine_kwargs): engine_params = { "poolclass": NullPool, - "future": run_v20_sqlalchemy, + "future": True, + "echo": True, } engine_params.update(engine_kwargs) - engine = create_engine(url, **engine_kwargs) + engine = create_engine(url, **engine_params) return engine @pytest.fixture() -def engine_testaccount(request, run_v20_sqlalchemy): +def engine_testaccount(request): url = url_factory() - engine = get_engine(url, run_v20_sqlalchemy=run_v20_sqlalchemy) + engine = get_engine(url) request.addfinalizer(engine.dispose) yield engine @@ -181,17 +167,17 @@ def engine_testaccount(request, run_v20_sqlalchemy): @pytest.fixture() def engine_testaccount_with_numpy(request): url = url_factory(numpy=True) - engine = get_engine(url, run_v20_sqlalchemy=run_v20_sqlalchemy) + engine = get_engine(url) request.addfinalizer(engine.dispose) yield engine @pytest.fixture() -def engine_testaccount_with_qmark(request, run_v20_sqlalchemy): +def engine_testaccount_with_qmark(request): snowflake.connector.paramstyle = "qmark" url = url_factory() - engine = get_engine(url, run_v20_sqlalchemy=run_v20_sqlalchemy) + engine = get_engine(url) request.addfinalizer(engine.dispose) yield engine diff --git a/tests/sqlalchemy_test_suite/conftest.py b/tests/sqlalchemy_test_suite/conftest.py index 31cd7c5c..f0464c7d 100644 --- a/tests/sqlalchemy_test_suite/conftest.py +++ b/tests/sqlalchemy_test_suite/conftest.py @@ -15,6 +15,7 @@ import snowflake.connector from snowflake.sqlalchemy import URL +from snowflake.sqlalchemy.compat import IS_VERSION_20 from ..conftest import get_db_parameters from ..util import random_string @@ -25,6 +26,12 @@ TEST_SCHEMA_2 = f"{TEST_SCHEMA}_2" +if IS_VERSION_20: + collect_ignore_glob = ["test_suite.py"] +else: + collect_ignore_glob = ["test_suite_20.py"] + + # patch sqlalchemy.testing.config.Confi.__init__ for schema name randomization # same schema name would result in conflict as we're running tests in parallel in the CI def config_patched__init__(self, db, db_opts, options, file_config): diff --git a/tests/sqlalchemy_test_suite/test_suite.py b/tests/sqlalchemy_test_suite/test_suite.py index d79e511e..643d1559 100644 --- a/tests/sqlalchemy_test_suite/test_suite.py +++ b/tests/sqlalchemy_test_suite/test_suite.py @@ -69,6 +69,10 @@ def test_empty_insert(self, connection): def test_empty_insert_multiple(self, connection): pass + @pytest.mark.skip("Snowflake does not support returning in insert.") + def test_no_results_for_non_returning_insert(self, connection, style, executemany): + pass + # 2. Patched Tests diff --git a/tests/sqlalchemy_test_suite/test_suite_20.py b/tests/sqlalchemy_test_suite/test_suite_20.py new file mode 100644 index 00000000..1f79c4e9 --- /dev/null +++ b/tests/sqlalchemy_test_suite/test_suite_20.py @@ -0,0 +1,205 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import Integer, testing +from sqlalchemy.schema import Column, Sequence, Table +from sqlalchemy.testing import config +from sqlalchemy.testing.assertions import eq_ +from sqlalchemy.testing.suite import ( + BizarroCharacterFKResolutionTest as _BizarroCharacterFKResolutionTest, +) +from sqlalchemy.testing.suite import ( + CompositeKeyReflectionTest as _CompositeKeyReflectionTest, +) +from sqlalchemy.testing.suite import DateTimeHistoricTest as _DateTimeHistoricTest +from sqlalchemy.testing.suite import FetchLimitOffsetTest as _FetchLimitOffsetTest +from sqlalchemy.testing.suite import HasSequenceTest as _HasSequenceTest +from sqlalchemy.testing.suite import InsertBehaviorTest as _InsertBehaviorTest +from sqlalchemy.testing.suite import LikeFunctionsTest as _LikeFunctionsTest +from sqlalchemy.testing.suite import LongNameBlowoutTest as _LongNameBlowoutTest +from sqlalchemy.testing.suite import SimpleUpdateDeleteTest as _SimpleUpdateDeleteTest +from sqlalchemy.testing.suite import TimeMicrosecondsTest as _TimeMicrosecondsTest +from sqlalchemy.testing.suite import TrueDivTest as _TrueDivTest +from sqlalchemy.testing.suite import * # noqa + +# 1. Unsupported by snowflake db + +del ComponentReflectionTest # require indexes not supported by snowflake +del HasIndexTest # require indexes not supported by snowflake +del QuotedNameArgumentTest # require indexes not supported by snowflake + + +class LongNameBlowoutTest(_LongNameBlowoutTest): + # The combination ("ix",) is removed due to Snowflake not supporting indexes + def ix(self, metadata, connection): + pytest.skip("ix required index feature not supported by Snowflake") + + +class FetchLimitOffsetTest(_FetchLimitOffsetTest): + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_bound_offset(self, connection): + pass + + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_simple_limit_expr_offset(self, connection): + pass + + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_simple_offset(self, connection): + pass + + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_simple_offset_zero(self, connection): + pass + + +class InsertBehaviorTest(_InsertBehaviorTest): + @pytest.mark.skip( + "Snowflake does not support inserting empty values, the value may be a literal or an expression." + ) + def test_empty_insert(self, connection): + pass + + @pytest.mark.skip( + "Snowflake does not support inserting empty values, The value may be a literal or an expression." + ) + def test_empty_insert_multiple(self, connection): + pass + + @pytest.mark.skip("Snowflake does not support returning in insert.") + def test_no_results_for_non_returning_insert(self, connection, style, executemany): + pass + + +# road to 2.0 +class TrueDivTest(_TrueDivTest): + @pytest.mark.skip("`//` not supported") + def test_floordiv_integer_bound(self, connection): + """Snowflake does not provide `//` arithmetic operator. + + https://docs.snowflake.com/en/sql-reference/operators-arithmetic. + """ + pass + + @pytest.mark.skip("`//` not supported") + def test_floordiv_integer(self, connection, left, right, expected): + """Snowflake does not provide `//` arithmetic operator. + + https://docs.snowflake.com/en/sql-reference/operators-arithmetic. + """ + pass + + +class TimeMicrosecondsTest(_TimeMicrosecondsTest): + def __init__(self): + super().__init__() + + +class DateTimeHistoricTest(_DateTimeHistoricTest): + def __init__(self): + super().__init__() + + +# 2. Patched Tests + + +class HasSequenceTest(_HasSequenceTest): + # Override the define_tables method as snowflake does not support 'nomaxvalue'/'nominvalue' + @classmethod + def define_tables(cls, metadata): + Sequence("user_id_seq", metadata=metadata) + # Replace Sequence("other_seq") creation as in the original test suite, + # the Sequence created with 'nomaxvalue' and 'nominvalue' + # which snowflake does not support: + # Sequence("other_seq", metadata=metadata, nomaxvalue=True, nominvalue=True) + Sequence("other_seq", metadata=metadata) + if testing.requires.schemas.enabled: + Sequence("user_id_seq", schema=config.test_schema, metadata=metadata) + Sequence("schema_seq", schema=config.test_schema, metadata=metadata) + Table( + "user_id_table", + metadata, + Column("id", Integer, primary_key=True), + ) + + +class LikeFunctionsTest(_LikeFunctionsTest): + @testing.requires.regexp_match + @testing.combinations( + ("a.cde.*", {1, 5, 6, 9}), + ("abc.*", {1, 5, 6, 9, 10}), + ("^abc.*", {1, 5, 6, 9, 10}), + (".*9cde.*", {8}), + ("^a.*", set(range(1, 11))), + (".*(b|c).*", set(range(1, 11))), + ("^(b|c).*", set()), + ) + def test_regexp_match(self, text, expected): + super().test_regexp_match(text, expected) + + def test_not_regexp_match(self): + col = self.tables.some_table.c.data + self._test(~col.regexp_match("a.cde.*"), {2, 3, 4, 7, 8, 10}) + + +class SimpleUpdateDeleteTest(_SimpleUpdateDeleteTest): + def test_update(self, connection): + t = self.tables.plain_pk + r = connection.execute(t.update().where(t.c.id == 2), dict(data="d2_new")) + assert not r.is_insert + # snowflake returns a row with numbers of rows updated and number of multi-joined rows updated + assert r.returns_rows + assert r.rowcount == 1 + + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (2, "d2_new"), (3, "d3")], + ) + + def test_delete(self, connection): + t = self.tables.plain_pk + r = connection.execute(t.delete().where(t.c.id == 2)) + assert not r.is_insert + # snowflake returns a row with number of rows deleted + assert r.returns_rows + assert r.rowcount == 1 + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (3, "d3")], + ) + + +class CompositeKeyReflectionTest(_CompositeKeyReflectionTest): + @pytest.mark.xfail(reason="Fixing this would require behavior breaking change.") + def test_fk_column_order(self): + # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. + super().test_fk_column_order() + + @pytest.mark.xfail(reason="Fixing this would require behavior breaking change.") + def test_pk_column_order(self): + # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. + super().test_pk_column_order() + + +class BizarroCharacterFKResolutionTest(_BizarroCharacterFKResolutionTest): + @testing.combinations( + ("id",), ("(3)",), ("col%p",), ("[brack]",), argnames="columnname" + ) + @testing.variation("use_composite", [True, False]) + @testing.combinations( + ("plain",), + ("(2)",), + ("[brackets]",), + argnames="tablename", + ) + def test_fk_ref(self, connection, metadata, use_composite, tablename, columnname): + super().test_fk_ref(connection, metadata, use_composite, tablename, columnname) diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 0fd75c38..40207b41 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -5,7 +5,7 @@ from sqlalchemy import Integer, String, and_, func, select from sqlalchemy.schema import DropColumnComment, DropTableComment from sqlalchemy.sql import column, quoted_name, table -from sqlalchemy.testing import AssertsCompiledSQL +from sqlalchemy.testing.assertions import AssertsCompiledSQL from snowflake.sqlalchemy import snowdialect diff --git a/tests/test_core.py b/tests/test_core.py index 6c8d7416..179133c8 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -34,7 +34,7 @@ inspect, text, ) -from sqlalchemy.exc import DBAPIError, NoSuchTableError +from sqlalchemy.exc import DBAPIError, NoSuchTableError, OperationalError from sqlalchemy.sql import and_, not_, or_, select import snowflake.connector.errors @@ -406,16 +406,6 @@ def test_insert_tables(engine_testaccount): str(users.join(addresses)) == "users JOIN addresses ON " "users.id = addresses.user_id" ) - assert ( - str( - users.join( - addresses, - addresses.c.email_address.like(users.c.name + "%"), - ) - ) - == "users JOIN addresses " - "ON addresses.email_address LIKE users.name || :name_1" - ) s = select(users.c.fullname).select_from( users.join( @@ -444,7 +434,7 @@ def test_table_does_not_exist(engine_testaccount): """ meta = MetaData() with pytest.raises(NoSuchTableError): - Table("does_not_exist", meta, autoload=True, autoload_with=engine_testaccount) + Table("does_not_exist", meta, autoload_with=engine_testaccount) @pytest.mark.skip( @@ -470,9 +460,7 @@ def test_reflextion(engine_testaccount): ) try: meta = MetaData() - user_reflected = Table( - "user", meta, autoload=True, autoload_with=engine_testaccount - ) + user_reflected = Table("user", meta, autoload_with=engine_testaccount) assert user_reflected.c == ["user.id", "user.name", "user.fullname"] finally: conn.execute("DROP TABLE IF EXISTS user") @@ -1071,28 +1059,15 @@ def harass_inspector(): assert outcome -@pytest.mark.timeout(15) -def test_region(): - engine = create_engine( - URL( - user="testuser", - password="testpassword", - account="testaccount", - region="eu-central-1", - login_timeout=5, - ) - ) - try: - engine.connect() - pytest.fail("should not run") - except Exception as ex: - assert ex.orig.errno == 250001 - assert "Failed to connect to DB" in ex.orig.msg - assert "testaccount.eu-central-1.snowflakecomputing.com" in ex.orig.msg - - -@pytest.mark.timeout(15) -def test_azure(): +@pytest.mark.timeout(10) +@pytest.mark.parametrize( + "region", + ( + pytest.param("eu-central-1", id="region"), + pytest.param("east-us-2.azure", id="azure"), + ), +) +def test_connection_timeout_error(region): engine = create_engine( URL( user="testuser", @@ -1102,13 +1077,13 @@ def test_azure(): login_timeout=5, ) ) - try: + + with pytest.raises(OperationalError) as excinfo: engine.connect() - pytest.fail("should not run") - except Exception as ex: - assert ex.orig.errno == 250001 - assert "Failed to connect to DB" in ex.orig.msg - assert "testaccount.east-us-2.azure.snowflakecomputing.com" in ex.orig.msg + + assert excinfo.value.orig.errno == 250001 + assert "Could not connect to Snowflake backend" in excinfo.value.orig.msg + assert region not in excinfo.value.orig.msg def test_load_dialect(): @@ -1535,11 +1510,16 @@ def test_too_many_columns_detection(engine_testaccount, db_parameters): metadata.create_all(engine_testaccount) inspector = inspect(engine_testaccount) # Do test - original_execute = inspector.bind.execute + connection = inspector.bind.connect() + original_execute = connection.execute + + too_many_columns_was_raised = False def mock_helper(command, *args, **kwargs): - if "_get_schema_columns" in command: + if "_get_schema_columns" in command.text: # Creating exception exactly how SQLAlchemy does + nonlocal too_many_columns_was_raised + too_many_columns_was_raised = True raise DBAPIError.instance( """ SELECT /* sqlalchemy:_get_schema_columns */ @@ -1571,9 +1551,12 @@ def mock_helper(command, *args, **kwargs): else: return original_execute(command, *args, **kwargs) - with patch.object(inspector.bind, "execute", side_effect=mock_helper): - column_metadata = inspector.get_columns("users", db_parameters["schema"]) + with patch.object(engine_testaccount, "connect") as conn: + conn.return_value = connection + with patch.object(connection, "execute", side_effect=mock_helper): + column_metadata = inspector.get_columns("users", db_parameters["schema"]) assert len(column_metadata) == 4 + assert too_many_columns_was_raised # Clean up metadata.drop_all(engine_testaccount) @@ -1615,9 +1598,7 @@ def test_column_type_schema(engine_testaccount): """ ) - table_reflected = Table( - table_name, MetaData(), autoload=True, autoload_with=conn - ) + table_reflected = Table(table_name, MetaData(), autoload_with=conn) columns = table_reflected.columns assert ( len(columns) == len(ischema_names_baseline) - 1 @@ -1638,9 +1619,7 @@ def test_result_type_and_value(engine_testaccount): ) """ ) - table_reflected = Table( - table_name, MetaData(), autoload=True, autoload_with=conn - ) + table_reflected = Table(table_name, MetaData(), autoload_with=conn) current_date = date.today() current_utctime = datetime.utcnow() current_localtime = pytz.utc.localize(current_utctime, is_dst=False).astimezone( diff --git a/tests/test_custom_functions.py b/tests/test_custom_functions.py new file mode 100644 index 00000000..2a1e1cb5 --- /dev/null +++ b/tests/test_custom_functions.py @@ -0,0 +1,25 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +import pytest +from sqlalchemy import func + +from snowflake.sqlalchemy import snowdialect + + +def test_flatten_does_not_render_params(): + """This behavior is for backward compatibility. + + In previous version params were not rendered. + In future this behavior will change. + """ + flat = func.flatten("[1, 2]", outer=True) + res = flat.compile(dialect=snowdialect.dialect()) + + assert str(res) == "flatten(%(flatten_1)s)" + + +def test_flatten_emits_warning(): + expected_warning = "For backward compatibility params are not rendered." + with pytest.warns(DeprecationWarning, match=expected_warning): + func.flatten().compile(dialect=snowdialect.dialect()) diff --git a/tests/test_orm.py b/tests/test_orm.py index e485d737..f53cd708 100644 --- a/tests/test_orm.py +++ b/tests/test_orm.py @@ -20,7 +20,7 @@ from sqlalchemy.orm import Session, declarative_base, relationship -def test_basic_orm(engine_testaccount, run_v20_sqlalchemy): +def test_basic_orm(engine_testaccount): """ Tests declarative """ @@ -46,7 +46,6 @@ def __repr__(self): ed_user = User(name="ed", fullname="Edward Jones") session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(ed_user) our_user = session.query(User).filter_by(name="ed").first() @@ -56,7 +55,7 @@ def __repr__(self): Base.metadata.drop_all(engine_testaccount) -def test_orm_one_to_many_relationship(engine_testaccount, run_v20_sqlalchemy): +def test_orm_one_to_many_relationship(engine_testaccount): """ Tests One to Many relationship """ @@ -97,7 +96,6 @@ def __repr__(self): ] session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(jack) # cascade each Address into the Session as well session.commit() @@ -124,7 +122,7 @@ def __repr__(self): Base.metadata.drop_all(engine_testaccount) -def test_delete_cascade(engine_testaccount, run_v20_sqlalchemy): +def test_delete_cascade(engine_testaccount): """ Test delete cascade """ @@ -169,7 +167,6 @@ def __repr__(self): ] session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(jack) # cascade each Address into the Session as well session.commit() @@ -189,7 +186,7 @@ def __repr__(self): WIP """, ) -def test_orm_query(engine_testaccount, run_v20_sqlalchemy): +def test_orm_query(engine_testaccount): """ Tests ORM query """ @@ -210,7 +207,6 @@ def __repr__(self): # TODO: insert rows session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy # TODO: query.all() for name, fullname in session.query(User.name, User.fullname): @@ -220,7 +216,7 @@ def __repr__(self): # MultipleResultsFound if not one result -def test_schema_including_db(engine_testaccount, db_parameters, run_v20_sqlalchemy): +def test_schema_including_db(engine_testaccount, db_parameters): """ Test schema parameter including database separated by a dot. """ @@ -243,7 +239,6 @@ class User(Base): ed_user = User(name="ed", fullname="Edward Jones") session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(ed_user) ret_user = session.query(User.id, User.name).first() @@ -255,7 +250,7 @@ class User(Base): Base.metadata.drop_all(engine_testaccount) -def test_schema_including_dot(engine_testaccount, db_parameters, run_v20_sqlalchemy): +def test_schema_including_dot(engine_testaccount, db_parameters): """ Tests pseudo schema name including dot. """ @@ -276,7 +271,6 @@ class User(Base): fullname = Column(String) session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy query = session.query(User.id) assert str(query).startswith( 'SELECT {db}."{schema}.{schema}".{db}.users.id'.format( @@ -285,9 +279,7 @@ class User(Base): ) -def test_schema_translate_map( - engine_testaccount, db_parameters, sql_compiler, run_v20_sqlalchemy -): +def test_schema_translate_map(engine_testaccount, db_parameters): """ Test schema translate map execution option works replaces schema correctly """ @@ -310,7 +302,6 @@ class User(Base): schema_translate_map={schema_map: db_parameters["schema"]} ) as con: session = Session(bind=con) - session.future = run_v20_sqlalchemy with con.begin(): Base.metadata.create_all(con) try: @@ -367,18 +358,29 @@ class Department(Base): .select_from(Employee) .outerjoin(sub) ) - assert ( - str(query.compile(engine_testaccount)).replace("\n", "") - == "SELECT employees.employee_id, departments.department_id " + compiled_stmts = ( + # v1.x + "SELECT employees.employee_id, departments.department_id " "FROM departments, employees LEFT OUTER JOIN LATERAL " "(SELECT departments.department_id AS department_id, departments.name AS name " - "FROM departments) AS anon_1" + "FROM departments) AS anon_1", + # v2.x + "SELECT employees.employee_id, departments.department_id " + "FROM employees LEFT OUTER JOIN LATERAL " + "(SELECT departments.department_id AS department_id, departments.name AS name " + "FROM departments) AS anon_1, departments", ) + compiled_stmt = str(query.compile(engine_testaccount)).replace("\n", "") + assert compiled_stmt in compiled_stmts + with caplog.at_level(logging.DEBUG): assert [res for res in session.execute(query)] assert ( "SELECT employees.employee_id, departments.department_id FROM departments" in caplog.text + ) or ( + "SELECT employees.employee_id, departments.department_id FROM employees" + in caplog.text ) diff --git a/tests/test_pandas.py b/tests/test_pandas.py index ef64d65e..63cd6d0e 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -27,6 +27,7 @@ from snowflake.connector import ProgrammingError from snowflake.connector.pandas_tools import make_pd_writer, pd_writer +from snowflake.sqlalchemy.compat import IS_VERSION_20 def _create_users_addresses_tables(engine_testaccount, metadata): @@ -240,8 +241,8 @@ def test_timezone(db_parameters, engine_testaccount, engine_testaccount_with_num conn.exec_driver_sql(f"DROP TABLE {test_table_name};") -def test_pandas_writeback(engine_testaccount, run_v20_sqlalchemy): - if run_v20_sqlalchemy and sys.version_info < (3, 8): +def test_pandas_writeback(engine_testaccount): + if IS_VERSION_20 and sys.version_info < (3, 8): pytest.skip( "In Python 3.7, this test depends on pandas features of which the implementation is incompatible with sqlachemy 2.0, and pandas does not support Python 3.7 anymore." ) @@ -352,8 +353,8 @@ def test_pandas_invalid_make_pd_writer(engine_testaccount): ) -def test_percent_signs(engine_testaccount, run_v20_sqlalchemy): - if run_v20_sqlalchemy and sys.version_info < (3, 8): +def test_percent_signs(engine_testaccount): + if IS_VERSION_20 and sys.version_info < (3, 8): pytest.skip( "In Python 3.7, this test depends on pandas features of which the implementation is incompatible with sqlachemy 2.0, and pandas does not support Python 3.7 anymore." ) @@ -376,7 +377,7 @@ def test_percent_signs(engine_testaccount, run_v20_sqlalchemy): not_like_sql = f"select * from {table_name} where c2 not like '%b%'" like_sql = f"select * from {table_name} where c2 like '%b%'" calculate_sql = "SELECT 1600 % 400 AS a, 1599 % 400 as b" - if run_v20_sqlalchemy: + if IS_VERSION_20: not_like_sql = sqlalchemy.text(not_like_sql) like_sql = sqlalchemy.text(like_sql) calculate_sql = sqlalchemy.text(calculate_sql) diff --git a/tests/test_qmark.py b/tests/test_qmark.py index f98fa7d3..3761181a 100644 --- a/tests/test_qmark.py +++ b/tests/test_qmark.py @@ -12,11 +12,11 @@ THIS_DIR = os.path.dirname(os.path.realpath(__file__)) -def test_qmark_bulk_insert(run_v20_sqlalchemy, engine_testaccount_with_qmark): +def test_qmark_bulk_insert(engine_testaccount_with_qmark): """ Bulk insert using qmark paramstyle """ - if run_v20_sqlalchemy and sys.version_info < (3, 8): + if sys.version_info < (3, 8): pytest.skip( "In Python 3.7, this test depends on pandas features of which the implementation is incompatible with sqlachemy 2.0, and pandas does not support Python 3.7 anymore." ) diff --git a/tox.ini b/tox.ini index 0c1cb483..7f605627 100644 --- a/tox.ini +++ b/tox.ini @@ -34,7 +34,7 @@ passenv = setenv = COVERAGE_FILE = {env:COVERAGE_FILE:{toxworkdir}/.coverage.{envname}} SQLALCHEMY_WARN_20 = 1 - ci: SNOWFLAKE_PYTEST_OPTS = -vvv + ci: SNOWFLAKE_PYTEST_OPTS = -vvv --tb=long commands = pytest \ {env:SNOWFLAKE_PYTEST_OPTS:} \ --cov "snowflake.sqlalchemy" \ @@ -44,12 +44,6 @@ commands = pytest \ --cov "snowflake.sqlalchemy" --cov-append \ --junitxml {toxworkdir}/junit_{envname}.xml \ {posargs:tests/sqlalchemy_test_suite} - pytest \ - {env:SNOWFLAKE_PYTEST_OPTS:} \ - --cov "snowflake.sqlalchemy" --cov-append \ - --junitxml {toxworkdir}/junit_{envname}.xml \ - --run_v20_sqlalchemy \ - {posargs:tests} [testenv:.pkg_external] deps = build @@ -86,7 +80,7 @@ commands = pre-commit run --all-files python -c 'import pathlib; print("hint: run \{\} install to add checks as pre-commit hook".format(pathlib.Path(r"{envdir}") / "bin" / "pre-commit"))' [pytest] -addopts = -ra --strict-markers --ignore=tests/sqlalchemy_test_suite +addopts = -ra --ignore=tests/sqlalchemy_test_suite junit_family = legacy log_level = info markers = From 423b8c13ec23d6d63d39f4019bc0f1caa97909ac Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Wed, 3 Jul 2024 14:46:35 +0200 Subject: [PATCH 02/21] SNOW-1516075: use SQLALchemy 2.0 as default dependency (#511) * SNOW-1516075: use SQLALchemy 2.0 as default dependency --- .github/workflows/build_test.yml | 20 ++++++++++---------- DESCRIPTION.md | 3 ++- README.md | 31 +++++++++++++++++++++---------- ci/build.sh | 13 ++++++++----- ci/test_linux.sh | 18 +++++++++--------- pyproject.toml | 15 ++++++++++++--- 6 files changed, 62 insertions(+), 38 deletions(-) diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index 3baa6a0d..d7f7832b 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -52,7 +52,7 @@ jobs: strategy: fail-fast: true matrix: - hatch-env: [default, sa20] + hatch-env: [default, sa14] steps: - uses: actions/checkout@v4 with: @@ -165,8 +165,8 @@ jobs: path: | ./coverage.xml - test-dialect-v20: - name: Test dialect v20 ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + test-dialect-v14: + name: Test dialect v14 ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} needs: [ lint, build-install ] runs-on: ${{ matrix.os }} strategy: @@ -204,15 +204,15 @@ jobs: python -m uv pip install -U hatch python -m hatch env create default - name: Run tests - run: hatch run sa20:test-dialect + run: hatch run sa14:test-dialect - uses: actions/upload-artifact@v4 with: - name: coverage.xml_dialect-v20-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + name: coverage.xml_dialect-v14-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} path: | ./coverage.xml - test-dialect-compatibility-v20: - name: Test dialect v20 compatibility ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + test-dialect-compatibility-v14: + name: Test dialect v14 compatibility ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} needs: lint runs-on: ${{ matrix.os }} strategy: @@ -250,17 +250,17 @@ jobs: gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py - name: Run tests - run: hatch run sa20:test-dialect-compatibility + run: hatch run sa14:test-dialect-compatibility - uses: actions/upload-artifact@v4 with: - name: coverage.xml_dialect-v20-compatibility-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + name: coverage.xml_dialect-v14-compatibility-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} path: | ./coverage.xml combine-coverage: name: Combine coverage if: ${{ success() || failure() }} - needs: [test-dialect, test-dialect-compatibility, test-dialect-v20, test-dialect-compatibility-v20] + needs: [test-dialect, test-dialect-compatibility, test-dialect-v14, test-dialect-compatibility-v14] runs-on: ubuntu-latest steps: - name: Set up Python diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 8b4dcd37..782c426d 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,9 +9,10 @@ Source code is also available at: # Release Notes -- v1.6.0(Not released) +- v1.6.0(July 4, 2024) - support for installing with SQLAlchemy 2.0.x + - use `hatch` & `uv` for managing project virtual environments - v1.5.4 diff --git a/README.md b/README.md index 0c75854e..c428353f 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,7 @@ containing special characters need to be URL encoded to be parsed correctly. Thi characters could lead to authentication failure. The encoding for the password can be generated using `urllib.parse`: + ```python import urllib.parse urllib.parse.quote("kx@% jj5/g") @@ -111,6 +112,7 @@ urllib.parse.quote("kx@% jj5/g") To create an engine with the proper encodings, either manually constructing the url string by formatting or taking advantage of the `snowflake.sqlalchemy.URL` helper method: + ```python import urllib.parse from snowflake.sqlalchemy import URL @@ -191,14 +193,23 @@ engine = create_engine(...) engine.execute() engine.dispose() -# Do this. +# Better. engine = create_engine(...) connection = engine.connect() try: - connection.execute() + connection.execute(text()) finally: connection.close() engine.dispose() + +# Best +try: + with engine.connext() as connection: + connection.execute(text()) + # or + connection.exec_driver_sql() +finally: + engine.dispose() ``` ### Auto-increment Behavior @@ -242,14 +253,14 @@ engine = create_engine(URL( specific_date = np.datetime64('2016-03-04T12:03:05.123456789Z') -connection = engine.connect() -connection.execute( - "CREATE OR REPLACE TABLE ts_tbl(c1 TIMESTAMP_NTZ)") -connection.execute( - "INSERT INTO ts_tbl(c1) values(%s)", (specific_date,) -) -df = pd.read_sql_query("SELECT * FROM ts_tbl", engine) -assert df.c1.values[0] == specific_date +with engine.connect() as connection: + connection.exec_driver_sql( + "CREATE OR REPLACE TABLE ts_tbl(c1 TIMESTAMP_NTZ)") + connection.exec_driver_sql( + "INSERT INTO ts_tbl(c1) values(%s)", (specific_date,) + ) + df = pd.read_sql_query("SELECT * FROM ts_tbl", connection) + assert df.c1.values[0] == specific_date ``` The following `NumPy` data types are supported: diff --git a/ci/build.sh b/ci/build.sh index 4229506d..85d67df7 100755 --- a/ci/build.sh +++ b/ci/build.sh @@ -3,7 +3,7 @@ # Build snowflake-sqlalchemy set -o pipefail -PYTHON="python3.7" +PYTHON="python3.8" THIS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" SQLALCHEMY_DIR="$(dirname "${THIS_DIR}")" DIST_DIR="${SQLALCHEMY_DIR}/dist" @@ -11,8 +11,8 @@ DIST_DIR="${SQLALCHEMY_DIR}/dist" cd "$SQLALCHEMY_DIR" # Clean up previously built DIST_DIR if [ -d "${DIST_DIR}" ]; then - echo "[WARN] ${DIST_DIR} already existing, deleting it..." - rm -rf "${DIST_DIR}" + echo "[WARN] ${DIST_DIR} already existing, deleting it..." + rm -rf "${DIST_DIR}" fi # Constants and setup @@ -20,5 +20,8 @@ fi echo "[Info] Building snowflake-sqlalchemy with $PYTHON" # Clean up possible build artifacts rm -rf build generated_version.py -${PYTHON} -m pip install --upgrade pip setuptools wheel build -${PYTHON} -m build --outdir ${DIST_DIR} . +# ${PYTHON} -m pip install --upgrade pip setuptools wheel build +# ${PYTHON} -m build --outdir ${DIST_DIR} . +export UV_NO_CACHE=true +${PYTHON} -m pip install uv hatch +${PYTHON} -m hatch build diff --git a/ci/test_linux.sh b/ci/test_linux.sh index 695251e6..f5afc4fb 100755 --- a/ci/test_linux.sh +++ b/ci/test_linux.sh @@ -6,9 +6,9 @@ # - This script assumes that ../dist/repaired_wheels has the wheel(s) built for all versions to be tested # - This is the script that test_docker.sh runs inside of the docker container -PYTHON_VERSIONS="${1:-3.7 3.8 3.9 3.10 3.11}" -THIS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -SQLALCHEMY_DIR="$( dirname "${THIS_DIR}")" +PYTHON_VERSIONS="${1:-3.8 3.9 3.10 3.11}" +THIS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SQLALCHEMY_DIR="$(dirname "${THIS_DIR}")" # Install one copy of tox python3 -m pip install -U tox @@ -16,10 +16,10 @@ python3 -m pip install -U tox # Run tests cd $SQLALCHEMY_DIR for PYTHON_VERSION in ${PYTHON_VERSIONS}; do - echo "[Info] Testing with ${PYTHON_VERSION}" - SHORT_VERSION=$(python3 -c "print('${PYTHON_VERSION}'.replace('.', ''))") - SQLALCHEMY_WHL=$(ls $SQLALCHEMY_DIR/dist/snowflake_sqlalchemy-*-py2.py3-none-any.whl | sort -r | head -n 1) - TEST_ENVLIST=fix_lint,py${SHORT_VERSION}-ci,py${SHORT_VERSION}-coverage - echo "[Info] Running tox for ${TEST_ENVLIST}" - python3 -m tox -e ${TEST_ENVLIST} --installpkg ${SQLALCHEMY_WHL} + echo "[Info] Testing with ${PYTHON_VERSION}" + SHORT_VERSION=$(python3 -c "print('${PYTHON_VERSION}'.replace('.', ''))") + SQLALCHEMY_WHL=$(ls $SQLALCHEMY_DIR/dist/snowflake_sqlalchemy-*-py3-none-any.whl | sort -r | head -n 1) + TEST_ENVLIST=fix_lint,py${SHORT_VERSION}-ci,py${SHORT_VERSION}-coverage + echo "[Info] Running tox for ${TEST_ENVLIST}" + python3 -m tox -e ${TEST_ENVLIST} --installpkg ${SQLALCHEMY_WHL} done diff --git a/pyproject.toml b/pyproject.toml index d2316a44..58544017 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,13 +73,14 @@ exclude = ["/.github"] packages = ["src/snowflake"] [tool.hatch.envs.default] -extra-dependencies = ["SQLAlchemy>=1.4.19,<2.0.0"] +extra-dependencies = ["SQLAlchemy>=1.4.19,<2.1.0"] features = ["development", "pandas"] python = "3.8" installer = "uv" -[tool.hatch.envs.sa20] -extra-dependencies = ["SQLAlchemy>=1.4.19,<=2.1.0"] +[tool.hatch.envs.sa14] +extra-dependencies = ["SQLAlchemy>=1.4.19,<2.0.0"] +features = ["development", "pandas"] python = "3.8" [tool.hatch.envs.default.env-vars] @@ -93,6 +94,14 @@ test-dialect-compatibility = "pytest -ra -vvv --tb=short --cov snowflake.sqlalch gh-cache-sum = "python -VV | sha256sum | cut -d' ' -f1" check-import = "python -c 'import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)'" +[[tool.hatch.envs.release.matrix]] +python = ["3.8", "3.9", "3.10", "3.11", "3.12"] +features = ["development", "pandas"] + +[tool.hatch.envs.release.scripts] +test-dialect = "pytest -ra -vvv --tb=short --ignore=tests/sqlalchemy_test_suite tests/" +test-compatibility = "pytest -ra -vvv --tb=short tests/sqlalchemy_test_suite tests/" + [tool.ruff] line-length = 88 From 71308ce07465e106222923a31435e29dd022f2f5 Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Thu, 4 Jul 2024 10:31:43 +0200 Subject: [PATCH 03/21] SNOW-1516075: set release date to July 8th (#514) --- DESCRIPTION.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 782c426d..79971c53 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,7 +9,7 @@ Source code is also available at: # Release Notes -- v1.6.0(July 4, 2024) +- v1.6.0(July 8, 2024) - support for installing with SQLAlchemy 2.0.x - use `hatch` & `uv` for managing project virtual environments From bde2372c3a79ac799a318369fec34c48112758ec Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Thu, 4 Jul 2024 12:29:28 +0200 Subject: [PATCH 04/21] SNOW-1519492: add export PATH in build.sh script (#516) --- ci/build.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ci/build.sh b/ci/build.sh index 85d67df7..b63c8e01 100755 --- a/ci/build.sh +++ b/ci/build.sh @@ -16,12 +16,11 @@ if [ -d "${DIST_DIR}" ]; then fi # Constants and setup +export PATH=$PATH:$HOME/.local/bin echo "[Info] Building snowflake-sqlalchemy with $PYTHON" # Clean up possible build artifacts rm -rf build generated_version.py -# ${PYTHON} -m pip install --upgrade pip setuptools wheel build -# ${PYTHON} -m build --outdir ${DIST_DIR} . export UV_NO_CACHE=true ${PYTHON} -m pip install uv hatch ${PYTHON} -m hatch build From 411ae559f5d6402e4d01d2b07a5e1011153292ed Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Thu, 4 Jul 2024 15:59:00 +0200 Subject: [PATCH 05/21] SNOW-1519635: skip dialect tests in snowflake tests (#517) --- pyproject.toml | 2 +- tox.ini | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 58544017..9cdd9fb4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,7 +90,7 @@ SQLACHEMY_WARN_20 = "1" [tool.hatch.envs.default.scripts] check = "pre-commit run --all-files" test-dialect = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/" -test-dialect-compatibility = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml tests/sqlalchemy_test_suite tests/" +test-dialect-compatibility = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml tests/sqlalchemy_test_suite" gh-cache-sum = "python -VV | sha256sum | cut -d' ' -f1" check-import = "python -c 'import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)'" diff --git a/tox.ini b/tox.ini index 7f605627..2f7360a6 100644 --- a/tox.ini +++ b/tox.ini @@ -39,6 +39,7 @@ commands = pytest \ {env:SNOWFLAKE_PYTEST_OPTS:} \ --cov "snowflake.sqlalchemy" \ --junitxml {toxworkdir}/junit_{envname}.xml \ + --ignore=tests/sqlalchemy_test_suite \ {posargs:tests} pytest {env:SNOWFLAKE_PYTEST_OPTS:} \ --cov "snowflake.sqlalchemy" --cov-append \ @@ -74,6 +75,7 @@ passenv = PROGRAMDATA deps = {[testenv]deps} + tomlkit >= 1.12.0 pre-commit >= 2.9.0 skip_install = True commands = pre-commit run --all-files From 64fafbb7e94c5c256f51c918182d4e70412d4195 Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Thu, 4 Jul 2024 16:58:50 +0200 Subject: [PATCH 06/21] SNOW-1519766: drop tomlkit version for fix_lint (#518) --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 2f7360a6..102e2273 100644 --- a/tox.ini +++ b/tox.ini @@ -75,7 +75,7 @@ passenv = PROGRAMDATA deps = {[testenv]deps} - tomlkit >= 1.12.0 + tomlkit pre-commit >= 2.9.0 skip_install = True commands = pre-commit run --all-files From 305d2980cec33d37dbb9418684adcef72a3eaf76 Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Tue, 9 Jul 2024 12:38:23 +0200 Subject: [PATCH 07/21] mraba/update-python_publish-workflow (#520) SNOW-1519875: update publish branch workflow for v1.6.1 --- .github/workflows/python-publish.yml | 8 ++++---- DESCRIPTION.md | 4 ++++ src/snowflake/sqlalchemy/version.py | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 23116e7a..0a9f22bd 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -30,12 +30,12 @@ jobs: python-version: '3.x' - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install build + python -m pip install -U uv + python -m uv pip install -U hatch - name: Build package - run: python -m build + run: python -m hatch build --clean - name: Publish package - uses: pypa/gh-action-pypi-publish@e53eb8b103ffcb59469888563dc324e3c8ba6f06 + uses: pypa/gh-action-pypi-publish@release/v1 with: user: __token__ password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 79971c53..38cd70f7 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,6 +9,10 @@ Source code is also available at: # Release Notes +- v1.6.1(July 9, 2024) + + - Update internal project workflow with pypi publishing + - v1.6.0(July 8, 2024) - support for installing with SQLAlchemy 2.0.x diff --git a/src/snowflake/sqlalchemy/version.py b/src/snowflake/sqlalchemy/version.py index 56509b7d..d90f706b 100644 --- a/src/snowflake/sqlalchemy/version.py +++ b/src/snowflake/sqlalchemy/version.py @@ -3,4 +3,4 @@ # # Update this for the versions # Don't change the forth version number from None -VERSION = "1.6.0" +VERSION = "1.6.1" From dd7fc8aca7460fc669c7bb6667e45c83f615865e Mon Sep 17 00:00:00 2001 From: Angel Antonio Avalos Cisneros Date: Thu, 8 Aug 2024 14:53:13 -0700 Subject: [PATCH 08/21] sign artifacts before publish (#522) --- .github/workflows/python-publish.yml | 43 +++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 0a9f22bd..a1eb1a0c 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -13,7 +13,8 @@ on: types: [published] permissions: - contents: read + contents: write + id-token: write jobs: deploy: @@ -34,6 +35,46 @@ jobs: python -m uv pip install -U hatch - name: Build package run: python -m hatch build --clean + - name: List artifacts + run: ls ./dist + - name: Install sigstore + run: python -m pip install sigstore + - name: Signing + run: | + for dist in dist/*; do + dist_base="$(basename "${dist}")" + echo "dist: ${dist}" + echo "dist_base: ${dist_base}" + python -m \ + sigstore sign "${dist}" \ + --output-signature "${dist_base}.sig" \ + --output-certificate "${dist_base}.crt" \ + --bundle "${dist_base}.sigstore" + + # Verify using `.sig` `.crt` pair; + python -m \ + sigstore verify identity "${dist}" \ + --signature "${dist_base}.sig" \ + --cert "${dist_base}.crt" \ + --cert-oidc-issuer https://token.actions.githubusercontent.com \ + --cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/build_and_sign_demand.yml@${GITHUB_REF} + + # Verify using `.sigstore` bundle; + python -m \ + sigstore verify identity "${dist}" \ + --bundle "${dist_base}.sigstore" \ + --cert-oidc-issuer https://token.actions.githubusercontent.com \ + --cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/build_and_sign_demand.yml@${GITHUB_REF} + done + - name: List artifacts after sign + run: ls ./dist + - name: Copy files to release + run: | + gh release upload ${{ github.event.release.tag_name }} *.sigstore + gh release upload ${{ github.event.release.tag_name }} *.sig + gh release upload ${{ github.event.release.tag_name }} *.crt + env: + GITHUB_TOKEN: ${{ github.TOKEN }} - name: Publish package uses: pypa/gh-action-pypi-publish@release/v1 with: From fd8c29a08696feab256d51fc2c42773cfd74c590 Mon Sep 17 00:00:00 2001 From: Angel Antonio Avalos Cisneros Date: Mon, 19 Aug 2024 13:08:35 -0700 Subject: [PATCH 09/21] Update python-publish.yml (#526) --- .github/workflows/python-publish.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index a1eb1a0c..52f43106 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -57,14 +57,14 @@ jobs: --signature "${dist_base}.sig" \ --cert "${dist_base}.crt" \ --cert-oidc-issuer https://token.actions.githubusercontent.com \ - --cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/build_and_sign_demand.yml@${GITHUB_REF} + --cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/python-publish.yml@${GITHUB_REF} # Verify using `.sigstore` bundle; python -m \ sigstore verify identity "${dist}" \ --bundle "${dist_base}.sigstore" \ --cert-oidc-issuer https://token.actions.githubusercontent.com \ - --cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/build_and_sign_demand.yml@${GITHUB_REF} + --cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/python-publish.yml@${GITHUB_REF} done - name: List artifacts after sign run: ls ./dist From 957a4699cf21151070071969cd4996735e33001b Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Mon, 9 Sep 2024 03:38:57 -0600 Subject: [PATCH 10/21] Add tests to try max lob size in memory feature (#529) * Add test with large object --- pyproject.toml | 2 ++ tests/test_custom_types.py | 33 ++++++++++++++++++++++++++++++++- tests/test_orm.py | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9cdd9fb4..99aacbee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,6 +109,7 @@ line-length = 88 line-length = 88 [tool.pytest.ini_options] +addopts = "-m 'not feature_max_lob_size'" markers = [ # Optional dependency groups markers "lambda: AWS lambda tests", @@ -126,4 +127,5 @@ markers = [ "timeout: tests that need a timeout time", "internal: tests that could but should only run on our internal CI", "external: tests that could but should only run on our external CI", + "feature_max_lob_size: tests that could but should only run on our external CI", ] diff --git a/tests/test_custom_types.py b/tests/test_custom_types.py index a997ffe8..3961a5d3 100644 --- a/tests/test_custom_types.py +++ b/tests/test_custom_types.py @@ -2,7 +2,10 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from snowflake.sqlalchemy import custom_types +import pytest +from sqlalchemy import Column, Integer, MetaData, Table, text + +from snowflake.sqlalchemy import TEXT, custom_types def test_string_conversions(): @@ -34,3 +37,31 @@ def test_string_conversions(): sample = getattr(custom_types, type_)() if type_ in sf_custom_types: assert type_ == str(sample) + + +@pytest.mark.feature_max_lob_size +def test_create_table_with_text_type(engine_testaccount): + metadata = MetaData() + table_name = "test_max_lob_size_0" + test_max_lob_size = Table( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("full_name", TEXT(), server_default=text("id::varchar")), + ) + + metadata.create_all(engine_testaccount) + try: + assert test_max_lob_size is not None + + with engine_testaccount.connect() as conn: + with conn.begin(): + query = text(f"SELECT GET_DDL('TABLE', '{table_name}')") + result = conn.execute(query) + row = str(result.mappings().fetchone()) + assert ( + "VARCHAR(134217728)" in row + ), f"Expected VARCHAR(134217728) in {row}" + + finally: + test_max_lob_size.drop(engine_testaccount) diff --git a/tests/test_orm.py b/tests/test_orm.py index f53cd708..f51c9a90 100644 --- a/tests/test_orm.py +++ b/tests/test_orm.py @@ -7,6 +7,7 @@ import pytest from sqlalchemy import ( + TEXT, Column, Enum, ForeignKey, @@ -413,3 +414,34 @@ class Employee(Base): '[SELECT "Employee".uid FROM "Employee" JOIN LATERAL flatten(PARSE_JSON("Employee"' in caplog.text ) + + +@pytest.mark.feature_max_lob_size +def test_basic_table_with_large_lob_size_in_memory(engine_testaccount, sql_compiler): + Base = declarative_base() + + class User(Base): + __tablename__ = "user" + + id = Column(Integer, primary_key=True) + full_name = Column(TEXT(), server_default=text("id::varchar")) + + def __repr__(self): + return f"" + + Base.metadata.create_all(engine_testaccount) + + try: + assert User.__table__ is not None + + with engine_testaccount.connect() as conn: + with conn.begin(): + query = text(f"SELECT GET_DDL('TABLE', '{User.__tablename__}')") + result = conn.execute(query) + row = str(result.mappings().fetchone()) + assert ( + "VARCHAR(134217728)" in row + ), f"Expected VARCHAR(134217728) in {row}" + + finally: + Base.metadata.drop_all(engine_testaccount) From 16ad10fbb90d2fc98d3ab7218fe41e2ac708db33 Mon Sep 17 00:00:00 2001 From: Marcin Raba Date: Wed, 2 Oct 2024 10:49:48 +0200 Subject: [PATCH 11/21] SNOW-1655751: register overwritten functions under `snowflake` namespace (#532) * SNOW-1655751: register overwritten functions under `snowflake` namespace --- DESCRIPTION.md | 7 +++++-- src/snowflake/sqlalchemy/base.py | 3 +-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 38cd70f7..67b50ab0 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,6 +9,9 @@ Source code is also available at: # Release Notes +- 1.6.2 + - Fixed SAWarning when registering functions with existing name in default namespace + - v1.6.1(July 9, 2024) - Update internal project workflow with pypi publishing @@ -24,7 +27,7 @@ Source code is also available at: - v1.5.3(April 16, 2024) - - Limit SQLAlchemy to < 2.0.0 before releasing version compatible with 2.0 + - Limit SQLAlchemy to < 2.0.0 before releasing version compatible with 2.0 - v1.5.2(April 11, 2024) @@ -33,7 +36,7 @@ Source code is also available at: - v1.5.1(November 03, 2023) - - Fixed a compatibility issue with Snowflake Behavioral Change 1057 on outer lateral join, for more details check https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057. + - Fixed a compatibility issue with Snowflake Behavioral Change 1057 on outer lateral join, for more details check . - Fixed credentials with `externalbrowser` authentication not caching due to incorrect parsing of boolean query parameters. - This fixes other boolean parameter passing to driver as well. diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 1aaa881e..3e504f7b 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -184,7 +184,6 @@ def _join_determine_implicit_left_side(self, raw_columns, left, right, onclause) [element._from_objects for element in statement._where_criteria] ), ): - potential[from_clause] = () all_clauses = list(potential.keys()) @@ -1065,4 +1064,4 @@ def visit_GEOMETRY(self, type_, **kw): construct_arguments = [(Table, {"clusterby": None})] -functions.register_function("flatten", flatten) +functions.register_function("flatten", flatten, "snowflake") From b5af4e31611b4ac9e4467eee4ad6235a4b6d8d57 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Wed, 2 Oct 2024 09:49:44 -0600 Subject: [PATCH 12/21] Adding support for snowflake dynamic tables to SqlAlchemy Core (#531) * Add support for dynamic tables * Update DESCRIPTION.md * Remove unnesary code to support dynamic tables in sqlalchemy 1.4 * Fix bug to support sqlalchemy v1.4 * Add syrupy * Remove non necessary parameter * Add snapshots --- DESCRIPTION.md | 5 +- pyproject.toml | 1 + src/snowflake/sqlalchemy/__init__.py | 7 + src/snowflake/sqlalchemy/_constants.py | 1 + src/snowflake/sqlalchemy/base.py | 30 ++- src/snowflake/sqlalchemy/snowdialect.py | 3 +- src/snowflake/sqlalchemy/sql/__init__.py | 3 + .../sqlalchemy/sql/custom_schema/__init__.py | 6 + .../sql/custom_schema/custom_table_base.py | 51 +++++ .../sql/custom_schema/dynamic_table.py | 86 +++++++++ .../sql/custom_schema/options/__init__.py | 9 + .../sql/custom_schema/options/as_query.py | 62 ++++++ .../sql/custom_schema/options/table_option.py | 26 +++ .../options/table_option_base.py | 30 +++ .../sql/custom_schema/options/target_lag.py | 60 ++++++ .../sql/custom_schema/options/warehouse.py | 51 +++++ .../sql/custom_schema/table_from_query.py | 60 ++++++ .../test_compile_dynamic_table.ambr | 13 ++ .../test_reflect_dynamic_table.ambr | 4 + tests/test_compile_dynamic_table.py | 177 ++++++++++++++++++ tests/test_create_dynamic_table.py | 93 +++++++++ tests/test_reflect_dynamic_table.py | 88 +++++++++ 22 files changed, 860 insertions(+), 6 deletions(-) create mode 100644 src/snowflake/sqlalchemy/sql/__init__.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/__init__.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py create mode 100644 tests/__snapshots__/test_compile_dynamic_table.ambr create mode 100644 tests/__snapshots__/test_reflect_dynamic_table.ambr create mode 100644 tests/test_compile_dynamic_table.py create mode 100644 tests/test_create_dynamic_table.py create mode 100644 tests/test_reflect_dynamic_table.py diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 67b50ab0..205685f1 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,7 +9,10 @@ Source code is also available at: # Release Notes -- 1.6.2 +- (Unreleased) + + - Add support for dynamic tables and required options + - Fixed SAWarning when registering functions with existing name in default namespace - Fixed SAWarning when registering functions with existing name in default namespace - v1.6.1(July 9, 2024) diff --git a/pyproject.toml b/pyproject.toml index 99aacbee..4fe06a9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ development = [ "pytz", "numpy", "mock", + "syrupy==4.6.1", ] pandas = ["snowflake-connector-python[pandas]"] diff --git a/src/snowflake/sqlalchemy/__init__.py b/src/snowflake/sqlalchemy/__init__.py index 9df6aaa2..30cd140c 100644 --- a/src/snowflake/sqlalchemy/__init__.py +++ b/src/snowflake/sqlalchemy/__init__.py @@ -61,6 +61,8 @@ VARBINARY, VARIANT, ) +from .sql.custom_schema import DynamicTable +from .sql.custom_schema.options import AsQuery, TargetLag, TimeUnit, Warehouse from .util import _url as URL base.dialect = dialect = snowdialect.dialect @@ -113,4 +115,9 @@ "ExternalStage", "CreateStage", "CreateFileFormat", + "DynamicTable", + "AsQuery", + "TargetLag", + "TimeUnit", + "Warehouse", ) diff --git a/src/snowflake/sqlalchemy/_constants.py b/src/snowflake/sqlalchemy/_constants.py index 46af4454..839745ee 100644 --- a/src/snowflake/sqlalchemy/_constants.py +++ b/src/snowflake/sqlalchemy/_constants.py @@ -10,3 +10,4 @@ APPLICATION_NAME = "SnowflakeSQLAlchemy" SNOWFLAKE_SQLALCHEMY_VERSION = VERSION +DIALECT_NAME = "snowflake" diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 3e504f7b..56631728 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -18,9 +18,16 @@ from sqlalchemy.sql.elements import quoted_name from sqlalchemy.sql.selectable import Lateral, SelectState -from .compat import IS_VERSION_20, args_reducer, string_types -from .custom_commands import AWSBucket, AzureContainer, ExternalStage +from snowflake.sqlalchemy._constants import DIALECT_NAME +from snowflake.sqlalchemy.compat import IS_VERSION_20, args_reducer, string_types +from snowflake.sqlalchemy.custom_commands import ( + AWSBucket, + AzureContainer, + ExternalStage, +) + from .functions import flatten +from .sql.custom_schema.options.table_option_base import TableOptionBase from .util import ( _find_left_clause_to_join_from, _set_connection_interpolate_empty_sequences, @@ -878,7 +885,7 @@ def get_column_specification(self, column, **kwargs): return " ".join(colspec) - def post_create_table(self, table): + def handle_cluster_by(self, table): """ Handles snowflake-specific ``CREATE TABLE ... CLUSTER BY`` syntax. @@ -908,7 +915,7 @@ def post_create_table(self, table): """ text = "" - info = table.dialect_options["snowflake"] + info = table.dialect_options[DIALECT_NAME] cluster = info.get("clusterby") if cluster: text += " CLUSTER BY ({})".format( @@ -916,6 +923,21 @@ def post_create_table(self, table): ) return text + def post_create_table(self, table): + text = self.handle_cluster_by(table) + options = [ + option + for _, option in table.dialect_options[DIALECT_NAME].items() + if isinstance(option, TableOptionBase) + ] + options.sort( + key=lambda x: (x.__priority__.value, x.__option_name__), reverse=True + ) + for option in options: + text += "\t" + option.render_option(self) + + return text + def visit_create_stage(self, create_stage, **kw): """ This visitor will create the SQL representation for a CREATE STAGE command. diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index 04305a00..b0472eb6 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -42,6 +42,7 @@ from snowflake.connector.constants import UTF8 from snowflake.sqlalchemy.compat import returns_unicode +from ._constants import DIALECT_NAME from .base import ( SnowflakeCompiler, SnowflakeDDLCompiler, @@ -119,7 +120,7 @@ class SnowflakeDialect(default.DefaultDialect): - name = "snowflake" + name = DIALECT_NAME driver = "snowflake" max_identifier_length = 255 cte_follows_insert = True diff --git a/src/snowflake/sqlalchemy/sql/__init__.py b/src/snowflake/sqlalchemy/sql/__init__.py new file mode 100644 index 00000000..ef416f64 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py b/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py new file mode 100644 index 00000000..4bbac246 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py @@ -0,0 +1,6 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from .dynamic_table import DynamicTable + +__all__ = ["DynamicTable"] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py new file mode 100644 index 00000000..0c04f33f --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import typing +from typing import Any + +from sqlalchemy.exc import ArgumentError +from sqlalchemy.sql.schema import MetaData, SchemaItem, Table + +from ..._constants import DIALECT_NAME +from ...compat import IS_VERSION_20 +from ...custom_commands import NoneType +from .options.table_option import TableOption + + +class CustomTableBase(Table): + __table_prefix__ = "" + _support_primary_and_foreign_keys = True + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + if self.__table_prefix__ != "": + prefixes = kw.get("prefixes", []) + [self.__table_prefix__] + kw.update(prefixes=prefixes) + if not IS_VERSION_20 and hasattr(super(), "_init"): + super()._init(name, metadata, *args, **kw) + else: + super().__init__(name, metadata, *args, **kw) + + if not kw.get("autoload_with", False): + self._validate_table() + + def _validate_table(self): + if not self._support_primary_and_foreign_keys and ( + self.primary_key or self.foreign_keys + ): + raise ArgumentError( + f"Primary key and foreign keys are not supported in {self.__table_prefix__} TABLE." + ) + + return True + + def _get_dialect_option(self, option_name: str) -> typing.Optional[TableOption]: + if option_name in self.dialect_options[DIALECT_NAME]: + return self.dialect_options[DIALECT_NAME][option_name] + return NoneType diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py new file mode 100644 index 00000000..7d0a02e6 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py @@ -0,0 +1,86 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import typing +from typing import Any + +from sqlalchemy.exc import ArgumentError +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .options.target_lag import TargetLag +from .options.warehouse import Warehouse +from .table_from_query import TableFromQueryBase + + +class DynamicTable(TableFromQueryBase): + """ + A class representing a dynamic table with configurable options and settings. + + The `DynamicTable` class allows for the creation and querying of tables with + specific options, such as `Warehouse` and `TargetLag`. + + While it does not support reflection at this time, it provides a flexible + interface for creating dynamic tables and management. + + """ + + __table_prefix__ = "DYNAMIC" + + _support_primary_and_foreign_keys = False + + @property + def warehouse(self) -> typing.Optional[Warehouse]: + return self._get_dialect_option(Warehouse.__option_name__) + + @property + def target_lag(self) -> typing.Optional[TargetLag]: + return self._get_dialect_option(TargetLag.__option_name__) + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + super().__init__(name, metadata, *args, **kw) + + def _init( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + super().__init__(name, metadata, *args, **kw) + + def _validate_table(self): + missing_attributes = [] + if self.target_lag is NoneType: + missing_attributes.append("TargetLag") + if self.warehouse is NoneType: + missing_attributes.append("Warehouse") + if self.as_query is NoneType: + missing_attributes.append("AsQuery") + if missing_attributes: + raise ArgumentError( + "DYNAMIC TABLE must have the following arguments: %s" + % ", ".join(missing_attributes) + ) + super()._validate_table() + + def __repr__(self) -> str: + return "DynamicTable(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + [repr(self.target_lag)] + + [repr(self.warehouse)] + + [repr(self.as_query)] + + [f"{k}={repr(getattr(self, k))}" for k in ["schema"]] + ) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py new file mode 100644 index 00000000..052e2d96 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py @@ -0,0 +1,9 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from .as_query import AsQuery +from .target_lag import TargetLag, TimeUnit +from .warehouse import Warehouse + +__all__ = ["Warehouse", "AsQuery", "TargetLag", "TimeUnit"] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py new file mode 100644 index 00000000..68076af9 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py @@ -0,0 +1,62 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Union + +from sqlalchemy.sql import Selectable + +from .table_option import TableOption +from .table_option_base import Priority + + +class AsQuery(TableOption): + """Class to represent an AS clause in tables. + This configuration option is used to specify the query from which the table is created. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-table#create-table-as-select-also-referred-to-as-ctas + + + AsQuery example usage using an input string: + DynamicTable( + "sometable", metadata, + Column("name", String(50)), + Column("address", String(100)), + AsQuery('select name, address from existing_table where name = "test"') + ) + + AsQuery example usage using a selectable statement: + DynamicTable( + "sometable", + Base.metadata, + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery(select(test_table_1).where(test_table_1.c.id == 23)) + ) + + """ + + __option_name__ = "as_query" + __priority__ = Priority.LOWEST + + def __init__(self, query: Union[str, Selectable]) -> None: + r"""Construct an as_query object. + + :param \*expressions: + AS + + """ + self.query = query + + @staticmethod + def template() -> str: + return "AS %s" + + def get_expression(self): + if isinstance(self.query, Selectable): + return self.query.compile(compile_kwargs={"literal_binds": True}) + return self.query + + def render_option(self, compiler) -> str: + return AsQuery.template() % (self.get_expression()) + + def __repr__(self) -> str: + return "Query(%s)" % self.get_expression() diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py new file mode 100644 index 00000000..7ac27575 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py @@ -0,0 +1,26 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Any + +from sqlalchemy import exc +from sqlalchemy.sql.base import SchemaEventTarget +from sqlalchemy.sql.schema import SchemaItem, Table + +from snowflake.sqlalchemy._constants import DIALECT_NAME + +from .table_option_base import TableOptionBase + + +class TableOption(TableOptionBase, SchemaItem): + def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: + if self.__option_name__ == "default": + raise exc.SQLAlchemyError(f"{self.__class__.__name__} does not has a name") + if not isinstance(parent, Table): + raise exc.SQLAlchemyError( + f"{self.__class__.__name__} option can only be applied to Table" + ) + parent.dialect_options[DIALECT_NAME][self.__option_name__] = self + + def _set_table_option_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: + pass diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py new file mode 100644 index 00000000..54008ec8 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py @@ -0,0 +1,30 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from enum import Enum + + +class Priority(Enum): + LOWEST = 0 + VERY_LOW = 1 + LOW = 2 + MEDIUM = 4 + HIGH = 6 + VERY_HIGH = 7 + HIGHEST = 8 + + +class TableOptionBase: + __option_name__ = "default" + __visit_name__ = __option_name__ + __priority__ = Priority.MEDIUM + + @staticmethod + def template() -> str: + raise NotImplementedError + + def get_expression(self): + raise NotImplementedError + + def render_option(self, compiler) -> str: + raise NotImplementedError diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py new file mode 100644 index 00000000..4331a4cb --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py @@ -0,0 +1,60 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# from enum import Enum +from enum import Enum +from typing import Optional + +from .table_option import TableOption +from .table_option_base import Priority + + +class TimeUnit(Enum): + SECONDS = "seconds" + MINUTES = "minutes" + HOURS = "hour" + DAYS = "days" + + +class TargetLag(TableOption): + """Class to represent the target lag clause. + This configuration option is used to specify the target lag time for the dynamic table. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table + + + Target lag example usage: + DynamicTable("sometable", metadata, + Column("name", String(50)), + Column("address", String(100)), + TargetLag(20, TimeUnit.MINUTES), + ) + """ + + __option_name__ = "target_lag" + __priority__ = Priority.HIGH + + def __init__( + self, + time: Optional[int] = 0, + unit: Optional[TimeUnit] = TimeUnit.MINUTES, + down_stream: Optional[bool] = False, + ) -> None: + self.time = time + self.unit = unit + self.down_stream = down_stream + + @staticmethod + def template() -> str: + return "TARGET_LAG = %s" + + def get_expression(self): + return ( + f"'{str(self.time)} {str(self.unit.value)}'" + if not self.down_stream + else "DOWNSTREAM" + ) + + def render_option(self, compiler) -> str: + return TargetLag.template() % (self.get_expression()) + + def __repr__(self) -> str: + return "TargetLag(%s)" % self.get_expression() diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py new file mode 100644 index 00000000..a5b8cce0 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional + +from .table_option import TableOption +from .table_option_base import Priority + + +class Warehouse(TableOption): + """Class to represent the warehouse clause. + This configuration option is used to specify the warehouse for the dynamic table. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table + + + Warehouse example usage: + DynamicTable("sometable", metadata, + Column("name", String(50)), + Column("address", String(100)), + Warehouse('my_warehouse_name') + ) + """ + + __option_name__ = "warehouse" + __priority__ = Priority.HIGH + + def __init__( + self, + name: Optional[str], + ) -> None: + r"""Construct a Warehouse object. + + :param \*expressions: + Dynamic table warehouse option. + WAREHOUSE = + + """ + self.name = name + + @staticmethod + def template() -> str: + return "WAREHOUSE = %s" + + def get_expression(self): + return self.name + + def render_option(self, compiler) -> str: + return Warehouse.template() % (self.get_expression()) + + def __repr__(self) -> str: + return "Warehouse(%s)" % self.get_expression() diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py new file mode 100644 index 00000000..60e8995f --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py @@ -0,0 +1,60 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import typing +from typing import Any, Optional + +from sqlalchemy.sql import Selectable +from sqlalchemy.sql.schema import Column, MetaData, SchemaItem +from sqlalchemy.util import NoneType + +from .custom_table_base import CustomTableBase +from .options.as_query import AsQuery + + +class TableFromQueryBase(CustomTableBase): + + @property + def as_query(self): + return self._get_dialect_option(AsQuery.__option_name__) + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + items = [item for item in args] + as_query: AsQuery = self.__get_as_query_from_items(items) + if ( + as_query is not NoneType + and isinstance(as_query.query, Selectable) + and not self.__has_defined_columns(items) + ): + columns = self.__create_columns_from_selectable(as_query.query) + args = items + columns + super().__init__(name, metadata, *args, **kw) + + def __get_as_query_from_items( + self, items: typing.List[SchemaItem] + ) -> Optional[AsQuery]: + for item in items: + if isinstance(item, AsQuery): + return item + return NoneType + + def __has_defined_columns(self, items: typing.List[SchemaItem]) -> bool: + for item in items: + if isinstance(item, Column): + return True + + def __create_columns_from_selectable( + self, selectable: Selectable + ) -> Optional[typing.List[Column]]: + if not isinstance(selectable, Selectable): + return + columns: typing.List[Column] = [] + for _, c in selectable.exported_columns.items(): + columns += [Column(c.name, c.type)] + return columns diff --git a/tests/__snapshots__/test_compile_dynamic_table.ambr b/tests/__snapshots__/test_compile_dynamic_table.ambr new file mode 100644 index 00000000..81c7f90f --- /dev/null +++ b/tests/__snapshots__/test_compile_dynamic_table.ambr @@ -0,0 +1,13 @@ +# serializer version: 1 +# name: test_compile_dynamic_table + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm + "CREATE DYNAMIC TABLE test_dynamic_table_orm (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm_with_str_keys + "CREATE DYNAMIC TABLE test_dynamic_table_orm_2 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_selectable + "CREATE DYNAMIC TABLE dynamic_test_table_1 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT test_table_1.id, test_table_1.name FROM test_table_1 WHERE test_table_1.id = 23" +# --- diff --git a/tests/__snapshots__/test_reflect_dynamic_table.ambr b/tests/__snapshots__/test_reflect_dynamic_table.ambr new file mode 100644 index 00000000..d4cc22b5 --- /dev/null +++ b/tests/__snapshots__/test_reflect_dynamic_table.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_compile_dynamic_table + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- diff --git a/tests/test_compile_dynamic_table.py b/tests/test_compile_dynamic_table.py new file mode 100644 index 00000000..73ce54aa --- /dev/null +++ b/tests/test_compile_dynamic_table.py @@ -0,0 +1,177 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import ( + Column, + ForeignKeyConstraint, + Integer, + MetaData, + String, + Table, + exc, + select, +) +from sqlalchemy.orm import declarative_base +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import GEOMETRY, DynamicTable +from snowflake.sqlalchemy.sql.custom_schema.options.as_query import AsQuery +from snowflake.sqlalchemy.sql.custom_schema.options.target_lag import ( + TargetLag, + TimeUnit, +) +from snowflake.sqlalchemy.sql.custom_schema.options.warehouse import Warehouse + + +def test_compile_dynamic_table(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + test_geometry = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery("SELECT * FROM table"), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_without_required_args(sql_compiler): + with pytest.raises( + exc.ArgumentError, + match="DYNAMIC TABLE must have the following arguments: TargetLag, " + "Warehouse, AsQuery", + ): + DynamicTable( + "test_dynamic_table", + MetaData(), + Column("id", Integer, primary_key=True), + Column("geom", GEOMETRY), + ) + + +def test_compile_dynamic_table_with_primary_key(sql_compiler): + with pytest.raises( + exc.ArgumentError, + match="Primary key and foreign keys are not supported in DYNAMIC TABLE.", + ): + DynamicTable( + "test_dynamic_table", + MetaData(), + Column("id", Integer, primary_key=True), + Column("geom", GEOMETRY), + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery("SELECT * FROM table"), + ) + + +def test_compile_dynamic_table_with_foreign_key(sql_compiler): + with pytest.raises( + exc.ArgumentError, + match="Primary key and foreign keys are not supported in DYNAMIC TABLE.", + ): + DynamicTable( + "test_dynamic_table", + MetaData(), + Column("id", Integer), + Column("geom", GEOMETRY), + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery("SELECT * FROM table"), + ForeignKeyConstraint(["id"], ["table.id"]), + ) + + +def test_compile_dynamic_table_orm(sql_compiler, snapshot): + Base = declarative_base() + metadata = MetaData() + table_name = "test_dynamic_table_orm" + test_dynamic_table_orm = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("name", String), + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery("SELECT * FROM table"), + ) + + class TestDynamicTableOrm(Base): + __table__ = test_dynamic_table_orm + __mapper_args__ = { + "primary_key": [test_dynamic_table_orm.c.id, test_dynamic_table_orm.c.name] + } + + def __repr__(self): + return f"" + + value = CreateTable(TestDynamicTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_orm_with_str_keys(sql_compiler, snapshot): + Base = declarative_base() + + class TestDynamicTableOrm(Base): + __tablename__ = "test_dynamic_table_orm_2" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return DynamicTable(name, metadata, *arg, **kw) + + __table_args__ = ( + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery("SELECT * FROM table"), + ) + + id = Column(Integer) + name = Column(String) + + __mapper_args__ = {"primary_key": [id, name]} + + def __repr__(self): + return f"" + + value = CreateTable(TestDynamicTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_with_selectable(sql_compiler, snapshot): + Base = declarative_base() + + test_table_1 = Table( + "test_table_1", + Base.metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + ) + + dynamic_test_table = DynamicTable( + "dynamic_test_table_1", + Base.metadata, + TargetLag(10, TimeUnit.SECONDS), + Warehouse("warehouse"), + AsQuery(select(test_table_1).where(test_table_1.c.id == 23)), + ) + + value = CreateTable(dynamic_test_table) + + actual = sql_compiler(value) + + assert actual == snapshot diff --git a/tests/test_create_dynamic_table.py b/tests/test_create_dynamic_table.py new file mode 100644 index 00000000..4e6c48ca --- /dev/null +++ b/tests/test_create_dynamic_table.py @@ -0,0 +1,93 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from sqlalchemy import Column, Integer, MetaData, String, Table, select + +from snowflake.sqlalchemy import DynamicTable +from snowflake.sqlalchemy.sql.custom_schema.options.as_query import AsQuery +from snowflake.sqlalchemy.sql.custom_schema.options.target_lag import ( + TargetLag, + TimeUnit, +) +from snowflake.sqlalchemy.sql.custom_schema.options.warehouse import Warehouse + + +def test_create_dynamic_table(engine_testaccount, db_parameters): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + + dynamic_test_table_1 = DynamicTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + TargetLag(1, TimeUnit.HOURS), + Warehouse(warehouse), + AsQuery("SELECT id, name from test_table_1;"), + ) + + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table_1) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) + + +def test_create_dynamic_table_without_dynamictable_class( + engine_testaccount, db_parameters +): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + + dynamic_test_table_1 = Table( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + TargetLag(1, TimeUnit.HOURS), + Warehouse(warehouse), + AsQuery("SELECT id, name from test_table_1;"), + prefixes=["DYNAMIC"], + ) + + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table_1) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/test_reflect_dynamic_table.py b/tests/test_reflect_dynamic_table.py new file mode 100644 index 00000000..8a4a8445 --- /dev/null +++ b/tests/test_reflect_dynamic_table.py @@ -0,0 +1,88 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from sqlalchemy import Column, Integer, MetaData, String, Table, select + +from snowflake.sqlalchemy import DynamicTable +from snowflake.sqlalchemy.custom_commands import NoneType + + +def test_simple_reflection_dynamic_table_as_table(engine_testaccount, db_parameters): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + create_table_sql = f""" + CREATE DYNAMIC TABLE dynamic_test_table (id INT, name VARCHAR) + TARGET_LAG = '20 minutes' + WAREHOUSE = {warehouse} + AS SELECT id, name from test_table_1; + """ + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + dynamic_test_table = Table( + "dynamic_test_table", metadata, autoload_with=engine_testaccount + ) + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) + + +def test_simple_reflection_without_options_loading(engine_testaccount, db_parameters): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + create_table_sql = f""" + CREATE DYNAMIC TABLE dynamic_test_table (id INT, name VARCHAR) + TARGET_LAG = '20 minutes' + WAREHOUSE = {warehouse} + AS SELECT id, name from test_table_1; + """ + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + dynamic_test_table = DynamicTable( + "dynamic_test_table", metadata, autoload_with=engine_testaccount + ) + + # TODO: Add support for loading options when table is reflected + assert dynamic_test_table.warehouse is NoneType + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) From 43c6b563e462884faf7b7063bbf7fe10de7a5f60 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Tue, 8 Oct 2024 07:34:47 -0600 Subject: [PATCH 13/21] Add support for hybrid tables and indexes (#533) * Add support for hybrid tables * Update DESCRIPTION.md and add support for indexes --- .github/workflows/build_test.yml | 6 + DESCRIPTION.md | 2 +- pyproject.toml | 3 +- src/snowflake/sqlalchemy/__init__.py | 3 +- src/snowflake/sqlalchemy/snowdialect.py | 138 +++++++++++++++- .../sqlalchemy/sql/custom_schema/__init__.py | 3 +- .../sql/custom_schema/custom_table_base.py | 23 ++- .../sql/custom_schema/custom_table_prefix.py | 13 ++ .../sql/custom_schema/dynamic_table.py | 3 +- .../sql/custom_schema/hybrid_table.py | 67 ++++++++ tests/__snapshots__/test_orm.ambr | 4 + tests/custom_tables/__init__.py | 2 + .../test_compile_dynamic_table.ambr | 13 ++ .../test_compile_hybrid_table.ambr | 7 + .../test_create_hybrid_table.ambr | 7 + .../test_reflect_hybrid_table.ambr | 4 + .../test_compile_dynamic_table.py | 4 +- .../test_compile_hybrid_table.py | 52 ++++++ .../test_create_dynamic_table.py | 0 .../custom_tables/test_create_hybrid_table.py | 95 +++++++++++ .../test_reflect_dynamic_table.py | 0 .../test_reflect_hybrid_table.py | 65 ++++++++ tests/test_core.py | 7 +- tests/test_index_reflection.py | 34 ++++ tests/test_orm.py | 155 +++++++++++++++++- tests/test_pandas.py | 2 +- 26 files changed, 679 insertions(+), 33 deletions(-) create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py create mode 100644 tests/__snapshots__/test_orm.ambr create mode 100644 tests/custom_tables/__init__.py create mode 100644 tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr create mode 100644 tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr create mode 100644 tests/custom_tables/__snapshots__/test_create_hybrid_table.ambr create mode 100644 tests/custom_tables/__snapshots__/test_reflect_hybrid_table.ambr rename tests/{ => custom_tables}/test_compile_dynamic_table.py (96%) create mode 100644 tests/custom_tables/test_compile_hybrid_table.py rename tests/{ => custom_tables}/test_create_dynamic_table.py (100%) create mode 100644 tests/custom_tables/test_create_hybrid_table.py rename tests/{ => custom_tables}/test_reflect_dynamic_table.py (100%) create mode 100644 tests/custom_tables/test_reflect_hybrid_table.py create mode 100644 tests/test_index_reflection.py diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index d7f7832b..5e9823f2 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -111,6 +111,9 @@ jobs: run: | gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py + - name: Run test for AWS + run: hatch run test-dialect-aws + if: matrix.cloud-provider == 'aws' - name: Run tests run: hatch run test-dialect - uses: actions/upload-artifact@v4 @@ -203,6 +206,9 @@ jobs: python -m pip install -U uv python -m uv pip install -U hatch python -m hatch env create default + - name: Run test for AWS + run: hatch run sa14:test-dialect-aws + if: matrix.cloud-provider == 'aws' - name: Run tests run: hatch run sa14:test-dialect - uses: actions/upload-artifact@v4 diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 205685f1..58c2dfe2 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -12,7 +12,7 @@ Source code is also available at: - (Unreleased) - Add support for dynamic tables and required options - - Fixed SAWarning when registering functions with existing name in default namespace + - Add support for hybrid tables - Fixed SAWarning when registering functions with existing name in default namespace - v1.6.1(July 9, 2024) diff --git a/pyproject.toml b/pyproject.toml index 4fe06a9b..6c72f683 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,6 +92,7 @@ SQLACHEMY_WARN_20 = "1" check = "pre-commit run --all-files" test-dialect = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/" test-dialect-compatibility = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml tests/sqlalchemy_test_suite" +test-dialect-aws = "pytest -m \"aws\" -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/" gh-cache-sum = "python -VV | sha256sum | cut -d' ' -f1" check-import = "python -c 'import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)'" @@ -110,7 +111,7 @@ line-length = 88 line-length = 88 [tool.pytest.ini_options] -addopts = "-m 'not feature_max_lob_size'" +addopts = "-m 'not feature_max_lob_size and not aws'" markers = [ # Optional dependency groups markers "lambda: AWS lambda tests", diff --git a/src/snowflake/sqlalchemy/__init__.py b/src/snowflake/sqlalchemy/__init__.py index 30cd140c..0afd44a5 100644 --- a/src/snowflake/sqlalchemy/__init__.py +++ b/src/snowflake/sqlalchemy/__init__.py @@ -61,7 +61,7 @@ VARBINARY, VARIANT, ) -from .sql.custom_schema import DynamicTable +from .sql.custom_schema import DynamicTable, HybridTable from .sql.custom_schema.options import AsQuery, TargetLag, TimeUnit, Warehouse from .util import _url as URL @@ -120,4 +120,5 @@ "TargetLag", "TimeUnit", "Warehouse", + "HybridTable", ) diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index b0472eb6..f2fb9b18 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -65,6 +65,7 @@ _CUSTOM_Float, _CUSTOM_Time, ) +from .sql.custom_schema.custom_table_prefix import CustomTablePrefix from .util import ( _update_connection_application_name, parse_url_boolean, @@ -352,14 +353,6 @@ def _map_name_to_idx(result): name_to_idx[col[0]] = idx return name_to_idx - @reflection.cache - def get_indexes(self, connection, table_name, schema=None, **kw): - """ - Gets all indexes - """ - # no index is supported by Snowflake - return [] - @reflection.cache def get_check_constraints(self, connection, table_name, schema, **kw): # check constraints are not supported by Snowflake @@ -895,6 +888,129 @@ def get_table_comment(self, connection, table_name, schema=None, **kw): ) } + def get_multi_indexes( + self, + connection, + *, + schema, + filter_names, + **kw, + ): + """ + Gets the indexes definition + """ + + table_prefixes = self.get_multi_prefixes( + connection, schema, filter_prefix=CustomTablePrefix.HYBRID.name + ) + if len(table_prefixes) == 0: + return [] + schema = schema or self.default_schema_name + if not schema: + result = connection.execute( + text("SHOW /* sqlalchemy:get_multi_indexes */ INDEXES") + ) + else: + result = connection.execute( + text( + f"SHOW /* sqlalchemy:get_multi_indexes */ INDEXES IN SCHEMA {self._denormalize_quote_join(schema)}" + ) + ) + + n2i = self.__class__._map_name_to_idx(result) + indexes = {} + + for row in result.cursor.fetchall(): + table = self.normalize_name(str(row[n2i["table"]])) + if ( + row[n2i["name"]] == f'SYS_INDEX_{row[n2i["table"]]}_PRIMARY' + or table not in filter_names + or (schema, table) not in table_prefixes + or ( + (schema, table) in table_prefixes + and CustomTablePrefix.HYBRID.name + not in table_prefixes[(schema, table)] + ) + ): + continue + index = { + "name": row[n2i["name"]], + "unique": row[n2i["is_unique"]] == "Y", + "column_names": row[n2i["columns"]], + "include_columns": row[n2i["included_columns"]], + "dialect_options": {}, + } + if (schema, table) in indexes: + indexes[(schema, table)] = indexes[(schema, table)].append(index) + else: + indexes[(schema, table)] = [index] + + return list(indexes.items()) + + def _value_or_default(self, data, table, schema): + table = self.normalize_name(str(table)) + dic_data = dict(data) + if (schema, table) in dic_data: + return dic_data[(schema, table)] + else: + return [] + + def get_prefixes_from_data(self, n2i, row, **kw): + prefixes_found = [] + for valid_prefix in CustomTablePrefix: + key = f"is_{valid_prefix.name.lower()}" + if key in n2i and row[n2i[key]] == "Y": + prefixes_found.append(valid_prefix.name) + return prefixes_found + + @reflection.cache + def get_multi_prefixes( + self, connection, schema, table_name=None, filter_prefix=None, **kw + ): + """ + Gets all table prefixes + """ + schema = schema or self.default_schema_name + filter = f"LIKE '{table_name}'" if table_name else "" + if schema: + result = connection.execute( + text( + f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES IN SCHEMA {schema}" + ) + ) + else: + result = connection.execute( + text( + f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES LIKE '{table_name}'" + ) + ) + + n2i = self.__class__._map_name_to_idx(result) + tables_prefixes = {} + for row in result.cursor.fetchall(): + table = self.normalize_name(str(row[n2i["name"]])) + table_prefixes = self.get_prefixes_from_data(n2i, row) + if filter_prefix and filter_prefix not in table_prefixes: + continue + if (schema, table) in tables_prefixes: + tables_prefixes[(schema, table)].append(table_prefixes) + else: + tables_prefixes[(schema, table)] = table_prefixes + + return tables_prefixes + + @reflection.cache + def get_indexes(self, connection, tablename, schema, **kw): + """ + Gets the indexes definition + """ + table_name = self.normalize_name(str(tablename)) + data = self.get_multi_indexes( + connection=connection, schema=schema, filter_names=[table_name], **kw + ) + + return self._value_or_default(data, table_name, schema) + def connect(self, *cargs, **cparams): return ( super().connect( @@ -912,8 +1028,12 @@ def connect(self, *cargs, **cparams): @sa_vnt.listens_for(Table, "before_create") def check_table(table, connection, _ddl_runner, **kw): + from .sql.custom_schema.hybrid_table import HybridTable + + if HybridTable.is_equal_type(table): # noqa + return True if isinstance(_ddl_runner.dialect, SnowflakeDialect) and table.indexes: - raise NotImplementedError("Snowflake does not support indexes") + raise NotImplementedError("Only Snowflake Hybrid Tables supports indexes") dialect = SnowflakeDialect diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py b/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py index 4bbac246..66b9270f 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py @@ -2,5 +2,6 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # from .dynamic_table import DynamicTable +from .hybrid_table import HybridTable -__all__ = ["DynamicTable"] +__all__ = ["DynamicTable", "HybridTable"] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py index 0c04f33f..b61c270d 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py @@ -10,12 +10,17 @@ from ..._constants import DIALECT_NAME from ...compat import IS_VERSION_20 from ...custom_commands import NoneType +from .custom_table_prefix import CustomTablePrefix from .options.table_option import TableOption class CustomTableBase(Table): - __table_prefix__ = "" - _support_primary_and_foreign_keys = True + __table_prefixes__: typing.List[CustomTablePrefix] = [] + _support_primary_and_foreign_keys: bool = True + + @property + def table_prefixes(self) -> typing.List[str]: + return [prefix.name for prefix in self.__table_prefixes__] def __init__( self, @@ -24,8 +29,8 @@ def __init__( *args: SchemaItem, **kw: Any, ) -> None: - if self.__table_prefix__ != "": - prefixes = kw.get("prefixes", []) + [self.__table_prefix__] + if len(self.__table_prefixes__) > 0: + prefixes = kw.get("prefixes", []) + self.table_prefixes kw.update(prefixes=prefixes) if not IS_VERSION_20 and hasattr(super(), "_init"): super()._init(name, metadata, *args, **kw) @@ -40,7 +45,7 @@ def _validate_table(self): self.primary_key or self.foreign_keys ): raise ArgumentError( - f"Primary key and foreign keys are not supported in {self.__table_prefix__} TABLE." + f"Primary key and foreign keys are not supported in {' '.join(self.table_prefixes)} TABLE." ) return True @@ -49,3 +54,11 @@ def _get_dialect_option(self, option_name: str) -> typing.Optional[TableOption]: if option_name in self.dialect_options[DIALECT_NAME]: return self.dialect_options[DIALECT_NAME][option_name] return NoneType + + @classmethod + def is_equal_type(cls, table: Table) -> bool: + for prefix in cls.__table_prefixes__: + if prefix.name not in table._prefixes: + return False + + return True diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py new file mode 100644 index 00000000..de7835d1 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +from enum import Enum + + +class CustomTablePrefix(Enum): + DEFAULT = 0 + EXTERNAL = 1 + EVENT = 2 + HYBRID = 3 + ICEBERG = 4 + DYNAMIC = 5 diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py index 7d0a02e6..1a2248fc 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py @@ -10,6 +10,7 @@ from snowflake.sqlalchemy.custom_commands import NoneType +from .custom_table_prefix import CustomTablePrefix from .options.target_lag import TargetLag from .options.warehouse import Warehouse from .table_from_query import TableFromQueryBase @@ -27,7 +28,7 @@ class DynamicTable(TableFromQueryBase): """ - __table_prefix__ = "DYNAMIC" + __table_prefixes__ = [CustomTablePrefix.DYNAMIC] _support_primary_and_foreign_keys = False diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py new file mode 100644 index 00000000..bd49a420 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from typing import Any + +from sqlalchemy.exc import ArgumentError +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .custom_table_base import CustomTableBase +from .custom_table_prefix import CustomTablePrefix + + +class HybridTable(CustomTableBase): + """ + A class representing a hybrid table with configurable options and settings. + + The `HybridTable` class allows for the creation and querying of OLTP Snowflake Tables . + + While it does not support reflection at this time, it provides a flexible + interface for creating dynamic tables and management. + """ + + __table_prefixes__ = [CustomTablePrefix.HYBRID] + + _support_primary_and_foreign_keys = True + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + super().__init__(name, metadata, *args, **kw) + + def _init( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + super().__init__(name, metadata, *args, **kw) + + def _validate_table(self): + missing_attributes = [] + if self.key is NoneType: + missing_attributes.append("Primary Key") + if missing_attributes: + raise ArgumentError( + "HYBRID TABLE must have the following arguments: %s" + % ", ".join(missing_attributes) + ) + super()._validate_table() + + def __repr__(self) -> str: + return "HybridTable(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + [f"{k}={repr(getattr(self, k))}" for k in ["schema"]] + ) diff --git a/tests/__snapshots__/test_orm.ambr b/tests/__snapshots__/test_orm.ambr new file mode 100644 index 00000000..2116e9e9 --- /dev/null +++ b/tests/__snapshots__/test_orm.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_orm_one_to_many_relationship_with_hybrid_table + ProgrammingError('(snowflake.connector.errors.ProgrammingError) 200009 (22000): Foreign key constraint "SYS_INDEX_HB_TBL_ADDRESS_FOREIGN_KEY_USER_ID_HB_TBL_USER_ID" was violated.') +# --- diff --git a/tests/custom_tables/__init__.py b/tests/custom_tables/__init__.py new file mode 100644 index 00000000..d43f066c --- /dev/null +++ b/tests/custom_tables/__init__.py @@ -0,0 +1,2 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. diff --git a/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr b/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr new file mode 100644 index 00000000..81c7f90f --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr @@ -0,0 +1,13 @@ +# serializer version: 1 +# name: test_compile_dynamic_table + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm + "CREATE DYNAMIC TABLE test_dynamic_table_orm (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm_with_str_keys + "CREATE DYNAMIC TABLE test_dynamic_table_orm_2 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_selectable + "CREATE DYNAMIC TABLE dynamic_test_table_1 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT test_table_1.id, test_table_1.name FROM test_table_1 WHERE test_table_1.id = 23" +# --- diff --git a/tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr b/tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr new file mode 100644 index 00000000..9412fb45 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr @@ -0,0 +1,7 @@ +# serializer version: 1 +# name: test_compile_hybrid_table + 'CREATE HYBRID TABLE test_hybrid_table (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tgeom GEOMETRY, \tPRIMARY KEY (id))' +# --- +# name: test_compile_hybrid_table_orm + 'CREATE HYBRID TABLE test_hybrid_table_orm (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tPRIMARY KEY (id))' +# --- diff --git a/tests/custom_tables/__snapshots__/test_create_hybrid_table.ambr b/tests/custom_tables/__snapshots__/test_create_hybrid_table.ambr new file mode 100644 index 00000000..696ff9c8 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_create_hybrid_table.ambr @@ -0,0 +1,7 @@ +# serializer version: 1 +# name: test_create_hybrid_table + "[(1, 'test')]" +# --- +# name: test_create_hybrid_table_with_multiple_index + ProgrammingError("(snowflake.connector.errors.ProgrammingError) 391480 (0A000): Another index is being built on table 'TEST_HYBRID_TABLE_WITH_MULTIPLE_INDEX'. Only one index can be built at a time. Either cancel the other index creation or wait until it is complete.") +# --- diff --git a/tests/custom_tables/__snapshots__/test_reflect_hybrid_table.ambr b/tests/custom_tables/__snapshots__/test_reflect_hybrid_table.ambr new file mode 100644 index 00000000..6f6cd395 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_reflect_hybrid_table.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_simple_reflection_hybrid_table_as_table + 'CREATE TABLE test_hybrid_table_reflection (\tid DECIMAL(38, 0) NOT NULL, \tname VARCHAR(16777216), \tCONSTRAINT demo_name PRIMARY KEY (id))' +# --- diff --git a/tests/test_compile_dynamic_table.py b/tests/custom_tables/test_compile_dynamic_table.py similarity index 96% rename from tests/test_compile_dynamic_table.py rename to tests/custom_tables/test_compile_dynamic_table.py index 73ce54aa..16a039e7 100644 --- a/tests/test_compile_dynamic_table.py +++ b/tests/custom_tables/test_compile_dynamic_table.py @@ -121,11 +121,13 @@ def __repr__(self): assert actual == snapshot -def test_compile_dynamic_table_orm_with_str_keys(sql_compiler, snapshot): +def test_compile_dynamic_table_orm_with_str_keys(sql_compiler, db_parameters, snapshot): Base = declarative_base() + schema = db_parameters["schema"] class TestDynamicTableOrm(Base): __tablename__ = "test_dynamic_table_orm_2" + __table_args__ = {"schema": schema} @classmethod def __table_cls__(cls, name, metadata, *arg, **kw): diff --git a/tests/custom_tables/test_compile_hybrid_table.py b/tests/custom_tables/test_compile_hybrid_table.py new file mode 100644 index 00000000..f1af6dc2 --- /dev/null +++ b/tests/custom_tables/test_compile_hybrid_table.py @@ -0,0 +1,52 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import Column, Integer, MetaData, String +from sqlalchemy.orm import declarative_base +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import GEOMETRY, HybridTable + + +@pytest.mark.aws +def test_compile_hybrid_table(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_hybrid_table" + test_geometry = HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + Column("geom", GEOMETRY), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +@pytest.mark.aws +def test_compile_hybrid_table_orm(sql_compiler, snapshot): + Base = declarative_base() + + class TestHybridTableOrm(Base): + __tablename__ = "test_hybrid_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"" + + value = CreateTable(TestHybridTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot diff --git a/tests/test_create_dynamic_table.py b/tests/custom_tables/test_create_dynamic_table.py similarity index 100% rename from tests/test_create_dynamic_table.py rename to tests/custom_tables/test_create_dynamic_table.py diff --git a/tests/custom_tables/test_create_hybrid_table.py b/tests/custom_tables/test_create_hybrid_table.py new file mode 100644 index 00000000..43ae3ab6 --- /dev/null +++ b/tests/custom_tables/test_create_hybrid_table.py @@ -0,0 +1,95 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +import sqlalchemy.exc +from sqlalchemy import Column, Index, Integer, MetaData, String, select +from sqlalchemy.orm import Session, declarative_base + +from snowflake.sqlalchemy import HybridTable + + +@pytest.mark.aws +def test_create_hybrid_table(engine_testaccount, db_parameters, snapshot): + metadata = MetaData() + table_name = "test_create_hybrid_table" + + dynamic_test_table_1 = HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = dynamic_test_table_1.insert().values(id=1, name="test") + conn.execute(ins) + conn.commit() + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table_1) + results_hybrid_table = conn.execute(s).fetchall() + assert str(results_hybrid_table) == snapshot + finally: + metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_create_hybrid_table_with_multiple_index( + engine_testaccount, db_parameters, snapshot, sql_compiler +): + metadata = MetaData() + table_name = "test_hybrid_table_with_multiple_index" + + hybrid_test_table_1 = HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String, index=True), + Column("name2", String), + Column("name3", String), + ) + + metadata.create_all(engine_testaccount) + + index = Index("idx_col34", hybrid_test_table_1.c.name2, hybrid_test_table_1.c.name3) + + with pytest.raises(sqlalchemy.exc.ProgrammingError) as exc_info: + index.create(engine_testaccount) + try: + assert exc_info.value == snapshot + finally: + metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_create_hybrid_table_with_orm(sql_compiler, engine_testaccount): + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestHybridTableOrm(Base): + __tablename__ = "test_hybrid_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"({self.id!r}, {self.name!r})" + + Base.metadata.create_all(engine_testaccount) + + try: + instance = TestHybridTableOrm(id=0, name="name_example") + session.add(instance) + session.commit() + data = session.query(TestHybridTableOrm).all() + assert str(data) == "[(0, 'name_example')]" + finally: + Base.metadata.drop_all(engine_testaccount) diff --git a/tests/test_reflect_dynamic_table.py b/tests/custom_tables/test_reflect_dynamic_table.py similarity index 100% rename from tests/test_reflect_dynamic_table.py rename to tests/custom_tables/test_reflect_dynamic_table.py diff --git a/tests/custom_tables/test_reflect_hybrid_table.py b/tests/custom_tables/test_reflect_hybrid_table.py new file mode 100644 index 00000000..4a777bf0 --- /dev/null +++ b/tests/custom_tables/test_reflect_hybrid_table.py @@ -0,0 +1,65 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import MetaData, Table +from sqlalchemy.sql.ddl import CreateTable + + +@pytest.mark.aws +def test_simple_reflection_hybrid_table_as_table( + engine_testaccount, db_parameters, sql_compiler, snapshot +): + metadata = MetaData() + table_name = "test_hybrid_table_reflection" + + create_table_sql = f""" + CREATE HYBRID TABLE {table_name} (id INT primary key, name VARCHAR, INDEX index_name (name)); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + hybrid_test_table = Table(table_name, metadata, autoload_with=engine_testaccount) + + constraint = hybrid_test_table.constraints.pop() + constraint.name = "demo_name" + hybrid_test_table.constraints.add(constraint) + + try: + with engine_testaccount.connect(): + value = CreateTable(hybrid_test_table) + + actual = sql_compiler(value) + + # Prefixes reflection not supported, example: "HYBRID, DYNAMIC" + assert actual == snapshot + + finally: + metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_reflect_hybrid_table_with_index( + engine_testaccount, db_parameters, sql_compiler +): + metadata = MetaData() + schema = db_parameters["schema"] + + table_name = "test_hybrid_table_2" + index_name = "INDEX_NAME_2" + + create_table_sql = f""" + CREATE HYBRID TABLE {table_name} (id INT primary key, name VARCHAR, INDEX {index_name} (name)); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + table = Table(table_name, metadata, schema=schema, autoload_with=engine_testaccount) + + try: + assert len(table.indexes) == 1 and table.indexes.pop().name == index_name + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/test_core.py b/tests/test_core.py index 179133c8..15840838 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -502,19 +502,20 @@ def test_inspect_column(engine_testaccount): users.drop(engine_testaccount) -def test_get_indexes(engine_testaccount): +def test_get_indexes(engine_testaccount, db_parameters): """ Tests get indexes - NOTE: Snowflake doesn't support indexes + NOTE: Only Snowflake Hybrid Tables support indexes """ + schema = db_parameters["schema"] metadata = MetaData() users, addresses = _create_users_addresses_tables_without_sequence( engine_testaccount, metadata ) try: inspector = inspect(engine_testaccount) - assert inspector.get_indexes("users") == [] + assert inspector.get_indexes("users", schema) == [] finally: addresses.drop(engine_testaccount) diff --git a/tests/test_index_reflection.py b/tests/test_index_reflection.py new file mode 100644 index 00000000..09f5cfe7 --- /dev/null +++ b/tests/test_index_reflection.py @@ -0,0 +1,34 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import MetaData +from sqlalchemy.engine import reflection + + +@pytest.mark.aws +def test_indexes_reflection(engine_testaccount, db_parameters, sql_compiler): + metadata = MetaData() + + table_name = "test_hybrid_table_2" + index_name = "INDEX_NAME_2" + schema = db_parameters["schema"] + + create_table_sql = f""" + CREATE HYBRID TABLE {table_name} (id INT primary key, name VARCHAR, INDEX {index_name} (name)); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + insp = reflection.Inspector.from_engine(engine_testaccount) + + try: + with engine_testaccount.connect(): + # Prefixes reflection not supported, example: "HYBRID, DYNAMIC" + indexes = insp.get_indexes(table_name, schema) + assert len(indexes) == 1 + assert indexes[0].get("name") == index_name + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/test_orm.py b/tests/test_orm.py index f51c9a90..cb3a7768 100644 --- a/tests/test_orm.py +++ b/tests/test_orm.py @@ -14,12 +14,15 @@ Integer, Sequence, String, + exc, func, select, text, ) from sqlalchemy.orm import Session, declarative_base, relationship +from snowflake.sqlalchemy import HybridTable + def test_basic_orm(engine_testaccount): """ @@ -56,14 +59,15 @@ def __repr__(self): Base.metadata.drop_all(engine_testaccount) -def test_orm_one_to_many_relationship(engine_testaccount): +def test_orm_one_to_many_relationship(engine_testaccount, db_parameters): """ Tests One to Many relationship """ Base = declarative_base() + prefix = "tbl_" class User(Base): - __tablename__ = "user" + __tablename__ = prefix + "user" id = Column(Integer, Sequence("user_id_seq"), primary_key=True) name = Column(String) @@ -73,13 +77,13 @@ def __repr__(self): return f"" class Address(Base): - __tablename__ = "address" + __tablename__ = prefix + "address" id = Column(Integer, Sequence("address_id_seq"), primary_key=True) email_address = Column(String, nullable=False) - user_id = Column(Integer, ForeignKey("user.id")) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) - user = relationship("User", backref="addresses") + user = relationship(User, backref="addresses") def __repr__(self): return f"" @@ -123,14 +127,79 @@ def __repr__(self): Base.metadata.drop_all(engine_testaccount) +@pytest.mark.aws +def test_orm_one_to_many_relationship_with_hybrid_table(engine_testaccount, snapshot): + """ + Tests One to Many relationship + """ + Base = declarative_base() + + class User(Base): + __tablename__ = "hb_tbl_user" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + name = Column(String) + fullname = Column(String) + + def __repr__(self): + return f"" + + class Address(Base): + __tablename__ = "hb_tbl_address" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, Sequence("address_id_seq"), primary_key=True) + email_address = Column(String, nullable=False) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) + + user = relationship(User, backref="addresses") + + def __repr__(self): + return f"" + + Base.metadata.create_all(engine_testaccount) + + try: + jack = User(name="jack", fullname="Jack Bean") + assert jack.addresses == [], "one to many record is empty list" + + jack.addresses = [ + Address(email_address="jack@gmail.com"), + Address(email_address="j25@yahoo.com"), + Address(email_address="jack@hotmail.com"), + ] + + session = Session(bind=engine_testaccount) + session.add(jack) # cascade each Address into the Session as well + session.commit() + + session.delete(jack) + + with pytest.raises(exc.ProgrammingError) as exc_info: + session.query(Address).all() + + assert exc_info.value == snapshot, "Iceberg Table enforce FK constraint" + + finally: + Base.metadata.drop_all(engine_testaccount) + + def test_delete_cascade(engine_testaccount): """ Test delete cascade """ Base = declarative_base() + prefix = "tbl_" class User(Base): - __tablename__ = "user" + __tablename__ = prefix + "user" id = Column(Integer, Sequence("user_id_seq"), primary_key=True) name = Column(String) @@ -144,13 +213,81 @@ def __repr__(self): return f"" class Address(Base): - __tablename__ = "address" + __tablename__ = prefix + "address" + + id = Column(Integer, Sequence("address_id_seq"), primary_key=True) + email_address = Column(String, nullable=False) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) + + user = relationship(User, back_populates="addresses") + + def __repr__(self): + return f"" + + Base.metadata.create_all(engine_testaccount) + + try: + jack = User(name="jack", fullname="Jack Bean") + assert jack.addresses == [], "one to many record is empty list" + + jack.addresses = [ + Address(email_address="jack@gmail.com"), + Address(email_address="j25@yahoo.com"), + Address(email_address="jack@hotmail.com"), + ] + + session = Session(bind=engine_testaccount) + session.add(jack) # cascade each Address into the Session as well + session.commit() + + got_jack = session.query(User).first() + assert got_jack == jack + + session.delete(jack) + got_addresses = session.query(Address).all() + assert len(got_addresses) == 0, "no address record" + finally: + Base.metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_delete_cascade_hybrid_table(engine_testaccount): + """ + Test delete cascade + """ + Base = declarative_base() + prefix = "hb_tbl_" + + class User(Base): + __tablename__ = prefix + "user" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + name = Column(String) + fullname = Column(String) + + addresses = relationship( + "Address", back_populates="user", cascade="all, delete, delete-orphan" + ) + + def __repr__(self): + return f"" + + class Address(Base): + __tablename__ = prefix + "address" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) id = Column(Integer, Sequence("address_id_seq"), primary_key=True) email_address = Column(String, nullable=False) - user_id = Column(Integer, ForeignKey("user.id")) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) - user = relationship("User", back_populates="addresses") + user = relationship(User, back_populates="addresses") def __repr__(self): return f"" diff --git a/tests/test_pandas.py b/tests/test_pandas.py index 63cd6d0e..2a6b9f1b 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -169,7 +169,7 @@ def test_no_indexes(engine_testaccount, db_parameters): con=conn, if_exists="replace", ) - assert str(exc.value) == "Snowflake does not support indexes" + assert str(exc.value) == "Only Snowflake Hybrid Tables supports indexes" def test_timezone(db_parameters, engine_testaccount, engine_testaccount_with_numpy): From e78319725d4b96ea205ef1264b744c65eb37853d Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Tue, 22 Oct 2024 09:18:29 -0600 Subject: [PATCH 14/21] Add generic options in order to support Iceberg Table in following PR (#537) * Add generic options and remove schema options --- .pre-commit-config.yaml | 1 + DESCRIPTION.md | 2 + src/snowflake/sqlalchemy/__init__.py | 28 +++- src/snowflake/sqlalchemy/base.py | 36 ++-- src/snowflake/sqlalchemy/exc.py | 74 +++++++++ .../sql/custom_schema/custom_table_base.py | 67 ++++++-- .../sql/custom_schema/dynamic_table.py | 84 +++++++--- .../sql/custom_schema/hybrid_table.py | 29 ++-- .../sql/custom_schema/options/__init__.py | 29 +++- .../sql/custom_schema/options/as_query.py | 62 ------- .../custom_schema/options/as_query_option.py | 63 +++++++ .../options/identifier_option.py | 63 +++++++ .../options/invalid_table_option.py | 25 +++ .../custom_schema/options/keyword_option.py | 65 ++++++++ .../sql/custom_schema/options/keywords.py | 14 ++ .../custom_schema/options/literal_option.py | 67 ++++++++ .../sql/custom_schema/options/table_option.py | 91 +++++++++-- .../options/table_option_base.py | 30 ---- .../sql/custom_schema/options/target_lag.py | 60 ------- .../options/target_lag_option.py | 94 +++++++++++ .../sql/custom_schema/options/warehouse.py | 51 ------ .../sql/custom_schema/table_from_query.py | 22 +-- .../test_compile_dynamic_table.ambr | 29 +++- .../test_create_dynamic_table.ambr | 7 + .../__snapshots__/test_generic_options.ambr | 13 ++ .../test_compile_dynamic_table.py | 154 ++++++++++++++---- .../test_create_dynamic_table.py | 75 ++++++--- tests/custom_tables/test_generic_options.py | 83 ++++++++++ .../test_reflect_dynamic_table.py | 2 +- tests/test_core.py | 1 + 30 files changed, 1058 insertions(+), 363 deletions(-) create mode 100644 src/snowflake/sqlalchemy/exc.py delete mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/as_query_option.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/identifier_option.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/invalid_table_option.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/keyword_option.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/keywords.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/literal_option.py delete mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py delete mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag_option.py delete mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py create mode 100644 tests/custom_tables/__snapshots__/test_create_dynamic_table.ambr create mode 100644 tests/custom_tables/__snapshots__/test_generic_options.ambr create mode 100644 tests/custom_tables/test_generic_options.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 83172eb8..b7370b74 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,7 @@ repos: rev: v4.5.0 hooks: - id: trailing-whitespace + exclude: '\.ambr$' - id: end-of-file-fixer - id: check-yaml exclude: .github/repo_meta.yaml diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 58c2dfe2..909d52cf 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -14,6 +14,8 @@ Source code is also available at: - Add support for dynamic tables and required options - Add support for hybrid tables - Fixed SAWarning when registering functions with existing name in default namespace + - Update options to be defined in key arguments instead of arguments. + - Add support for refresh_mode option in DynamicTable - v1.6.1(July 9, 2024) diff --git a/src/snowflake/sqlalchemy/__init__.py b/src/snowflake/sqlalchemy/__init__.py index 0afd44a5..e53f9b74 100644 --- a/src/snowflake/sqlalchemy/__init__.py +++ b/src/snowflake/sqlalchemy/__init__.py @@ -62,7 +62,16 @@ VARIANT, ) from .sql.custom_schema import DynamicTable, HybridTable -from .sql.custom_schema.options import AsQuery, TargetLag, TimeUnit, Warehouse +from .sql.custom_schema.options import ( + AsQueryOption, + IdentifierOption, + KeywordOption, + LiteralOption, + SnowflakeKeyword, + TableOptionKey, + TargetLagOption, + TimeUnit, +) from .util import _url as URL base.dialect = dialect = snowdialect.dialect @@ -70,6 +79,7 @@ __version__ = importlib_metadata.version("snowflake-sqlalchemy") __all__ = ( + # Custom Types "BIGINT", "BINARY", "BOOLEAN", @@ -104,6 +114,7 @@ "TINYINT", "VARBINARY", "VARIANT", + # Custom Commands "MergeInto", "CSVFormatter", "JSONFormatter", @@ -115,10 +126,17 @@ "ExternalStage", "CreateStage", "CreateFileFormat", + # Custom Tables + "HybridTable", "DynamicTable", - "AsQuery", - "TargetLag", + # Custom Table Options + "AsQueryOption", + "TargetLagOption", + "LiteralOption", + "IdentifierOption", + "KeywordOption", + # Enums "TimeUnit", - "Warehouse", - "HybridTable", + "TableOptionKey", + "SnowflakeKeyword", ) diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 56631728..023f7afb 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -5,6 +5,7 @@ import itertools import operator import re +from typing import List from sqlalchemy import exc as sa_exc from sqlalchemy import inspect, sql @@ -26,8 +27,13 @@ ExternalStage, ) +from .exc import ( + CustomOptionsAreOnlySupportedOnSnowflakeTables, + UnexpectedOptionTypeError, +) from .functions import flatten -from .sql.custom_schema.options.table_option_base import TableOptionBase +from .sql.custom_schema.custom_table_base import CustomTableBase +from .sql.custom_schema.options.table_option import TableOption from .util import ( _find_left_clause_to_join_from, _set_connection_interpolate_empty_sequences, @@ -925,16 +931,24 @@ def handle_cluster_by(self, table): def post_create_table(self, table): text = self.handle_cluster_by(table) - options = [ - option - for _, option in table.dialect_options[DIALECT_NAME].items() - if isinstance(option, TableOptionBase) - ] - options.sort( - key=lambda x: (x.__priority__.value, x.__option_name__), reverse=True - ) - for option in options: - text += "\t" + option.render_option(self) + options = [] + invalid_options: List[str] = [] + + for key, option in table.dialect_options[DIALECT_NAME].items(): + if isinstance(option, TableOption): + options.append(option) + elif key not in ["clusterby", "*"]: + invalid_options.append(key) + + if len(invalid_options) > 0: + raise UnexpectedOptionTypeError(sorted(invalid_options)) + + if isinstance(table, CustomTableBase): + options.sort(key=lambda x: (x.priority.value, x.option_name), reverse=True) + for option in options: + text += "\t" + option.render_option(self) + elif len(options) > 0: + raise CustomOptionsAreOnlySupportedOnSnowflakeTables() return text diff --git a/src/snowflake/sqlalchemy/exc.py b/src/snowflake/sqlalchemy/exc.py new file mode 100644 index 00000000..898de279 --- /dev/null +++ b/src/snowflake/sqlalchemy/exc.py @@ -0,0 +1,74 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +from typing import List + +from sqlalchemy.exc import ArgumentError + + +class NoPrimaryKeyError(ArgumentError): + def __init__(self, target: str): + super().__init__(f"Table {target} required primary key.") + + +class UnsupportedPrimaryKeysAndForeignKeysError(ArgumentError): + def __init__(self, target: str): + super().__init__(f"Primary key and foreign keys are not supported in {target}.") + + +class RequiredParametersNotProvidedError(ArgumentError): + def __init__(self, target: str, parameters: List[str]): + super().__init__( + f"{target} requires the following parameters: %s." % ", ".join(parameters) + ) + + +class UnexpectedTableOptionKeyError(ArgumentError): + def __init__(self, expected: str, actual: str): + super().__init__(f"Expected table option {expected} but got {actual}.") + + +class OptionKeyNotProvidedError(ArgumentError): + def __init__(self, target: str): + super().__init__( + f"Expected option key in {target} option but got NoneType instead." + ) + + +class UnexpectedOptionParameterTypeError(ArgumentError): + def __init__(self, parameter_name: str, target: str, types: List[str]): + super().__init__( + f"Parameter {parameter_name} of {target} requires to be one" + f" of following types: {', '.join(types)}." + ) + + +class CustomOptionsAreOnlySupportedOnSnowflakeTables(ArgumentError): + def __init__(self): + super().__init__( + "Identifier, Literal, TargetLag and other custom options are only supported on Snowflake tables." + ) + + +class UnexpectedOptionTypeError(ArgumentError): + def __init__(self, options: List[str]): + super().__init__( + f"The following options are either unsupported or should be defined using a Snowflake table: {', '.join(options)}." + ) + + +class InvalidTableParameterTypeError(ArgumentError): + def __init__(self, name: str, input_type: str, expected_types: List[str]): + expected_types_str = "', '".join(expected_types) + super().__init__( + f"Invalid parameter type '{input_type}' provided for '{name}'. " + f"Expected one of the following types: '{expected_types_str}'.\n" + ) + + +class MultipleErrors(ArgumentError): + def __init__(self, errors): + self.errors = errors + + def __str__(self): + return "".join(str(e) for e in self.errors) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py index b61c270d..671c6957 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py @@ -2,21 +2,29 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # import typing -from typing import Any +from typing import Any, List -from sqlalchemy.exc import ArgumentError from sqlalchemy.sql.schema import MetaData, SchemaItem, Table from ..._constants import DIALECT_NAME from ...compat import IS_VERSION_20 from ...custom_commands import NoneType +from ...exc import ( + MultipleErrors, + NoPrimaryKeyError, + RequiredParametersNotProvidedError, + UnsupportedPrimaryKeysAndForeignKeysError, +) from .custom_table_prefix import CustomTablePrefix -from .options.table_option import TableOption +from .options.invalid_table_option import InvalidTableOption +from .options.table_option import TableOption, TableOptionKey class CustomTableBase(Table): __table_prefixes__: typing.List[CustomTablePrefix] = [] _support_primary_and_foreign_keys: bool = True + _enforce_primary_keys: bool = False + _required_parameters: List[TableOptionKey] = [] @property def table_prefixes(self) -> typing.List[str]: @@ -32,7 +40,9 @@ def __init__( if len(self.__table_prefixes__) > 0: prefixes = kw.get("prefixes", []) + self.table_prefixes kw.update(prefixes=prefixes) + if not IS_VERSION_20 and hasattr(super(), "_init"): + kw.pop("_no_init", True) super()._init(name, metadata, *args, **kw) else: super().__init__(name, metadata, *args, **kw) @@ -41,19 +51,56 @@ def __init__( self._validate_table() def _validate_table(self): + exceptions: List[Exception] = [] + + for _, option in self.dialect_options[DIALECT_NAME].items(): + if isinstance(option, InvalidTableOption): + exceptions.append(option.exception) + + if isinstance(self.key, NoneType) and self._enforce_primary_keys: + exceptions.append(NoPrimaryKeyError(self.__class__.__name__)) + missing_parameters: List[str] = [] + + for required_parameter in self._required_parameters: + if isinstance(self._get_dialect_option(required_parameter), NoneType): + missing_parameters.append(required_parameter.name.lower()) + if missing_parameters: + exceptions.append( + RequiredParametersNotProvidedError( + self.__class__.__name__, missing_parameters + ) + ) + if not self._support_primary_and_foreign_keys and ( self.primary_key or self.foreign_keys ): - raise ArgumentError( - f"Primary key and foreign keys are not supported in {' '.join(self.table_prefixes)} TABLE." + exceptions.append( + UnsupportedPrimaryKeysAndForeignKeysError(self.__class__.__name__) ) - return True + if len(exceptions) > 1: + exceptions.sort(key=lambda e: str(e)) + raise MultipleErrors(exceptions) + elif len(exceptions) == 1: + raise exceptions[0] + + def _get_dialect_option( + self, option_name: TableOptionKey + ) -> typing.Optional[TableOption]: + if option_name.value in self.dialect_options[DIALECT_NAME]: + return self.dialect_options[DIALECT_NAME][option_name.value] + return None - def _get_dialect_option(self, option_name: str) -> typing.Optional[TableOption]: - if option_name in self.dialect_options[DIALECT_NAME]: - return self.dialect_options[DIALECT_NAME][option_name] - return NoneType + def _as_dialect_options( + self, table_options: List[TableOption] + ) -> typing.Dict[str, TableOption]: + result = {} + for table_option in table_options: + if isinstance(table_option, TableOption) and isinstance( + table_option.option_name, str + ): + result[DIALECT_NAME + "_" + table_option.option_name] = table_option + return result @classmethod def is_equal_type(cls, table: Table) -> bool: diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py index 1a2248fc..6db4312d 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py @@ -3,16 +3,21 @@ # import typing -from typing import Any +from typing import Any, Union -from sqlalchemy.exc import ArgumentError from sqlalchemy.sql.schema import MetaData, SchemaItem -from snowflake.sqlalchemy.custom_commands import NoneType - from .custom_table_prefix import CustomTablePrefix -from .options.target_lag import TargetLag -from .options.warehouse import Warehouse +from .options import ( + IdentifierOption, + IdentifierOptionType, + KeywordOptionType, + LiteralOption, + TableOptionKey, + TargetLagOption, + TargetLagOptionType, +) +from .options.keyword_option import KeywordOption from .table_from_query import TableFromQueryBase @@ -26,29 +31,69 @@ class DynamicTable(TableFromQueryBase): While it does not support reflection at this time, it provides a flexible interface for creating dynamic tables and management. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table + + Example using option values: + DynamicTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + target_lag=(1, TimeUnit.HOURS), + warehouse='warehouse_name', + refresh_mode=SnowflakeKeyword.AUTO + as_query="SELECT id, name from test_table_1;" + ) + + Example using full options: + DynamicTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + target_lag=TargetLag(1, TimeUnit.HOURS), + warehouse=Identifier('warehouse_name'), + refresh_mode=KeywordOption(SnowflakeKeyword.AUTO) + as_query=AsQuery("SELECT id, name from test_table_1;") + ) """ __table_prefixes__ = [CustomTablePrefix.DYNAMIC] - _support_primary_and_foreign_keys = False + _required_parameters = [ + TableOptionKey.WAREHOUSE, + TableOptionKey.AS_QUERY, + TableOptionKey.TARGET_LAG, + ] @property - def warehouse(self) -> typing.Optional[Warehouse]: - return self._get_dialect_option(Warehouse.__option_name__) + def warehouse(self) -> typing.Optional[LiteralOption]: + return self._get_dialect_option(TableOptionKey.WAREHOUSE) @property - def target_lag(self) -> typing.Optional[TargetLag]: - return self._get_dialect_option(TargetLag.__option_name__) + def target_lag(self) -> typing.Optional[TargetLagOption]: + return self._get_dialect_option(TableOptionKey.TARGET_LAG) def __init__( self, name: str, metadata: MetaData, *args: SchemaItem, + warehouse: IdentifierOptionType = None, + target_lag: Union[TargetLagOptionType, KeywordOptionType] = None, + refresh_mode: KeywordOptionType = None, **kw: Any, ) -> None: if kw.get("_no_init", True): return + + options = [ + IdentifierOption.create(TableOptionKey.WAREHOUSE, warehouse), + TargetLagOption.create(target_lag), + KeywordOption.create(TableOptionKey.REFRESH_MODE, refresh_mode), + ] + + kw.update(self._as_dialect_options(options)) super().__init__(name, metadata, *args, **kw) def _init( @@ -58,22 +103,7 @@ def _init( *args: SchemaItem, **kw: Any, ) -> None: - super().__init__(name, metadata, *args, **kw) - - def _validate_table(self): - missing_attributes = [] - if self.target_lag is NoneType: - missing_attributes.append("TargetLag") - if self.warehouse is NoneType: - missing_attributes.append("Warehouse") - if self.as_query is NoneType: - missing_attributes.append("AsQuery") - if missing_attributes: - raise ArgumentError( - "DYNAMIC TABLE must have the following arguments: %s" - % ", ".join(missing_attributes) - ) - super()._validate_table() + self.__init__(name, metadata, *args, _no_init=False, **kw) def __repr__(self) -> str: return "DynamicTable(%s)" % ", ".join( diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py index bd49a420..b7c29e78 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py @@ -4,11 +4,8 @@ from typing import Any -from sqlalchemy.exc import ArgumentError from sqlalchemy.sql.schema import MetaData, SchemaItem -from snowflake.sqlalchemy.custom_commands import NoneType - from .custom_table_base import CustomTableBase from .custom_table_prefix import CustomTablePrefix @@ -21,11 +18,20 @@ class HybridTable(CustomTableBase): While it does not support reflection at this time, it provides a flexible interface for creating dynamic tables and management. + + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-hybrid-table + + Example usage: + HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String) + ) """ __table_prefixes__ = [CustomTablePrefix.HYBRID] - - _support_primary_and_foreign_keys = True + _enforce_primary_keys: bool = True def __init__( self, @@ -45,18 +51,7 @@ def _init( *args: SchemaItem, **kw: Any, ) -> None: - super().__init__(name, metadata, *args, **kw) - - def _validate_table(self): - missing_attributes = [] - if self.key is NoneType: - missing_attributes.append("Primary Key") - if missing_attributes: - raise ArgumentError( - "HYBRID TABLE must have the following arguments: %s" - % ", ".join(missing_attributes) - ) - super()._validate_table() + self.__init__(name, metadata, *args, _no_init=False, **kw) def __repr__(self) -> str: return "HybridTable(%s)" % ", ".join( diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py index 052e2d96..11b54c1a 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py @@ -2,8 +2,29 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from .as_query import AsQuery -from .target_lag import TargetLag, TimeUnit -from .warehouse import Warehouse +from .as_query_option import AsQueryOption, AsQueryOptionType +from .identifier_option import IdentifierOption, IdentifierOptionType +from .keyword_option import KeywordOption, KeywordOptionType +from .keywords import SnowflakeKeyword +from .literal_option import LiteralOption, LiteralOptionType +from .table_option import TableOptionKey +from .target_lag_option import TargetLagOption, TargetLagOptionType, TimeUnit -__all__ = ["Warehouse", "AsQuery", "TargetLag", "TimeUnit"] +__all__ = [ + # Options + "IdentifierOption", + "LiteralOption", + "KeywordOption", + "AsQueryOption", + "TargetLagOption", + # Enums + "TimeUnit", + "SnowflakeKeyword", + "TableOptionKey", + # Types + "IdentifierOptionType", + "LiteralOptionType", + "AsQueryOptionType", + "TargetLagOptionType", + "KeywordOptionType", +] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py deleted file mode 100644 index 68076af9..00000000 --- a/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query.py +++ /dev/null @@ -1,62 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# -from typing import Union - -from sqlalchemy.sql import Selectable - -from .table_option import TableOption -from .table_option_base import Priority - - -class AsQuery(TableOption): - """Class to represent an AS clause in tables. - This configuration option is used to specify the query from which the table is created. - For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-table#create-table-as-select-also-referred-to-as-ctas - - - AsQuery example usage using an input string: - DynamicTable( - "sometable", metadata, - Column("name", String(50)), - Column("address", String(100)), - AsQuery('select name, address from existing_table where name = "test"') - ) - - AsQuery example usage using a selectable statement: - DynamicTable( - "sometable", - Base.metadata, - TargetLag(10, TimeUnit.SECONDS), - Warehouse("warehouse"), - AsQuery(select(test_table_1).where(test_table_1.c.id == 23)) - ) - - """ - - __option_name__ = "as_query" - __priority__ = Priority.LOWEST - - def __init__(self, query: Union[str, Selectable]) -> None: - r"""Construct an as_query object. - - :param \*expressions: - AS - - """ - self.query = query - - @staticmethod - def template() -> str: - return "AS %s" - - def get_expression(self): - if isinstance(self.query, Selectable): - return self.query.compile(compile_kwargs={"literal_binds": True}) - return self.query - - def render_option(self, compiler) -> str: - return AsQuery.template() % (self.get_expression()) - - def __repr__(self) -> str: - return "Query(%s)" % self.get_expression() diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query_option.py new file mode 100644 index 00000000..93994abc --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query_option.py @@ -0,0 +1,63 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional, Union + +from sqlalchemy.sql import Selectable + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .table_option import Priority, TableOption, TableOptionKey + + +class AsQueryOption(TableOption): + """Class to represent an AS clause in tables. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-table#create-table-as-select-also-referred-to-as-ctas + + Example: + as_query=AsQueryOption('select name, address from existing_table where name = "test"') + + is equivalent to: + + as select name, address from existing_table where name = "test" + """ + + def __init__(self, query: Union[str, Selectable]) -> None: + super().__init__() + self._name: TableOptionKey = TableOptionKey.AS_QUERY + self.query = query + + @staticmethod + def create( + value: Optional[Union["AsQueryOption", str, Selectable]] + ) -> "TableOption": + if isinstance(value, (NoneType, AsQueryOption)): + return value + if isinstance(value, (str, Selectable)): + return AsQueryOption(value) + return TableOption._get_invalid_table_option( + TableOptionKey.AS_QUERY, + str(type(value).__name__), + [AsQueryOption.__name__, str.__name__, Selectable.__name__], + ) + + def template(self) -> str: + return "AS %s" + + @property + def priority(self) -> Priority: + return Priority.LOWEST + + def __get_expression(self): + if isinstance(self.query, Selectable): + return self.query.compile(compile_kwargs={"literal_binds": True}) + return self.query + + def _render(self, compiler) -> str: + return self.template() % (self.__get_expression()) + + def __repr__(self) -> str: + return "AsQueryOption(%s)" % self.__get_expression() + + +AsQueryOptionType = Union[AsQueryOption, str, Selectable] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/identifier_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/identifier_option.py new file mode 100644 index 00000000..b296898b --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/identifier_option.py @@ -0,0 +1,63 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .table_option import Priority, TableOption, TableOptionKey + + +class IdentifierOption(TableOption): + """Class to represent an identifier option in Snowflake Tables. + + Example: + warehouse = IdentifierOption('my_warehouse') + + is equivalent to: + + WAREHOUSE = my_warehouse + """ + + def __init__(self, value: Union[str]) -> None: + super().__init__() + self.value: str = value + + @property + def priority(self): + return Priority.HIGH + + @staticmethod + def create( + name: TableOptionKey, value: Optional[Union[str, "IdentifierOption"]] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return None + + if isinstance(value, str): + value = IdentifierOption(value) + + if isinstance(value, IdentifierOption): + value._set_option_name(name) + return value + + return TableOption._get_invalid_table_option( + name, str(type(value).__name__), [IdentifierOption.__name__, str.__name__] + ) + + def template(self) -> str: + return f"{self.option_name.upper()} = %s" + + def _render(self, compiler) -> str: + return self.template() % self.value + + def __repr__(self) -> str: + option_name = ( + f", table_option_key={self.option_name}" + if not isinstance(self.option_name, NoneType) + else "" + ) + return f"IdentifierOption(value='{self.value}'{option_name})" + + +IdentifierOptionType = Union[IdentifierOption, str, int] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/invalid_table_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/invalid_table_option.py new file mode 100644 index 00000000..2bdc9dd3 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/invalid_table_option.py @@ -0,0 +1,25 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional + +from .table_option import TableOption, TableOptionKey + + +class InvalidTableOption(TableOption): + """Class to store errors and raise them after table initialization in order to avoid recursion error.""" + + def __init__(self, name: TableOptionKey, value: Exception) -> None: + super().__init__() + self.exception: Exception = value + self._name = name + + @staticmethod + def create(name: TableOptionKey, value: Exception) -> Optional[TableOption]: + return InvalidTableOption(name, value) + + def _render(self, compiler) -> str: + raise self.exception + + def __repr__(self) -> str: + return f"ErrorOption(value='{self.exception}')" diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/keyword_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/keyword_option.py new file mode 100644 index 00000000..ff6b444d --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/keyword_option.py @@ -0,0 +1,65 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .keywords import SnowflakeKeyword +from .table_option import Priority, TableOption, TableOptionKey + + +class KeywordOption(TableOption): + """Class to represent a keyword option in Snowflake Tables. + + Example: + target_lag = KeywordOption(SnowflakeKeyword.DOWNSTREAM) + + is equivalent to: + + TARGET_LAG = DOWNSTREAM + """ + + def __init__(self, value: Union[SnowflakeKeyword]) -> None: + super().__init__() + self.value: str = value.value + + @property + def priority(self): + return Priority.HIGH + + def template(self) -> str: + return f"{self.option_name.upper()} = %s" + + def _render(self, compiler) -> str: + return self.template() % self.value.upper() + + @staticmethod + def create( + name: TableOptionKey, value: Optional[Union[SnowflakeKeyword, "KeywordOption"]] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return value + if isinstance(value, SnowflakeKeyword): + value = KeywordOption(value) + + if isinstance(value, KeywordOption): + value._set_option_name(name) + return value + + return TableOption._get_invalid_table_option( + name, + str(type(value).__name__), + [KeywordOption.__name__, SnowflakeKeyword.__name__], + ) + + def __repr__(self) -> str: + option_name = ( + f", table_option_key={self.option_name}" + if isinstance(self.option_name, NoneType) + else "" + ) + return f"KeywordOption(value='{self.value}'{option_name})" + + +KeywordOptionType = Union[KeywordOption, SnowflakeKeyword] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/keywords.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/keywords.py new file mode 100644 index 00000000..f4ba87ea --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/keywords.py @@ -0,0 +1,14 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +from enum import Enum + + +class SnowflakeKeyword(Enum): + # TARGET_LAG + DOWNSTREAM = "DOWNSTREAM" + + # REFRESH_MODE + AUTO = "AUTO" + FULL = "FULL" + INCREMENTAL = "INCREMENTAL" diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/literal_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/literal_option.py new file mode 100644 index 00000000..55dd7675 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/literal_option.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Any, Optional, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .table_option import Priority, TableOption, TableOptionKey + + +class LiteralOption(TableOption): + """Class to represent a literal option in Snowflake Table. + + Example: + warehouse = LiteralOption('my_warehouse') + + is equivalent to: + + WAREHOUSE = 'my_warehouse' + """ + + def __init__(self, value: Union[int, str]) -> None: + super().__init__() + self.value: Any = value + + @property + def priority(self): + return Priority.HIGH + + @staticmethod + def create( + name: TableOptionKey, value: Optional[Union[str, int, "LiteralOption"]] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return None + if isinstance(value, (str, int)): + value = LiteralOption(value) + + if isinstance(value, LiteralOption): + value._set_option_name(name) + return value + + return TableOption._get_invalid_table_option( + name, + str(type(value).__name__), + [LiteralOption.__name__, str.__name__, int.__name__], + ) + + def template(self) -> str: + if isinstance(self.value, int): + return f"{self.option_name.upper()} = %d" + else: + return f"{self.option_name.upper()} = '%s'" + + def _render(self, compiler) -> str: + return self.template() % self.value + + def __repr__(self) -> str: + option_name = ( + f", table_option_key={self.option_name}" + if not isinstance(self.option_name, NoneType) + else "" + ) + return f"LiteralOption(value='{self.value}'{option_name})" + + +LiteralOptionType = Union[LiteralOption, str, int] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py index 7ac27575..14b91f2e 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py @@ -1,26 +1,83 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from typing import Any +from enum import Enum +from typing import List, Optional -from sqlalchemy import exc -from sqlalchemy.sql.base import SchemaEventTarget -from sqlalchemy.sql.schema import SchemaItem, Table +from snowflake.sqlalchemy import exc +from snowflake.sqlalchemy.custom_commands import NoneType -from snowflake.sqlalchemy._constants import DIALECT_NAME -from .table_option_base import TableOptionBase +class Priority(Enum): + LOWEST = 0 + VERY_LOW = 1 + LOW = 2 + MEDIUM = 4 + HIGH = 6 + VERY_HIGH = 7 + HIGHEST = 8 -class TableOption(TableOptionBase, SchemaItem): - def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: - if self.__option_name__ == "default": - raise exc.SQLAlchemyError(f"{self.__class__.__name__} does not has a name") - if not isinstance(parent, Table): - raise exc.SQLAlchemyError( - f"{self.__class__.__name__} option can only be applied to Table" - ) - parent.dialect_options[DIALECT_NAME][self.__option_name__] = self +class TableOption: - def _set_table_option_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: - pass + def __init__(self) -> None: + self._name: Optional[TableOptionKey] = None + + @property + def option_name(self) -> str: + if isinstance(self._name, NoneType): + return None + return str(self._name.value) + + def _set_option_name(self, name: Optional["TableOptionKey"]): + self._name = name + + @property + def priority(self) -> Priority: + return Priority.MEDIUM + + @staticmethod + def create(**kwargs) -> "TableOption": + raise NotImplementedError + + @staticmethod + def _get_invalid_table_option( + parameter_name: "TableOptionKey", input_type: str, expected_types: List[str] + ) -> "TableOption": + from .invalid_table_option import InvalidTableOption + + return InvalidTableOption( + parameter_name, + exc.InvalidTableParameterTypeError( + parameter_name.value, input_type, expected_types + ), + ) + + def _validate_option(self): + if isinstance(self.option_name, NoneType): + raise exc.OptionKeyNotProvidedError(self.__class__.__name__) + + def template(self) -> str: + return f"{self.option_name.upper()} = %s" + + def render_option(self, compiler) -> str: + self._validate_option() + return self._render(compiler) + + def _render(self, compiler) -> str: + raise NotImplementedError + + +class TableOptionKey(Enum): + AS_QUERY = "as_query" + BASE_LOCATION = "base_location" + CATALOG = "catalog" + CATALOG_SYNC = "catalog_sync" + DATA_RETENTION_TIME_IN_DAYS = "data_retention_time_in_days" + DEFAULT_DDL_COLLATION = "default_ddl_collation" + EXTERNAL_VOLUME = "external_volume" + MAX_DATA_EXTENSION_TIME_IN_DAYS = "max_data_extension_time_in_days" + REFRESH_MODE = "refresh_mode" + STORAGE_SERIALIZATION_POLICY = "storage_serialization_policy" + TARGET_LAG = "target_lag" + WAREHOUSE = "warehouse" diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py deleted file mode 100644 index 54008ec8..00000000 --- a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option_base.py +++ /dev/null @@ -1,30 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# -from enum import Enum - - -class Priority(Enum): - LOWEST = 0 - VERY_LOW = 1 - LOW = 2 - MEDIUM = 4 - HIGH = 6 - VERY_HIGH = 7 - HIGHEST = 8 - - -class TableOptionBase: - __option_name__ = "default" - __visit_name__ = __option_name__ - __priority__ = Priority.MEDIUM - - @staticmethod - def template() -> str: - raise NotImplementedError - - def get_expression(self): - raise NotImplementedError - - def render_option(self, compiler) -> str: - raise NotImplementedError diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py deleted file mode 100644 index 4331a4cb..00000000 --- a/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag.py +++ /dev/null @@ -1,60 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# from enum import Enum -from enum import Enum -from typing import Optional - -from .table_option import TableOption -from .table_option_base import Priority - - -class TimeUnit(Enum): - SECONDS = "seconds" - MINUTES = "minutes" - HOURS = "hour" - DAYS = "days" - - -class TargetLag(TableOption): - """Class to represent the target lag clause. - This configuration option is used to specify the target lag time for the dynamic table. - For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table - - - Target lag example usage: - DynamicTable("sometable", metadata, - Column("name", String(50)), - Column("address", String(100)), - TargetLag(20, TimeUnit.MINUTES), - ) - """ - - __option_name__ = "target_lag" - __priority__ = Priority.HIGH - - def __init__( - self, - time: Optional[int] = 0, - unit: Optional[TimeUnit] = TimeUnit.MINUTES, - down_stream: Optional[bool] = False, - ) -> None: - self.time = time - self.unit = unit - self.down_stream = down_stream - - @staticmethod - def template() -> str: - return "TARGET_LAG = %s" - - def get_expression(self): - return ( - f"'{str(self.time)} {str(self.unit.value)}'" - if not self.down_stream - else "DOWNSTREAM" - ) - - def render_option(self, compiler) -> str: - return TargetLag.template() % (self.get_expression()) - - def __repr__(self) -> str: - return "TargetLag(%s)" % self.get_expression() diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag_option.py new file mode 100644 index 00000000..7c1c0825 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag_option.py @@ -0,0 +1,94 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# from enum import Enum +from enum import Enum +from typing import Optional, Tuple, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .keyword_option import KeywordOption, KeywordOptionType +from .keywords import SnowflakeKeyword +from .table_option import Priority, TableOption, TableOptionKey + + +class TimeUnit(Enum): + SECONDS = "seconds" + MINUTES = "minutes" + HOURS = "hours" + DAYS = "days" + + +class TargetLagOption(TableOption): + """Class to represent the target lag clause in Dynamic Tables. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table + + Example using the time and unit parameters: + + target_lag = TargetLagOption(10, TimeUnit.SECONDS) + + is equivalent to: + + TARGET_LAG = '10 SECONDS' + + Example using keyword parameter: + + target_lag = KeywordOption(SnowflakeKeyword.DOWNSTREAM) + + is equivalent to: + + TARGET_LAG = DOWNSTREAM + + """ + + def __init__( + self, + time: Optional[int] = 0, + unit: Optional[TimeUnit] = TimeUnit.MINUTES, + ) -> None: + super().__init__() + self.time = time + self.unit = unit + self._name: TableOptionKey = TableOptionKey.TARGET_LAG + + @staticmethod + def create( + value: Union["TargetLagOption", Tuple[int, TimeUnit], KeywordOptionType] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return value + + if isinstance(value, Tuple): + time, unit = value + value = TargetLagOption(time, unit) + + if isinstance(value, TargetLagOption): + return value + + if isinstance(value, (KeywordOption, SnowflakeKeyword)): + return KeywordOption.create(TableOptionKey.TARGET_LAG, value) + + return TableOption._get_invalid_table_option( + TableOptionKey.TARGET_LAG, + str(type(value).__name__), + [ + TargetLagOption.__name__, + f"Tuple[int, {TimeUnit.__name__}])", + SnowflakeKeyword.__name__, + ], + ) + + def __get_expression(self): + return f"'{str(self.time)} {str(self.unit.value)}'" + + @property + def priority(self) -> Priority: + return Priority.HIGH + + def _render(self, compiler) -> str: + return self.template() % (self.__get_expression()) + + def __repr__(self) -> str: + return "TargetLagOption(%s)" % self.__get_expression() + + +TargetLagOptionType = Union[TargetLagOption, Tuple[int, TimeUnit]] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py deleted file mode 100644 index a5b8cce0..00000000 --- a/src/snowflake/sqlalchemy/sql/custom_schema/options/warehouse.py +++ /dev/null @@ -1,51 +0,0 @@ -# -# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. -# -from typing import Optional - -from .table_option import TableOption -from .table_option_base import Priority - - -class Warehouse(TableOption): - """Class to represent the warehouse clause. - This configuration option is used to specify the warehouse for the dynamic table. - For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table - - - Warehouse example usage: - DynamicTable("sometable", metadata, - Column("name", String(50)), - Column("address", String(100)), - Warehouse('my_warehouse_name') - ) - """ - - __option_name__ = "warehouse" - __priority__ = Priority.HIGH - - def __init__( - self, - name: Optional[str], - ) -> None: - r"""Construct a Warehouse object. - - :param \*expressions: - Dynamic table warehouse option. - WAREHOUSE = - - """ - self.name = name - - @staticmethod - def template() -> str: - return "WAREHOUSE = %s" - - def get_expression(self): - return self.name - - def render_option(self, compiler) -> str: - return Warehouse.template() % (self.get_expression()) - - def __repr__(self) -> str: - return "Warehouse(%s)" % self.get_expression() diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py index 60e8995f..fccc7a0b 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py @@ -6,29 +6,31 @@ from sqlalchemy.sql import Selectable from sqlalchemy.sql.schema import Column, MetaData, SchemaItem -from sqlalchemy.util import NoneType from .custom_table_base import CustomTableBase -from .options.as_query import AsQuery +from .options.as_query_option import AsQueryOption, AsQueryOptionType +from .options.table_option import TableOptionKey class TableFromQueryBase(CustomTableBase): @property - def as_query(self): - return self._get_dialect_option(AsQuery.__option_name__) + def as_query(self) -> Optional[AsQueryOption]: + return self._get_dialect_option(TableOptionKey.AS_QUERY) def __init__( self, name: str, metadata: MetaData, *args: SchemaItem, + as_query: AsQueryOptionType = None, **kw: Any, ) -> None: items = [item for item in args] - as_query: AsQuery = self.__get_as_query_from_items(items) + as_query = AsQueryOption.create(as_query) # noqa + kw.update(self._as_dialect_options([as_query])) if ( - as_query is not NoneType + isinstance(as_query, AsQueryOption) and isinstance(as_query.query, Selectable) and not self.__has_defined_columns(items) ): @@ -36,14 +38,6 @@ def __init__( args = items + columns super().__init__(name, metadata, *args, **kw) - def __get_as_query_from_items( - self, items: typing.List[SchemaItem] - ) -> Optional[AsQuery]: - for item in items: - if isinstance(item, AsQuery): - return item - return NoneType - def __has_defined_columns(self, items: typing.List[SchemaItem]) -> bool: for item in items: if isinstance(item, Column): diff --git a/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr b/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr index 81c7f90f..66c8f98e 100644 --- a/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr +++ b/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr @@ -6,7 +6,34 @@ "CREATE DYNAMIC TABLE test_dynamic_table_orm (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" # --- # name: test_compile_dynamic_table_orm_with_str_keys - "CREATE DYNAMIC TABLE test_dynamic_table_orm_2 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" + 'CREATE DYNAMIC TABLE "SCHEMA_DB".test_dynamic_table_orm_2 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = \'10 seconds\'\tAS SELECT * FROM table' +# --- +# name: test_compile_dynamic_table_with_multiple_wrong_option_types + ''' + Invalid parameter type 'IdentifierOption' provided for 'refresh_mode'. Expected one of the following types: 'KeywordOption', 'SnowflakeKeyword'. + Invalid parameter type 'IdentifierOption' provided for 'target_lag'. Expected one of the following types: 'TargetLagOption', 'Tuple[int, TimeUnit])', 'SnowflakeKeyword'. + Invalid parameter type 'KeywordOption' provided for 'as_query'. Expected one of the following types: 'AsQueryOption', 'str', 'Selectable'. + Invalid parameter type 'KeywordOption' provided for 'warehouse'. Expected one of the following types: 'IdentifierOption', 'str'. + + ''' +# --- +# name: test_compile_dynamic_table_with_one_wrong_option_types + ''' + Invalid parameter type 'LiteralOption' provided for 'warehouse'. Expected one of the following types: 'IdentifierOption', 'str'. + + ''' +# --- +# name: test_compile_dynamic_table_with_options_objects + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = AUTO\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_refresh_mode[SnowflakeKeyword.AUTO] + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = AUTO\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_refresh_mode[SnowflakeKeyword.FULL] + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = FULL\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_refresh_mode[SnowflakeKeyword.INCREMENTAL] + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = INCREMENTAL\tAS SELECT * FROM table" # --- # name: test_compile_dynamic_table_with_selectable "CREATE DYNAMIC TABLE dynamic_test_table_1 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT test_table_1.id, test_table_1.name FROM test_table_1 WHERE test_table_1.id = 23" diff --git a/tests/custom_tables/__snapshots__/test_create_dynamic_table.ambr b/tests/custom_tables/__snapshots__/test_create_dynamic_table.ambr new file mode 100644 index 00000000..80201495 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_create_dynamic_table.ambr @@ -0,0 +1,7 @@ +# serializer version: 1 +# name: test_create_dynamic_table_without_dynamictable_and_defined_options + CustomOptionsAreOnlySupportedOnSnowflakeTables('Identifier, Literal, TargetLag and other custom options are only supported on Snowflake tables.') +# --- +# name: test_create_dynamic_table_without_dynamictable_class + UnexpectedOptionTypeError('The following options are either unsupported or should be defined using a Snowflake table: as_query, warehouse.') +# --- diff --git a/tests/custom_tables/__snapshots__/test_generic_options.ambr b/tests/custom_tables/__snapshots__/test_generic_options.ambr new file mode 100644 index 00000000..eef5e6fd --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_generic_options.ambr @@ -0,0 +1,13 @@ +# serializer version: 1 +# name: test_identifier_option_with_wrong_type + InvalidTableParameterTypeError("Invalid parameter type 'int' provided for 'warehouse'. Expected one of the following types: 'IdentifierOption', 'str'.\n") +# --- +# name: test_identifier_option_without_name + OptionKeyNotProvidedError('Expected option key in IdentifierOption option but got NoneType instead.') +# --- +# name: test_invalid_as_query_option + InvalidTableParameterTypeError("Invalid parameter type 'int' provided for 'as_query'. Expected one of the following types: 'AsQueryOption', 'str', 'Selectable'.\n") +# --- +# name: test_literal_option_with_wrong_type + InvalidTableParameterTypeError("Invalid parameter type 'SnowflakeKeyword' provided for 'warehouse'. Expected one of the following types: 'LiteralOption', 'str', 'int'.\n") +# --- diff --git a/tests/custom_tables/test_compile_dynamic_table.py b/tests/custom_tables/test_compile_dynamic_table.py index 16a039e7..935c61cd 100644 --- a/tests/custom_tables/test_compile_dynamic_table.py +++ b/tests/custom_tables/test_compile_dynamic_table.py @@ -12,16 +12,21 @@ exc, select, ) +from sqlalchemy.exc import ArgumentError from sqlalchemy.orm import declarative_base from sqlalchemy.sql.ddl import CreateTable from snowflake.sqlalchemy import GEOMETRY, DynamicTable -from snowflake.sqlalchemy.sql.custom_schema.options.as_query import AsQuery -from snowflake.sqlalchemy.sql.custom_schema.options.target_lag import ( - TargetLag, +from snowflake.sqlalchemy.exc import MultipleErrors +from snowflake.sqlalchemy.sql.custom_schema.options import ( + AsQueryOption, + IdentifierOption, + KeywordOption, + LiteralOption, + TargetLagOption, TimeUnit, ) -from snowflake.sqlalchemy.sql.custom_schema.options.warehouse import Warehouse +from snowflake.sqlalchemy.sql.custom_schema.options.keywords import SnowflakeKeyword def test_compile_dynamic_table(sql_compiler, snapshot): @@ -32,9 +37,9 @@ def test_compile_dynamic_table(sql_compiler, snapshot): metadata, Column("id", Integer), Column("geom", GEOMETRY), - TargetLag(10, TimeUnit.SECONDS), - Warehouse("warehouse"), - AsQuery("SELECT * FROM table"), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", ) value = CreateTable(test_geometry) @@ -44,11 +49,99 @@ def test_compile_dynamic_table(sql_compiler, snapshot): assert actual == snapshot +@pytest.mark.parametrize( + "refresh_mode_keyword", + [ + SnowflakeKeyword.AUTO, + SnowflakeKeyword.FULL, + SnowflakeKeyword.INCREMENTAL, + ], +) +def test_compile_dynamic_table_with_refresh_mode( + sql_compiler, snapshot, refresh_mode_keyword +): + metadata = MetaData() + table_name = "test_dynamic_table" + test_geometry = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", + refresh_mode=refresh_mode_keyword, + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_with_options_objects(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + test_geometry = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=TargetLagOption(10, TimeUnit.SECONDS), + warehouse=IdentifierOption("warehouse"), + as_query=AsQueryOption("SELECT * FROM table"), + refresh_mode=KeywordOption(SnowflakeKeyword.AUTO), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_with_one_wrong_option_types(snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + with pytest.raises(ArgumentError) as argument_error: + DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=TargetLagOption(10, TimeUnit.SECONDS), + warehouse=LiteralOption("warehouse"), + as_query=AsQueryOption("SELECT * FROM table"), + refresh_mode=KeywordOption(SnowflakeKeyword.AUTO), + ) + + assert str(argument_error.value) == snapshot + + +def test_compile_dynamic_table_with_multiple_wrong_option_types(snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + with pytest.raises(MultipleErrors) as argument_error: + DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=IdentifierOption(SnowflakeKeyword.AUTO), + warehouse=KeywordOption(SnowflakeKeyword.AUTO), + as_query=KeywordOption(SnowflakeKeyword.AUTO), + refresh_mode=IdentifierOption(SnowflakeKeyword.AUTO), + ) + + assert str(argument_error.value) == snapshot + + def test_compile_dynamic_table_without_required_args(sql_compiler): with pytest.raises( exc.ArgumentError, - match="DYNAMIC TABLE must have the following arguments: TargetLag, " - "Warehouse, AsQuery", + match="DynamicTable requires the following parameters: warehouse, " + "as_query, target_lag.", ): DynamicTable( "test_dynamic_table", @@ -61,33 +154,33 @@ def test_compile_dynamic_table_without_required_args(sql_compiler): def test_compile_dynamic_table_with_primary_key(sql_compiler): with pytest.raises( exc.ArgumentError, - match="Primary key and foreign keys are not supported in DYNAMIC TABLE.", + match="Primary key and foreign keys are not supported in DynamicTable.", ): DynamicTable( "test_dynamic_table", MetaData(), Column("id", Integer, primary_key=True), Column("geom", GEOMETRY), - TargetLag(10, TimeUnit.SECONDS), - Warehouse("warehouse"), - AsQuery("SELECT * FROM table"), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", ) def test_compile_dynamic_table_with_foreign_key(sql_compiler): with pytest.raises( exc.ArgumentError, - match="Primary key and foreign keys are not supported in DYNAMIC TABLE.", + match="Primary key and foreign keys are not supported in DynamicTable.", ): DynamicTable( "test_dynamic_table", MetaData(), Column("id", Integer), Column("geom", GEOMETRY), - TargetLag(10, TimeUnit.SECONDS), - Warehouse("warehouse"), - AsQuery("SELECT * FROM table"), ForeignKeyConstraint(["id"], ["table.id"]), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", ) @@ -100,9 +193,9 @@ def test_compile_dynamic_table_orm(sql_compiler, snapshot): metadata, Column("id", Integer), Column("name", String), - TargetLag(10, TimeUnit.SECONDS), - Warehouse("warehouse"), - AsQuery("SELECT * FROM table"), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", ) class TestDynamicTableOrm(Base): @@ -121,23 +214,22 @@ def __repr__(self): assert actual == snapshot -def test_compile_dynamic_table_orm_with_str_keys(sql_compiler, db_parameters, snapshot): +def test_compile_dynamic_table_orm_with_str_keys(sql_compiler, snapshot): Base = declarative_base() - schema = db_parameters["schema"] class TestDynamicTableOrm(Base): __tablename__ = "test_dynamic_table_orm_2" - __table_args__ = {"schema": schema} @classmethod def __table_cls__(cls, name, metadata, *arg, **kw): return DynamicTable(name, metadata, *arg, **kw) - __table_args__ = ( - TargetLag(10, TimeUnit.SECONDS), - Warehouse("warehouse"), - AsQuery("SELECT * FROM table"), - ) + __table_args__ = { + "schema": "SCHEMA_DB", + "target_lag": (10, TimeUnit.SECONDS), + "warehouse": "warehouse", + "as_query": "SELECT * FROM table", + } id = Column(Integer) name = Column(String) @@ -167,9 +259,9 @@ def test_compile_dynamic_table_with_selectable(sql_compiler, snapshot): dynamic_test_table = DynamicTable( "dynamic_test_table_1", Base.metadata, - TargetLag(10, TimeUnit.SECONDS), - Warehouse("warehouse"), - AsQuery(select(test_table_1).where(test_table_1.c.id == 23)), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query=select(test_table_1).where(test_table_1.c.id == 23), ) value = CreateTable(dynamic_test_table) diff --git a/tests/custom_tables/test_create_dynamic_table.py b/tests/custom_tables/test_create_dynamic_table.py index 4e6c48ca..b583faad 100644 --- a/tests/custom_tables/test_create_dynamic_table.py +++ b/tests/custom_tables/test_create_dynamic_table.py @@ -1,15 +1,20 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # +import pytest from sqlalchemy import Column, Integer, MetaData, String, Table, select -from snowflake.sqlalchemy import DynamicTable -from snowflake.sqlalchemy.sql.custom_schema.options.as_query import AsQuery -from snowflake.sqlalchemy.sql.custom_schema.options.target_lag import ( - TargetLag, +from snowflake.sqlalchemy import DynamicTable, exc +from snowflake.sqlalchemy.sql.custom_schema.options.as_query_option import AsQueryOption +from snowflake.sqlalchemy.sql.custom_schema.options.identifier_option import ( + IdentifierOption, +) +from snowflake.sqlalchemy.sql.custom_schema.options.keywords import SnowflakeKeyword +from snowflake.sqlalchemy.sql.custom_schema.options.table_option import TableOptionKey +from snowflake.sqlalchemy.sql.custom_schema.options.target_lag_option import ( + TargetLagOption, TimeUnit, ) -from snowflake.sqlalchemy.sql.custom_schema.options.warehouse import Warehouse def test_create_dynamic_table(engine_testaccount, db_parameters): @@ -32,9 +37,10 @@ def test_create_dynamic_table(engine_testaccount, db_parameters): metadata, Column("id", Integer), Column("name", String), - TargetLag(1, TimeUnit.HOURS), - Warehouse(warehouse), - AsQuery("SELECT id, name from test_table_1;"), + target_lag=(1, TimeUnit.HOURS), + warehouse=warehouse, + as_query="SELECT id, name from test_table_1;", + refresh_mode=SnowflakeKeyword.FULL, ) metadata.create_all(engine_testaccount) @@ -52,7 +58,7 @@ def test_create_dynamic_table(engine_testaccount, db_parameters): def test_create_dynamic_table_without_dynamictable_class( - engine_testaccount, db_parameters + engine_testaccount, db_parameters, snapshot ): warehouse = db_parameters.get("warehouse", "default") metadata = MetaData() @@ -68,26 +74,51 @@ def test_create_dynamic_table_without_dynamictable_class( conn.execute(ins) conn.commit() - dynamic_test_table_1 = Table( + Table( "dynamic_test_table_1", metadata, Column("id", Integer), Column("name", String), - TargetLag(1, TimeUnit.HOURS), - Warehouse(warehouse), - AsQuery("SELECT id, name from test_table_1;"), + snowflake_warehouse=warehouse, + snowflake_as_query="SELECT id, name from test_table_1;", prefixes=["DYNAMIC"], ) + with pytest.raises(exc.UnexpectedOptionTypeError) as exc_info: + metadata.create_all(engine_testaccount) + assert exc_info.value == snapshot + + +def test_create_dynamic_table_without_dynamictable_and_defined_options( + engine_testaccount, db_parameters, snapshot +): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + metadata.create_all(engine_testaccount) - try: - with engine_testaccount.connect() as conn: - s = select(dynamic_test_table_1) - results_dynamic_table = conn.execute(s).fetchall() - s = select(test_table_1) - results_table = conn.execute(s).fetchall() - assert results_dynamic_table == results_table + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") - finally: - metadata.drop_all(engine_testaccount) + conn.execute(ins) + conn.commit() + + Table( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + snowflake_target_lag=TargetLagOption.create((1, TimeUnit.HOURS)), + snowflake_warehouse=IdentifierOption.create( + TableOptionKey.WAREHOUSE, warehouse + ), + snowflake_as_query=AsQueryOption.create("SELECT id, name from test_table_1;"), + prefixes=["DYNAMIC"], + ) + + with pytest.raises(exc.CustomOptionsAreOnlySupportedOnSnowflakeTables) as exc_info: + metadata.create_all(engine_testaccount) + assert exc_info.value == snapshot diff --git a/tests/custom_tables/test_generic_options.py b/tests/custom_tables/test_generic_options.py new file mode 100644 index 00000000..916b94c6 --- /dev/null +++ b/tests/custom_tables/test_generic_options.py @@ -0,0 +1,83 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +import pytest + +from snowflake.sqlalchemy import ( + AsQueryOption, + IdentifierOption, + KeywordOption, + LiteralOption, + SnowflakeKeyword, + TableOptionKey, + TargetLagOption, + exc, +) +from snowflake.sqlalchemy.sql.custom_schema.options.invalid_table_option import ( + InvalidTableOption, +) + + +def test_identifier_option(): + identifier = IdentifierOption.create(TableOptionKey.WAREHOUSE, "xsmall") + assert identifier.render_option(None) == "WAREHOUSE = xsmall" + + +def test_literal_option(): + literal = LiteralOption.create(TableOptionKey.WAREHOUSE, "xsmall") + assert literal.render_option(None) == "WAREHOUSE = 'xsmall'" + + +def test_identifier_option_without_name(snapshot): + identifier = IdentifierOption("xsmall") + with pytest.raises(exc.OptionKeyNotProvidedError) as exc_info: + identifier.render_option(None) + assert exc_info.value == snapshot + + +def test_identifier_option_with_wrong_type(snapshot): + identifier = IdentifierOption.create(TableOptionKey.WAREHOUSE, 23) + with pytest.raises(exc.InvalidTableParameterTypeError) as exc_info: + identifier.render_option(None) + assert exc_info.value == snapshot + + +def test_literal_option_with_wrong_type(snapshot): + literal = LiteralOption.create( + TableOptionKey.WAREHOUSE, SnowflakeKeyword.DOWNSTREAM + ) + with pytest.raises(exc.InvalidTableParameterTypeError) as exc_info: + literal.render_option(None) + assert exc_info.value == snapshot + + +def test_invalid_as_query_option(snapshot): + as_query = AsQueryOption.create(23) + with pytest.raises(exc.InvalidTableParameterTypeError) as exc_info: + as_query.render_option(None) + assert exc_info.value == snapshot + + +@pytest.mark.parametrize( + "table_option", + [ + IdentifierOption, + LiteralOption, + KeywordOption, + ], +) +def test_generic_option_with_wrong_type(table_option): + literal = table_option.create(TableOptionKey.WAREHOUSE, 0.32) + assert isinstance(literal, InvalidTableOption), "Expected InvalidTableOption" + + +@pytest.mark.parametrize( + "table_option", + [ + TargetLagOption, + AsQueryOption, + ], +) +def test_non_generic_option_with_wrong_type(table_option): + literal = table_option.create(0.32) + assert isinstance(literal, InvalidTableOption), "Expected InvalidTableOption" diff --git a/tests/custom_tables/test_reflect_dynamic_table.py b/tests/custom_tables/test_reflect_dynamic_table.py index 8a4a8445..52eb4457 100644 --- a/tests/custom_tables/test_reflect_dynamic_table.py +++ b/tests/custom_tables/test_reflect_dynamic_table.py @@ -74,7 +74,7 @@ def test_simple_reflection_without_options_loading(engine_testaccount, db_parame ) # TODO: Add support for loading options when table is reflected - assert dynamic_test_table.warehouse is NoneType + assert isinstance(dynamic_test_table.warehouse, NoneType) try: with engine_testaccount.connect() as conn: diff --git a/tests/test_core.py b/tests/test_core.py index 15840838..980db1d2 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1060,6 +1060,7 @@ def harass_inspector(): assert outcome +@pytest.mark.skip(reason="Testaccount is not available, it returns 404 error.") @pytest.mark.timeout(10) @pytest.mark.parametrize( "region", From 14be28216fb477d10815fd15e8290c866de4e260 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Tue, 29 Oct 2024 07:03:43 -0600 Subject: [PATCH 15/21] Add support for iceberg table with snowflake catalog (#539) * Add support for Iceberg Table with Snowflake Catalog * Add support for Snowflake Table * Update DESCRIPTION.md --- DESCRIPTION.md | 2 + README.md | 2 +- src/snowflake/sqlalchemy/__init__.py | 48 +++-- src/snowflake/sqlalchemy/base.py | 13 +- .../sqlalchemy/sql/custom_schema/__init__.py | 4 +- .../sql/custom_schema/clustered_table.py | 37 ++++ .../sql/custom_schema/dynamic_table.py | 6 +- .../sql/custom_schema/hybrid_table.py | 2 +- .../sql/custom_schema/iceberg_table.py | 101 ++++++++++ .../sql/custom_schema/options/__init__.py | 3 + .../options/cluster_by_option.py | 58 ++++++ .../sql/custom_schema/options/table_option.py | 1 + .../sql/custom_schema/snowflake_table.py | 70 +++++++ .../sql/custom_schema/table_from_query.py | 4 +- tests/__snapshots__/test_core.ambr | 4 + .../test_compile_iceberg_table.ambr | 19 ++ .../test_compile_snowflake_table.ambr | 35 ++++ .../test_create_iceberg_table.ambr | 14 ++ .../test_create_snowflake_table.ambr | 4 + .../test_reflect_snowflake_table.ambr | 7 + .../test_compile_iceberg_table.py | 116 +++++++++++ .../test_compile_snowflake_table.py | 180 ++++++++++++++++++ .../test_create_iceberg_table.py | 43 +++++ .../test_create_snowflake_table.py | 66 +++++++ .../test_reflect_snowflake_table.py | 69 +++++++ tests/test_core.py | 34 ++++ 26 files changed, 916 insertions(+), 26 deletions(-) create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/clustered_table.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/options/cluster_by_option.py create mode 100644 src/snowflake/sqlalchemy/sql/custom_schema/snowflake_table.py create mode 100644 tests/__snapshots__/test_core.ambr create mode 100644 tests/custom_tables/__snapshots__/test_compile_iceberg_table.ambr create mode 100644 tests/custom_tables/__snapshots__/test_compile_snowflake_table.ambr create mode 100644 tests/custom_tables/__snapshots__/test_create_iceberg_table.ambr create mode 100644 tests/custom_tables/__snapshots__/test_create_snowflake_table.ambr create mode 100644 tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr create mode 100644 tests/custom_tables/test_compile_iceberg_table.py create mode 100644 tests/custom_tables/test_compile_snowflake_table.py create mode 100644 tests/custom_tables/test_create_iceberg_table.py create mode 100644 tests/custom_tables/test_create_snowflake_table.py create mode 100644 tests/custom_tables/test_reflect_snowflake_table.py diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 909d52cf..47697d30 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -16,6 +16,8 @@ Source code is also available at: - Fixed SAWarning when registering functions with existing name in default namespace - Update options to be defined in key arguments instead of arguments. - Add support for refresh_mode option in DynamicTable + - Add support for iceberg table with Snowflake Catalog + - Fix cluster by option to support explicit expressions - v1.6.1(July 9, 2024) diff --git a/README.md b/README.md index c428353f..c6c13349 100644 --- a/README.md +++ b/README.md @@ -340,7 +340,7 @@ This example shows how to create a table with two columns, `id` and `name`, as t t = Table('myuser', metadata, Column('id', Integer, primary_key=True), Column('name', String), - snowflake_clusterby=['id', 'name'], ... + snowflake_clusterby=['id', 'name', text('id > 5')], ... ) metadata.create_all(engine) ``` diff --git a/src/snowflake/sqlalchemy/__init__.py b/src/snowflake/sqlalchemy/__init__.py index e53f9b74..f6c97f0d 100644 --- a/src/snowflake/sqlalchemy/__init__.py +++ b/src/snowflake/sqlalchemy/__init__.py @@ -9,7 +9,7 @@ else: import importlib.metadata as importlib_metadata -from sqlalchemy.types import ( +from sqlalchemy.types import ( # noqa BIGINT, BINARY, BOOLEAN, @@ -27,8 +27,8 @@ VARCHAR, ) -from . import base, snowdialect -from .custom_commands import ( +from . import base, snowdialect # noqa +from .custom_commands import ( # noqa AWSBucket, AzureContainer, CopyFormatter, @@ -41,7 +41,7 @@ MergeInto, PARQUETFormatter, ) -from .custom_types import ( +from .custom_types import ( # noqa ARRAY, BYTEINT, CHARACTER, @@ -61,9 +61,15 @@ VARBINARY, VARIANT, ) -from .sql.custom_schema import DynamicTable, HybridTable -from .sql.custom_schema.options import ( +from .sql.custom_schema import ( # noqa + DynamicTable, + HybridTable, + IcebergTable, + SnowflakeTable, +) +from .sql.custom_schema.options import ( # noqa AsQueryOption, + ClusterByOption, IdentifierOption, KeywordOption, LiteralOption, @@ -72,14 +78,13 @@ TargetLagOption, TimeUnit, ) -from .util import _url as URL +from .util import _url as URL # noqa base.dialect = dialect = snowdialect.dialect __version__ = importlib_metadata.version("snowflake-sqlalchemy") -__all__ = ( - # Custom Types +_custom_types = ( "BIGINT", "BINARY", "BOOLEAN", @@ -114,7 +119,9 @@ "TINYINT", "VARBINARY", "VARIANT", - # Custom Commands +) + +_custom_commands = ( "MergeInto", "CSVFormatter", "JSONFormatter", @@ -126,17 +133,28 @@ "ExternalStage", "CreateStage", "CreateFileFormat", - # Custom Tables - "HybridTable", - "DynamicTable", - # Custom Table Options +) + +_custom_tables = ("HybridTable", "DynamicTable", "IcebergTable", "SnowflakeTable") + +_custom_table_options = ( "AsQueryOption", "TargetLagOption", "LiteralOption", "IdentifierOption", "KeywordOption", - # Enums + "ClusterByOption", +) + +_enums = ( "TimeUnit", "TableOptionKey", "SnowflakeKeyword", ) +__all__ = ( + *_custom_types, + *_custom_commands, + *_custom_tables, + *_custom_table_options, + *_enums, +) diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 023f7afb..4e36c4ad 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -908,7 +908,7 @@ def handle_cluster_by(self, table): ... metadata, ... sa.Column('id', sa.Integer, primary_key=True), ... sa.Column('name', sa.String), - ... snowflake_clusterby=['id', 'name'] + ... snowflake_clusterby=['id', 'name', text("id > 5")] ... ) >>> print(CreateTable(user).compile(engine)) @@ -916,7 +916,7 @@ def handle_cluster_by(self, table): id INTEGER NOT NULL AUTOINCREMENT, name VARCHAR, PRIMARY KEY (id) - ) CLUSTER BY (id, name) + ) CLUSTER BY (id, name, id > 5) """ @@ -925,7 +925,14 @@ def handle_cluster_by(self, table): cluster = info.get("clusterby") if cluster: text += " CLUSTER BY ({})".format( - ", ".join(self.denormalize_column_name(key) for key in cluster) + ", ".join( + ( + self.denormalize_column_name(key) + if isinstance(key, str) + else str(key) + ) + for key in cluster + ) ) return text diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py b/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py index 66b9270f..cbc75ebc 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py @@ -3,5 +3,7 @@ # from .dynamic_table import DynamicTable from .hybrid_table import HybridTable +from .iceberg_table import IcebergTable +from .snowflake_table import SnowflakeTable -__all__ = ["DynamicTable", "HybridTable"] +__all__ = ["DynamicTable", "HybridTable", "IcebergTable", "SnowflakeTable"] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/clustered_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/clustered_table.py new file mode 100644 index 00000000..6c0904a8 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/clustered_table.py @@ -0,0 +1,37 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from typing import Any, Optional + +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from .custom_table_base import CustomTableBase +from .options.as_query_option import AsQueryOption +from .options.cluster_by_option import ClusterByOption, ClusterByOptionType +from .options.table_option import TableOptionKey + + +class ClusteredTableBase(CustomTableBase): + + @property + def cluster_by(self) -> Optional[AsQueryOption]: + return self._get_dialect_option(TableOptionKey.CLUSTER_BY) + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + cluster_by: ClusterByOptionType = None, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + + options = [ + ClusterByOption.create(cluster_by), + ] + + kw.update(self._as_dialect_options(options)) + super().__init__(name, metadata, *args, **kw) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py index 6db4312d..91c379f0 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py @@ -12,7 +12,6 @@ IdentifierOption, IdentifierOptionType, KeywordOptionType, - LiteralOption, TableOptionKey, TargetLagOption, TargetLagOptionType, @@ -45,7 +44,7 @@ class DynamicTable(TableFromQueryBase): as_query="SELECT id, name from test_table_1;" ) - Example using full options: + Example using explicit options: DynamicTable( "dynamic_test_table_1", metadata, @@ -67,7 +66,7 @@ class DynamicTable(TableFromQueryBase): ] @property - def warehouse(self) -> typing.Optional[LiteralOption]: + def warehouse(self) -> typing.Optional[IdentifierOption]: return self._get_dialect_option(TableOptionKey.WAREHOUSE) @property @@ -112,6 +111,7 @@ def __repr__(self) -> str: + [repr(x) for x in self.columns] + [repr(self.target_lag)] + [repr(self.warehouse)] + + [repr(self.cluster_by)] + [repr(self.as_query)] + [f"{k}={repr(getattr(self, k))}" for k in ["schema"]] ) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py index b7c29e78..16a58d47 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py @@ -17,7 +17,7 @@ class HybridTable(CustomTableBase): The `HybridTable` class allows for the creation and querying of OLTP Snowflake Tables . While it does not support reflection at this time, it provides a flexible - interface for creating dynamic tables and management. + interface for creating hybrid tables and management. For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-hybrid-table diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py new file mode 100644 index 00000000..5c9c53d9 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py @@ -0,0 +1,101 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import typing +from typing import Any + +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from .custom_table_prefix import CustomTablePrefix +from .options import LiteralOption, LiteralOptionType, TableOptionKey +from .table_from_query import TableFromQueryBase + + +class IcebergTable(TableFromQueryBase): + """ + A class representing an iceberg table with configurable options and settings. + + While it does not support reflection at this time, it provides a flexible + interface for creating iceberg tables and management. + + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-iceberg-table + + Example using option values: + + IcebergTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + external_volume='my_external_volume', + base_location='my_iceberg_table'" + ) + + Example using explicit options: + DynamicTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + external_volume=LiteralOption('my_external_volume') + base_location=LiteralOption('my_iceberg_table') + ) + """ + + __table_prefixes__ = [CustomTablePrefix.ICEBERG] + + @property + def external_volume(self) -> typing.Optional[LiteralOption]: + return self._get_dialect_option(TableOptionKey.EXTERNAL_VOLUME) + + @property + def base_location(self) -> typing.Optional[LiteralOption]: + return self._get_dialect_option(TableOptionKey.BASE_LOCATION) + + @property + def catalog(self) -> typing.Optional[LiteralOption]: + return self._get_dialect_option(TableOptionKey.CATALOG) + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + external_volume: LiteralOptionType = None, + base_location: LiteralOptionType = None, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + + options = [ + LiteralOption.create(TableOptionKey.EXTERNAL_VOLUME, external_volume), + LiteralOption.create(TableOptionKey.BASE_LOCATION, base_location), + LiteralOption.create(TableOptionKey.CATALOG, "SNOWFLAKE"), + ] + + kw.update(self._as_dialect_options(options)) + super().__init__(name, metadata, *args, **kw) + + def _init( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + self.__init__(name, metadata, *args, _no_init=False, **kw) + + def __repr__(self) -> str: + return "IcebergTable(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + [repr(self.external_volume)] + + [repr(self.base_location)] + + [repr(self.catalog)] + + [repr(self.cluster_by)] + + [repr(self.as_query)] + + [f"{k}={repr(getattr(self, k))}" for k in ["schema"]] + ) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py index 11b54c1a..e94ea46b 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py @@ -3,6 +3,7 @@ # from .as_query_option import AsQueryOption, AsQueryOptionType +from .cluster_by_option import ClusterByOption, ClusterByOptionType from .identifier_option import IdentifierOption, IdentifierOptionType from .keyword_option import KeywordOption, KeywordOptionType from .keywords import SnowflakeKeyword @@ -17,6 +18,7 @@ "KeywordOption", "AsQueryOption", "TargetLagOption", + "ClusterByOption", # Enums "TimeUnit", "SnowflakeKeyword", @@ -27,4 +29,5 @@ "AsQueryOptionType", "TargetLagOptionType", "KeywordOptionType", + "ClusterByOptionType", ] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/cluster_by_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/cluster_by_option.py new file mode 100644 index 00000000..b92029bb --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/cluster_by_option.py @@ -0,0 +1,58 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import List, Union + +from sqlalchemy.sql.expression import TextClause + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .table_option import Priority, TableOption, TableOptionKey + + +class ClusterByOption(TableOption): + """Class to represent the cluster by clause in tables. + For further information on this clause, please refer to: https://docs.snowflake.com/en/user-guide/tables-clustering-keys + Example: + cluster_by=ClusterByOption('name', text('id > 0')) + + is equivalent to: + + cluster by (name, id > 0) + """ + + def __init__(self, *expressions: Union[str, TextClause]) -> None: + super().__init__() + self._name: TableOptionKey = TableOptionKey.CLUSTER_BY + self.expressions = expressions + + @staticmethod + def create(value: "ClusterByOptionType") -> "TableOption": + if isinstance(value, (NoneType, ClusterByOption)): + return value + if isinstance(value, List): + return ClusterByOption(*value) + return TableOption._get_invalid_table_option( + TableOptionKey.CLUSTER_BY, + str(type(value).__name__), + [ClusterByOption.__name__, list.__name__], + ) + + def template(self) -> str: + return f"{self.option_name.upper()} (%s)" + + @property + def priority(self) -> Priority: + return Priority.HIGH + + def __get_expression(self): + return ", ".join([str(expression) for expression in self.expressions]) + + def _render(self, compiler) -> str: + return self.template() % (self.__get_expression()) + + def __repr__(self) -> str: + return "ClusterByOption(%s)" % self.__get_expression() + + +ClusterByOptionType = Union[ClusterByOption, List[Union[str, TextClause]]] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py index 14b91f2e..5ebb4817 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py @@ -73,6 +73,7 @@ class TableOptionKey(Enum): BASE_LOCATION = "base_location" CATALOG = "catalog" CATALOG_SYNC = "catalog_sync" + CLUSTER_BY = "cluster by" DATA_RETENTION_TIME_IN_DAYS = "data_retention_time_in_days" DEFAULT_DDL_COLLATION = "default_ddl_collation" EXTERNAL_VOLUME = "external_volume" diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/snowflake_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/snowflake_table.py new file mode 100644 index 00000000..56a14c83 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/snowflake_table.py @@ -0,0 +1,70 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from typing import Any + +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from .table_from_query import TableFromQueryBase + + +class SnowflakeTable(TableFromQueryBase): + """ + A class representing a table in Snowflake with configurable options and settings. + + While it does not support reflection at this time, it provides a flexible + interface for creating tables and management. + + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-table + Example usage: + + SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by = ["id", text("name > 5")] + ) + + Example using explict options: + + SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by = ClusterByOption("id", text("name > 5")) + ) + + """ + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + super().__init__(name, metadata, *args, **kw) + + def _init( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + self.__init__(name, metadata, *args, _no_init=False, **kw) + + def __repr__(self) -> str: + return "SnowflakeTable(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + [repr(self.cluster_by)] + + [repr(self.as_query)] + + [f"{k}={repr(getattr(self, k))}" for k in ["schema"]] + ) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py index fccc7a0b..cbd65de3 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py @@ -7,12 +7,12 @@ from sqlalchemy.sql import Selectable from sqlalchemy.sql.schema import Column, MetaData, SchemaItem -from .custom_table_base import CustomTableBase +from .clustered_table import ClusteredTableBase from .options.as_query_option import AsQueryOption, AsQueryOptionType from .options.table_option import TableOptionKey -class TableFromQueryBase(CustomTableBase): +class TableFromQueryBase(ClusteredTableBase): @property def as_query(self) -> Optional[AsQueryOption]: diff --git a/tests/__snapshots__/test_core.ambr b/tests/__snapshots__/test_core.ambr new file mode 100644 index 00000000..7a4e0f99 --- /dev/null +++ b/tests/__snapshots__/test_core.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_compile_table_with_cluster_by_with_expression + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tPRIMARY KEY ("Id")) CLUSTER BY ("Id", name, "Id" > 5)' +# --- diff --git a/tests/custom_tables/__snapshots__/test_compile_iceberg_table.ambr b/tests/custom_tables/__snapshots__/test_compile_iceberg_table.ambr new file mode 100644 index 00000000..b243cc09 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_compile_iceberg_table.ambr @@ -0,0 +1,19 @@ +# serializer version: 1 +# name: test_compile_dynamic_table_orm_with_as_query + "CREATE ICEBERG TABLE test_iceberg_table_orm_2 (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tPRIMARY KEY (id))\tEXTERNAL_VOLUME = 'my_external_volume'\tCATALOG = 'SNOWFLAKE'\tBASE_LOCATION = 'my_iceberg_table'\tAS SELECT * FROM table" +# --- +# name: test_compile_icberg_table_with_primary_key + "CREATE ICEBERG TABLE test_iceberg_table_with_options (\tid INTEGER NOT NULL AUTOINCREMENT, \tgeom VARCHAR, \tPRIMARY KEY (id))\tEXTERNAL_VOLUME = 'my_external_volume'\tCATALOG = 'SNOWFLAKE'\tBASE_LOCATION = 'my_iceberg_table'" +# --- +# name: test_compile_iceberg_table + "CREATE ICEBERG TABLE test_iceberg_table (\tid INTEGER, \tgeom VARCHAR)\tEXTERNAL_VOLUME = 'my_external_volume'\tCATALOG = 'SNOWFLAKE'\tBASE_LOCATION = 'my_iceberg_table'" +# --- +# name: test_compile_iceberg_table_with_one_wrong_option_types + ''' + Invalid parameter type 'IdentifierOption' provided for 'external_volume'. Expected one of the following types: 'LiteralOption', 'str', 'int'. + + ''' +# --- +# name: test_compile_iceberg_table_with_options_objects + "CREATE ICEBERG TABLE test_iceberg_table_with_options (\tid INTEGER, \tgeom VARCHAR)\tEXTERNAL_VOLUME = 'my_external_volume'\tCATALOG = 'SNOWFLAKE'\tBASE_LOCATION = 'my_iceberg_table'" +# --- diff --git a/tests/custom_tables/__snapshots__/test_compile_snowflake_table.ambr b/tests/custom_tables/__snapshots__/test_compile_snowflake_table.ambr new file mode 100644 index 00000000..5ea64c12 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_compile_snowflake_table.ambr @@ -0,0 +1,35 @@ +# serializer version: 1 +# name: test_compile_dynamic_table_orm_with_str_keys + 'CREATE TABLE "SCHEMA_DB".test_snowflake_table_orm_2 (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tPRIMARY KEY (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_dynamic_table_with_foreign_key + 'CREATE TABLE test_table_2 (\tid INTEGER NOT NULL, \tgeom VARCHAR, \tPRIMARY KEY (id), \tFOREIGN KEY(id) REFERENCES "table" (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_dynamic_table_with_primary_key + 'CREATE TABLE test_table_2 (\tid INTEGER NOT NULL AUTOINCREMENT, \tgeom VARCHAR, \tPRIMARY KEY (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table + 'CREATE TABLE test_table_1 (\tid INTEGER, \tgeom VARCHAR)\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table_orm_with_str_keys + 'CREATE TABLE "SCHEMA_DB".test_snowflake_table_orm_2 (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tPRIMARY KEY (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table_with_explicit_options + 'CREATE TABLE test_table_2 (\tid INTEGER, \tgeom VARCHAR)\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table_with_foreign_key + 'CREATE TABLE test_table_2 (\tid INTEGER NOT NULL, \tgeom VARCHAR, \tPRIMARY KEY (id), \tFOREIGN KEY(id) REFERENCES "table" (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table_with_primary_key + 'CREATE TABLE test_table_2 (\tid INTEGER NOT NULL AUTOINCREMENT, \tgeom VARCHAR, \tPRIMARY KEY (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table_with_selectable + 'CREATE TABLE snowflake_test_table_1 (\tid INTEGER, \tgeom VARCHAR)\tAS SELECT test_table_1.id, test_table_1.geom FROM test_table_1 WHERE test_table_1.id = 23' +# --- +# name: test_compile_snowflake_table_with_wrong_option_types + ''' + Invalid parameter type 'AsQueryOption' provided for 'cluster by'. Expected one of the following types: 'ClusterByOption', 'list'. + Invalid parameter type 'ClusterByOption' provided for 'as_query'. Expected one of the following types: 'AsQueryOption', 'str', 'Selectable'. + + ''' +# --- diff --git a/tests/custom_tables/__snapshots__/test_create_iceberg_table.ambr b/tests/custom_tables/__snapshots__/test_create_iceberg_table.ambr new file mode 100644 index 00000000..908a4c60 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_create_iceberg_table.ambr @@ -0,0 +1,14 @@ +# serializer version: 1 +# name: test_create_iceberg_table + ''' + (snowflake.connector.errors.ProgrammingError) 091017 (22000): S3 bucket 'my_example_bucket' does not exist or not authorized. + [SQL: + CREATE ICEBERG TABLE "Iceberg_Table_1" ( + id INTEGER NOT NULL AUTOINCREMENT, + geom VARCHAR, + PRIMARY KEY (id) + ) EXTERNAL_VOLUME = 'exvol' CATALOG = 'SNOWFLAKE' BASE_LOCATION = 'my_iceberg_table' + + ] + ''' +# --- diff --git a/tests/custom_tables/__snapshots__/test_create_snowflake_table.ambr b/tests/custom_tables/__snapshots__/test_create_snowflake_table.ambr new file mode 100644 index 00000000..98d3137f --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_create_snowflake_table.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_create_snowflake_table_with_cluster_by + "[(1, 'test')]" +# --- diff --git a/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr b/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr new file mode 100644 index 00000000..6ef09ff7 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr @@ -0,0 +1,7 @@ +# serializer version: 1 +# name: test_simple_reflection_of_table_as_snowflake_table + 'CREATE TABLE test_snowflake_table_reflection (\tid DECIMAL(38, 0) NOT NULL, \tname VARCHAR(16777216), \tCONSTRAINT demo_name PRIMARY KEY (id))' +# --- +# name: test_simple_reflection_of_table_as_sqlalchemy_table + 'CREATE TABLE test_snowflake_table_reflection (\tid DECIMAL(38, 0) NOT NULL, \tname VARCHAR(16777216), \tCONSTRAINT demo_name PRIMARY KEY (id))' +# --- diff --git a/tests/custom_tables/test_compile_iceberg_table.py b/tests/custom_tables/test_compile_iceberg_table.py new file mode 100644 index 00000000..173e7b0a --- /dev/null +++ b/tests/custom_tables/test_compile_iceberg_table.py @@ -0,0 +1,116 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import Column, Integer, MetaData, String +from sqlalchemy.exc import ArgumentError +from sqlalchemy.orm import declarative_base +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import IcebergTable +from snowflake.sqlalchemy.sql.custom_schema.options import ( + IdentifierOption, + LiteralOption, +) + + +def test_compile_iceberg_table(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_iceberg_table" + test_table = IcebergTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + external_volume="my_external_volume", + base_location="my_iceberg_table", + ) + + value = CreateTable(test_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_iceberg_table_with_options_objects(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_iceberg_table_with_options" + test_table = IcebergTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + external_volume=LiteralOption("my_external_volume"), + base_location=LiteralOption("my_iceberg_table"), + ) + + value = CreateTable(test_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_iceberg_table_with_one_wrong_option_types(snapshot): + metadata = MetaData() + table_name = "test_wrong_iceberg_table" + with pytest.raises(ArgumentError) as argument_error: + IcebergTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + external_volume=IdentifierOption("my_external_volume"), + base_location=LiteralOption("my_iceberg_table"), + ) + + assert str(argument_error.value) == snapshot + + +def test_compile_icberg_table_with_primary_key(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_iceberg_table_with_options" + test_table = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + external_volume=LiteralOption("my_external_volume"), + base_location=LiteralOption("my_iceberg_table"), + ) + + value = CreateTable(test_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_orm_with_as_query(sql_compiler, snapshot): + Base = declarative_base() + + class TestDynamicTableOrm(Base): + __tablename__ = "test_iceberg_table_orm_2" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return IcebergTable(name, metadata, *arg, **kw) + + __table_args__ = { + "external_volume": "my_external_volume", + "base_location": "my_iceberg_table", + "as_query": "SELECT * FROM table", + } + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"" + + value = CreateTable(TestDynamicTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot diff --git a/tests/custom_tables/test_compile_snowflake_table.py b/tests/custom_tables/test_compile_snowflake_table.py new file mode 100644 index 00000000..be9383eb --- /dev/null +++ b/tests/custom_tables/test_compile_snowflake_table.py @@ -0,0 +1,180 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import ( + Column, + ForeignKeyConstraint, + Integer, + MetaData, + String, + select, + text, +) +from sqlalchemy.exc import ArgumentError +from sqlalchemy.orm import declarative_base +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import SnowflakeTable +from snowflake.sqlalchemy.sql.custom_schema.options import ( + AsQueryOption, + ClusterByOption, +) + + +def test_compile_snowflake_table(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_table_1" + test_geometry = SnowflakeTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + cluster_by=["id", text("id > 100")], + as_query="SELECT * FROM table", + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_snowflake_table_with_explicit_options(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_table_2" + test_geometry = SnowflakeTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + cluster_by=ClusterByOption("id", text("id > 100")), + as_query=AsQueryOption("SELECT * FROM table"), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_snowflake_table_with_wrong_option_types(snapshot): + metadata = MetaData() + table_name = "test_snowflake_table" + with pytest.raises(ArgumentError) as argument_error: + SnowflakeTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + as_query=ClusterByOption("id", text("id > 100")), + cluster_by=AsQueryOption("SELECT * FROM table"), + ) + + assert str(argument_error.value) == snapshot + + +def test_compile_snowflake_table_with_primary_key(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_table_2" + test_geometry = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + cluster_by=ClusterByOption("id", text("id > 100")), + as_query=AsQueryOption("SELECT * FROM table"), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_snowflake_table_with_foreign_key(sql_compiler, snapshot): + metadata = MetaData() + + SnowflakeTable( + "table", + metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + ForeignKeyConstraint(["id"], ["table.id"]), + cluster_by=ClusterByOption("id", text("id > 100")), + as_query=AsQueryOption("SELECT * FROM table"), + ) + + table_name = "test_table_2" + test_geometry = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + ForeignKeyConstraint(["id"], ["table.id"]), + cluster_by=ClusterByOption("id", text("id > 100")), + as_query=AsQueryOption("SELECT * FROM table"), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_snowflake_table_orm_with_str_keys(sql_compiler, snapshot): + Base = declarative_base() + + class TestSnowflakeTableOrm(Base): + __tablename__ = "test_snowflake_table_orm_2" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return SnowflakeTable(name, metadata, *arg, **kw) + + __table_args__ = { + "schema": "SCHEMA_DB", + "cluster_by": ["id", text("id > 100")], + "as_query": "SELECT * FROM table", + } + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"" + + value = CreateTable(TestSnowflakeTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_snowflake_table_with_selectable(sql_compiler, snapshot): + Base = declarative_base() + + test_table_1 = SnowflakeTable( + "test_table_1", + Base.metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + ForeignKeyConstraint(["id"], ["table.id"]), + cluster_by=ClusterByOption("id", text("id > 100")), + ) + + test_table_2 = SnowflakeTable( + "snowflake_test_table_1", + Base.metadata, + as_query=select(test_table_1).where(test_table_1.c.id == 23), + ) + + value = CreateTable(test_table_2) + + actual = sql_compiler(value) + + assert actual == snapshot diff --git a/tests/custom_tables/test_create_iceberg_table.py b/tests/custom_tables/test_create_iceberg_table.py new file mode 100644 index 00000000..3ecd703b --- /dev/null +++ b/tests/custom_tables/test_create_iceberg_table.py @@ -0,0 +1,43 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import Column, Integer, MetaData, String +from sqlalchemy.exc import ProgrammingError + +from snowflake.sqlalchemy import IcebergTable + + +@pytest.mark.aws +def test_create_iceberg_table(engine_testaccount, snapshot): + metadata = MetaData() + external_volume_name = "exvol" + create_external_volume = f""" + CREATE OR REPLACE EXTERNAL VOLUME {external_volume_name} + STORAGE_LOCATIONS = + ( + ( + NAME = 'my-s3-us-west-2' + STORAGE_PROVIDER = 'S3' + STORAGE_BASE_URL = 's3://MY_EXAMPLE_BUCKET/' + STORAGE_AWS_ROLE_ARN = 'arn:aws:iam::123456789012:role/myrole' + ENCRYPTION=(TYPE='AWS_SSE_KMS' KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab') + ) + ); + """ + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_external_volume) + IcebergTable( + "Iceberg_Table_1", + metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + external_volume=external_volume_name, + base_location="my_iceberg_table", + ) + + with pytest.raises(ProgrammingError) as argument_error: + metadata.create_all(engine_testaccount) + + error_str = str(argument_error.value) + assert error_str[: error_str.rfind("\n")] == snapshot diff --git a/tests/custom_tables/test_create_snowflake_table.py b/tests/custom_tables/test_create_snowflake_table.py new file mode 100644 index 00000000..09140fb8 --- /dev/null +++ b/tests/custom_tables/test_create_snowflake_table.py @@ -0,0 +1,66 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from sqlalchemy import Column, Integer, MetaData, String, select, text +from sqlalchemy.orm import Session, declarative_base + +from snowflake.sqlalchemy import SnowflakeTable + + +def test_create_snowflake_table_with_cluster_by( + engine_testaccount, db_parameters, snapshot +): + metadata = MetaData() + table_name = "test_create_snowflake_table" + + test_table_1 = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by=["id", text("id > 5")], + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + conn.execute(ins) + conn.commit() + + try: + with engine_testaccount.connect() as conn: + s = select(test_table_1) + results_hybrid_table = conn.execute(s).fetchall() + assert str(results_hybrid_table) == snapshot + finally: + metadata.drop_all(engine_testaccount) + + +def test_create_snowflake_table_with_orm(sql_compiler, engine_testaccount): + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestHybridTableOrm(Base): + __tablename__ = "test_snowflake_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return SnowflakeTable(name, metadata, *arg, **kw) + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"({self.id!r}, {self.name!r})" + + Base.metadata.create_all(engine_testaccount) + + try: + instance = TestHybridTableOrm(id=0, name="name_example") + session.add(instance) + session.commit() + data = session.query(TestHybridTableOrm).all() + assert str(data) == "[(0, 'name_example')]" + finally: + Base.metadata.drop_all(engine_testaccount) diff --git a/tests/custom_tables/test_reflect_snowflake_table.py b/tests/custom_tables/test_reflect_snowflake_table.py new file mode 100644 index 00000000..ef84622b --- /dev/null +++ b/tests/custom_tables/test_reflect_snowflake_table.py @@ -0,0 +1,69 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from sqlalchemy import MetaData, Table +from sqlalchemy.sql.ddl import CreateTable + +from src.snowflake.sqlalchemy import SnowflakeTable + + +def test_simple_reflection_of_table_as_sqlalchemy_table( + engine_testaccount, db_parameters, sql_compiler, snapshot +): + metadata = MetaData() + table_name = "test_snowflake_table_reflection" + + create_table_sql = f""" + CREATE TABLE {table_name} (id INT primary key, name VARCHAR); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + snowflake_test_table = Table(table_name, metadata, autoload_with=engine_testaccount) + constraint = snowflake_test_table.constraints.pop() + constraint.name = "demo_name" + snowflake_test_table.constraints.add(constraint) + + try: + with engine_testaccount.connect(): + value = CreateTable(snowflake_test_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + finally: + metadata.drop_all(engine_testaccount) + + +def test_simple_reflection_of_table_as_snowflake_table( + engine_testaccount, db_parameters, sql_compiler, snapshot +): + metadata = MetaData() + table_name = "test_snowflake_table_reflection" + + create_table_sql = f""" + CREATE TABLE {table_name} (id INT primary key, name VARCHAR); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + snowflake_test_table = SnowflakeTable( + table_name, metadata, autoload_with=engine_testaccount + ) + constraint = snowflake_test_table.constraints.pop() + constraint.name = "demo_name" + snowflake_test_table.constraints.add(constraint) + + try: + with engine_testaccount.connect(): + value = CreateTable(snowflake_test_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/test_core.py b/tests/test_core.py index 980db1d2..9342ad58 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -36,6 +36,7 @@ ) from sqlalchemy.exc import DBAPIError, NoSuchTableError, OperationalError from sqlalchemy.sql import and_, not_, or_, select +from sqlalchemy.sql.ddl import CreateTable import snowflake.connector.errors import snowflake.sqlalchemy.snowdialect @@ -699,6 +700,39 @@ def test_create_table_with_cluster_by(engine_testaccount): user.drop(engine_testaccount) +def test_create_table_with_cluster_by_with_expression(engine_testaccount): + metadata = MetaData() + Table( + "clustered_user", + metadata, + Column("Id", Integer, primary_key=True), + Column("name", String), + snowflake_clusterby=["Id", "name", text('"Id" > 5')], + ) + metadata.create_all(engine_testaccount) + try: + inspector = inspect(engine_testaccount) + columns_in_table = inspector.get_columns("clustered_user") + assert columns_in_table[0]["name"] == "Id", "name" + finally: + metadata.drop_all(engine_testaccount) + + +def test_compile_table_with_cluster_by_with_expression(sql_compiler, snapshot): + metadata = MetaData() + user = Table( + "clustered_user", + metadata, + Column("Id", Integer, primary_key=True), + Column("name", String), + snowflake_clusterby=["Id", "name", text('"Id" > 5')], + ) + + create_table = CreateTable(user) + + assert sql_compiler(create_table) == snapshot + + def test_view_names(engine_testaccount): """ Tests all views From 31d0da643f0bef64b1fbd49e49e840ff3d29eb87 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Tue, 19 Nov 2024 10:15:48 -0600 Subject: [PATCH 16/21] Add support for map datatype (#541) Add support for map datatype --- DESCRIPTION.md | 3 +- pyproject.toml | 3 +- src/snowflake/sqlalchemy/__init__.py | 2 + src/snowflake/sqlalchemy/_constants.py | 1 + src/snowflake/sqlalchemy/base.py | 7 + src/snowflake/sqlalchemy/custom_types.py | 20 ++ src/snowflake/sqlalchemy/exc.py | 8 + .../sqlalchemy/parser/custom_type_parser.py | 190 ++++++++++++ src/snowflake/sqlalchemy/snowdialect.py | 213 ++++++-------- .../sql/custom_schema/custom_table_base.py | 16 ++ .../sql/custom_schema/iceberg_table.py | 1 + src/snowflake/sqlalchemy/version.py | 2 +- .../test_structured_datatypes.ambr | 90 ++++++ .../test_unit_structured_types.ambr | 4 + tests/conftest.py | 30 ++ .../test_reflect_snowflake_table.ambr | 22 ++ .../test_reflect_snowflake_table.py | 27 +- tests/test_core.py | 136 ++++----- tests/test_structured_datatypes.py | 271 ++++++++++++++++++ tests/test_unit_structured_types.py | 73 +++++ tests/util.py | 2 + 21 files changed, 898 insertions(+), 223 deletions(-) create mode 100644 src/snowflake/sqlalchemy/parser/custom_type_parser.py create mode 100644 tests/__snapshots__/test_structured_datatypes.ambr create mode 100644 tests/__snapshots__/test_unit_structured_types.ambr create mode 100644 tests/test_structured_datatypes.py create mode 100644 tests/test_unit_structured_types.py diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 47697d30..33775996 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,7 +9,7 @@ Source code is also available at: # Release Notes -- (Unreleased) +- v1.7.0(November 12, 2024) - Add support for dynamic tables and required options - Add support for hybrid tables @@ -18,6 +18,7 @@ Source code is also available at: - Add support for refresh_mode option in DynamicTable - Add support for iceberg table with Snowflake Catalog - Fix cluster by option to support explicit expressions + - Add support for MAP datatype - v1.6.1(July 9, 2024) diff --git a/pyproject.toml b/pyproject.toml index 6c72f683..84e64faf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,7 +111,7 @@ line-length = 88 line-length = 88 [tool.pytest.ini_options] -addopts = "-m 'not feature_max_lob_size and not aws'" +addopts = "-m 'not feature_max_lob_size and not aws and not requires_external_volume'" markers = [ # Optional dependency groups markers "lambda: AWS lambda tests", @@ -128,6 +128,7 @@ markers = [ # Other markers "timeout: tests that need a timeout time", "internal: tests that could but should only run on our internal CI", + "requires_external_volume: tests that needs a external volume to be executed", "external: tests that could but should only run on our external CI", "feature_max_lob_size: tests that could but should only run on our external CI", ] diff --git a/src/snowflake/sqlalchemy/__init__.py b/src/snowflake/sqlalchemy/__init__.py index f6c97f0d..7d795b2a 100644 --- a/src/snowflake/sqlalchemy/__init__.py +++ b/src/snowflake/sqlalchemy/__init__.py @@ -50,6 +50,7 @@ FIXED, GEOGRAPHY, GEOMETRY, + MAP, NUMBER, OBJECT, STRING, @@ -119,6 +120,7 @@ "TINYINT", "VARBINARY", "VARIANT", + "MAP", ) _custom_commands = ( diff --git a/src/snowflake/sqlalchemy/_constants.py b/src/snowflake/sqlalchemy/_constants.py index 839745ee..205ad5d9 100644 --- a/src/snowflake/sqlalchemy/_constants.py +++ b/src/snowflake/sqlalchemy/_constants.py @@ -11,3 +11,4 @@ APPLICATION_NAME = "SnowflakeSQLAlchemy" SNOWFLAKE_SQLALCHEMY_VERSION = VERSION DIALECT_NAME = "snowflake" +NOT_NULL = "NOT NULL" diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index 4e36c4ad..a1e16062 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -27,6 +27,7 @@ ExternalStage, ) +from ._constants import NOT_NULL from .exc import ( CustomOptionsAreOnlySupportedOnSnowflakeTables, UnexpectedOptionTypeError, @@ -1071,6 +1072,12 @@ def visit_TINYINT(self, type_, **kw): def visit_VARIANT(self, type_, **kw): return "VARIANT" + def visit_MAP(self, type_, **kw): + not_null = f" {NOT_NULL}" if type_.not_null else "" + return ( + f"MAP({type_.key_type.compile()}, {type_.value_type.compile()}{not_null})" + ) + def visit_ARRAY(self, type_, **kw): return "ARRAY" diff --git a/src/snowflake/sqlalchemy/custom_types.py b/src/snowflake/sqlalchemy/custom_types.py index 802d1ce1..f2c950dd 100644 --- a/src/snowflake/sqlalchemy/custom_types.py +++ b/src/snowflake/sqlalchemy/custom_types.py @@ -37,6 +37,26 @@ class VARIANT(SnowflakeType): __visit_name__ = "VARIANT" +class StructuredType(SnowflakeType): + def __init__(self): + super().__init__() + + +class MAP(StructuredType): + __visit_name__ = "MAP" + + def __init__( + self, + key_type: sqltypes.TypeEngine, + value_type: sqltypes.TypeEngine, + not_null: bool = False, + ): + self.key_type = key_type + self.value_type = value_type + self.not_null = not_null + super().__init__() + + class OBJECT(SnowflakeType): __visit_name__ = "OBJECT" diff --git a/src/snowflake/sqlalchemy/exc.py b/src/snowflake/sqlalchemy/exc.py index 898de279..399e94b6 100644 --- a/src/snowflake/sqlalchemy/exc.py +++ b/src/snowflake/sqlalchemy/exc.py @@ -72,3 +72,11 @@ def __init__(self, errors): def __str__(self): return "".join(str(e) for e in self.errors) + + +class StructuredTypeNotSupportedInTableColumnsError(ArgumentError): + def __init__(self, table_type: str, table_name: str, column_name: str): + super().__init__( + f"Column '{column_name}' is of a structured type, which is only supported on Iceberg tables. " + f"The table '{table_name}' is of type '{table_type}', not Iceberg." + ) diff --git a/src/snowflake/sqlalchemy/parser/custom_type_parser.py b/src/snowflake/sqlalchemy/parser/custom_type_parser.py new file mode 100644 index 00000000..cf69c594 --- /dev/null +++ b/src/snowflake/sqlalchemy/parser/custom_type_parser.py @@ -0,0 +1,190 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +import sqlalchemy.types as sqltypes +from sqlalchemy.sql.type_api import TypeEngine +from sqlalchemy.types import ( + BIGINT, + BINARY, + BOOLEAN, + CHAR, + DATE, + DATETIME, + DECIMAL, + FLOAT, + INTEGER, + REAL, + SMALLINT, + TIME, + TIMESTAMP, + VARCHAR, + NullType, +) + +from ..custom_types import ( + _CUSTOM_DECIMAL, + ARRAY, + DOUBLE, + GEOGRAPHY, + GEOMETRY, + MAP, + OBJECT, + TIMESTAMP_LTZ, + TIMESTAMP_NTZ, + TIMESTAMP_TZ, + VARIANT, +) + +ischema_names = { + "BIGINT": BIGINT, + "BINARY": BINARY, + # 'BIT': BIT, + "BOOLEAN": BOOLEAN, + "CHAR": CHAR, + "CHARACTER": CHAR, + "DATE": DATE, + "DATETIME": DATETIME, + "DEC": DECIMAL, + "DECIMAL": DECIMAL, + "DOUBLE": DOUBLE, + "FIXED": DECIMAL, + "FLOAT": FLOAT, # Snowflake FLOAT datatype doesn't has parameters + "INT": INTEGER, + "INTEGER": INTEGER, + "NUMBER": _CUSTOM_DECIMAL, + # 'OBJECT': ? + "REAL": REAL, + "BYTEINT": SMALLINT, + "SMALLINT": SMALLINT, + "STRING": VARCHAR, + "TEXT": VARCHAR, + "TIME": TIME, + "TIMESTAMP": TIMESTAMP, + "TIMESTAMP_TZ": TIMESTAMP_TZ, + "TIMESTAMP_LTZ": TIMESTAMP_LTZ, + "TIMESTAMP_NTZ": TIMESTAMP_NTZ, + "TINYINT": SMALLINT, + "VARBINARY": BINARY, + "VARCHAR": VARCHAR, + "VARIANT": VARIANT, + "MAP": MAP, + "OBJECT": OBJECT, + "ARRAY": ARRAY, + "GEOGRAPHY": GEOGRAPHY, + "GEOMETRY": GEOMETRY, +} + + +def extract_parameters(text: str) -> list: + """ + Extracts parameters from a comma-separated string, handling parentheses. + + :param text: A string with comma-separated parameters, which may include parentheses. + + :return: A list of parameters as strings. + + :example: + For input `"a, (b, c), d"`, the output is `['a', '(b, c)', 'd']`. + """ + + output_parameters = [] + parameter = "" + open_parenthesis = 0 + for c in text: + + if c == "(": + open_parenthesis += 1 + elif c == ")": + open_parenthesis -= 1 + + if open_parenthesis > 0 or c != ",": + parameter += c + elif c == ",": + output_parameters.append(parameter.strip(" ")) + parameter = "" + if parameter != "": + output_parameters.append(parameter.strip(" ")) + return output_parameters + + +def parse_type(type_text: str) -> TypeEngine: + """ + Parses a type definition string and returns the corresponding SQLAlchemy type. + + The function handles types with or without parameters, such as `VARCHAR(255)` or `INTEGER`. + + :param type_text: A string representing a SQLAlchemy type, which may include parameters + in parentheses (e.g., "VARCHAR(255)" or "DECIMAL(10, 2)"). + :return: An instance of the corresponding SQLAlchemy type class (e.g., `String`, `Integer`), + or `NullType` if the type is not recognized. + + :example: + parse_type("VARCHAR(255)") + String(length=255) + """ + index = type_text.find("(") + type_name = type_text[:index] if index != -1 else type_text + parameters = ( + extract_parameters(type_text[index + 1 : -1]) if type_name != type_text else [] + ) + + col_type_class = ischema_names.get(type_name, None) + col_type_kw = {} + if col_type_class is None: + col_type_class = NullType + else: + if issubclass(col_type_class, sqltypes.Numeric): + col_type_kw = __parse_numeric_type_parameters(parameters) + elif issubclass(col_type_class, (sqltypes.String, sqltypes.BINARY)): + col_type_kw = __parse_type_with_length_parameters(parameters) + elif issubclass(col_type_class, MAP): + col_type_kw = __parse_map_type_parameters(parameters) + if col_type_kw is None: + col_type_class = NullType + col_type_kw = {} + + return col_type_class(**col_type_kw) + + +def __parse_map_type_parameters(parameters): + if len(parameters) != 2: + return None + + key_type_str = parameters[0] + value_type_str = parameters[1] + not_null_str = "NOT NULL" + not_null = False + if ( + len(value_type_str) >= len(not_null_str) + and value_type_str[-len(not_null_str) :] == not_null_str + ): + not_null = True + value_type_str = value_type_str[: -len(not_null_str) - 1] + + key_type: TypeEngine = parse_type(key_type_str) + value_type: TypeEngine = parse_type(value_type_str) + if isinstance(key_type, NullType) or isinstance(value_type, NullType): + return None + + return { + "key_type": key_type, + "value_type": value_type, + "not_null": not_null, + } + + +def __parse_type_with_length_parameters(parameters): + return ( + {"length": int(parameters[0])} + if len(parameters) == 1 and str.isdigit(parameters[0]) + else {} + ) + + +def __parse_numeric_type_parameters(parameters): + result = {} + if len(parameters) >= 1 and str.isdigit(parameters[0]): + result["precision"] = int(parameters[0]) + if len(parameters) == 2 and str.isdigit(parameters[1]): + result["scale"] = int(parameters[1]) + return result diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index f2fb9b18..f9e2e4c8 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -3,6 +3,7 @@ # import operator +import re from collections import defaultdict from functools import reduce from typing import Any @@ -16,26 +17,7 @@ from sqlalchemy.schema import Table from sqlalchemy.sql import text from sqlalchemy.sql.elements import quoted_name -from sqlalchemy.types import ( - BIGINT, - BINARY, - BOOLEAN, - CHAR, - DATE, - DATETIME, - DECIMAL, - FLOAT, - INTEGER, - REAL, - SMALLINT, - TIME, - TIMESTAMP, - VARCHAR, - Date, - DateTime, - Float, - Time, -) +from sqlalchemy.types import FLOAT, Date, DateTime, Float, NullType, Time from snowflake.connector import errors as sf_errors from snowflake.connector.connection import DEFAULT_CONFIGURATION @@ -51,20 +33,13 @@ SnowflakeTypeCompiler, ) from .custom_types import ( - _CUSTOM_DECIMAL, - ARRAY, - GEOGRAPHY, - GEOMETRY, - OBJECT, - TIMESTAMP_LTZ, - TIMESTAMP_NTZ, - TIMESTAMP_TZ, - VARIANT, + MAP, _CUSTOM_Date, _CUSTOM_DateTime, _CUSTOM_Float, _CUSTOM_Time, ) +from .parser.custom_type_parser import ischema_names, parse_type from .sql.custom_schema.custom_table_prefix import CustomTablePrefix from .util import ( _update_connection_application_name, @@ -79,44 +54,6 @@ Float: _CUSTOM_Float, } -ischema_names = { - "BIGINT": BIGINT, - "BINARY": BINARY, - # 'BIT': BIT, - "BOOLEAN": BOOLEAN, - "CHAR": CHAR, - "CHARACTER": CHAR, - "DATE": DATE, - "DATETIME": DATETIME, - "DEC": DECIMAL, - "DECIMAL": DECIMAL, - "DOUBLE": FLOAT, - "FIXED": DECIMAL, - "FLOAT": FLOAT, - "INT": INTEGER, - "INTEGER": INTEGER, - "NUMBER": _CUSTOM_DECIMAL, - # 'OBJECT': ? - "REAL": REAL, - "BYTEINT": SMALLINT, - "SMALLINT": SMALLINT, - "STRING": VARCHAR, - "TEXT": VARCHAR, - "TIME": TIME, - "TIMESTAMP": TIMESTAMP, - "TIMESTAMP_TZ": TIMESTAMP_TZ, - "TIMESTAMP_LTZ": TIMESTAMP_LTZ, - "TIMESTAMP_NTZ": TIMESTAMP_NTZ, - "TINYINT": SMALLINT, - "VARBINARY": BINARY, - "VARCHAR": VARCHAR, - "VARIANT": VARIANT, - "OBJECT": OBJECT, - "ARRAY": ARRAY, - "GEOGRAPHY": GEOGRAPHY, - "GEOMETRY": GEOMETRY, -} - _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME = True @@ -333,8 +270,8 @@ def _denormalize_quote_join(self, *idents): @reflection.cache def _current_database_schema(self, connection, **kw): - res = connection.exec_driver_sql( - "select current_database(), current_schema();" + res = connection.execute( + text("select current_database(), current_schema();") ).fetchone() return ( self.normalize_name(res[0]), @@ -508,6 +445,12 @@ def get_foreign_keys(self, connection, table_name, schema=None, **kw): ) return foreign_key_map.get(table_name, []) + def table_columns_as_dict(self, columns): + result = {} + for column in columns: + result[column["name"]] = column + return result + @reflection.cache def _get_schema_columns(self, connection, schema, **kw): """Get all columns in the schema, if we hit 'Information schema query returned too much data' problem return @@ -515,10 +458,12 @@ def _get_schema_columns(self, connection, schema, **kw): ans = {} current_database, _ = self._current_database_schema(connection, **kw) full_schema_name = self._denormalize_quote_join(current_database, schema) + full_columns_descriptions = {} try: schema_primary_keys = self._get_schema_primary_keys( connection, full_schema_name, **kw ) + schema_name = self.denormalize_name(schema) result = connection.execute( text( """ @@ -539,7 +484,7 @@ def _get_schema_columns(self, connection, schema, **kw): WHERE ic.table_schema=:table_schema ORDER BY ic.ordinal_position""" ), - {"table_schema": self.denormalize_name(schema)}, + {"table_schema": schema_name}, ) except sa_exc.ProgrammingError as pe: if pe.orig.errno == 90030: @@ -569,10 +514,7 @@ def _get_schema_columns(self, connection, schema, **kw): col_type = self.ischema_names.get(coltype, None) col_type_kw = {} if col_type is None: - sa_util.warn( - f"Did not recognize type '{coltype}' of column '{column_name}'" - ) - col_type = sqltypes.NULLTYPE + col_type = NullType else: if issubclass(col_type, FLOAT): col_type_kw["precision"] = numeric_precision @@ -582,6 +524,33 @@ def _get_schema_columns(self, connection, schema, **kw): col_type_kw["scale"] = numeric_scale elif issubclass(col_type, (sqltypes.String, sqltypes.BINARY)): col_type_kw["length"] = character_maximum_length + elif issubclass(col_type, MAP): + if (schema_name, table_name) not in full_columns_descriptions: + full_columns_descriptions[(schema_name, table_name)] = ( + self.table_columns_as_dict( + self._get_table_columns( + connection, table_name, schema_name + ) + ) + ) + + if ( + (schema_name, table_name) in full_columns_descriptions + and column_name + in full_columns_descriptions[(schema_name, table_name)] + ): + ans[table_name].append( + full_columns_descriptions[(schema_name, table_name)][ + column_name + ] + ) + continue + else: + col_type = NullType + if col_type == NullType: + sa_util.warn( + f"Did not recognize type '{coltype}' of column '{column_name}'" + ) type_instance = col_type(**col_type_kw) @@ -616,91 +585,71 @@ def _get_schema_columns(self, connection, schema, **kw): def _get_table_columns(self, connection, table_name, schema=None, **kw): """Get all columns in a table in a schema""" ans = [] - current_database, _ = self._current_database_schema(connection, **kw) - full_schema_name = self._denormalize_quote_join(current_database, schema) - schema_primary_keys = self._get_schema_primary_keys( - connection, full_schema_name, **kw + current_database, default_schema = self._current_database_schema( + connection, **kw ) + schema = schema if schema else default_schema + table_schema = self.denormalize_name(schema) + table_name = self.denormalize_name(table_name) result = connection.execute( text( - """ - SELECT /* sqlalchemy:get_table_columns */ - ic.table_name, - ic.column_name, - ic.data_type, - ic.character_maximum_length, - ic.numeric_precision, - ic.numeric_scale, - ic.is_nullable, - ic.column_default, - ic.is_identity, - ic.comment - FROM information_schema.columns ic - WHERE ic.table_schema=:table_schema - AND ic.table_name=:table_name - ORDER BY ic.ordinal_position""" - ), - { - "table_schema": self.denormalize_name(schema), - "table_name": self.denormalize_name(table_name), - }, + "DESC /* sqlalchemy:_get_schema_columns */" + f" TABLE {table_schema}.{table_name} TYPE = COLUMNS" + ) ) for ( - table_name, column_name, coltype, - character_maximum_length, - numeric_precision, - numeric_scale, + _kind, is_nullable, column_default, - is_identity, + primary_key, + _unique_key, + _check, + _expression, comment, + _policy_name, + _privacy_domain, + _name_mapping, ) in result: - table_name = self.normalize_name(table_name) + column_name = self.normalize_name(column_name) if column_name.startswith("sys_clustering_column"): continue # ignoring clustering column - col_type = self.ischema_names.get(coltype, None) - col_type_kw = {} - if col_type is None: + type_instance = parse_type(coltype) + if isinstance(type_instance, NullType): sa_util.warn( f"Did not recognize type '{coltype}' of column '{column_name}'" ) - col_type = sqltypes.NULLTYPE - else: - if issubclass(col_type, FLOAT): - col_type_kw["precision"] = numeric_precision - col_type_kw["decimal_return_scale"] = numeric_scale - elif issubclass(col_type, sqltypes.Numeric): - col_type_kw["precision"] = numeric_precision - col_type_kw["scale"] = numeric_scale - elif issubclass(col_type, (sqltypes.String, sqltypes.BINARY)): - col_type_kw["length"] = character_maximum_length - - type_instance = col_type(**col_type_kw) - current_table_pks = schema_primary_keys.get(table_name) + identity = None + match = re.match( + r"IDENTITY START (?P\d+) INCREMENT (?P\d+) (?PORDER|NOORDER)", + column_default if column_default else "", + ) + if match: + identity = { + "start": int(match.group("start")), + "increment": int(match.group("increment")), + "order_type": match.group("order_type"), + } + is_identity = identity is not None ans.append( { "name": column_name, "type": type_instance, - "nullable": is_nullable == "YES", - "default": column_default, - "autoincrement": is_identity == "YES", + "nullable": is_nullable == "Y", + "default": None if is_identity else column_default, + "autoincrement": is_identity, "comment": comment if comment != "" else None, - "primary_key": ( - ( - column_name - in schema_primary_keys[table_name]["constrained_columns"] - ) - if current_table_pks - else False - ), + "primary_key": primary_key == "Y", } ) + if is_identity: + ans[-1]["identity"] = identity + # If we didn't find any columns for the table, the table doesn't exist. if len(ans) == 0: raise sa_exc.NoSuchTableError() diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py index 671c6957..6f7ee0c5 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py @@ -9,10 +9,12 @@ from ..._constants import DIALECT_NAME from ...compat import IS_VERSION_20 from ...custom_commands import NoneType +from ...custom_types import StructuredType from ...exc import ( MultipleErrors, NoPrimaryKeyError, RequiredParametersNotProvidedError, + StructuredTypeNotSupportedInTableColumnsError, UnsupportedPrimaryKeysAndForeignKeysError, ) from .custom_table_prefix import CustomTablePrefix @@ -25,6 +27,7 @@ class CustomTableBase(Table): _support_primary_and_foreign_keys: bool = True _enforce_primary_keys: bool = False _required_parameters: List[TableOptionKey] = [] + _support_structured_types: bool = False @property def table_prefixes(self) -> typing.List[str]: @@ -53,6 +56,10 @@ def __init__( def _validate_table(self): exceptions: List[Exception] = [] + columns_validation = self.__validate_columns() + if columns_validation is not None: + exceptions.append(columns_validation) + for _, option in self.dialect_options[DIALECT_NAME].items(): if isinstance(option, InvalidTableOption): exceptions.append(option.exception) @@ -84,6 +91,15 @@ def _validate_table(self): elif len(exceptions) == 1: raise exceptions[0] + def __validate_columns(self): + for column in self.columns: + if not self._support_structured_types and isinstance( + column.type, StructuredType + ): + return StructuredTypeNotSupportedInTableColumnsError( + self.__class__.__name__, self.name, column.name + ) + def _get_dialect_option( self, option_name: TableOptionKey ) -> typing.Optional[TableOption]: diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py index 5c9c53d9..4f62d4f2 100644 --- a/src/snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py +++ b/src/snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py @@ -44,6 +44,7 @@ class IcebergTable(TableFromQueryBase): """ __table_prefixes__ = [CustomTablePrefix.ICEBERG] + _support_structured_types = True @property def external_volume(self) -> typing.Optional[LiteralOption]: diff --git a/src/snowflake/sqlalchemy/version.py b/src/snowflake/sqlalchemy/version.py index d90f706b..b80a9096 100644 --- a/src/snowflake/sqlalchemy/version.py +++ b/src/snowflake/sqlalchemy/version.py @@ -3,4 +3,4 @@ # # Update this for the versions # Don't change the forth version number from None -VERSION = "1.6.1" +VERSION = "1.7.0" diff --git a/tests/__snapshots__/test_structured_datatypes.ambr b/tests/__snapshots__/test_structured_datatypes.ambr new file mode 100644 index 00000000..0325a946 --- /dev/null +++ b/tests/__snapshots__/test_structured_datatypes.ambr @@ -0,0 +1,90 @@ +# serializer version: 1 +# name: test_compile_table_with_cluster_by_with_expression + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname MAP(DECIMAL, VARCHAR), \tPRIMARY KEY ("Id"))' +# --- +# name: test_compile_table_with_double_map + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname MAP(DECIMAL, MAP(DECIMAL, VARCHAR)), \tPRIMARY KEY ("Id"))' +# --- +# name: test_insert_map + list([ + (1, '{\n "100": "item1",\n "200": "item2"\n}'), + ]) +# --- +# name: test_insert_map_orm + ''' + 002014 (22000): SQL compilation error: + Invalid expression [CAST(OBJECT_CONSTRUCT('100', 'item1', '200', 'item2') AS MAP(NUMBER(10,0), VARCHAR(16777216)))] in VALUES clause + ''' +# --- +# name: test_inspect_structured_data_types[structured_type0] + list([ + dict({ + 'autoincrement': True, + 'comment': None, + 'default': None, + 'identity': dict({ + 'increment': 1, + 'start': 1, + }), + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=10, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'map_id', + 'nullable': True, + 'primary_key': False, + 'type': MAP(_CUSTOM_DECIMAL(precision=10, scale=0), VARCHAR(length=16777216)), + }), + ]) +# --- +# name: test_inspect_structured_data_types[structured_type1] + list([ + dict({ + 'autoincrement': True, + 'comment': None, + 'default': None, + 'identity': dict({ + 'increment': 1, + 'start': 1, + }), + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=10, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'map_id', + 'nullable': True, + 'primary_key': False, + 'type': MAP(_CUSTOM_DECIMAL(precision=10, scale=0), MAP(_CUSTOM_DECIMAL(precision=10, scale=0), VARCHAR(length=16777216))), + }), + ]) +# --- +# name: test_reflect_structured_data_types[MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), VARCHAR))] + "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tmap_id MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" +# --- +# name: test_reflect_structured_data_types[MAP(NUMBER(10, 0), VARCHAR)] + "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tmap_id MAP(DECIMAL(10, 0), VARCHAR(16777216)), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" +# --- +# name: test_select_map_orm + list([ + (1, '{\n "100": "item1",\n "200": "item2"\n}'), + (2, '{\n "100": "item1",\n "200": "item2"\n}'), + ]) +# --- +# name: test_select_map_orm.1 + list([ + ]) +# --- +# name: test_select_map_orm.2 + list([ + ]) +# --- diff --git a/tests/__snapshots__/test_unit_structured_types.ambr b/tests/__snapshots__/test_unit_structured_types.ambr new file mode 100644 index 00000000..ff861351 --- /dev/null +++ b/tests/__snapshots__/test_unit_structured_types.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_compile_map_with_not_null + 'MAP(DECIMAL(10, 0), VARCHAR NOT NULL)' +# --- diff --git a/tests/conftest.py b/tests/conftest.py index d4dab3d1..a91521b9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -94,6 +94,36 @@ def db_parameters(): yield get_db_parameters() +@pytest.fixture(scope="session") +def external_volume(): + db_parameters = get_db_parameters() + if "external_volume" in db_parameters: + yield db_parameters["external_volume"] + else: + raise ValueError("External_volume is not set") + + +@pytest.fixture(scope="session") +def external_stage(): + db_parameters = get_db_parameters() + if "external_stage" in db_parameters: + yield db_parameters["external_stage"] + else: + raise ValueError("External_stage is not set") + + +@pytest.fixture(scope="function") +def base_location(external_stage, engine_testaccount): + unique_id = str(uuid.uuid4()) + base_location = "L" + unique_id.replace("-", "_") + yield base_location + remove_base_location = f""" + REMOVE @{external_stage} pattern='.*{base_location}.*'; + """ + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(remove_base_location) + + def get_db_parameters() -> dict: """ Sets the db connection parameters diff --git a/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr b/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr index 6ef09ff7..7e85841a 100644 --- a/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr +++ b/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr @@ -1,4 +1,26 @@ # serializer version: 1 +# name: test_inspect_snowflake_table + list([ + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=38, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'name', + 'nullable': True, + 'primary_key': False, + 'type': VARCHAR(length=16777216), + }), + ]) +# --- # name: test_simple_reflection_of_table_as_snowflake_table 'CREATE TABLE test_snowflake_table_reflection (\tid DECIMAL(38, 0) NOT NULL, \tname VARCHAR(16777216), \tCONSTRAINT demo_name PRIMARY KEY (id))' # --- diff --git a/tests/custom_tables/test_reflect_snowflake_table.py b/tests/custom_tables/test_reflect_snowflake_table.py index ef84622b..603b6187 100644 --- a/tests/custom_tables/test_reflect_snowflake_table.py +++ b/tests/custom_tables/test_reflect_snowflake_table.py @@ -1,10 +1,10 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from sqlalchemy import MetaData, Table +from sqlalchemy import MetaData, Table, inspect from sqlalchemy.sql.ddl import CreateTable -from src.snowflake.sqlalchemy import SnowflakeTable +from snowflake.sqlalchemy import SnowflakeTable def test_simple_reflection_of_table_as_sqlalchemy_table( @@ -67,3 +67,26 @@ def test_simple_reflection_of_table_as_snowflake_table( finally: metadata.drop_all(engine_testaccount) + + +def test_inspect_snowflake_table( + engine_testaccount, db_parameters, sql_compiler, snapshot +): + metadata = MetaData() + table_name = "test_snowflake_table_inspect" + + create_table_sql = f""" + CREATE TABLE {table_name} (id INT primary key, name VARCHAR); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + try: + with engine_testaccount.connect() as conn: + insp = inspect(conn) + table = insp.get_columns(table_name) + assert table == snapshot + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/test_core.py b/tests/test_core.py index 9342ad58..63f097db 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -30,6 +30,7 @@ UniqueConstraint, create_engine, dialects, + exc, insert, inspect, text, @@ -124,14 +125,26 @@ def test_connect_args(): Snowflake connect string supports account name as a replacement of host:port """ + server = "" + if "host" in CONNECTION_PARAMETERS and "port" in CONNECTION_PARAMETERS: + server = "{host}:{port}".format( + host=CONNECTION_PARAMETERS["host"], port=CONNECTION_PARAMETERS["port"] + ) + elif "account" in CONNECTION_PARAMETERS and "region" in CONNECTION_PARAMETERS: + server = "{account}.{region}".format( + account=CONNECTION_PARAMETERS["account"], + region=CONNECTION_PARAMETERS["region"], + ) + elif "account" in CONNECTION_PARAMETERS: + server = CONNECTION_PARAMETERS["account"] + engine = create_engine( - "snowflake://{user}:{password}@{host}:{port}/{database}/{schema}" + "snowflake://{user}:{password}@{server}/{database}/{schema}" "?account={account}&protocol={protocol}".format( user=CONNECTION_PARAMETERS["user"], account=CONNECTION_PARAMETERS["account"], password=CONNECTION_PARAMETERS["password"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], + server=server, database=CONNECTION_PARAMETERS["database"], schema=CONNECTION_PARAMETERS["schema"], protocol=CONNECTION_PARAMETERS["protocol"], @@ -142,32 +155,14 @@ def test_connect_args(): finally: engine.dispose() - engine = create_engine( - URL( - user=CONNECTION_PARAMETERS["user"], - password=CONNECTION_PARAMETERS["password"], - account=CONNECTION_PARAMETERS["account"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], - protocol=CONNECTION_PARAMETERS["protocol"], - ) - ) + engine = create_engine(URL(**CONNECTION_PARAMETERS)) try: verify_engine_connection(engine) finally: engine.dispose() - - engine = create_engine( - URL( - user=CONNECTION_PARAMETERS["user"], - password=CONNECTION_PARAMETERS["password"], - account=CONNECTION_PARAMETERS["account"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], - protocol=CONNECTION_PARAMETERS["protocol"], - warehouse="testwh", - ) - ) + parameters = {**CONNECTION_PARAMETERS} + parameters["warehouse"] = "testwh" + engine = create_engine(URL(**parameters)) try: verify_engine_connection(engine) finally: @@ -175,14 +170,10 @@ def test_connect_args(): def test_boolean_query_argument_parsing(): + engine = create_engine( URL( - user=CONNECTION_PARAMETERS["user"], - password=CONNECTION_PARAMETERS["password"], - account=CONNECTION_PARAMETERS["account"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], - protocol=CONNECTION_PARAMETERS["protocol"], + **CONNECTION_PARAMETERS, validate_default_parameters=True, ) ) @@ -1549,15 +1540,8 @@ def test_too_many_columns_detection(engine_testaccount, db_parameters): connection = inspector.bind.connect() original_execute = connection.execute - too_many_columns_was_raised = False - - def mock_helper(command, *args, **kwargs): - if "_get_schema_columns" in command.text: - # Creating exception exactly how SQLAlchemy does - nonlocal too_many_columns_was_raised - too_many_columns_was_raised = True - raise DBAPIError.instance( - """ + exception_instance = DBAPIError.instance( + """ SELECT /* sqlalchemy:_get_schema_columns */ ic.table_name, ic.column_name, @@ -1572,27 +1556,32 @@ def mock_helper(command, *args, **kwargs): FROM information_schema.columns ic WHERE ic.table_schema='schema_name' ORDER BY ic.ordinal_position""", - {"table_schema": "TESTSCHEMA"}, - ProgrammingError( - "Information schema query returned too much data. Please repeat query with more " - "selective predicates.", - 90030, - ), - Error, - hide_parameters=False, - connection_invalidated=False, - dialect=SnowflakeDialect(), - ismulti=None, - ) + {"table_schema": "TESTSCHEMA"}, + ProgrammingError( + "Information schema query returned too much data. Please repeat query with more " + "selective predicates.", + 90030, + ), + Error, + hide_parameters=False, + connection_invalidated=False, + dialect=SnowflakeDialect(), + ismulti=None, + ) + + def mock_helper(command, *args, **kwargs): + if "_get_schema_columns" in command.text: + # Creating exception exactly how SQLAlchemy does + raise exception_instance else: return original_execute(command, *args, **kwargs) with patch.object(engine_testaccount, "connect") as conn: conn.return_value = connection with patch.object(connection, "execute", side_effect=mock_helper): - column_metadata = inspector.get_columns("users", db_parameters["schema"]) - assert len(column_metadata) == 4 - assert too_many_columns_was_raised + with pytest.raises(exc.ProgrammingError) as exception: + inspector.get_columns("users", db_parameters["schema"]) + assert exception.value.orig == exception_instance.orig # Clean up metadata.drop_all(engine_testaccount) @@ -1636,9 +1625,9 @@ def test_column_type_schema(engine_testaccount): table_reflected = Table(table_name, MetaData(), autoload_with=conn) columns = table_reflected.columns - assert ( - len(columns) == len(ischema_names_baseline) - 1 - ) # -1 because FIXED is not supported + assert len(columns) == ( + len(ischema_names_baseline) - 2 + ) # -2 because FIXED and MAP is not supported def test_result_type_and_value(engine_testaccount): @@ -1816,30 +1805,14 @@ def test_normalize_and_denormalize_empty_string_column_name(engine_testaccount): def test_snowflake_sqlalchemy_as_valid_client_type(): engine = create_engine( - URL( - user=CONNECTION_PARAMETERS["user"], - password=CONNECTION_PARAMETERS["password"], - account=CONNECTION_PARAMETERS["account"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], - protocol=CONNECTION_PARAMETERS["protocol"], - ), + URL(**CONNECTION_PARAMETERS), connect_args={"internal_application_name": "UnknownClient"}, ) with engine.connect() as conn: with pytest.raises(snowflake.connector.errors.NotSupportedError): conn.exec_driver_sql("select 1").cursor.fetch_pandas_all() - engine = create_engine( - URL( - user=CONNECTION_PARAMETERS["user"], - password=CONNECTION_PARAMETERS["password"], - account=CONNECTION_PARAMETERS["account"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], - protocol=CONNECTION_PARAMETERS["protocol"], - ) - ) + engine = create_engine(URL(**CONNECTION_PARAMETERS)) with engine.connect() as conn: conn.exec_driver_sql("select 1").cursor.fetch_pandas_all() @@ -1870,16 +1843,7 @@ def test_snowflake_sqlalchemy_as_valid_client_type(): "3.0.0", (type(None), str), ) - engine = create_engine( - URL( - user=CONNECTION_PARAMETERS["user"], - password=CONNECTION_PARAMETERS["password"], - account=CONNECTION_PARAMETERS["account"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], - protocol=CONNECTION_PARAMETERS["protocol"], - ) - ) + engine = create_engine(URL(**CONNECTION_PARAMETERS)) with engine.connect() as conn: conn.exec_driver_sql("select 1").cursor.fetch_pandas_all() assert ( diff --git a/tests/test_structured_datatypes.py b/tests/test_structured_datatypes.py new file mode 100644 index 00000000..4ea0892b --- /dev/null +++ b/tests/test_structured_datatypes.py @@ -0,0 +1,271 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import pytest +from sqlalchemy import ( + Column, + Integer, + MetaData, + Sequence, + Table, + cast, + exc, + inspect, + text, +) +from sqlalchemy.orm import Session, declarative_base +from sqlalchemy.sql import select +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import NUMBER, IcebergTable, SnowflakeTable +from snowflake.sqlalchemy.custom_types import MAP, TEXT +from snowflake.sqlalchemy.exc import StructuredTypeNotSupportedInTableColumnsError + + +def test_compile_table_with_cluster_by_with_expression(sql_compiler, snapshot): + metadata = MetaData() + user_table = Table( + "clustered_user", + metadata, + Column("Id", Integer, primary_key=True), + Column("name", MAP(NUMBER(), TEXT())), + ) + + create_table = CreateTable(user_table) + + assert sql_compiler(create_table) == snapshot + + +@pytest.mark.requires_external_volume +def test_create_table_structured_datatypes( + engine_testaccount, external_volume, base_location +): + metadata = MetaData() + table_name = "test_map0" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("map_id", MAP(NUMBER(10, 0), TEXT())), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + try: + assert test_map is not None + finally: + test_map.drop(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_insert_map(engine_testaccount, external_volume, base_location, snapshot): + metadata = MetaData() + table_name = "test_insert_map" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("map_id", MAP(NUMBER(10, 0), TEXT())), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + slt = select( + 1, + cast( + text("{'100':'item1', '200':'item2'}"), MAP(NUMBER(10, 0), TEXT()) + ), + ) + ins = test_map.insert().from_select(["id", "map_id"], slt) + conn.execute(ins) + + results = conn.execute(test_map.select()) + data = results.fetchmany() + results.close() + snapshot.assert_match(data) + finally: + test_map.drop(engine_testaccount) + + +@pytest.mark.requires_external_volume +@pytest.mark.parametrize( + "structured_type", + [ + MAP(NUMBER(10, 0), TEXT()), + MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), TEXT())), + ], +) +def test_inspect_structured_data_types( + engine_testaccount, external_volume, base_location, snapshot, structured_type +): + metadata = MetaData() + table_name = "test_st_types" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("map_id", structured_type), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + + try: + inspecter = inspect(engine_testaccount) + columns = inspecter.get_columns(table_name) + + assert isinstance(columns[0]["type"], NUMBER) + assert isinstance(columns[1]["type"], MAP) + assert columns == snapshot + + finally: + test_map.drop(engine_testaccount) + + +@pytest.mark.requires_external_volume +@pytest.mark.parametrize( + "structured_type", + [ + "MAP(NUMBER(10, 0), VARCHAR)", + "MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), VARCHAR))", + ], +) +def test_reflect_structured_data_types( + engine_testaccount, + external_volume, + base_location, + snapshot, + structured_type, + sql_compiler, +): + metadata = MetaData() + table_name = "test_reflect_st_types" + create_table_sql = f""" +CREATE OR REPLACE ICEBERG TABLE {table_name} ( + id number(38,0) primary key, + map_id {structured_type}) +CATALOG = 'SNOWFLAKE' +EXTERNAL_VOLUME = '{external_volume}' +BASE_LOCATION = '{base_location}'; + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + iceberg_table = IcebergTable(table_name, metadata, autoload_with=engine_testaccount) + constraint = iceberg_table.constraints.pop() + constraint.name = "constraint_name" + iceberg_table.constraints.add(constraint) + + try: + with engine_testaccount.connect(): + value = CreateTable(iceberg_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + finally: + metadata.drop_all(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_insert_map_orm( + sql_compiler, external_volume, base_location, engine_testaccount, snapshot +): + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __tablename__ = "test_iceberg_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return IcebergTable(name, metadata, *arg, **kw) + + __table_args__ = { + "external_volume": external_volume, + "base_location": base_location, + } + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + map_id = Column(MAP(NUMBER(10, 0), TEXT())) + + def __repr__(self): + return f"({self.id!r}, {self.name!r})" + + Base.metadata.create_all(engine_testaccount) + + try: + cast_expr = cast( + text("{'100':'item1', '200':'item2'}"), MAP(NUMBER(10, 0), TEXT()) + ) + instance = TestIcebergTableOrm(id=0, map_id=cast_expr) + session.add(instance) + with pytest.raises(exc.ProgrammingError) as programming_error: + session.commit() + # TODO: Support variant in insert statement + assert str(programming_error.value.orig) == snapshot + finally: + Base.metadata.drop_all(engine_testaccount) + + +def test_snowflake_tables_with_structured_types(sql_compiler): + metadata = MetaData() + with pytest.raises( + StructuredTypeNotSupportedInTableColumnsError + ) as programming_error: + SnowflakeTable( + "clustered_user", + metadata, + Column("Id", Integer, primary_key=True), + Column("name", MAP(NUMBER(10, 0), TEXT())), + ) + assert programming_error is not None + + +@pytest.mark.requires_external_volume +def test_select_map_orm(engine_testaccount, external_volume, base_location, snapshot): + metadata = MetaData() + table_name = "test_select_map_orm" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("map_id", MAP(NUMBER(10, 0), TEXT())), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + slt1 = select( + 2, + cast(text("{'100':'item1', '200':'item2'}"), MAP(NUMBER(10, 0), TEXT())), + ) + slt2 = select( + 1, + cast(text("{'100':'item1', '200':'item2'}"), MAP(NUMBER(10, 0), TEXT())), + ).union_all(slt1) + ins = test_map.insert().from_select(["id", "map_id"], slt2) + conn.execute(ins) + conn.commit() + + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __table__ = test_map + + def __repr__(self): + return f"({self.id!r}, {self.map_id!r})" + + try: + data = session.query(TestIcebergTableOrm).all() + snapshot.assert_match(data) + finally: + test_map.drop(engine_testaccount) diff --git a/tests/test_unit_structured_types.py b/tests/test_unit_structured_types.py new file mode 100644 index 00000000..c7bcd6ef --- /dev/null +++ b/tests/test_unit_structured_types.py @@ -0,0 +1,73 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest + +from snowflake.sqlalchemy import NUMBER +from snowflake.sqlalchemy.custom_types import MAP, TEXT +from src.snowflake.sqlalchemy.parser.custom_type_parser import ( + extract_parameters, + parse_type, +) + + +def test_compile_map_with_not_null(snapshot): + user_table = MAP(NUMBER(10, 0), TEXT(), not_null=True) + assert user_table.compile() == snapshot + + +def test_extract_parameters(): + example = "a, b(c, d, f), d" + assert extract_parameters(example) == ["a", "b(c, d, f)", "d"] + + +@pytest.mark.parametrize( + "input_type, expected_type", + [ + ("BIGINT", "BIGINT"), + ("BINARY(16)", "BINARY(16)"), + ("BOOLEAN", "BOOLEAN"), + ("CHAR(5)", "CHAR(5)"), + ("CHARACTER(5)", "CHAR(5)"), + ("DATE", "DATE"), + ("DATETIME(3)", "DATETIME"), + ("DECIMAL(10, 2)", "DECIMAL(10, 2)"), + ("DEC(10, 2)", "DECIMAL(10, 2)"), + ("DOUBLE", "FLOAT"), + ("FLOAT", "FLOAT"), + ("FIXED(10, 2)", "DECIMAL(10, 2)"), + ("INT", "INTEGER"), + ("INTEGER", "INTEGER"), + ("NUMBER(12, 4)", "DECIMAL(12, 4)"), + ("REAL", "REAL"), + ("BYTEINT", "SMALLINT"), + ("SMALLINT", "SMALLINT"), + ("STRING(255)", "VARCHAR(255)"), + ("TEXT(255)", "VARCHAR(255)"), + ("VARCHAR(255)", "VARCHAR(255)"), + ("TIME(6)", "TIME"), + ("TIMESTAMP(3)", "TIMESTAMP"), + ("TIMESTAMP_TZ(3)", "TIMESTAMP_TZ"), + ("TIMESTAMP_LTZ(3)", "TIMESTAMP_LTZ"), + ("TIMESTAMP_NTZ(3)", "TIMESTAMP_NTZ"), + ("TINYINT", "SMALLINT"), + ("VARBINARY(16)", "BINARY(16)"), + ("VARCHAR(255)", "VARCHAR(255)"), + ("VARIANT", "VARIANT"), + ( + "MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR NOT NULL))", + "MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR NOT NULL))", + ), + ( + "MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR))", + "MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR))", + ), + ("MAP(DECIMAL(10, 0), VARIANT)", "MAP(DECIMAL(10, 0), VARIANT)"), + ("OBJECT", "OBJECT"), + ("ARRAY", "ARRAY"), + ("GEOGRAPHY", "GEOGRAPHY"), + ("GEOMETRY", "GEOMETRY"), + ], +) +def test_snowflake_data_types(input_type, expected_type): + assert parse_type(input_type).compile() == expected_type diff --git a/tests/util.py b/tests/util.py index db0b0c9c..264478ff 100644 --- a/tests/util.py +++ b/tests/util.py @@ -29,6 +29,7 @@ ARRAY, GEOGRAPHY, GEOMETRY, + MAP, OBJECT, TIMESTAMP_LTZ, TIMESTAMP_NTZ, @@ -72,6 +73,7 @@ "ARRAY": ARRAY, "GEOGRAPHY": GEOGRAPHY, "GEOMETRY": GEOMETRY, + "MAP": MAP, } From 0d0e6864d9f9f5d32eaf74d2077a28cf907b6298 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Wed, 20 Nov 2024 14:45:21 -0600 Subject: [PATCH 17/21] Update release notes date (#547) * Update release notes date november 22 --- DESCRIPTION.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 33775996..e39984b7 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,7 +9,7 @@ Source code is also available at: # Release Notes -- v1.7.0(November 12, 2024) +- v1.7.0(November 22, 2024) - Add support for dynamic tables and required options - Add support for hybrid tables From 3f633e28a1bd37ffc862e0215116f3a87a684cdc Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Thu, 21 Nov 2024 10:10:32 -0600 Subject: [PATCH 18/21] Update CODEOWNERS (#540) Update CODEOWNERS --- .github/CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 836e0136..b2168af7 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -* @snowflakedb/snowcli +* @snowflakedb/ORM From 65754a4ab2524d9de2c8b9d56d1fb07f819248d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luis=20Calder=C3=B3n=20Ach=C3=ADo?= Date: Thu, 21 Nov 2024 16:06:50 -0600 Subject: [PATCH 19/21] SNOW-878116 Add support for PARTITION BY to COPY INTO location (#542) * add PARTITION BY option for CopyInto --------- Co-authored-by: azban --- DESCRIPTION.md | 3 + src/snowflake/sqlalchemy/base.py | 26 +++++++-- src/snowflake/sqlalchemy/custom_commands.py | 9 ++- tests/test_copy.py | 65 +++++++++++++++------ 4 files changed, 78 insertions(+), 25 deletions(-) diff --git a/DESCRIPTION.md b/DESCRIPTION.md index e39984b7..82ddebc9 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,6 +9,9 @@ Source code is also available at: # Release Notes +- (Unreleased) + - Add support for partition by to copy into + - v1.7.0(November 22, 2024) - Add support for dynamic tables and required options diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index a1e16062..02e4f741 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -16,7 +16,8 @@ from sqlalchemy.schema import Sequence, Table from sqlalchemy.sql import compiler, expression, functions from sqlalchemy.sql.base import CompileState -from sqlalchemy.sql.elements import quoted_name +from sqlalchemy.sql.elements import BindParameter, quoted_name +from sqlalchemy.sql.expression import Executable from sqlalchemy.sql.selectable import Lateral, SelectState from snowflake.sqlalchemy._constants import DIALECT_NAME @@ -563,9 +564,8 @@ def visit_copy_into(self, copy_into, **kw): if isinstance(copy_into.into, Table) else copy_into.into._compiler_dispatch(self, **kw) ) - from_ = None if isinstance(copy_into.from_, Table): - from_ = copy_into.from_ + from_ = copy_into.from_.name # this is intended to catch AWSBucket and AzureContainer elif ( isinstance(copy_into.from_, AWSBucket) @@ -576,6 +576,21 @@ def visit_copy_into(self, copy_into, **kw): # everything else (selects, etc.) else: from_ = f"({copy_into.from_._compiler_dispatch(self, **kw)})" + + partition_by_value = None + if isinstance(copy_into.partition_by, (BindParameter, Executable)): + partition_by_value = copy_into.partition_by.compile( + compile_kwargs={"literal_binds": True} + ) + elif copy_into.partition_by is not None: + partition_by_value = copy_into.partition_by + + partition_by = ( + f"PARTITION BY {partition_by_value}" + if partition_by_value is not None and partition_by_value != "" + else "" + ) + credentials, encryption = "", "" if isinstance(into, tuple): into, credentials, encryption = into @@ -586,8 +601,7 @@ def visit_copy_into(self, copy_into, **kw): options_list.sort(key=operator.itemgetter(0)) options = ( ( - " " - + " ".join( + " ".join( [ "{} = {}".format( n, @@ -608,7 +622,7 @@ def visit_copy_into(self, copy_into, **kw): options += f" {credentials}" if encryption: options += f" {encryption}" - return f"COPY INTO {into} FROM {from_} {formatter}{options}" + return f"COPY INTO {into} FROM {' '.join([from_, partition_by, formatter, options])}" def visit_copy_formatter(self, formatter, **kw): options_list = list(formatter.options.items()) diff --git a/src/snowflake/sqlalchemy/custom_commands.py b/src/snowflake/sqlalchemy/custom_commands.py index 15585bd5..1b9260fe 100644 --- a/src/snowflake/sqlalchemy/custom_commands.py +++ b/src/snowflake/sqlalchemy/custom_commands.py @@ -115,18 +115,23 @@ class CopyInto(UpdateBase): __visit_name__ = "copy_into" _bind = None - def __init__(self, from_, into, formatter=None): + def __init__(self, from_, into, partition_by=None, formatter=None): self.from_ = from_ self.into = into self.formatter = formatter self.copy_options = {} + self.partition_by = partition_by def __repr__(self): """ repr for debugging / logging purposes only. For compilation logic, see the corresponding visitor in base.py """ - return f"COPY INTO {self.into} FROM {repr(self.from_)} {repr(self.formatter)} ({self.copy_options})" + val = f"COPY INTO {self.into} FROM {repr(self.from_)}" + if self.partition_by is not None: + val += f" PARTITION BY {self.partition_by}" + + return val + f" {repr(self.formatter)} ({self.copy_options})" def bind(self): return None diff --git a/tests/test_copy.py b/tests/test_copy.py index e0752d4f..8dfcf286 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -4,7 +4,7 @@ import pytest from sqlalchemy import Column, Integer, MetaData, Sequence, String, Table -from sqlalchemy.sql import select, text +from sqlalchemy.sql import functions, select, text from snowflake.sqlalchemy import ( AWSBucket, @@ -58,8 +58,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_1) - == "COPY INTO 's3://backup' FROM python_tests_foods FILE_FORMAT=(TYPE=csv " - "ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION=" + == "COPY INTO 's3://backup' FROM python_tests_foods FILE_FORMAT=(TYPE=csv " + "ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION=" "(KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab' TYPE='AWS_SSE_KMS')" ) copy_stmt_2 = CopyIntoStorage( @@ -73,8 +73,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler): sql_compiler(copy_stmt_2) == "COPY INTO 's3://backup' FROM (SELECT python_tests_foods.id, " "python_tests_foods.name, python_tests_foods.quantity FROM python_tests_foods " - "WHERE python_tests_foods.id = 1) FILE_FORMAT=(TYPE=json COMPRESSION='zstd' " - "FILE_EXTENSION='json') CREDENTIALS=(AWS_ROLE='some_iam_role') " + "WHERE python_tests_foods.id = 1) FILE_FORMAT=(TYPE=json COMPRESSION='zstd' " + "FILE_EXTENSION='json') CREDENTIALS=(AWS_ROLE='some_iam_role') " "ENCRYPTION=(TYPE='AWS_SSE_S3')" ) copy_stmt_3 = CopyIntoStorage( @@ -87,7 +87,7 @@ def test_copy_into_location(engine_testaccount, sql_compiler): assert ( sql_compiler(copy_stmt_3) == "COPY INTO 'azure://snowflake.blob.core.windows.net/snowpile/backup' " - "FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) " + "FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) " "CREDENTIALS=(AZURE_SAS_TOKEN='token')" ) @@ -95,7 +95,7 @@ def test_copy_into_location(engine_testaccount, sql_compiler): assert ( sql_compiler(copy_stmt_3) == "COPY INTO 'azure://snowflake.blob.core.windows.net/snowpile/backup' " - "FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) " + "FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) " "MAX_FILE_SIZE = 50000000 " "CREDENTIALS=(AZURE_SAS_TOKEN='token')" ) @@ -112,8 +112,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_4) - == "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv " - "ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION=" + == "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv " + "ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION=" "(KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab' TYPE='AWS_SSE_KMS')" ) @@ -126,8 +126,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_5) - == "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv " - "FIELD_DELIMITER=',') ENCRYPTION=" + == "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv " + "FIELD_DELIMITER=',') ENCRYPTION=" "(KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab' TYPE='AWS_SSE_KMS')" ) @@ -138,7 +138,7 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_6) - == "COPY INTO @stage_name FROM python_tests_foods FILE_FORMAT=(TYPE=csv)" + == "COPY INTO @stage_name FROM python_tests_foods FILE_FORMAT=(TYPE=csv) " ) copy_stmt_7 = CopyIntoStorage( @@ -148,7 +148,38 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_7) - == "COPY INTO @name.stage_name/prefix/file FROM python_tests_foods FILE_FORMAT=(TYPE=csv)" + == "COPY INTO @name.stage_name/prefix/file FROM python_tests_foods FILE_FORMAT=(TYPE=csv) " + ) + + copy_stmt_8 = CopyIntoStorage( + from_=food_items, + into=ExternalStage(name="stage_name"), + partition_by=text("('YEAR=' || year)"), + ) + assert ( + sql_compiler(copy_stmt_8) + == "COPY INTO @stage_name FROM python_tests_foods PARTITION BY ('YEAR=' || year) " + ) + + copy_stmt_9 = CopyIntoStorage( + from_=food_items, + into=ExternalStage(name="stage_name"), + partition_by=functions.concat( + text("'YEAR='"), text(food_items.columns["name"].name) + ), + ) + assert ( + sql_compiler(copy_stmt_9) + == "COPY INTO @stage_name FROM python_tests_foods PARTITION BY concat('YEAR=', name) " + ) + + copy_stmt_10 = CopyIntoStorage( + from_=food_items, + into=ExternalStage(name="stage_name"), + partition_by="", + ) + assert ( + sql_compiler(copy_stmt_10) == "COPY INTO @stage_name FROM python_tests_foods " ) # NOTE Other than expect known compiled text, submit it to RegressionTests environment and expect them to fail, but @@ -231,7 +262,7 @@ def test_copy_into_storage_csv_extended(sql_compiler): result = sql_compiler(copy_into) expected = ( r"COPY INTO TEST_IMPORT " - r"FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata " + r"FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata " r"FILE_FORMAT=(TYPE=csv COMPRESSION='auto' DATE_FORMAT='AUTO' " r"ERROR_ON_COLUMN_COUNT_MISMATCH=True ESCAPE=None " r"ESCAPE_UNENCLOSED_FIELD='\134' FIELD_DELIMITER=',' " @@ -288,7 +319,7 @@ def test_copy_into_storage_parquet_named_format(sql_compiler): expected = ( "COPY INTO TEST_IMPORT " "FROM (SELECT $1:COL1::number, $1:COL2::varchar " - "FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet) " + "FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet) " "FILE_FORMAT=(format_name = parquet_file_format) force = TRUE" ) assert result == expected @@ -350,7 +381,7 @@ def test_copy_into_storage_parquet_files(sql_compiler): "COPY INTO TEST_IMPORT " "FROM (SELECT $1:COL1::number, $1:COL2::varchar " "FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet " - "(file_format => parquet_file_format)) FILES = ('foo.txt','bar.txt') " + "(file_format => parquet_file_format)) FILES = ('foo.txt','bar.txt') " "FORCE = true" ) assert result == expected @@ -412,6 +443,6 @@ def test_copy_into_storage_parquet_pattern(sql_compiler): "COPY INTO TEST_IMPORT " "FROM (SELECT $1:COL1::number, $1:COL2::varchar " "FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet " - "(file_format => parquet_file_format)) FORCE = true PATTERN = '.*csv'" + "(file_format => parquet_file_format)) FORCE = true PATTERN = '.*csv'" ) assert result == expected From 9157932f02725f4b356fb0109875eac5511bdb14 Mon Sep 17 00:00:00 2001 From: David Szmolka <69192509+sfc-gh-dszmolka@users.noreply.github.com> Date: Fri, 22 Nov 2024 01:25:00 +0100 Subject: [PATCH 20/21] Amend README for urgent support (#544) --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index c6c13349..dac87fe8 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,11 @@ Snowflake SQLAlchemy runs on the top of the Snowflake Connector for Python as a [dialect](http://docs.sqlalchemy.org/en/latest/dialects/) to bridge a Snowflake database and SQLAlchemy applications. + +| :exclamation: | For production-affecting or urgent issues related to the connector, please [create a case with Snowflake Support](https://community.snowflake.com/s/article/How-To-Submit-a-Support-Case-in-Snowflake-Lodge). | +|---------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| + + ## Prerequisites ### Snowflake Connector for Python From 9b2c6d15c3da64990da4fd49c036d7094cd36e23 Mon Sep 17 00:00:00 2001 From: Jorge Vasquez Rojas Date: Mon, 25 Nov 2024 07:13:35 -0600 Subject: [PATCH 21/21] Fix readme typos (#548) * Fix typo in README.md --------- Co-authored-by: Norman Rosner Co-authored-by: Anthony Holten --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index dac87fe8..2dbf6632 100644 --- a/README.md +++ b/README.md @@ -209,7 +209,7 @@ finally: # Best try: - with engine.connext() as connection: + with engine.connect() as connection: connection.execute(text()) # or connection.exec_driver_sql() @@ -230,7 +230,7 @@ t = Table('mytable', metadata, ### Object Name Case Handling -Snowflake stores all case-insensitive object names in uppercase text. In contrast, SQLAlchemy considers all lowercase object names to be case-insensitive. Snowflake SQLAlchemy converts the object name case during schema-level communication, i.e. during table and index reflection. If you use uppercase object names, SQLAlchemy assumes they are case-sensitive and encloses the names with quotes. This behavior will cause mismatches agaisnt data dictionary data received from Snowflake, so unless identifier names have been truly created as case sensitive using quotes, e.g., `"TestDb"`, all lowercase names should be used on the SQLAlchemy side. +Snowflake stores all case-insensitive object names in uppercase text. In contrast, SQLAlchemy considers all lowercase object names to be case-insensitive. Snowflake SQLAlchemy converts the object name case during schema-level communication, i.e. during table and index reflection. If you use uppercase object names, SQLAlchemy assumes they are case-sensitive and encloses the names with quotes. This behavior will cause mismatches against data dictionary data received from Snowflake, so unless identifier names have been truly created as case sensitive using quotes, e.g., `"TestDb"`, all lowercase names should be used on the SQLAlchemy side. ### Index Support