Skip to content

Commit

Permalink
Merge branch 'main' into dpp_bug
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahyurick authored Jan 23, 2024
2 parents 4a726be + f53c3e0 commit c1b876b
Show file tree
Hide file tree
Showing 11 changed files with 74 additions and 29 deletions.
2 changes: 1 addition & 1 deletion continuous_integration/environment-3.10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies:
- python=3.10
- scikit-learn>=1.0.0
- sphinx
- sqlalchemy<2
- sqlalchemy
- tpot>=0.12.0
# FIXME: https://github.com/fugue-project/fugue/issues/526
- triad<0.9.2
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/environment-3.11.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies:
- python=3.11
- scikit-learn>=1.0.0
- sphinx
- sqlalchemy<2
- sqlalchemy
- tpot>=0.12.0
# FIXME: https://github.com/fugue-project/fugue/issues/526
- triad<0.9.2
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/environment-3.12.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies:
- python=3.12
- scikit-learn>=1.0.0
- sphinx
- sqlalchemy<2
- sqlalchemy
- tpot>=0.12.0
# FIXME: https://github.com/fugue-project/fugue/issues/526
- triad<0.9.2
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/environment-3.9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies:
- python=3.9
- scikit-learn=1.0.0
- sphinx
- sqlalchemy<2
- sqlalchemy
- tpot>=0.12.0
# FIXME: https://github.com/fugue-project/fugue/issues/526
- triad<0.9.2
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/gpuci/environment-3.10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies:
- python=3.10
- scikit-learn>=1.0.0
- sphinx
- sqlalchemy<2
- sqlalchemy
- tpot>=0.12.0
# FIXME: https://github.com/fugue-project/fugue/issues/526
- triad<0.9.2
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/gpuci/environment-3.9.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies:
- python=3.9
- scikit-learn>=1.0.0
- sphinx
- sqlalchemy<2
- sqlalchemy
- tpot>=0.12.0
# FIXME: https://github.com/fugue-project/fugue/issues/526
- triad<0.9.2
Expand Down
24 changes: 17 additions & 7 deletions dask_sql/input_utils/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ class HiveInputPlugin(BaseInputPlugin):
def is_correct_input(
self, input_item: Any, table_name: str, format: str = None, **kwargs
):
is_sqlalchemy_hive = sqlalchemy and isinstance(
input_item, sqlalchemy.engine.base.Connection
)
is_hive_cursor = hive and isinstance(input_item, hive.Cursor)

return is_sqlalchemy_hive or is_hive_cursor or format == "hive"
return self.is_sqlalchemy_hive(input_item) or is_hive_cursor or format == "hive"

def is_sqlalchemy_hive(self, input_item: Any):
return sqlalchemy and isinstance(input_item, sqlalchemy.engine.base.Connection)

