Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
goodwanghan authored Jun 13, 2024
1 parent de00706 commit d20bfa0
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 7 deletions.
36 changes: 36 additions & 0 deletions tests/fugue_dask/test_execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from fugue_dask.execution_engine import DaskExecutionEngine
from fugue_test.builtin_suite import BuiltInTests
from fugue_test.execution_suite import ExecutionEngineTests
from fugue.column import col, all_cols
import fugue.column.functions as ff

_CONF = {
"fugue.rpc.server": "fugue.rpc.flask.FlaskRPCServer",
Expand All @@ -50,6 +52,40 @@ def test_get_parallelism(self):
def test__join_outer_pandas_incompatible(self):
return

# TODO: dask-sql 2024.5.0 has a bug, can't pass the HAVING tests
def test_select(self):
a = ArrayDataFrame(
[[1, 2], [None, 2], [None, 1], [3, 4], [None, 4]], "a:double,b:int"
)

# simple
b = fa.select(a, col("b"), (col("b") + 1).alias("c").cast(str))
self.df_eq(
b,
[[2, "3"], [2, "3"], [1, "2"], [4, "5"], [4, "5"]],
"b:int,c:str",
throw=True,
)

# with distinct
b = fa.select(
a, col("b"), (col("b") + 1).alias("c").cast(str), distinct=True
)
self.df_eq(
b,
[[2, "3"], [1, "2"], [4, "5"]],
"b:int,c:str",
throw=True,
)

# wildcard
b = fa.select(a, all_cols(), where=col("a") + col("b") == 3)
self.df_eq(b, [[1, 2]], "a:double,b:int", throw=True)

# aggregation
b = fa.select(a, col("a"), ff.sum(col("b")).cast(float).alias("b"))
self.df_eq(b, [[1, 2], [3, 4], [None, 7]], "a:double,b:double", throw=True)

def test_to_df(self):
e = self.engine
a = e.to_df([[1, 2], [3, 4]], "a:int,b:int")
Expand Down
16 changes: 9 additions & 7 deletions tests/fugue_ibis/mock/execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,17 @@ def sample(
f"one and only one of n and frac should be non-negative, {n}, {frac}"
),
)
tn = self.get_temp_table_name()
idf = self.to_df(df)
tn = f"({idf.native.compile()})"
if seed is not None:
_seed = f",{seed}"
else:
_seed = ""
if frac is not None:
sql = f"SELECT * FROM {tn} USING SAMPLE bernoulli({frac*100} PERCENT)"
sql = f"SELECT * FROM {tn} USING SAMPLE {frac*100}% (bernoulli{_seed})"
else:
sql = f"SELECT * FROM {tn} USING SAMPLE reservoir({n} ROWS)"
if seed is not None:
sql += f" REPEATABLE ({seed})"
idf = self.to_df(df)
_res = f"WITH {tn} AS ({idf.native.compile()}) " + sql
sql = f"SELECT * FROM {tn} USING SAMPLE {n} ROWS (reservoir{_seed})"
_res = f"SELECT * FROM ({sql})" # ibis has a bug to inject LIMIT
return self.to_df(self.backend.sql(_res))

def _register_df(
Expand Down

0 comments on commit d20bfa0

Please sign in to comment.