Skip to content

Commit

Permalink
Add utility functions to resolve intermediate type and fix bug in jso…
Browse files Browse the repository at this point in the history
…n handling of aggregate functions
  • Loading branch information
Pratik Joseph Dabre authored and Pratik Joseph Dabre committed Aug 1, 2024
1 parent 9c6435a commit 49d3b9d
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.common.type;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

public final class TypeSignatureUtils
{
private TypeSignatureUtils() {}

public static TypeSignature resolveIntermediateType(TypeSignature typeSignature, List<TypeSignature> parameters, List<TypeSignature> argumentTypes)
{
Map<TypeSignature, TypeSignature> typeSignatureMap = getTypeSignatureMap(parameters, argumentTypes);
return resolveTypeSignatures(typeSignature, typeSignatureMap).getTypeSignature();
}

// todo: Change to ImmutableList when open sourcing

private static NamedTypeSignature resolveTypeSignatures(TypeSignature typeSignature, Map<TypeSignature, TypeSignature> typeSignatureMap)
{
TypeSignature resolvedTypeSignature = typeSignatureMap.getOrDefault(typeSignature, typeSignature);
List<NamedTypeSignature> namedTypeSignatures = new ArrayList<>();
List<TypeSignature> typeSignatures = new ArrayList<>();
List<TypeSignatureParameter> typeSignaturesList = typeSignature.getParameters();
for (TypeSignatureParameter typeSignatureParameter : typeSignaturesList) {
TypeSignature typeSignatureOrNamedTypeSignature = typeSignatureParameter.getTypeSignatureOrNamedTypeSignature().orElseThrow(() ->
new IllegalStateException("Could not get type signature for type parameter [" + typeSignatureParameter + "]"));
TypeSignature resolvedTypeParameterSignature = typeSignatureMap.getOrDefault(typeSignatureOrNamedTypeSignature, typeSignatureOrNamedTypeSignature);
if (resolvedTypeSignature.getBase().equals("row")) {
if (!typeSignatureOrNamedTypeSignature.getParameters().isEmpty()) {
namedTypeSignatures.add(resolveTypeSignatures(resolvedTypeParameterSignature, typeSignatureMap));
}
else {
namedTypeSignatures.add(new NamedTypeSignature(Optional.empty(), new TypeSignature(resolvedTypeParameterSignature.getBase(), Collections.emptyList())));
}
}
else {
if (!typeSignatureOrNamedTypeSignature.getParameters().isEmpty()) {
typeSignatures.add(resolveTypeSignatures(resolvedTypeParameterSignature, typeSignatureMap).getTypeSignature());
}
else {
typeSignatures.add(new TypeSignature(resolvedTypeParameterSignature.getBase(), Collections.emptyList()));
}
}
}
return new NamedTypeSignature(Optional.empty(), new TypeSignature(resolvedTypeSignature.getBase(),
(typeSignatures.isEmpty() ? namedTypeSignatures : typeSignatures).stream().map(
signature -> signature instanceof NamedTypeSignature ?
TypeSignatureParameter.of((NamedTypeSignature) signature)
: TypeSignatureParameter.of((TypeSignature) signature)).collect(Collectors.toList())));
}

/**
* Parameter and argument type mapping must be consistent
*/

public static Map<TypeSignature, TypeSignature> getTypeSignatureMap(List<TypeSignature> parameters, List<TypeSignature> argumentTypes)
{
HashMap<TypeSignature, TypeSignature> typeSignatureMap = new HashMap<>();
if (argumentTypes.size() != parameters.size()) {
throw new IllegalStateException("Parameters size and argumentTypes size do not match!");
}
for (int i = 0; i < argumentTypes.size(); i++) {
TypeSignature parameter = parameters.get(i);
TypeSignature argumentType = argumentTypes.get(i);
if (argumentTypes.get(i).getParameters().isEmpty()) {
typeSignatureMap.put(parameter, argumentType);
}
else {
typeSignatureMap.putAll(getTypeSignatureMap(parameter.getTypeParametersAsTypeSignatures(), argumentType.getTypeParametersAsTypeSignatures()));
}
}
return typeSignatureMap;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ protected AggregationFunctionImplementation sqlInvokedFunctionToAggregationImple
getClass().getSimpleName(),
implementationType));
case CPP:
case REST:
checkArgument(
function.getAggregationMetadata().isPresent(),
"Need aggregationMetadata to get aggregation function implementation");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.facebook.presto.spi.NodeManager;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.AggregationFunctionImplementation;
import com.facebook.presto.spi.function.AggregationFunctionMetadata;
import com.facebook.presto.spi.function.AlterRoutineCharacteristics;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.function.FunctionMetadata;
Expand All @@ -40,6 +41,7 @@
import com.facebook.presto.spi.function.SqlFunctionHandle;
import com.facebook.presto.spi.function.SqlFunctionId;
import com.facebook.presto.spi.function.SqlFunctionSupplier;
import com.facebook.presto.spi.function.SqlInvokedAggregationFunctionImplementation;
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.function.TypeVariableConstraint;
import com.google.common.base.Suppliers;
Expand All @@ -59,7 +61,9 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import static com.facebook.presto.common.type.TypeSignatureUtils.resolveIntermediateType;
import static com.facebook.presto.spi.StandardErrorCode.GENERIC_USER_ERROR;
import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED;
import static com.facebook.presto.spi.function.FunctionVersion.notVersioned;
Expand Down Expand Up @@ -141,19 +145,52 @@ public final AggregationFunctionImplementation getAggregateFunctionImplementatio
checkArgument(functionHandle instanceof SqlFunctionHandle, "Unsupported FunctionHandle type '%s'", functionHandle.getClass().getSimpleName());

