Skip to content

Commit

Permalink
allow treating bad large message segments as raw payloads (#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
radai-rosenblatt authored and ambroff committed Nov 21, 2019
1 parent 153bb1e commit ef41616
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public class LiKafkaConsumerConfig extends AbstractConfig {
public static final String MESSAGE_ASSEMBLER_EXPIRATION_OFFSET_GAP_CONFIG = "message.assembler.expiration.offset.gap";
public static final String MAX_TRACKED_MESSAGES_PER_PARTITION_CONFIG = "max.tracked.messages.per.partition";
public static final String EXCEPTION_ON_MESSAGE_DROPPED_CONFIG = "exception.on.message.dropped";
public static final String TREAT_BAD_SEGMENTS_AS_PAYLOAD_CONFIG = "treat.bad.segments.as.payload";
public static final String KEY_DESERIALIZER_CLASS_CONFIG = ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG;
public static final String VALUE_DESERIALIZER_CLASS_CONFIG = ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG;
public static final String SEGMENT_DESERIALIZER_CLASS_CONFIG = "segment.deserializer.class";
Expand Down Expand Up @@ -71,6 +72,9 @@ public class LiKafkaConsumerConfig extends AbstractConfig {
"full or the incomplete message has expired. The consumer will throw a LargeMessageDroppedException if this " +
"configuration is set to true. Otherwise the consumer will drop the message silently.";

private static final String TREAT_BAD_SEGMENTS_AS_PAYLOAD_DOC = "The message assembler will treat invalid message segments " +
" as payload. this can be used as a last resort when some arbitrary payloads accidentally pass as a large message segment";

private static final String KEY_DESERIALIZER_CLASS_DOC = "The key deserializer class for the consumer.";

private static final String VALUE_DESERIALIZER_CLASS_DOC = "The value deserializer class for the consumer.";
Expand Down Expand Up @@ -128,6 +132,11 @@ public class LiKafkaConsumerConfig extends AbstractConfig {
"false",
Importance.LOW,
EXCEPTION_ON_MESSAGE_DROPPED_DOC)
.define(TREAT_BAD_SEGMENTS_AS_PAYLOAD_CONFIG,
Type.BOOLEAN,
"false",
Importance.LOW,
TREAT_BAD_SEGMENTS_AS_PAYLOAD_DOC)
.define(KEY_DESERIALIZER_CLASS_CONFIG,
Type.CLASS,
ByteArrayDeserializer.class.getName(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,9 @@ public Object metricValue() {
int messageAssemblerCapacity = configs.getInt(LiKafkaConsumerConfig.MESSAGE_ASSEMBLER_BUFFER_CAPACITY_CONFIG);
int messageAssemblerExpirationOffsetGap = configs.getInt(LiKafkaConsumerConfig.MESSAGE_ASSEMBLER_EXPIRATION_OFFSET_GAP_CONFIG);
boolean exceptionOnMessageDropped = configs.getBoolean(LiKafkaConsumerConfig.EXCEPTION_ON_MESSAGE_DROPPED_CONFIG);
boolean treatBadSegmentsAsPayload = configs.getBoolean(LiKafkaConsumerConfig.TREAT_BAD_SEGMENTS_AS_PAYLOAD_CONFIG);
MessageAssembler assembler = new MessageAssemblerImpl(messageAssemblerCapacity, messageAssemblerExpirationOffsetGap,
exceptionOnMessageDropped, segmentDeserializer);
exceptionOnMessageDropped, segmentDeserializer, treatBadSegmentsAsPayload);

// Instantiate delivered message offset tracker if needed.
int maxTrackedMessagesPerPartition = configs.getInt(LiKafkaConsumerConfig.MAX_TRACKED_MESSAGES_PER_PARTITION_CONFIG);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,27 +179,12 @@ private LargeMessage evictEldestMessage() {
}

private LargeMessage validateSegmentAndGetMessage(TopicPartition tp, LargeMessageSegment segment, long offset) {
if (segment.payload == null) {
throw new InvalidSegmentException("Payload cannot be null");
}
segment.sanityCheck();

segment.payload.rewind();
long segmentSize = segment.payload.remaining();
UUID messageId = segment.messageId;
int messageSizeInBytes = segment.messageSizeInBytes;
int numberOfSegments = segment.numberOfSegments;
int seq = segment.sequenceNumber;

if (messageId == null) {
throw new InvalidSegmentException("Message Id can not be null");
}
if (segmentSize > messageSizeInBytes) {
throw new InvalidSegmentException("Segment size should not be larger than message size.");
}

if (seq < 0 || seq > numberOfSegments - 1) {
throw new InvalidSegmentException("Sequence number " + seq
+ " should fall between [0," + (numberOfSegments - 1) + "].");
}

// Create the incomplete message if needed.
LargeMessage message = _incompleteMessageMap.get(messageId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

package com.linkedin.kafka.clients.largemessage;

import com.linkedin.kafka.clients.largemessage.errors.InvalidSegmentException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.UUID;
Expand Down Expand Up @@ -61,6 +62,29 @@ public byte[] payloadArray() {
}
}

public void sanityCheck() throws InvalidSegmentException {
if (messageId == null) {
throw new InvalidSegmentException("Message Id can not be null");
}
if (messageSizeInBytes < 0) {
throw new InvalidSegmentException("message size (" + messageSizeInBytes + ") should be >= 0");
}
if (payload == null) {
throw new InvalidSegmentException("payload cannot be null");
}
//this tries to handle cases where payload has not been flipped/rewound
long dataSize = payload.position() > 0 ? payload.position() : payload.limit();
if (dataSize > messageSizeInBytes) {
throw new InvalidSegmentException("segment size (" + dataSize + ") should not be larger than message size (" + messageSizeInBytes + ")");
}
if (numberOfSegments <= 0) {
throw new InvalidSegmentException("number of segments should be > 0, instead is " + numberOfSegments);
}
if (sequenceNumber < 0 || sequenceNumber > numberOfSegments - 1) {
throw new InvalidSegmentException("Sequence number " + sequenceNumber + " should fall between [0," + (numberOfSegments - 1) + "].");
}
}

@Override
public String toString() {
return "[messageId=" + messageId + ",seq=" + sequenceNumber + ",numSegs=" + numberOfSegments + ",messageSize=" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

package com.linkedin.kafka.clients.largemessage;

import com.linkedin.kafka.clients.largemessage.errors.InvalidSegmentException;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.serialization.Deserializer;
import org.slf4j.Logger;
Expand All @@ -19,13 +20,23 @@ public class MessageAssemblerImpl implements MessageAssembler {
private static final Logger LOG = LoggerFactory.getLogger(MessageAssemblerImpl.class);
private final LargeMessageBufferPool _messagePool;
private final Deserializer<LargeMessageSegment> _segmentDeserializer;
private final boolean _treatInvalidMessageSegmentsAsPayload;

public MessageAssemblerImpl(long bufferCapacity,
long expirationOffsetGap,
boolean exceptionOnMessageDropped,
Deserializer<LargeMessageSegment> segmentDeserializer) {
this(bufferCapacity, expirationOffsetGap, exceptionOnMessageDropped, segmentDeserializer, false);
}

public MessageAssemblerImpl(long bufferCapacity,
long expirationOffsetGap,
boolean exceptionOnMessageDropped,
Deserializer<LargeMessageSegment> segmentDeserializer,
boolean treatInvalidMessageSegmentsAsPayload) {
_messagePool = new LargeMessageBufferPool(bufferCapacity, expirationOffsetGap, exceptionOnMessageDropped);
_segmentDeserializer = segmentDeserializer;
_treatInvalidMessageSegmentsAsPayload = treatInvalidMessageSegmentsAsPayload;
}

@Override
Expand All @@ -36,8 +47,21 @@ public AssembleResult assemble(TopicPartition tp, long offset, byte[] segmentByt

LargeMessageSegment segment = _segmentDeserializer.deserialize(tp.topic(), segmentBytes);
if (segment == null) {
//not a segment
return new AssembleResult(segmentBytes, offset, offset);
} else {
//sanity-check the segment
try {
segment.sanityCheck();
} catch (InvalidSegmentException e) {
if (_treatInvalidMessageSegmentsAsPayload) {
//behave as if this didnt look like a segment to us
LOG.warn("message at {}/{} may have been a false-positive segment and will be treated as regular payload", tp, offset, e);
return new AssembleResult(segmentBytes, offset, offset);
} else {
throw e;
}
}
// Return immediately if it is a single segment message.
if (segment.numberOfSegments == 1) {
return new AssembleResult(segment.payloadArray(), offset, offset);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public List<ProducerRecord<byte[], byte[]>> split(String topic,
// Get original message size in bytes
int messageSizeInBytes = serializedRecord.length;
ByteBuffer bytebuffer = ByteBuffer.wrap(serializedRecord);

//messages with >1 segments absolutely must have a != null key set to guarantee they land in the same partition
byte[] segmentKey = (key == null && numberOfSegments > 1) ? LiKafkaClientsUtils.uuidToBytes(segmentMessageId) : key;
// Sequence number starts from 0.
for (int seq = 0; seq < numberOfSegments; seq++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ public void testSegmentSizeTooLarge() {
pool.tryCompleteMessage(tp, 0, segment);
fail("Should throw large message exception for wrong segment size.");
} catch (InvalidSegmentException ise) {
assertTrue(ise.getMessage().startsWith("Segment size should not be larger"));
assertTrue(ise.getMessage().toLowerCase().contains("segment size"));
assertTrue(ise.getMessage().toLowerCase().contains("should not be larger than message size"));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@

package com.linkedin.kafka.clients.largemessage;

import com.linkedin.kafka.clients.largemessage.errors.InvalidSegmentException;
import com.linkedin.kafka.clients.utils.LiKafkaClientsUtils;
import java.util.UUID;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.serialization.Deserializer;
import org.apache.kafka.common.serialization.Serializer;
import org.testng.Assert;
import org.testng.annotations.Test;

import java.nio.ByteBuffer;
Expand All @@ -19,6 +22,7 @@
* Unit test for message assembler.
*/
public class MessageAssemblerTest {

@Test
public void testSingleMessageSegment() {
// Create serializer/deserializers.
Expand All @@ -36,6 +40,33 @@ public void testSingleMessageSegment() {
assertEquals(assembleResult.messageEndingOffset(), 0, "The message ending offset should be 0");
}

@Test
public void testTreatBadSegmentAsPayload() {
Serializer<LargeMessageSegment> segmentSerializer = new DefaultSegmentSerializer();
Deserializer<LargeMessageSegment> segmentDeserializer = new DefaultSegmentDeserializer();
MessageAssembler messageAssembler = new MessageAssemblerImpl(100, 100, true, segmentDeserializer, false);
TopicPartition tp = new TopicPartition("topic", 0);

UUID uuid = UUID.randomUUID();
byte[] realPayload = "message".getBytes();
LargeMessageSegment badSegment = new LargeMessageSegment(uuid, -1, 100, -1, ByteBuffer.wrap(realPayload));
byte[] messageWrappedBytes = segmentSerializer.serialize(tp.topic(), badSegment);
Assert.assertTrue(messageWrappedBytes.length > realPayload.length); //wrapping has been done

try {
messageAssembler.assemble(tp, 0, messageWrappedBytes);
Assert.fail("expected to throw");
} catch (InvalidSegmentException expected) {

}

messageAssembler = new MessageAssemblerImpl(100, 100, true, segmentDeserializer, true);
MessageAssembler.AssembleResult assembleResult = messageAssembler.assemble(tp, 0, messageWrappedBytes);
Assert.assertEquals(assembleResult.messageBytes(), messageWrappedBytes);
Assert.assertEquals(assembleResult.messageStartingOffset(), 0);
Assert.assertEquals(assembleResult.messageEndingOffset(), 0);
}

@Test
public void testNonLargeMessageSegmentBytes() {
// Create serializer/deserializers.
Expand Down

0 comments on commit ef41616

Please sign in to comment.