Skip to content

Commit

Permalink
add PARTITION BY option for CopyInto
Browse files Browse the repository at this point in the history
  • Loading branch information
azban committed Jul 28, 2023
1 parent e3f675e commit 99e4eae
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
11 changes: 7 additions & 4 deletions src/snowflake/sqlalchemy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def visit_copy_into(self, copy_into, **kw):
if isinstance(copy_into.into, Table)
else copy_into.into._compiler_dispatch(self, **kw)
)
from_ = None
if isinstance(copy_into.from_, Table):
from_ = copy_into.from_
# this is intended to catch AWSBucket and AzureContainer
Expand All @@ -228,6 +227,11 @@ def visit_copy_into(self, copy_into, **kw):
# everything else (selects, etc.)
else:
from_ = f"({copy_into.from_._compiler_dispatch(self, **kw)})"

partition_by = ""
if copy_into.partition_by is not None:
partition_by = f"PARTITION BY {partition_by}"

credentials, encryption = "", ""
if isinstance(into, tuple):
into, credentials, encryption = into
Expand All @@ -238,8 +242,7 @@ def visit_copy_into(self, copy_into, **kw):
options_list.sort(key=operator.itemgetter(0))
options = (
(
" "
+ " ".join(
" ".join(
[
"{} = {}".format(
n,
Expand All @@ -258,7 +261,7 @@ def visit_copy_into(self, copy_into, **kw):
options += f" {credentials}"
if encryption:
options += f" {encryption}"
return f"COPY INTO {into} FROM {from_} {formatter}{options}"
return f"COPY INTO {into} FROM {' '.join([from_, partition_by, formatter, options])}"

def visit_copy_formatter(self, formatter, **kw):
options_list = list(formatter.options.items())
Expand Down
11 changes: 8 additions & 3 deletions src/snowflake/sqlalchemy/custom_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,23 @@ class CopyInto(UpdateBase):
__visit_name__ = "copy_into"
_bind = None

def __init__(self, from_, into, formatter=None):
def __init__(self, from_, into, partition_by=None, formatter=None):
self.from_ = from_
self.into = into
self.formatter = formatter
self.copy_options = {}
self.partition_by = partition_by

def __repr__(self):
"""
repr for debugging / logging purposes only. For compilation logic, see
the corresponding visitor in base.py
"""
return f"COPY INTO {self.into} FROM {repr(self.from_)} {repr(self.formatter)} ({self.copy_options})"
val = f"COPY INTO {self.into} FROM {repr(self.from_)}"
if self.partition_by is not None:
val += f" PARTITION BY {self.partition_by}"

return val + f" {repr(self.formatter)} ({self.copy_options})"

def bind(self):
return None
Expand Down Expand Up @@ -530,7 +535,7 @@ def __repr__(self):
)

def credentials(
self, aws_role=None, aws_key_id=None, aws_secret_key=None, aws_token=None
self, aws_role=None, aws_key_id=None, aws_secret_key=None, aws_token=None
):
if aws_role is None and (aws_key_id is None and aws_secret_key is None):
raise ValueError(
Expand Down
10 changes: 10 additions & 0 deletions tests/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,16 @@ def test_copy_into_location(engine_testaccount, sql_compiler):
== "COPY INTO @name.stage_name/prefix/file FROM python_tests_foods FILE_FORMAT=(TYPE=csv)"
)

copy_stmt_7 = CopyIntoStorage(
from_=food_items,
into=ExternalStage(name="stage_name"),
partition_by="('YEAR=' || year)"
)
assert (
sql_compiler(copy_stmt_7)
== "COPY INTO @stage_name FROM python_tests_foods PARTITION BY ('YEAR=' || year)"
)

# NOTE Other than expect known compiled text, submit it to RegressionTests environment and expect them to fail, but
# because of the right reasons
acceptable_exc_reasons = {
Expand Down

0 comments on commit 99e4eae

Please sign in to comment.