def to_dc(
self,
Expand Down Expand Up @@ -201,7 +201,11 @@ def _parse_hive_table_description(
of the DESCRIBE FORMATTED call, which is unfortunately
in a format not easily readable by machines.
"""
cursor.execute(f"USE {schema}")
cursor.execute(
sqlalchemy.text(f"USE {schema}")
if self.is_sqlalchemy_hive(cursor)
else f"USE {schema}"
)
if partition:
# Hive wants quoted, comma separated list of partition keys
partition = partition.replace("=", '="')
Expand Down Expand Up @@ -283,7 +287,11 @@ def _parse_hive_partition_description(
"""
Extract all partition informaton for a given table
"""
cursor.execute(f"USE {schema}")
cursor.execute(
sqlalchemy.text(f"USE {schema}")
if self.is_sqlalchemy_hive(cursor)
else f"USE {schema}"
)
result = self._fetch_all_results(cursor, f"SHOW PARTITIONS {table_name}")

return [row[0] for row in result]
Expand All @@ -298,7 +306,9 @@ def _fetch_all_results(
The former has the fetchall method on the cursor,
whereas the latter on the executed query.
"""
result = cursor.execute(sql)
result = cursor.execute(
sqlalchemy.text(sql) if self.is_sqlalchemy_hive(cursor) else sql
)

try:
return result.fetchall()
Expand Down
4 changes: 4 additions & 0 deletions tests/integration/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,10 @@ def _assert_query_gives_same_result(query, sort_columns=None, **kwargs):
# as expressions are handled differently
dask_result.columns = sql_result.columns

# replace all pd.NA scalars, which are resistent to
# check_dype=False and .astype()
dask_result = dask_result.replace({pd.NA: None})

if sort_columns:
sql_result = sql_result.sort_values(sort_columns)
dask_result = dask_result.sort_values(sort_columns)
Expand Down
36 changes: 27 additions & 9 deletions tests/integration/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,25 +142,43 @@ def hive_cursor():

# Create a non-partitioned column
cursor.execute(
f"CREATE TABLE df (i INTEGER, j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir}'"
sqlalchemy.text(
f"CREATE TABLE df (i INTEGER, j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir}'"
)
)
cursor.execute("INSERT INTO df (i, j) VALUES (1, 2)")
cursor.execute("INSERT INTO df (i, j) VALUES (2, 4)")
cursor.execute(sqlalchemy.text("INSERT INTO df (i, j) VALUES (1, 2)"))
cursor.execute(sqlalchemy.text("INSERT INTO df (i, j) VALUES (2, 4)"))

cursor.execute(
f"CREATE TABLE df_part (i INTEGER) PARTITIONED BY (j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir_parted}'"
sqlalchemy.text(
f"CREATE TABLE df_part (i INTEGER) PARTITIONED BY (j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir_parted}'"
)
)
cursor.execute(
sqlalchemy.text("INSERT INTO df_part PARTITION (j=2) (i) VALUES (1)")
)
cursor.execute(
sqlalchemy.text("INSERT INTO df_part PARTITION (j=4) (i) VALUES (2)")
)
cursor.execute("INSERT INTO df_part PARTITION (j=2) (i) VALUES (1)")
cursor.execute("INSERT INTO df_part PARTITION (j=4) (i) VALUES (2)")

cursor.execute(
f"""
sqlalchemy.text(
f"""
CREATE TABLE df_parts (i INTEGER) PARTITIONED BY (j INTEGER, k STRING)
ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir_multiparted}'
"""
)
)
cursor.execute(
sqlalchemy.text(
"INSERT INTO df_parts PARTITION (j=1, k='a') (i) VALUES (1)"
)
)
cursor.execute(
sqlalchemy.text(
"INSERT INTO df_parts PARTITION (j=2, k='b') (i) VALUES (2)"
)
)
cursor.execute("INSERT INTO df_parts PARTITION (j=1, k='a') (i) VALUES (1)")
cursor.execute("INSERT INTO df_parts PARTITION (j=2, k='b') (i) VALUES (2)")

# The data files are created as root user by default. Change that:
hive_server.exec_run(["chmod", "a+rwx", "-R", tmpdir])
Expand Down
19 changes: 13 additions & 6 deletions tests/integration/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,6 @@ def test_join_reorder(c):
SELECT a1, b2, c3
FROM a, b, c
WHERE b1 < 3 AND c3 < 5 AND a1 = b1 AND b2 = c2
LIMIT 10
"""

explain_string = c.explain(query)
Expand Down Expand Up @@ -491,15 +490,20 @@ def test_join_reorder(c):
assert explain_string.index(second_join) < explain_string.index(first_join)

result_df = c.sql(query)
expected_df = pd.DataFrame({"a1": [1] * 10, "b2": [2] * 10, "c3": [4] * 10})
assert_eq(result_df, expected_df)
merged_df = df.merge(df2, left_on="a1", right_on="b1").merge(
df3, left_on="b2", right_on="c2"
)
expected_df = merged_df[(merged_df["b1"] < 3) & (merged_df["c3"] < 5)][
["a1", "b2", "c3"]
]

assert_eq(result_df, expected_df, check_index=False)

# By default, join reordering should NOT reorder unfiltered dimension tables
query = """
SELECT a1, b2, c3
FROM a, b, c
WHERE a1 = b1 AND b2 = c2
LIMIT 10
"""

explain_string = c.explain(query)
Expand All @@ -510,8 +514,11 @@ def test_join_reorder(c):
assert explain_string.index(second_join) < explain_string.index(first_join)

result_df = c.sql(query)
expected_df = pd.DataFrame({"a1": [1] * 10, "b2": [2] * 10, "c3": [4, 5] * 5})
assert_eq(result_df, expected_df)
expected_df = df.merge(df2, left_on="a1", right_on="b1").merge(
df3, left_on="b2", right_on="c2"
)[["a1", "b2", "c3"]]

assert_eq(result_df, expected_df, check_index=False)


@pytest.mark.xfail(
Expand Down
8 changes: 7 additions & 1 deletion tests/integration/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import joblib
import pandas as pd
import pytest
from packaging.version import parse as parseVersion

from tests.utils import assert_eq

Expand All @@ -17,6 +18,10 @@
xgboost = None
dask_cudf = None

sklearn = pytest.importorskip("sklearn")

SKLEARN_GT_130 = parseVersion(sklearn.__version__) >= parseVersion("1.4")


def check_trained_model(c, model_name="my_model", df_name="timeseries"):
sql = f"""
Expand Down Expand Up @@ -902,10 +907,10 @@ def test_ml_experiment(c, client):
)


@pytest.mark.xfail(reason="tpot is broken with sklearn>=1.4", condition=SKLEARN_GT_130)
def test_experiment_automl_classifier(c, client):
tpot = pytest.importorskip("tpot", reason="tpot not installed")

# currently tested with tpot==
c.sql(
"""
CREATE EXPERIMENT my_automl_exp1 WITH (
Expand All @@ -927,6 +932,7 @@ def test_experiment_automl_classifier(c, client):
check_trained_model(c, "my_automl_exp1")


@pytest.mark.xfail(reason="tpot is broken with sklearn>=1.4", condition=SKLEARN_GT_130)
def test_experiment_automl_regressor(c, client):
tpot = pytest.importorskip("tpot", reason="tpot not installed")

Expand Down

0 comments on commit c1b876b

Please sign in to comment.