diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 836e0136..b2168af7 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1 @@ -* @snowflakedb/snowcli +* @snowflakedb/ORM diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index be19f1f1..5e9823f2 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -33,8 +33,8 @@ jobs: python-version: '3.8' - name: Upgrade and install tools run: | - python -m pip install -U pip - python -m pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch python -m hatch env create default - name: Set PY run: echo "PY=$(hatch run gh-cache-sum)" >> $GITHUB_ENV @@ -49,6 +49,10 @@ jobs: name: Test package build and installation runs-on: ubuntu-latest needs: lint + strategy: + fail-fast: true + matrix: + hatch-env: [default, sa14] steps: - uses: actions/checkout@v4 with: @@ -59,15 +63,14 @@ jobs: python-version: '3.8' - name: Upgrade and install tools run: | - python -m pip install -U pip - python -m pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch - name: Build package run: | - python -m hatch clean - python -m hatch build + python -m hatch -e ${{ matrix.hatch-env }} build --clean - name: Install and check import run: | - python -m pip install dist/snowflake_sqlalchemy-*.whl + python -m uv pip install dist/snowflake_sqlalchemy-*.whl python -c "import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)" test-dialect: @@ -79,7 +82,7 @@ jobs: matrix: os: [ ubuntu-latest, - macos-latest, + macos-13, windows-latest, ] python-version: ["3.8"] @@ -98,8 +101,8 @@ jobs: python-version: ${{ matrix.python-version }} - name: Upgrade pip and prepare environment run: | - python -m pip install -U pip - python -m pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch python -m hatch env create default - name: Setup parameters file shell: bash @@ -108,6 +111,9 @@ jobs: run: | gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py + - name: Run test for AWS + run: hatch run test-dialect-aws + if: matrix.cloud-provider == 'aws' - name: Run tests run: hatch run test-dialect - uses: actions/upload-artifact@v4 @@ -125,7 +131,7 @@ jobs: matrix: os: [ ubuntu-latest, - macos-latest, + macos-13, windows-latest, ] python-version: ["3.8"] @@ -144,8 +150,8 @@ jobs: python-version: ${{ matrix.python-version }} - name: Upgrade pip and install hatch run: | - python -m pip install -U pip - python -m pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch python -m hatch env create default - name: Setup parameters file shell: bash @@ -162,8 +168,8 @@ jobs: path: | ./coverage.xml - test-dialect-run-v20: - name: Test dialect run v20 ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + test-dialect-v14: + name: Test dialect v14 ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} needs: [ lint, build-install ] runs-on: ${{ matrix.os }} strategy: @@ -171,7 +177,7 @@ jobs: matrix: os: [ ubuntu-latest, - macos-latest, + macos-13, windows-latest, ] python-version: ["3.8"] @@ -197,21 +203,70 @@ jobs: .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py - name: Upgrade pip and install hatch run: | - python -m pip install -U pip - python -m pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch python -m hatch env create default + - name: Run test for AWS + run: hatch run sa14:test-dialect-aws + if: matrix.cloud-provider == 'aws' - name: Run tests - run: hatch run test-run_v20 + run: hatch run sa14:test-dialect - uses: actions/upload-artifact@v4 with: - name: coverage.xml_dialect-run-20-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + name: coverage.xml_dialect-v14-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + path: | + ./coverage.xml + + test-dialect-compatibility-v14: + name: Test dialect v14 compatibility ${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} + needs: lint + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ + ubuntu-latest, + macos-13, + windows-latest, + ] + python-version: ["3.8"] + cloud-provider: [ + aws, + azure, + gcp, + ] + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Upgrade pip and install hatch + run: | + python -m pip install -U uv + python -m uv pip install -U hatch + python -m hatch env create default + - name: Setup parameters file + shell: bash + env: + PARAMETERS_SECRET: ${{ secrets.PARAMETERS_SECRET }} + run: | + gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \ + .github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py + - name: Run tests + run: hatch run sa14:test-dialect-compatibility + - uses: actions/upload-artifact@v4 + with: + name: coverage.xml_dialect-v14-compatibility-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }} path: | ./coverage.xml combine-coverage: name: Combine coverage if: ${{ success() || failure() }} - needs: [test-dialect, test-dialect-compatibility, test-dialect-run-v20] + needs: [test-dialect, test-dialect-compatibility, test-dialect-v14, test-dialect-compatibility-v14] runs-on: ubuntu-latest steps: - name: Set up Python @@ -220,8 +275,8 @@ jobs: python-version: "3.8" - name: Prepare environment run: | - pip install -U pip - pip install -U hatch + python -m pip install -U uv + python -m uv pip install -U hatch hatch env create default - uses: actions/checkout@v4 with: @@ -233,22 +288,15 @@ jobs: run: | hatch run coverage combine -a artifacts/coverage.xml_*/coverage.xml hatch run coverage report -m - hatch run coverage xml -o combined_coverage.xml - hatch run coverage html -d htmlcov - name: Store coverage reports uses: actions/upload-artifact@v4 with: - name: combined_coverage.xml - path: combined_coverage.xml - - name: Store htmlcov report - uses: actions/upload-artifact@v4 - with: - name: combined_htmlcov - path: htmlcov + name: coverage.xml + path: coverage.xml - name: Uplaod to codecov uses: codecov/codecov-action@v4 with: - file: combined_coverage.xml + file: coverage.xml env_vars: OS,PYTHON fail_ci_if_error: false flags: unittests diff --git a/.github/workflows/create_req_files.yml b/.github/workflows/create_req_files.yml index 618b3024..2cb7a371 100644 --- a/.github/workflows/create_req_files.yml +++ b/.github/workflows/create_req_files.yml @@ -21,10 +21,10 @@ jobs: - name: Display Python version run: python -c "import sys; print(sys.version)" - name: Upgrade setuptools, pip and wheel - run: python -m pip install -U setuptools pip wheel + run: python -m pip install -U setuptools pip wheel uv - name: Install Snowflake SQLAlchemy shell: bash - run: python -m pip install . + run: python -m uv pip install . - name: Generate reqs file name shell: bash run: echo "requirements_file=temp_requirement/requirements_$(python -c 'from sys import version_info;print(str(version_info.major)+str(version_info.minor))').reqs" >> $GITHUB_ENV @@ -34,7 +34,7 @@ jobs: mkdir temp_requirement echo "# Generated on: $(python --version)" >${{ env.requirements_file }} python -m pip freeze | grep -v snowflake-sqlalchemy 1>>${{ env.requirements_file }} 2>/dev/null - echo "snowflake-sqlalchemy==$(python -m pip show snowflake-sqlalchemy | grep ^Version | cut -d' ' -f2-)" >>${{ env.requirements_file }} + echo "snowflake-sqlalchemy==$(python -m uv pip show snowflake-sqlalchemy | grep ^Version | cut -d' ' -f2-)" >>${{ env.requirements_file }} id: create-reqs-file - name: Show created req file shell: bash diff --git a/.github/workflows/jira_close.yml b/.github/workflows/jira_close.yml index 5b170d75..7862f483 100644 --- a/.github/workflows/jira_close.yml +++ b/.github/workflows/jira_close.yml @@ -17,7 +17,7 @@ jobs: token: ${{ secrets.SNOWFLAKE_GITHUB_TOKEN }} # stored in GitHub secrets path: . - name: Jira login - uses: atlassian/gajira-login@master + uses: atlassian/gajira-login@v3 env: JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} diff --git a/.github/workflows/jira_comment.yml b/.github/workflows/jira_comment.yml index 954929fa..8533c14c 100644 --- a/.github/workflows/jira_comment.yml +++ b/.github/workflows/jira_comment.yml @@ -9,7 +9,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Jira login - uses: atlassian/gajira-login@master + uses: atlassian/gajira-login@v3 env: JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} @@ -22,7 +22,7 @@ jobs: jira=$(echo -n $TITLE | awk '{print $1}' | sed -e 's/://') echo ::set-output name=jira::$jira - name: Comment on issue - uses: atlassian/gajira-comment@master + uses: atlassian/gajira-comment@v3 if: startsWith(steps.extract.outputs.jira, 'SNOW-') with: issue: "${{ steps.extract.outputs.jira }}" diff --git a/.github/workflows/jira_issue.yml b/.github/workflows/jira_issue.yml index 31b93aae..85c774ca 100644 --- a/.github/workflows/jira_issue.yml +++ b/.github/workflows/jira_issue.yml @@ -23,7 +23,7 @@ jobs: path: . - name: Login - uses: atlassian/gajira-login@v2.0.0 + uses: atlassian/gajira-login@v3 env: JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} @@ -31,7 +31,7 @@ jobs: - name: Create JIRA Ticket id: create - uses: atlassian/gajira-create@v2.0.1 + uses: atlassian/gajira-create@v3 with: project: SNOW issuetype: Bug diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index ab4be45b..52f43106 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -13,7 +13,8 @@ on: types: [published] permissions: - contents: read + contents: write + id-token: write jobs: deploy: @@ -30,10 +31,50 @@ jobs: python-version: '3.x' - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install build + python -m pip install -U uv + python -m uv pip install -U hatch - name: Build package - run: python -m build + run: python -m hatch build --clean + - name: List artifacts + run: ls ./dist + - name: Install sigstore + run: python -m pip install sigstore + - name: Signing + run: | + for dist in dist/*; do + dist_base="$(basename "${dist}")" + echo "dist: ${dist}" + echo "dist_base: ${dist_base}" + python -m \ + sigstore sign "${dist}" \ + --output-signature "${dist_base}.sig" \ + --output-certificate "${dist_base}.crt" \ + --bundle "${dist_base}.sigstore" + + # Verify using `.sig` `.crt` pair; + python -m \ + sigstore verify identity "${dist}" \ + --signature "${dist_base}.sig" \ + --cert "${dist_base}.crt" \ + --cert-oidc-issuer https://token.actions.githubusercontent.com \ + --cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/python-publish.yml@${GITHUB_REF} + + # Verify using `.sigstore` bundle; + python -m \ + sigstore verify identity "${dist}" \ + --bundle "${dist_base}.sigstore" \ + --cert-oidc-issuer https://token.actions.githubusercontent.com \ + --cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/python-publish.yml@${GITHUB_REF} + done + - name: List artifacts after sign + run: ls ./dist + - name: Copy files to release + run: | + gh release upload ${{ github.event.release.tag_name }} *.sigstore + gh release upload ${{ github.event.release.tag_name }} *.sig + gh release upload ${{ github.event.release.tag_name }} *.crt + env: + GITHUB_TOKEN: ${{ github.TOKEN }} - name: Publish package uses: pypa/gh-action-pypi-publish@release/v1 with: diff --git a/.github/workflows/stale_issue_bot.yml b/.github/workflows/stale_issue_bot.yml index 6d76e9f4..4ee56ff8 100644 --- a/.github/workflows/stale_issue_bot.yml +++ b/.github/workflows/stale_issue_bot.yml @@ -10,7 +10,7 @@ jobs: stale: runs-on: ubuntu-latest steps: - - uses: actions/stale@v7 + - uses: actions/stale@v9 with: close-issue-message: 'To clean up and re-prioritize bugs and feature requests we are closing all issues older than 6 months as of Apr 1, 2023. If there are any issues or feature requests that you would like us to address, please re-create them. For urgent issues, opening a support case with this link [Snowflake Community](https://community.snowflake.com/s/article/How-To-Submit-a-Support-Case-in-Snowflake-Lodge) is the fastest way to get a response' days-before-issue-stale: ${{ inputs.staleDays }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 83172eb8..b7370b74 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,7 @@ repos: rev: v4.5.0 hooks: - id: trailing-whitespace + exclude: '\.ambr$' - id: end-of-file-fixer - id: check-yaml exclude: .github/repo_meta.yaml diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 2f228781..82ddebc9 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,6 +9,29 @@ Source code is also available at: # Release Notes +- (Unreleased) + - Add support for partition by to copy into + +- v1.7.0(November 22, 2024) + + - Add support for dynamic tables and required options + - Add support for hybrid tables + - Fixed SAWarning when registering functions with existing name in default namespace + - Update options to be defined in key arguments instead of arguments. + - Add support for refresh_mode option in DynamicTable + - Add support for iceberg table with Snowflake Catalog + - Fix cluster by option to support explicit expressions + - Add support for MAP datatype + +- v1.6.1(July 9, 2024) + + - Update internal project workflow with pypi publishing + +- v1.6.0(July 8, 2024) + + - support for installing with SQLAlchemy 2.0.x + - use `hatch` & `uv` for managing project virtual environments + - v1.5.4 - Add ability to set ORDER / NOORDER sequence on columns with IDENTITY @@ -24,7 +47,7 @@ Source code is also available at: - v1.5.1(November 03, 2023) - - Fixed a compatibility issue with Snowflake Behavioral Change 1057 on outer lateral join, for more details check https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057. + - Fixed a compatibility issue with Snowflake Behavioral Change 1057 on outer lateral join, for more details check . - Fixed credentials with `externalbrowser` authentication not caching due to incorrect parsing of boolean query parameters. - This fixes other boolean parameter passing to driver as well. diff --git a/README.md b/README.md index 0c75854e..2dbf6632 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,11 @@ Snowflake SQLAlchemy runs on the top of the Snowflake Connector for Python as a [dialect](http://docs.sqlalchemy.org/en/latest/dialects/) to bridge a Snowflake database and SQLAlchemy applications. + +| :exclamation: | For production-affecting or urgent issues related to the connector, please [create a case with Snowflake Support](https://community.snowflake.com/s/article/How-To-Submit-a-Support-Case-in-Snowflake-Lodge). | +|---------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| + + ## Prerequisites ### Snowflake Connector for Python @@ -101,6 +106,7 @@ containing special characters need to be URL encoded to be parsed correctly. Thi characters could lead to authentication failure. The encoding for the password can be generated using `urllib.parse`: + ```python import urllib.parse urllib.parse.quote("kx@% jj5/g") @@ -111,6 +117,7 @@ urllib.parse.quote("kx@% jj5/g") To create an engine with the proper encodings, either manually constructing the url string by formatting or taking advantage of the `snowflake.sqlalchemy.URL` helper method: + ```python import urllib.parse from snowflake.sqlalchemy import URL @@ -191,14 +198,23 @@ engine = create_engine(...) engine.execute() engine.dispose() -# Do this. +# Better. engine = create_engine(...) connection = engine.connect() try: - connection.execute() + connection.execute(text()) finally: connection.close() engine.dispose() + +# Best +try: + with engine.connect() as connection: + connection.execute(text()) + # or + connection.exec_driver_sql() +finally: + engine.dispose() ``` ### Auto-increment Behavior @@ -214,7 +230,7 @@ t = Table('mytable', metadata, ### Object Name Case Handling -Snowflake stores all case-insensitive object names in uppercase text. In contrast, SQLAlchemy considers all lowercase object names to be case-insensitive. Snowflake SQLAlchemy converts the object name case during schema-level communication, i.e. during table and index reflection. If you use uppercase object names, SQLAlchemy assumes they are case-sensitive and encloses the names with quotes. This behavior will cause mismatches agaisnt data dictionary data received from Snowflake, so unless identifier names have been truly created as case sensitive using quotes, e.g., `"TestDb"`, all lowercase names should be used on the SQLAlchemy side. +Snowflake stores all case-insensitive object names in uppercase text. In contrast, SQLAlchemy considers all lowercase object names to be case-insensitive. Snowflake SQLAlchemy converts the object name case during schema-level communication, i.e. during table and index reflection. If you use uppercase object names, SQLAlchemy assumes they are case-sensitive and encloses the names with quotes. This behavior will cause mismatches against data dictionary data received from Snowflake, so unless identifier names have been truly created as case sensitive using quotes, e.g., `"TestDb"`, all lowercase names should be used on the SQLAlchemy side. ### Index Support @@ -242,14 +258,14 @@ engine = create_engine(URL( specific_date = np.datetime64('2016-03-04T12:03:05.123456789Z') -connection = engine.connect() -connection.execute( - "CREATE OR REPLACE TABLE ts_tbl(c1 TIMESTAMP_NTZ)") -connection.execute( - "INSERT INTO ts_tbl(c1) values(%s)", (specific_date,) -) -df = pd.read_sql_query("SELECT * FROM ts_tbl", engine) -assert df.c1.values[0] == specific_date +with engine.connect() as connection: + connection.exec_driver_sql( + "CREATE OR REPLACE TABLE ts_tbl(c1 TIMESTAMP_NTZ)") + connection.exec_driver_sql( + "INSERT INTO ts_tbl(c1) values(%s)", (specific_date,) + ) + df = pd.read_sql_query("SELECT * FROM ts_tbl", connection) + assert df.c1.values[0] == specific_date ``` The following `NumPy` data types are supported: @@ -329,7 +345,7 @@ This example shows how to create a table with two columns, `id` and `name`, as t t = Table('myuser', metadata, Column('id', Integer, primary_key=True), Column('name', String), - snowflake_clusterby=['id', 'name'], ... + snowflake_clusterby=['id', 'name', text('id > 5')], ... ) metadata.create_all(engine) ``` diff --git a/ci/build.sh b/ci/build.sh index 4229506d..b63c8e01 100755 --- a/ci/build.sh +++ b/ci/build.sh @@ -3,7 +3,7 @@ # Build snowflake-sqlalchemy set -o pipefail -PYTHON="python3.7" +PYTHON="python3.8" THIS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" SQLALCHEMY_DIR="$(dirname "${THIS_DIR}")" DIST_DIR="${SQLALCHEMY_DIR}/dist" @@ -11,14 +11,16 @@ DIST_DIR="${SQLALCHEMY_DIR}/dist" cd "$SQLALCHEMY_DIR" # Clean up previously built DIST_DIR if [ -d "${DIST_DIR}" ]; then - echo "[WARN] ${DIST_DIR} already existing, deleting it..." - rm -rf "${DIST_DIR}" + echo "[WARN] ${DIST_DIR} already existing, deleting it..." + rm -rf "${DIST_DIR}" fi # Constants and setup +export PATH=$PATH:$HOME/.local/bin echo "[Info] Building snowflake-sqlalchemy with $PYTHON" # Clean up possible build artifacts rm -rf build generated_version.py -${PYTHON} -m pip install --upgrade pip setuptools wheel build -${PYTHON} -m build --outdir ${DIST_DIR} . +export UV_NO_CACHE=true +${PYTHON} -m pip install uv hatch +${PYTHON} -m hatch build diff --git a/ci/test_linux.sh b/ci/test_linux.sh index 695251e6..f5afc4fb 100755 --- a/ci/test_linux.sh +++ b/ci/test_linux.sh @@ -6,9 +6,9 @@ # - This script assumes that ../dist/repaired_wheels has the wheel(s) built for all versions to be tested # - This is the script that test_docker.sh runs inside of the docker container -PYTHON_VERSIONS="${1:-3.7 3.8 3.9 3.10 3.11}" -THIS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -SQLALCHEMY_DIR="$( dirname "${THIS_DIR}")" +PYTHON_VERSIONS="${1:-3.8 3.9 3.10 3.11}" +THIS_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SQLALCHEMY_DIR="$(dirname "${THIS_DIR}")" # Install one copy of tox python3 -m pip install -U tox @@ -16,10 +16,10 @@ python3 -m pip install -U tox # Run tests cd $SQLALCHEMY_DIR for PYTHON_VERSION in ${PYTHON_VERSIONS}; do - echo "[Info] Testing with ${PYTHON_VERSION}" - SHORT_VERSION=$(python3 -c "print('${PYTHON_VERSION}'.replace('.', ''))") - SQLALCHEMY_WHL=$(ls $SQLALCHEMY_DIR/dist/snowflake_sqlalchemy-*-py2.py3-none-any.whl | sort -r | head -n 1) - TEST_ENVLIST=fix_lint,py${SHORT_VERSION}-ci,py${SHORT_VERSION}-coverage - echo "[Info] Running tox for ${TEST_ENVLIST}" - python3 -m tox -e ${TEST_ENVLIST} --installpkg ${SQLALCHEMY_WHL} + echo "[Info] Testing with ${PYTHON_VERSION}" + SHORT_VERSION=$(python3 -c "print('${PYTHON_VERSION}'.replace('.', ''))") + SQLALCHEMY_WHL=$(ls $SQLALCHEMY_DIR/dist/snowflake_sqlalchemy-*-py3-none-any.whl | sort -r | head -n 1) + TEST_ENVLIST=fix_lint,py${SHORT_VERSION}-ci,py${SHORT_VERSION}-coverage + echo "[Info] Running tox for ${TEST_ENVLIST}" + python3 -m tox -e ${TEST_ENVLIST} --installpkg ${SQLALCHEMY_WHL} done diff --git a/pyproject.toml b/pyproject.toml index 3f95df46..84e64faf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Application Frameworks", "Topic :: Software Development :: Libraries :: Python Modules", ] -dependencies = ["SQLAlchemy>=1.4.19,<2.0.0", "snowflake-connector-python<4.0.0"] +dependencies = ["SQLAlchemy>=1.4.19", "snowflake-connector-python<4.0.0"] [tool.hatch.version] path = "src/snowflake/sqlalchemy/version.py" @@ -53,6 +53,7 @@ development = [ "pytz", "numpy", "mock", + "syrupy==4.6.1", ] pandas = ["snowflake-connector-python[pandas]"] @@ -73,6 +74,13 @@ exclude = ["/.github"] packages = ["src/snowflake"] [tool.hatch.envs.default] +extra-dependencies = ["SQLAlchemy>=1.4.19,<2.1.0"] +features = ["development", "pandas"] +python = "3.8" +installer = "uv" + +[tool.hatch.envs.sa14] +extra-dependencies = ["SQLAlchemy>=1.4.19,<2.0.0"] features = ["development", "pandas"] python = "3.8" @@ -82,10 +90,19 @@ SQLACHEMY_WARN_20 = "1" [tool.hatch.envs.default.scripts] check = "pre-commit run --all-files" -test-dialect = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite" +test-dialect = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/" test-dialect-compatibility = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml tests/sqlalchemy_test_suite" -test-run_v20 = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite --run_v20_sqlalchemy" +test-dialect-aws = "pytest -m \"aws\" -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/" gh-cache-sum = "python -VV | sha256sum | cut -d' ' -f1" +check-import = "python -c 'import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)'" + +[[tool.hatch.envs.release.matrix]] +python = ["3.8", "3.9", "3.10", "3.11", "3.12"] +features = ["development", "pandas"] + +[tool.hatch.envs.release.scripts] +test-dialect = "pytest -ra -vvv --tb=short --ignore=tests/sqlalchemy_test_suite tests/" +test-compatibility = "pytest -ra -vvv --tb=short tests/sqlalchemy_test_suite tests/" [tool.ruff] line-length = 88 @@ -94,6 +111,7 @@ line-length = 88 line-length = 88 [tool.pytest.ini_options] +addopts = "-m 'not feature_max_lob_size and not aws and not requires_external_volume'" markers = [ # Optional dependency groups markers "lambda: AWS lambda tests", @@ -110,5 +128,7 @@ markers = [ # Other markers "timeout: tests that need a timeout time", "internal: tests that could but should only run on our internal CI", + "requires_external_volume: tests that needs a external volume to be executed", "external: tests that could but should only run on our external CI", + "feature_max_lob_size: tests that could but should only run on our external CI", ] diff --git a/snyk/requirements.txt b/snyk/requirements.txt index 3a77e0f9..0166d751 100644 --- a/snyk/requirements.txt +++ b/snyk/requirements.txt @@ -1,2 +1,2 @@ -SQLAlchemy>=1.4.19,<2.0.0 +SQLAlchemy>=1.4.19 snowflake-connector-python<4.0.0 diff --git a/snyk/requiremtnts.txt b/snyk/requiremtnts.txt new file mode 100644 index 00000000..a92c527e --- /dev/null +++ b/snyk/requiremtnts.txt @@ -0,0 +1,2 @@ +snowflake-connector-python<4.0.0 +SQLAlchemy>=1.4.19,<2.1.0 diff --git a/src/snowflake/sqlalchemy/__init__.py b/src/snowflake/sqlalchemy/__init__.py index 9df6aaa2..7d795b2a 100644 --- a/src/snowflake/sqlalchemy/__init__.py +++ b/src/snowflake/sqlalchemy/__init__.py @@ -9,7 +9,7 @@ else: import importlib.metadata as importlib_metadata -from sqlalchemy.types import ( +from sqlalchemy.types import ( # noqa BIGINT, BINARY, BOOLEAN, @@ -27,8 +27,8 @@ VARCHAR, ) -from . import base, snowdialect -from .custom_commands import ( +from . import base, snowdialect # noqa +from .custom_commands import ( # noqa AWSBucket, AzureContainer, CopyFormatter, @@ -41,7 +41,7 @@ MergeInto, PARQUETFormatter, ) -from .custom_types import ( +from .custom_types import ( # noqa ARRAY, BYTEINT, CHARACTER, @@ -50,6 +50,7 @@ FIXED, GEOGRAPHY, GEOMETRY, + MAP, NUMBER, OBJECT, STRING, @@ -61,13 +62,30 @@ VARBINARY, VARIANT, ) -from .util import _url as URL +from .sql.custom_schema import ( # noqa + DynamicTable, + HybridTable, + IcebergTable, + SnowflakeTable, +) +from .sql.custom_schema.options import ( # noqa + AsQueryOption, + ClusterByOption, + IdentifierOption, + KeywordOption, + LiteralOption, + SnowflakeKeyword, + TableOptionKey, + TargetLagOption, + TimeUnit, +) +from .util import _url as URL # noqa base.dialect = dialect = snowdialect.dialect __version__ = importlib_metadata.version("snowflake-sqlalchemy") -__all__ = ( +_custom_types = ( "BIGINT", "BINARY", "BOOLEAN", @@ -102,6 +120,10 @@ "TINYINT", "VARBINARY", "VARIANT", + "MAP", +) + +_custom_commands = ( "MergeInto", "CSVFormatter", "JSONFormatter", @@ -114,3 +136,27 @@ "CreateStage", "CreateFileFormat", ) + +_custom_tables = ("HybridTable", "DynamicTable", "IcebergTable", "SnowflakeTable") + +_custom_table_options = ( + "AsQueryOption", + "TargetLagOption", + "LiteralOption", + "IdentifierOption", + "KeywordOption", + "ClusterByOption", +) + +_enums = ( + "TimeUnit", + "TableOptionKey", + "SnowflakeKeyword", +) +__all__ = ( + *_custom_types, + *_custom_commands, + *_custom_tables, + *_custom_table_options, + *_enums, +) diff --git a/src/snowflake/sqlalchemy/_constants.py b/src/snowflake/sqlalchemy/_constants.py index 46af4454..205ad5d9 100644 --- a/src/snowflake/sqlalchemy/_constants.py +++ b/src/snowflake/sqlalchemy/_constants.py @@ -10,3 +10,5 @@ APPLICATION_NAME = "SnowflakeSQLAlchemy" SNOWFLAKE_SQLALCHEMY_VERSION = VERSION +DIALECT_NAME = "snowflake" +NOT_NULL = "NOT NULL" diff --git a/src/snowflake/sqlalchemy/base.py b/src/snowflake/sqlalchemy/base.py index e008c92f..02e4f741 100644 --- a/src/snowflake/sqlalchemy/base.py +++ b/src/snowflake/sqlalchemy/base.py @@ -5,6 +5,7 @@ import itertools import operator import re +from typing import List from sqlalchemy import exc as sa_exc from sqlalchemy import inspect, sql @@ -13,13 +14,28 @@ from sqlalchemy.orm import context from sqlalchemy.orm.context import _MapperEntity from sqlalchemy.schema import Sequence, Table -from sqlalchemy.sql import compiler, expression +from sqlalchemy.sql import compiler, expression, functions from sqlalchemy.sql.base import CompileState -from sqlalchemy.sql.elements import quoted_name +from sqlalchemy.sql.elements import BindParameter, quoted_name +from sqlalchemy.sql.expression import Executable from sqlalchemy.sql.selectable import Lateral, SelectState -from sqlalchemy.util.compat import string_types -from .custom_commands import AWSBucket, AzureContainer, ExternalStage +from snowflake.sqlalchemy._constants import DIALECT_NAME +from snowflake.sqlalchemy.compat import IS_VERSION_20, args_reducer, string_types +from snowflake.sqlalchemy.custom_commands import ( + AWSBucket, + AzureContainer, + ExternalStage, +) + +from ._constants import NOT_NULL +from .exc import ( + CustomOptionsAreOnlySupportedOnSnowflakeTables, + UnexpectedOptionTypeError, +) +from .functions import flatten +from .sql.custom_schema.custom_table_base import CustomTableBase +from .sql.custom_schema.options.table_option import TableOption from .util import ( _find_left_clause_to_join_from, _set_connection_interpolate_empty_sequences, @@ -183,7 +199,6 @@ def _join_determine_implicit_left_side(self, raw_columns, left, right, onclause) [element._from_objects for element in statement._where_criteria] ), ): - potential[from_clause] = () all_clauses = list(potential.keys()) @@ -324,17 +339,9 @@ def _join_determine_implicit_left_side( return left, replace_from_obj_index, use_entity_index + @args_reducer(positions_to_drop=(6, 7)) def _join_left_to_right( - self, - entities_collection, - left, - right, - onclause, - prop, - create_aliases, - aliased_generation, - outerjoin, - full, + self, entities_collection, left, right, onclause, prop, outerjoin, full ): """given raw "left", "right", "onclause" parameters consumed from a particular key within _join(), add a real ORMJoin object to @@ -364,7 +371,7 @@ def _join_left_to_right( use_entity_index, ) = self._join_place_explicit_left_side(entities_collection, left) - if left is right and not create_aliases: + if left is right: raise sa_exc.InvalidRequestError( "Can't construct a join from %s to %s, they " "are the same entity" % (left, right) @@ -373,9 +380,15 @@ def _join_left_to_right( # the right side as given often needs to be adapted. additionally # a lot of things can be wrong with it. handle all that and # get back the new effective "right" side - r_info, right, onclause = self._join_check_and_adapt_right_side( - left, right, onclause, prop, create_aliases, aliased_generation - ) + + if IS_VERSION_20: + r_info, right, onclause = self._join_check_and_adapt_right_side( + left, right, onclause, prop + ) + else: + r_info, right, onclause = self._join_check_and_adapt_right_side( + left, right, onclause, prop, False, False + ) if not r_info.is_selectable: extra_criteria = self._get_extra_criteria(r_info) @@ -551,9 +564,8 @@ def visit_copy_into(self, copy_into, **kw): if isinstance(copy_into.into, Table) else copy_into.into._compiler_dispatch(self, **kw) ) - from_ = None if isinstance(copy_into.from_, Table): - from_ = copy_into.from_ + from_ = copy_into.from_.name # this is intended to catch AWSBucket and AzureContainer elif ( isinstance(copy_into.from_, AWSBucket) @@ -564,6 +576,21 @@ def visit_copy_into(self, copy_into, **kw): # everything else (selects, etc.) else: from_ = f"({copy_into.from_._compiler_dispatch(self, **kw)})" + + partition_by_value = None + if isinstance(copy_into.partition_by, (BindParameter, Executable)): + partition_by_value = copy_into.partition_by.compile( + compile_kwargs={"literal_binds": True} + ) + elif copy_into.partition_by is not None: + partition_by_value = copy_into.partition_by + + partition_by = ( + f"PARTITION BY {partition_by_value}" + if partition_by_value is not None and partition_by_value != "" + else "" + ) + credentials, encryption = "", "" if isinstance(into, tuple): into, credentials, encryption = into @@ -574,8 +601,7 @@ def visit_copy_into(self, copy_into, **kw): options_list.sort(key=operator.itemgetter(0)) options = ( ( - " " - + " ".join( + " ".join( [ "{} = {}".format( n, @@ -596,7 +622,7 @@ def visit_copy_into(self, copy_into, **kw): options += f" {credentials}" if encryption: options += f" {encryption}" - return f"COPY INTO {into} FROM {from_} {formatter}{options}" + return f"COPY INTO {into} FROM {' '.join([from_, partition_by, formatter, options])}" def visit_copy_formatter(self, formatter, **kw): options_list = list(formatter.options.items()) @@ -880,7 +906,7 @@ def get_column_specification(self, column, **kwargs): return " ".join(colspec) - def post_create_table(self, table): + def handle_cluster_by(self, table): """ Handles snowflake-specific ``CREATE TABLE ... CLUSTER BY`` syntax. @@ -897,7 +923,7 @@ def post_create_table(self, table): ... metadata, ... sa.Column('id', sa.Integer, primary_key=True), ... sa.Column('name', sa.String), - ... snowflake_clusterby=['id', 'name'] + ... snowflake_clusterby=['id', 'name', text("id > 5")] ... ) >>> print(CreateTable(user).compile(engine)) @@ -905,19 +931,49 @@ def post_create_table(self, table): id INTEGER NOT NULL AUTOINCREMENT, name VARCHAR, PRIMARY KEY (id) - ) CLUSTER BY (id, name) + ) CLUSTER BY (id, name, id > 5) """ text = "" - info = table.dialect_options["snowflake"] + info = table.dialect_options[DIALECT_NAME] cluster = info.get("clusterby") if cluster: text += " CLUSTER BY ({})".format( - ", ".join(self.denormalize_column_name(key) for key in cluster) + ", ".join( + ( + self.denormalize_column_name(key) + if isinstance(key, str) + else str(key) + ) + for key in cluster + ) ) return text + def post_create_table(self, table): + text = self.handle_cluster_by(table) + options = [] + invalid_options: List[str] = [] + + for key, option in table.dialect_options[DIALECT_NAME].items(): + if isinstance(option, TableOption): + options.append(option) + elif key not in ["clusterby", "*"]: + invalid_options.append(key) + + if len(invalid_options) > 0: + raise UnexpectedOptionTypeError(sorted(invalid_options)) + + if isinstance(table, CustomTableBase): + options.sort(key=lambda x: (x.priority.value, x.option_name), reverse=True) + for option in options: + text += "\t" + option.render_option(self) + elif len(options) > 0: + raise CustomOptionsAreOnlySupportedOnSnowflakeTables() + + return text + def visit_create_stage(self, create_stage, **kw): """ This visitor will create the SQL representation for a CREATE STAGE command. @@ -979,24 +1035,23 @@ def visit_identity_column(self, identity, **kw): def get_identity_options(self, identity_options): text = [] if identity_options.increment is not None: - text.append(f"INCREMENT BY {identity_options.increment:d}") + text.append("INCREMENT BY %d" % identity_options.increment) if identity_options.start is not None: - text.append(f"START WITH {identity_options.start:d}") + text.append("START WITH %d" % identity_options.start) if identity_options.minvalue is not None: - text.append(f"MINVALUE {identity_options.minvalue:d}") + text.append("MINVALUE %d" % identity_options.minvalue) if identity_options.maxvalue is not None: - text.append(f"MAXVALUE {identity_options.maxvalue:d}") + text.append("MAXVALUE %d" % identity_options.maxvalue) if identity_options.nominvalue is not None: text.append("NO MINVALUE") if identity_options.nomaxvalue is not None: text.append("NO MAXVALUE") if identity_options.cache is not None: - text.append(f"CACHE {identity_options.cache:d}") + text.append("CACHE %d" % identity_options.cache) if identity_options.cycle is not None: text.append("CYCLE" if identity_options.cycle else "NO CYCLE") if identity_options.order is not None: text.append("ORDER" if identity_options.order else "NOORDER") - return " ".join(text) @@ -1031,6 +1086,12 @@ def visit_TINYINT(self, type_, **kw): def visit_VARIANT(self, type_, **kw): return "VARIANT" + def visit_MAP(self, type_, **kw): + not_null = f" {NOT_NULL}" if type_.not_null else "" + return ( + f"MAP({type_.key_type.compile()}, {type_.value_type.compile()}{not_null})" + ) + def visit_ARRAY(self, type_, **kw): return "ARRAY" @@ -1066,3 +1127,5 @@ def visit_GEOMETRY(self, type_, **kw): construct_arguments = [(Table, {"clusterby": None})] + +functions.register_function("flatten", flatten, "snowflake") diff --git a/src/snowflake/sqlalchemy/compat.py b/src/snowflake/sqlalchemy/compat.py new file mode 100644 index 00000000..9e97e574 --- /dev/null +++ b/src/snowflake/sqlalchemy/compat.py @@ -0,0 +1,36 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +from __future__ import annotations + +import functools +from typing import Callable + +from sqlalchemy import __version__ as SA_VERSION +from sqlalchemy import util + +string_types = (str,) +returns_unicode = util.symbol("RETURNS_UNICODE") + +IS_VERSION_20 = tuple(int(v) for v in SA_VERSION.split(".")) >= (2, 0, 0) + + +def args_reducer(positions_to_drop: tuple): + """Removes args at positions provided in tuple positions_to_drop. + + For example tuple (3, 5) will remove items at third and fifth position. + Keep in mind that on class methods first postion is cls or self. + """ + + def fn_wrapper(fn: Callable): + @functools.wraps(fn) + def wrapper(*args): + reduced_args = args + if not IS_VERSION_20: + reduced_args = tuple( + arg for idx, arg in enumerate(args) if idx not in positions_to_drop + ) + fn(*reduced_args) + + return wrapper + + return fn_wrapper diff --git a/src/snowflake/sqlalchemy/custom_commands.py b/src/snowflake/sqlalchemy/custom_commands.py index cec16673..1b9260fe 100644 --- a/src/snowflake/sqlalchemy/custom_commands.py +++ b/src/snowflake/sqlalchemy/custom_commands.py @@ -10,7 +10,8 @@ from sqlalchemy.sql.dml import UpdateBase from sqlalchemy.sql.elements import ClauseElement from sqlalchemy.sql.roles import FromClauseRole -from sqlalchemy.util.compat import string_types + +from .compat import string_types NoneType = type(None) @@ -114,18 +115,23 @@ class CopyInto(UpdateBase): __visit_name__ = "copy_into" _bind = None - def __init__(self, from_, into, formatter=None): + def __init__(self, from_, into, partition_by=None, formatter=None): self.from_ = from_ self.into = into self.formatter = formatter self.copy_options = {} + self.partition_by = partition_by def __repr__(self): """ repr for debugging / logging purposes only. For compilation logic, see the corresponding visitor in base.py """ - return f"COPY INTO {self.into} FROM {repr(self.from_)} {repr(self.formatter)} ({self.copy_options})" + val = f"COPY INTO {self.into} FROM {repr(self.from_)}" + if self.partition_by is not None: + val += f" PARTITION BY {self.partition_by}" + + return val + f" {repr(self.formatter)} ({self.copy_options})" def bind(self): return None diff --git a/src/snowflake/sqlalchemy/custom_types.py b/src/snowflake/sqlalchemy/custom_types.py index 802d1ce1..f2c950dd 100644 --- a/src/snowflake/sqlalchemy/custom_types.py +++ b/src/snowflake/sqlalchemy/custom_types.py @@ -37,6 +37,26 @@ class VARIANT(SnowflakeType): __visit_name__ = "VARIANT" +class StructuredType(SnowflakeType): + def __init__(self): + super().__init__() + + +class MAP(StructuredType): + __visit_name__ = "MAP" + + def __init__( + self, + key_type: sqltypes.TypeEngine, + value_type: sqltypes.TypeEngine, + not_null: bool = False, + ): + self.key_type = key_type + self.value_type = value_type + self.not_null = not_null + super().__init__() + + class OBJECT(SnowflakeType): __visit_name__ = "OBJECT" diff --git a/src/snowflake/sqlalchemy/exc.py b/src/snowflake/sqlalchemy/exc.py new file mode 100644 index 00000000..399e94b6 --- /dev/null +++ b/src/snowflake/sqlalchemy/exc.py @@ -0,0 +1,82 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +from typing import List + +from sqlalchemy.exc import ArgumentError + + +class NoPrimaryKeyError(ArgumentError): + def __init__(self, target: str): + super().__init__(f"Table {target} required primary key.") + + +class UnsupportedPrimaryKeysAndForeignKeysError(ArgumentError): + def __init__(self, target: str): + super().__init__(f"Primary key and foreign keys are not supported in {target}.") + + +class RequiredParametersNotProvidedError(ArgumentError): + def __init__(self, target: str, parameters: List[str]): + super().__init__( + f"{target} requires the following parameters: %s." % ", ".join(parameters) + ) + + +class UnexpectedTableOptionKeyError(ArgumentError): + def __init__(self, expected: str, actual: str): + super().__init__(f"Expected table option {expected} but got {actual}.") + + +class OptionKeyNotProvidedError(ArgumentError): + def __init__(self, target: str): + super().__init__( + f"Expected option key in {target} option but got NoneType instead." + ) + + +class UnexpectedOptionParameterTypeError(ArgumentError): + def __init__(self, parameter_name: str, target: str, types: List[str]): + super().__init__( + f"Parameter {parameter_name} of {target} requires to be one" + f" of following types: {', '.join(types)}." + ) + + +class CustomOptionsAreOnlySupportedOnSnowflakeTables(ArgumentError): + def __init__(self): + super().__init__( + "Identifier, Literal, TargetLag and other custom options are only supported on Snowflake tables." + ) + + +class UnexpectedOptionTypeError(ArgumentError): + def __init__(self, options: List[str]): + super().__init__( + f"The following options are either unsupported or should be defined using a Snowflake table: {', '.join(options)}." + ) + + +class InvalidTableParameterTypeError(ArgumentError): + def __init__(self, name: str, input_type: str, expected_types: List[str]): + expected_types_str = "', '".join(expected_types) + super().__init__( + f"Invalid parameter type '{input_type}' provided for '{name}'. " + f"Expected one of the following types: '{expected_types_str}'.\n" + ) + + +class MultipleErrors(ArgumentError): + def __init__(self, errors): + self.errors = errors + + def __str__(self): + return "".join(str(e) for e in self.errors) + + +class StructuredTypeNotSupportedInTableColumnsError(ArgumentError): + def __init__(self, table_type: str, table_name: str, column_name: str): + super().__init__( + f"Column '{column_name}' is of a structured type, which is only supported on Iceberg tables. " + f"The table '{table_name}' is of type '{table_type}', not Iceberg." + ) diff --git a/src/snowflake/sqlalchemy/functions.py b/src/snowflake/sqlalchemy/functions.py new file mode 100644 index 00000000..c08aa734 --- /dev/null +++ b/src/snowflake/sqlalchemy/functions.py @@ -0,0 +1,16 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +import warnings + +from sqlalchemy.sql import functions as sqlfunc + +FLATTEN_WARNING = "For backward compatibility params are not rendered." + + +class flatten(sqlfunc.GenericFunction): + name = "flatten" + + def __init__(self, *args, **kwargs): + warnings.warn(FLATTEN_WARNING, DeprecationWarning, stacklevel=2) + super().__init__(*args, **kwargs) diff --git a/src/snowflake/sqlalchemy/parser/custom_type_parser.py b/src/snowflake/sqlalchemy/parser/custom_type_parser.py new file mode 100644 index 00000000..cf69c594 --- /dev/null +++ b/src/snowflake/sqlalchemy/parser/custom_type_parser.py @@ -0,0 +1,190 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +import sqlalchemy.types as sqltypes +from sqlalchemy.sql.type_api import TypeEngine +from sqlalchemy.types import ( + BIGINT, + BINARY, + BOOLEAN, + CHAR, + DATE, + DATETIME, + DECIMAL, + FLOAT, + INTEGER, + REAL, + SMALLINT, + TIME, + TIMESTAMP, + VARCHAR, + NullType, +) + +from ..custom_types import ( + _CUSTOM_DECIMAL, + ARRAY, + DOUBLE, + GEOGRAPHY, + GEOMETRY, + MAP, + OBJECT, + TIMESTAMP_LTZ, + TIMESTAMP_NTZ, + TIMESTAMP_TZ, + VARIANT, +) + +ischema_names = { + "BIGINT": BIGINT, + "BINARY": BINARY, + # 'BIT': BIT, + "BOOLEAN": BOOLEAN, + "CHAR": CHAR, + "CHARACTER": CHAR, + "DATE": DATE, + "DATETIME": DATETIME, + "DEC": DECIMAL, + "DECIMAL": DECIMAL, + "DOUBLE": DOUBLE, + "FIXED": DECIMAL, + "FLOAT": FLOAT, # Snowflake FLOAT datatype doesn't has parameters + "INT": INTEGER, + "INTEGER": INTEGER, + "NUMBER": _CUSTOM_DECIMAL, + # 'OBJECT': ? + "REAL": REAL, + "BYTEINT": SMALLINT, + "SMALLINT": SMALLINT, + "STRING": VARCHAR, + "TEXT": VARCHAR, + "TIME": TIME, + "TIMESTAMP": TIMESTAMP, + "TIMESTAMP_TZ": TIMESTAMP_TZ, + "TIMESTAMP_LTZ": TIMESTAMP_LTZ, + "TIMESTAMP_NTZ": TIMESTAMP_NTZ, + "TINYINT": SMALLINT, + "VARBINARY": BINARY, + "VARCHAR": VARCHAR, + "VARIANT": VARIANT, + "MAP": MAP, + "OBJECT": OBJECT, + "ARRAY": ARRAY, + "GEOGRAPHY": GEOGRAPHY, + "GEOMETRY": GEOMETRY, +} + + +def extract_parameters(text: str) -> list: + """ + Extracts parameters from a comma-separated string, handling parentheses. + + :param text: A string with comma-separated parameters, which may include parentheses. + + :return: A list of parameters as strings. + + :example: + For input `"a, (b, c), d"`, the output is `['a', '(b, c)', 'd']`. + """ + + output_parameters = [] + parameter = "" + open_parenthesis = 0 + for c in text: + + if c == "(": + open_parenthesis += 1 + elif c == ")": + open_parenthesis -= 1 + + if open_parenthesis > 0 or c != ",": + parameter += c + elif c == ",": + output_parameters.append(parameter.strip(" ")) + parameter = "" + if parameter != "": + output_parameters.append(parameter.strip(" ")) + return output_parameters + + +def parse_type(type_text: str) -> TypeEngine: + """ + Parses a type definition string and returns the corresponding SQLAlchemy type. + + The function handles types with or without parameters, such as `VARCHAR(255)` or `INTEGER`. + + :param type_text: A string representing a SQLAlchemy type, which may include parameters + in parentheses (e.g., "VARCHAR(255)" or "DECIMAL(10, 2)"). + :return: An instance of the corresponding SQLAlchemy type class (e.g., `String`, `Integer`), + or `NullType` if the type is not recognized. + + :example: + parse_type("VARCHAR(255)") + String(length=255) + """ + index = type_text.find("(") + type_name = type_text[:index] if index != -1 else type_text + parameters = ( + extract_parameters(type_text[index + 1 : -1]) if type_name != type_text else [] + ) + + col_type_class = ischema_names.get(type_name, None) + col_type_kw = {} + if col_type_class is None: + col_type_class = NullType + else: + if issubclass(col_type_class, sqltypes.Numeric): + col_type_kw = __parse_numeric_type_parameters(parameters) + elif issubclass(col_type_class, (sqltypes.String, sqltypes.BINARY)): + col_type_kw = __parse_type_with_length_parameters(parameters) + elif issubclass(col_type_class, MAP): + col_type_kw = __parse_map_type_parameters(parameters) + if col_type_kw is None: + col_type_class = NullType + col_type_kw = {} + + return col_type_class(**col_type_kw) + + +def __parse_map_type_parameters(parameters): + if len(parameters) != 2: + return None + + key_type_str = parameters[0] + value_type_str = parameters[1] + not_null_str = "NOT NULL" + not_null = False + if ( + len(value_type_str) >= len(not_null_str) + and value_type_str[-len(not_null_str) :] == not_null_str + ): + not_null = True + value_type_str = value_type_str[: -len(not_null_str) - 1] + + key_type: TypeEngine = parse_type(key_type_str) + value_type: TypeEngine = parse_type(value_type_str) + if isinstance(key_type, NullType) or isinstance(value_type, NullType): + return None + + return { + "key_type": key_type, + "value_type": value_type, + "not_null": not_null, + } + + +def __parse_type_with_length_parameters(parameters): + return ( + {"length": int(parameters[0])} + if len(parameters) == 1 and str.isdigit(parameters[0]) + else {} + ) + + +def __parse_numeric_type_parameters(parameters): + result = {} + if len(parameters) >= 1 and str.isdigit(parameters[0]): + result["precision"] = int(parameters[0]) + if len(parameters) == 2 and str.isdigit(parameters[1]): + result["scale"] = int(parameters[1]) + return result diff --git a/src/snowflake/sqlalchemy/requirements.py b/src/snowflake/sqlalchemy/requirements.py index ea30a823..f2844804 100644 --- a/src/snowflake/sqlalchemy/requirements.py +++ b/src/snowflake/sqlalchemy/requirements.py @@ -289,9 +289,25 @@ def datetime_implicit_bound(self): # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. return exclusions.closed() + @property + def date_implicit_bound(self): + # Supporting this would require behavior breaking change to implicitly convert str to timestamp when binding + # parameters in string forms of timestamp values. + return exclusions.closed() + + @property + def time_implicit_bound(self): + # Supporting this would require behavior breaking change to implicitly convert str to timestamp when binding + # parameters in string forms of timestamp values. + return exclusions.closed() + @property def timestamp_microseconds_implicit_bound(self): # Supporting this would require behavior breaking change to implicitly convert str to timestamp when binding # parameters in string forms of timestamp values. # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. return exclusions.closed() + + @property + def array_type(self): + return exclusions.closed() diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index 2e40d03c..f9e2e4c8 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -3,8 +3,10 @@ # import operator +import re from collections import defaultdict from functools import reduce +from typing import Any from urllib.parse import unquote_plus import sqlalchemy.types as sqltypes @@ -15,32 +17,14 @@ from sqlalchemy.schema import Table from sqlalchemy.sql import text from sqlalchemy.sql.elements import quoted_name -from sqlalchemy.sql.sqltypes import String -from sqlalchemy.types import ( - BIGINT, - BINARY, - BOOLEAN, - CHAR, - DATE, - DATETIME, - DECIMAL, - FLOAT, - INTEGER, - REAL, - SMALLINT, - TIME, - TIMESTAMP, - VARCHAR, - Date, - DateTime, - Float, - Time, -) +from sqlalchemy.types import FLOAT, Date, DateTime, Float, NullType, Time from snowflake.connector import errors as sf_errors from snowflake.connector.connection import DEFAULT_CONFIGURATION from snowflake.connector.constants import UTF8 +from snowflake.sqlalchemy.compat import returns_unicode +from ._constants import DIALECT_NAME from .base import ( SnowflakeCompiler, SnowflakeDDLCompiler, @@ -49,21 +33,19 @@ SnowflakeTypeCompiler, ) from .custom_types import ( - _CUSTOM_DECIMAL, - ARRAY, - GEOGRAPHY, - GEOMETRY, - OBJECT, - TIMESTAMP_LTZ, - TIMESTAMP_NTZ, - TIMESTAMP_TZ, - VARIANT, + MAP, _CUSTOM_Date, _CUSTOM_DateTime, _CUSTOM_Float, _CUSTOM_Time, ) -from .util import _update_connection_application_name, parse_url_boolean +from .parser.custom_type_parser import ischema_names, parse_type +from .sql.custom_schema.custom_table_prefix import CustomTablePrefix +from .util import ( + _update_connection_application_name, + parse_url_boolean, + parse_url_integer, +) colspecs = { Date: _CUSTOM_Date, @@ -72,49 +54,11 @@ Float: _CUSTOM_Float, } -ischema_names = { - "BIGINT": BIGINT, - "BINARY": BINARY, - # 'BIT': BIT, - "BOOLEAN": BOOLEAN, - "CHAR": CHAR, - "CHARACTER": CHAR, - "DATE": DATE, - "DATETIME": DATETIME, - "DEC": DECIMAL, - "DECIMAL": DECIMAL, - "DOUBLE": FLOAT, - "FIXED": DECIMAL, - "FLOAT": FLOAT, - "INT": INTEGER, - "INTEGER": INTEGER, - "NUMBER": _CUSTOM_DECIMAL, - # 'OBJECT': ? - "REAL": REAL, - "BYTEINT": SMALLINT, - "SMALLINT": SMALLINT, - "STRING": VARCHAR, - "TEXT": VARCHAR, - "TIME": TIME, - "TIMESTAMP": TIMESTAMP, - "TIMESTAMP_TZ": TIMESTAMP_TZ, - "TIMESTAMP_LTZ": TIMESTAMP_LTZ, - "TIMESTAMP_NTZ": TIMESTAMP_NTZ, - "TINYINT": SMALLINT, - "VARBINARY": BINARY, - "VARCHAR": VARCHAR, - "VARIANT": VARIANT, - "OBJECT": OBJECT, - "ARRAY": ARRAY, - "GEOGRAPHY": GEOGRAPHY, - "GEOMETRY": GEOMETRY, -} - _ENABLE_SQLALCHEMY_AS_APPLICATION_NAME = True class SnowflakeDialect(default.DefaultDialect): - name = "snowflake" + name = DIALECT_NAME driver = "snowflake" max_identifier_length = 255 cte_follows_insert = True @@ -134,7 +78,7 @@ class SnowflakeDialect(default.DefaultDialect): # unicode strings supports_unicode_statements = True supports_unicode_binds = True - returns_unicode_strings = String.RETURNS_UNICODE + returns_unicode_strings = returns_unicode description_encoding = None # No lastrowid support. See SNOW-11155 @@ -195,10 +139,34 @@ class SnowflakeDialect(default.DefaultDialect): @classmethod def dbapi(cls): + return cls.import_dbapi() + + @classmethod + def import_dbapi(cls): from snowflake import connector return connector + @staticmethod + def parse_query_param_type(name: str, value: Any) -> Any: + """Cast param value if possible to type defined in connector-python.""" + if not (maybe_type_configuration := DEFAULT_CONFIGURATION.get(name)): + return value + + _, expected_type = maybe_type_configuration + if not isinstance(expected_type, tuple): + expected_type = (expected_type,) + + if isinstance(value, expected_type): + return value + + elif bool in expected_type: + return parse_url_boolean(value) + elif int in expected_type: + return parse_url_integer(value) + else: + return value + def create_connect_args(self, url: URL): opts = url.translate_connect_args(username="user") if "database" in opts: @@ -235,47 +203,25 @@ def create_connect_args(self, url: URL): # URL sets the query parameter values as strings, we need to cast to expected types when necessary for name, value in query.items(): - maybe_type_configuration = DEFAULT_CONFIGURATION.get(name) - if ( - not maybe_type_configuration - ): # if the parameter is not found in the type mapping, pass it through as a string - opts[name] = value - continue - - (_, expected_type) = maybe_type_configuration - if not isinstance(expected_type, tuple): - expected_type = (expected_type,) - - if isinstance( - value, expected_type - ): # if the expected type is str, pass it through as a string - opts[name] = value - - elif ( - bool in expected_type - ): # if the expected type is bool, parse it and pass as a boolean - opts[name] = parse_url_boolean(value) - else: - # TODO: other types like int are stil passed through as string - # https://github.com/snowflakedb/snowflake-sqlalchemy/issues/447 - opts[name] = value + opts[name] = self.parse_query_param_type(name, value) return ([], opts) - def has_table(self, connection, table_name, schema=None): + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): """ Checks if the table exists """ return self._has_object(connection, "TABLE", table_name, schema) - def has_sequence(self, connection, sequence_name, schema=None): + @reflection.cache + def has_sequence(self, connection, sequence_name, schema=None, **kw): """ Checks if the sequence exists """ return self._has_object(connection, "SEQUENCE", sequence_name, schema) def _has_object(self, connection, object_type, object_name, schema=None): - full_name = self._denormalize_quote_join(schema, object_name) try: results = connection.execute( @@ -324,8 +270,8 @@ def _denormalize_quote_join(self, *idents): @reflection.cache def _current_database_schema(self, connection, **kw): - res = connection.exec_driver_sql( - "select current_database(), current_schema();" + res = connection.execute( + text("select current_database(), current_schema();") ).fetchone() return ( self.normalize_name(res[0]), @@ -344,14 +290,6 @@ def _map_name_to_idx(result): name_to_idx[col[0]] = idx return name_to_idx - @reflection.cache - def get_indexes(self, connection, table_name, schema=None, **kw): - """ - Gets all indexes - """ - # no index is supported by Snowflake - return [] - @reflection.cache def get_check_constraints(self, connection, table_name, schema, **kw): # check constraints are not supported by Snowflake @@ -507,6 +445,12 @@ def get_foreign_keys(self, connection, table_name, schema=None, **kw): ) return foreign_key_map.get(table_name, []) + def table_columns_as_dict(self, columns): + result = {} + for column in columns: + result[column["name"]] = column + return result + @reflection.cache def _get_schema_columns(self, connection, schema, **kw): """Get all columns in the schema, if we hit 'Information schema query returned too much data' problem return @@ -514,10 +458,12 @@ def _get_schema_columns(self, connection, schema, **kw): ans = {} current_database, _ = self._current_database_schema(connection, **kw) full_schema_name = self._denormalize_quote_join(current_database, schema) + full_columns_descriptions = {} try: schema_primary_keys = self._get_schema_primary_keys( connection, full_schema_name, **kw ) + schema_name = self.denormalize_name(schema) result = connection.execute( text( """ @@ -538,7 +484,7 @@ def _get_schema_columns(self, connection, schema, **kw): WHERE ic.table_schema=:table_schema ORDER BY ic.ordinal_position""" ), - {"table_schema": self.denormalize_name(schema)}, + {"table_schema": schema_name}, ) except sa_exc.ProgrammingError as pe: if pe.orig.errno == 90030: @@ -568,10 +514,7 @@ def _get_schema_columns(self, connection, schema, **kw): col_type = self.ischema_names.get(coltype, None) col_type_kw = {} if col_type is None: - sa_util.warn( - f"Did not recognize type '{coltype}' of column '{column_name}'" - ) - col_type = sqltypes.NULLTYPE + col_type = NullType else: if issubclass(col_type, FLOAT): col_type_kw["precision"] = numeric_precision @@ -581,6 +524,33 @@ def _get_schema_columns(self, connection, schema, **kw): col_type_kw["scale"] = numeric_scale elif issubclass(col_type, (sqltypes.String, sqltypes.BINARY)): col_type_kw["length"] = character_maximum_length + elif issubclass(col_type, MAP): + if (schema_name, table_name) not in full_columns_descriptions: + full_columns_descriptions[(schema_name, table_name)] = ( + self.table_columns_as_dict( + self._get_table_columns( + connection, table_name, schema_name + ) + ) + ) + + if ( + (schema_name, table_name) in full_columns_descriptions + and column_name + in full_columns_descriptions[(schema_name, table_name)] + ): + ans[table_name].append( + full_columns_descriptions[(schema_name, table_name)][ + column_name + ] + ) + continue + else: + col_type = NullType + if col_type == NullType: + sa_util.warn( + f"Did not recognize type '{coltype}' of column '{column_name}'" + ) type_instance = col_type(**col_type_kw) @@ -615,91 +585,71 @@ def _get_schema_columns(self, connection, schema, **kw): def _get_table_columns(self, connection, table_name, schema=None, **kw): """Get all columns in a table in a schema""" ans = [] - current_database, _ = self._current_database_schema(connection, **kw) - full_schema_name = self._denormalize_quote_join(current_database, schema) - schema_primary_keys = self._get_schema_primary_keys( - connection, full_schema_name, **kw + current_database, default_schema = self._current_database_schema( + connection, **kw ) + schema = schema if schema else default_schema + table_schema = self.denormalize_name(schema) + table_name = self.denormalize_name(table_name) result = connection.execute( text( - """ - SELECT /* sqlalchemy:get_table_columns */ - ic.table_name, - ic.column_name, - ic.data_type, - ic.character_maximum_length, - ic.numeric_precision, - ic.numeric_scale, - ic.is_nullable, - ic.column_default, - ic.is_identity, - ic.comment - FROM information_schema.columns ic - WHERE ic.table_schema=:table_schema - AND ic.table_name=:table_name - ORDER BY ic.ordinal_position""" - ), - { - "table_schema": self.denormalize_name(schema), - "table_name": self.denormalize_name(table_name), - }, + "DESC /* sqlalchemy:_get_schema_columns */" + f" TABLE {table_schema}.{table_name} TYPE = COLUMNS" + ) ) for ( - table_name, column_name, coltype, - character_maximum_length, - numeric_precision, - numeric_scale, + _kind, is_nullable, column_default, - is_identity, + primary_key, + _unique_key, + _check, + _expression, comment, + _policy_name, + _privacy_domain, + _name_mapping, ) in result: - table_name = self.normalize_name(table_name) + column_name = self.normalize_name(column_name) if column_name.startswith("sys_clustering_column"): continue # ignoring clustering column - col_type = self.ischema_names.get(coltype, None) - col_type_kw = {} - if col_type is None: + type_instance = parse_type(coltype) + if isinstance(type_instance, NullType): sa_util.warn( f"Did not recognize type '{coltype}' of column '{column_name}'" ) - col_type = sqltypes.NULLTYPE - else: - if issubclass(col_type, FLOAT): - col_type_kw["precision"] = numeric_precision - col_type_kw["decimal_return_scale"] = numeric_scale - elif issubclass(col_type, sqltypes.Numeric): - col_type_kw["precision"] = numeric_precision - col_type_kw["scale"] = numeric_scale - elif issubclass(col_type, (sqltypes.String, sqltypes.BINARY)): - col_type_kw["length"] = character_maximum_length - - type_instance = col_type(**col_type_kw) - current_table_pks = schema_primary_keys.get(table_name) + identity = None + match = re.match( + r"IDENTITY START (?P\d+) INCREMENT (?P\d+) (?PORDER|NOORDER)", + column_default if column_default else "", + ) + if match: + identity = { + "start": int(match.group("start")), + "increment": int(match.group("increment")), + "order_type": match.group("order_type"), + } + is_identity = identity is not None ans.append( { "name": column_name, "type": type_instance, - "nullable": is_nullable == "YES", - "default": column_default, - "autoincrement": is_identity == "YES", + "nullable": is_nullable == "Y", + "default": None if is_identity else column_default, + "autoincrement": is_identity, "comment": comment if comment != "" else None, - "primary_key": ( - ( - column_name - in schema_primary_keys[table_name]["constrained_columns"] - ) - if current_table_pks - else False - ), + "primary_key": primary_key == "Y", } ) + if is_identity: + ans[-1]["identity"] = identity + # If we didn't find any columns for the table, the table doesn't exist. if len(ans) == 0: raise sa_exc.NoSuchTableError() @@ -887,6 +837,129 @@ def get_table_comment(self, connection, table_name, schema=None, **kw): ) } + def get_multi_indexes( + self, + connection, + *, + schema, + filter_names, + **kw, + ): + """ + Gets the indexes definition + """ + + table_prefixes = self.get_multi_prefixes( + connection, schema, filter_prefix=CustomTablePrefix.HYBRID.name + ) + if len(table_prefixes) == 0: + return [] + schema = schema or self.default_schema_name + if not schema: + result = connection.execute( + text("SHOW /* sqlalchemy:get_multi_indexes */ INDEXES") + ) + else: + result = connection.execute( + text( + f"SHOW /* sqlalchemy:get_multi_indexes */ INDEXES IN SCHEMA {self._denormalize_quote_join(schema)}" + ) + ) + + n2i = self.__class__._map_name_to_idx(result) + indexes = {} + + for row in result.cursor.fetchall(): + table = self.normalize_name(str(row[n2i["table"]])) + if ( + row[n2i["name"]] == f'SYS_INDEX_{row[n2i["table"]]}_PRIMARY' + or table not in filter_names + or (schema, table) not in table_prefixes + or ( + (schema, table) in table_prefixes + and CustomTablePrefix.HYBRID.name + not in table_prefixes[(schema, table)] + ) + ): + continue + index = { + "name": row[n2i["name"]], + "unique": row[n2i["is_unique"]] == "Y", + "column_names": row[n2i["columns"]], + "include_columns": row[n2i["included_columns"]], + "dialect_options": {}, + } + if (schema, table) in indexes: + indexes[(schema, table)] = indexes[(schema, table)].append(index) + else: + indexes[(schema, table)] = [index] + + return list(indexes.items()) + + def _value_or_default(self, data, table, schema): + table = self.normalize_name(str(table)) + dic_data = dict(data) + if (schema, table) in dic_data: + return dic_data[(schema, table)] + else: + return [] + + def get_prefixes_from_data(self, n2i, row, **kw): + prefixes_found = [] + for valid_prefix in CustomTablePrefix: + key = f"is_{valid_prefix.name.lower()}" + if key in n2i and row[n2i[key]] == "Y": + prefixes_found.append(valid_prefix.name) + return prefixes_found + + @reflection.cache + def get_multi_prefixes( + self, connection, schema, table_name=None, filter_prefix=None, **kw + ): + """ + Gets all table prefixes + """ + schema = schema or self.default_schema_name + filter = f"LIKE '{table_name}'" if table_name else "" + if schema: + result = connection.execute( + text( + f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES IN SCHEMA {schema}" + ) + ) + else: + result = connection.execute( + text( + f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES LIKE '{table_name}'" + ) + ) + + n2i = self.__class__._map_name_to_idx(result) + tables_prefixes = {} + for row in result.cursor.fetchall(): + table = self.normalize_name(str(row[n2i["name"]])) + table_prefixes = self.get_prefixes_from_data(n2i, row) + if filter_prefix and filter_prefix not in table_prefixes: + continue + if (schema, table) in tables_prefixes: + tables_prefixes[(schema, table)].append(table_prefixes) + else: + tables_prefixes[(schema, table)] = table_prefixes + + return tables_prefixes + + @reflection.cache + def get_indexes(self, connection, tablename, schema, **kw): + """ + Gets the indexes definition + """ + table_name = self.normalize_name(str(tablename)) + data = self.get_multi_indexes( + connection=connection, schema=schema, filter_names=[table_name], **kw + ) + + return self._value_or_default(data, table_name, schema) + def connect(self, *cargs, **cparams): return ( super().connect( @@ -904,8 +977,12 @@ def connect(self, *cargs, **cparams): @sa_vnt.listens_for(Table, "before_create") def check_table(table, connection, _ddl_runner, **kw): + from .sql.custom_schema.hybrid_table import HybridTable + + if HybridTable.is_equal_type(table): # noqa + return True if isinstance(_ddl_runner.dialect, SnowflakeDialect) and table.indexes: - raise NotImplementedError("Snowflake does not support indexes") + raise NotImplementedError("Only Snowflake Hybrid Tables supports indexes") dialect = SnowflakeDialect diff --git a/src/snowflake/sqlalchemy/sql/__init__.py b/src/snowflake/sqlalchemy/sql/__init__.py new file mode 100644 index 00000000..ef416f64 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/__init__.py @@ -0,0 +1,3 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py b/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py new file mode 100644 index 00000000..cbc75ebc --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/__init__.py @@ -0,0 +1,9 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from .dynamic_table import DynamicTable +from .hybrid_table import HybridTable +from .iceberg_table import IcebergTable +from .snowflake_table import SnowflakeTable + +__all__ = ["DynamicTable", "HybridTable", "IcebergTable", "SnowflakeTable"] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/clustered_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/clustered_table.py new file mode 100644 index 00000000..6c0904a8 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/clustered_table.py @@ -0,0 +1,37 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from typing import Any, Optional + +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from .custom_table_base import CustomTableBase +from .options.as_query_option import AsQueryOption +from .options.cluster_by_option import ClusterByOption, ClusterByOptionType +from .options.table_option import TableOptionKey + + +class ClusteredTableBase(CustomTableBase): + + @property + def cluster_by(self) -> Optional[AsQueryOption]: + return self._get_dialect_option(TableOptionKey.CLUSTER_BY) + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + cluster_by: ClusterByOptionType = None, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + + options = [ + ClusterByOption.create(cluster_by), + ] + + kw.update(self._as_dialect_options(options)) + super().__init__(name, metadata, *args, **kw) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py new file mode 100644 index 00000000..6f7ee0c5 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py @@ -0,0 +1,127 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import typing +from typing import Any, List + +from sqlalchemy.sql.schema import MetaData, SchemaItem, Table + +from ..._constants import DIALECT_NAME +from ...compat import IS_VERSION_20 +from ...custom_commands import NoneType +from ...custom_types import StructuredType +from ...exc import ( + MultipleErrors, + NoPrimaryKeyError, + RequiredParametersNotProvidedError, + StructuredTypeNotSupportedInTableColumnsError, + UnsupportedPrimaryKeysAndForeignKeysError, +) +from .custom_table_prefix import CustomTablePrefix +from .options.invalid_table_option import InvalidTableOption +from .options.table_option import TableOption, TableOptionKey + + +class CustomTableBase(Table): + __table_prefixes__: typing.List[CustomTablePrefix] = [] + _support_primary_and_foreign_keys: bool = True + _enforce_primary_keys: bool = False + _required_parameters: List[TableOptionKey] = [] + _support_structured_types: bool = False + + @property + def table_prefixes(self) -> typing.List[str]: + return [prefix.name for prefix in self.__table_prefixes__] + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + if len(self.__table_prefixes__) > 0: + prefixes = kw.get("prefixes", []) + self.table_prefixes + kw.update(prefixes=prefixes) + + if not IS_VERSION_20 and hasattr(super(), "_init"): + kw.pop("_no_init", True) + super()._init(name, metadata, *args, **kw) + else: + super().__init__(name, metadata, *args, **kw) + + if not kw.get("autoload_with", False): + self._validate_table() + + def _validate_table(self): + exceptions: List[Exception] = [] + + columns_validation = self.__validate_columns() + if columns_validation is not None: + exceptions.append(columns_validation) + + for _, option in self.dialect_options[DIALECT_NAME].items(): + if isinstance(option, InvalidTableOption): + exceptions.append(option.exception) + + if isinstance(self.key, NoneType) and self._enforce_primary_keys: + exceptions.append(NoPrimaryKeyError(self.__class__.__name__)) + missing_parameters: List[str] = [] + + for required_parameter in self._required_parameters: + if isinstance(self._get_dialect_option(required_parameter), NoneType): + missing_parameters.append(required_parameter.name.lower()) + if missing_parameters: + exceptions.append( + RequiredParametersNotProvidedError( + self.__class__.__name__, missing_parameters + ) + ) + + if not self._support_primary_and_foreign_keys and ( + self.primary_key or self.foreign_keys + ): + exceptions.append( + UnsupportedPrimaryKeysAndForeignKeysError(self.__class__.__name__) + ) + + if len(exceptions) > 1: + exceptions.sort(key=lambda e: str(e)) + raise MultipleErrors(exceptions) + elif len(exceptions) == 1: + raise exceptions[0] + + def __validate_columns(self): + for column in self.columns: + if not self._support_structured_types and isinstance( + column.type, StructuredType + ): + return StructuredTypeNotSupportedInTableColumnsError( + self.__class__.__name__, self.name, column.name + ) + + def _get_dialect_option( + self, option_name: TableOptionKey + ) -> typing.Optional[TableOption]: + if option_name.value in self.dialect_options[DIALECT_NAME]: + return self.dialect_options[DIALECT_NAME][option_name.value] + return None + + def _as_dialect_options( + self, table_options: List[TableOption] + ) -> typing.Dict[str, TableOption]: + result = {} + for table_option in table_options: + if isinstance(table_option, TableOption) and isinstance( + table_option.option_name, str + ): + result[DIALECT_NAME + "_" + table_option.option_name] = table_option + return result + + @classmethod + def is_equal_type(cls, table: Table) -> bool: + for prefix in cls.__table_prefixes__: + if prefix.name not in table._prefixes: + return False + + return True diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py new file mode 100644 index 00000000..de7835d1 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +from enum import Enum + + +class CustomTablePrefix(Enum): + DEFAULT = 0 + EXTERNAL = 1 + EVENT = 2 + HYBRID = 3 + ICEBERG = 4 + DYNAMIC = 5 diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py new file mode 100644 index 00000000..91c379f0 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py @@ -0,0 +1,117 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import typing +from typing import Any, Union + +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from .custom_table_prefix import CustomTablePrefix +from .options import ( + IdentifierOption, + IdentifierOptionType, + KeywordOptionType, + TableOptionKey, + TargetLagOption, + TargetLagOptionType, +) +from .options.keyword_option import KeywordOption +from .table_from_query import TableFromQueryBase + + +class DynamicTable(TableFromQueryBase): + """ + A class representing a dynamic table with configurable options and settings. + + The `DynamicTable` class allows for the creation and querying of tables with + specific options, such as `Warehouse` and `TargetLag`. + + While it does not support reflection at this time, it provides a flexible + interface for creating dynamic tables and management. + + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table + + Example using option values: + DynamicTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + target_lag=(1, TimeUnit.HOURS), + warehouse='warehouse_name', + refresh_mode=SnowflakeKeyword.AUTO + as_query="SELECT id, name from test_table_1;" + ) + + Example using explicit options: + DynamicTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + target_lag=TargetLag(1, TimeUnit.HOURS), + warehouse=Identifier('warehouse_name'), + refresh_mode=KeywordOption(SnowflakeKeyword.AUTO) + as_query=AsQuery("SELECT id, name from test_table_1;") + ) + """ + + __table_prefixes__ = [CustomTablePrefix.DYNAMIC] + _support_primary_and_foreign_keys = False + _required_parameters = [ + TableOptionKey.WAREHOUSE, + TableOptionKey.AS_QUERY, + TableOptionKey.TARGET_LAG, + ] + + @property + def warehouse(self) -> typing.Optional[IdentifierOption]: + return self._get_dialect_option(TableOptionKey.WAREHOUSE) + + @property + def target_lag(self) -> typing.Optional[TargetLagOption]: + return self._get_dialect_option(TableOptionKey.TARGET_LAG) + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + warehouse: IdentifierOptionType = None, + target_lag: Union[TargetLagOptionType, KeywordOptionType] = None, + refresh_mode: KeywordOptionType = None, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + + options = [ + IdentifierOption.create(TableOptionKey.WAREHOUSE, warehouse), + TargetLagOption.create(target_lag), + KeywordOption.create(TableOptionKey.REFRESH_MODE, refresh_mode), + ] + + kw.update(self._as_dialect_options(options)) + super().__init__(name, metadata, *args, **kw) + + def _init( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + self.__init__(name, metadata, *args, _no_init=False, **kw) + + def __repr__(self) -> str: + return "DynamicTable(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + [repr(self.target_lag)] + + [repr(self.warehouse)] + + [repr(self.cluster_by)] + + [repr(self.as_query)] + + [f"{k}={repr(getattr(self, k))}" for k in ["schema"]] + ) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py new file mode 100644 index 00000000..16a58d47 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py @@ -0,0 +1,62 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from typing import Any + +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from .custom_table_base import CustomTableBase +from .custom_table_prefix import CustomTablePrefix + + +class HybridTable(CustomTableBase): + """ + A class representing a hybrid table with configurable options and settings. + + The `HybridTable` class allows for the creation and querying of OLTP Snowflake Tables . + + While it does not support reflection at this time, it provides a flexible + interface for creating hybrid tables and management. + + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-hybrid-table + + Example usage: + HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String) + ) + """ + + __table_prefixes__ = [CustomTablePrefix.HYBRID] + _enforce_primary_keys: bool = True + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + super().__init__(name, metadata, *args, **kw) + + def _init( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + self.__init__(name, metadata, *args, _no_init=False, **kw) + + def __repr__(self) -> str: + return "HybridTable(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + [f"{k}={repr(getattr(self, k))}" for k in ["schema"]] + ) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py new file mode 100644 index 00000000..4f62d4f2 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/iceberg_table.py @@ -0,0 +1,102 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import typing +from typing import Any + +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from .custom_table_prefix import CustomTablePrefix +from .options import LiteralOption, LiteralOptionType, TableOptionKey +from .table_from_query import TableFromQueryBase + + +class IcebergTable(TableFromQueryBase): + """ + A class representing an iceberg table with configurable options and settings. + + While it does not support reflection at this time, it provides a flexible + interface for creating iceberg tables and management. + + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-iceberg-table + + Example using option values: + + IcebergTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + external_volume='my_external_volume', + base_location='my_iceberg_table'" + ) + + Example using explicit options: + DynamicTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + external_volume=LiteralOption('my_external_volume') + base_location=LiteralOption('my_iceberg_table') + ) + """ + + __table_prefixes__ = [CustomTablePrefix.ICEBERG] + _support_structured_types = True + + @property + def external_volume(self) -> typing.Optional[LiteralOption]: + return self._get_dialect_option(TableOptionKey.EXTERNAL_VOLUME) + + @property + def base_location(self) -> typing.Optional[LiteralOption]: + return self._get_dialect_option(TableOptionKey.BASE_LOCATION) + + @property + def catalog(self) -> typing.Optional[LiteralOption]: + return self._get_dialect_option(TableOptionKey.CATALOG) + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + external_volume: LiteralOptionType = None, + base_location: LiteralOptionType = None, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + + options = [ + LiteralOption.create(TableOptionKey.EXTERNAL_VOLUME, external_volume), + LiteralOption.create(TableOptionKey.BASE_LOCATION, base_location), + LiteralOption.create(TableOptionKey.CATALOG, "SNOWFLAKE"), + ] + + kw.update(self._as_dialect_options(options)) + super().__init__(name, metadata, *args, **kw) + + def _init( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + self.__init__(name, metadata, *args, _no_init=False, **kw) + + def __repr__(self) -> str: + return "IcebergTable(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + [repr(self.external_volume)] + + [repr(self.base_location)] + + [repr(self.catalog)] + + [repr(self.cluster_by)] + + [repr(self.as_query)] + + [f"{k}={repr(getattr(self, k))}" for k in ["schema"]] + ) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py new file mode 100644 index 00000000..e94ea46b --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/__init__.py @@ -0,0 +1,33 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from .as_query_option import AsQueryOption, AsQueryOptionType +from .cluster_by_option import ClusterByOption, ClusterByOptionType +from .identifier_option import IdentifierOption, IdentifierOptionType +from .keyword_option import KeywordOption, KeywordOptionType +from .keywords import SnowflakeKeyword +from .literal_option import LiteralOption, LiteralOptionType +from .table_option import TableOptionKey +from .target_lag_option import TargetLagOption, TargetLagOptionType, TimeUnit + +__all__ = [ + # Options + "IdentifierOption", + "LiteralOption", + "KeywordOption", + "AsQueryOption", + "TargetLagOption", + "ClusterByOption", + # Enums + "TimeUnit", + "SnowflakeKeyword", + "TableOptionKey", + # Types + "IdentifierOptionType", + "LiteralOptionType", + "AsQueryOptionType", + "TargetLagOptionType", + "KeywordOptionType", + "ClusterByOptionType", +] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query_option.py new file mode 100644 index 00000000..93994abc --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/as_query_option.py @@ -0,0 +1,63 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional, Union + +from sqlalchemy.sql import Selectable + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .table_option import Priority, TableOption, TableOptionKey + + +class AsQueryOption(TableOption): + """Class to represent an AS clause in tables. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-table#create-table-as-select-also-referred-to-as-ctas + + Example: + as_query=AsQueryOption('select name, address from existing_table where name = "test"') + + is equivalent to: + + as select name, address from existing_table where name = "test" + """ + + def __init__(self, query: Union[str, Selectable]) -> None: + super().__init__() + self._name: TableOptionKey = TableOptionKey.AS_QUERY + self.query = query + + @staticmethod + def create( + value: Optional[Union["AsQueryOption", str, Selectable]] + ) -> "TableOption": + if isinstance(value, (NoneType, AsQueryOption)): + return value + if isinstance(value, (str, Selectable)): + return AsQueryOption(value) + return TableOption._get_invalid_table_option( + TableOptionKey.AS_QUERY, + str(type(value).__name__), + [AsQueryOption.__name__, str.__name__, Selectable.__name__], + ) + + def template(self) -> str: + return "AS %s" + + @property + def priority(self) -> Priority: + return Priority.LOWEST + + def __get_expression(self): + if isinstance(self.query, Selectable): + return self.query.compile(compile_kwargs={"literal_binds": True}) + return self.query + + def _render(self, compiler) -> str: + return self.template() % (self.__get_expression()) + + def __repr__(self) -> str: + return "AsQueryOption(%s)" % self.__get_expression() + + +AsQueryOptionType = Union[AsQueryOption, str, Selectable] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/cluster_by_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/cluster_by_option.py new file mode 100644 index 00000000..b92029bb --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/cluster_by_option.py @@ -0,0 +1,58 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import List, Union + +from sqlalchemy.sql.expression import TextClause + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .table_option import Priority, TableOption, TableOptionKey + + +class ClusterByOption(TableOption): + """Class to represent the cluster by clause in tables. + For further information on this clause, please refer to: https://docs.snowflake.com/en/user-guide/tables-clustering-keys + Example: + cluster_by=ClusterByOption('name', text('id > 0')) + + is equivalent to: + + cluster by (name, id > 0) + """ + + def __init__(self, *expressions: Union[str, TextClause]) -> None: + super().__init__() + self._name: TableOptionKey = TableOptionKey.CLUSTER_BY + self.expressions = expressions + + @staticmethod + def create(value: "ClusterByOptionType") -> "TableOption": + if isinstance(value, (NoneType, ClusterByOption)): + return value + if isinstance(value, List): + return ClusterByOption(*value) + return TableOption._get_invalid_table_option( + TableOptionKey.CLUSTER_BY, + str(type(value).__name__), + [ClusterByOption.__name__, list.__name__], + ) + + def template(self) -> str: + return f"{self.option_name.upper()} (%s)" + + @property + def priority(self) -> Priority: + return Priority.HIGH + + def __get_expression(self): + return ", ".join([str(expression) for expression in self.expressions]) + + def _render(self, compiler) -> str: + return self.template() % (self.__get_expression()) + + def __repr__(self) -> str: + return "ClusterByOption(%s)" % self.__get_expression() + + +ClusterByOptionType = Union[ClusterByOption, List[Union[str, TextClause]]] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/identifier_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/identifier_option.py new file mode 100644 index 00000000..b296898b --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/identifier_option.py @@ -0,0 +1,63 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .table_option import Priority, TableOption, TableOptionKey + + +class IdentifierOption(TableOption): + """Class to represent an identifier option in Snowflake Tables. + + Example: + warehouse = IdentifierOption('my_warehouse') + + is equivalent to: + + WAREHOUSE = my_warehouse + """ + + def __init__(self, value: Union[str]) -> None: + super().__init__() + self.value: str = value + + @property + def priority(self): + return Priority.HIGH + + @staticmethod + def create( + name: TableOptionKey, value: Optional[Union[str, "IdentifierOption"]] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return None + + if isinstance(value, str): + value = IdentifierOption(value) + + if isinstance(value, IdentifierOption): + value._set_option_name(name) + return value + + return TableOption._get_invalid_table_option( + name, str(type(value).__name__), [IdentifierOption.__name__, str.__name__] + ) + + def template(self) -> str: + return f"{self.option_name.upper()} = %s" + + def _render(self, compiler) -> str: + return self.template() % self.value + + def __repr__(self) -> str: + option_name = ( + f", table_option_key={self.option_name}" + if not isinstance(self.option_name, NoneType) + else "" + ) + return f"IdentifierOption(value='{self.value}'{option_name})" + + +IdentifierOptionType = Union[IdentifierOption, str, int] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/invalid_table_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/invalid_table_option.py new file mode 100644 index 00000000..2bdc9dd3 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/invalid_table_option.py @@ -0,0 +1,25 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional + +from .table_option import TableOption, TableOptionKey + + +class InvalidTableOption(TableOption): + """Class to store errors and raise them after table initialization in order to avoid recursion error.""" + + def __init__(self, name: TableOptionKey, value: Exception) -> None: + super().__init__() + self.exception: Exception = value + self._name = name + + @staticmethod + def create(name: TableOptionKey, value: Exception) -> Optional[TableOption]: + return InvalidTableOption(name, value) + + def _render(self, compiler) -> str: + raise self.exception + + def __repr__(self) -> str: + return f"ErrorOption(value='{self.exception}')" diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/keyword_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/keyword_option.py new file mode 100644 index 00000000..ff6b444d --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/keyword_option.py @@ -0,0 +1,65 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Optional, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .keywords import SnowflakeKeyword +from .table_option import Priority, TableOption, TableOptionKey + + +class KeywordOption(TableOption): + """Class to represent a keyword option in Snowflake Tables. + + Example: + target_lag = KeywordOption(SnowflakeKeyword.DOWNSTREAM) + + is equivalent to: + + TARGET_LAG = DOWNSTREAM + """ + + def __init__(self, value: Union[SnowflakeKeyword]) -> None: + super().__init__() + self.value: str = value.value + + @property + def priority(self): + return Priority.HIGH + + def template(self) -> str: + return f"{self.option_name.upper()} = %s" + + def _render(self, compiler) -> str: + return self.template() % self.value.upper() + + @staticmethod + def create( + name: TableOptionKey, value: Optional[Union[SnowflakeKeyword, "KeywordOption"]] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return value + if isinstance(value, SnowflakeKeyword): + value = KeywordOption(value) + + if isinstance(value, KeywordOption): + value._set_option_name(name) + return value + + return TableOption._get_invalid_table_option( + name, + str(type(value).__name__), + [KeywordOption.__name__, SnowflakeKeyword.__name__], + ) + + def __repr__(self) -> str: + option_name = ( + f", table_option_key={self.option_name}" + if isinstance(self.option_name, NoneType) + else "" + ) + return f"KeywordOption(value='{self.value}'{option_name})" + + +KeywordOptionType = Union[KeywordOption, SnowflakeKeyword] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/keywords.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/keywords.py new file mode 100644 index 00000000..f4ba87ea --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/keywords.py @@ -0,0 +1,14 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +from enum import Enum + + +class SnowflakeKeyword(Enum): + # TARGET_LAG + DOWNSTREAM = "DOWNSTREAM" + + # REFRESH_MODE + AUTO = "AUTO" + FULL = "FULL" + INCREMENTAL = "INCREMENTAL" diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/literal_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/literal_option.py new file mode 100644 index 00000000..55dd7675 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/literal_option.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from typing import Any, Optional, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .table_option import Priority, TableOption, TableOptionKey + + +class LiteralOption(TableOption): + """Class to represent a literal option in Snowflake Table. + + Example: + warehouse = LiteralOption('my_warehouse') + + is equivalent to: + + WAREHOUSE = 'my_warehouse' + """ + + def __init__(self, value: Union[int, str]) -> None: + super().__init__() + self.value: Any = value + + @property + def priority(self): + return Priority.HIGH + + @staticmethod + def create( + name: TableOptionKey, value: Optional[Union[str, int, "LiteralOption"]] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return None + if isinstance(value, (str, int)): + value = LiteralOption(value) + + if isinstance(value, LiteralOption): + value._set_option_name(name) + return value + + return TableOption._get_invalid_table_option( + name, + str(type(value).__name__), + [LiteralOption.__name__, str.__name__, int.__name__], + ) + + def template(self) -> str: + if isinstance(self.value, int): + return f"{self.option_name.upper()} = %d" + else: + return f"{self.option_name.upper()} = '%s'" + + def _render(self, compiler) -> str: + return self.template() % self.value + + def __repr__(self) -> str: + option_name = ( + f", table_option_key={self.option_name}" + if not isinstance(self.option_name, NoneType) + else "" + ) + return f"LiteralOption(value='{self.value}'{option_name})" + + +LiteralOptionType = Union[LiteralOption, str, int] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py new file mode 100644 index 00000000..5ebb4817 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/table_option.py @@ -0,0 +1,84 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from enum import Enum +from typing import List, Optional + +from snowflake.sqlalchemy import exc +from snowflake.sqlalchemy.custom_commands import NoneType + + +class Priority(Enum): + LOWEST = 0 + VERY_LOW = 1 + LOW = 2 + MEDIUM = 4 + HIGH = 6 + VERY_HIGH = 7 + HIGHEST = 8 + + +class TableOption: + + def __init__(self) -> None: + self._name: Optional[TableOptionKey] = None + + @property + def option_name(self) -> str: + if isinstance(self._name, NoneType): + return None + return str(self._name.value) + + def _set_option_name(self, name: Optional["TableOptionKey"]): + self._name = name + + @property + def priority(self) -> Priority: + return Priority.MEDIUM + + @staticmethod + def create(**kwargs) -> "TableOption": + raise NotImplementedError + + @staticmethod + def _get_invalid_table_option( + parameter_name: "TableOptionKey", input_type: str, expected_types: List[str] + ) -> "TableOption": + from .invalid_table_option import InvalidTableOption + + return InvalidTableOption( + parameter_name, + exc.InvalidTableParameterTypeError( + parameter_name.value, input_type, expected_types + ), + ) + + def _validate_option(self): + if isinstance(self.option_name, NoneType): + raise exc.OptionKeyNotProvidedError(self.__class__.__name__) + + def template(self) -> str: + return f"{self.option_name.upper()} = %s" + + def render_option(self, compiler) -> str: + self._validate_option() + return self._render(compiler) + + def _render(self, compiler) -> str: + raise NotImplementedError + + +class TableOptionKey(Enum): + AS_QUERY = "as_query" + BASE_LOCATION = "base_location" + CATALOG = "catalog" + CATALOG_SYNC = "catalog_sync" + CLUSTER_BY = "cluster by" + DATA_RETENTION_TIME_IN_DAYS = "data_retention_time_in_days" + DEFAULT_DDL_COLLATION = "default_ddl_collation" + EXTERNAL_VOLUME = "external_volume" + MAX_DATA_EXTENSION_TIME_IN_DAYS = "max_data_extension_time_in_days" + REFRESH_MODE = "refresh_mode" + STORAGE_SERIALIZATION_POLICY = "storage_serialization_policy" + TARGET_LAG = "target_lag" + WAREHOUSE = "warehouse" diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag_option.py b/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag_option.py new file mode 100644 index 00000000..7c1c0825 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/options/target_lag_option.py @@ -0,0 +1,94 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# from enum import Enum +from enum import Enum +from typing import Optional, Tuple, Union + +from snowflake.sqlalchemy.custom_commands import NoneType + +from .keyword_option import KeywordOption, KeywordOptionType +from .keywords import SnowflakeKeyword +from .table_option import Priority, TableOption, TableOptionKey + + +class TimeUnit(Enum): + SECONDS = "seconds" + MINUTES = "minutes" + HOURS = "hours" + DAYS = "days" + + +class TargetLagOption(TableOption): + """Class to represent the target lag clause in Dynamic Tables. + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table + + Example using the time and unit parameters: + + target_lag = TargetLagOption(10, TimeUnit.SECONDS) + + is equivalent to: + + TARGET_LAG = '10 SECONDS' + + Example using keyword parameter: + + target_lag = KeywordOption(SnowflakeKeyword.DOWNSTREAM) + + is equivalent to: + + TARGET_LAG = DOWNSTREAM + + """ + + def __init__( + self, + time: Optional[int] = 0, + unit: Optional[TimeUnit] = TimeUnit.MINUTES, + ) -> None: + super().__init__() + self.time = time + self.unit = unit + self._name: TableOptionKey = TableOptionKey.TARGET_LAG + + @staticmethod + def create( + value: Union["TargetLagOption", Tuple[int, TimeUnit], KeywordOptionType] + ) -> Optional[TableOption]: + if isinstance(value, NoneType): + return value + + if isinstance(value, Tuple): + time, unit = value + value = TargetLagOption(time, unit) + + if isinstance(value, TargetLagOption): + return value + + if isinstance(value, (KeywordOption, SnowflakeKeyword)): + return KeywordOption.create(TableOptionKey.TARGET_LAG, value) + + return TableOption._get_invalid_table_option( + TableOptionKey.TARGET_LAG, + str(type(value).__name__), + [ + TargetLagOption.__name__, + f"Tuple[int, {TimeUnit.__name__}])", + SnowflakeKeyword.__name__, + ], + ) + + def __get_expression(self): + return f"'{str(self.time)} {str(self.unit.value)}'" + + @property + def priority(self) -> Priority: + return Priority.HIGH + + def _render(self, compiler) -> str: + return self.template() % (self.__get_expression()) + + def __repr__(self) -> str: + return "TargetLagOption(%s)" % self.__get_expression() + + +TargetLagOptionType = Union[TargetLagOption, Tuple[int, TimeUnit]] diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/snowflake_table.py b/src/snowflake/sqlalchemy/sql/custom_schema/snowflake_table.py new file mode 100644 index 00000000..56a14c83 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/snowflake_table.py @@ -0,0 +1,70 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from typing import Any + +from sqlalchemy.sql.schema import MetaData, SchemaItem + +from .table_from_query import TableFromQueryBase + + +class SnowflakeTable(TableFromQueryBase): + """ + A class representing a table in Snowflake with configurable options and settings. + + While it does not support reflection at this time, it provides a flexible + interface for creating tables and management. + + For further information on this clause, please refer to: https://docs.snowflake.com/en/sql-reference/sql/create-table + Example usage: + + SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by = ["id", text("name > 5")] + ) + + Example using explict options: + + SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by = ClusterByOption("id", text("name > 5")) + ) + + """ + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + if kw.get("_no_init", True): + return + super().__init__(name, metadata, *args, **kw) + + def _init( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + **kw: Any, + ) -> None: + self.__init__(name, metadata, *args, _no_init=False, **kw) + + def __repr__(self) -> str: + return "SnowflakeTable(%s)" % ", ".join( + [repr(self.name)] + + [repr(self.metadata)] + + [repr(x) for x in self.columns] + + [repr(self.cluster_by)] + + [repr(self.as_query)] + + [f"{k}={repr(getattr(self, k))}" for k in ["schema"]] + ) diff --git a/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py new file mode 100644 index 00000000..cbd65de3 --- /dev/null +++ b/src/snowflake/sqlalchemy/sql/custom_schema/table_from_query.py @@ -0,0 +1,54 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import typing +from typing import Any, Optional + +from sqlalchemy.sql import Selectable +from sqlalchemy.sql.schema import Column, MetaData, SchemaItem + +from .clustered_table import ClusteredTableBase +from .options.as_query_option import AsQueryOption, AsQueryOptionType +from .options.table_option import TableOptionKey + + +class TableFromQueryBase(ClusteredTableBase): + + @property + def as_query(self) -> Optional[AsQueryOption]: + return self._get_dialect_option(TableOptionKey.AS_QUERY) + + def __init__( + self, + name: str, + metadata: MetaData, + *args: SchemaItem, + as_query: AsQueryOptionType = None, + **kw: Any, + ) -> None: + items = [item for item in args] + as_query = AsQueryOption.create(as_query) # noqa + kw.update(self._as_dialect_options([as_query])) + if ( + isinstance(as_query, AsQueryOption) + and isinstance(as_query.query, Selectable) + and not self.__has_defined_columns(items) + ): + columns = self.__create_columns_from_selectable(as_query.query) + args = items + columns + super().__init__(name, metadata, *args, **kw) + + def __has_defined_columns(self, items: typing.List[SchemaItem]) -> bool: + for item in items: + if isinstance(item, Column): + return True + + def __create_columns_from_selectable( + self, selectable: Selectable + ) -> Optional[typing.List[Column]]: + if not isinstance(selectable, Selectable): + return + columns: typing.List[Column] = [] + for _, c in selectable.exported_columns.items(): + columns += [Column(c.name, c.type)] + return columns diff --git a/src/snowflake/sqlalchemy/util.py b/src/snowflake/sqlalchemy/util.py index 32e07373..a1aefff9 100644 --- a/src/snowflake/sqlalchemy/util.py +++ b/src/snowflake/sqlalchemy/util.py @@ -7,7 +7,7 @@ from typing import Any from urllib.parse import quote_plus -from sqlalchemy import exc, inspection, sql, util +from sqlalchemy import exc, inspection, sql from sqlalchemy.exc import NoForeignKeysError from sqlalchemy.orm.interfaces import MapperProperty from sqlalchemy.orm.util import _ORMJoin as sa_orm_util_ORMJoin @@ -19,6 +19,7 @@ from snowflake.connector.compat import IS_STR from snowflake.connector.connection import SnowflakeConnection +from snowflake.sqlalchemy import compat from ._constants import ( APPLICATION_NAME, @@ -124,6 +125,13 @@ def parse_url_boolean(value: str) -> bool: raise ValueError(f"Invalid boolean value detected: '{value}'") +def parse_url_integer(value: str) -> int: + try: + return int(value) + except ValueError as e: + raise ValueError(f"Invalid int value detected: '{value}") from e + + # handle Snowflake BCR bcr-1057 # the BCR impacts sqlalchemy.orm.context.ORMSelectCompileState and sqlalchemy.sql.selectable.SelectState # which used the 'sqlalchemy.util.preloaded.sql_util.find_left_clause_to_join_from' method that @@ -212,7 +220,7 @@ def __init__( # then the "_joined_from_info" concept can go left_orm_info = getattr(left, "_joined_from_info", left_info) self._joined_from_info = right_info - if isinstance(onclause, util.string_types): + if isinstance(onclause, compat.string_types): onclause = getattr(left_orm_info.entity, onclause) # #### diff --git a/src/snowflake/sqlalchemy/version.py b/src/snowflake/sqlalchemy/version.py index 61c9fc41..b80a9096 100644 --- a/src/snowflake/sqlalchemy/version.py +++ b/src/snowflake/sqlalchemy/version.py @@ -3,4 +3,4 @@ # # Update this for the versions # Don't change the forth version number from None -VERSION = "1.5.3" +VERSION = "1.7.0" diff --git a/tests/__snapshots__/test_compile_dynamic_table.ambr b/tests/__snapshots__/test_compile_dynamic_table.ambr new file mode 100644 index 00000000..81c7f90f --- /dev/null +++ b/tests/__snapshots__/test_compile_dynamic_table.ambr @@ -0,0 +1,13 @@ +# serializer version: 1 +# name: test_compile_dynamic_table + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm + "CREATE DYNAMIC TABLE test_dynamic_table_orm (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm_with_str_keys + "CREATE DYNAMIC TABLE test_dynamic_table_orm_2 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_selectable + "CREATE DYNAMIC TABLE dynamic_test_table_1 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT test_table_1.id, test_table_1.name FROM test_table_1 WHERE test_table_1.id = 23" +# --- diff --git a/tests/__snapshots__/test_core.ambr b/tests/__snapshots__/test_core.ambr new file mode 100644 index 00000000..7a4e0f99 --- /dev/null +++ b/tests/__snapshots__/test_core.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_compile_table_with_cluster_by_with_expression + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tPRIMARY KEY ("Id")) CLUSTER BY ("Id", name, "Id" > 5)' +# --- diff --git a/tests/__snapshots__/test_orm.ambr b/tests/__snapshots__/test_orm.ambr new file mode 100644 index 00000000..2116e9e9 --- /dev/null +++ b/tests/__snapshots__/test_orm.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_orm_one_to_many_relationship_with_hybrid_table + ProgrammingError('(snowflake.connector.errors.ProgrammingError) 200009 (22000): Foreign key constraint "SYS_INDEX_HB_TBL_ADDRESS_FOREIGN_KEY_USER_ID_HB_TBL_USER_ID" was violated.') +# --- diff --git a/tests/__snapshots__/test_reflect_dynamic_table.ambr b/tests/__snapshots__/test_reflect_dynamic_table.ambr new file mode 100644 index 00000000..d4cc22b5 --- /dev/null +++ b/tests/__snapshots__/test_reflect_dynamic_table.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_compile_dynamic_table + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- diff --git a/tests/__snapshots__/test_structured_datatypes.ambr b/tests/__snapshots__/test_structured_datatypes.ambr new file mode 100644 index 00000000..0325a946 --- /dev/null +++ b/tests/__snapshots__/test_structured_datatypes.ambr @@ -0,0 +1,90 @@ +# serializer version: 1 +# name: test_compile_table_with_cluster_by_with_expression + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname MAP(DECIMAL, VARCHAR), \tPRIMARY KEY ("Id"))' +# --- +# name: test_compile_table_with_double_map + 'CREATE TABLE clustered_user (\t"Id" INTEGER NOT NULL AUTOINCREMENT, \tname MAP(DECIMAL, MAP(DECIMAL, VARCHAR)), \tPRIMARY KEY ("Id"))' +# --- +# name: test_insert_map + list([ + (1, '{\n "100": "item1",\n "200": "item2"\n}'), + ]) +# --- +# name: test_insert_map_orm + ''' + 002014 (22000): SQL compilation error: + Invalid expression [CAST(OBJECT_CONSTRUCT('100', 'item1', '200', 'item2') AS MAP(NUMBER(10,0), VARCHAR(16777216)))] in VALUES clause + ''' +# --- +# name: test_inspect_structured_data_types[structured_type0] + list([ + dict({ + 'autoincrement': True, + 'comment': None, + 'default': None, + 'identity': dict({ + 'increment': 1, + 'start': 1, + }), + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=10, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'map_id', + 'nullable': True, + 'primary_key': False, + 'type': MAP(_CUSTOM_DECIMAL(precision=10, scale=0), VARCHAR(length=16777216)), + }), + ]) +# --- +# name: test_inspect_structured_data_types[structured_type1] + list([ + dict({ + 'autoincrement': True, + 'comment': None, + 'default': None, + 'identity': dict({ + 'increment': 1, + 'start': 1, + }), + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=10, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'map_id', + 'nullable': True, + 'primary_key': False, + 'type': MAP(_CUSTOM_DECIMAL(precision=10, scale=0), MAP(_CUSTOM_DECIMAL(precision=10, scale=0), VARCHAR(length=16777216))), + }), + ]) +# --- +# name: test_reflect_structured_data_types[MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), VARCHAR))] + "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tmap_id MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR(16777216))), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" +# --- +# name: test_reflect_structured_data_types[MAP(NUMBER(10, 0), VARCHAR)] + "CREATE ICEBERG TABLE test_reflect_st_types (\tid DECIMAL(38, 0) NOT NULL, \tmap_id MAP(DECIMAL(10, 0), VARCHAR(16777216)), \tCONSTRAINT constraint_name PRIMARY KEY (id))\tCATALOG = 'SNOWFLAKE'" +# --- +# name: test_select_map_orm + list([ + (1, '{\n "100": "item1",\n "200": "item2"\n}'), + (2, '{\n "100": "item1",\n "200": "item2"\n}'), + ]) +# --- +# name: test_select_map_orm.1 + list([ + ]) +# --- +# name: test_select_map_orm.2 + list([ + ]) +# --- diff --git a/tests/__snapshots__/test_unit_structured_types.ambr b/tests/__snapshots__/test_unit_structured_types.ambr new file mode 100644 index 00000000..ff861351 --- /dev/null +++ b/tests/__snapshots__/test_unit_structured_types.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_compile_map_with_not_null + 'MAP(DECIMAL(10, 0), VARCHAR NOT NULL)' +# --- diff --git a/tests/conftest.py b/tests/conftest.py index a9c2560a..a91521b9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,21 +46,6 @@ TEST_SCHEMA = f"sqlalchemy_tests_{str(uuid.uuid4()).replace('-', '_')}" -def pytest_addoption(parser): - parser.addoption( - "--run_v20_sqlalchemy", - help="Use only 2.0 SQLAlchemy APIs, any legacy features (< 2.0) will not be supported." - "Turning on this option will set future flag to True on Engine and Session objects according to" - "the migration guide: https://docs.sqlalchemy.org/en/14/changelog/migration_20.html", - action="store_true", - ) - - -@pytest.fixture(scope="session") -def run_v20_sqlalchemy(pytestconfig): - return pytestconfig.option.run_v20_sqlalchemy - - @pytest.fixture(scope="session") def on_travis(): return os.getenv("TRAVIS", "").lower() == "true" @@ -109,6 +94,36 @@ def db_parameters(): yield get_db_parameters() +@pytest.fixture(scope="session") +def external_volume(): + db_parameters = get_db_parameters() + if "external_volume" in db_parameters: + yield db_parameters["external_volume"] + else: + raise ValueError("External_volume is not set") + + +@pytest.fixture(scope="session") +def external_stage(): + db_parameters = get_db_parameters() + if "external_stage" in db_parameters: + yield db_parameters["external_stage"] + else: + raise ValueError("External_stage is not set") + + +@pytest.fixture(scope="function") +def base_location(external_stage, engine_testaccount): + unique_id = str(uuid.uuid4()) + base_location = "L" + unique_id.replace("-", "_") + yield base_location + remove_base_location = f""" + REMOVE @{external_stage} pattern='.*{base_location}.*'; + """ + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(remove_base_location) + + def get_db_parameters() -> dict: """ Sets the db connection parameters @@ -160,20 +175,21 @@ def url_factory(**kwargs) -> URL: return URL(**url_params) -def get_engine(url: URL, run_v20_sqlalchemy=False, **engine_kwargs): +def get_engine(url: URL, **engine_kwargs): engine_params = { "poolclass": NullPool, - "future": run_v20_sqlalchemy, + "future": True, + "echo": True, } engine_params.update(engine_kwargs) - engine = create_engine(url, **engine_kwargs) + engine = create_engine(url, **engine_params) return engine @pytest.fixture() -def engine_testaccount(request, run_v20_sqlalchemy): +def engine_testaccount(request): url = url_factory() - engine = get_engine(url, run_v20_sqlalchemy=run_v20_sqlalchemy) + engine = get_engine(url) request.addfinalizer(engine.dispose) yield engine @@ -181,17 +197,17 @@ def engine_testaccount(request, run_v20_sqlalchemy): @pytest.fixture() def engine_testaccount_with_numpy(request): url = url_factory(numpy=True) - engine = get_engine(url, run_v20_sqlalchemy=run_v20_sqlalchemy) + engine = get_engine(url) request.addfinalizer(engine.dispose) yield engine @pytest.fixture() -def engine_testaccount_with_qmark(request, run_v20_sqlalchemy): +def engine_testaccount_with_qmark(request): snowflake.connector.paramstyle = "qmark" url = url_factory() - engine = get_engine(url, run_v20_sqlalchemy=run_v20_sqlalchemy) + engine = get_engine(url) request.addfinalizer(engine.dispose) yield engine diff --git a/tests/custom_tables/__init__.py b/tests/custom_tables/__init__.py new file mode 100644 index 00000000..d43f066c --- /dev/null +++ b/tests/custom_tables/__init__.py @@ -0,0 +1,2 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. diff --git a/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr b/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr new file mode 100644 index 00000000..66c8f98e --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_compile_dynamic_table.ambr @@ -0,0 +1,40 @@ +# serializer version: 1 +# name: test_compile_dynamic_table + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm + "CREATE DYNAMIC TABLE test_dynamic_table_orm (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_orm_with_str_keys + 'CREATE DYNAMIC TABLE "SCHEMA_DB".test_dynamic_table_orm_2 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = \'10 seconds\'\tAS SELECT * FROM table' +# --- +# name: test_compile_dynamic_table_with_multiple_wrong_option_types + ''' + Invalid parameter type 'IdentifierOption' provided for 'refresh_mode'. Expected one of the following types: 'KeywordOption', 'SnowflakeKeyword'. + Invalid parameter type 'IdentifierOption' provided for 'target_lag'. Expected one of the following types: 'TargetLagOption', 'Tuple[int, TimeUnit])', 'SnowflakeKeyword'. + Invalid parameter type 'KeywordOption' provided for 'as_query'. Expected one of the following types: 'AsQueryOption', 'str', 'Selectable'. + Invalid parameter type 'KeywordOption' provided for 'warehouse'. Expected one of the following types: 'IdentifierOption', 'str'. + + ''' +# --- +# name: test_compile_dynamic_table_with_one_wrong_option_types + ''' + Invalid parameter type 'LiteralOption' provided for 'warehouse'. Expected one of the following types: 'IdentifierOption', 'str'. + + ''' +# --- +# name: test_compile_dynamic_table_with_options_objects + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = AUTO\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_refresh_mode[SnowflakeKeyword.AUTO] + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = AUTO\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_refresh_mode[SnowflakeKeyword.FULL] + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = FULL\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_refresh_mode[SnowflakeKeyword.INCREMENTAL] + "CREATE DYNAMIC TABLE test_dynamic_table (\tid INTEGER, \tgeom GEOMETRY)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tREFRESH_MODE = INCREMENTAL\tAS SELECT * FROM table" +# --- +# name: test_compile_dynamic_table_with_selectable + "CREATE DYNAMIC TABLE dynamic_test_table_1 (\tid INTEGER, \tname VARCHAR)\tWAREHOUSE = warehouse\tTARGET_LAG = '10 seconds'\tAS SELECT test_table_1.id, test_table_1.name FROM test_table_1 WHERE test_table_1.id = 23" +# --- diff --git a/tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr b/tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr new file mode 100644 index 00000000..9412fb45 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_compile_hybrid_table.ambr @@ -0,0 +1,7 @@ +# serializer version: 1 +# name: test_compile_hybrid_table + 'CREATE HYBRID TABLE test_hybrid_table (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tgeom GEOMETRY, \tPRIMARY KEY (id))' +# --- +# name: test_compile_hybrid_table_orm + 'CREATE HYBRID TABLE test_hybrid_table_orm (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tPRIMARY KEY (id))' +# --- diff --git a/tests/custom_tables/__snapshots__/test_compile_iceberg_table.ambr b/tests/custom_tables/__snapshots__/test_compile_iceberg_table.ambr new file mode 100644 index 00000000..b243cc09 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_compile_iceberg_table.ambr @@ -0,0 +1,19 @@ +# serializer version: 1 +# name: test_compile_dynamic_table_orm_with_as_query + "CREATE ICEBERG TABLE test_iceberg_table_orm_2 (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tPRIMARY KEY (id))\tEXTERNAL_VOLUME = 'my_external_volume'\tCATALOG = 'SNOWFLAKE'\tBASE_LOCATION = 'my_iceberg_table'\tAS SELECT * FROM table" +# --- +# name: test_compile_icberg_table_with_primary_key + "CREATE ICEBERG TABLE test_iceberg_table_with_options (\tid INTEGER NOT NULL AUTOINCREMENT, \tgeom VARCHAR, \tPRIMARY KEY (id))\tEXTERNAL_VOLUME = 'my_external_volume'\tCATALOG = 'SNOWFLAKE'\tBASE_LOCATION = 'my_iceberg_table'" +# --- +# name: test_compile_iceberg_table + "CREATE ICEBERG TABLE test_iceberg_table (\tid INTEGER, \tgeom VARCHAR)\tEXTERNAL_VOLUME = 'my_external_volume'\tCATALOG = 'SNOWFLAKE'\tBASE_LOCATION = 'my_iceberg_table'" +# --- +# name: test_compile_iceberg_table_with_one_wrong_option_types + ''' + Invalid parameter type 'IdentifierOption' provided for 'external_volume'. Expected one of the following types: 'LiteralOption', 'str', 'int'. + + ''' +# --- +# name: test_compile_iceberg_table_with_options_objects + "CREATE ICEBERG TABLE test_iceberg_table_with_options (\tid INTEGER, \tgeom VARCHAR)\tEXTERNAL_VOLUME = 'my_external_volume'\tCATALOG = 'SNOWFLAKE'\tBASE_LOCATION = 'my_iceberg_table'" +# --- diff --git a/tests/custom_tables/__snapshots__/test_compile_snowflake_table.ambr b/tests/custom_tables/__snapshots__/test_compile_snowflake_table.ambr new file mode 100644 index 00000000..5ea64c12 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_compile_snowflake_table.ambr @@ -0,0 +1,35 @@ +# serializer version: 1 +# name: test_compile_dynamic_table_orm_with_str_keys + 'CREATE TABLE "SCHEMA_DB".test_snowflake_table_orm_2 (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tPRIMARY KEY (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_dynamic_table_with_foreign_key + 'CREATE TABLE test_table_2 (\tid INTEGER NOT NULL, \tgeom VARCHAR, \tPRIMARY KEY (id), \tFOREIGN KEY(id) REFERENCES "table" (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_dynamic_table_with_primary_key + 'CREATE TABLE test_table_2 (\tid INTEGER NOT NULL AUTOINCREMENT, \tgeom VARCHAR, \tPRIMARY KEY (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table + 'CREATE TABLE test_table_1 (\tid INTEGER, \tgeom VARCHAR)\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table_orm_with_str_keys + 'CREATE TABLE "SCHEMA_DB".test_snowflake_table_orm_2 (\tid INTEGER NOT NULL AUTOINCREMENT, \tname VARCHAR, \tPRIMARY KEY (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table_with_explicit_options + 'CREATE TABLE test_table_2 (\tid INTEGER, \tgeom VARCHAR)\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table_with_foreign_key + 'CREATE TABLE test_table_2 (\tid INTEGER NOT NULL, \tgeom VARCHAR, \tPRIMARY KEY (id), \tFOREIGN KEY(id) REFERENCES "table" (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table_with_primary_key + 'CREATE TABLE test_table_2 (\tid INTEGER NOT NULL AUTOINCREMENT, \tgeom VARCHAR, \tPRIMARY KEY (id))\tCLUSTER BY (id, id > 100)\tAS SELECT * FROM table' +# --- +# name: test_compile_snowflake_table_with_selectable + 'CREATE TABLE snowflake_test_table_1 (\tid INTEGER, \tgeom VARCHAR)\tAS SELECT test_table_1.id, test_table_1.geom FROM test_table_1 WHERE test_table_1.id = 23' +# --- +# name: test_compile_snowflake_table_with_wrong_option_types + ''' + Invalid parameter type 'AsQueryOption' provided for 'cluster by'. Expected one of the following types: 'ClusterByOption', 'list'. + Invalid parameter type 'ClusterByOption' provided for 'as_query'. Expected one of the following types: 'AsQueryOption', 'str', 'Selectable'. + + ''' +# --- diff --git a/tests/custom_tables/__snapshots__/test_create_dynamic_table.ambr b/tests/custom_tables/__snapshots__/test_create_dynamic_table.ambr new file mode 100644 index 00000000..80201495 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_create_dynamic_table.ambr @@ -0,0 +1,7 @@ +# serializer version: 1 +# name: test_create_dynamic_table_without_dynamictable_and_defined_options + CustomOptionsAreOnlySupportedOnSnowflakeTables('Identifier, Literal, TargetLag and other custom options are only supported on Snowflake tables.') +# --- +# name: test_create_dynamic_table_without_dynamictable_class + UnexpectedOptionTypeError('The following options are either unsupported or should be defined using a Snowflake table: as_query, warehouse.') +# --- diff --git a/tests/custom_tables/__snapshots__/test_create_hybrid_table.ambr b/tests/custom_tables/__snapshots__/test_create_hybrid_table.ambr new file mode 100644 index 00000000..696ff9c8 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_create_hybrid_table.ambr @@ -0,0 +1,7 @@ +# serializer version: 1 +# name: test_create_hybrid_table + "[(1, 'test')]" +# --- +# name: test_create_hybrid_table_with_multiple_index + ProgrammingError("(snowflake.connector.errors.ProgrammingError) 391480 (0A000): Another index is being built on table 'TEST_HYBRID_TABLE_WITH_MULTIPLE_INDEX'. Only one index can be built at a time. Either cancel the other index creation or wait until it is complete.") +# --- diff --git a/tests/custom_tables/__snapshots__/test_create_iceberg_table.ambr b/tests/custom_tables/__snapshots__/test_create_iceberg_table.ambr new file mode 100644 index 00000000..908a4c60 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_create_iceberg_table.ambr @@ -0,0 +1,14 @@ +# serializer version: 1 +# name: test_create_iceberg_table + ''' + (snowflake.connector.errors.ProgrammingError) 091017 (22000): S3 bucket 'my_example_bucket' does not exist or not authorized. + [SQL: + CREATE ICEBERG TABLE "Iceberg_Table_1" ( + id INTEGER NOT NULL AUTOINCREMENT, + geom VARCHAR, + PRIMARY KEY (id) + ) EXTERNAL_VOLUME = 'exvol' CATALOG = 'SNOWFLAKE' BASE_LOCATION = 'my_iceberg_table' + + ] + ''' +# --- diff --git a/tests/custom_tables/__snapshots__/test_create_snowflake_table.ambr b/tests/custom_tables/__snapshots__/test_create_snowflake_table.ambr new file mode 100644 index 00000000..98d3137f --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_create_snowflake_table.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_create_snowflake_table_with_cluster_by + "[(1, 'test')]" +# --- diff --git a/tests/custom_tables/__snapshots__/test_generic_options.ambr b/tests/custom_tables/__snapshots__/test_generic_options.ambr new file mode 100644 index 00000000..eef5e6fd --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_generic_options.ambr @@ -0,0 +1,13 @@ +# serializer version: 1 +# name: test_identifier_option_with_wrong_type + InvalidTableParameterTypeError("Invalid parameter type 'int' provided for 'warehouse'. Expected one of the following types: 'IdentifierOption', 'str'.\n") +# --- +# name: test_identifier_option_without_name + OptionKeyNotProvidedError('Expected option key in IdentifierOption option but got NoneType instead.') +# --- +# name: test_invalid_as_query_option + InvalidTableParameterTypeError("Invalid parameter type 'int' provided for 'as_query'. Expected one of the following types: 'AsQueryOption', 'str', 'Selectable'.\n") +# --- +# name: test_literal_option_with_wrong_type + InvalidTableParameterTypeError("Invalid parameter type 'SnowflakeKeyword' provided for 'warehouse'. Expected one of the following types: 'LiteralOption', 'str', 'int'.\n") +# --- diff --git a/tests/custom_tables/__snapshots__/test_reflect_hybrid_table.ambr b/tests/custom_tables/__snapshots__/test_reflect_hybrid_table.ambr new file mode 100644 index 00000000..6f6cd395 --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_reflect_hybrid_table.ambr @@ -0,0 +1,4 @@ +# serializer version: 1 +# name: test_simple_reflection_hybrid_table_as_table + 'CREATE TABLE test_hybrid_table_reflection (\tid DECIMAL(38, 0) NOT NULL, \tname VARCHAR(16777216), \tCONSTRAINT demo_name PRIMARY KEY (id))' +# --- diff --git a/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr b/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr new file mode 100644 index 00000000..7e85841a --- /dev/null +++ b/tests/custom_tables/__snapshots__/test_reflect_snowflake_table.ambr @@ -0,0 +1,29 @@ +# serializer version: 1 +# name: test_inspect_snowflake_table + list([ + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'id', + 'nullable': False, + 'primary_key': True, + 'type': _CUSTOM_DECIMAL(precision=38, scale=0), + }), + dict({ + 'autoincrement': False, + 'comment': None, + 'default': None, + 'name': 'name', + 'nullable': True, + 'primary_key': False, + 'type': VARCHAR(length=16777216), + }), + ]) +# --- +# name: test_simple_reflection_of_table_as_snowflake_table + 'CREATE TABLE test_snowflake_table_reflection (\tid DECIMAL(38, 0) NOT NULL, \tname VARCHAR(16777216), \tCONSTRAINT demo_name PRIMARY KEY (id))' +# --- +# name: test_simple_reflection_of_table_as_sqlalchemy_table + 'CREATE TABLE test_snowflake_table_reflection (\tid DECIMAL(38, 0) NOT NULL, \tname VARCHAR(16777216), \tCONSTRAINT demo_name PRIMARY KEY (id))' +# --- diff --git a/tests/custom_tables/test_compile_dynamic_table.py b/tests/custom_tables/test_compile_dynamic_table.py new file mode 100644 index 00000000..935c61cd --- /dev/null +++ b/tests/custom_tables/test_compile_dynamic_table.py @@ -0,0 +1,271 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import ( + Column, + ForeignKeyConstraint, + Integer, + MetaData, + String, + Table, + exc, + select, +) +from sqlalchemy.exc import ArgumentError +from sqlalchemy.orm import declarative_base +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import GEOMETRY, DynamicTable +from snowflake.sqlalchemy.exc import MultipleErrors +from snowflake.sqlalchemy.sql.custom_schema.options import ( + AsQueryOption, + IdentifierOption, + KeywordOption, + LiteralOption, + TargetLagOption, + TimeUnit, +) +from snowflake.sqlalchemy.sql.custom_schema.options.keywords import SnowflakeKeyword + + +def test_compile_dynamic_table(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + test_geometry = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +@pytest.mark.parametrize( + "refresh_mode_keyword", + [ + SnowflakeKeyword.AUTO, + SnowflakeKeyword.FULL, + SnowflakeKeyword.INCREMENTAL, + ], +) +def test_compile_dynamic_table_with_refresh_mode( + sql_compiler, snapshot, refresh_mode_keyword +): + metadata = MetaData() + table_name = "test_dynamic_table" + test_geometry = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", + refresh_mode=refresh_mode_keyword, + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_with_options_objects(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + test_geometry = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=TargetLagOption(10, TimeUnit.SECONDS), + warehouse=IdentifierOption("warehouse"), + as_query=AsQueryOption("SELECT * FROM table"), + refresh_mode=KeywordOption(SnowflakeKeyword.AUTO), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_with_one_wrong_option_types(snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + with pytest.raises(ArgumentError) as argument_error: + DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=TargetLagOption(10, TimeUnit.SECONDS), + warehouse=LiteralOption("warehouse"), + as_query=AsQueryOption("SELECT * FROM table"), + refresh_mode=KeywordOption(SnowflakeKeyword.AUTO), + ) + + assert str(argument_error.value) == snapshot + + +def test_compile_dynamic_table_with_multiple_wrong_option_types(snapshot): + metadata = MetaData() + table_name = "test_dynamic_table" + with pytest.raises(MultipleErrors) as argument_error: + DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", GEOMETRY), + target_lag=IdentifierOption(SnowflakeKeyword.AUTO), + warehouse=KeywordOption(SnowflakeKeyword.AUTO), + as_query=KeywordOption(SnowflakeKeyword.AUTO), + refresh_mode=IdentifierOption(SnowflakeKeyword.AUTO), + ) + + assert str(argument_error.value) == snapshot + + +def test_compile_dynamic_table_without_required_args(sql_compiler): + with pytest.raises( + exc.ArgumentError, + match="DynamicTable requires the following parameters: warehouse, " + "as_query, target_lag.", + ): + DynamicTable( + "test_dynamic_table", + MetaData(), + Column("id", Integer, primary_key=True), + Column("geom", GEOMETRY), + ) + + +def test_compile_dynamic_table_with_primary_key(sql_compiler): + with pytest.raises( + exc.ArgumentError, + match="Primary key and foreign keys are not supported in DynamicTable.", + ): + DynamicTable( + "test_dynamic_table", + MetaData(), + Column("id", Integer, primary_key=True), + Column("geom", GEOMETRY), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", + ) + + +def test_compile_dynamic_table_with_foreign_key(sql_compiler): + with pytest.raises( + exc.ArgumentError, + match="Primary key and foreign keys are not supported in DynamicTable.", + ): + DynamicTable( + "test_dynamic_table", + MetaData(), + Column("id", Integer), + Column("geom", GEOMETRY), + ForeignKeyConstraint(["id"], ["table.id"]), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", + ) + + +def test_compile_dynamic_table_orm(sql_compiler, snapshot): + Base = declarative_base() + metadata = MetaData() + table_name = "test_dynamic_table_orm" + test_dynamic_table_orm = DynamicTable( + table_name, + metadata, + Column("id", Integer), + Column("name", String), + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query="SELECT * FROM table", + ) + + class TestDynamicTableOrm(Base): + __table__ = test_dynamic_table_orm + __mapper_args__ = { + "primary_key": [test_dynamic_table_orm.c.id, test_dynamic_table_orm.c.name] + } + + def __repr__(self): + return f"" + + value = CreateTable(TestDynamicTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_orm_with_str_keys(sql_compiler, snapshot): + Base = declarative_base() + + class TestDynamicTableOrm(Base): + __tablename__ = "test_dynamic_table_orm_2" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return DynamicTable(name, metadata, *arg, **kw) + + __table_args__ = { + "schema": "SCHEMA_DB", + "target_lag": (10, TimeUnit.SECONDS), + "warehouse": "warehouse", + "as_query": "SELECT * FROM table", + } + + id = Column(Integer) + name = Column(String) + + __mapper_args__ = {"primary_key": [id, name]} + + def __repr__(self): + return f"" + + value = CreateTable(TestDynamicTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_with_selectable(sql_compiler, snapshot): + Base = declarative_base() + + test_table_1 = Table( + "test_table_1", + Base.metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + ) + + dynamic_test_table = DynamicTable( + "dynamic_test_table_1", + Base.metadata, + target_lag=(10, TimeUnit.SECONDS), + warehouse="warehouse", + as_query=select(test_table_1).where(test_table_1.c.id == 23), + ) + + value = CreateTable(dynamic_test_table) + + actual = sql_compiler(value) + + assert actual == snapshot diff --git a/tests/custom_tables/test_compile_hybrid_table.py b/tests/custom_tables/test_compile_hybrid_table.py new file mode 100644 index 00000000..f1af6dc2 --- /dev/null +++ b/tests/custom_tables/test_compile_hybrid_table.py @@ -0,0 +1,52 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import Column, Integer, MetaData, String +from sqlalchemy.orm import declarative_base +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import GEOMETRY, HybridTable + + +@pytest.mark.aws +def test_compile_hybrid_table(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_hybrid_table" + test_geometry = HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + Column("geom", GEOMETRY), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +@pytest.mark.aws +def test_compile_hybrid_table_orm(sql_compiler, snapshot): + Base = declarative_base() + + class TestHybridTableOrm(Base): + __tablename__ = "test_hybrid_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"" + + value = CreateTable(TestHybridTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot diff --git a/tests/custom_tables/test_compile_iceberg_table.py b/tests/custom_tables/test_compile_iceberg_table.py new file mode 100644 index 00000000..173e7b0a --- /dev/null +++ b/tests/custom_tables/test_compile_iceberg_table.py @@ -0,0 +1,116 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import Column, Integer, MetaData, String +from sqlalchemy.exc import ArgumentError +from sqlalchemy.orm import declarative_base +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import IcebergTable +from snowflake.sqlalchemy.sql.custom_schema.options import ( + IdentifierOption, + LiteralOption, +) + + +def test_compile_iceberg_table(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_iceberg_table" + test_table = IcebergTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + external_volume="my_external_volume", + base_location="my_iceberg_table", + ) + + value = CreateTable(test_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_iceberg_table_with_options_objects(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_iceberg_table_with_options" + test_table = IcebergTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + external_volume=LiteralOption("my_external_volume"), + base_location=LiteralOption("my_iceberg_table"), + ) + + value = CreateTable(test_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_iceberg_table_with_one_wrong_option_types(snapshot): + metadata = MetaData() + table_name = "test_wrong_iceberg_table" + with pytest.raises(ArgumentError) as argument_error: + IcebergTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + external_volume=IdentifierOption("my_external_volume"), + base_location=LiteralOption("my_iceberg_table"), + ) + + assert str(argument_error.value) == snapshot + + +def test_compile_icberg_table_with_primary_key(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_iceberg_table_with_options" + test_table = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + external_volume=LiteralOption("my_external_volume"), + base_location=LiteralOption("my_iceberg_table"), + ) + + value = CreateTable(test_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_dynamic_table_orm_with_as_query(sql_compiler, snapshot): + Base = declarative_base() + + class TestDynamicTableOrm(Base): + __tablename__ = "test_iceberg_table_orm_2" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return IcebergTable(name, metadata, *arg, **kw) + + __table_args__ = { + "external_volume": "my_external_volume", + "base_location": "my_iceberg_table", + "as_query": "SELECT * FROM table", + } + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"" + + value = CreateTable(TestDynamicTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot diff --git a/tests/custom_tables/test_compile_snowflake_table.py b/tests/custom_tables/test_compile_snowflake_table.py new file mode 100644 index 00000000..be9383eb --- /dev/null +++ b/tests/custom_tables/test_compile_snowflake_table.py @@ -0,0 +1,180 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import ( + Column, + ForeignKeyConstraint, + Integer, + MetaData, + String, + select, + text, +) +from sqlalchemy.exc import ArgumentError +from sqlalchemy.orm import declarative_base +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import SnowflakeTable +from snowflake.sqlalchemy.sql.custom_schema.options import ( + AsQueryOption, + ClusterByOption, +) + + +def test_compile_snowflake_table(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_table_1" + test_geometry = SnowflakeTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + cluster_by=["id", text("id > 100")], + as_query="SELECT * FROM table", + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_snowflake_table_with_explicit_options(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_table_2" + test_geometry = SnowflakeTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + cluster_by=ClusterByOption("id", text("id > 100")), + as_query=AsQueryOption("SELECT * FROM table"), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_snowflake_table_with_wrong_option_types(snapshot): + metadata = MetaData() + table_name = "test_snowflake_table" + with pytest.raises(ArgumentError) as argument_error: + SnowflakeTable( + table_name, + metadata, + Column("id", Integer), + Column("geom", String), + as_query=ClusterByOption("id", text("id > 100")), + cluster_by=AsQueryOption("SELECT * FROM table"), + ) + + assert str(argument_error.value) == snapshot + + +def test_compile_snowflake_table_with_primary_key(sql_compiler, snapshot): + metadata = MetaData() + table_name = "test_table_2" + test_geometry = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + cluster_by=ClusterByOption("id", text("id > 100")), + as_query=AsQueryOption("SELECT * FROM table"), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_snowflake_table_with_foreign_key(sql_compiler, snapshot): + metadata = MetaData() + + SnowflakeTable( + "table", + metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + ForeignKeyConstraint(["id"], ["table.id"]), + cluster_by=ClusterByOption("id", text("id > 100")), + as_query=AsQueryOption("SELECT * FROM table"), + ) + + table_name = "test_table_2" + test_geometry = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + ForeignKeyConstraint(["id"], ["table.id"]), + cluster_by=ClusterByOption("id", text("id > 100")), + as_query=AsQueryOption("SELECT * FROM table"), + ) + + value = CreateTable(test_geometry) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_snowflake_table_orm_with_str_keys(sql_compiler, snapshot): + Base = declarative_base() + + class TestSnowflakeTableOrm(Base): + __tablename__ = "test_snowflake_table_orm_2" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return SnowflakeTable(name, metadata, *arg, **kw) + + __table_args__ = { + "schema": "SCHEMA_DB", + "cluster_by": ["id", text("id > 100")], + "as_query": "SELECT * FROM table", + } + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"" + + value = CreateTable(TestSnowflakeTableOrm.__table__) + + actual = sql_compiler(value) + + assert actual == snapshot + + +def test_compile_snowflake_table_with_selectable(sql_compiler, snapshot): + Base = declarative_base() + + test_table_1 = SnowflakeTable( + "test_table_1", + Base.metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + ForeignKeyConstraint(["id"], ["table.id"]), + cluster_by=ClusterByOption("id", text("id > 100")), + ) + + test_table_2 = SnowflakeTable( + "snowflake_test_table_1", + Base.metadata, + as_query=select(test_table_1).where(test_table_1.c.id == 23), + ) + + value = CreateTable(test_table_2) + + actual = sql_compiler(value) + + assert actual == snapshot diff --git a/tests/custom_tables/test_create_dynamic_table.py b/tests/custom_tables/test_create_dynamic_table.py new file mode 100644 index 00000000..b583faad --- /dev/null +++ b/tests/custom_tables/test_create_dynamic_table.py @@ -0,0 +1,124 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import Column, Integer, MetaData, String, Table, select + +from snowflake.sqlalchemy import DynamicTable, exc +from snowflake.sqlalchemy.sql.custom_schema.options.as_query_option import AsQueryOption +from snowflake.sqlalchemy.sql.custom_schema.options.identifier_option import ( + IdentifierOption, +) +from snowflake.sqlalchemy.sql.custom_schema.options.keywords import SnowflakeKeyword +from snowflake.sqlalchemy.sql.custom_schema.options.table_option import TableOptionKey +from snowflake.sqlalchemy.sql.custom_schema.options.target_lag_option import ( + TargetLagOption, + TimeUnit, +) + + +def test_create_dynamic_table(engine_testaccount, db_parameters): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + + dynamic_test_table_1 = DynamicTable( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + target_lag=(1, TimeUnit.HOURS), + warehouse=warehouse, + as_query="SELECT id, name from test_table_1;", + refresh_mode=SnowflakeKeyword.FULL, + ) + + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table_1) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) + + +def test_create_dynamic_table_without_dynamictable_class( + engine_testaccount, db_parameters, snapshot +): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + + Table( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + snowflake_warehouse=warehouse, + snowflake_as_query="SELECT id, name from test_table_1;", + prefixes=["DYNAMIC"], + ) + + with pytest.raises(exc.UnexpectedOptionTypeError) as exc_info: + metadata.create_all(engine_testaccount) + assert exc_info.value == snapshot + + +def test_create_dynamic_table_without_dynamictable_and_defined_options( + engine_testaccount, db_parameters, snapshot +): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + + Table( + "dynamic_test_table_1", + metadata, + Column("id", Integer), + Column("name", String), + snowflake_target_lag=TargetLagOption.create((1, TimeUnit.HOURS)), + snowflake_warehouse=IdentifierOption.create( + TableOptionKey.WAREHOUSE, warehouse + ), + snowflake_as_query=AsQueryOption.create("SELECT id, name from test_table_1;"), + prefixes=["DYNAMIC"], + ) + + with pytest.raises(exc.CustomOptionsAreOnlySupportedOnSnowflakeTables) as exc_info: + metadata.create_all(engine_testaccount) + assert exc_info.value == snapshot diff --git a/tests/custom_tables/test_create_hybrid_table.py b/tests/custom_tables/test_create_hybrid_table.py new file mode 100644 index 00000000..43ae3ab6 --- /dev/null +++ b/tests/custom_tables/test_create_hybrid_table.py @@ -0,0 +1,95 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +import sqlalchemy.exc +from sqlalchemy import Column, Index, Integer, MetaData, String, select +from sqlalchemy.orm import Session, declarative_base + +from snowflake.sqlalchemy import HybridTable + + +@pytest.mark.aws +def test_create_hybrid_table(engine_testaccount, db_parameters, snapshot): + metadata = MetaData() + table_name = "test_create_hybrid_table" + + dynamic_test_table_1 = HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = dynamic_test_table_1.insert().values(id=1, name="test") + conn.execute(ins) + conn.commit() + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table_1) + results_hybrid_table = conn.execute(s).fetchall() + assert str(results_hybrid_table) == snapshot + finally: + metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_create_hybrid_table_with_multiple_index( + engine_testaccount, db_parameters, snapshot, sql_compiler +): + metadata = MetaData() + table_name = "test_hybrid_table_with_multiple_index" + + hybrid_test_table_1 = HybridTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String, index=True), + Column("name2", String), + Column("name3", String), + ) + + metadata.create_all(engine_testaccount) + + index = Index("idx_col34", hybrid_test_table_1.c.name2, hybrid_test_table_1.c.name3) + + with pytest.raises(sqlalchemy.exc.ProgrammingError) as exc_info: + index.create(engine_testaccount) + try: + assert exc_info.value == snapshot + finally: + metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_create_hybrid_table_with_orm(sql_compiler, engine_testaccount): + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestHybridTableOrm(Base): + __tablename__ = "test_hybrid_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"({self.id!r}, {self.name!r})" + + Base.metadata.create_all(engine_testaccount) + + try: + instance = TestHybridTableOrm(id=0, name="name_example") + session.add(instance) + session.commit() + data = session.query(TestHybridTableOrm).all() + assert str(data) == "[(0, 'name_example')]" + finally: + Base.metadata.drop_all(engine_testaccount) diff --git a/tests/custom_tables/test_create_iceberg_table.py b/tests/custom_tables/test_create_iceberg_table.py new file mode 100644 index 00000000..3ecd703b --- /dev/null +++ b/tests/custom_tables/test_create_iceberg_table.py @@ -0,0 +1,43 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import Column, Integer, MetaData, String +from sqlalchemy.exc import ProgrammingError + +from snowflake.sqlalchemy import IcebergTable + + +@pytest.mark.aws +def test_create_iceberg_table(engine_testaccount, snapshot): + metadata = MetaData() + external_volume_name = "exvol" + create_external_volume = f""" + CREATE OR REPLACE EXTERNAL VOLUME {external_volume_name} + STORAGE_LOCATIONS = + ( + ( + NAME = 'my-s3-us-west-2' + STORAGE_PROVIDER = 'S3' + STORAGE_BASE_URL = 's3://MY_EXAMPLE_BUCKET/' + STORAGE_AWS_ROLE_ARN = 'arn:aws:iam::123456789012:role/myrole' + ENCRYPTION=(TYPE='AWS_SSE_KMS' KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab') + ) + ); + """ + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_external_volume) + IcebergTable( + "Iceberg_Table_1", + metadata, + Column("id", Integer, primary_key=True), + Column("geom", String), + external_volume=external_volume_name, + base_location="my_iceberg_table", + ) + + with pytest.raises(ProgrammingError) as argument_error: + metadata.create_all(engine_testaccount) + + error_str = str(argument_error.value) + assert error_str[: error_str.rfind("\n")] == snapshot diff --git a/tests/custom_tables/test_create_snowflake_table.py b/tests/custom_tables/test_create_snowflake_table.py new file mode 100644 index 00000000..09140fb8 --- /dev/null +++ b/tests/custom_tables/test_create_snowflake_table.py @@ -0,0 +1,66 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from sqlalchemy import Column, Integer, MetaData, String, select, text +from sqlalchemy.orm import Session, declarative_base + +from snowflake.sqlalchemy import SnowflakeTable + + +def test_create_snowflake_table_with_cluster_by( + engine_testaccount, db_parameters, snapshot +): + metadata = MetaData() + table_name = "test_create_snowflake_table" + + test_table_1 = SnowflakeTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("name", String), + cluster_by=["id", text("id > 5")], + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + conn.execute(ins) + conn.commit() + + try: + with engine_testaccount.connect() as conn: + s = select(test_table_1) + results_hybrid_table = conn.execute(s).fetchall() + assert str(results_hybrid_table) == snapshot + finally: + metadata.drop_all(engine_testaccount) + + +def test_create_snowflake_table_with_orm(sql_compiler, engine_testaccount): + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestHybridTableOrm(Base): + __tablename__ = "test_snowflake_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return SnowflakeTable(name, metadata, *arg, **kw) + + id = Column(Integer, primary_key=True) + name = Column(String) + + def __repr__(self): + return f"({self.id!r}, {self.name!r})" + + Base.metadata.create_all(engine_testaccount) + + try: + instance = TestHybridTableOrm(id=0, name="name_example") + session.add(instance) + session.commit() + data = session.query(TestHybridTableOrm).all() + assert str(data) == "[(0, 'name_example')]" + finally: + Base.metadata.drop_all(engine_testaccount) diff --git a/tests/custom_tables/test_generic_options.py b/tests/custom_tables/test_generic_options.py new file mode 100644 index 00000000..916b94c6 --- /dev/null +++ b/tests/custom_tables/test_generic_options.py @@ -0,0 +1,83 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +import pytest + +from snowflake.sqlalchemy import ( + AsQueryOption, + IdentifierOption, + KeywordOption, + LiteralOption, + SnowflakeKeyword, + TableOptionKey, + TargetLagOption, + exc, +) +from snowflake.sqlalchemy.sql.custom_schema.options.invalid_table_option import ( + InvalidTableOption, +) + + +def test_identifier_option(): + identifier = IdentifierOption.create(TableOptionKey.WAREHOUSE, "xsmall") + assert identifier.render_option(None) == "WAREHOUSE = xsmall" + + +def test_literal_option(): + literal = LiteralOption.create(TableOptionKey.WAREHOUSE, "xsmall") + assert literal.render_option(None) == "WAREHOUSE = 'xsmall'" + + +def test_identifier_option_without_name(snapshot): + identifier = IdentifierOption("xsmall") + with pytest.raises(exc.OptionKeyNotProvidedError) as exc_info: + identifier.render_option(None) + assert exc_info.value == snapshot + + +def test_identifier_option_with_wrong_type(snapshot): + identifier = IdentifierOption.create(TableOptionKey.WAREHOUSE, 23) + with pytest.raises(exc.InvalidTableParameterTypeError) as exc_info: + identifier.render_option(None) + assert exc_info.value == snapshot + + +def test_literal_option_with_wrong_type(snapshot): + literal = LiteralOption.create( + TableOptionKey.WAREHOUSE, SnowflakeKeyword.DOWNSTREAM + ) + with pytest.raises(exc.InvalidTableParameterTypeError) as exc_info: + literal.render_option(None) + assert exc_info.value == snapshot + + +def test_invalid_as_query_option(snapshot): + as_query = AsQueryOption.create(23) + with pytest.raises(exc.InvalidTableParameterTypeError) as exc_info: + as_query.render_option(None) + assert exc_info.value == snapshot + + +@pytest.mark.parametrize( + "table_option", + [ + IdentifierOption, + LiteralOption, + KeywordOption, + ], +) +def test_generic_option_with_wrong_type(table_option): + literal = table_option.create(TableOptionKey.WAREHOUSE, 0.32) + assert isinstance(literal, InvalidTableOption), "Expected InvalidTableOption" + + +@pytest.mark.parametrize( + "table_option", + [ + TargetLagOption, + AsQueryOption, + ], +) +def test_non_generic_option_with_wrong_type(table_option): + literal = table_option.create(0.32) + assert isinstance(literal, InvalidTableOption), "Expected InvalidTableOption" diff --git a/tests/custom_tables/test_reflect_dynamic_table.py b/tests/custom_tables/test_reflect_dynamic_table.py new file mode 100644 index 00000000..52eb4457 --- /dev/null +++ b/tests/custom_tables/test_reflect_dynamic_table.py @@ -0,0 +1,88 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from sqlalchemy import Column, Integer, MetaData, String, Table, select + +from snowflake.sqlalchemy import DynamicTable +from snowflake.sqlalchemy.custom_commands import NoneType + + +def test_simple_reflection_dynamic_table_as_table(engine_testaccount, db_parameters): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + create_table_sql = f""" + CREATE DYNAMIC TABLE dynamic_test_table (id INT, name VARCHAR) + TARGET_LAG = '20 minutes' + WAREHOUSE = {warehouse} + AS SELECT id, name from test_table_1; + """ + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + dynamic_test_table = Table( + "dynamic_test_table", metadata, autoload_with=engine_testaccount + ) + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) + + +def test_simple_reflection_without_options_loading(engine_testaccount, db_parameters): + warehouse = db_parameters.get("warehouse", "default") + metadata = MetaData() + test_table_1 = Table( + "test_table_1", metadata, Column("id", Integer), Column("name", String) + ) + + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + ins = test_table_1.insert().values(id=1, name="test") + + conn.execute(ins) + conn.commit() + create_table_sql = f""" + CREATE DYNAMIC TABLE dynamic_test_table (id INT, name VARCHAR) + TARGET_LAG = '20 minutes' + WAREHOUSE = {warehouse} + AS SELECT id, name from test_table_1; + """ + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + dynamic_test_table = DynamicTable( + "dynamic_test_table", metadata, autoload_with=engine_testaccount + ) + + # TODO: Add support for loading options when table is reflected + assert isinstance(dynamic_test_table.warehouse, NoneType) + + try: + with engine_testaccount.connect() as conn: + s = select(dynamic_test_table) + results_dynamic_table = conn.execute(s).fetchall() + s = select(test_table_1) + results_table = conn.execute(s).fetchall() + assert results_dynamic_table == results_table + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/custom_tables/test_reflect_hybrid_table.py b/tests/custom_tables/test_reflect_hybrid_table.py new file mode 100644 index 00000000..4a777bf0 --- /dev/null +++ b/tests/custom_tables/test_reflect_hybrid_table.py @@ -0,0 +1,65 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import MetaData, Table +from sqlalchemy.sql.ddl import CreateTable + + +@pytest.mark.aws +def test_simple_reflection_hybrid_table_as_table( + engine_testaccount, db_parameters, sql_compiler, snapshot +): + metadata = MetaData() + table_name = "test_hybrid_table_reflection" + + create_table_sql = f""" + CREATE HYBRID TABLE {table_name} (id INT primary key, name VARCHAR, INDEX index_name (name)); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + hybrid_test_table = Table(table_name, metadata, autoload_with=engine_testaccount) + + constraint = hybrid_test_table.constraints.pop() + constraint.name = "demo_name" + hybrid_test_table.constraints.add(constraint) + + try: + with engine_testaccount.connect(): + value = CreateTable(hybrid_test_table) + + actual = sql_compiler(value) + + # Prefixes reflection not supported, example: "HYBRID, DYNAMIC" + assert actual == snapshot + + finally: + metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_reflect_hybrid_table_with_index( + engine_testaccount, db_parameters, sql_compiler +): + metadata = MetaData() + schema = db_parameters["schema"] + + table_name = "test_hybrid_table_2" + index_name = "INDEX_NAME_2" + + create_table_sql = f""" + CREATE HYBRID TABLE {table_name} (id INT primary key, name VARCHAR, INDEX {index_name} (name)); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + table = Table(table_name, metadata, schema=schema, autoload_with=engine_testaccount) + + try: + assert len(table.indexes) == 1 and table.indexes.pop().name == index_name + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/custom_tables/test_reflect_snowflake_table.py b/tests/custom_tables/test_reflect_snowflake_table.py new file mode 100644 index 00000000..603b6187 --- /dev/null +++ b/tests/custom_tables/test_reflect_snowflake_table.py @@ -0,0 +1,92 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +from sqlalchemy import MetaData, Table, inspect +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import SnowflakeTable + + +def test_simple_reflection_of_table_as_sqlalchemy_table( + engine_testaccount, db_parameters, sql_compiler, snapshot +): + metadata = MetaData() + table_name = "test_snowflake_table_reflection" + + create_table_sql = f""" + CREATE TABLE {table_name} (id INT primary key, name VARCHAR); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + snowflake_test_table = Table(table_name, metadata, autoload_with=engine_testaccount) + constraint = snowflake_test_table.constraints.pop() + constraint.name = "demo_name" + snowflake_test_table.constraints.add(constraint) + + try: + with engine_testaccount.connect(): + value = CreateTable(snowflake_test_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + finally: + metadata.drop_all(engine_testaccount) + + +def test_simple_reflection_of_table_as_snowflake_table( + engine_testaccount, db_parameters, sql_compiler, snapshot +): + metadata = MetaData() + table_name = "test_snowflake_table_reflection" + + create_table_sql = f""" + CREATE TABLE {table_name} (id INT primary key, name VARCHAR); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + snowflake_test_table = SnowflakeTable( + table_name, metadata, autoload_with=engine_testaccount + ) + constraint = snowflake_test_table.constraints.pop() + constraint.name = "demo_name" + snowflake_test_table.constraints.add(constraint) + + try: + with engine_testaccount.connect(): + value = CreateTable(snowflake_test_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + finally: + metadata.drop_all(engine_testaccount) + + +def test_inspect_snowflake_table( + engine_testaccount, db_parameters, sql_compiler, snapshot +): + metadata = MetaData() + table_name = "test_snowflake_table_inspect" + + create_table_sql = f""" + CREATE TABLE {table_name} (id INT primary key, name VARCHAR); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + try: + with engine_testaccount.connect() as conn: + insp = inspect(conn) + table = insp.get_columns(table_name) + assert table == snapshot + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/sqlalchemy_test_suite/conftest.py b/tests/sqlalchemy_test_suite/conftest.py index 31cd7c5c..f0464c7d 100644 --- a/tests/sqlalchemy_test_suite/conftest.py +++ b/tests/sqlalchemy_test_suite/conftest.py @@ -15,6 +15,7 @@ import snowflake.connector from snowflake.sqlalchemy import URL +from snowflake.sqlalchemy.compat import IS_VERSION_20 from ..conftest import get_db_parameters from ..util import random_string @@ -25,6 +26,12 @@ TEST_SCHEMA_2 = f"{TEST_SCHEMA}_2" +if IS_VERSION_20: + collect_ignore_glob = ["test_suite.py"] +else: + collect_ignore_glob = ["test_suite_20.py"] + + # patch sqlalchemy.testing.config.Confi.__init__ for schema name randomization # same schema name would result in conflict as we're running tests in parallel in the CI def config_patched__init__(self, db, db_opts, options, file_config): diff --git a/tests/sqlalchemy_test_suite/test_suite.py b/tests/sqlalchemy_test_suite/test_suite.py index d79e511e..643d1559 100644 --- a/tests/sqlalchemy_test_suite/test_suite.py +++ b/tests/sqlalchemy_test_suite/test_suite.py @@ -69,6 +69,10 @@ def test_empty_insert(self, connection): def test_empty_insert_multiple(self, connection): pass + @pytest.mark.skip("Snowflake does not support returning in insert.") + def test_no_results_for_non_returning_insert(self, connection, style, executemany): + pass + # 2. Patched Tests diff --git a/tests/sqlalchemy_test_suite/test_suite_20.py b/tests/sqlalchemy_test_suite/test_suite_20.py new file mode 100644 index 00000000..1f79c4e9 --- /dev/null +++ b/tests/sqlalchemy_test_suite/test_suite_20.py @@ -0,0 +1,205 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import Integer, testing +from sqlalchemy.schema import Column, Sequence, Table +from sqlalchemy.testing import config +from sqlalchemy.testing.assertions import eq_ +from sqlalchemy.testing.suite import ( + BizarroCharacterFKResolutionTest as _BizarroCharacterFKResolutionTest, +) +from sqlalchemy.testing.suite import ( + CompositeKeyReflectionTest as _CompositeKeyReflectionTest, +) +from sqlalchemy.testing.suite import DateTimeHistoricTest as _DateTimeHistoricTest +from sqlalchemy.testing.suite import FetchLimitOffsetTest as _FetchLimitOffsetTest +from sqlalchemy.testing.suite import HasSequenceTest as _HasSequenceTest +from sqlalchemy.testing.suite import InsertBehaviorTest as _InsertBehaviorTest +from sqlalchemy.testing.suite import LikeFunctionsTest as _LikeFunctionsTest +from sqlalchemy.testing.suite import LongNameBlowoutTest as _LongNameBlowoutTest +from sqlalchemy.testing.suite import SimpleUpdateDeleteTest as _SimpleUpdateDeleteTest +from sqlalchemy.testing.suite import TimeMicrosecondsTest as _TimeMicrosecondsTest +from sqlalchemy.testing.suite import TrueDivTest as _TrueDivTest +from sqlalchemy.testing.suite import * # noqa + +# 1. Unsupported by snowflake db + +del ComponentReflectionTest # require indexes not supported by snowflake +del HasIndexTest # require indexes not supported by snowflake +del QuotedNameArgumentTest # require indexes not supported by snowflake + + +class LongNameBlowoutTest(_LongNameBlowoutTest): + # The combination ("ix",) is removed due to Snowflake not supporting indexes + def ix(self, metadata, connection): + pytest.skip("ix required index feature not supported by Snowflake") + + +class FetchLimitOffsetTest(_FetchLimitOffsetTest): + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_bound_offset(self, connection): + pass + + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_simple_limit_expr_offset(self, connection): + pass + + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_simple_offset(self, connection): + pass + + @pytest.mark.skip( + "Snowflake only takes non-negative integer constants for offset/limit" + ) + def test_simple_offset_zero(self, connection): + pass + + +class InsertBehaviorTest(_InsertBehaviorTest): + @pytest.mark.skip( + "Snowflake does not support inserting empty values, the value may be a literal or an expression." + ) + def test_empty_insert(self, connection): + pass + + @pytest.mark.skip( + "Snowflake does not support inserting empty values, The value may be a literal or an expression." + ) + def test_empty_insert_multiple(self, connection): + pass + + @pytest.mark.skip("Snowflake does not support returning in insert.") + def test_no_results_for_non_returning_insert(self, connection, style, executemany): + pass + + +# road to 2.0 +class TrueDivTest(_TrueDivTest): + @pytest.mark.skip("`//` not supported") + def test_floordiv_integer_bound(self, connection): + """Snowflake does not provide `//` arithmetic operator. + + https://docs.snowflake.com/en/sql-reference/operators-arithmetic. + """ + pass + + @pytest.mark.skip("`//` not supported") + def test_floordiv_integer(self, connection, left, right, expected): + """Snowflake does not provide `//` arithmetic operator. + + https://docs.snowflake.com/en/sql-reference/operators-arithmetic. + """ + pass + + +class TimeMicrosecondsTest(_TimeMicrosecondsTest): + def __init__(self): + super().__init__() + + +class DateTimeHistoricTest(_DateTimeHistoricTest): + def __init__(self): + super().__init__() + + +# 2. Patched Tests + + +class HasSequenceTest(_HasSequenceTest): + # Override the define_tables method as snowflake does not support 'nomaxvalue'/'nominvalue' + @classmethod + def define_tables(cls, metadata): + Sequence("user_id_seq", metadata=metadata) + # Replace Sequence("other_seq") creation as in the original test suite, + # the Sequence created with 'nomaxvalue' and 'nominvalue' + # which snowflake does not support: + # Sequence("other_seq", metadata=metadata, nomaxvalue=True, nominvalue=True) + Sequence("other_seq", metadata=metadata) + if testing.requires.schemas.enabled: + Sequence("user_id_seq", schema=config.test_schema, metadata=metadata) + Sequence("schema_seq", schema=config.test_schema, metadata=metadata) + Table( + "user_id_table", + metadata, + Column("id", Integer, primary_key=True), + ) + + +class LikeFunctionsTest(_LikeFunctionsTest): + @testing.requires.regexp_match + @testing.combinations( + ("a.cde.*", {1, 5, 6, 9}), + ("abc.*", {1, 5, 6, 9, 10}), + ("^abc.*", {1, 5, 6, 9, 10}), + (".*9cde.*", {8}), + ("^a.*", set(range(1, 11))), + (".*(b|c).*", set(range(1, 11))), + ("^(b|c).*", set()), + ) + def test_regexp_match(self, text, expected): + super().test_regexp_match(text, expected) + + def test_not_regexp_match(self): + col = self.tables.some_table.c.data + self._test(~col.regexp_match("a.cde.*"), {2, 3, 4, 7, 8, 10}) + + +class SimpleUpdateDeleteTest(_SimpleUpdateDeleteTest): + def test_update(self, connection): + t = self.tables.plain_pk + r = connection.execute(t.update().where(t.c.id == 2), dict(data="d2_new")) + assert not r.is_insert + # snowflake returns a row with numbers of rows updated and number of multi-joined rows updated + assert r.returns_rows + assert r.rowcount == 1 + + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (2, "d2_new"), (3, "d3")], + ) + + def test_delete(self, connection): + t = self.tables.plain_pk + r = connection.execute(t.delete().where(t.c.id == 2)) + assert not r.is_insert + # snowflake returns a row with number of rows deleted + assert r.returns_rows + assert r.rowcount == 1 + eq_( + connection.execute(t.select().order_by(t.c.id)).fetchall(), + [(1, "d1"), (3, "d3")], + ) + + +class CompositeKeyReflectionTest(_CompositeKeyReflectionTest): + @pytest.mark.xfail(reason="Fixing this would require behavior breaking change.") + def test_fk_column_order(self): + # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. + super().test_fk_column_order() + + @pytest.mark.xfail(reason="Fixing this would require behavior breaking change.") + def test_pk_column_order(self): + # Check https://snowflakecomputing.atlassian.net/browse/SNOW-640134 for details on breaking changes discussion. + super().test_pk_column_order() + + +class BizarroCharacterFKResolutionTest(_BizarroCharacterFKResolutionTest): + @testing.combinations( + ("id",), ("(3)",), ("col%p",), ("[brack]",), argnames="columnname" + ) + @testing.variation("use_composite", [True, False]) + @testing.combinations( + ("plain",), + ("(2)",), + ("[brackets]",), + argnames="tablename", + ) + def test_fk_ref(self, connection, metadata, use_composite, tablename, columnname): + super().test_fk_ref(connection, metadata, use_composite, tablename, columnname) diff --git a/tests/test_compiler.py b/tests/test_compiler.py index 0fd75c38..40207b41 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -5,7 +5,7 @@ from sqlalchemy import Integer, String, and_, func, select from sqlalchemy.schema import DropColumnComment, DropTableComment from sqlalchemy.sql import column, quoted_name, table -from sqlalchemy.testing import AssertsCompiledSQL +from sqlalchemy.testing.assertions import AssertsCompiledSQL from snowflake.sqlalchemy import snowdialect diff --git a/tests/test_copy.py b/tests/test_copy.py index e0752d4f..8dfcf286 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -4,7 +4,7 @@ import pytest from sqlalchemy import Column, Integer, MetaData, Sequence, String, Table -from sqlalchemy.sql import select, text +from sqlalchemy.sql import functions, select, text from snowflake.sqlalchemy import ( AWSBucket, @@ -58,8 +58,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_1) - == "COPY INTO 's3://backup' FROM python_tests_foods FILE_FORMAT=(TYPE=csv " - "ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION=" + == "COPY INTO 's3://backup' FROM python_tests_foods FILE_FORMAT=(TYPE=csv " + "ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION=" "(KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab' TYPE='AWS_SSE_KMS')" ) copy_stmt_2 = CopyIntoStorage( @@ -73,8 +73,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler): sql_compiler(copy_stmt_2) == "COPY INTO 's3://backup' FROM (SELECT python_tests_foods.id, " "python_tests_foods.name, python_tests_foods.quantity FROM python_tests_foods " - "WHERE python_tests_foods.id = 1) FILE_FORMAT=(TYPE=json COMPRESSION='zstd' " - "FILE_EXTENSION='json') CREDENTIALS=(AWS_ROLE='some_iam_role') " + "WHERE python_tests_foods.id = 1) FILE_FORMAT=(TYPE=json COMPRESSION='zstd' " + "FILE_EXTENSION='json') CREDENTIALS=(AWS_ROLE='some_iam_role') " "ENCRYPTION=(TYPE='AWS_SSE_S3')" ) copy_stmt_3 = CopyIntoStorage( @@ -87,7 +87,7 @@ def test_copy_into_location(engine_testaccount, sql_compiler): assert ( sql_compiler(copy_stmt_3) == "COPY INTO 'azure://snowflake.blob.core.windows.net/snowpile/backup' " - "FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) " + "FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) " "CREDENTIALS=(AZURE_SAS_TOKEN='token')" ) @@ -95,7 +95,7 @@ def test_copy_into_location(engine_testaccount, sql_compiler): assert ( sql_compiler(copy_stmt_3) == "COPY INTO 'azure://snowflake.blob.core.windows.net/snowpile/backup' " - "FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) " + "FROM python_tests_foods FILE_FORMAT=(TYPE=parquet SNAPPY_COMPRESSION=true) " "MAX_FILE_SIZE = 50000000 " "CREDENTIALS=(AZURE_SAS_TOKEN='token')" ) @@ -112,8 +112,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_4) - == "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv " - "ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION=" + == "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv " + "ESCAPE=None NULL_IF=('null', 'Null') RECORD_DELIMITER='|') ENCRYPTION=" "(KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab' TYPE='AWS_SSE_KMS')" ) @@ -126,8 +126,8 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_5) - == "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv " - "FIELD_DELIMITER=',') ENCRYPTION=" + == "COPY INTO python_tests_foods FROM 's3://backup' FILE_FORMAT=(TYPE=csv " + "FIELD_DELIMITER=',') ENCRYPTION=" "(KMS_KEY_ID='1234abcd-12ab-34cd-56ef-1234567890ab' TYPE='AWS_SSE_KMS')" ) @@ -138,7 +138,7 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_6) - == "COPY INTO @stage_name FROM python_tests_foods FILE_FORMAT=(TYPE=csv)" + == "COPY INTO @stage_name FROM python_tests_foods FILE_FORMAT=(TYPE=csv) " ) copy_stmt_7 = CopyIntoStorage( @@ -148,7 +148,38 @@ def test_copy_into_location(engine_testaccount, sql_compiler): ) assert ( sql_compiler(copy_stmt_7) - == "COPY INTO @name.stage_name/prefix/file FROM python_tests_foods FILE_FORMAT=(TYPE=csv)" + == "COPY INTO @name.stage_name/prefix/file FROM python_tests_foods FILE_FORMAT=(TYPE=csv) " + ) + + copy_stmt_8 = CopyIntoStorage( + from_=food_items, + into=ExternalStage(name="stage_name"), + partition_by=text("('YEAR=' || year)"), + ) + assert ( + sql_compiler(copy_stmt_8) + == "COPY INTO @stage_name FROM python_tests_foods PARTITION BY ('YEAR=' || year) " + ) + + copy_stmt_9 = CopyIntoStorage( + from_=food_items, + into=ExternalStage(name="stage_name"), + partition_by=functions.concat( + text("'YEAR='"), text(food_items.columns["name"].name) + ), + ) + assert ( + sql_compiler(copy_stmt_9) + == "COPY INTO @stage_name FROM python_tests_foods PARTITION BY concat('YEAR=', name) " + ) + + copy_stmt_10 = CopyIntoStorage( + from_=food_items, + into=ExternalStage(name="stage_name"), + partition_by="", + ) + assert ( + sql_compiler(copy_stmt_10) == "COPY INTO @stage_name FROM python_tests_foods " ) # NOTE Other than expect known compiled text, submit it to RegressionTests environment and expect them to fail, but @@ -231,7 +262,7 @@ def test_copy_into_storage_csv_extended(sql_compiler): result = sql_compiler(copy_into) expected = ( r"COPY INTO TEST_IMPORT " - r"FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata " + r"FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata " r"FILE_FORMAT=(TYPE=csv COMPRESSION='auto' DATE_FORMAT='AUTO' " r"ERROR_ON_COLUMN_COUNT_MISMATCH=True ESCAPE=None " r"ESCAPE_UNENCLOSED_FIELD='\134' FIELD_DELIMITER=',' " @@ -288,7 +319,7 @@ def test_copy_into_storage_parquet_named_format(sql_compiler): expected = ( "COPY INTO TEST_IMPORT " "FROM (SELECT $1:COL1::number, $1:COL2::varchar " - "FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet) " + "FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet) " "FILE_FORMAT=(format_name = parquet_file_format) force = TRUE" ) assert result == expected @@ -350,7 +381,7 @@ def test_copy_into_storage_parquet_files(sql_compiler): "COPY INTO TEST_IMPORT " "FROM (SELECT $1:COL1::number, $1:COL2::varchar " "FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet " - "(file_format => parquet_file_format)) FILES = ('foo.txt','bar.txt') " + "(file_format => parquet_file_format)) FILES = ('foo.txt','bar.txt') " "FORCE = true" ) assert result == expected @@ -412,6 +443,6 @@ def test_copy_into_storage_parquet_pattern(sql_compiler): "COPY INTO TEST_IMPORT " "FROM (SELECT $1:COL1::number, $1:COL2::varchar " "FROM @ML_POC.PUBLIC.AZURE_STAGE/testdata/out.parquet " - "(file_format => parquet_file_format)) FORCE = true PATTERN = '.*csv'" + "(file_format => parquet_file_format)) FORCE = true PATTERN = '.*csv'" ) assert result == expected diff --git a/tests/test_core.py b/tests/test_core.py index 6c8d7416..63f097db 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -30,12 +30,14 @@ UniqueConstraint, create_engine, dialects, + exc, insert, inspect, text, ) -from sqlalchemy.exc import DBAPIError, NoSuchTableError +from sqlalchemy.exc import DBAPIError, NoSuchTableError, OperationalError from sqlalchemy.sql import and_, not_, or_, select +from sqlalchemy.sql.ddl import CreateTable import snowflake.connector.errors import snowflake.sqlalchemy.snowdialect @@ -123,14 +125,26 @@ def test_connect_args(): Snowflake connect string supports account name as a replacement of host:port """ + server = "" + if "host" in CONNECTION_PARAMETERS and "port" in CONNECTION_PARAMETERS: + server = "{host}:{port}".format( + host=CONNECTION_PARAMETERS["host"], port=CONNECTION_PARAMETERS["port"] + ) + elif "account" in CONNECTION_PARAMETERS and "region" in CONNECTION_PARAMETERS: + server = "{account}.{region}".format( + account=CONNECTION_PARAMETERS["account"], + region=CONNECTION_PARAMETERS["region"], + ) + elif "account" in CONNECTION_PARAMETERS: + server = CONNECTION_PARAMETERS["account"] + engine = create_engine( - "snowflake://{user}:{password}@{host}:{port}/{database}/{schema}" + "snowflake://{user}:{password}@{server}/{database}/{schema}" "?account={account}&protocol={protocol}".format( user=CONNECTION_PARAMETERS["user"], account=CONNECTION_PARAMETERS["account"], password=CONNECTION_PARAMETERS["password"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], + server=server, database=CONNECTION_PARAMETERS["database"], schema=CONNECTION_PARAMETERS["schema"], protocol=CONNECTION_PARAMETERS["protocol"], @@ -141,32 +155,14 @@ def test_connect_args(): finally: engine.dispose() - engine = create_engine( - URL( - user=CONNECTION_PARAMETERS["user"], - password=CONNECTION_PARAMETERS["password"], - account=CONNECTION_PARAMETERS["account"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], - protocol=CONNECTION_PARAMETERS["protocol"], - ) - ) + engine = create_engine(URL(**CONNECTION_PARAMETERS)) try: verify_engine_connection(engine) finally: engine.dispose() - - engine = create_engine( - URL( - user=CONNECTION_PARAMETERS["user"], - password=CONNECTION_PARAMETERS["password"], - account=CONNECTION_PARAMETERS["account"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], - protocol=CONNECTION_PARAMETERS["protocol"], - warehouse="testwh", - ) - ) + parameters = {**CONNECTION_PARAMETERS} + parameters["warehouse"] = "testwh" + engine = create_engine(URL(**parameters)) try: verify_engine_connection(engine) finally: @@ -174,14 +170,10 @@ def test_connect_args(): def test_boolean_query_argument_parsing(): + engine = create_engine( URL( - user=CONNECTION_PARAMETERS["user"], - password=CONNECTION_PARAMETERS["password"], - account=CONNECTION_PARAMETERS["account"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], - protocol=CONNECTION_PARAMETERS["protocol"], + **CONNECTION_PARAMETERS, validate_default_parameters=True, ) ) @@ -406,16 +398,6 @@ def test_insert_tables(engine_testaccount): str(users.join(addresses)) == "users JOIN addresses ON " "users.id = addresses.user_id" ) - assert ( - str( - users.join( - addresses, - addresses.c.email_address.like(users.c.name + "%"), - ) - ) - == "users JOIN addresses " - "ON addresses.email_address LIKE users.name || :name_1" - ) s = select(users.c.fullname).select_from( users.join( @@ -444,7 +426,7 @@ def test_table_does_not_exist(engine_testaccount): """ meta = MetaData() with pytest.raises(NoSuchTableError): - Table("does_not_exist", meta, autoload=True, autoload_with=engine_testaccount) + Table("does_not_exist", meta, autoload_with=engine_testaccount) @pytest.mark.skip( @@ -470,9 +452,7 @@ def test_reflextion(engine_testaccount): ) try: meta = MetaData() - user_reflected = Table( - "user", meta, autoload=True, autoload_with=engine_testaccount - ) + user_reflected = Table("user", meta, autoload_with=engine_testaccount) assert user_reflected.c == ["user.id", "user.name", "user.fullname"] finally: conn.execute("DROP TABLE IF EXISTS user") @@ -514,19 +494,20 @@ def test_inspect_column(engine_testaccount): users.drop(engine_testaccount) -def test_get_indexes(engine_testaccount): +def test_get_indexes(engine_testaccount, db_parameters): """ Tests get indexes - NOTE: Snowflake doesn't support indexes + NOTE: Only Snowflake Hybrid Tables support indexes """ + schema = db_parameters["schema"] metadata = MetaData() users, addresses = _create_users_addresses_tables_without_sequence( engine_testaccount, metadata ) try: inspector = inspect(engine_testaccount) - assert inspector.get_indexes("users") == [] + assert inspector.get_indexes("users", schema) == [] finally: addresses.drop(engine_testaccount) @@ -710,6 +691,39 @@ def test_create_table_with_cluster_by(engine_testaccount): user.drop(engine_testaccount) +def test_create_table_with_cluster_by_with_expression(engine_testaccount): + metadata = MetaData() + Table( + "clustered_user", + metadata, + Column("Id", Integer, primary_key=True), + Column("name", String), + snowflake_clusterby=["Id", "name", text('"Id" > 5')], + ) + metadata.create_all(engine_testaccount) + try: + inspector = inspect(engine_testaccount) + columns_in_table = inspector.get_columns("clustered_user") + assert columns_in_table[0]["name"] == "Id", "name" + finally: + metadata.drop_all(engine_testaccount) + + +def test_compile_table_with_cluster_by_with_expression(sql_compiler, snapshot): + metadata = MetaData() + user = Table( + "clustered_user", + metadata, + Column("Id", Integer, primary_key=True), + Column("name", String), + snowflake_clusterby=["Id", "name", text('"Id" > 5')], + ) + + create_table = CreateTable(user) + + assert sql_compiler(create_table) == snapshot + + def test_view_names(engine_testaccount): """ Tests all views @@ -1071,28 +1085,16 @@ def harass_inspector(): assert outcome -@pytest.mark.timeout(15) -def test_region(): - engine = create_engine( - URL( - user="testuser", - password="testpassword", - account="testaccount", - region="eu-central-1", - login_timeout=5, - ) - ) - try: - engine.connect() - pytest.fail("should not run") - except Exception as ex: - assert ex.orig.errno == 250001 - assert "Failed to connect to DB" in ex.orig.msg - assert "testaccount.eu-central-1.snowflakecomputing.com" in ex.orig.msg - - -@pytest.mark.timeout(15) -def test_azure(): +@pytest.mark.skip(reason="Testaccount is not available, it returns 404 error.") +@pytest.mark.timeout(10) +@pytest.mark.parametrize( + "region", + ( + pytest.param("eu-central-1", id="region"), + pytest.param("east-us-2.azure", id="azure"), + ), +) +def test_connection_timeout_error(region): engine = create_engine( URL( user="testuser", @@ -1102,13 +1104,13 @@ def test_azure(): login_timeout=5, ) ) - try: + + with pytest.raises(OperationalError) as excinfo: engine.connect() - pytest.fail("should not run") - except Exception as ex: - assert ex.orig.errno == 250001 - assert "Failed to connect to DB" in ex.orig.msg - assert "testaccount.east-us-2.azure.snowflakecomputing.com" in ex.orig.msg + + assert excinfo.value.orig.errno == 250001 + assert "Could not connect to Snowflake backend" in excinfo.value.orig.msg + assert region not in excinfo.value.orig.msg def test_load_dialect(): @@ -1535,13 +1537,11 @@ def test_too_many_columns_detection(engine_testaccount, db_parameters): metadata.create_all(engine_testaccount) inspector = inspect(engine_testaccount) # Do test - original_execute = inspector.bind.execute + connection = inspector.bind.connect() + original_execute = connection.execute - def mock_helper(command, *args, **kwargs): - if "_get_schema_columns" in command: - # Creating exception exactly how SQLAlchemy does - raise DBAPIError.instance( - """ + exception_instance = DBAPIError.instance( + """ SELECT /* sqlalchemy:_get_schema_columns */ ic.table_name, ic.column_name, @@ -1556,24 +1556,32 @@ def mock_helper(command, *args, **kwargs): FROM information_schema.columns ic WHERE ic.table_schema='schema_name' ORDER BY ic.ordinal_position""", - {"table_schema": "TESTSCHEMA"}, - ProgrammingError( - "Information schema query returned too much data. Please repeat query with more " - "selective predicates.", - 90030, - ), - Error, - hide_parameters=False, - connection_invalidated=False, - dialect=SnowflakeDialect(), - ismulti=None, - ) + {"table_schema": "TESTSCHEMA"}, + ProgrammingError( + "Information schema query returned too much data. Please repeat query with more " + "selective predicates.", + 90030, + ), + Error, + hide_parameters=False, + connection_invalidated=False, + dialect=SnowflakeDialect(), + ismulti=None, + ) + + def mock_helper(command, *args, **kwargs): + if "_get_schema_columns" in command.text: + # Creating exception exactly how SQLAlchemy does + raise exception_instance else: return original_execute(command, *args, **kwargs) - with patch.object(inspector.bind, "execute", side_effect=mock_helper): - column_metadata = inspector.get_columns("users", db_parameters["schema"]) - assert len(column_metadata) == 4 + with patch.object(engine_testaccount, "connect") as conn: + conn.return_value = connection + with patch.object(connection, "execute", side_effect=mock_helper): + with pytest.raises(exc.ProgrammingError) as exception: + inspector.get_columns("users", db_parameters["schema"]) + assert exception.value.orig == exception_instance.orig # Clean up metadata.drop_all(engine_testaccount) @@ -1615,13 +1623,11 @@ def test_column_type_schema(engine_testaccount): """ ) - table_reflected = Table( - table_name, MetaData(), autoload=True, autoload_with=conn - ) + table_reflected = Table(table_name, MetaData(), autoload_with=conn) columns = table_reflected.columns - assert ( - len(columns) == len(ischema_names_baseline) - 1 - ) # -1 because FIXED is not supported + assert len(columns) == ( + len(ischema_names_baseline) - 2 + ) # -2 because FIXED and MAP is not supported def test_result_type_and_value(engine_testaccount): @@ -1638,9 +1644,7 @@ def test_result_type_and_value(engine_testaccount): ) """ ) - table_reflected = Table( - table_name, MetaData(), autoload=True, autoload_with=conn - ) + table_reflected = Table(table_name, MetaData(), autoload_with=conn) current_date = date.today() current_utctime = datetime.utcnow() current_localtime = pytz.utc.localize(current_utctime, is_dst=False).astimezone( @@ -1801,30 +1805,14 @@ def test_normalize_and_denormalize_empty_string_column_name(engine_testaccount): def test_snowflake_sqlalchemy_as_valid_client_type(): engine = create_engine( - URL( - user=CONNECTION_PARAMETERS["user"], - password=CONNECTION_PARAMETERS["password"], - account=CONNECTION_PARAMETERS["account"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], - protocol=CONNECTION_PARAMETERS["protocol"], - ), + URL(**CONNECTION_PARAMETERS), connect_args={"internal_application_name": "UnknownClient"}, ) with engine.connect() as conn: with pytest.raises(snowflake.connector.errors.NotSupportedError): conn.exec_driver_sql("select 1").cursor.fetch_pandas_all() - engine = create_engine( - URL( - user=CONNECTION_PARAMETERS["user"], - password=CONNECTION_PARAMETERS["password"], - account=CONNECTION_PARAMETERS["account"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], - protocol=CONNECTION_PARAMETERS["protocol"], - ) - ) + engine = create_engine(URL(**CONNECTION_PARAMETERS)) with engine.connect() as conn: conn.exec_driver_sql("select 1").cursor.fetch_pandas_all() @@ -1855,16 +1843,7 @@ def test_snowflake_sqlalchemy_as_valid_client_type(): "3.0.0", (type(None), str), ) - engine = create_engine( - URL( - user=CONNECTION_PARAMETERS["user"], - password=CONNECTION_PARAMETERS["password"], - account=CONNECTION_PARAMETERS["account"], - host=CONNECTION_PARAMETERS["host"], - port=CONNECTION_PARAMETERS["port"], - protocol=CONNECTION_PARAMETERS["protocol"], - ) - ) + engine = create_engine(URL(**CONNECTION_PARAMETERS)) with engine.connect() as conn: conn.exec_driver_sql("select 1").cursor.fetch_pandas_all() assert ( diff --git a/tests/test_custom_functions.py b/tests/test_custom_functions.py new file mode 100644 index 00000000..2a1e1cb5 --- /dev/null +++ b/tests/test_custom_functions.py @@ -0,0 +1,25 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. + +import pytest +from sqlalchemy import func + +from snowflake.sqlalchemy import snowdialect + + +def test_flatten_does_not_render_params(): + """This behavior is for backward compatibility. + + In previous version params were not rendered. + In future this behavior will change. + """ + flat = func.flatten("[1, 2]", outer=True) + res = flat.compile(dialect=snowdialect.dialect()) + + assert str(res) == "flatten(%(flatten_1)s)" + + +def test_flatten_emits_warning(): + expected_warning = "For backward compatibility params are not rendered." + with pytest.warns(DeprecationWarning, match=expected_warning): + func.flatten().compile(dialect=snowdialect.dialect()) diff --git a/tests/test_custom_types.py b/tests/test_custom_types.py index a997ffe8..3961a5d3 100644 --- a/tests/test_custom_types.py +++ b/tests/test_custom_types.py @@ -2,7 +2,10 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -from snowflake.sqlalchemy import custom_types +import pytest +from sqlalchemy import Column, Integer, MetaData, Table, text + +from snowflake.sqlalchemy import TEXT, custom_types def test_string_conversions(): @@ -34,3 +37,31 @@ def test_string_conversions(): sample = getattr(custom_types, type_)() if type_ in sf_custom_types: assert type_ == str(sample) + + +@pytest.mark.feature_max_lob_size +def test_create_table_with_text_type(engine_testaccount): + metadata = MetaData() + table_name = "test_max_lob_size_0" + test_max_lob_size = Table( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("full_name", TEXT(), server_default=text("id::varchar")), + ) + + metadata.create_all(engine_testaccount) + try: + assert test_max_lob_size is not None + + with engine_testaccount.connect() as conn: + with conn.begin(): + query = text(f"SELECT GET_DDL('TABLE', '{table_name}')") + result = conn.execute(query) + row = str(result.mappings().fetchone()) + assert ( + "VARCHAR(134217728)" in row + ), f"Expected VARCHAR(134217728) in {row}" + + finally: + test_max_lob_size.drop(engine_testaccount) diff --git a/tests/test_index_reflection.py b/tests/test_index_reflection.py new file mode 100644 index 00000000..09f5cfe7 --- /dev/null +++ b/tests/test_index_reflection.py @@ -0,0 +1,34 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest +from sqlalchemy import MetaData +from sqlalchemy.engine import reflection + + +@pytest.mark.aws +def test_indexes_reflection(engine_testaccount, db_parameters, sql_compiler): + metadata = MetaData() + + table_name = "test_hybrid_table_2" + index_name = "INDEX_NAME_2" + schema = db_parameters["schema"] + + create_table_sql = f""" + CREATE HYBRID TABLE {table_name} (id INT primary key, name VARCHAR, INDEX {index_name} (name)); + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + insp = reflection.Inspector.from_engine(engine_testaccount) + + try: + with engine_testaccount.connect(): + # Prefixes reflection not supported, example: "HYBRID, DYNAMIC" + indexes = insp.get_indexes(table_name, schema) + assert len(indexes) == 1 + assert indexes[0].get("name") == index_name + + finally: + metadata.drop_all(engine_testaccount) diff --git a/tests/test_orm.py b/tests/test_orm.py index e485d737..cb3a7768 100644 --- a/tests/test_orm.py +++ b/tests/test_orm.py @@ -7,20 +7,24 @@ import pytest from sqlalchemy import ( + TEXT, Column, Enum, ForeignKey, Integer, Sequence, String, + exc, func, select, text, ) from sqlalchemy.orm import Session, declarative_base, relationship +from snowflake.sqlalchemy import HybridTable -def test_basic_orm(engine_testaccount, run_v20_sqlalchemy): + +def test_basic_orm(engine_testaccount): """ Tests declarative """ @@ -46,7 +50,6 @@ def __repr__(self): ed_user = User(name="ed", fullname="Edward Jones") session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(ed_user) our_user = session.query(User).filter_by(name="ed").first() @@ -56,14 +59,15 @@ def __repr__(self): Base.metadata.drop_all(engine_testaccount) -def test_orm_one_to_many_relationship(engine_testaccount, run_v20_sqlalchemy): +def test_orm_one_to_many_relationship(engine_testaccount, db_parameters): """ Tests One to Many relationship """ Base = declarative_base() + prefix = "tbl_" class User(Base): - __tablename__ = "user" + __tablename__ = prefix + "user" id = Column(Integer, Sequence("user_id_seq"), primary_key=True) name = Column(String) @@ -73,13 +77,13 @@ def __repr__(self): return f"" class Address(Base): - __tablename__ = "address" + __tablename__ = prefix + "address" id = Column(Integer, Sequence("address_id_seq"), primary_key=True) email_address = Column(String, nullable=False) - user_id = Column(Integer, ForeignKey("user.id")) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) - user = relationship("User", backref="addresses") + user = relationship(User, backref="addresses") def __repr__(self): return f"" @@ -97,7 +101,6 @@ def __repr__(self): ] session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(jack) # cascade each Address into the Session as well session.commit() @@ -124,14 +127,143 @@ def __repr__(self): Base.metadata.drop_all(engine_testaccount) -def test_delete_cascade(engine_testaccount, run_v20_sqlalchemy): +@pytest.mark.aws +def test_orm_one_to_many_relationship_with_hybrid_table(engine_testaccount, snapshot): + """ + Tests One to Many relationship + """ + Base = declarative_base() + + class User(Base): + __tablename__ = "hb_tbl_user" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + name = Column(String) + fullname = Column(String) + + def __repr__(self): + return f"" + + class Address(Base): + __tablename__ = "hb_tbl_address" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) + + id = Column(Integer, Sequence("address_id_seq"), primary_key=True) + email_address = Column(String, nullable=False) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) + + user = relationship(User, backref="addresses") + + def __repr__(self): + return f"" + + Base.metadata.create_all(engine_testaccount) + + try: + jack = User(name="jack", fullname="Jack Bean") + assert jack.addresses == [], "one to many record is empty list" + + jack.addresses = [ + Address(email_address="jack@gmail.com"), + Address(email_address="j25@yahoo.com"), + Address(email_address="jack@hotmail.com"), + ] + + session = Session(bind=engine_testaccount) + session.add(jack) # cascade each Address into the Session as well + session.commit() + + session.delete(jack) + + with pytest.raises(exc.ProgrammingError) as exc_info: + session.query(Address).all() + + assert exc_info.value == snapshot, "Iceberg Table enforce FK constraint" + + finally: + Base.metadata.drop_all(engine_testaccount) + + +def test_delete_cascade(engine_testaccount): """ Test delete cascade """ Base = declarative_base() + prefix = "tbl_" class User(Base): - __tablename__ = "user" + __tablename__ = prefix + "user" + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + name = Column(String) + fullname = Column(String) + + addresses = relationship( + "Address", back_populates="user", cascade="all, delete, delete-orphan" + ) + + def __repr__(self): + return f"" + + class Address(Base): + __tablename__ = prefix + "address" + + id = Column(Integer, Sequence("address_id_seq"), primary_key=True) + email_address = Column(String, nullable=False) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) + + user = relationship(User, back_populates="addresses") + + def __repr__(self): + return f"" + + Base.metadata.create_all(engine_testaccount) + + try: + jack = User(name="jack", fullname="Jack Bean") + assert jack.addresses == [], "one to many record is empty list" + + jack.addresses = [ + Address(email_address="jack@gmail.com"), + Address(email_address="j25@yahoo.com"), + Address(email_address="jack@hotmail.com"), + ] + + session = Session(bind=engine_testaccount) + session.add(jack) # cascade each Address into the Session as well + session.commit() + + got_jack = session.query(User).first() + assert got_jack == jack + + session.delete(jack) + got_addresses = session.query(Address).all() + assert len(got_addresses) == 0, "no address record" + finally: + Base.metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_delete_cascade_hybrid_table(engine_testaccount): + """ + Test delete cascade + """ + Base = declarative_base() + prefix = "hb_tbl_" + + class User(Base): + __tablename__ = prefix + "user" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) id = Column(Integer, Sequence("user_id_seq"), primary_key=True) name = Column(String) @@ -145,13 +277,17 @@ def __repr__(self): return f"" class Address(Base): - __tablename__ = "address" + __tablename__ = prefix + "address" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return HybridTable(name, metadata, *arg, **kw) id = Column(Integer, Sequence("address_id_seq"), primary_key=True) email_address = Column(String, nullable=False) - user_id = Column(Integer, ForeignKey("user.id")) + user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id")) - user = relationship("User", back_populates="addresses") + user = relationship(User, back_populates="addresses") def __repr__(self): return f"" @@ -169,7 +305,6 @@ def __repr__(self): ] session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(jack) # cascade each Address into the Session as well session.commit() @@ -189,7 +324,7 @@ def __repr__(self): WIP """, ) -def test_orm_query(engine_testaccount, run_v20_sqlalchemy): +def test_orm_query(engine_testaccount): """ Tests ORM query """ @@ -210,7 +345,6 @@ def __repr__(self): # TODO: insert rows session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy # TODO: query.all() for name, fullname in session.query(User.name, User.fullname): @@ -220,7 +354,7 @@ def __repr__(self): # MultipleResultsFound if not one result -def test_schema_including_db(engine_testaccount, db_parameters, run_v20_sqlalchemy): +def test_schema_including_db(engine_testaccount, db_parameters): """ Test schema parameter including database separated by a dot. """ @@ -243,7 +377,6 @@ class User(Base): ed_user = User(name="ed", fullname="Edward Jones") session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy session.add(ed_user) ret_user = session.query(User.id, User.name).first() @@ -255,7 +388,7 @@ class User(Base): Base.metadata.drop_all(engine_testaccount) -def test_schema_including_dot(engine_testaccount, db_parameters, run_v20_sqlalchemy): +def test_schema_including_dot(engine_testaccount, db_parameters): """ Tests pseudo schema name including dot. """ @@ -276,7 +409,6 @@ class User(Base): fullname = Column(String) session = Session(bind=engine_testaccount) - session.future = run_v20_sqlalchemy query = session.query(User.id) assert str(query).startswith( 'SELECT {db}."{schema}.{schema}".{db}.users.id'.format( @@ -285,9 +417,7 @@ class User(Base): ) -def test_schema_translate_map( - engine_testaccount, db_parameters, sql_compiler, run_v20_sqlalchemy -): +def test_schema_translate_map(engine_testaccount, db_parameters): """ Test schema translate map execution option works replaces schema correctly """ @@ -310,7 +440,6 @@ class User(Base): schema_translate_map={schema_map: db_parameters["schema"]} ) as con: session = Session(bind=con) - session.future = run_v20_sqlalchemy with con.begin(): Base.metadata.create_all(con) try: @@ -367,18 +496,29 @@ class Department(Base): .select_from(Employee) .outerjoin(sub) ) - assert ( - str(query.compile(engine_testaccount)).replace("\n", "") - == "SELECT employees.employee_id, departments.department_id " + compiled_stmts = ( + # v1.x + "SELECT employees.employee_id, departments.department_id " "FROM departments, employees LEFT OUTER JOIN LATERAL " "(SELECT departments.department_id AS department_id, departments.name AS name " - "FROM departments) AS anon_1" + "FROM departments) AS anon_1", + # v2.x + "SELECT employees.employee_id, departments.department_id " + "FROM employees LEFT OUTER JOIN LATERAL " + "(SELECT departments.department_id AS department_id, departments.name AS name " + "FROM departments) AS anon_1, departments", ) + compiled_stmt = str(query.compile(engine_testaccount)).replace("\n", "") + assert compiled_stmt in compiled_stmts + with caplog.at_level(logging.DEBUG): assert [res for res in session.execute(query)] assert ( "SELECT employees.employee_id, departments.department_id FROM departments" in caplog.text + ) or ( + "SELECT employees.employee_id, departments.department_id FROM employees" + in caplog.text ) @@ -411,3 +551,34 @@ class Employee(Base): '[SELECT "Employee".uid FROM "Employee" JOIN LATERAL flatten(PARSE_JSON("Employee"' in caplog.text ) + + +@pytest.mark.feature_max_lob_size +def test_basic_table_with_large_lob_size_in_memory(engine_testaccount, sql_compiler): + Base = declarative_base() + + class User(Base): + __tablename__ = "user" + + id = Column(Integer, primary_key=True) + full_name = Column(TEXT(), server_default=text("id::varchar")) + + def __repr__(self): + return f"" + + Base.metadata.create_all(engine_testaccount) + + try: + assert User.__table__ is not None + + with engine_testaccount.connect() as conn: + with conn.begin(): + query = text(f"SELECT GET_DDL('TABLE', '{User.__tablename__}')") + result = conn.execute(query) + row = str(result.mappings().fetchone()) + assert ( + "VARCHAR(134217728)" in row + ), f"Expected VARCHAR(134217728) in {row}" + + finally: + Base.metadata.drop_all(engine_testaccount) diff --git a/tests/test_pandas.py b/tests/test_pandas.py index ef64d65e..2a6b9f1b 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -27,6 +27,7 @@ from snowflake.connector import ProgrammingError from snowflake.connector.pandas_tools import make_pd_writer, pd_writer +from snowflake.sqlalchemy.compat import IS_VERSION_20 def _create_users_addresses_tables(engine_testaccount, metadata): @@ -168,7 +169,7 @@ def test_no_indexes(engine_testaccount, db_parameters): con=conn, if_exists="replace", ) - assert str(exc.value) == "Snowflake does not support indexes" + assert str(exc.value) == "Only Snowflake Hybrid Tables supports indexes" def test_timezone(db_parameters, engine_testaccount, engine_testaccount_with_numpy): @@ -240,8 +241,8 @@ def test_timezone(db_parameters, engine_testaccount, engine_testaccount_with_num conn.exec_driver_sql(f"DROP TABLE {test_table_name};") -def test_pandas_writeback(engine_testaccount, run_v20_sqlalchemy): - if run_v20_sqlalchemy and sys.version_info < (3, 8): +def test_pandas_writeback(engine_testaccount): + if IS_VERSION_20 and sys.version_info < (3, 8): pytest.skip( "In Python 3.7, this test depends on pandas features of which the implementation is incompatible with sqlachemy 2.0, and pandas does not support Python 3.7 anymore." ) @@ -352,8 +353,8 @@ def test_pandas_invalid_make_pd_writer(engine_testaccount): ) -def test_percent_signs(engine_testaccount, run_v20_sqlalchemy): - if run_v20_sqlalchemy and sys.version_info < (3, 8): +def test_percent_signs(engine_testaccount): + if IS_VERSION_20 and sys.version_info < (3, 8): pytest.skip( "In Python 3.7, this test depends on pandas features of which the implementation is incompatible with sqlachemy 2.0, and pandas does not support Python 3.7 anymore." ) @@ -376,7 +377,7 @@ def test_percent_signs(engine_testaccount, run_v20_sqlalchemy): not_like_sql = f"select * from {table_name} where c2 not like '%b%'" like_sql = f"select * from {table_name} where c2 like '%b%'" calculate_sql = "SELECT 1600 % 400 AS a, 1599 % 400 as b" - if run_v20_sqlalchemy: + if IS_VERSION_20: not_like_sql = sqlalchemy.text(not_like_sql) like_sql = sqlalchemy.text(like_sql) calculate_sql = sqlalchemy.text(calculate_sql) diff --git a/tests/test_qmark.py b/tests/test_qmark.py index f98fa7d3..3761181a 100644 --- a/tests/test_qmark.py +++ b/tests/test_qmark.py @@ -12,11 +12,11 @@ THIS_DIR = os.path.dirname(os.path.realpath(__file__)) -def test_qmark_bulk_insert(run_v20_sqlalchemy, engine_testaccount_with_qmark): +def test_qmark_bulk_insert(engine_testaccount_with_qmark): """ Bulk insert using qmark paramstyle """ - if run_v20_sqlalchemy and sys.version_info < (3, 8): + if sys.version_info < (3, 8): pytest.skip( "In Python 3.7, this test depends on pandas features of which the implementation is incompatible with sqlachemy 2.0, and pandas does not support Python 3.7 anymore." ) diff --git a/tests/test_structured_datatypes.py b/tests/test_structured_datatypes.py new file mode 100644 index 00000000..4ea0892b --- /dev/null +++ b/tests/test_structured_datatypes.py @@ -0,0 +1,271 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import pytest +from sqlalchemy import ( + Column, + Integer, + MetaData, + Sequence, + Table, + cast, + exc, + inspect, + text, +) +from sqlalchemy.orm import Session, declarative_base +from sqlalchemy.sql import select +from sqlalchemy.sql.ddl import CreateTable + +from snowflake.sqlalchemy import NUMBER, IcebergTable, SnowflakeTable +from snowflake.sqlalchemy.custom_types import MAP, TEXT +from snowflake.sqlalchemy.exc import StructuredTypeNotSupportedInTableColumnsError + + +def test_compile_table_with_cluster_by_with_expression(sql_compiler, snapshot): + metadata = MetaData() + user_table = Table( + "clustered_user", + metadata, + Column("Id", Integer, primary_key=True), + Column("name", MAP(NUMBER(), TEXT())), + ) + + create_table = CreateTable(user_table) + + assert sql_compiler(create_table) == snapshot + + +@pytest.mark.requires_external_volume +def test_create_table_structured_datatypes( + engine_testaccount, external_volume, base_location +): + metadata = MetaData() + table_name = "test_map0" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("map_id", MAP(NUMBER(10, 0), TEXT())), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + try: + assert test_map is not None + finally: + test_map.drop(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_insert_map(engine_testaccount, external_volume, base_location, snapshot): + metadata = MetaData() + table_name = "test_insert_map" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("map_id", MAP(NUMBER(10, 0), TEXT())), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + + try: + with engine_testaccount.connect() as conn: + slt = select( + 1, + cast( + text("{'100':'item1', '200':'item2'}"), MAP(NUMBER(10, 0), TEXT()) + ), + ) + ins = test_map.insert().from_select(["id", "map_id"], slt) + conn.execute(ins) + + results = conn.execute(test_map.select()) + data = results.fetchmany() + results.close() + snapshot.assert_match(data) + finally: + test_map.drop(engine_testaccount) + + +@pytest.mark.requires_external_volume +@pytest.mark.parametrize( + "structured_type", + [ + MAP(NUMBER(10, 0), TEXT()), + MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), TEXT())), + ], +) +def test_inspect_structured_data_types( + engine_testaccount, external_volume, base_location, snapshot, structured_type +): + metadata = MetaData() + table_name = "test_st_types" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("map_id", structured_type), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + + try: + inspecter = inspect(engine_testaccount) + columns = inspecter.get_columns(table_name) + + assert isinstance(columns[0]["type"], NUMBER) + assert isinstance(columns[1]["type"], MAP) + assert columns == snapshot + + finally: + test_map.drop(engine_testaccount) + + +@pytest.mark.requires_external_volume +@pytest.mark.parametrize( + "structured_type", + [ + "MAP(NUMBER(10, 0), VARCHAR)", + "MAP(NUMBER(10, 0), MAP(NUMBER(10, 0), VARCHAR))", + ], +) +def test_reflect_structured_data_types( + engine_testaccount, + external_volume, + base_location, + snapshot, + structured_type, + sql_compiler, +): + metadata = MetaData() + table_name = "test_reflect_st_types" + create_table_sql = f""" +CREATE OR REPLACE ICEBERG TABLE {table_name} ( + id number(38,0) primary key, + map_id {structured_type}) +CATALOG = 'SNOWFLAKE' +EXTERNAL_VOLUME = '{external_volume}' +BASE_LOCATION = '{base_location}'; + """ + + with engine_testaccount.connect() as connection: + connection.exec_driver_sql(create_table_sql) + + iceberg_table = IcebergTable(table_name, metadata, autoload_with=engine_testaccount) + constraint = iceberg_table.constraints.pop() + constraint.name = "constraint_name" + iceberg_table.constraints.add(constraint) + + try: + with engine_testaccount.connect(): + value = CreateTable(iceberg_table) + + actual = sql_compiler(value) + + assert actual == snapshot + + finally: + metadata.drop_all(engine_testaccount) + + +@pytest.mark.requires_external_volume +def test_insert_map_orm( + sql_compiler, external_volume, base_location, engine_testaccount, snapshot +): + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __tablename__ = "test_iceberg_table_orm" + + @classmethod + def __table_cls__(cls, name, metadata, *arg, **kw): + return IcebergTable(name, metadata, *arg, **kw) + + __table_args__ = { + "external_volume": external_volume, + "base_location": base_location, + } + + id = Column(Integer, Sequence("user_id_seq"), primary_key=True) + map_id = Column(MAP(NUMBER(10, 0), TEXT())) + + def __repr__(self): + return f"({self.id!r}, {self.name!r})" + + Base.metadata.create_all(engine_testaccount) + + try: + cast_expr = cast( + text("{'100':'item1', '200':'item2'}"), MAP(NUMBER(10, 0), TEXT()) + ) + instance = TestIcebergTableOrm(id=0, map_id=cast_expr) + session.add(instance) + with pytest.raises(exc.ProgrammingError) as programming_error: + session.commit() + # TODO: Support variant in insert statement + assert str(programming_error.value.orig) == snapshot + finally: + Base.metadata.drop_all(engine_testaccount) + + +def test_snowflake_tables_with_structured_types(sql_compiler): + metadata = MetaData() + with pytest.raises( + StructuredTypeNotSupportedInTableColumnsError + ) as programming_error: + SnowflakeTable( + "clustered_user", + metadata, + Column("Id", Integer, primary_key=True), + Column("name", MAP(NUMBER(10, 0), TEXT())), + ) + assert programming_error is not None + + +@pytest.mark.requires_external_volume +def test_select_map_orm(engine_testaccount, external_volume, base_location, snapshot): + metadata = MetaData() + table_name = "test_select_map_orm" + test_map = IcebergTable( + table_name, + metadata, + Column("id", Integer, primary_key=True), + Column("map_id", MAP(NUMBER(10, 0), TEXT())), + external_volume=external_volume, + base_location=base_location, + ) + metadata.create_all(engine_testaccount) + + with engine_testaccount.connect() as conn: + slt1 = select( + 2, + cast(text("{'100':'item1', '200':'item2'}"), MAP(NUMBER(10, 0), TEXT())), + ) + slt2 = select( + 1, + cast(text("{'100':'item1', '200':'item2'}"), MAP(NUMBER(10, 0), TEXT())), + ).union_all(slt1) + ins = test_map.insert().from_select(["id", "map_id"], slt2) + conn.execute(ins) + conn.commit() + + Base = declarative_base() + session = Session(bind=engine_testaccount) + + class TestIcebergTableOrm(Base): + __table__ = test_map + + def __repr__(self): + return f"({self.id!r}, {self.map_id!r})" + + try: + data = session.query(TestIcebergTableOrm).all() + snapshot.assert_match(data) + finally: + test_map.drop(engine_testaccount) diff --git a/tests/test_unit_structured_types.py b/tests/test_unit_structured_types.py new file mode 100644 index 00000000..c7bcd6ef --- /dev/null +++ b/tests/test_unit_structured_types.py @@ -0,0 +1,73 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# +import pytest + +from snowflake.sqlalchemy import NUMBER +from snowflake.sqlalchemy.custom_types import MAP, TEXT +from src.snowflake.sqlalchemy.parser.custom_type_parser import ( + extract_parameters, + parse_type, +) + + +def test_compile_map_with_not_null(snapshot): + user_table = MAP(NUMBER(10, 0), TEXT(), not_null=True) + assert user_table.compile() == snapshot + + +def test_extract_parameters(): + example = "a, b(c, d, f), d" + assert extract_parameters(example) == ["a", "b(c, d, f)", "d"] + + +@pytest.mark.parametrize( + "input_type, expected_type", + [ + ("BIGINT", "BIGINT"), + ("BINARY(16)", "BINARY(16)"), + ("BOOLEAN", "BOOLEAN"), + ("CHAR(5)", "CHAR(5)"), + ("CHARACTER(5)", "CHAR(5)"), + ("DATE", "DATE"), + ("DATETIME(3)", "DATETIME"), + ("DECIMAL(10, 2)", "DECIMAL(10, 2)"), + ("DEC(10, 2)", "DECIMAL(10, 2)"), + ("DOUBLE", "FLOAT"), + ("FLOAT", "FLOAT"), + ("FIXED(10, 2)", "DECIMAL(10, 2)"), + ("INT", "INTEGER"), + ("INTEGER", "INTEGER"), + ("NUMBER(12, 4)", "DECIMAL(12, 4)"), + ("REAL", "REAL"), + ("BYTEINT", "SMALLINT"), + ("SMALLINT", "SMALLINT"), + ("STRING(255)", "VARCHAR(255)"), + ("TEXT(255)", "VARCHAR(255)"), + ("VARCHAR(255)", "VARCHAR(255)"), + ("TIME(6)", "TIME"), + ("TIMESTAMP(3)", "TIMESTAMP"), + ("TIMESTAMP_TZ(3)", "TIMESTAMP_TZ"), + ("TIMESTAMP_LTZ(3)", "TIMESTAMP_LTZ"), + ("TIMESTAMP_NTZ(3)", "TIMESTAMP_NTZ"), + ("TINYINT", "SMALLINT"), + ("VARBINARY(16)", "BINARY(16)"), + ("VARCHAR(255)", "VARCHAR(255)"), + ("VARIANT", "VARIANT"), + ( + "MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR NOT NULL))", + "MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR NOT NULL))", + ), + ( + "MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR))", + "MAP(DECIMAL(10, 0), MAP(DECIMAL(10, 0), VARCHAR))", + ), + ("MAP(DECIMAL(10, 0), VARIANT)", "MAP(DECIMAL(10, 0), VARIANT)"), + ("OBJECT", "OBJECT"), + ("ARRAY", "ARRAY"), + ("GEOGRAPHY", "GEOGRAPHY"), + ("GEOMETRY", "GEOMETRY"), + ], +) +def test_snowflake_data_types(input_type, expected_type): + assert parse_type(input_type).compile() == expected_type diff --git a/tests/util.py b/tests/util.py index db0b0c9c..264478ff 100644 --- a/tests/util.py +++ b/tests/util.py @@ -29,6 +29,7 @@ ARRAY, GEOGRAPHY, GEOMETRY, + MAP, OBJECT, TIMESTAMP_LTZ, TIMESTAMP_NTZ, @@ -72,6 +73,7 @@ "ARRAY": ARRAY, "GEOGRAPHY": GEOGRAPHY, "GEOMETRY": GEOMETRY, + "MAP": MAP, } diff --git a/tox.ini b/tox.ini index 0c1cb483..102e2273 100644 --- a/tox.ini +++ b/tox.ini @@ -34,22 +34,17 @@ passenv = setenv = COVERAGE_FILE = {env:COVERAGE_FILE:{toxworkdir}/.coverage.{envname}} SQLALCHEMY_WARN_20 = 1 - ci: SNOWFLAKE_PYTEST_OPTS = -vvv + ci: SNOWFLAKE_PYTEST_OPTS = -vvv --tb=long commands = pytest \ {env:SNOWFLAKE_PYTEST_OPTS:} \ --cov "snowflake.sqlalchemy" \ --junitxml {toxworkdir}/junit_{envname}.xml \ + --ignore=tests/sqlalchemy_test_suite \ {posargs:tests} pytest {env:SNOWFLAKE_PYTEST_OPTS:} \ --cov "snowflake.sqlalchemy" --cov-append \ --junitxml {toxworkdir}/junit_{envname}.xml \ {posargs:tests/sqlalchemy_test_suite} - pytest \ - {env:SNOWFLAKE_PYTEST_OPTS:} \ - --cov "snowflake.sqlalchemy" --cov-append \ - --junitxml {toxworkdir}/junit_{envname}.xml \ - --run_v20_sqlalchemy \ - {posargs:tests} [testenv:.pkg_external] deps = build @@ -80,13 +75,14 @@ passenv = PROGRAMDATA deps = {[testenv]deps} + tomlkit pre-commit >= 2.9.0 skip_install = True commands = pre-commit run --all-files python -c 'import pathlib; print("hint: run \{\} install to add checks as pre-commit hook".format(pathlib.Path(r"{envdir}") / "bin" / "pre-commit"))' [pytest] -addopts = -ra --strict-markers --ignore=tests/sqlalchemy_test_suite +addopts = -ra --ignore=tests/sqlalchemy_test_suite junit_family = legacy log_level = info markers =