Skip to content

Commit

Permalink
Merge branch 'main' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
rosner authored Oct 16, 2024
2 parents 1a8c023 + 43c6b56 commit cc289df
Show file tree
Hide file tree
Showing 40 changed files with 1,601 additions and 34 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/build_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ jobs:
run: |
gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \
.github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py
- name: Run test for AWS
run: hatch run test-dialect-aws
if: matrix.cloud-provider == 'aws'
- name: Run tests
run: hatch run test-dialect
- uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -203,6 +206,9 @@ jobs:
python -m pip install -U uv
python -m uv pip install -U hatch
python -m hatch env create default
- name: Run test for AWS
run: hatch run sa14:test-dialect-aws
if: matrix.cloud-provider == 'aws'
- name: Run tests
run: hatch run sa14:test-dialect
- uses: actions/upload-artifact@v4
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ jobs:
--signature "${dist_base}.sig" \
--cert "${dist_base}.crt" \
--cert-oidc-issuer https://token.actions.githubusercontent.com \
--cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/build_and_sign_demand.yml@${GITHUB_REF}
--cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/python-publish.yml@${GITHUB_REF}
# Verify using `.sigstore` bundle;
python -m \
sigstore verify identity "${dist}" \
--bundle "${dist_base}.sigstore" \
--cert-oidc-issuer https://token.actions.githubusercontent.com \
--cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/build_and_sign_demand.yml@${GITHUB_REF}
--cert-identity ${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}/.github/workflows/python-publish.yml@${GITHUB_REF}
done
- name: List artifacts after sign
run: ls ./dist
Expand Down
10 changes: 8 additions & 2 deletions DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ Source code is also available at:

# Release Notes

- (Unreleased)

- Add support for dynamic tables and required options
- Add support for hybrid tables
- Fixed SAWarning when registering functions with existing name in default namespace

- v1.6.1(July 9, 2024)

- Update internal project workflow with pypi publishing
Expand All @@ -24,7 +30,7 @@ Source code is also available at:

- v1.5.3(April 16, 2024)

- Limit SQLAlchemy to < 2.0.0 before releasing version compatible with 2.0
- Limit SQLAlchemy to < 2.0.0 before releasing version compatible with 2.0

- v1.5.2(April 11, 2024)

Expand All @@ -33,7 +39,7 @@ Source code is also available at:

- v1.5.1(November 03, 2023)

- Fixed a compatibility issue with Snowflake Behavioral Change 1057 on outer lateral join, for more details check https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057.
- Fixed a compatibility issue with Snowflake Behavioral Change 1057 on outer lateral join, for more details check <https://docs.snowflake.com/en/release-notes/bcr-bundles/2023_04/bcr-1057>.
- Fixed credentials with `externalbrowser` authentication not caching due to incorrect parsing of boolean query parameters.
- This fixes other boolean parameter passing to driver as well.

Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ development = [
"pytz",
"numpy",
"mock",
"syrupy==4.6.1",
]
pandas = ["snowflake-connector-python[pandas]"]

Expand Down Expand Up @@ -91,6 +92,7 @@ SQLACHEMY_WARN_20 = "1"
check = "pre-commit run --all-files"
test-dialect = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/"
test-dialect-compatibility = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml tests/sqlalchemy_test_suite"
test-dialect-aws = "pytest -m \"aws\" -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/"
gh-cache-sum = "python -VV | sha256sum | cut -d' ' -f1"
check-import = "python -c 'import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)'"

Expand All @@ -109,6 +111,7 @@ line-length = 88
line-length = 88

[tool.pytest.ini_options]
addopts = "-m 'not feature_max_lob_size and not aws'"
markers = [
# Optional dependency groups markers
"lambda: AWS lambda tests",
Expand All @@ -126,4 +129,5 @@ markers = [
"timeout: tests that need a timeout time",
"internal: tests that could but should only run on our internal CI",
"external: tests that could but should only run on our external CI",
"feature_max_lob_size: tests that could but should only run on our external CI",
]
8 changes: 8 additions & 0 deletions src/snowflake/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
VARBINARY,
VARIANT,
)
from .sql.custom_schema import DynamicTable, HybridTable
from .sql.custom_schema.options import AsQuery, TargetLag, TimeUnit, Warehouse
from .util import _url as URL

