Skip to content

Commit

Permalink
[Substrait] Add support for SI1 and nested tuple types. (iree-org/ire…
Browse files Browse the repository at this point in the history
…e-llvm-sandbox#817)

Signed-off-by: Ingo Müller <[email protected]>
  • Loading branch information
ingomueller-net committed Oct 15, 2024
1 parent df050bb commit e00f3a2
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 16 deletions.
1 change: 1 addition & 0 deletions include/structured/Dialect/Substrait/IR/SubstraitTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class Substrait_Type<string name, string typeMnemonic, list<Trait> traits = []>
// TODO(ingomueller): Add the other low-hanging fruits here.
def Substrait_AtomicTypes {
list<Type> types = [
SI1, // Boolean
SI32 // I32
];
}
Expand Down
53 changes: 42 additions & 11 deletions lib/Target/SubstraitPB/Export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,51 @@ DECLARE_EXPORT_FUNC(RelOpInterface, Rel)

FailureOr<std::unique_ptr<proto::Type>> exportType(Location loc,
mlir::Type mlirType) {
// TODO(ingomueller): Support other types.
auto si32 = IntegerType::get(mlirType.getContext(), 32, IntegerType::Signed);
if (mlirType != si32)
return emitError(loc) << "could not export unsupported type " << mlirType;
MLIRContext *context = mlirType.getContext();

// Handle SI1.
auto si1 = IntegerType::get(context, 1, IntegerType::Signed);
if (mlirType == si1) {
// TODO(ingomueller): support other nullability modes.
auto i1Type = std::make_unique<proto::Type::Boolean>();
i1Type->set_nullability(
Type_Nullability::Type_Nullability_NULLABILITY_REQUIRED);

auto type = std::make_unique<proto::Type>();
type->set_allocated_bool_(i1Type.release());
return std::move(type);
}

// TODO(ingomueller): support other nullability modes.
auto i32Type = std::make_unique<proto::Type::I32>();
i32Type->set_nullability(
Type_Nullability::Type_Nullability_NULLABILITY_REQUIRED);
// Handle SI32.
auto si32 = IntegerType::get(context, 32, IntegerType::Signed);
if (mlirType == si32) {
// TODO(ingomueller): support other nullability modes.
auto i32Type = std::make_unique<proto::Type::I32>();
i32Type->set_nullability(
Type_Nullability::Type_Nullability_NULLABILITY_REQUIRED);

auto type = std::make_unique<proto::Type>();
type->set_allocated_i32(i32Type.release());
return std::move(type);
}

auto type = std::make_unique<proto::Type>();
type->set_allocated_i32(i32Type.release());
if (auto tupleType = llvm::dyn_cast<TupleType>(mlirType)) {
auto structType = std::make_unique<proto::Type::Struct>();
for (mlir::Type fieldType : tupleType.getTypes()) {
// Convert field type recursively.
FailureOr<std::unique_ptr<proto::Type>> type = exportType(loc, fieldType);
if (failed(type))
return failure();
*structType->add_types() = *type.value();
}

auto type = std::make_unique<proto::Type>();
type->set_allocated_struct_(structType.release());
return std::move(type);
}

return std::move(type);
// TODO(ingomueller): Support other types.
return emitError(loc) << "could not export unsupported type " << mlirType;
}

FailureOr<std::unique_ptr<Plan>> exportOperation(ModuleOp op) {
Expand Down
29 changes: 25 additions & 4 deletions lib/Target/SubstraitPB/Import.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,36 @@ DECLARE_IMPORT_FUNC(Rel, Rel, RelOpInterface)

static mlir::FailureOr<mlir::Type> importType(MLIRContext *context,
const proto::Type &type) {
// TODO(ingomueller): Support more types.
if (!type.has_i32()) {

proto::Type::KindCase kind_case = type.kind_case();
switch (kind_case) {
case proto::Type::kBool: {
return IntegerType::get(context, 1, IntegerType::Signed);
}
case proto::Type::kI32: {
return IntegerType::get(context, 32, IntegerType::Signed);
}
case proto::Type::kStruct: {
const proto::Type::Struct &structType = type.struct_();
llvm::SmallVector<mlir::Type> fieldTypes;
fieldTypes.reserve(structType.types_size());
for (const proto::Type &fieldType : structType.types()) {
FailureOr<mlir::Type> mlirFieldType = importType(context, fieldType);
if (failed(mlirFieldType))
return failure();
fieldTypes.push_back(mlirFieldType.value());
}
return TupleType::get(context, fieldTypes);
}
// TODO(ingomueller): Support more types.
default: {
auto loc = UnknownLoc::get(context);
const pb::FieldDescriptor *desc =
proto::Type::GetDescriptor()->FindFieldByNumber(type.kind_case());
proto::Type::GetDescriptor()->FindFieldByNumber(kind_case);
return emitError(loc) << "could not import unsupported type "
<< desc->name();
}
return IntegerType::get(context, 32, IntegerType::Signed);
}
}

static mlir::FailureOr<NamedTableOp>
Expand Down
39 changes: 38 additions & 1 deletion test/Target/SubstraitPB/Export/types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
// RUN: structured-translate -substrait-to-protobuf %s \
// RUN: --split-input-file --output-split-marker="# -----" \
// RUN: | structured-translate -protobuf-to-substrait \
// RUN: --split-input-file="# -----" \
// RUN: --split-input-file="# -----" --output-split-marker="// ""-----" \
// RUN: | structured-translate -substrait-to-protobuf \
// RUN: --split-input-file --output-split-marker="# -----" \
// RUN: | FileCheck %s

// CHECK-LABEL: relations {
Expand All @@ -30,3 +31,39 @@ substrait.plan version 0 : 42 : 1 {
yield %0 : tuple<si32>
}
}

// -----

// CHECK-LABEL: relations {
// CHECK-NEXT: rel {
// CHECK-NEXT: read {
// CHECK: base_schema {
// CHECK-NEXT: names: "a"
// CHECK-NEXT: names: "b"
// CHECK-NEXT: names: "c"
// CHECK-NEXT: struct {
// CHECK-NEXT: types {
// CHECK-NEXT: bool {
// CHECK-NEXT: nullability: NULLABILITY_REQUIRED
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: types {
// CHECK-NEXT: struct {
// CHECK-NEXT: types {
// CHECK-NEXT: bool {
// CHECK-NEXT: nullability: NULLABILITY_REQUIRED
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: nullability: NULLABILITY_REQUIRED
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: named_table {

substrait.plan version 0 : 42 : 1 {
relation {
%0 = named_table @t1 as ["a", "b", "c"] : tuple<si1, tuple<si1>>
yield %0 : tuple<si1, tuple<si1>>
}
}
51 changes: 51 additions & 0 deletions test/Target/SubstraitPB/Import/types.textpb
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# RUN: structured-translate -protobuf-to-substrait %s \
# RUN: --split-input-file="# ""-----" \
# RUN: | FileCheck %s

# RUN: structured-translate -protobuf-to-substrait %s \
# RUN: --split-input-file="# ""-----" --output-split-marker="// -----" \
# RUN: | structured-translate -substrait-to-protobuf \
# RUN: --split-input-file --output-split-marker="# ""-----" \
# RUN: | structured-translate -protobuf-to-substrait \
# RUN: --split-input-file="# ""-----" --output-split-marker="// -----" \
# RUN: | FileCheck %s

# CHECK: substrait.plan
Expand Down Expand Up @@ -39,3 +43,50 @@ version {
minor_number: 42
patch_number: 1
}

# -----

# CHECK: substrait.plan
# CHECK-NEXT: relation
# CHECK-NEXT: named_table
# CHECK-SAME: : tuple<si1, tuple<si1>>

relations {
rel {
read {
common {
direct {
}
}
base_schema {
names: "a"
names: "b"
names: "c"
struct {
types {
bool {
nullability: NULLABILITY_REQUIRED
}
}
types {
struct {
types {
bool {
nullability: NULLABILITY_REQUIRED
}
}
}
}
nullability: NULLABILITY_REQUIRED
}
}
named_table {
names: "t1"
}
}
}
}
version {
minor_number: 42
patch_number: 1
}

0 comments on commit e00f3a2

Please sign in to comment.