Skip to content

Commit

Permalink
[Substrait] Add test that executes exported plan in DataFusion. (iree…
Browse files Browse the repository at this point in the history
…-org/iree-llvm-sandbox#823)

This PR adds an end-to-end test with Apache DataFusion, an Arrow-based SQL execution engine that can consume Substrait plans. This illustrates how our Substrait plans can run on different engines and provide a base for more complete compatibility tests, which we may want to do at some point down the road.

Signed-off-by: Ingo Müller <[email protected]>
  • Loading branch information
ingomueller-net committed Oct 15, 2024
1 parent b931768 commit df050bb
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
-r third_party/llvm-project/mlir/python/requirements.txt

# Testing.
datafusion==32.0.0
lit
pyarrow

Expand Down
54 changes: 54 additions & 0 deletions test/python/dialects/substrait/e2e_datafusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# RUN: %PYTHON %s | FileCheck %s

import datafusion
from datafusion import substrait as dfss
import pyarrow as pa

from mlir_structured.dialects import substrait as ss
from mlir_structured import ir


def run(f):
print("\nTEST:", f.__name__)
with ir.Context(), ir.Location.unknown():
ss.register_dialect()
f()
return f


# CHECK-LABEL: TEST: testNamedTable
@run
def testNamedTable():
# Set up test table.
ctx = datafusion.SessionContext()
columns = {"a": [1, 2, 3], "b": [7, 8, 9]}
schema = pa.schema([('a', pa.int32()), ('b', pa.int32())])
batch = pa.RecordBatch.from_pydict(columns, schema=schema)
ctx.register_record_batches("t", [[batch]])

# Set up test plan in MLIR.
plan = ir.Module.parse('''
substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t as ["a", "b"] : tuple<si32, si32>
yield %0 : tuple<si32, si32>
}
}
''')

# Export MLIR plan to protobuf.
pb_plan = ss.to_binpb(plan.operation)
pb_plan = pb_plan.encode('utf8')

# Import plan in datafusion, execute, and print result.
ss_plan = dfss.substrait.serde.deserialize_bytes(pb_plan)
df_plan = dfss.substrait.consumer.from_substrait_plan(ctx, ss_plan)
df = ctx.create_dataframe_from_logical_plan(df_plan)

print(df.to_arrow_table())
# CHECK-NEXT: pyarrow.Table
# CHECK-NEXT: a: int32
# CHECK-NEXT: b: int32
# CHECK-NEXT: ----
# CHECK-NEXT{LITERAL}: a: [[1,2,3]]
# CHECK-NEXT{LITERAL}: b: [[7,8,9]]

0 comments on commit df050bb

Please sign in to comment.