diff --git a/continuous_integration/environment-3.10.yaml b/continuous_integration/environment-3.10.yaml index b0557a915..912e2c54e 100644 --- a/continuous_integration/environment-3.10.yaml +++ b/continuous_integration/environment-3.10.yaml @@ -28,7 +28,7 @@ dependencies: - python=3.10 - scikit-learn>=1.0.0 - sphinx -- sqlalchemy<2 +- sqlalchemy - tpot>=0.12.0 # FIXME: https://github.com/fugue-project/fugue/issues/526 - triad<0.9.2 diff --git a/continuous_integration/environment-3.11.yaml b/continuous_integration/environment-3.11.yaml index 1bcf46d45..cd77ac8d5 100644 --- a/continuous_integration/environment-3.11.yaml +++ b/continuous_integration/environment-3.11.yaml @@ -28,7 +28,7 @@ dependencies: - python=3.11 - scikit-learn>=1.0.0 - sphinx -- sqlalchemy<2 +- sqlalchemy - tpot>=0.12.0 # FIXME: https://github.com/fugue-project/fugue/issues/526 - triad<0.9.2 diff --git a/continuous_integration/environment-3.12.yaml b/continuous_integration/environment-3.12.yaml index 18a67409b..53b52e629 100644 --- a/continuous_integration/environment-3.12.yaml +++ b/continuous_integration/environment-3.12.yaml @@ -29,7 +29,7 @@ dependencies: - python=3.12 - scikit-learn>=1.0.0 - sphinx -- sqlalchemy<2 +- sqlalchemy - tpot>=0.12.0 # FIXME: https://github.com/fugue-project/fugue/issues/526 - triad<0.9.2 diff --git a/continuous_integration/environment-3.9.yaml b/continuous_integration/environment-3.9.yaml index a627318c1..f9f8e9ebf 100644 --- a/continuous_integration/environment-3.9.yaml +++ b/continuous_integration/environment-3.9.yaml @@ -28,7 +28,7 @@ dependencies: - python=3.9 - scikit-learn=1.0.0 - sphinx -- sqlalchemy<2 +- sqlalchemy - tpot>=0.12.0 # FIXME: https://github.com/fugue-project/fugue/issues/526 - triad<0.9.2 diff --git a/continuous_integration/gpuci/environment-3.10.yaml b/continuous_integration/gpuci/environment-3.10.yaml index 2420e949f..6d567d498 100644 --- a/continuous_integration/gpuci/environment-3.10.yaml +++ b/continuous_integration/gpuci/environment-3.10.yaml @@ -33,7 +33,7 @@ dependencies: - python=3.10 - scikit-learn>=1.0.0 - sphinx -- sqlalchemy<2 +- sqlalchemy - tpot>=0.12.0 # FIXME: https://github.com/fugue-project/fugue/issues/526 - triad<0.9.2 diff --git a/continuous_integration/gpuci/environment-3.9.yaml b/continuous_integration/gpuci/environment-3.9.yaml index f88cf57c7..1e2c50efb 100644 --- a/continuous_integration/gpuci/environment-3.9.yaml +++ b/continuous_integration/gpuci/environment-3.9.yaml @@ -33,7 +33,7 @@ dependencies: - python=3.9 - scikit-learn>=1.0.0 - sphinx -- sqlalchemy<2 +- sqlalchemy - tpot>=0.12.0 # FIXME: https://github.com/fugue-project/fugue/issues/526 - triad<0.9.2 diff --git a/dask_sql/input_utils/hive.py b/dask_sql/input_utils/hive.py index 14bc547f0..b65e4d5ce 100644 --- a/dask_sql/input_utils/hive.py +++ b/dask_sql/input_utils/hive.py @@ -30,12 +30,12 @@ class HiveInputPlugin(BaseInputPlugin): def is_correct_input( self, input_item: Any, table_name: str, format: str = None, **kwargs ): - is_sqlalchemy_hive = sqlalchemy and isinstance( - input_item, sqlalchemy.engine.base.Connection - ) is_hive_cursor = hive and isinstance(input_item, hive.Cursor) - return is_sqlalchemy_hive or is_hive_cursor or format == "hive" + return self.is_sqlalchemy_hive(input_item) or is_hive_cursor or format == "hive" + + def is_sqlalchemy_hive(self, input_item: Any): + return sqlalchemy and isinstance(input_item, sqlalchemy.engine.base.Connection) def to_dc( self, @@ -201,7 +201,11 @@ def _parse_hive_table_description( of the DESCRIBE FORMATTED call, which is unfortunately in a format not easily readable by machines. """ - cursor.execute(f"USE {schema}") + cursor.execute( + sqlalchemy.text(f"USE {schema}") + if self.is_sqlalchemy_hive(cursor) + else f"USE {schema}" + ) if partition: # Hive wants quoted, comma separated list of partition keys partition = partition.replace("=", '="') @@ -283,7 +287,11 @@ def _parse_hive_partition_description( """ Extract all partition informaton for a given table """ - cursor.execute(f"USE {schema}") + cursor.execute( + sqlalchemy.text(f"USE {schema}") + if self.is_sqlalchemy_hive(cursor) + else f"USE {schema}" + ) result = self._fetch_all_results(cursor, f"SHOW PARTITIONS {table_name}") return [row[0] for row in result] @@ -298,7 +306,9 @@ def _fetch_all_results( The former has the fetchall method on the cursor, whereas the latter on the executed query. """ - result = cursor.execute(sql) + result = cursor.execute( + sqlalchemy.text(sql) if self.is_sqlalchemy_hive(cursor) else sql + ) try: return result.fetchall() diff --git a/tests/integration/fixtures.py b/tests/integration/fixtures.py index 90b6f3828..cd4e38928 100644 --- a/tests/integration/fixtures.py +++ b/tests/integration/fixtures.py @@ -335,6 +335,10 @@ def _assert_query_gives_same_result(query, sort_columns=None, **kwargs): # as expressions are handled differently dask_result.columns = sql_result.columns + # replace all pd.NA scalars, which are resistent to + # check_dype=False and .astype() + dask_result = dask_result.replace({pd.NA: None}) + if sort_columns: sql_result = sql_result.sort_values(sort_columns) dask_result = dask_result.sort_values(sort_columns) diff --git a/tests/integration/test_hive.py b/tests/integration/test_hive.py index 1a86082c1..17f4c1a98 100644 --- a/tests/integration/test_hive.py +++ b/tests/integration/test_hive.py @@ -142,25 +142,43 @@ def hive_cursor(): # Create a non-partitioned column cursor.execute( - f"CREATE TABLE df (i INTEGER, j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir}'" + sqlalchemy.text( + f"CREATE TABLE df (i INTEGER, j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir}'" + ) ) - cursor.execute("INSERT INTO df (i, j) VALUES (1, 2)") - cursor.execute("INSERT INTO df (i, j) VALUES (2, 4)") + cursor.execute(sqlalchemy.text("INSERT INTO df (i, j) VALUES (1, 2)")) + cursor.execute(sqlalchemy.text("INSERT INTO df (i, j) VALUES (2, 4)")) cursor.execute( - f"CREATE TABLE df_part (i INTEGER) PARTITIONED BY (j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir_parted}'" + sqlalchemy.text( + f"CREATE TABLE df_part (i INTEGER) PARTITIONED BY (j INTEGER) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir_parted}'" + ) + ) + cursor.execute( + sqlalchemy.text("INSERT INTO df_part PARTITION (j=2) (i) VALUES (1)") + ) + cursor.execute( + sqlalchemy.text("INSERT INTO df_part PARTITION (j=4) (i) VALUES (2)") ) - cursor.execute("INSERT INTO df_part PARTITION (j=2) (i) VALUES (1)") - cursor.execute("INSERT INTO df_part PARTITION (j=4) (i) VALUES (2)") cursor.execute( - f""" + sqlalchemy.text( + f""" CREATE TABLE df_parts (i INTEGER) PARTITIONED BY (j INTEGER, k STRING) ROW FORMAT DELIMITED STORED AS PARQUET LOCATION '{tmpdir_multiparted}' """ + ) + ) + cursor.execute( + sqlalchemy.text( + "INSERT INTO df_parts PARTITION (j=1, k='a') (i) VALUES (1)" + ) + ) + cursor.execute( + sqlalchemy.text( + "INSERT INTO df_parts PARTITION (j=2, k='b') (i) VALUES (2)" + ) ) - cursor.execute("INSERT INTO df_parts PARTITION (j=1, k='a') (i) VALUES (1)") - cursor.execute("INSERT INTO df_parts PARTITION (j=2, k='b') (i) VALUES (2)") # The data files are created as root user by default. Change that: hive_server.exec_run(["chmod", "a+rwx", "-R", tmpdir]) diff --git a/tests/integration/test_join.py b/tests/integration/test_join.py index 3f19a3211..e47721108 100644 --- a/tests/integration/test_join.py +++ b/tests/integration/test_join.py @@ -463,7 +463,6 @@ def test_join_reorder(c): SELECT a1, b2, c3 FROM a, b, c WHERE b1 < 3 AND c3 < 5 AND a1 = b1 AND b2 = c2 - LIMIT 10 """ explain_string = c.explain(query) @@ -491,15 +490,20 @@ def test_join_reorder(c): assert explain_string.index(second_join) < explain_string.index(first_join) result_df = c.sql(query) - expected_df = pd.DataFrame({"a1": [1] * 10, "b2": [2] * 10, "c3": [4] * 10}) - assert_eq(result_df, expected_df) + merged_df = df.merge(df2, left_on="a1", right_on="b1").merge( + df3, left_on="b2", right_on="c2" + ) + expected_df = merged_df[(merged_df["b1"] < 3) & (merged_df["c3"] < 5)][ + ["a1", "b2", "c3"] + ] + + assert_eq(result_df, expected_df, check_index=False) # By default, join reordering should NOT reorder unfiltered dimension tables query = """ SELECT a1, b2, c3 FROM a, b, c WHERE a1 = b1 AND b2 = c2 - LIMIT 10 """ explain_string = c.explain(query) @@ -510,8 +514,11 @@ def test_join_reorder(c): assert explain_string.index(second_join) < explain_string.index(first_join) result_df = c.sql(query) - expected_df = pd.DataFrame({"a1": [1] * 10, "b2": [2] * 10, "c3": [4, 5] * 5}) - assert_eq(result_df, expected_df) + expected_df = df.merge(df2, left_on="a1", right_on="b1").merge( + df3, left_on="b2", right_on="c2" + )[["a1", "b2", "c3"]] + + assert_eq(result_df, expected_df, check_index=False) @pytest.mark.xfail( diff --git a/tests/integration/test_model.py b/tests/integration/test_model.py index 973802fe4..c341965ce 100644 --- a/tests/integration/test_model.py +++ b/tests/integration/test_model.py @@ -5,6 +5,7 @@ import joblib import pandas as pd import pytest +from packaging.version import parse as parseVersion from tests.utils import assert_eq @@ -17,6 +18,10 @@ xgboost = None dask_cudf = None +sklearn = pytest.importorskip("sklearn") + +SKLEARN_GT_130 = parseVersion(sklearn.__version__) >= parseVersion("1.4") + def check_trained_model(c, model_name="my_model", df_name="timeseries"): sql = f""" @@ -902,10 +907,10 @@ def test_ml_experiment(c, client): ) +@pytest.mark.xfail(reason="tpot is broken with sklearn>=1.4", condition=SKLEARN_GT_130) def test_experiment_automl_classifier(c, client): tpot = pytest.importorskip("tpot", reason="tpot not installed") - # currently tested with tpot== c.sql( """ CREATE EXPERIMENT my_automl_exp1 WITH ( @@ -927,6 +932,7 @@ def test_experiment_automl_classifier(c, client): check_trained_model(c, "my_automl_exp1") +@pytest.mark.xfail(reason="tpot is broken with sklearn>=1.4", condition=SKLEARN_GT_130) def test_experiment_automl_regressor(c, client): tpot = pytest.importorskip("tpot", reason="tpot not installed")