base.dialect = dialect = snowdialect.dialect
Expand Down Expand Up @@ -113,4 +115,10 @@
"ExternalStage",
"CreateStage",
"CreateFileFormat",
"DynamicTable",
"AsQuery",
"TargetLag",
"TimeUnit",
"Warehouse",
"HybridTable",
)
1 change: 1 addition & 0 deletions src/snowflake/sqlalchemy/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@

APPLICATION_NAME = "SnowflakeSQLAlchemy"
SNOWFLAKE_SQLALCHEMY_VERSION = VERSION
DIALECT_NAME = "snowflake"
33 changes: 27 additions & 6 deletions src/snowflake/sqlalchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,16 @@
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.selectable import Lateral, SelectState

from .compat import IS_VERSION_20, args_reducer, string_types
from .custom_commands import AWSBucket, AzureContainer, ExternalStage
from snowflake.sqlalchemy._constants import DIALECT_NAME
from snowflake.sqlalchemy.compat import IS_VERSION_20, args_reducer, string_types
from snowflake.sqlalchemy.custom_commands import (
AWSBucket,
AzureContainer,
ExternalStage,
)

from .functions import flatten
from .sql.custom_schema.options.table_option_base import TableOptionBase
from .util import (
_find_left_clause_to_join_from,
_set_connection_interpolate_empty_sequences,
Expand Down Expand Up @@ -184,7 +191,6 @@ def _join_determine_implicit_left_side(self, raw_columns, left, right, onclause)
[element._from_objects for element in statement._where_criteria]
),
):

potential[from_clause] = ()

all_clauses = list(potential.keys())
Expand Down Expand Up @@ -879,7 +885,7 @@ def get_column_specification(self, column, **kwargs):

return " ".join(colspec)

def post_create_table(self, table):
def handle_cluster_by(self, table):
"""
Handles snowflake-specific ``CREATE TABLE ... CLUSTER BY`` syntax.
Expand Down Expand Up @@ -909,14 +915,29 @@ def post_create_table(self, table):
<BLANKLINE>
"""
text = ""
info = table.dialect_options["snowflake"]
info = table.dialect_options[DIALECT_NAME]
cluster = info.get("clusterby")
if cluster:
text += " CLUSTER BY ({})".format(
", ".join(self.denormalize_column_name(key) for key in cluster)
)
return text

def post_create_table(self, table):
text = self.handle_cluster_by(table)
options = [
option
for _, option in table.dialect_options[DIALECT_NAME].items()
if isinstance(option, TableOptionBase)
]
options.sort(
key=lambda x: (x.__priority__.value, x.__option_name__), reverse=True
)
for option in options:
text += "\t" + option.render_option(self)

return text

