From 49d3b9d6ba6507cd6f4af74846cafb8ad38bcbe4 Mon Sep 17 00:00:00 2001 From: Pratik Joseph Dabre Date: Wed, 12 Jun 2024 09:59:43 -0700 Subject: [PATCH] Add utility functions to resolve intermediate type and fix bug in json handling of aggregate functions --- .../common/type/TypeSignatureUtils.java | 92 +++++++++++++++++++ ...actSqlInvokedFunctionNamespaceManager.java | 1 + .../NativeFunctionNamespaceManager.java | 60 ++++++++++-- .../main/types/PrestoToVeloxQueryPlan.cpp | 6 +- 4 files changed, 148 insertions(+), 11 deletions(-) create mode 100644 presto-common/src/main/java/com/facebook/presto/common/type/TypeSignatureUtils.java diff --git a/presto-common/src/main/java/com/facebook/presto/common/type/TypeSignatureUtils.java b/presto-common/src/main/java/com/facebook/presto/common/type/TypeSignatureUtils.java new file mode 100644 index 0000000000000..014f8468c4c50 --- /dev/null +++ b/presto-common/src/main/java/com/facebook/presto/common/type/TypeSignatureUtils.java @@ -0,0 +1,92 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.common.type; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +public final class TypeSignatureUtils +{ + private TypeSignatureUtils() {} + + public static TypeSignature resolveIntermediateType(TypeSignature typeSignature, List parameters, List argumentTypes) + { + Map typeSignatureMap = getTypeSignatureMap(parameters, argumentTypes); + return resolveTypeSignatures(typeSignature, typeSignatureMap).getTypeSignature(); + } + + // todo: Change to ImmutableList when open sourcing + + private static NamedTypeSignature resolveTypeSignatures(TypeSignature typeSignature, Map typeSignatureMap) + { + TypeSignature resolvedTypeSignature = typeSignatureMap.getOrDefault(typeSignature, typeSignature); + List namedTypeSignatures = new ArrayList<>(); + List typeSignatures = new ArrayList<>(); + List typeSignaturesList = typeSignature.getParameters(); + for (TypeSignatureParameter typeSignatureParameter : typeSignaturesList) { + TypeSignature typeSignatureOrNamedTypeSignature = typeSignatureParameter.getTypeSignatureOrNamedTypeSignature().orElseThrow(() -> + new IllegalStateException("Could not get type signature for type parameter [" + typeSignatureParameter + "]")); + TypeSignature resolvedTypeParameterSignature = typeSignatureMap.getOrDefault(typeSignatureOrNamedTypeSignature, typeSignatureOrNamedTypeSignature); + if (resolvedTypeSignature.getBase().equals("row")) { + if (!typeSignatureOrNamedTypeSignature.getParameters().isEmpty()) { + namedTypeSignatures.add(resolveTypeSignatures(resolvedTypeParameterSignature, typeSignatureMap)); + } + else { + namedTypeSignatures.add(new NamedTypeSignature(Optional.empty(), new TypeSignature(resolvedTypeParameterSignature.getBase(), Collections.emptyList()))); + } + } + else { + if (!typeSignatureOrNamedTypeSignature.getParameters().isEmpty()) { + typeSignatures.add(resolveTypeSignatures(resolvedTypeParameterSignature, typeSignatureMap).getTypeSignature()); + } + else { + typeSignatures.add(new TypeSignature(resolvedTypeParameterSignature.getBase(), Collections.emptyList())); + } + } + } + return new NamedTypeSignature(Optional.empty(), new TypeSignature(resolvedTypeSignature.getBase(), + (typeSignatures.isEmpty() ? namedTypeSignatures : typeSignatures).stream().map( + signature -> signature instanceof NamedTypeSignature ? + TypeSignatureParameter.of((NamedTypeSignature) signature) + : TypeSignatureParameter.of((TypeSignature) signature)).collect(Collectors.toList()))); + } + + /** + * Parameter and argument type mapping must be consistent + */ + + public static Map getTypeSignatureMap(List parameters, List argumentTypes) + { + HashMap typeSignatureMap = new HashMap<>(); + if (argumentTypes.size() != parameters.size()) { + throw new IllegalStateException("Parameters size and argumentTypes size do not match!"); + } + for (int i = 0; i < argumentTypes.size(); i++) { + TypeSignature parameter = parameters.get(i); + TypeSignature argumentType = argumentTypes.get(i); + if (argumentTypes.get(i).getParameters().isEmpty()) { + typeSignatureMap.put(parameter, argumentType); + } + else { + typeSignatureMap.putAll(getTypeSignatureMap(parameter.getTypeParametersAsTypeSignatures(), argumentType.getTypeParametersAsTypeSignatures())); + } + } + return typeSignatureMap; + } +} diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/AbstractSqlInvokedFunctionNamespaceManager.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/AbstractSqlInvokedFunctionNamespaceManager.java index a33fc4c5ec24a..c824148f7b781 100644 --- a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/AbstractSqlInvokedFunctionNamespaceManager.java +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/AbstractSqlInvokedFunctionNamespaceManager.java @@ -357,6 +357,7 @@ protected AggregationFunctionImplementation sqlInvokedFunctionToAggregationImple getClass().getSimpleName(), implementationType)); case CPP: + case REST: checkArgument( function.getAggregationMetadata().isPresent(), "Need aggregationMetadata to get aggregation function implementation"); diff --git a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/prestissimo/NativeFunctionNamespaceManager.java b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/prestissimo/NativeFunctionNamespaceManager.java index 4d9bde7c5ea1d..1f1ba7bb57bf5 100644 --- a/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/prestissimo/NativeFunctionNamespaceManager.java +++ b/presto-function-namespace-managers/src/main/java/com/facebook/presto/functionNamespace/prestissimo/NativeFunctionNamespaceManager.java @@ -28,6 +28,7 @@ import com.facebook.presto.spi.NodeManager; import com.facebook.presto.spi.PrestoException; import com.facebook.presto.spi.function.AggregationFunctionImplementation; +import com.facebook.presto.spi.function.AggregationFunctionMetadata; import com.facebook.presto.spi.function.AlterRoutineCharacteristics; import com.facebook.presto.spi.function.FunctionHandle; import com.facebook.presto.spi.function.FunctionMetadata; @@ -40,6 +41,7 @@ import com.facebook.presto.spi.function.SqlFunctionHandle; import com.facebook.presto.spi.function.SqlFunctionId; import com.facebook.presto.spi.function.SqlFunctionSupplier; +import com.facebook.presto.spi.function.SqlInvokedAggregationFunctionImplementation; import com.facebook.presto.spi.function.SqlInvokedFunction; import com.facebook.presto.spi.function.TypeVariableConstraint; import com.google.common.base.Suppliers; @@ -59,7 +61,9 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; +import java.util.stream.Collectors; +import static com.facebook.presto.common.type.TypeSignatureUtils.resolveIntermediateType; import static com.facebook.presto.spi.StandardErrorCode.GENERIC_USER_ERROR; import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED; import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; @@ -141,19 +145,52 @@ public final AggregationFunctionImplementation getAggregateFunctionImplementatio checkArgument(functionHandle instanceof SqlFunctionHandle, "Unsupported FunctionHandle type '%s'", functionHandle.getClass().getSimpleName()); SqlFunctionHandle sqlFunctionHandle = (SqlFunctionHandle) functionHandle; + if (functionHandle instanceof NativeFunctionHandle) { + NativeFunctionHandle nativeFunctionHandle = (NativeFunctionHandle) functionHandle; + return processNativeFunctionHandle(nativeFunctionHandle, typeManager); + } + else { + return processSqlFunctionHandle(sqlFunctionHandle, typeManager); + } + } + + private AggregationFunctionImplementation processNativeFunctionHandle(NativeFunctionHandle nativeFunctionHandle, TypeManager typeManager) + { + if (!aggregationImplementationByHandle.containsKey(nativeFunctionHandle)) { + Signature signature = nativeFunctionHandle.getSignature(); + SqlFunction function = getSqlFunctionFromSignature(signature); + SqlInvokedFunction sqlFunction = (SqlInvokedFunction) function; + + checkArgument( + sqlFunction.getAggregationMetadata().isPresent(), + "Need aggregationMetadata to get aggregation function implementation"); + + AggregationFunctionMetadata aggregationMetadata = sqlFunction.getAggregationMetadata().get(); + TypeSignature intermediateType = aggregationMetadata.getIntermediateType(); + List typeSignatures = sqlFunction.getParameters().stream().map(Parameter::getType).collect(Collectors.toList()); + TypeSignature resolvedIntermediateType = resolveIntermediateType(intermediateType, typeSignatures, signature.getArgumentTypes()); + aggregationImplementationByHandle.put( + nativeFunctionHandle, + new SqlInvokedAggregationFunctionImplementation( + typeManager.getType(resolvedIntermediateType), + typeManager.getType(signature.getReturnType()), + aggregationMetadata.isOrderSensitive())); + } + return aggregationImplementationByHandle.get(nativeFunctionHandle); + } - // Cache results if applicable + private AggregationFunctionImplementation processSqlFunctionHandle(SqlFunctionHandle sqlFunctionHandle, TypeManager typeManager) + { if (!aggregationImplementationByHandle.containsKey(sqlFunctionHandle)) { SqlFunctionId functionId = sqlFunctionHandle.getFunctionId(); - if (!latestFunctions.containsKey(functionId)) { + if (!memoizedFunctionsSupplier.get().containsKey(functionId)) { throw new PrestoException(GENERIC_USER_ERROR, format("Function '%s' is missing from cache", functionId.getId())); } aggregationImplementationByHandle.put( sqlFunctionHandle, - sqlInvokedFunctionToAggregationImplementation(latestFunctions.get(functionId), typeManager)); + sqlInvokedFunctionToAggregationImplementation(memoizedFunctionsSupplier.get().get(functionId), typeManager)); } - return aggregationImplementationByHandle.get(sqlFunctionHandle); } @@ -303,18 +340,23 @@ public final FunctionHandle getFunctionHandle(Optional