Skip to content

Commit

Permalink
fix segment deserializer bug for case when checksum=0 (#162)
Browse files Browse the repository at this point in the history
  • Loading branch information
radai-rosenblatt authored Jan 10, 2020
1 parent 2cbbfdc commit 1f981de
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,24 @@ public LargeMessageSegment deserialize(String s, byte[] bytes) {
int checksum = byteBuffer.getInt();
long messageIdMostSignificantBits = byteBuffer.getLong();
long messageIdLeastSignificantBits = byteBuffer.getLong();
if (checksum == 0 ||
checksum != ((int) (messageIdMostSignificantBits + messageIdLeastSignificantBits))) {
if (checksum != ((int) (messageIdMostSignificantBits + messageIdLeastSignificantBits))) {
LOG.debug("Serialized segment checksum does not match. not large message segment.");
return null;
}
UUID messageId = new UUID(messageIdMostSignificantBits, messageIdLeastSignificantBits);
int sequenceNumber = byteBuffer.getInt();
int numberOfSegments = byteBuffer.getInt();
int messageSizeInBytes = byteBuffer.getInt();
int sequenceNumber = byteBuffer.getInt(); //expected to be [0, numberOfSegments)
int numberOfSegments = byteBuffer.getInt(); //expected to be >0
int messageSizeInBytes = byteBuffer.getInt(); //expected to be >= bytes.length - headerLength
if (sequenceNumber < 0 || numberOfSegments <= 0 || sequenceNumber >= numberOfSegments) {
LOG.warn("Serialized segment sequence {} not in [0, {}). treating as regular payload", sequenceNumber, numberOfSegments);
return null;
}
int segmentPayloadSize = bytes.length - headerLength; //how much user data in this record
if (messageSizeInBytes < segmentPayloadSize) {
//there cannot be more data in a single segment than the total size of the assembled msg
LOG.warn("Serialized segment size {} bigger than assembled msg size {}, treating as regular payload", segmentPayloadSize, messageSizeInBytes);
return null;
}
ByteBuffer payload = byteBuffer.slice();
return new LargeMessageSegment(messageId, sequenceNumber, numberOfSegments, messageSizeInBytes, payload);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

package com.linkedin.kafka.clients.largemessage;

import com.linkedin.kafka.clients.largemessage.errors.InvalidSegmentException;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.serialization.Deserializer;
import org.slf4j.Logger;
Expand All @@ -20,23 +19,23 @@ public class MessageAssemblerImpl implements MessageAssembler {
private static final Logger LOG = LoggerFactory.getLogger(MessageAssemblerImpl.class);
private final LargeMessageBufferPool _messagePool;
private final Deserializer<LargeMessageSegment> _segmentDeserializer;
private final boolean _treatInvalidMessageSegmentsAsPayload;

public MessageAssemblerImpl(long bufferCapacity,
long expirationOffsetGap,
boolean exceptionOnMessageDropped,
Deserializer<LargeMessageSegment> segmentDeserializer) {
this(bufferCapacity, expirationOffsetGap, exceptionOnMessageDropped, segmentDeserializer, false);
_messagePool = new LargeMessageBufferPool(bufferCapacity, expirationOffsetGap, exceptionOnMessageDropped);
_segmentDeserializer = segmentDeserializer;
}

@Deprecated
public MessageAssemblerImpl(long bufferCapacity,
long expirationOffsetGap,
boolean exceptionOnMessageDropped,
Deserializer<LargeMessageSegment> segmentDeserializer,
boolean treatInvalidMessageSegmentsAsPayload) {
@SuppressWarnings("unused") boolean treatInvalidMessageSegmentsAsPayload) {
_messagePool = new LargeMessageBufferPool(bufferCapacity, expirationOffsetGap, exceptionOnMessageDropped);
_segmentDeserializer = segmentDeserializer;
_treatInvalidMessageSegmentsAsPayload = treatInvalidMessageSegmentsAsPayload;
}

@Override
Expand All @@ -51,17 +50,8 @@ public AssembleResult assemble(TopicPartition tp, long offset, byte[] segmentByt
return new AssembleResult(segmentBytes, offset, offset);
} else {
//sanity-check the segment
try {
segment.sanityCheck();
} catch (InvalidSegmentException e) {
if (_treatInvalidMessageSegmentsAsPayload) {
//behave as if this didnt look like a segment to us
LOG.warn("message at {}/{} may have been a false-positive segment and will be treated as regular payload", tp, offset, e);
return new AssembleResult(segmentBytes, offset, offset);
} else {
throw e;
}
}
segment.sanityCheck();

// Return immediately if it is a single segment message.
if (segment.numberOfSegments == 1) {
return new AssembleResult(segment.payloadArray(), offset, offset);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright 2020 LinkedIn Corp. Licensed under the BSD 2-Clause License (the "License").
 See License in the project root for license information.
*/

package com.linkedin.kafka.clients.largemessage;

import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.UUID;
import org.testng.Assert;
import org.testng.annotations.Test;


public class DefaultSegmentDeserializerTest {

@Test
public void testZeroChecksum() {
DefaultSegmentSerializer segmentSerializer = new DefaultSegmentSerializer();

//doctor a UUID such that the projected checksum is 0
long a = (Long.MAX_VALUE / 2) - 1;
long b = (Long.MAX_VALUE / 2) + 3;
int checksum = (int) (a + b);

Assert.assertEquals(checksum, 0, "projected checksum should be 0. instead was " + checksum);

UUID msgId = new UUID(a, b);
byte[] payload = new byte[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
LargeMessageSegment segment = new LargeMessageSegment(msgId, 0, 1, 10, ByteBuffer.wrap(payload));

byte[] serialized = segmentSerializer.serialize("topic", segment);

DefaultSegmentDeserializer segmentDeserializer = new DefaultSegmentDeserializer();

LargeMessageSegment deserialized = segmentDeserializer.deserialize("topic", serialized);

Assert.assertNotNull(deserialized);
Assert.assertEquals(deserialized.messageId, msgId);
Assert.assertTrue(Arrays.equals(payload, deserialized.payloadArray()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

package com.linkedin.kafka.clients.largemessage;

import com.linkedin.kafka.clients.largemessage.errors.InvalidSegmentException;
import com.linkedin.kafka.clients.utils.LiKafkaClientsUtils;
import java.util.UUID;
import org.apache.kafka.common.TopicPartition;
Expand Down Expand Up @@ -44,7 +43,7 @@ public void testSingleMessageSegment() {
public void testTreatBadSegmentAsPayload() {
Serializer<LargeMessageSegment> segmentSerializer = new DefaultSegmentSerializer();
Deserializer<LargeMessageSegment> segmentDeserializer = new DefaultSegmentDeserializer();
MessageAssembler messageAssembler = new MessageAssemblerImpl(100, 100, true, segmentDeserializer, false);
MessageAssembler messageAssembler = new MessageAssemblerImpl(100, 100, true, segmentDeserializer);
TopicPartition tp = new TopicPartition("topic", 0);

UUID uuid = UUID.randomUUID();
Expand All @@ -53,14 +52,8 @@ public void testTreatBadSegmentAsPayload() {
byte[] messageWrappedBytes = segmentSerializer.serialize(tp.topic(), badSegment);
Assert.assertTrue(messageWrappedBytes.length > realPayload.length); //wrapping has been done

try {
messageAssembler.assemble(tp, 0, messageWrappedBytes);
Assert.fail("expected to throw");
} catch (InvalidSegmentException expected) {

}
messageAssembler.assemble(tp, 0, messageWrappedBytes);

messageAssembler = new MessageAssemblerImpl(100, 100, true, segmentDeserializer, true);
MessageAssembler.AssembleResult assembleResult = messageAssembler.assemble(tp, 0, messageWrappedBytes);
Assert.assertEquals(assembleResult.messageBytes(), messageWrappedBytes);
Assert.assertEquals(assembleResult.messageStartingOffset(), 0);
Expand Down

0 comments on commit 1f981de

Please sign in to comment.