Skip to content

Commit

Permalink
Add flag to run test only on aws
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jvasquezrojas committed Oct 4, 2024
1 parent 5af3a07 commit 3ea345c
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 9 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/build_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ jobs:
.github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py
- name: Run tests
run: hatch run test-dialect
if: matrix.cloud-provider != 'aws'
- name: Run test for AWS
run: hatch run test-dialect-aws
if: matrix.cloud-provider == 'aws'
- uses: actions/upload-artifact@v4
with:
name: coverage.xml_dialect-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }}
Expand Down Expand Up @@ -205,6 +209,10 @@ jobs:
python -m hatch env create default
- name: Run tests
run: hatch run sa14:test-dialect
if: matrix.cloud-provider != 'aws'
- name: Run test for AWS
run: hatch run sa14:test-dialect-aws
if: matrix.cloud-provider == 'aws'
- uses: actions/upload-artifact@v4
with:
name: coverage.xml_dialect-v14-${{ matrix.os }}-${{ matrix.python-version }}-${{ matrix.cloud-provider }}
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ SQLACHEMY_WARN_20 = "1"
check = "pre-commit run --all-files"
test-dialect = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/"
test-dialect-compatibility = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml tests/sqlalchemy_test_suite"
test-dialect-aws = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite -m 'aws' tests/"
gh-cache-sum = "python -VV | sha256sum | cut -d' ' -f1"
check-import = "python -c 'import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)'"

Expand All @@ -110,7 +111,7 @@ line-length = 88
line-length = 88