def visit_create_stage(self, create_stage, **kw):
"""
This visitor will create the SQL representation for a CREATE STAGE command.
Expand Down Expand Up @@ -1065,4 +1086,4 @@ def visit_GEOMETRY(self, type_, **kw):

construct_arguments = [(Table, {"clusterby": None})]

functions.register_function("flatten", flatten)
functions.register_function("flatten", flatten, "snowflake")
141 changes: 131 additions & 10 deletions src/snowflake/sqlalchemy/snowdialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from snowflake.connector.constants import UTF8
from snowflake.sqlalchemy.compat import returns_unicode

from ._constants import DIALECT_NAME
from .base import (
SnowflakeCompiler,
SnowflakeDDLCompiler,
Expand All @@ -64,6 +65,7 @@
_CUSTOM_Float,
_CUSTOM_Time,
)
from .sql.custom_schema.custom_table_prefix import CustomTablePrefix
from .util import (
_update_connection_application_name,
parse_url_boolean,
Expand Down Expand Up @@ -119,7 +121,7 @@


class SnowflakeDialect(default.DefaultDialect):
name = "snowflake"
name = DIALECT_NAME
driver = "snowflake"
max_identifier_length = 255
cte_follows_insert = True
Expand Down Expand Up @@ -351,14 +353,6 @@ def _map_name_to_idx(result):
name_to_idx[col[0]] = idx
return name_to_idx

@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
"""
Gets all indexes
"""
# no index is supported by Snowflake
return []

@reflection.cache
def get_check_constraints(self, connection, table_name, schema, **kw):
# check constraints are not supported by Snowflake
Expand Down Expand Up @@ -894,6 +888,129 @@ def get_table_comment(self, connection, table_name, schema=None, **kw):
)
}

def get_multi_indexes(
self,
connection,
*,
schema,
filter_names,
**kw,
):
"""
Gets the indexes definition
"""

table_prefixes = self.get_multi_prefixes(
connection, schema, filter_prefix=CustomTablePrefix.HYBRID.name
)
if len(table_prefixes) == 0:
return []
schema = schema or self.default_schema_name
if not schema:
result = connection.execute(
text("SHOW /* sqlalchemy:get_multi_indexes */ INDEXES")
)
else:
result = connection.execute(
text(
f"SHOW /* sqlalchemy:get_multi_indexes */ INDEXES IN SCHEMA {self._denormalize_quote_join(schema)}"
)
)

n2i = self.__class__._map_name_to_idx(result)
indexes = {}

for row in result.cursor.fetchall():
table = self.normalize_name(str(row[n2i["table"]]))
if (
row[n2i["name"]] == f'SYS_INDEX_{row[n2i["table"]]}_PRIMARY'
or table not in filter_names
or (schema, table) not in table_prefixes
or (
(schema, table) in table_prefixes
and CustomTablePrefix.HYBRID.name
not in table_prefixes[(schema, table)]
)
):
continue
index = {
"name": row[n2i["name"]],
"unique": row[n2i["is_unique"]] == "Y",
"column_names": row[n2i["columns"]],
"include_columns": row[n2i["included_columns"]],
"dialect_options": {},
}
if (schema, table) in indexes:
indexes[(schema, table)] = indexes[(schema, table)].append(index)
else:
indexes[(schema, table)] = [index]

return list(indexes.items())

def _value_or_default(self, data, table, schema):
table = self.normalize_name(str(table))
dic_data = dict(data)
if (schema, table) in dic_data:
return dic_data[(schema, table)]
else:
return []

def get_prefixes_from_data(self, n2i, row, **kw):
prefixes_found = []
for valid_prefix in CustomTablePrefix:
key = f"is_{valid_prefix.name.lower()}"
if key in n2i and row[n2i[key]] == "Y":
prefixes_found.append(valid_prefix.name)
return prefixes_found

@reflection.cache
def get_multi_prefixes(
self, connection, schema, table_name=None, filter_prefix=None, **kw
):
"""
Gets all table prefixes
"""
schema = schema or self.default_schema_name
filter = f"LIKE '{table_name}'" if table_name else ""
if schema:
result = connection.execute(
text(
f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES IN SCHEMA {schema}"
)
)
else:
result = connection.execute(
text(
f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES LIKE '{table_name}'"
)
)

n2i = self.__class__._map_name_to_idx(result)
tables_prefixes = {}
for row in result.cursor.fetchall():
table = self.normalize_name(str(row[n2i["name"]]))
table_prefixes = self.get_prefixes_from_data(n2i, row)
if filter_prefix and filter_prefix not in table_prefixes:
continue
if (schema, table) in tables_prefixes:
tables_prefixes[(schema, table)].append(table_prefixes)
else:
tables_prefixes[(schema, table)] = table_prefixes

return tables_prefixes

@reflection.cache
def get_indexes(self, connection, tablename, schema, **kw):
"""
Gets the indexes definition
"""
table_name = self.normalize_name(str(tablename))
data = self.get_multi_indexes(
connection=connection, schema=schema, filter_names=[table_name], **kw
)

return self._value_or_default(data, table_name, schema)

def connect(self, *cargs, **cparams):
return (
super().connect(
Expand All @@ -911,8 +1028,12 @@ def connect(self, *cargs, **cparams):

@sa_vnt.listens_for(Table, "before_create")
def check_table(table, connection, _ddl_runner, **kw):
from .sql.custom_schema.hybrid_table import HybridTable

if HybridTable.is_equal_type(table): # noqa
return True
if isinstance(_ddl_runner.dialect, SnowflakeDialect) and table.indexes:
raise NotImplementedError("Snowflake does not support indexes")
raise NotImplementedError("Only Snowflake Hybrid Tables supports indexes")


dialect = SnowflakeDialect
3 changes: 3 additions & 0 deletions src/snowflake/sqlalchemy/sql/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#
7 changes: 7 additions & 0 deletions src/snowflake/sqlalchemy/sql/custom_schema/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#
from .dynamic_table import DynamicTable
from .hybrid_table import HybridTable

__all__ = ["DynamicTable", "HybridTable"]
Loading

0 comments on commit cc289df

Please sign in to comment.