Skip to content

Commit

Permalink
Merge pull request #33062 from vespa-engine/bratseth/test-packBits
Browse files Browse the repository at this point in the history
Test pack bits combined with a natively binarizing embedder
  • Loading branch information
bratseth authored Jan 7, 2025
2 parents 937b47f + 20a8aff commit c82503d
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,13 @@ public IndexingProcessor(DocumentTypeManager documentTypeManager,
IlscriptsConfig ilscriptsConfig,
Linguistics linguistics,
ComponentRegistry<Embedder> embedders) {
this(documentTypeManager, new ScriptManager(documentTypeManager, ilscriptsConfig, linguistics, toMap(embedders)));
}

public IndexingProcessor(DocumentTypeManager documentTypeManager,
ScriptManager scriptManager) {
this.documentTypeManager = documentTypeManager;
scriptManager = new ScriptManager(this.documentTypeManager, ilscriptsConfig, linguistics, toMap(embedders));
this.scriptManager = scriptManager;
adapterFactory = new SimpleAdapterFactory(new ExpressionSelector());
}

Expand Down Expand Up @@ -132,7 +137,7 @@ private void processRemove(DocumentRemove input, List<DocumentOperation> out) {
out.add(input);
}

private Map<String, Embedder> toMap(ComponentRegistry<Embedder> embedders) {
private static Map<String, Embedder> toMap(ComponentRegistry<Embedder> embedders) {
var map = embedders.allComponentsById().entrySet().stream()
.collect(Collectors.toMap(e -> e.getKey().stringValue(), Map.Entry::getValue));
if (map.size() > 1) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.docprocs.indexing;

import com.yahoo.component.AbstractComponent;
import com.yahoo.document.DataType;
import com.yahoo.document.Document;
import com.yahoo.document.DocumentOperation;
import com.yahoo.document.DocumentPut;
import com.yahoo.document.DocumentType;
import com.yahoo.document.DocumentTypeManager;
import com.yahoo.document.DocumentUpdate;
import com.yahoo.document.PositionDataType;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.StringFieldValue;
import com.yahoo.document.update.AssignValueUpdate;
import com.yahoo.document.update.FieldUpdate;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.Tensors;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.configdefinition.IlscriptsConfig;
import org.junit.Test;

import java.util.List;
import java.util.Map;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
Expand Down Expand Up @@ -225,4 +239,61 @@ public void requireThatIndexerForwardsUpdatesOfUnknownType() {
assertSame(input, output);
}

@Test
public void testEmbedBinarizeAndPack() {
var documentTypes = new DocumentTypeManager();
var test = new DocumentType("test");
test.addField("myText", DataType.STRING);
test.addField("embedding", new TensorDataType(TensorType.fromSpec("tensor<int8>(x[16])")));
documentTypes.register(test);

IlscriptsConfig.Builder config = new IlscriptsConfig.Builder();
config.ilscript(new IlscriptsConfig.Ilscript.Builder().doctype("test")
.content("input myText | embed | binarize | pack_bits | attribute embedding")
.docfield("myText"));
var scripts = new ScriptManager(documentTypes, new IlscriptsConfig(config), null, Map.of("test", new TestEmbedder()));
assertNotNull(scripts.getScript(documentTypes.getDocumentType("test")));

var tester = new IndexingProcessorTester(documentTypes, scripts);
DocumentUpdate input = new DocumentUpdate(test, "id:ns:test::");
input.addFieldUpdate(FieldUpdate.createAssign(test.getField("myText"), new StringFieldValue("my text")));
DocumentUpdate output = (DocumentUpdate)tester.process(input);
FieldUpdate embeddingUpdate = output.getFieldUpdate("embedding");
AssignValueUpdate valueUpdate = (AssignValueUpdate)embeddingUpdate.getValueUpdate(0);
assertEquals(Tensor.from("tensor<int8>(x[16]):[-110, 73, 36, -110, 73, 36, -110, 73, 36, -110, 73, 36, -110, 73, 36, -110]"),
valueUpdate.getValue().getWrappedValue());
}

/** An ebedder which also does its own quantization, similar to HuggingFaceEmbedder. */
static class TestEmbedder extends AbstractComponent implements Embedder {

@Override
public List<Integer> embed(String s, Context context) {
throw new UnsupportedOperationException();
}

@Override
public Tensor embed(String text, Context context, TensorType tensorType) {
if (tensorType.dimensions().size() != 1)
throw new IllegalArgumentException("Error in embedding to type '" + tensorType + "': should only have one dimension.");
if (!tensorType.dimensions().get(0).isIndexed())
throw new IllegalArgumentException("Error in embedding to type '" + tensorType + "': dimension should be indexed.");
boolean binarize = tensorType.valueType() == TensorType.Value.INT8;
long size = tensorType.dimensions().get(0).size().get();
if (binarize)
size = size * 8;
var embeddedType = new TensorType.Builder().indexed(tensorType.dimensions().get(0).name(), size).build();
var resultBuilder = Tensor.Builder.of(embeddedType);
for (int i = 0; i < size; i++) {
int v = ((i % 3) == 0) ? 1 : 0;
resultBuilder.cell(v, i);
}
var result = resultBuilder.build();
if (binarize)
result = Tensors.packBits(result);
return result;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ public IndexingProcessorTester(String configDir) {
indexer = newProcessor("dir:" + configDir);
}

public IndexingProcessorTester(DocumentTypeManager documentTypes, ScriptManager scripts) {
indexer = newProcessor(documentTypes, scripts);
}

public DocumentType getDocumentType(String name) {
return indexer.getDocumentTypeManager().getDocumentType(name);
}
Expand Down Expand Up @@ -70,4 +74,8 @@ private static IndexingProcessor newProcessor(String configId) {
new ComponentRegistry<>());
}

private static IndexingProcessor newProcessor(DocumentTypeManager documentTypes, ScriptManager scripts) {
return new IndexingProcessor(documentTypes, scripts);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ protected void doExecute(ExecutionContext context) {

/** Returns the type this requires when producing the given output type. */
private TensorType inputType(TensorType givenType) {
var builder = new TensorType.Builder(TensorType.Value.INT8); // Any larger value type is also permissible
var builder = new TensorType.Builder(TensorType.Value.DOUBLE); // Any value type is permissible
for (var d : givenType.dimensions())
builder.dimension(d.size().isPresent() ? d.withSize(d.size().get() * 8) : d);
return builder.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public void testEmbedAndBinarize() {
@Test
public void testEmbedBinarizeAndPack_bits() {
var tester = new EmbeddingScriptTester(Map.of("emb1", new EmbeddingScriptTester.MockIndexedEmbedder("myDocument.myTensor", -111)));
tester.testStatement("input myText | embed | binarize | pack_bits | attribute 'myTensor'", "input text", "tensor<int8>(x[2])", "[58, 192]");
tester.testStatement("input myText | embed | binarize | pack_bits | attribute 'myTensor'", "input text", "tensor<int8>(x[2])", "[58, -64]");
}

@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ public void deconstruct() {
tokenizer.close();
}

@SuppressWarnings("unchecked")
@Override
public Tensor embed(String text, Context context, TensorType tensorType) {
if (tensorType.dimensions().size() != 1) {
Expand Down Expand Up @@ -213,6 +212,7 @@ private Tensor binaryQuantization(HuggingFaceEmbedder.HFEmbeddingResult embeddin
/**
* Binary quantization of the embedding into a tensor of type int8 with the specified dimensions.
*/
// TODO: Call Tensors.packBits instead. It is more general and faster.
static public Tensor binarize(IndexedTensor embedding, TensorType tensorType) {
Tensor.Builder builder = Tensor.Builder.of(tensorType);
BitSet bitSet = new BitSet(8);
Expand Down
2 changes: 1 addition & 1 deletion vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ private static String cellToString(Map.Entry<TensorAddress, Double> cell, Tensor
int hashCode();

/**
* Implement here to make this work across implementations.
* Implemented here to make this work across implementations.
* Implementations must override equals and call this because this is an interface and cannot override equals.
*/
static boolean equals(Tensor a, Tensor b) {
Expand Down
19 changes: 9 additions & 10 deletions vespajlib/src/main/java/com/yahoo/tensor/Tensors.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.yahoo.api.annotations.Beta;

import java.util.Arrays;
import java.util.BitSet;
import java.util.Iterator;

/**
Expand Down Expand Up @@ -47,9 +48,10 @@ public static Tensor toSparse(Tensor tensor, String ... dimensions) {
}

/**
* Converts any tensor containing only ones and zeroes into one where each consecutive 8 values in the
* dense dimension are packed into a single byte. As a consequence the output type of this is a tensor
* where the dense dimension is 1/8th as large.
* Converts any tensor into one where each consecutive 8 values in the
* dense dimension are packed into a single byte,
* by setting a bit to 1 when the tensor has a positive value and 0 otherwise.
* As a consequence the output type of this is a tensor where the dense dimension is 1/8th as large.
*
* @throws IllegalArgumentException if the tensor has the wrong type or contains any other value than 0 or 1
*/
Expand All @@ -71,7 +73,7 @@ public static Tensor packBits(Tensor tensor) {
int packedValue = 0;
for (int j = 0; j < 8 && i < indexed.size(); j++)
packedValue = packInto(packedValue, indexed.get(i), j, i++);
builder.cell(packedValue, packedIndex);
builder.cell((byte)packedValue, packedIndex);
}
}
else if (tensor instanceof MixedTensor mixed) {
Expand All @@ -81,7 +83,7 @@ else if (tensor instanceof MixedTensor mixed) {
int packedValue = 0;
for (int j = 0; j < 8 && i < denseSubspace.cells.length; j++)
packedValue = packInto(packedValue, denseSubspace.cells[i], j, i++);
builder.cell(packedAddress, packedValue);
builder.cell(packedAddress, (byte)packedValue);
}
}
}
Expand All @@ -93,13 +95,10 @@ else if (tensor instanceof MixedTensor mixed) {
}

private static int packInto(int packedValue, double value, int bitPosition, long sourcePosition) {
if (value == 0.0)
if (value <= 0.0)
return packedValue;
else if (value == 1.0)
return packedValue | ( 1 << ( 7 - bitPosition ));
else
throw new IllegalArgumentException("The tensor to be packed can only contain 0 or 1 values, " +
"but has " + value + " at position " + sourcePosition);
return packedValue | ( 1 << ( 7 - bitPosition ));
}

}
22 changes: 8 additions & 14 deletions vespajlib/src/test/java/com/yahoo/tensor/TensorsTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import org.junit.jupiter.api.Test;

import java.util.BitSet;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;

Expand All @@ -27,15 +29,15 @@ void testToSparse() {

@Test
void testPackBits() {
assertPacked("tensor<int8>(x[2]):[129,14]", "tensor(x[16]):[1,0,0,0,0,0,0,1, 0,0,0,0,1,1,1,0]");
assertPacked("tensor<int8>(x[2]):[129,14]", "tensor(x[15]):[1,0,0,0,0,0,0,1, 0,0,0,0,1,1,1]");
assertPacked("tensor<int8>(x[1]):[128]", "tensor(x[1]):[1]");
assertPacked("tensor<int8>(key{},x[2]):{a:[129,14], b:[12, 7]}",
assertPacked("tensor<int8>(x[2]):[-127,14]", "tensor(x[16]):[1,0,0,0,0,0,0,1, 0,0,0,0,1,1,1,0]");
assertPacked("tensor<int8>(x[2]):[-127,14]", "tensor(x[15]):[1,0,0,0,0,0,0,1, 0,0,0,0,1,2,3]");
assertPacked("tensor<int8>(x[1]):[-128]", "tensor(x[1]):[1]");
assertPacked("tensor<int8>(key{},x[2]):{a:[-127,14], b:[12, 7]}",
"tensor(key{},x[16]):{a:[1,0,0,0,0,0,0,1, 0,0,0,0,1,1,1,0]," +
" b:[0,0,0,0,1,1,0,0, 0,0,0,0,0,1,1,1]}");
assertPacked("tensor<int8>(key{},x[1]):{a:[160],b:[32]}",
assertPacked("tensor<int8>(key{},x[1]):{a:[-96],b:[32]}",
"tensor(key{},x[3]):{a:[1,0,1],b:[0,0,1]}");
assertPacked("tensor<int8>(key{},x[1]):{a:[128]}", "tensor(key{}, x[1]):{a:[1]}");
assertPacked("tensor<int8>(key{},x[1]):{a:[-128]}", "tensor(key{}, x[1]):{a:[1]}");

try {
Tensors.packBits(Tensor.from("tensor(x[1],y[1]):[1]"));
Expand All @@ -45,14 +47,6 @@ void testPackBits() {
assertEquals("packBits requires a tensor with one dense dimensions, but got tensor(x[1],y[1])",
e.getMessage());
}
try {
Tensors.packBits(Tensor.from("tensor(x[3]):[0, 1, 2]"));
fail("Expected exception");
}
catch (IllegalArgumentException e) {
assertEquals("The tensor to be packed can only contain 0 or 1 values, but has 2.0 at position 2",
e.getMessage());
}
}

void assertConvertedToSparse(String inputType, String outputType, String tensorValue, String ... dimensions) {
Expand Down

0 comments on commit c82503d

Please sign in to comment.