Skip to content

Commit

Permalink
[server] A few WCAA performance improvements (#694)
Browse files Browse the repository at this point in the history
* [server] A few WCAA performance improvements

1. Adopt fast-avro in the AAWC code path.
2. Use a schema-id based serde cache to avoid expensive schema comparison.
3. Dynamically choose the right structure for different logics, such as
   LinkedList vs ArrayList.

* Fixed spotbug issue

* Fixed test coverage issue

* Addressed comments
  • Loading branch information
gaojieliu authored Nov 1, 2023
1 parent b33fd93 commit 216c2c5
Show file tree
Hide file tree
Showing 25 changed files with 268 additions and 97 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ if (project.hasProperty('overrideBuildEnvironment')) {
}

def avroVersion = '1.10.2'
def avroUtilVersion = '0.3.19'
def avroUtilVersion = '0.3.21'
def grpcVersion = '1.49.2'
def kafkaGroup = 'com.linkedin.kafka'
def kafkaVersion = '2.4.1.65'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,17 @@ public ActiveActiveStoreIngestionTask(
StringAnnotatedStoreSchemaCache annotatedReadOnlySchemaRepository =
new StringAnnotatedStoreSchemaCache(storeName, schemaRepository);

this.rmdSerDe = new RmdSerDe(annotatedReadOnlySchemaRepository, rmdProtocolVersionId);
this.rmdSerDe = new RmdSerDe(
annotatedReadOnlySchemaRepository,
rmdProtocolVersionId,
getServerConfig().isComputeFastAvroEnabled());
this.mergeConflictResolver = MergeConflictResolverFactory.getInstance()
.createMergeConflictResolver(
annotatedReadOnlySchemaRepository,
rmdSerDe,
getStoreName(),
isWriteComputationEnabled);
isWriteComputationEnabled,
getServerConfig().isComputeFastAvroEnabled());
this.remoteIngestionRepairService = builder.getRemoteIngestionRepairService();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import com.linkedin.davinci.schema.merge.MergeRecordHelper;
import com.linkedin.davinci.schema.writecompute.WriteComputeProcessor;
import com.linkedin.davinci.schema.writecompute.WriteComputeSchemaValidator;
import com.linkedin.davinci.serializer.avro.MapOrderingPreservingSerDeFactory;
import com.linkedin.davinci.serializer.avro.MapOrderPreservingSerDeFactory;
import com.linkedin.venice.exceptions.VeniceException;
import com.linkedin.venice.meta.ReadOnlySchemaRepository;
import com.linkedin.venice.serializer.AvroSerializer;
Expand Down Expand Up @@ -132,7 +132,7 @@ RecordDeserializer<GenericRecord> getValueDeserializer(Schema writerSchema, Sche
// Map in write compute needs to have consistent ordering. On the sender side, users may not care about ordering
// in their maps. However, on the receiver side, we still want to make sure that the same serialized map bytes
// always get deserialized into maps with the same entry ordering.
return MapOrderingPreservingSerDeFactory.getDeserializer(writerSchema, readerSchema);
return MapOrderPreservingSerDeFactory.getDeserializer(writerSchema, readerSchema);
}

private RecordSerializer<GenericRecord> getValueSerializer(int valueSchemaId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

import com.linkedin.davinci.replication.RmdWithValueSchemaId;
import com.linkedin.davinci.schema.merge.ValueAndRmd;
import com.linkedin.davinci.serializer.avro.MapOrderingPreservingSerDeFactory;
import com.linkedin.davinci.serializer.avro.MapOrderPreservingSerDeFactory;
import com.linkedin.davinci.serializer.avro.fast.MapOrderPreservingFastSerDeFactory;
import com.linkedin.davinci.store.record.ValueRecord;
import com.linkedin.venice.annotation.Threadsafe;
import com.linkedin.venice.exceptions.VeniceException;
Expand All @@ -20,7 +21,11 @@
import com.linkedin.venice.schema.rmd.RmdTimestampType;
import com.linkedin.venice.schema.rmd.RmdUtils;
import com.linkedin.venice.schema.writecompute.WriteComputeOperation;
import com.linkedin.venice.serializer.RecordDeserializer;
import com.linkedin.venice.serializer.RecordSerializer;
import com.linkedin.venice.utils.AvroSchemaUtils;
import com.linkedin.venice.utils.SparseConcurrentList;
import com.linkedin.venice.utils.collections.BiIntKeyCache;
import com.linkedin.venice.utils.lazy.Lazy;
import java.nio.ByteBuffer;
import java.util.ArrayList;
Expand Down Expand Up @@ -52,6 +57,11 @@ public class MergeConflictResolver {
private final MergeResultValueSchemaResolver mergeResultValueSchemaResolver;
private final RmdSerDe rmdSerde;
private final boolean useFieldLevelTimestamp;
private final boolean fastAvroEnabled;

private final SparseConcurrentList<RecordSerializer<GenericRecord>> serializerIndexedByValueSchemaId;
private final BiIntKeyCache<RecordDeserializer<GenericRecord>> deserializerCacheForFullValue;
private final BiIntKeyCache<SparseConcurrentList<RecordDeserializer<GenericRecord>>> deserializerCacheForUpdateValue;

MergeConflictResolver(
StringAnnotatedStoreSchemaCache storeSchemaCache,
Expand All @@ -61,7 +71,8 @@ public class MergeConflictResolver {
MergeByteBuffer mergeByteBuffer,
MergeResultValueSchemaResolver mergeResultValueSchemaResolver,
RmdSerDe rmdSerde,
boolean useFieldLevelTimestamp) {
boolean useFieldLevelTimestamp,
boolean fastAvroEnabled) {
this.storeSchemaCache = Validate.notNull(storeSchemaCache);
this.storeName = Validate.notNull(storeName);
this.newRmdCreator = Validate.notNull(newRmdCreator);
Expand All @@ -70,6 +81,18 @@ public class MergeConflictResolver {
this.mergeByteBuffer = Validate.notNull(mergeByteBuffer);
this.rmdSerde = Validate.notNull(rmdSerde);
this.useFieldLevelTimestamp = useFieldLevelTimestamp;
this.fastAvroEnabled = fastAvroEnabled;

this.serializerIndexedByValueSchemaId = new SparseConcurrentList<>();
this.deserializerCacheForFullValue = new BiIntKeyCache<>((writerSchemaId, readerSchemaId) -> {
Schema writerSchema = getValueSchema(writerSchemaId);
Schema readerSchema = getValueSchema(readerSchemaId);
return this.fastAvroEnabled
? MapOrderPreservingFastSerDeFactory.getDeserializer(writerSchema, readerSchema)
: MapOrderPreservingSerDeFactory.getDeserializer(writerSchema, readerSchema);
});
this.deserializerCacheForUpdateValue =
new BiIntKeyCache<>((writerSchemaId, readerSchemaId) -> new SparseConcurrentList<>());
}

/**
Expand Down Expand Up @@ -247,7 +270,7 @@ public MergeConflictResult update(
}
final ByteBuffer updatedValueBytes = updatedValueAndRmd.getValue() == null
? null
: serializeMergedValueRecord(oldValueSchema, updatedValueAndRmd.getValue());
: serializeMergedValueRecord(oldValueSchemaID, updatedValueAndRmd.getValue());
return new MergeConflictResult(updatedValueBytes, oldValueSchemaID, false, updatedValueAndRmd.getRmd());
}

Expand Down Expand Up @@ -300,14 +323,13 @@ private MergeConflictResult mergePutWithFieldLevelTimestamp(
}
final SchemaEntry mergeResultValueSchemaEntry =
mergeResultValueSchemaResolver.getMergeResultValueSchema(oldValueSchemaID, newValueSchemaID);
final Schema mergeResultValueSchema = mergeResultValueSchemaEntry.getSchema();
final Schema newValueWriterSchema = getValueSchema(newValueSchemaID);
/**
* Note that it is important that the new value record should NOT use {@link mergeResultValueSchema}.
* {@link newValueWriterSchema} is either the same as {@link mergeResultValueSchema} or it is a subset of
* {@link mergeResultValueSchema}.
*/
GenericRecord newValueRecord = deserializeValue(newValueBytes, newValueWriterSchema, newValueWriterSchema);
GenericRecord newValueRecord =
deserializerCacheForFullValue.get(newValueSchemaID, newValueSchemaID).deserialize(newValueBytes);
ValueAndRmd<GenericRecord> oldValueAndRmd = createOldValueAndRmd(
mergeResultValueSchemaEntry.getSchema(),
mergeResultValueSchemaEntry.getId(),
Expand All @@ -325,7 +347,8 @@ private MergeConflictResult mergePutWithFieldLevelTimestamp(
if (mergedValueAndRmd.isUpdateIgnored()) {
return MergeConflictResult.getIgnoredResult();
}
ByteBuffer mergedValueBytes = serializeMergedValueRecord(mergeResultValueSchema, mergedValueAndRmd.getValue());
ByteBuffer mergedValueBytes =
serializeMergedValueRecord(mergeResultValueSchemaEntry.getId(), mergedValueAndRmd.getValue());
return new MergeConflictResult(mergedValueBytes, newValueSchemaID, false, mergedValueAndRmd.getRmd());
}

Expand Down Expand Up @@ -380,7 +403,7 @@ private MergeConflictResult mergeDeleteWithFieldLevelTimestamp(
}
final ByteBuffer mergedValueBytes = mergedValueAndRmd.getValue() == null
? null
: serializeMergedValueRecord(oldValueSchema, mergedValueAndRmd.getValue());
: serializeMergedValueRecord(oldValueSchemaID, mergedValueAndRmd.getValue());
return new MergeConflictResult(mergedValueBytes, oldValueSchemaID, false, mergedValueAndRmd.getRmd());
}

Expand All @@ -402,8 +425,11 @@ private ValueAndRmd<GenericRecord> createOldValueAndRmd(
int oldValueWriterSchemaID,
Lazy<ByteBuffer> oldValueBytesProvider,
GenericRecord oldRmdRecord) {
final GenericRecord oldValueRecord =
createValueRecordFromByteBuffer(readerValueSchema, oldValueWriterSchemaID, oldValueBytesProvider.get());
final GenericRecord oldValueRecord = createValueRecordFromByteBuffer(
readerValueSchema,
readerValueSchemaID,
oldValueWriterSchemaID,
oldValueBytesProvider.get());

// RMD record should contain a per-field timestamp and it should use the RMD schema generated from
// mergeResultValueSchema.
Expand All @@ -418,13 +444,13 @@ private ValueAndRmd<GenericRecord> createOldValueAndRmd(

private GenericRecord createValueRecordFromByteBuffer(
Schema readerValueSchema,
int readerValueSchemaID,
int oldValueWriterSchemaID,
ByteBuffer oldValueBytes) {
if (oldValueBytes == null) {
return AvroSchemaUtils.createGenericRecord(readerValueSchema);
}
final Schema oldValueWriterSchema = getValueSchema(oldValueWriterSchemaID);
return deserializeValue(oldValueBytes, oldValueWriterSchema, readerValueSchema);
return deserializerCacheForFullValue.get(oldValueWriterSchemaID, readerValueSchemaID).deserialize(oldValueBytes);
}

private GenericRecord convertRmdToUseReaderValueSchema(
Expand All @@ -439,13 +465,6 @@ private GenericRecord convertRmdToUseReaderValueSchema(
return rmdSerde.deserializeRmdBytes(writerValueSchemaID, readerValueSchemaID, rmdBytes);
}

private GenericRecord deserializeValue(ByteBuffer bytes, Schema writerSchema, Schema readerSchema) {
/**
* TODO: Refactor this to use {@link com.linkedin.venice.serialization.StoreDeserializerCache}
*/
return MapOrderingPreservingSerDeFactory.getDeserializer(writerSchema, readerSchema).deserialize(bytes);
}

private boolean ignoreNewPut(
final int oldValueSchemaID,
GenericRecord oldValueFieldTimestampsRecord,
Expand Down Expand Up @@ -582,9 +601,17 @@ private GenericRecord deserializeWriteComputeBytes(
int readerValueSchemaId,
int updateProtocolVersion,
ByteBuffer updateBytes) {
Schema writerSchema = getWriteComputeSchema(writerValueSchemaId, updateProtocolVersion);
Schema readerSchema = getWriteComputeSchema(readerValueSchemaId, updateProtocolVersion);
return deserializeValue(updateBytes, writerSchema, readerSchema);
RecordDeserializer<GenericRecord> deserializer =
deserializerCacheForUpdateValue.get(writerValueSchemaId, readerValueSchemaId)
.computeIfAbsent(updateProtocolVersion, ignored -> {
Schema writerSchema = getWriteComputeSchema(writerValueSchemaId, updateProtocolVersion);
Schema readerSchema = getWriteComputeSchema(readerValueSchemaId, updateProtocolVersion);
return this.fastAvroEnabled
? MapOrderPreservingFastSerDeFactory.getDeserializer(writerSchema, readerSchema)
: MapOrderPreservingSerDeFactory.getDeserializer(writerSchema, readerSchema);
});

return deserializer.deserialize(updateBytes);
}

private ValueAndRmd<GenericRecord> prepareValueAndRmdForUpdate(
Expand All @@ -603,8 +630,8 @@ private ValueAndRmd<GenericRecord> prepareValueAndRmdForUpdate(
* case, the value must be retrieved from storage engine, and is prepended with schema ID.
*/
int schemaId = ValueRecord.parseSchemaId(oldValueBytes.array());
Schema writerSchema = getValueSchema(schemaId);
newValue = deserializeValue(oldValueBytes, writerSchema, readerValueSchemaEntry.getSchema());
newValue =
deserializerCacheForFullValue.get(schemaId, readerValueSchemaEntry.getId()).deserialize(oldValueBytes);
}
GenericRecord newRmd = newRmdCreator.apply(readerValueSchemaEntry.getId());
newRmd.put(TIMESTAMP_FIELD_POS, createPerFieldTimestampRecord(newRmd.getSchema(), 0L, newValue));
Expand Down Expand Up @@ -741,10 +768,16 @@ private boolean ignoreNewUpdate(
}
}

private ByteBuffer serializeMergedValueRecord(Schema mergedValueSchema, GenericRecord mergedValue) {
private ByteBuffer serializeMergedValueRecord(int mergedValueSchemaId, GenericRecord mergedValue) {
// TODO: avoid serializing the merged value result here and instead serializing it before persisting it. The goal
// is to avoid back-and-forth ser/de. Because when the merged result is read before it is persisted, we may need
// to deserialize it.
return ByteBuffer.wrap(MapOrderingPreservingSerDeFactory.getSerializer(mergedValueSchema).serialize(mergedValue));
RecordSerializer serializer = serializerIndexedByValueSchemaId.computeIfAbsent(mergedValueSchemaId, ignored -> {
Schema mergedValueSchema = getValueSchema(mergedValueSchemaId);
return fastAvroEnabled
? MapOrderPreservingFastSerDeFactory.getSerializer(mergedValueSchema)
: MapOrderPreservingSerDeFactory.getSerializer(mergedValueSchema);
});
return ByteBuffer.wrap(serializer.serialize(mergedValue));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ public MergeConflictResolver createMergeConflictResolver(
StringAnnotatedStoreSchemaCache annotatedReadOnlySchemaRepository,
RmdSerDe rmdSerDe,
String storeName,
boolean rmdUseFieldLevelTs) {
boolean rmdUseFieldLevelTs,
boolean fastAvroEnabled) {
MergeRecordHelper mergeRecordHelper = new CollectionTimestampMergeRecordHelper();
return new MergeConflictResolver(
annotatedReadOnlySchemaRepository,
Expand All @@ -31,13 +32,14 @@ public MergeConflictResolver createMergeConflictResolver(
new MergeByteBuffer(),
new MergeResultValueSchemaResolverImpl(annotatedReadOnlySchemaRepository, storeName),
rmdSerDe,
rmdUseFieldLevelTs);
rmdUseFieldLevelTs,
fastAvroEnabled);
}

public MergeConflictResolver createMergeConflictResolver(
StringAnnotatedStoreSchemaCache annotatedReadOnlySchemaRepository,
RmdSerDe rmdSerDe,
String storeName) {
return createMergeConflictResolver(annotatedReadOnlySchemaRepository, rmdSerDe, storeName, false);
return createMergeConflictResolver(annotatedReadOnlySchemaRepository, rmdSerDe, storeName, false, true);
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package com.linkedin.davinci.replication.merge;

import com.linkedin.davinci.replication.RmdWithValueSchemaId;
import com.linkedin.davinci.serializer.avro.MapOrderingPreservingSerDeFactory;
import com.linkedin.davinci.serializer.avro.MapOrderPreservingSerDeFactory;
import com.linkedin.davinci.serializer.avro.fast.MapOrderPreservingFastSerDeFactory;
import com.linkedin.venice.annotation.Threadsafe;
import com.linkedin.venice.exceptions.VeniceException;
import com.linkedin.venice.schema.rmd.RmdSchemaEntry;
Expand Down Expand Up @@ -30,16 +31,27 @@ public class RmdSerDe {
private final SparseConcurrentList<Schema> rmdSchemaIndexedByValueSchemaId;
private final SparseConcurrentList<RecordSerializer<GenericRecord>> rmdSerializerIndexedByValueSchemaId;
private final BiIntKeyCache<RecordDeserializer<GenericRecord>> deserializerCache;
private final boolean fastAvroEnabled;

public RmdSerDe(StringAnnotatedStoreSchemaCache annotatedStoreSchemaCache, int rmdVersionId) {
this(annotatedStoreSchemaCache, rmdVersionId, true);
}

public RmdSerDe(
StringAnnotatedStoreSchemaCache annotatedStoreSchemaCache,
int rmdVersionId,
boolean fastAvroEnabled) {
this.annotatedStoreSchemaCache = annotatedStoreSchemaCache;
this.rmdVersionId = rmdVersionId;
this.rmdSchemaIndexedByValueSchemaId = new SparseConcurrentList<>();
this.rmdSerializerIndexedByValueSchemaId = new SparseConcurrentList<>();
this.fastAvroEnabled = fastAvroEnabled;
this.deserializerCache = new BiIntKeyCache<>((writerSchemaId, readerSchemaId) -> {
Schema rmdWriterSchema = getRmdSchema(writerSchemaId);
Schema rmdReaderSchema = getRmdSchema(readerSchemaId);
return MapOrderingPreservingSerDeFactory.getDeserializer(rmdWriterSchema, rmdReaderSchema);
return this.fastAvroEnabled
? MapOrderPreservingFastSerDeFactory.getDeserializer(rmdWriterSchema, rmdReaderSchema)
: MapOrderPreservingSerDeFactory.getDeserializer(rmdWriterSchema, rmdReaderSchema);
});
}

Expand Down Expand Up @@ -98,6 +110,8 @@ private RecordDeserializer<GenericRecord> getRmdDeserializer(final int writerSch

private RecordSerializer<GenericRecord> generateRmdSerializer(int valueSchemaId) {
Schema replicationMetadataSchema = getRmdSchema(valueSchemaId);
return MapOrderingPreservingSerDeFactory.getSerializer(replicationMetadataSchema);
return fastAvroEnabled
? MapOrderPreservingFastSerDeFactory.getSerializer(replicationMetadataSchema)
: MapOrderPreservingSerDeFactory.getSerializer(replicationMetadataSchema);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ public UpdateResultStatus handlePutList(
collectionFieldRmd.setPutOnlyPartLength(toPutList.size());
return UpdateResultStatus.COMPLETELY_UPDATED;
}
/**
* LinkedList is more efficient for the following add/remove operations.
*/
if (!toPutList.isEmpty() && !(toPutList instanceof LinkedList)) {
toPutList = new LinkedList<>((toPutList));
}
// The current list is NOT in the put-only state. So we need to de-dup the incoming list.
deDupListFromEnd(toPutList);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package com.linkedin.davinci.schema.merge;

import com.linkedin.davinci.utils.IndexedHashMap;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;


Expand All @@ -18,6 +20,12 @@ static <T> IndexedHashMap<T, Long> createElementToActiveTsMap(
final int putOnlyPartLength) {
IndexedHashMap<T, Long> activeElementToTsMap = new IndexedHashMap<>(existingElements.size());
int idx = 0;
if (!existingElements.isEmpty() && activeTimestamps instanceof LinkedList) {
/**
* LinkedList is not efficient for get operation
*/
activeTimestamps = new ArrayList<>(activeTimestamps);
}
for (T existingElement: existingElements) {
final long activeTimestamp;
if (idx < putOnlyPartLength) {
Expand All @@ -41,6 +49,12 @@ static <T> IndexedHashMap<T, Long> createDeletedElementToTsMap(
) {
IndexedHashMap<T, Long> elementToTimestampMap = new IndexedHashMap<>();
int idx = 0;
if (!deletedTimestamps.isEmpty() && deletedElements instanceof LinkedList) {
/**
* LinkedList is not efficient for get operation
*/
deletedElements = new ArrayList<>(deletedElements);
}
for (long deletedTimestamp: deletedTimestamps) {
if (deletedTimestamp >= minTimestamp) {
elementToTimestampMap.put(deletedElements.get(idx), deletedTimestamp);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import com.linkedin.davinci.utils.IndexedHashMap;
import com.linkedin.davinci.utils.IndexedMap;
import java.util.Collection;
import java.util.LinkedList;
import java.util.Map;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericDatumReader;
Expand All @@ -18,15 +16,6 @@ public MapOrderPreservingDatumReader(Schema writer, Schema reader) {
super(writer, reader);
}

@Override
protected Object newArray(Object old, int size, Schema schema) {
if (old instanceof Collection) {
((Collection) old).clear();
return old;
} else
return new LinkedList<>();
}

@Override
protected Object newMap(Object old, int size) {
if (old instanceof IndexedMap) {
Expand Down
Loading

0 comments on commit 216c2c5

Please sign in to comment.