From 72f90cb90bcb7b03eaef26980198398ff7165773 Mon Sep 17 00:00:00 2001 From: qinxuye Date: Sun, 10 Sep 2023 11:55:28 +0800 Subject: [PATCH 01/10] ENH: basic support TPC-H --- xorbits_sql/executor.py | 86 +++++++++++++++++++++++++++++---- xorbits_sql/tests/test_tpc_h.py | 64 ++++++++++++++++++++++++ 2 files changed, 141 insertions(+), 9 deletions(-) create mode 100644 xorbits_sql/tests/test_tpc_h.py diff --git a/xorbits_sql/executor.py b/xorbits_sql/executor.py index 9bdc223..8b5221f 100644 --- a/xorbits_sql/executor.py +++ b/xorbits_sql/executor.py @@ -36,6 +36,15 @@ exp.Variance: "var", } +_SQLGLOT_NUMBER_TO_PD_TYPES = { + exp.DataType.Type.FLOAT: pandas.Float32Dtype, + exp.DataType.Type.DOUBLE: pandas.Float64Dtype, + exp.DataType.Type.INT: pandas.Int32Dtype, + exp.DataType.Type.TINYINT: pandas.Int8Dtype, + exp.DataType.Type.SMALLINT: pandas.Int16Dtype, + exp.DataType.Type.BIGINT: pandas.Int64Dtype, +} + class XorbitsExecutor: def __init__(self, tables: Tables | None = None): @@ -45,14 +54,16 @@ def __init__(self, tables: Tables | None = None): @lru_cache(1) def _exp_visitors(cls) -> TypeDispatcher: dispatcher = TypeDispatcher() + for func in exp.ALL_FUNCTIONS: + dispatcher.register(func, cls._func) dispatcher.register(exp.Alias, cls._alias) dispatcher.register(exp.Binary, cls._func) dispatcher.register(exp.Boolean, cls._boolean) + dispatcher.register(exp.Cast, cls._cast) dispatcher.register(exp.Column, cls._column) dispatcher.register(exp.Literal, cls._literal) dispatcher.register(exp.Ordered, cls._ordered) - for func in exp.ALL_FUNCTIONS: - dispatcher.register(func, cls._func) + dispatcher.register(exp.Paren, cls._paren) return dispatcher @classmethod @@ -74,7 +85,7 @@ def _literal(literal: exp.Literal, context: dict[str, pd.DataFrame]): elif literal.is_int: return int(literal.this) elif literal.is_star: - return ... + return slice(None) else: return float(literal.this) @@ -82,14 +93,61 @@ def _literal(literal: exp.Literal, context: dict[str, pd.DataFrame]): def _boolean(boolean: exp.Boolean, context: dict[str, pd.DataFrame]): return True if boolean.this else False + @classmethod + def _cast( + cls, + cast: exp.Cast, + context: dict[str, pd.DataFrame], + ): + this = cls._visit_exp(cast.this, context) + to = getattr(exp.DataType.Type, str(cast.to)) + + if to == exp.DataType.Type.DATE: + if pandas.api.types.is_scalar(this): + return pandas.to_datetime(this).to_pydatetime().date() + else: + return pd.to_datetime(this).dt.date + elif to in (exp.DataType.Type.DATETIME, exp.DataType.Type.TIMESTAMP): + return pd.to_datetime(this) + elif to == exp.DataType.Type.BOOLEAN: + if pandas.api.types.is_scalar(this): + return bool(this) + else: + return this.astype(pandas.BooleanDtype()) + elif to in exp.DataType.TEXT_TYPES: + if pandas.api.types.is_scalar(this): + return str(this) + else: + return this.astype(pandas.StringDtype("pyarrow")) + elif to in _SQLGLOT_NUMBER_TO_PD_TYPES: + pd_type = _SQLGLOT_NUMBER_TO_PD_TYPES[to]() + if pandas.api.types.is_scalar(this): + return pd_type.type(this) + else: + return this.astype(pd_type) + else: + raise NotImplementedError(f"Casting {cast.this} to '{to}' not implemented.") + @staticmethod def _column(column: exp.Column, context: dict[str, pd.DataFrame]) -> pd.Series: return context[column.table][column.name] @classmethod - def _alias(cls, alias: exp.Alias, context: dict[str, pd.DataFrame]) -> pd.Series: + def _alias( + cls, + alias: exp.Alias, + context: dict[str, pd.DataFrame], + ) -> pd.Series: return cls._visit_exp(alias.this, context).rename(alias.output_name) + @classmethod + def _paren( + cls, + paren: exp.Paren, + context: dict[str, pd.DataFrame], + ): + return cls._visit_exp(paren.this, context) + @classproperty @lru_cache(1) def _operator_visitors(cls) -> TypeDispatcher: @@ -198,7 +256,14 @@ def _scan_csv(step: planner.Scan) -> dict[str, pd.DataFrame]: args = source.expressions filename = source.name - df = pd.read_csv(filename, **{arg.name: arg for arg in args}) + + delimiter = "," + args = iter(arg.name for arg in args) + for k, v in zip(args, args): + if k == "delimiter": + delimiter = v + + df = pd.read_csv(filename, sep=delimiter) return {alias: df} def _project_and_filter( @@ -226,7 +291,10 @@ def aggregate( if step.operands: for op in step.operands: - df[op.alias_or_name] = self._visit_exp(op, context) + if isinstance(op.this, exp.Star): + df[op.alias_or_name] = 1 + else: + df[op.alias_or_name] = self._visit_exp(op, context) aggregations = dict() names = list(step.group) @@ -261,14 +329,14 @@ def join( source = step.name source_df = context[source] source_context = {source: source_df} - column_slices = {source: slice(0, source_df.shape[1])} + column_slices = {source: slice(0, len(source_df.dtypes))} df = None for name, join in step.joins.items(): df = context[name] join_context = {name: df} start = max(r.stop for r in column_slices.values()) - column_slices[name] = slice(start, df.shape[1] + start) + column_slices[name] = slice(start, len(df.dtypes) + start) if join.get("source_key"): df = self._hash_join(join, source_context, join_context) @@ -365,7 +433,7 @@ def sort( df = next(iter(context.values())) for projection in step.projections: df[projection.alias_or_name] = self._visit_exp(projection, context) - slc = slice(df.shape[1] - len(step.projections), df.shape[1]) + slc = slice(len(df.dtypes) - len(step.projections), len(df.dtypes)) sort_context = {"": df, **context} diff --git a/xorbits_sql/tests/test_tpc_h.py b/xorbits_sql/tests/test_tpc_h.py new file mode 100644 index 0000000..59e8174 --- /dev/null +++ b/xorbits_sql/tests/test_tpc_h.py @@ -0,0 +1,64 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import duckdb +import pandas as pd +import pytest +from sqlglot import exp, parse_one + +from .. import execute +from .helpers import FILE_DIR, TPCH_SCHEMA, load_sql + +DIR = FILE_DIR + "/tpc-h/" + + +@pytest.fixture +def prepare_data(): + conn = duckdb.connect() + + for table, columns in TPCH_SCHEMA.items(): + conn.execute( + f""" + CREATE VIEW {table} AS + SELECT * + FROM READ_CSV('{DIR}{table}.csv', delim='|', header=True, columns={columns}) + """ + ) + + sqls = [(sql, expected) for _, sql, expected in load_sql("tpc-h/tpc-h.sql")] + + try: + yield conn, sqls + finally: + conn.close() + + +def _to_csv(expression: exp.Expression) -> exp.Expression: + if isinstance(expression, exp.Table) and expression.name not in ("revenue"): + return parse_one( + f"READ_CSV('{DIR}{expression.name}.csv', 'delimiter', '|') AS {expression.alias_or_name}" + ) + return expression + + +def test_execute_tpc_h(prepare_data): + conn, sqls = prepare_data + for sql, _ in sqls[:1]: + expected = conn.execute(sql).fetchdf() + result = execute( + parse_one(sql, dialect="duckdb").transform(_to_csv).sql(pretty=True), + TPCH_SCHEMA, + dialect="duckdb", + ).fetch() + pd.testing.assert_frame_equal(result, expected) From 2ce4d49ba31669c898be0508490c9375d982db5d Mon Sep 17 00:00:00 2001 From: qinxuye Date: Sun, 10 Sep 2023 13:09:50 +0800 Subject: [PATCH 02/10] More fixes --- xorbits_sql/core.py | 2 +- xorbits_sql/executor.py | 47 +++++++++++++++++++++++++++++------------ 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/xorbits_sql/core.py b/xorbits_sql/core.py index 37a5b57..8022cc3 100644 --- a/xorbits_sql/core.py +++ b/xorbits_sql/core.py @@ -109,7 +109,7 @@ def execute( logger.debug("Logical Plan: %s", plan) now = time.time() - result = XorbitsExecutor(tables=tables_).execute(plan) + result = XorbitsExecutor(tables=tables_, schema=schema).execute(plan) logger.debug("Query finished: %f", time.time() - now) diff --git a/xorbits_sql/executor.py b/xorbits_sql/executor.py index 8b5221f..205a991 100644 --- a/xorbits_sql/executor.py +++ b/xorbits_sql/executor.py @@ -20,7 +20,7 @@ import pandas import xorbits import xorbits.pandas as pd -from sqlglot import exp, planner +from sqlglot import MappingSchema, exp, planner from xoscar.utils import TypeDispatcher, classproperty from .errors import ExecuteError, UnsupportedError @@ -36,19 +36,22 @@ exp.Variance: "var", } -_SQLGLOT_NUMBER_TO_PD_TYPES = { - exp.DataType.Type.FLOAT: pandas.Float32Dtype, - exp.DataType.Type.DOUBLE: pandas.Float64Dtype, - exp.DataType.Type.INT: pandas.Int32Dtype, - exp.DataType.Type.TINYINT: pandas.Int8Dtype, - exp.DataType.Type.SMALLINT: pandas.Int16Dtype, - exp.DataType.Type.BIGINT: pandas.Int64Dtype, +_SQLGLOT_TYPE_TO_DTYPE = { + "float": "float32", + "double": "float64", + "int": "int32", + "tinyint": "int8", + "smallint": "int16", + "bigint": "int64", } class XorbitsExecutor: - def __init__(self, tables: Tables | None = None): + def __init__( + self, tables: Tables | None = None, schema: MappingSchema | None = None + ): self.tables = tables or Tables() + self.schema = schema @classproperty @lru_cache(1) @@ -119,10 +122,10 @@ def _cast( return str(this) else: return this.astype(pandas.StringDtype("pyarrow")) - elif to in _SQLGLOT_NUMBER_TO_PD_TYPES: - pd_type = _SQLGLOT_NUMBER_TO_PD_TYPES[to]() + elif str(cast.to) in _SQLGLOT_TYPE_TO_DTYPE: + pd_type = _SQLGLOT_TYPE_TO_DTYPE[str(cast.to)] if pandas.api.types.is_scalar(this): - return pd_type.type(this) + return pandas.Series([this], dtype=pd_type)[0] else: return this.astype(pd_type) else: @@ -250,7 +253,16 @@ def scan( return {step.name: self._project_and_filter(step, context, df)} @staticmethod - def _scan_csv(step: planner.Scan) -> dict[str, pd.DataFrame]: + def _schema_to_dtype(schema: dict[str, str]) -> dict[str, str]: + result = dict() + for name, type_name in schema.items(): + try: + result[name] = _SQLGLOT_TYPE_TO_DTYPE[type_name.lower()] + except KeyError: + continue + return result + + def _scan_csv(self, step: planner.Scan) -> dict[str, pd.DataFrame]: alias = step.source.alias source: exp.ReadCSV = step.source.this @@ -263,7 +275,11 @@ def _scan_csv(step: planner.Scan) -> dict[str, pd.DataFrame]: if k == "delimiter": delimiter = v - df = pd.read_csv(filename, sep=delimiter) + dtype = None + if self.schema and alias in self.schema.mapping: + dtype = self._schema_to_dtype(self.schema.mapping[alias]) + + df = pd.read_csv(filename, sep=delimiter, dtype=dtype) return {alias: df} def _project_and_filter( @@ -458,6 +474,9 @@ def sort( if isinstance(step.limit, int): df = df.iloc[: step.limit] + projection_columns = [p.alias_or_name for p in step.projections] + df = df.loc[:, projection_columns] + return {step.name: df} def set_operation( From e331374bb0ff4cb3da14367874c01a0a1b7c5fa4 Mon Sep 17 00:00:00 2001 From: qinxuye Date: Wed, 13 Sep 2023 22:00:49 +0800 Subject: [PATCH 03/10] Support like --- xorbits_sql/executor.py | 6 ++++++ xorbits_sql/tests/test_tpc_h.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/xorbits_sql/executor.py b/xorbits_sql/executor.py index 205a991..ec17656 100644 --- a/xorbits_sql/executor.py +++ b/xorbits_sql/executor.py @@ -165,6 +165,7 @@ def _operator_visitors(cls) -> TypeDispatcher: dispatcher.register(exp.LTE, operator.le) dispatcher.register(exp.Mul, operator.mul) dispatcher.register(exp.NEQ, operator.ne) + dispatcher.register(exp.Like, cls._like) dispatcher.register(exp.Sub, operator.sub) return dispatcher @@ -182,6 +183,11 @@ def _func(cls, func: exp.Expression, context: dict[str, pd.DataFrame]) -> pd.Ser ) return func(*values) + @classmethod + def _like(cls, left: pd.Series, right: str): + r = right.replace("_", ".").replace("%", ".*") + return left.str.contains(r, regex=True, na=True) + def execute(self, plan: planner.Plan) -> pd.DataFrame: finished = set() queue = set(plan.leaves) diff --git a/xorbits_sql/tests/test_tpc_h.py b/xorbits_sql/tests/test_tpc_h.py index 59e8174..c76ddc7 100644 --- a/xorbits_sql/tests/test_tpc_h.py +++ b/xorbits_sql/tests/test_tpc_h.py @@ -54,7 +54,7 @@ def _to_csv(expression: exp.Expression) -> exp.Expression: def test_execute_tpc_h(prepare_data): conn, sqls = prepare_data - for sql, _ in sqls[:1]: + for sql, _ in sqls[1:2]: expected = conn.execute(sql).fetchdf() result = execute( parse_one(sql, dialect="duckdb").transform(_to_csv).sql(pretty=True), From 2f59c5806af258461fe1fcec55f1abad35ba8f8c Mon Sep 17 00:00:00 2001 From: qinxuye Date: Sun, 17 Sep 2023 15:42:48 +0800 Subject: [PATCH 04/10] Support q02 --- xorbits_sql/executor.py | 59 +++++++++++++++---------------- xorbits_sql/tests/test_execute.py | 2 +- xorbits_sql/tests/test_tpc_h.py | 2 +- 3 files changed, 31 insertions(+), 32 deletions(-) diff --git a/xorbits_sql/executor.py b/xorbits_sql/executor.py index ec17656..92b42de 100644 --- a/xorbits_sql/executor.py +++ b/xorbits_sql/executor.py @@ -306,9 +306,7 @@ def _project_and_filter( def aggregate( self, step: planner.Aggregate, context: dict[str, pd.DataFrame] ) -> dict[str, pd.DataFrame]: - dfs = list(context.values()) - assert len(dfs) == 1 - df = dfs[0] + df = context[step.source] group_by = [self._visit_exp(g, context) for g in step.group.values()] if step.operands: @@ -338,7 +336,9 @@ def aggregate( result.columns = names if step.projections or step.condition: - result = self._project_and_filter(step, {step.name: result}, result) + result = self._project_and_filter( + step, {step.name: result, **{name: result for name in context}}, result + ) if isinstance(step.limit, int): result = result.iloc[: step.limit] @@ -361,9 +361,9 @@ def join( column_slices[name] = slice(start, len(df.dtypes) + start) if join.get("source_key"): - df = self._hash_join(join, source_context, join_context) + df = self._hash_join(join, source_df, source_context, df, join_context) else: - df = self._nested_loop_join(join, source_context, join_context) + df = self._nested_loop_join(join, source_df, df) condition = self._visit_exp(join["condition"], {name: df}) if condition is not True: @@ -373,6 +373,7 @@ def join( name: df.iloc[:, column_slice] for name, column_slice in column_slices.items() } + source_df = df if not step.condition and not step.projections: return source_context @@ -382,48 +383,48 @@ def join( if step.projections: return {step.name: sink} else: - return source_context + return {name: sink for name in source_context} def _nested_loop_join( self, join: dict, - source_context: dict[str, pd.DataFrame], - join_context: dict[str, pd.DataFrame], + source_df: pd.DataFrame, + join_df: pd.DataFrame, ) -> pd.DataFrame: def func(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame: if pandas.__version__ >= "1.2.0": - return left.merge(right, on="cross") + return left.merge(right, how="cross") else: left["_on"] = 1 right["_on"] = 1 result = left.merge(right, on="_on") return result[left.dtypes.index.tolist() + right.dtypes.index.tolist()] - source_df = next(iter(source_context.values())) - join_df = next(iter(join_context.values())) - return source_df.cartisan_chunk(join_df, func) + return source_df.cartesian_chunk(join_df, func) def _hash_join( self, join: dict, + source_df: pd.DataFrame, source_context: dict[str, pd.DataFrame], + join_df: pd.DataFrame, join_context: dict[str, pd.DataFrame], ) -> pd.DataFrame: cols = [] - source_df = next(iter(source_context.values())) + source_df = pd.DataFrame({c: source_df[c] for c in source_df.dtypes.index}) cols.extend(source_df.dtypes.index.tolist()) left_ons = [] for i, source_key in enumerate(join["source_key"]): - col_name = f"_on_{i}" + col_name = f"__on_{i}" left_ons.append(col_name) source_df[col_name] = self._visit_exp(source_key, source_context) - join_df = next(iter(join_context.values())) + join_df = pd.DataFrame({c: join_df[c] for c in join_df.dtypes.index}) cols.extend(join_df.dtypes.index.tolist()) right_ons = [] for i, join_key in enumerate(join["join_key"]): - col_name = f"_on_{i}" + col_name = f"__on_{i}" right_ons.append(col_name) join_df[col_name] = self._visit_exp(join_key, join_context) @@ -434,9 +435,12 @@ def _hash_join( how = "right" result = source_df.merge(join_df, how=how, left_on=left_ons, right_on=right_ons) - result = result[ - [col for col in result.dtypes.index if not col.startswith("_on_")] + ilocs = [ + i + for i, col in enumerate(result.dtypes.index) + if not col.startswith("__on_") ] + result = result.iloc[:, ilocs] result.columns = cols return result @@ -451,11 +455,10 @@ def _ordered(cls, ordered: exp.Ordered, context: dict[str, pd.DataFrame]): def sort( self, step: planner.Sort, context: dict[str, pd.DataFrame] ) -> dict[str, pd.DataFrame]: - assert len(context) == 1 - df = next(iter(context.values())) + df = context[step.name] + df = pd.DataFrame({n: df[n] for n in df.dtypes.index}) for projection in step.projections: df[projection.alias_or_name] = self._visit_exp(projection, context) - slc = slice(len(df.dtypes) - len(step.projections), len(df.dtypes)) sort_context = {"": df, **context} @@ -464,7 +467,7 @@ def sort( ascendings = [] na_position = None for i, (s, descending, cur_na_position) in enumerate(sort): - sort_col = f"_s_{i}" + sort_col = f"__s_{i}" sort_cols.append(sort_col) ascendings.append(not descending) if na_position is None: @@ -473,17 +476,13 @@ def sort( raise NotImplementedError("nulls_first must be same for all sort keys") df[sort_col] = s - df = df.sort_values( - by=sort_cols, ascending=ascendings, na_position=na_position - ).iloc[:, slc] + df = df.sort_values(by=sort_cols, ascending=ascendings, na_position=na_position) + df = df[[p.alias_or_name for p in step.projections]] if isinstance(step.limit, int): df = df.iloc[: step.limit] - projection_columns = [p.alias_or_name for p in step.projections] - df = df.loc[:, projection_columns] - - return {step.name: df} + return {step.name: df.reset_index(drop=True)} def set_operation( self, step: planner.SetOperation, context: dict[str, pd.DataFrame] diff --git a/xorbits_sql/tests/test_execute.py b/xorbits_sql/tests/test_execute.py index 20a22f8..97e6672 100644 --- a/xorbits_sql/tests/test_execute.py +++ b/xorbits_sql/tests/test_execute.py @@ -111,6 +111,6 @@ def test_sort(prepare_data): expected = raw_df.sort_values(by="c", ascending=False) expected["b"] *= 5 - expected = expected.iloc[:10] + expected = expected.iloc[:10].reset_index(drop=True) result = execute(sql, tables={"t1": xpd.DataFrame(raw_df)}).fetch() pd.testing.assert_frame_equal(result, expected) diff --git a/xorbits_sql/tests/test_tpc_h.py b/xorbits_sql/tests/test_tpc_h.py index c76ddc7..e09e178 100644 --- a/xorbits_sql/tests/test_tpc_h.py +++ b/xorbits_sql/tests/test_tpc_h.py @@ -54,7 +54,7 @@ def _to_csv(expression: exp.Expression) -> exp.Expression: def test_execute_tpc_h(prepare_data): conn, sqls = prepare_data - for sql, _ in sqls[1:2]: + for sql, _ in sqls[:2]: expected = conn.execute(sql).fetchdf() result = execute( parse_one(sql, dialect="duckdb").transform(_to_csv).sql(pretty=True), From 8b0001e884d91c5d0e30c687cec1ffc19d15a0d0 Mon Sep 17 00:00:00 2001 From: qinxuye Date: Mon, 18 Sep 2023 17:17:50 +0800 Subject: [PATCH 05/10] Support q03 --- xorbits_sql/executor.py | 3 ++- xorbits_sql/tests/test_tpc_h.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/xorbits_sql/executor.py b/xorbits_sql/executor.py index 92b42de..f495590 100644 --- a/xorbits_sql/executor.py +++ b/xorbits_sql/executor.py @@ -121,7 +121,8 @@ def _cast( if pandas.api.types.is_scalar(this): return str(this) else: - return this.astype(pandas.StringDtype("pyarrow")) + # TODO: convert to arrow string when it's default in pandas + return this.astype(str) elif str(cast.to) in _SQLGLOT_TYPE_TO_DTYPE: pd_type = _SQLGLOT_TYPE_TO_DTYPE[str(cast.to)] if pandas.api.types.is_scalar(this): diff --git a/xorbits_sql/tests/test_tpc_h.py b/xorbits_sql/tests/test_tpc_h.py index e09e178..46a8b28 100644 --- a/xorbits_sql/tests/test_tpc_h.py +++ b/xorbits_sql/tests/test_tpc_h.py @@ -54,7 +54,7 @@ def _to_csv(expression: exp.Expression) -> exp.Expression: def test_execute_tpc_h(prepare_data): conn, sqls = prepare_data - for sql, _ in sqls[:2]: + for sql, _ in sqls[:3]: expected = conn.execute(sql).fetchdf() result = execute( parse_one(sql, dialect="duckdb").transform(_to_csv).sql(pretty=True), From 531b244f4a74c40d9c7f7b792b714674c7a9fa98 Mon Sep 17 00:00:00 2001 From: qinxuye Date: Mon, 18 Sep 2023 17:25:32 +0800 Subject: [PATCH 06/10] Fix bool dtype --- xorbits_sql/executor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xorbits_sql/executor.py b/xorbits_sql/executor.py index f495590..58c9782 100644 --- a/xorbits_sql/executor.py +++ b/xorbits_sql/executor.py @@ -116,7 +116,8 @@ def _cast( if pandas.api.types.is_scalar(this): return bool(this) else: - return this.astype(pandas.BooleanDtype()) + # TODO: convert to arrow string when it's default in pandas + return this.astype(bool) elif to in exp.DataType.TEXT_TYPES: if pandas.api.types.is_scalar(this): return str(this) From abeb34616dc8321ba3e661ea60ba19d48a342e00 Mon Sep 17 00:00:00 2001 From: qinxuye Date: Mon, 18 Sep 2023 20:00:45 +0800 Subject: [PATCH 07/10] Support q04 --- xorbits_sql/executor.py | 24 ++++++++++++++++++++++-- xorbits_sql/tests/test_tpc_h.py | 2 +- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/xorbits_sql/executor.py b/xorbits_sql/executor.py index 58c9782..b7d6951 100644 --- a/xorbits_sql/executor.py +++ b/xorbits_sql/executor.py @@ -16,6 +16,7 @@ import operator from functools import lru_cache +from typing import Any import pandas import xorbits @@ -65,6 +66,8 @@ def _exp_visitors(cls) -> TypeDispatcher: dispatcher.register(exp.Cast, cls._cast) dispatcher.register(exp.Column, cls._column) dispatcher.register(exp.Literal, cls._literal) + dispatcher.register(exp.Null, cls._null) + dispatcher.register(exp.Unary, cls._func) dispatcher.register(exp.Ordered, cls._ordered) dispatcher.register(exp.Paren, cls._paren) return dispatcher @@ -96,6 +99,10 @@ def _literal(literal: exp.Literal, context: dict[str, pd.DataFrame]): def _boolean(boolean: exp.Boolean, context: dict[str, pd.DataFrame]): return True if boolean.this else False + @staticmethod + def _null(null: exp.Null, context: dict[str, pd.DataFrame]): + return None + @classmethod def _cast( cls, @@ -163,10 +170,12 @@ def _operator_visitors(cls) -> TypeDispatcher: dispatcher.register(exp.Div, operator.truediv) dispatcher.register(exp.GT, operator.gt) dispatcher.register(exp.GTE, operator.ge) + dispatcher.register(exp.Is, cls._is) dispatcher.register(exp.LT, operator.lt) dispatcher.register(exp.LTE, operator.le) dispatcher.register(exp.Mul, operator.mul) dispatcher.register(exp.NEQ, operator.ne) + dispatcher.register(exp.Not, operator.neg) dispatcher.register(exp.Like, cls._like) dispatcher.register(exp.Sub, operator.sub) return dispatcher @@ -190,6 +199,13 @@ def _like(cls, left: pd.Series, right: str): r = right.replace("_", ".").replace("%", ".*") return left.str.contains(r, regex=True, na=True) + @classmethod + def _is(cls, left: pd.Series, right: Any): + if right is None: + return left.isnull() + else: + return left == right + def execute(self, plan: planner.Plan) -> pd.DataFrame: finished = set() queue = set(plan.leaves) @@ -334,8 +350,12 @@ def aggregate( column=agg.this.alias_or_name, aggfunc=aggfunc ) - result = df.groupby(group_by).agg(**aggregations).reset_index() - result.columns = names + if aggregations: + result = df.groupby(group_by).agg(**aggregations).reset_index() + result.columns = names + else: + assert len(group_by) == len(names) + result = pd.DataFrame(dict(zip(names, group_by))).drop_duplicates() if step.projections or step.condition: result = self._project_and_filter( diff --git a/xorbits_sql/tests/test_tpc_h.py b/xorbits_sql/tests/test_tpc_h.py index 46a8b28..dd87433 100644 --- a/xorbits_sql/tests/test_tpc_h.py +++ b/xorbits_sql/tests/test_tpc_h.py @@ -54,7 +54,7 @@ def _to_csv(expression: exp.Expression) -> exp.Expression: def test_execute_tpc_h(prepare_data): conn, sqls = prepare_data - for sql, _ in sqls[:3]: + for sql, _ in sqls[:4]: expected = conn.execute(sql).fetchdf() result = execute( parse_one(sql, dialect="duckdb").transform(_to_csv).sql(pretty=True), From d0ccf0808e7d8d6747530e5918808c1f3b4e48c2 Mon Sep 17 00:00:00 2001 From: qinxuye Date: Mon, 18 Sep 2023 20:06:14 +0800 Subject: [PATCH 08/10] Support q05 --- xorbits_sql/tests/test_tpc_h.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xorbits_sql/tests/test_tpc_h.py b/xorbits_sql/tests/test_tpc_h.py index dd87433..0be2a08 100644 --- a/xorbits_sql/tests/test_tpc_h.py +++ b/xorbits_sql/tests/test_tpc_h.py @@ -54,7 +54,7 @@ def _to_csv(expression: exp.Expression) -> exp.Expression: def test_execute_tpc_h(prepare_data): conn, sqls = prepare_data - for sql, _ in sqls[:4]: + for sql, _ in sqls[:5]: expected = conn.execute(sql).fetchdf() result = execute( parse_one(sql, dialect="duckdb").transform(_to_csv).sql(pretty=True), From 2c86ca7625a87ab27473b7d6870497fce556bdb6 Mon Sep 17 00:00:00 2001 From: qinxuye Date: Mon, 18 Sep 2023 20:29:01 +0800 Subject: [PATCH 09/10] Fix --- xorbits_sql/executor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/xorbits_sql/executor.py b/xorbits_sql/executor.py index b7d6951..2aa0fbb 100644 --- a/xorbits_sql/executor.py +++ b/xorbits_sql/executor.py @@ -176,6 +176,7 @@ def _operator_visitors(cls) -> TypeDispatcher: dispatcher.register(exp.Mul, operator.mul) dispatcher.register(exp.NEQ, operator.ne) dispatcher.register(exp.Not, operator.neg) + dispatcher.register(exp.Or, operator.or_) dispatcher.register(exp.Like, cls._like) dispatcher.register(exp.Sub, operator.sub) return dispatcher @@ -351,7 +352,10 @@ def aggregate( ) if aggregations: - result = df.groupby(group_by).agg(**aggregations).reset_index() + if step.group: + result = df.groupby(group_by).agg(**aggregations).reset_index() + else: + result = df.agg(**aggregations) result.columns = names else: assert len(group_by) == len(names) From 692fb945fe0d5a9cdc8542b2a0886bbeb92ab22a Mon Sep 17 00:00:00 2001 From: qinxuye Date: Tue, 19 Sep 2023 19:57:25 +0800 Subject: [PATCH 10/10] Support q06 --- xorbits_sql/executor.py | 2 +- xorbits_sql/tests/test_tpc_h.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xorbits_sql/executor.py b/xorbits_sql/executor.py index 2aa0fbb..852aa99 100644 --- a/xorbits_sql/executor.py +++ b/xorbits_sql/executor.py @@ -355,7 +355,7 @@ def aggregate( if step.group: result = df.groupby(group_by).agg(**aggregations).reset_index() else: - result = df.agg(**aggregations) + result = df.agg(**aggregations).reset_index(drop=True) result.columns = names else: assert len(group_by) == len(names) diff --git a/xorbits_sql/tests/test_tpc_h.py b/xorbits_sql/tests/test_tpc_h.py index 0be2a08..3d8aaa0 100644 --- a/xorbits_sql/tests/test_tpc_h.py +++ b/xorbits_sql/tests/test_tpc_h.py @@ -54,7 +54,7 @@ def _to_csv(expression: exp.Expression) -> exp.Expression: def test_execute_tpc_h(prepare_data): conn, sqls = prepare_data - for sql, _ in sqls[:5]: + for sql, _ in sqls[:6]: expected = conn.execute(sql).fetchdf() result = execute( parse_one(sql, dialect="duckdb").transform(_to_csv).sql(pretty=True),