diff --git a/config-model/src/main/java/com/yahoo/schema/Application.java b/config-model/src/main/java/com/yahoo/schema/Application.java index dbc21743d967..7142b0c5a2d7 100644 --- a/config-model/src/main/java/com/yahoo/schema/Application.java +++ b/config-model/src/main/java/com/yahoo/schema/Application.java @@ -30,6 +30,7 @@ public class Application { private final ApplicationPackage applicationPackage; private final Map schemas; private final DocumentModel documentModel; + private final RankProfileRegistry rankProfileRegistry; public Application(ApplicationPackage applicationPackage, List schemas, @@ -41,6 +42,7 @@ public Application(ApplicationPackage applicationPackage, Set> processorsToSkip, DeployLogger logger) { this.applicationPackage = applicationPackage; + this.rankProfileRegistry = rankProfileRegistry; Map schemaMap = new LinkedHashMap<>(); for (Schema schema : schemas) { @@ -87,6 +89,8 @@ public Application(ApplicationPackage applicationPackage, public ApplicationPackage applicationPackage() { return applicationPackage; } + public RankProfileRegistry rankProfileRegistry() { return rankProfileRegistry; } + /** Returns an unmodifiable list of the schemas of this application */ public Map schemas() { return schemas; } diff --git a/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java index 2a8dd49a0c1a..f3e8c0a2f489 100644 --- a/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java @@ -135,12 +135,13 @@ MapEvaluationTypeContext getParent(String forArgument, String boundTo) { () -> new IllegalArgumentException("argument "+forArgument+" is bound to "+boundTo+" but there is no parent context")); } - String resolveBinding(String argument) { - String bound = getBinding(argument); + @Override + public String resolveBinding(String name) { + String bound = getBinding(name); if (bound == null) { - return argument; + return name; } - return getParent(argument, bound).resolveBinding(bound); + return getParent(name, bound).resolveBinding(bound); } private TensorType resolveType(Reference reference) { @@ -148,7 +149,6 @@ private TensorType resolveType(Reference reference) { throw new IllegalArgumentException("Invocation loop: " + currentResolutionCallStack.stream().map(Reference::toString).collect(Collectors.joining(" -> ")) + " -> " + reference); - // Bound to a function argument? Optional binding = boundIdentifier(reference); if (binding.isPresent()) { @@ -156,8 +156,7 @@ private TensorType resolveType(Reference reference) { // This is not pretty, but changing to bind expressions rather // than their string values requires deeper changes var expr = new RankingExpression(binding.get()); - var type = expr.type(getParent(reference.name(), binding.get())); - return type; + return expr.type(getParent(reference.name(), binding.get())); } catch (ParseException e) { throw new IllegalArgumentException(e); } @@ -180,8 +179,7 @@ private TensorType resolveType(Reference reference) { if (function.isPresent()) { var body = function.get().getBody(); var child = this.withBindings(bind(function.get().arguments(), reference.arguments())); - var type = body.type(child); - return type; + return body.type(child); } // A reference to an ONNX model? diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java index ed1a4e98b49b..0d812699f887 100644 --- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java @@ -1184,7 +1184,6 @@ private RankingExpressionFunction compile(RankingExpressionFunction function, Map inlineFunctions, ExpressionTransforms expressionTransforms) { if (function == null) return null; - RankProfileTransformContext context = new RankProfileTransformContext(this, queryProfiles, featureTypes, diff --git a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java index 15e5891a3e34..d2f6f3b18e7f 100644 --- a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java @@ -288,9 +288,9 @@ private void deriveFunctionProperties(Map e : functions.entrySet()) { String propertyName = RankingExpression.propertyName(e.getKey()); if (! context.serializedFunctions().containsKey(propertyName)) { - String expressionString = e.getValue().function().getBody().getRoot().toString(context).toString(); context.addFunctionSerialization(propertyName, expressionString); + e.getValue().function().argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey()) .forEach(argumentType -> context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue())); } diff --git a/config-model/src/test/java/com/yahoo/schema/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/schema/RankProfileTestCase.java index 564bfd4a9904..f07ed4cf42ef 100644 --- a/config-model/src/test/java/com/yahoo/schema/RankProfileTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/RankProfileTestCase.java @@ -8,6 +8,7 @@ import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.document.DataType; +import com.yahoo.schema.derived.DerivedConfiguration; import com.yahoo.search.query.profile.QueryProfile; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.search.query.profile.types.FieldDescription; @@ -417,6 +418,69 @@ void requireThatConfigIsDerivedForQueryFeatureTypeSettings() throws ParseExcepti assertQueryFeatureTypeSettings(registry.get(schema, "p2"), schema); } + @Test + void dimensionArgumentResolution() throws ParseException{ + RankProfileRegistry registry = new RankProfileRegistry(); + ApplicationBuilder builder = new ApplicationBuilder(registry); + builder.addSchema(""" +schema test { +document test { + field embeddings type tensor(d1[384]) { + indexing: attribute + } +} +rank-profile feature_logging { + inputs { + query(query_embedding_int8) tensor(d0[384]) + query(query_embedding) tensor(d0{}, d1[384]) + } + first-phase { + expression: fakeRankResult + } + function query_field_cosine_similarity(field_name, query_tensor, dimension) { + expression: cosine_similarity(attribute(field_name), query_tensor, dimension) + } + function query_field_cos_distances(field_name, query_tensor, dimension){ + expression: max(1 - query_field_cosine_similarity(field_name, query_tensor, dimension), 0.0) + } + function query_field_acos_distances(field_name, query_tensor, dimension) { + expression: acos(query_field_cosine_similarity(field_name, query_tensor, dimension)) + } + function query_field_closeness(field_name, query_tensor, dimension) { + expression: reduce(1/(1+query_field_acos_distances(field_name, query_tensor, dimension)), max) + } + summary-features { + query_field_closeness(embeddings, query(query_embedding), d1) + } +} +}"""); + Application application = builder.build(true); + RankProfile profile = application.rankProfileRegistry().get("test", "feature_logging"); + + // Rank profile content is unbound, as written: + assertEquals("join(reduce(join(attribute(field_name), query_tensor, f(a,b)(a * b)), sum, dimension), " + + "map(join(reduce(join(attribute(field_name), attribute(field_name), f(a,b)(a * b)), sum, dimension), " + + "reduce(join(query_tensor, query_tensor, f(a,b)(a * b)), sum, dimension), " + + "f(a,b)(a * b)), f(a)(sqrt(a))), f(a,b)(a / b))", + profile.findFunction("query_field_cosine_similarity").function().getBody().getRoot().toString()); + + // Derived rank profile content is bound: attribute(field_name) -> attribute(embeddings), dimension -> d1 + assertEquals("join(reduce(join(attribute(embeddings), query(query_embedding), f(a,b)(a * b)), sum, d1), " + + "map(join(reduce(join(attribute(embeddings), attribute(embeddings), f(a,b)(a * b)), sum, d1), " + + "reduce(join(query(query_embedding), query(query_embedding), f(a,b)(a * b)), sum, d1), " + + "f(a,b)(a * b)), f(a)(sqrt(a))), f(a,b)(a / b))", + findDerivedFunction(application, "feature_logging", "query_field_cosine_similarity")); + } + + private String findDerivedFunction(Application application, String rankProfileName, String functionName) { + var derived = new DerivedConfiguration(application.schemas().get("test"), application.rankProfileRegistry()); + for (var line : derived.getRankProfileList().getRankProfiles().get("feature_logging").configProperties()) { + if (line.getFirst().startsWith("rankingExpression(query_field_cosine_similarity@")) + return line.getSecond(); + } + return null; + } + private static QueryProfileRegistry setupQueryProfileTypes() { QueryProfileRegistry registry = new QueryProfileRegistry(); QueryProfileTypeRegistry typeRegistry = registry.getTypeRegistry(); diff --git a/model-evaluation/abi-spec.json b/model-evaluation/abi-spec.json index 667712d0daae..226bc5e60683 100644 --- a/model-evaluation/abi-spec.json +++ b/model-evaluation/abi-spec.json @@ -36,6 +36,7 @@ "public com.yahoo.searchlib.rankingexpression.evaluation.Value get(int)", "public double getDouble(int)", "public int getIndex(java.lang.String)", + "public java.lang.String resolveBinding(java.lang.String)", "public int size()", "public java.util.Set names()", "public java.util.Set arguments()", diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index 18a3f01cf92d..1a57800a9e23 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -112,6 +112,11 @@ public int getIndex(String name) { return requireIndexOf(name); } + @Override + public String resolveBinding(String argument) { + return null; + } + @Override public int size() { return indexedBindings.names().size(); diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index 2d67abb0e048..7d06db8971c3 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -257,7 +257,7 @@ "public" ], "methods" : [ - "public void (com.yahoo.searchlib.rankingexpression.ExpressionFunction, java.lang.String, java.lang.String)", + "public void (java.lang.String, java.lang.String)", "public java.lang.String getName()", "public java.lang.String getExpressionString()" ], @@ -424,6 +424,7 @@ "public com.yahoo.searchlib.rankingexpression.evaluation.Value get(java.lang.String)", "public final com.yahoo.searchlib.rankingexpression.evaluation.Value get(int)", "public final double getDouble(int)", + "public java.lang.String resolveBinding(java.lang.String)", "public com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext clone()", "public bridge synthetic com.yahoo.searchlib.rankingexpression.evaluation.AbstractArrayContext clone()", "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)", @@ -537,6 +538,7 @@ "public com.yahoo.tensor.TensorType getType(com.yahoo.searchlib.rankingexpression.Reference)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value get(java.lang.String)", "public final com.yahoo.searchlib.rankingexpression.evaluation.Value get(int)", + "public java.lang.String resolveBinding(java.lang.String)", "public com.yahoo.searchlib.rankingexpression.evaluation.DoubleOnlyArrayContext clone()", "public bridge synthetic com.yahoo.searchlib.rankingexpression.evaluation.AbstractArrayContext clone()", "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)", @@ -629,6 +631,7 @@ "public com.yahoo.tensor.TensorType getType(com.yahoo.searchlib.rankingexpression.Reference)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value get(java.lang.String)", "public void put(java.lang.String, com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public java.lang.String resolveBinding(java.lang.String)", "public java.util.Map bindings()", "public com.yahoo.searchlib.rankingexpression.evaluation.MapContext thawedCopy()", "public java.util.Set names()", @@ -651,6 +654,7 @@ "public void setType(com.yahoo.searchlib.rankingexpression.Reference, com.yahoo.tensor.TensorType)", "public com.yahoo.tensor.TensorType getType(java.lang.String)", "public com.yahoo.tensor.TensorType getType(com.yahoo.searchlib.rankingexpression.Reference)", + "public java.lang.String resolveBinding(java.lang.String)", "public java.util.Map bindings()", "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)" ], diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java index 840eacd9dd9c..fd173bca2267 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java @@ -216,7 +216,7 @@ public String toString() { * An instance of a serialization of this function, using a particular serialization context (by {@link * ExpressionFunction#expand}) */ - public class Instance { + public static class Instance { private final String name; private final String expressionString; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java index d32b14886cae..bb0faf2a608d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java @@ -120,6 +120,11 @@ public final double getDouble(int index) { return value; } + @Override + public String resolveBinding(String argument) { + return null; + } + /** * Creates a clone of this context suitable for evaluating against the same ranking expression * in a different thread (i.e, name name to index map, different value set. diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java index f0685ea77fd9..0988c58e73d0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java @@ -93,6 +93,11 @@ public final Value get(int index) { return new DoubleValue(getDouble(index)); } + @Override + public String resolveBinding(String argument) { + return null; + } + /** * Creates a clone of this context suitable for evaluating against the same ranking expression * in a different thread (i.e, name name to index map, different value set. diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java index 2b620a6d8f05..c8a3eab381f3 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java @@ -73,6 +73,11 @@ public void put(String key, Value value) { bindings.put(key, value.freeze()); } + @Override + public String resolveBinding(String argument) { + return null; + } + /** Returns an immutable view of the bindings of this. */ public Map bindings() { if (frozen) return bindings; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java index 6c980181b47e..4a723eae578e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java @@ -32,6 +32,11 @@ public TensorType getType(Reference reference) { return featureTypes.get(reference); } + @Override + public String resolveBinding(String argument) { + return null; + } + /** Returns an unmodifiable map of the bindings in this */ public Map bindings() { return Collections.unmodifiableMap(featureTypes); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index b1585233a1e6..54703141cbdb 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -88,7 +88,6 @@ public StringBuilder toString(StringBuilder string, SerializationContext context if ( needSerialization ) { ExpressionFunction.Instance instance = function.expand(context, getArguments().expressions(), path); functionName = instance.getName(); - context.addFunctionSerialization(RankingExpression.propertyName(functionName), instance.getExpressionString()); for (Map.Entry argumentType : function.argumentTypes().entrySet()) context.addArgumentTypeSerialization(functionName, argumentType.getKey(), argumentType.getValue()); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 202dbebc3116..3c17c7830f2e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -59,8 +59,7 @@ private ExpressionNode toExpressionNode(TensorFunction f) { } private static ScalarFunction transform(ScalarFunction input, - Function transformer) - { + Function transformer) { if (input instanceof ExpressionScalarFunction wrapper) { ExpressionNode transformed = transformer.apply(wrapper.expression); return new ExpressionScalarFunction(transformed); @@ -411,6 +410,12 @@ public Value get(String name) { public TensorType getType(Reference name) { return delegate.getType(name); } + + @Override + public String resolveBinding(String argument) { + return delegate.resolveBinding(argument); + } + } private static Context asContext(EvaluationContext generic) { diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 06cf7d0d71a1..c3bb75b29781 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1570,7 +1570,8 @@ "public void put(java.lang.String, com.yahoo.tensor.Tensor)", "public com.yahoo.tensor.TensorType getType(java.lang.String)", "public com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)", - "public com.yahoo.tensor.Tensor getTensor(java.lang.String)" + "public com.yahoo.tensor.Tensor getTensor(java.lang.String)", + "public java.lang.String resolveBinding(java.lang.String)" ], "fields" : [ ] }, @@ -1599,7 +1600,8 @@ ], "methods" : [ "public abstract com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)", - "public abstract com.yahoo.tensor.TensorType getType(java.lang.String)" + "public abstract com.yahoo.tensor.TensorType getType(java.lang.String)", + "public abstract java.lang.String resolveBinding(java.lang.String)" ], "fields" : [ ] }, @@ -1685,7 +1687,7 @@ ], "methods" : [ "public void ()", - "public final com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", + "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)" ], "fields" : [ ] @@ -1809,6 +1811,7 @@ "public java.util.List arguments()", "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)", "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", + "public final com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)", "public int hashCode()" ], @@ -2920,6 +2923,7 @@ "methods" : [ "public static com.yahoo.tensor.functions.ToStringContext empty()", "public abstract java.lang.String getBinding(java.lang.String)", + "public java.lang.String resolveBinding(java.lang.String)", "public java.util.Optional typeContext()", "public abstract com.yahoo.tensor.functions.ToStringContext parent()" ], diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java index 8cdb06143788..3d7705c42b0a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java @@ -23,11 +23,12 @@ public TensorType getType(String name) { } @Override - public TensorType getType(NAMETYPE name) { - return getType(name.name()); - } + public TensorType getType(NAMETYPE name) { return getType(name.name()); } @Override public Tensor getTensor(String name) { return bindings.get(name); } + @Override + public String resolveBinding(String name) { return name; } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java index d875f1ef4eb3..eddfb9df276a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java @@ -26,4 +26,7 @@ public interface TypeContext { */ TensorType getType(String name); + /** Returns the string a parameter is bound to, or the input name if none. */ + String resolveBinding(String name); + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java index 88b5a385e9f3..0bd360ef15f0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java @@ -47,7 +47,7 @@ public PrimitiveTensorFunction toPrimitive() { @Override public String toString(ToStringContext context) { - return "argmax(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")"; + return "argmax(" + argument.toString(context) + Reduce.commaSeparatedNames(dimensions, context) + ")"; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java index ffee606e8f6b..8e1ad71d3848 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java @@ -47,7 +47,7 @@ public PrimitiveTensorFunction toPrimitive() { @Override public String toString(ToStringContext context) { - return "argmin(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")"; + return "argmin(" + argument.toString(context) + Reduce.commaSeparatedNames(dimensions, context) + ")"; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java index 5655bb020a4f..ff0fe95bc4ed 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java @@ -95,14 +95,11 @@ private Tensor castFromSomeFloat(Tensor tensor, TensorType type) { } static private Function selectRestrict(TensorType.Value toValueType) { - switch (toValueType) { - case BFLOAT16: - return val -> Float.intBitsToFloat(Float.floatToRawIntBits(val) & ~0xffff); - case INT8: - return val -> (float)val.byteValue(); - default: - return val -> val; - } + return switch (toValueType) { + case BFLOAT16 -> val -> Float.intBitsToFloat(Float.floatToRawIntBits(val) & ~0xffff); + case INT8 -> val -> (float) val.byteValue(); + default -> val -> val; + }; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java index 23d90e634884..87b0210cf603 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java @@ -17,7 +17,7 @@ public abstract class CompositeTensorFunction extends Ten /** Finds the type this produces by first converting it to a primitive function */ @Override - public final TensorType type(TypeContext context) { + public TensorType type(TypeContext context) { return toPrimitive().type(context); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 2635cbecb94b..0b128f77d120 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -62,7 +62,8 @@ public PrimitiveTensorFunction toPrimitive() { @Override public String toString(ToStringContext context) { - return "concat(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + dimension + ")"; + return "concat(" + argumentA.toString(context) + ", " + argumentB.toString(context) + + ", " + context.resolveBinding(dimension) + ")"; } @Override @@ -70,7 +71,7 @@ public String toString(ToStringContext context) { @Override public TensorType type(TypeContext context) { - return TypeResolver.concat(argumentA.type(context), argumentB.type(context), dimension); + return TypeResolver.concat(argumentA.type(context), argumentB.type(context), context.resolveBinding(dimension)); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java index d84b1bbdc163..2bdf5266ffcf 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java @@ -2,6 +2,7 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.MapEvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.Tensor; @@ -10,10 +11,12 @@ import java.util.List; import java.util.Objects; +import java.util.Optional; /** * Convenience for cosine similarity between vectors. - * cosine_similarity(a, b, mydim) == sum(a*b, mydim) / sqrt(sum(a*a, mydim) * sum(b*b, mydim)) + * cosine_similarity(a, b, mydim) == sum(a*b, mydim) / sqrt(sum(a*a, mydim) * sum(b*b, mydim)). + * * @author arnej */ public class CosineSimilarity extends TensorFunction { @@ -45,18 +48,18 @@ public TensorFunction withArguments(List> arg public TensorType type(TypeContext context) { TensorType t1 = arg1.toPrimitive().type(context); TensorType t2 = arg2.toPrimitive().type(context); - var d1 = t1.dimension(dimension); - var d2 = t2.dimension(dimension); + var resolvedDimension = context.resolveBinding(dimension); + var d1 = t1.dimension(resolvedDimension); + var d2 = t2.dimension(resolvedDimension); if (d1.isEmpty() || d2.isEmpty() || d1.get().type() != Dimension.Type.indexedBound || d2.get().type() != Dimension.Type.indexedBound || ! d1.get().size().equals(d2.get().size())) { throw new IllegalArgumentException("cosine_similarity expects both arguments to have the '" - + dimension + "' dimension with same size, but input types were " + + resolvedDimension + "' dimension with same size, but input types were " + t1 + " and " + t2); } - // Finds the type this produces by first converting it to a primitive function return toPrimitive().type(context); } @@ -83,7 +86,7 @@ public PrimitiveTensorFunction toPrimitive() { @Override public String toString(ToStringContext context) { - return "cosine_similarity(" + arg1.toString(context) + ", " + arg2.toString(context) + ", " + dimension + ")"; + return "cosine_similarity(" + arg1.toString(context) + ", " + arg2.toString(context) + ", " + context.resolveBinding(dimension) + ")"; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java index b6655a153616..3d6e44d86587 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java @@ -37,6 +37,7 @@ class MyTypeContext implements TypeContext { MyTypeContext(TensorType subspaceType) { this.subspaceType = subspaceType; } public TensorType getType(NAMETYPE name) { return getType(name.name()); } public TensorType getType(String name) { return argName.equals(name) ? subspaceType : null; } + public String resolveBinding(String name) { return name; } } TensorType outputType(TensorType subspaceType) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java index d627e0093bff..dc213db4cc0e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java @@ -45,15 +45,16 @@ public TensorFunction withArguments(List> arg public TensorType type(TypeContext context) { TensorType t1 = arg1.toPrimitive().type(context); TensorType t2 = arg2.toPrimitive().type(context); - var d1 = t1.dimension(dimension); - var d2 = t2.dimension(dimension); + String resolvedDimension = context.resolveBinding(dimension); + var d1 = t1.dimension(resolvedDimension); + var d2 = t2.dimension(resolvedDimension); if (d1.isEmpty() || d2.isEmpty() || d1.get().type() != Dimension.Type.indexedBound || d2.get().type() != Dimension.Type.indexedBound || ! d1.get().size().equals(d2.get().size())) { throw new IllegalArgumentException("euclidean_distance expects both arguments to have the '" - + dimension + "' dimension with same size, but input types were " + + resolvedDimension + "' dimension with same size, but input types were " + t1 + " and " + t2); } // Finds the type this produces by first converting it to a primitive function @@ -79,7 +80,7 @@ public PrimitiveTensorFunction toPrimitive() { @Override public String toString(ToStringContext context) { - return "euclidean_distance(" + arg1.toString(context) + ", " + arg2.toString(context) + ", " + dimension + ")"; + return "euclidean_distance(" + arg1.toString(context) + ", " + arg2.toString(context) + ", " + context.resolveBinding(dimension) + ")"; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java index f5a33dde064e..da2295e66db8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java @@ -3,6 +3,7 @@ import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.List; import java.util.Objects; @@ -16,11 +17,11 @@ public class Expand extends CompositeTensorFunction { private final TensorFunction argument; - private final String dimensionName; + private final String dimension; public Expand(TensorFunction argument, String dimension) { this.argument = argument; - this.dimensionName = dimension; + this.dimension = dimension; } @Override @@ -30,22 +31,31 @@ public Expand(TensorFunction argument, String dimension) { public TensorFunction withArguments(List> arguments) { if (arguments.size() != 1) throw new IllegalArgumentException("Expand must have 1 argument, got " + arguments.size()); - return new Expand<>(arguments.get(0), dimensionName); + return new Expand<>(arguments.get(0), dimension); } @Override public PrimitiveTensorFunction toPrimitive() { - TensorType type = new TensorType.Builder(TensorType.Value.INT8).indexed(dimensionName, 1).build(); + return toPrimitive(dimension); + } + + @Override + public final TensorType type(TypeContext context) { + return toPrimitive(context.resolveBinding(dimension)).type(context); + } + + private PrimitiveTensorFunction toPrimitive(String dimension) { + TensorType type = new TensorType.Builder(TensorType.Value.INT8).indexed(dimension, 1).build(); Generate expansion = new Generate<>(type, ScalarFunctions.constant(1.0)); return new Join<>(expansion, argument, ScalarFunctions.multiply()); } @Override public String toString(ToStringContext context) { - return "expand(" + argument.toString(context) + ", " + dimensionName + ")"; + return "expand(" + argument.toString(context) + ", " + context.resolveBinding(dimension) + ")"; } @Override - public int hashCode() { return Objects.hash("expand", argument, dimensionName); } + public int hashCode() { return Objects.hash("expand", argument, dimension); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index 947a39dafb20..fb6963fdbcb8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -183,6 +183,11 @@ public TensorType getType(String name) { return context.getType(name); } + @Override + public String resolveBinding(String name) { + return context.resolveBinding(name); + } + } /** A context which adds the bindings of the generate dimension names to the given context. */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java index a5afeb6d2a42..b1ea52e880f4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java @@ -40,7 +40,7 @@ public PrimitiveTensorFunction toPrimitive() { @Override public String toString(ToStringContext context) { - return "l1_normalize(" + argument.toString(context) + ", " + dimension + ")"; + return "l1_normalize(" + argument.toString(context) + ", " + context.resolveBinding(dimension) + ")"; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java index 47e341732ca9..c25871590816 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java @@ -42,7 +42,7 @@ public PrimitiveTensorFunction toPrimitive() { @Override public String toString(ToStringContext context) { - return "l2_normalize(" + argument.toString(context) + ", " + dimension + ")"; + return "l2_normalize(" + argument.toString(context) + ", " + context.resolveBinding(dimension) + ")"; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java index fbf3b461a353..d97c85d64e14 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -46,7 +46,7 @@ public PrimitiveTensorFunction toPrimitive() { @Override public String toString(ToStringContext context) { - return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")"; + return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + context.resolveBinding(dimension) + ")"; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 947fd6e00123..af1f20850851 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -87,19 +87,20 @@ public PrimitiveTensorFunction toPrimitive() { @Override public String toString(ToStringContext context) { - return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")"; + return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparatedNames(dimensions, context) + ")"; } - static String commaSeparated(List list) { + static String commaSeparatedNames(List list, ToStringContext context) { StringBuilder b = new StringBuilder(); for (String element : list) - b.append(", ").append(element); + b.append(", ").append(context.resolveBinding(element)); return b.toString(); } @Override public TensorType type(TypeContext context) { - return outputType(argument.type(context), dimensions); + List resolvedDimensions = dimensions.stream().map(d -> context.resolveBinding(d)).toList(); + return outputType(argument.type(context), resolvedDimensions); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java index 2d5a05187471..e6fa448fef3f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java @@ -317,10 +317,10 @@ private boolean reduceDimensionIsInnermost(Tensor a, Tensor b) { @Override public String toString(ToStringContext context) { return "reduce_join(" + argumentA.toString(context) + ", " + - argumentB.toString(context) + ", " + - combinator + ", " + - aggregator + - Reduce.commaSeparated(dimensions) + ")"; + argumentB.toString(context) + ", " + + combinator + ", " + + aggregator + + Reduce.commaSeparatedNames(dimensions, context) + ")"; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index 05db61f53956..eabf2e88739e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -14,6 +14,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; /** * The rename tensor function returns a tensor where some dimensions are assigned new names. @@ -71,18 +72,16 @@ public TensorFunction withArguments(List> arg @Override public TensorType type(TypeContext context) { - return type(argument.type(context)); - } - - private TensorType type(TensorType type) { - return TypeResolver.rename(type, fromDimensions, toDimensions); + List resolvedFromDimensions = fromDimensions.stream().map(d -> context.resolveBinding(d)).toList(); + List resolvedToDimensions = toDimensions.stream().map(d -> context.resolveBinding(d)).toList(); + return TypeResolver.rename(argument.type(context), resolvedFromDimensions, resolvedToDimensions); } @Override public Tensor evaluate(EvaluationContext context) { Tensor tensor = argument.evaluate(context); - TensorType renamedType = type(tensor.type()); + TensorType renamedType = TypeResolver.rename(tensor.type(), fromDimensions, toDimensions); // an array which lists the index of each label in the renamed type int[] toIndexes = new int[tensor.type().dimensions().size()]; @@ -118,12 +117,12 @@ private boolean simpleRenameIsPossible(int[] toIndexes) { return true; } - private String toVectorString(List elements) { + private String toVectorString(List elements, ToStringContext context) { if (elements.size() == 1) - return elements.get(0); + return context.resolveBinding(elements.get(0)); StringBuilder b = new StringBuilder("("); for (String element : elements) - b.append(element).append(", "); + b.append(context.resolveBinding(element)).append(", "); b.setLength(b.length() - 2); b.append(")"); return b.toString(); @@ -132,7 +131,7 @@ private String toVectorString(List elements) { @Override public String toString(ToStringContext context) { return "rename(" + argument.toString(context) + ", " + - toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")"; + toVectorString(fromDimensions, context) + ", " + toVectorString(toDimensions, context) + ")"; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java index 150bf82f0e89..a0ef87d6e0b7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java @@ -46,7 +46,7 @@ public PrimitiveTensorFunction toPrimitive() { @Override public String toString(ToStringContext context) { - return "softmax(" + argument.toString(context) + ", " + dimension + ")"; + return "softmax(" + argument.toString(context) + ", " + context.resolveBinding(dimension) + ")"; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java index eac012a450b2..1faf7e051c35 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java @@ -18,6 +18,12 @@ public interface ToStringContext { /** Returns the name an identifier is bound to, or null if not bound in this context */ String getBinding(String name); + /** Returns the name an identifier is bound to, or the input name if none */ + default String resolveBinding(String name) { + String binding = getBinding(name); + return binding == null ? name : binding; + } + /** * Returns the context used to resolve types in this, if present. * In some functions serialization depends on type information. diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java index 3913a16f35a4..d33d2e678fc0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java @@ -48,7 +48,7 @@ public String toString(ToStringContext context) { return "xw_plus_b(" + x.toString(context) + ", " + w.toString(context) + ", " + b.toString(context) + ", " + - dimension + ")"; + context.resolveBinding(dimension) + ")"; } @Override diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java index 8bacc24c3212..5244cf358cfb 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java @@ -12,6 +12,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import static org.junit.Assert.assertEquals; @@ -51,8 +52,10 @@ public void testSimilarityInMixed() { static class MyContext implements TypeContext { Map map = new HashMap<>(); + Map bindings = new HashMap<>(); public TensorType getType(Name name) { return getType(name.name()); } public TensorType getType(String name) { return map.get(name); } + public String resolveBinding(String name) { return Optional.ofNullable(bindings.get(name)).orElse(name); } } @Test @@ -80,4 +83,29 @@ public void testExpansion() { assertEquals("tensor(foo{},z[4])", resType.toString()); } + @Test + public void testExpansionWithDimensionBinding() { + var tTypeA = TensorType.fromSpec("tensor(foo{},vecdim[128])"); + var tTypeB = TensorType.fromSpec("tensor(vecdim[128],z[4])"); + var a = new VariableTensor<>("left", tTypeA); + var b = new VariableTensor<>("right", tTypeB); + var op = new CosineSimilarity<>(a, b, "dimensionArgument"); + assertEquals("join(" + + ( "reduce(join(left, right, f(a,b)(a * b)), sum, dimensionArgument), " + + "map(" + + ( "join(" + + ( "reduce(join(left, left, f(a,b)(a * b)), sum, dimensionArgument), " + + "reduce(join(right, right, f(a,b)(a * b)), sum, dimensionArgument), " + + "f(a,b)(a * b)), " ) + + "f(a)(sqrt(a))), " ) + + "f(a,b)(a / b)" ) + + ")", + op.toPrimitive().toString()); + var context = new MyContext(); + context.map.put("left", tTypeA); + context.map.put("right", tTypeB); + context.bindings.put("dimensionArgument", "vecdim"); + var resType = op.type(context); + assertEquals("tensor(foo{},z[4])", resType.toString()); + } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java index 42f9ef33ff1b..f7554a1b6b3e 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java @@ -50,6 +50,7 @@ static class MyContext implements TypeContext { Map map = new HashMap<>(); public TensorType getType(Name name) { return getType(name.name()); } public TensorType getType(String name) { return map.get(name); } + public String resolveBinding(String name) { return name; } } @Test