diff --git a/native-io/lakesoul-io-java/src/main/java/com/dmetasoul/lakesoul/lakesoul/io/substrait/SubstraitUtil.java b/native-io/lakesoul-io-java/src/main/java/com/dmetasoul/lakesoul/lakesoul/io/substrait/SubstraitUtil.java index f12c76960..b7ca656af 100644 --- a/native-io/lakesoul-io-java/src/main/java/com/dmetasoul/lakesoul/lakesoul/io/substrait/SubstraitUtil.java +++ b/native-io/lakesoul-io-java/src/main/java/com/dmetasoul/lakesoul/lakesoul/io/substrait/SubstraitUtil.java @@ -29,6 +29,7 @@ import org.apache.arrow.c.CDataDictionaryProvider; import org.apache.arrow.c.Data; import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.IntervalUnit; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; @@ -51,7 +52,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -import static io.substrait.extension.DefaultExtensionCatalog.FUNCTIONS_BOOLEAN; +import static io.substrait.extension.DefaultExtensionCatalog.*; public class SubstraitUtil { public static final SimpleExtension.ExtensionCollection EXTENSIONS; @@ -64,6 +65,8 @@ public class SubstraitUtil { public static final Expression CONST_FALSE = ExpressionCreator.bool(false, false); + public static final Expression CONST_ZERO = ExpressionCreator.i64(false, 0); + private static final LibLakeSoulIO LIB; private static final Pointer BUFFER1; @@ -97,8 +100,37 @@ public static Expression or(Expression left, Expression right) { } public static Expression not(Expression expression) { - SimpleExtension.FunctionAnchor fa = SimpleExtension.FunctionAnchor.of(BooleanNamespace, "not:bool"); - return ExpressionCreator.scalarFunction(EXTENSIONS.getScalarFunction(fa), TypeCreator.NULLABLE.BOOLEAN, expression); + return makeUnary(expression, FUNCTIONS_BOOLEAN, "not:bool", TypeCreator.NULLABLE.BOOLEAN); + } + + public static Expression in(Expression expr, List set) { + List eqList = set.stream().map( + lit -> makeBinary(expr, lit, FUNCTIONS_COMPARISON, "equal:any_any", TypeCreator.NULLABLE.BOOLEAN) + ).collect(Collectors.toList()); + Expression ret = null; + for (int i = 0; i < eqList.size(); i++) { + if (i == 0) { + ret = eqList.get(i); + } else { + ret = or(ret, eqList.get(i)); + } + } + return ret; + } + + public static Expression notIn(Expression expr, List set) { + List notEqList = set.stream().map( + lit -> makeBinary(expr, lit, FUNCTIONS_COMPARISON, "not_equal:any_any", TypeCreator.NULLABLE.BOOLEAN) + ).collect(Collectors.toList()); + Expression ret = null; + for (int i = 0; i < notEqList.size(); i++) { + if (i == 0) { + ret = notEqList.get(i); + } else { + ret = and(ret, notEqList.get(i)); + } + } + return ret; } public static Expression makeBinary(Expression left, Expression right, String namespace, String funcKey, Type type) { @@ -106,6 +138,11 @@ public static Expression makeBinary(Expression left, Expression right, String na return ExpressionCreator.scalarFunction(EXTENSIONS.getScalarFunction(fa), type, left, right); } + public static Expression makeUnary(Expression expr, String namespace, String funcKey, Type type) { + SimpleExtension.FunctionAnchor fa = SimpleExtension.FunctionAnchor.of(namespace, funcKey); + return ExpressionCreator.scalarFunction(EXTENSIONS.getScalarFunction(fa), type, expr); + } + public static io.substrait.proto.Plan substraitExprToProto(Expression e, String tableName) { return planToProto(exprToFilter(e, tableName)); } @@ -329,7 +366,9 @@ public Type visit(ArrowType.Int anInt) { @Override public Type visit(ArrowType.FloatingPoint floatingPoint) { - return typeCreator.FP32; + if (floatingPoint.getPrecision() == FloatingPointPrecision.SINGLE) return typeCreator.FP32; + if (floatingPoint.getPrecision() == FloatingPointPrecision.DOUBLE) return typeCreator.FP64; + return null; } @Override @@ -393,7 +432,7 @@ public Type visit(ArrowType.Duration duration) { } } - public static Expression anyToSubstraitLiteral(Type type, Object any) throws IOException { + public static Expression.Literal anyToSubstraitLiteral(Type type, Object any) throws IOException { if (type instanceof Type.Date) { if (any instanceof Integer) { return ExpressionCreator.date(false, (Integer) any); @@ -415,38 +454,39 @@ public static Expression anyToSubstraitLiteral(Type type, Object any) throws IOE return ExpressionCreator.timestampTZ(false, DateTimeUtils.toMicros(any)); } } - if (type instanceof Type.Str || any instanceof String) { + + if (any instanceof String) { return ExpressionCreator.string(false, (String) any); } - if (type instanceof Type.Bool || any instanceof Boolean) { + if (any instanceof Boolean) { return ExpressionCreator.bool(false, (Boolean) any); } - if (type instanceof Type.Binary || any instanceof byte[]) { + if (any instanceof byte[]) { return ExpressionCreator.binary(false, (byte[]) any); } - if (type instanceof Type.I8 || any instanceof Byte) { - return ExpressionCreator.i8(false, Byte.parseByte(any.toString())); + if (any instanceof Byte) { + return ExpressionCreator.i8(false, (Byte) any); } - if (type instanceof Type.I16 || any instanceof Short) { - return ExpressionCreator.i16(false, Short.parseShort(any.toString())); + if (any instanceof Short) { + return ExpressionCreator.i16(false, (Short) any); } - if (type instanceof Type.I32) { - return ExpressionCreator.i32(false, Integer.parseInt(any.toString())); + if (any instanceof Integer) { + return ExpressionCreator.i32(false, (Integer) any); } - if (type instanceof Type.I64) { - return ExpressionCreator.i64(false, Long.parseLong(any.toString())); + if (any instanceof Long) { + return ExpressionCreator.i64(false, (Long) any); } - if (type instanceof Type.FP32 || any instanceof Float) { + if (any instanceof Float) { return ExpressionCreator.fp32(false, (Float) any); } - if (type instanceof Type.FP64 || any instanceof Double) { + if (any instanceof Double) { return ExpressionCreator.fp64(false, (Double) any); } if (type instanceof Type.Decimal || any instanceof BigDecimal) { int precision = 10; int scale = 0; - if (type != null) { + if (type instanceof Type.Decimal) { precision = ((Type.Decimal) type).precision(); scale = ((Type.Decimal) type).scale(); }