SqlFunctionHandle sqlFunctionHandle = (SqlFunctionHandle) functionHandle;
if (functionHandle instanceof NativeFunctionHandle) {
NativeFunctionHandle nativeFunctionHandle = (NativeFunctionHandle) functionHandle;
return processNativeFunctionHandle(nativeFunctionHandle, typeManager);
}
else {
return processSqlFunctionHandle(sqlFunctionHandle, typeManager);
}
}

private AggregationFunctionImplementation processNativeFunctionHandle(NativeFunctionHandle nativeFunctionHandle, TypeManager typeManager)
{
if (!aggregationImplementationByHandle.containsKey(nativeFunctionHandle)) {
Signature signature = nativeFunctionHandle.getSignature();
SqlFunction function = getSqlFunctionFromSignature(signature);
SqlInvokedFunction sqlFunction = (SqlInvokedFunction) function;

checkArgument(
sqlFunction.getAggregationMetadata().isPresent(),
"Need aggregationMetadata to get aggregation function implementation");

AggregationFunctionMetadata aggregationMetadata = sqlFunction.getAggregationMetadata().get();
TypeSignature intermediateType = aggregationMetadata.getIntermediateType();
List<TypeSignature> typeSignatures = sqlFunction.getParameters().stream().map(Parameter::getType).collect(Collectors.toList());
TypeSignature resolvedIntermediateType = resolveIntermediateType(intermediateType, typeSignatures, signature.getArgumentTypes());
aggregationImplementationByHandle.put(
nativeFunctionHandle,
new SqlInvokedAggregationFunctionImplementation(
typeManager.getType(resolvedIntermediateType),
typeManager.getType(signature.getReturnType()),
aggregationMetadata.isOrderSensitive()));
}
return aggregationImplementationByHandle.get(nativeFunctionHandle);
}

// Cache results if applicable
private AggregationFunctionImplementation processSqlFunctionHandle(SqlFunctionHandle sqlFunctionHandle, TypeManager typeManager)
{
if (!aggregationImplementationByHandle.containsKey(sqlFunctionHandle)) {
SqlFunctionId functionId = sqlFunctionHandle.getFunctionId();
if (!latestFunctions.containsKey(functionId)) {
if (!memoizedFunctionsSupplier.get().containsKey(functionId)) {
throw new PrestoException(GENERIC_USER_ERROR, format("Function '%s' is missing from cache", functionId.getId()));
}

aggregationImplementationByHandle.put(
sqlFunctionHandle,
sqlInvokedFunctionToAggregationImplementation(latestFunctions.get(functionId), typeManager));
sqlInvokedFunctionToAggregationImplementation(memoizedFunctionsSupplier.get().get(functionId), typeManager));
}

return aggregationImplementationByHandle.get(sqlFunctionHandle);
}

Expand Down Expand Up @@ -303,18 +340,23 @@ public final FunctionHandle getFunctionHandle(Optional<? extends FunctionNamespa
return functionHandle;
}

private FunctionMetadata getMetadataFromNativeFunctionHandle(SqlFunctionHandle functionHandle)
private SqlFunction getSqlFunctionFromSignature(Signature signature)
{
NativeFunctionHandle nativeFunctionHandle = (NativeFunctionHandle) functionHandle;
Signature signature = nativeFunctionHandle.getSignature();
SqlFunctionSupplier functionKey;
try {
functionKey = specializedFunctionKeyCache.getUnchecked(signature);
}
catch (UncheckedExecutionException e) {
throw convertToPrestoException(e, format("Error getting FunctionMetadata for handle: %s", functionHandle));
throw convertToPrestoException(e, format("Error getting FunctionMetadata for signature: %s", signature));
}
SqlFunction function = functionKey.getFunction();
return functionKey.getFunction();
}

private FunctionMetadata getMetadataFromNativeFunctionHandle(SqlFunctionHandle functionHandle)
{
NativeFunctionHandle nativeFunctionHandle = (NativeFunctionHandle) functionHandle;
Signature signature = nativeFunctionHandle.getSignature();
SqlFunction function = getSqlFunctionFromSignature(signature);

// todo: verify this metadata return
SqlInvokedFunction sqlFunction = (SqlInvokedFunction) function;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -842,15 +842,17 @@ void VeloxQueryPlanConverterBase::toAggregations(
auto pos = functionId.find(";", start + 1);
if (pos == std::string::npos) {
auto argumentType = functionId.substr(start + 1);
aggregate.rawInputTypes.push_back(
if (!argumentType.empty()) {
aggregate.rawInputTypes.push_back(
stringToType(argumentType, typeParser_));
}
break;
}

auto argumentType = functionId.substr(start + 1, pos - start - 1);
aggregate.rawInputTypes.push_back(
stringToType(argumentType, typeParser_));
pos = start + 1;
start = pos;
}
}
} else {
Expand Down

0 comments on commit 49d3b9d

Please sign in to comment.