[tool.pytest.ini_options]
addopts = "-m 'not feature_max_lob_size'"
addopts = "-m 'not feature_max_lob_size and not aws'"
markers = [
# Optional dependency groups markers
"lambda: AWS lambda tests",
Expand Down
4 changes: 3 additions & 1 deletion tests/custom_tables/test_compile_hybrid_table.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

import pytest
from sqlalchemy import Column, Integer, MetaData, String
from sqlalchemy.orm import declarative_base
from sqlalchemy.sql.ddl import CreateTable

from snowflake.sqlalchemy import GEOMETRY, HybridTable


@pytest.mark.aws
def test_compile_hybrid_table(sql_compiler, snapshot):
metadata = MetaData()
table_name = "test_hybrid_table"
Expand All @@ -27,6 +28,7 @@ def test_compile_hybrid_table(sql_compiler, snapshot):
assert actual == snapshot


@pytest.mark.aws
def test_compile_hybrid_table_orm(sql_compiler, snapshot):
Base = declarative_base()

Expand Down
3 changes: 3 additions & 0 deletions tests/custom_tables/test_create_hybrid_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from src.snowflake.sqlalchemy import HybridTable


@pytest.mark.aws
def test_create_hybrid_table(engine_testaccount, db_parameters, snapshot):
metadata = MetaData()
table_name = "test_create_hybrid_table"
Expand Down Expand Up @@ -37,6 +38,7 @@ def test_create_hybrid_table(engine_testaccount, db_parameters, snapshot):
metadata.drop_all(engine_testaccount)


@pytest.mark.aws
def test_create_hybrid_table_with_multiple_index(
engine_testaccount, db_parameters, snapshot, sql_compiler
):
Expand Down Expand Up @@ -64,6 +66,7 @@ def test_create_hybrid_table_with_multiple_index(
metadata.drop_all(engine_testaccount)


@pytest.mark.aws
def test_create_hybrid_table_with_orm(sql_compiler, engine_testaccount):
Base = declarative_base()
session = Session(bind=engine_testaccount)
Expand Down
3 changes: 3 additions & 0 deletions tests/custom_tables/test_reflect_hybrid_table.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#
import pytest
from sqlalchemy import MetaData, Table
from sqlalchemy.sql.ddl import CreateTable


@pytest.mark.aws
def test_simple_reflection_hybrid_table_as_table(
engine_testaccount, db_parameters, sql_compiler, snapshot
):
Expand Down Expand Up @@ -37,6 +39,7 @@ def test_simple_reflection_hybrid_table_as_table(
metadata.drop_all(engine_testaccount)


@pytest.mark.aws
def test_reflect_hybrid_table_with_index(
engine_testaccount, db_parameters, sql_compiler
):
Expand Down
2 changes: 2 additions & 0 deletions tests/test_index_reflection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#
import pytest
from sqlalchemy import MetaData
from sqlalchemy.engine import reflection


@pytest.mark.aws
def test_indexes_reflection(engine_testaccount, db_parameters, sql_compiler):
metadata = MetaData()

Expand Down
72 changes: 65 additions & 7 deletions tests/test_orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
Integer,
Sequence,
String,
Table,
exc,
func,
select,
Expand Down Expand Up @@ -128,6 +127,7 @@ def __repr__(self):
Base.metadata.drop_all(engine_testaccount)


@pytest.mark.aws
def test_orm_one_to_many_relationship_with_hybrid_table(engine_testaccount, snapshot):
"""
Tests One to Many relationship
Expand Down Expand Up @@ -191,21 +191,79 @@ def __repr__(self):
Base.metadata.drop_all(engine_testaccount)


@pytest.mark.parametrize(
"table_class, prefix", [(Table, "tbl_"), (HybridTable, "hb_tbl_")]
)
def test_delete_cascade(engine_testaccount, table_class, prefix):
def test_delete_cascade(engine_testaccount):
"""
Test delete cascade
"""
Base = declarative_base()
prefix = "tbl_"

class User(Base):
__tablename__ = prefix + "user"

id = Column(Integer, Sequence("user_id_seq"), primary_key=True)
name = Column(String)
fullname = Column(String)

addresses = relationship(
"Address", back_populates="user", cascade="all, delete, delete-orphan"
)

def __repr__(self):
return f"<User({self.name!r}, {self.fullname!r})>"

class Address(Base):
__tablename__ = prefix + "address"

id = Column(Integer, Sequence("address_id_seq"), primary_key=True)
email_address = Column(String, nullable=False)
user_id = Column(Integer, ForeignKey(f"{User.__tablename__}.id"))

user = relationship(User, back_populates="addresses")

def __repr__(self):
return f"<Address({repr(self.email_address)})>"

Base.metadata.create_all(engine_testaccount)

try:
jack = User(name="jack", fullname="Jack Bean")
assert jack.addresses == [], "one to many record is empty list"

jack.addresses = [
Address(email_address="[email protected]"),
Address(email_address="[email protected]"),
Address(email_address="[email protected]"),
]

session = Session(bind=engine_testaccount)
session.add(jack) # cascade each Address into the Session as well
session.commit()

got_jack = session.query(User).first()
assert got_jack == jack

session.delete(jack)
got_addresses = session.query(Address).all()
assert len(got_addresses) == 0, "no address record"
finally:
Base.metadata.drop_all(engine_testaccount)


@pytest.mark.aws
def test_delete_cascade_hybrid_table(engine_testaccount):
"""
Test delete cascade
"""
Base = declarative_base()
prefix = "hb_tbl_"

class User(Base):
__tablename__ = prefix + "user"

@classmethod
def __table_cls__(cls, name, metadata, *arg, **kw):
return table_class(name, metadata, *arg, **kw)
return HybridTable(name, metadata, *arg, **kw)

id = Column(Integer, Sequence("user_id_seq"), primary_key=True)
name = Column(String)
Expand All @@ -223,7 +281,7 @@ class Address(Base):

@classmethod
def __table_cls__(cls, name, metadata, *arg, **kw):
return table_class(name, metadata, *arg, **kw)
return HybridTable(name, metadata, *arg, **kw)

id = Column(Integer, Sequence("address_id_seq"), primary_key=True)
email_address = Column(String, nullable=False)
Expand Down

0 comments on commit 3ea345c

Please sign in to comment.