Skip to content

Commit

Permalink
[NativeIO] Substrait filter completion (#512)
Browse files Browse the repository at this point in the history
Signed-off-by: zenghua <[email protected]>
Co-authored-by: zenghua <[email protected]>
  • Loading branch information
Ceng23333 and zenghua authored Jul 22, 2024
1 parent 9d4a49a commit b805300
Showing 1 changed file with 59 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.arrow.c.CDataDictionaryProvider;
import org.apache.arrow.c.Data;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.IntervalUnit;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
Expand All @@ -51,7 +52,7 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static io.substrait.extension.DefaultExtensionCatalog.FUNCTIONS_BOOLEAN;
import static io.substrait.extension.DefaultExtensionCatalog.*;

public class SubstraitUtil {
public static final SimpleExtension.ExtensionCollection EXTENSIONS;
Expand All @@ -64,6 +65,8 @@ public class SubstraitUtil {

public static final Expression CONST_FALSE = ExpressionCreator.bool(false, false);

public static final Expression CONST_ZERO = ExpressionCreator.i64(false, 0);

private static final LibLakeSoulIO LIB;

private static final Pointer BUFFER1;
Expand Down Expand Up @@ -97,15 +100,49 @@ public static Expression or(Expression left, Expression right) {
}

public static Expression not(Expression expression) {
SimpleExtension.FunctionAnchor fa = SimpleExtension.FunctionAnchor.of(BooleanNamespace, "not:bool");
return ExpressionCreator.scalarFunction(EXTENSIONS.getScalarFunction(fa), TypeCreator.NULLABLE.BOOLEAN, expression);
return makeUnary(expression, FUNCTIONS_BOOLEAN, "not:bool", TypeCreator.NULLABLE.BOOLEAN);
}

public static Expression in(Expression expr, List<Expression.Literal> set) {
List<Expression> eqList = set.stream().map(
lit -> makeBinary(expr, lit, FUNCTIONS_COMPARISON, "equal:any_any", TypeCreator.NULLABLE.BOOLEAN)
).collect(Collectors.toList());
Expression ret = null;
for (int i = 0; i < eqList.size(); i++) {
if (i == 0) {
ret = eqList.get(i);
} else {
ret = or(ret, eqList.get(i));
}
}
return ret;
}

public static Expression notIn(Expression expr, List<Expression.Literal> set) {
List<Expression> notEqList = set.stream().map(
lit -> makeBinary(expr, lit, FUNCTIONS_COMPARISON, "not_equal:any_any", TypeCreator.NULLABLE.BOOLEAN)
).collect(Collectors.toList());
Expression ret = null;
for (int i = 0; i < notEqList.size(); i++) {
if (i == 0) {
ret = notEqList.get(i);
} else {
ret = and(ret, notEqList.get(i));
}
}
return ret;
}

public static Expression makeBinary(Expression left, Expression right, String namespace, String funcKey, Type type) {
SimpleExtension.FunctionAnchor fa = SimpleExtension.FunctionAnchor.of(namespace, funcKey);
return ExpressionCreator.scalarFunction(EXTENSIONS.getScalarFunction(fa), type, left, right);
}

public static Expression makeUnary(Expression expr, String namespace, String funcKey, Type type) {
SimpleExtension.FunctionAnchor fa = SimpleExtension.FunctionAnchor.of(namespace, funcKey);
return ExpressionCreator.scalarFunction(EXTENSIONS.getScalarFunction(fa), type, expr);
}

public static io.substrait.proto.Plan substraitExprToProto(Expression e, String tableName) {
return planToProto(exprToFilter(e, tableName));
}
Expand Down Expand Up @@ -329,7 +366,9 @@ public Type visit(ArrowType.Int anInt) {

@Override
public Type visit(ArrowType.FloatingPoint floatingPoint) {
return typeCreator.FP32;
if (floatingPoint.getPrecision() == FloatingPointPrecision.SINGLE) return typeCreator.FP32;
if (floatingPoint.getPrecision() == FloatingPointPrecision.DOUBLE) return typeCreator.FP64;
return null;
}

@Override
Expand Down Expand Up @@ -393,7 +432,7 @@ public Type visit(ArrowType.Duration duration) {
}
}

public static Expression anyToSubstraitLiteral(Type type, Object any) throws IOException {
public static Expression.Literal anyToSubstraitLiteral(Type type, Object any) throws IOException {
if (type instanceof Type.Date) {
if (any instanceof Integer) {
return ExpressionCreator.date(false, (Integer) any);
Expand All @@ -415,38 +454,39 @@ public static Expression anyToSubstraitLiteral(Type type, Object any) throws IOE
return ExpressionCreator.timestampTZ(false, DateTimeUtils.toMicros(any));
}
}
if (type instanceof Type.Str || any instanceof String) {

if (any instanceof String) {
return ExpressionCreator.string(false, (String) any);
}
if (type instanceof Type.Bool || any instanceof Boolean) {
if (any instanceof Boolean) {
return ExpressionCreator.bool(false, (Boolean) any);
}
if (type instanceof Type.Binary || any instanceof byte[]) {
if (any instanceof byte[]) {
return ExpressionCreator.binary(false, (byte[]) any);
}

if (type instanceof Type.I8 || any instanceof Byte) {
return ExpressionCreator.i8(false, Byte.parseByte(any.toString()));
if (any instanceof Byte) {
return ExpressionCreator.i8(false, (Byte) any);
}
if (type instanceof Type.I16 || any instanceof Short) {
return ExpressionCreator.i16(false, Short.parseShort(any.toString()));
if (any instanceof Short) {
return ExpressionCreator.i16(false, (Short) any);
}
if (type instanceof Type.I32) {
return ExpressionCreator.i32(false, Integer.parseInt(any.toString()));
if (any instanceof Integer) {
return ExpressionCreator.i32(false, (Integer) any);
}
if (type instanceof Type.I64) {
return ExpressionCreator.i64(false, Long.parseLong(any.toString()));
if (any instanceof Long) {
return ExpressionCreator.i64(false, (Long) any);
}
if (type instanceof Type.FP32 || any instanceof Float) {
if (any instanceof Float) {
return ExpressionCreator.fp32(false, (Float) any);
}
if (type instanceof Type.FP64 || any instanceof Double) {
if (any instanceof Double) {
return ExpressionCreator.fp64(false, (Double) any);
}
if (type instanceof Type.Decimal || any instanceof BigDecimal) {
int precision = 10;
int scale = 0;
if (type != null) {
if (type instanceof Type.Decimal) {
precision = ((Type.Decimal) type).precision();
scale = ((Type.Decimal) type).scale();
}
Expand Down

0 comments on commit b805300

Please sign in to comment.