Skip to content

Commit

Permalink
Added a new nativeFunctionHandle to fix the failing variadic functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Pratik Joseph Dabre committed Aug 1, 2024
1 parent fb763ac commit 9c6435a
Show file tree
Hide file tree
Showing 18 changed files with 771 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ public Optional<UserDefinedType> getUserDefinedType(QualifiedObjectName typeName
}

@Override
public final FunctionHandle getFunctionHandle(Optional<? extends FunctionNamespaceTransactionHandle> transactionHandle, Signature signature)
public FunctionHandle getFunctionHandle(Optional<? extends FunctionNamespaceTransactionHandle> transactionHandle, Signature signature)
{
checkCatalog(signature.getName());
// This is the only assumption in this class that we're dealing with sql-invoked regular function.
Expand Down Expand Up @@ -245,7 +245,7 @@ public CompletableFuture<SqlFunctionResult> executeFunction(String source, Funct
typeManager.getType(functionMetadata.getReturnType()));
}

private static PrestoException convertToPrestoException(UncheckedExecutionException exception, String failureMessage)
protected static PrestoException convertToPrestoException(UncheckedExecutionException exception, String failureMessage)
{
Throwable cause = exception.getCause();
if (cause instanceof PrestoException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.facebook.presto.spi.function.AggregationFunctionMetadata;
import com.facebook.presto.spi.function.FunctionKind;
import com.facebook.presto.spi.function.RoutineCharacteristics;
import com.facebook.presto.spi.function.TypeVariableConstraint;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -61,6 +62,14 @@ public class JsonBasedUdfFunctionMetadata
* Optional Aggregate-specific metadata (required for aggregation functions)
*/
private final Optional<AggregationFunctionMetadata> aggregateMetadata;
/**
* Optional field: Marked to indicate the arity of the function.
*/
private final Optional<Boolean> variableArity;
/**
* Optional field: List of the typeVariableConstraints.
*/
private final Optional<List<TypeVariableConstraint>> typeVariableConstraints;

@JsonCreator
public JsonBasedUdfFunctionMetadata(
Expand All @@ -70,7 +79,9 @@ public JsonBasedUdfFunctionMetadata(
@JsonProperty("paramTypes") List<TypeSignature> paramTypes,
@JsonProperty("schema") String schema,
@JsonProperty("routineCharacteristics") RoutineCharacteristics routineCharacteristics,
@JsonProperty("aggregateMetadata") Optional<AggregationFunctionMetadata> aggregateMetadata)
@JsonProperty("aggregateMetadata") Optional<AggregationFunctionMetadata> aggregateMetadata,
@JsonProperty("variableArity") Optional<Boolean> variableArity,
@JsonProperty("typeVariableConstraints") Optional<List<TypeVariableConstraint>> typeVariableConstraints)
{
this.docString = requireNonNull(docString, "docString is null");
this.functionKind = requireNonNull(functionKind, "functionKind is null");
Expand All @@ -82,6 +93,8 @@ public JsonBasedUdfFunctionMetadata(
checkArgument(
(functionKind == AGGREGATE && aggregateMetadata.isPresent()) || (functionKind != AGGREGATE && !aggregateMetadata.isPresent()),
"aggregateMetadata must be present for aggregation functions and absent otherwise");
this.variableArity = requireNonNull(variableArity, "variableArity is null");
this.typeVariableConstraints = requireNonNull(typeVariableConstraints, "typeVariableConstraints is null");
}

public String getDocString()
Expand Down Expand Up @@ -123,4 +136,14 @@ public Optional<AggregationFunctionMetadata> getAggregateMetadata()
{
return aggregateMetadata;
}

public Optional<Boolean> getVariableArity()
{
return variableArity;
}

public Optional<List<TypeVariableConstraint>> getTypeVariableConstraints()
{
return typeVariableConstraints;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.functionNamespace.prestissimo;

import com.facebook.presto.common.CatalogSchemaName;
import com.facebook.presto.spi.function.FunctionKind;
import com.facebook.presto.spi.function.Signature;
import com.facebook.presto.spi.function.SqlFunctionHandle;
import com.facebook.presto.spi.function.SqlFunctionId;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;

import java.util.Objects;

import static java.util.Objects.requireNonNull;

public class NativeFunctionHandle
extends SqlFunctionHandle
{
private final Signature signature;

@JsonCreator
public NativeFunctionHandle(@JsonProperty("signature") Signature signature)
{
super(new SqlFunctionId(signature.getName(), signature.getArgumentTypes()), "1");
this.signature = requireNonNull(signature, "signature is null");
checkArgument(signature.getTypeVariableConstraints().isEmpty(), "%s has unbound type parameters", signature);
}

@JsonProperty
public Signature getSignature()
{
return signature;
}

@Override
public String getName()
{
return signature.getName().toString();
}

@Override
public FunctionKind getKind()
{
return signature.getKind();
}

@Override
public CatalogSchemaName getCatalogSchemaName()
{
return signature.getName().getCatalogSchemaName();
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
NativeFunctionHandle that = (NativeFunctionHandle) o;
return Objects.equals(signature, that.signature);
}

@Override
public int hashCode()
{
return Objects.hash(signature);
}

@Override
public String toString()
{
return signature.toString();
}

private static void checkArgument(boolean condition, String message, Object... args)
{
if (!condition) {
throw new IllegalArgumentException(String.format(message, args));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,28 @@
import com.facebook.presto.spi.function.AlterRoutineCharacteristics;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.function.FunctionMetadata;
import com.facebook.presto.spi.function.FunctionMetadataManager;
import com.facebook.presto.spi.function.FunctionNamespaceTransactionHandle;
import com.facebook.presto.spi.function.Parameter;
import com.facebook.presto.spi.function.ScalarFunctionImplementation;
import com.facebook.presto.spi.function.Signature;
import com.facebook.presto.spi.function.SqlFunction;
import com.facebook.presto.spi.function.SqlFunctionHandle;
import com.facebook.presto.spi.function.SqlFunctionId;
import com.facebook.presto.spi.function.SqlFunctionSupplier;
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.function.TypeVariableConstraint;
import com.google.common.base.Suppliers;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.UncheckedExecutionException;

import javax.inject.Inject;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -60,6 +71,7 @@
import static java.lang.Long.parseLong;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.HOURS;

public class NativeFunctionNamespaceManager
extends AbstractSqlInvokedFunctionNamespaceManager
Expand All @@ -71,20 +83,33 @@ public class NativeFunctionNamespaceManager
private final NodeManager nodeManager;
private final Map<SqlFunctionId, SqlInvokedFunction> latestFunctions = new ConcurrentHashMap<>();
private final Supplier<Map<SqlFunctionId, SqlInvokedFunction>> memoizedFunctionsSupplier;
private final FunctionMetadataManager functionMetadataManager;
private final LoadingCache<Signature, SqlFunctionSupplier> specializedFunctionKeyCache;

@Inject
public NativeFunctionNamespaceManager(
@ServingCatalog String catalogName,
SqlFunctionExecutors sqlFunctionExecutors,
SqlInvokedFunctionNamespaceManagerConfig config,
FunctionDefinitionProvider functionDefinitionProvider,
NodeManager nodeManager)
NodeManager nodeManager,
FunctionMetadataManager functionMetadataManager)
{
super(catalogName, sqlFunctionExecutors, config);
this.functionDefinitionProvider = requireNonNull(functionDefinitionProvider, "functionDefinitionProvider is null");
this.nodeManager = requireNonNull(nodeManager, "nodeManager is null");
this.memoizedFunctionsSupplier = Suppliers.memoizeWithExpiration(this::bootstrapNamespace,
config.getFunctionCacheExpiration().toMillis(), TimeUnit.MILLISECONDS);
this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null");
this.specializedFunctionKeyCache = CacheBuilder.newBuilder()
.maximumSize(1000)
.expireAfterWrite(1, HOURS)
.build(CacheLoader.from(this::doGetSpecializedFunctionKey));
}

private SqlFunctionSupplier doGetSpecializedFunctionKey(Signature signature)
{
return functionMetadataManager.getSpecializedFunctionKey(signature);
}

private Map<SqlFunctionId, SqlInvokedFunction> bootstrapNamespace()
Expand Down Expand Up @@ -137,6 +162,7 @@ private static SqlInvokedFunction copyFunction(SqlInvokedFunction function)
return new SqlInvokedFunction(
function.getSignature().getName(),
function.getParameters(),
function.getSignature().getTypeVariableConstraints(),
function.getSignature().getReturnType(),
function.getDescription(),
function.getRoutineCharacteristics(),
Expand All @@ -152,15 +178,28 @@ protected SqlInvokedFunction createSqlInvokedFunction(String functionName, JsonB
QualifiedObjectName qualifiedFunctionName = QualifiedObjectName.valueOf(new CatalogSchemaName(getCatalogName(), jsonBasedUdfFunctionMetaData.getSchema()), functionName);
List<String> parameterNameList = jsonBasedUdfFunctionMetaData.getParamNames();
List<TypeSignature> parameterTypeList = jsonBasedUdfFunctionMetaData.getParamTypes();

List<TypeVariableConstraint> typeVariableConstraintsList = jsonBasedUdfFunctionMetaData.getTypeVariableConstraints().isPresent() ?
jsonBasedUdfFunctionMetaData.getTypeVariableConstraints().get() : Collections.emptyList();
ImmutableList.Builder<Parameter> parameterBuilder = ImmutableList.builder();
for (int i = 0; i < parameterNameList.size(); i++) {
parameterBuilder.add(new Parameter(parameterNameList.get(i), parameterTypeList.get(i)));
}

// Todo: Clean this method up
ImmutableList.Builder<TypeVariableConstraint> typeVariableConstraintsBuilder = ImmutableList.builder();
for (TypeVariableConstraint typeVariableConstraint : typeVariableConstraintsList) {
typeVariableConstraintsBuilder.add(new TypeVariableConstraint(
typeVariableConstraint.getName(),
typeVariableConstraint.isComparableRequired(),
typeVariableConstraint.isOrderableRequired(),
null,
typeVariableConstraint.isNonDecimalNumericRequired()));
}

return new SqlInvokedFunction(
qualifiedFunctionName,
parameterBuilder.build(),
typeVariableConstraintsBuilder.build(),
jsonBasedUdfFunctionMetaData.getOutputType(),
jsonBasedUdfFunctionMetaData.getDocString(),
jsonBasedUdfFunctionMetaData.getRoutineCharacteristics(),
Expand Down Expand Up @@ -188,10 +227,13 @@ protected UserDefinedType fetchUserDefinedTypeDirect(QualifiedObjectName typeNam
@Override
protected FunctionMetadata fetchFunctionMetadataDirect(SqlFunctionHandle functionHandle)
{
if (functionHandle instanceof NativeFunctionHandle) {
return getMetadataFromNativeFunctionHandle(functionHandle);
}

return fetchFunctionsDirect(functionHandle.getFunctionId().getFunctionName()).stream()
.filter(function -> function.getRequiredFunctionHandle().equals(functionHandle))
.map(this::sqlInvokedFunctionToMetadata)
.collect(onlyElement());
.map(this::sqlInvokedFunctionToMetadata).collect(onlyElement());
}

@Override
Expand Down Expand Up @@ -248,4 +290,47 @@ public void addUserDefinedType(UserDefinedType userDefinedType)
name);
userDefinedTypes.put(name, userDefinedType);
}

@Override
public final FunctionHandle getFunctionHandle(Optional<? extends FunctionNamespaceTransactionHandle> transactionHandle, Signature signature)
{
FunctionHandle functionHandle = super.getFunctionHandle(transactionHandle, signature);

// only handle variadic signatures here , for normal signature we use the AbstractSqlInvokedFunctionNamespaceManager function handle.
if (functionHandle == null) {
return new NativeFunctionHandle(signature);
}
return functionHandle;
}

private FunctionMetadata getMetadataFromNativeFunctionHandle(SqlFunctionHandle functionHandle)
{
NativeFunctionHandle nativeFunctionHandle = (NativeFunctionHandle) functionHandle;
Signature signature = nativeFunctionHandle.getSignature();
SqlFunctionSupplier functionKey;
try {
functionKey = specializedFunctionKeyCache.getUnchecked(signature);
}
catch (UncheckedExecutionException e) {
throw convertToPrestoException(e, format("Error getting FunctionMetadata for handle: %s", functionHandle));
}
SqlFunction function = functionKey.getFunction();

// todo: verify this metadata return
SqlInvokedFunction sqlFunction = (SqlInvokedFunction) function;
return new FunctionMetadata(
signature.getName(),
signature.getArgumentTypes(),
sqlFunction.getParameters().stream()
.map(Parameter::getName)
.collect(toImmutableList()),
signature.getReturnType(),
function.getSignature().getKind(),
sqlFunction.getRoutineCharacteristics().getLanguage(),
getFunctionImplementationType(sqlFunction),
function.isDeterministic(),
function.isCalledOnNullInput(),
sqlFunction.getVersion(),
function.getComplexTypeFunctionDescriptor());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public FunctionNamespaceManager<?> create(String catalogName, Map<String, String
{
try {
Bootstrap app = new Bootstrap(
new NativeFunctionNamespaceManagerModule(catalogName, context.getNodeManager()),
new NativeFunctionNamespaceManagerModule(catalogName, context.getNodeManager(), context.getFunctionMetadataManager()),
new NoopSqlFunctionExecutorsModule());

Injector injector = app
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.facebook.presto.functionNamespace.SqlInvokedFunctionNamespaceManagerConfig;
import com.facebook.presto.functionNamespace.execution.SqlFunctionLanguageConfig;
import com.facebook.presto.spi.NodeManager;
import com.facebook.presto.spi.function.FunctionMetadataManager;
import com.google.inject.Binder;
import com.google.inject.Module;
import com.google.inject.TypeLiteral;
Expand All @@ -38,11 +39,13 @@ public class NativeFunctionNamespaceManagerModule
private final String catalogName;

private final NodeManager nodeManager;
private final FunctionMetadataManager functionMetadataManager;

public NativeFunctionNamespaceManagerModule(String catalogName, NodeManager nodeManager)
public NativeFunctionNamespaceManagerModule(String catalogName, NodeManager nodeManager, FunctionMetadataManager functionMetadataManager)
{
this.catalogName = requireNonNull(catalogName, "catalogName is null");
this.nodeManager = requireNonNull(nodeManager, "nodeManager is null");
this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null");
}

@Override
Expand All @@ -57,5 +60,6 @@ public void configure(Binder binder)
.toInstance(new JsonCodecFactory().mapJsonCodec(String.class, listJsonCodec(JsonBasedUdfFunctionMetadata.class)));
binder.bind(NativeFunctionNamespaceManager.class).in(SINGLETON);
binder.bind(NodeManager.class).toInstance(nodeManager);
binder.bind(FunctionMetadataManager.class).toInstance(functionMetadataManager);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1346,7 +1346,12 @@ private SpecializedFunctionKey getSpecializedFunctionKey(Signature signature)

private SpecializedFunctionKey doGetSpecializedFunctionKey(Signature signature)
{
Iterable<SqlFunction> candidates = getFunctions(null, signature.getName());
Collection<SqlFunction> candidates = getFunctions(null, signature.getName());
return doGetSpecializedFunctionKey(signature, candidates);
}

public SpecializedFunctionKey doGetSpecializedFunctionKey(Signature signature, Collection<SqlFunction> candidates)
{
// search for exact match
Type returnType = functionAndTypeManager.getType(signature.getReturnType());
List<TypeSignatureProvider> argumentTypeSignatureProviders = fromTypeSignatures(signature.getArgumentTypes());
Expand Down
Loading

0 comments on commit 9c6435a

Please sign in to comment.