From 51105f1d1e90348ef9f9c399e35001f47e34d76b Mon Sep 17 00:00:00 2001 From: Jialin Liu Date: Tue, 17 Dec 2024 11:18:15 -0800 Subject: [PATCH] [server] Fix NPE in StoreAwarePartitionWiseKafkaConsumerService#handleUnsubscription --- .../kafka/consumer/SharedKafkaConsumer.java | 2 +- ...warePartitionWiseKafkaConsumerService.java | 9 +++++-- .../consumer/KafkaConsumerServiceTest.java | 27 +++++++++++-------- 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/clients/da-vinci-client/src/main/java/com/linkedin/davinci/kafka/consumer/SharedKafkaConsumer.java b/clients/da-vinci-client/src/main/java/com/linkedin/davinci/kafka/consumer/SharedKafkaConsumer.java index 753effcf2e1..86e76c3acd6 100644 --- a/clients/da-vinci-client/src/main/java/com/linkedin/davinci/kafka/consumer/SharedKafkaConsumer.java +++ b/clients/da-vinci-client/src/main/java/com/linkedin/davinci/kafka/consumer/SharedKafkaConsumer.java @@ -190,8 +190,8 @@ public synchronized void batchUnsubscribe(Set pubSubTopicP */ protected synchronized void unSubscribeAction(Supplier> supplier, long timeoutMs) { long currentPollTimes = pollTimes; - Set topicPartitions = supplier.get(); long startTime = System.currentTimeMillis(); + Set topicPartitions = supplier.get(); long elapsedTime = System.currentTimeMillis() - startTime; LOGGER.info( "Shared consumer {} unsubscribed {} partition(s): ({}) in {} ms", diff --git a/clients/da-vinci-client/src/main/java/com/linkedin/davinci/kafka/consumer/StoreAwarePartitionWiseKafkaConsumerService.java b/clients/da-vinci-client/src/main/java/com/linkedin/davinci/kafka/consumer/StoreAwarePartitionWiseKafkaConsumerService.java index 78e600083b2..12fe64a43a0 100644 --- a/clients/da-vinci-client/src/main/java/com/linkedin/davinci/kafka/consumer/StoreAwarePartitionWiseKafkaConsumerService.java +++ b/clients/da-vinci-client/src/main/java/com/linkedin/davinci/kafka/consumer/StoreAwarePartitionWiseKafkaConsumerService.java @@ -122,7 +122,7 @@ void handleUnsubscription( PubSubTopic versionTopic, PubSubTopicPartition pubSubTopicPartition) { super.handleUnsubscription(consumer, versionTopic, pubSubTopicPartition); - decreaseConsumerStoreLoad(consumer, versionTopic.getStoreName()); + decreaseConsumerStoreLoad(consumer, versionTopic); } int getConsumerStoreLoad(SharedKafkaConsumer consumer, String storeName) { @@ -138,7 +138,12 @@ void increaseConsumerStoreLoad(SharedKafkaConsumer consumer, String storeName) { .compute(storeName, (k, v) -> (v == null) ? 1 : v + 1); } - void decreaseConsumerStoreLoad(SharedKafkaConsumer consumer, String storeName) { + void decreaseConsumerStoreLoad(SharedKafkaConsumer consumer, PubSubTopic versionTopic) { + if (versionTopic == null) { + getLOGGER().warn("Incoming versionTopic is null, will skip decreasing store load for consumer: {}", consumer); + return; + } + String storeName = versionTopic.getStoreName(); if (!getConsumerToBaseLoadCount().containsKey(consumer)) { throw new IllegalStateException("Consumer to base load count map does not contain consumer: " + consumer); } diff --git a/clients/da-vinci-client/src/test/java/com/linkedin/davinci/kafka/consumer/KafkaConsumerServiceTest.java b/clients/da-vinci-client/src/test/java/com/linkedin/davinci/kafka/consumer/KafkaConsumerServiceTest.java index 5cc2a4ae52c..e874e31d198 100644 --- a/clients/da-vinci-client/src/test/java/com/linkedin/davinci/kafka/consumer/KafkaConsumerServiceTest.java +++ b/clients/da-vinci-client/src/test/java/com/linkedin/davinci/kafka/consumer/KafkaConsumerServiceTest.java @@ -232,7 +232,7 @@ private KafkaConsumerService getKafkaConsumerServiceWithSingleConsumer( poolType, factory, properties, - 1000l, + 1000L, 1, mockIngestionThrottler, mock(KafkaClusterBasedRecordThrottler.class), @@ -436,6 +436,8 @@ public void testStoreAwarePartitionWiseGetConsumer() { PubSubTopic pubSubTopicForStoreName3 = pubSubTopicRepository.getTopic(topicForStoreName3); String storeName4 = Utils.getUniqueString("test_consumer_service4"); + String topicForStoreName4 = Version.composeKafkaTopic(storeName4, 1); + PubSubTopic pubSubTopicForStoreName4 = pubSubTopicRepository.getTopic(topicForStoreName4); SharedKafkaConsumer consumer1 = mock(SharedKafkaConsumer.class); SharedKafkaConsumer consumer2 = mock(SharedKafkaConsumer.class); @@ -461,9 +463,9 @@ public void testStoreAwarePartitionWiseGetConsumer() { when(consumerService.getLOGGER()) .thenReturn(LogManager.getLogger(StoreAwarePartitionWiseKafkaConsumerService.class)); doCallRealMethod().when(consumerService).pickConsumerForPartition(any(), any()); - doCallRealMethod().when(consumerService).getConsumerStoreLoad(any(), anyString()); - doCallRealMethod().when(consumerService).increaseConsumerStoreLoad(any(), anyString()); - doCallRealMethod().when(consumerService).decreaseConsumerStoreLoad(any(), anyString()); + doCallRealMethod().when(consumerService).getConsumerStoreLoad(any(), any()); + doCallRealMethod().when(consumerService).increaseConsumerStoreLoad(any(), any()); + doCallRealMethod().when(consumerService).decreaseConsumerStoreLoad(any(), any()); consumerToBasicLoadMap.put(consumer1, 1); Map innerMap1 = new VeniceConcurrentHashMap<>(); @@ -508,25 +510,28 @@ public void testStoreAwarePartitionWiseGetConsumer() { Assert.assertEquals(consumerService.getConsumerStoreLoad(consumer1, storeName3), 10003); // Validate decrease consumer entry - Assert.assertThrows(() -> consumerService.decreaseConsumerStoreLoad(consumer1, storeName4)); + Assert.assertThrows(() -> consumerService.decreaseConsumerStoreLoad(consumer1, pubSubTopicForStoreName4)); - consumerService.decreaseConsumerStoreLoad(consumer1, storeName1); + consumerService.decreaseConsumerStoreLoad(consumer1, pubSubTopicForStoreName1); Assert.assertEquals(consumerToBasicLoadMap.get(consumer1).intValue(), 2); Assert.assertNull(consumerToStoreLoadMap.get(consumer1).get(storeName1)); Assert.assertEquals(consumerService.getConsumerStoreLoad(consumer1, storeName1), 2); - Assert.assertThrows(() -> consumerService.decreaseConsumerStoreLoad(consumer1, storeName1)); + Assert.assertThrows(() -> consumerService.decreaseConsumerStoreLoad(consumer1, pubSubTopicForStoreName1)); - consumerService.decreaseConsumerStoreLoad(consumer1, storeName2); + consumerService.decreaseConsumerStoreLoad(consumer1, pubSubTopicForStoreName2); Assert.assertEquals(consumerToBasicLoadMap.get(consumer1).intValue(), 1); Assert.assertNull(consumerToStoreLoadMap.get(consumer1).get(storeName2)); Assert.assertEquals(consumerService.getConsumerStoreLoad(consumer1, storeName2), 1); - Assert.assertThrows(() -> consumerService.decreaseConsumerStoreLoad(consumer1, storeName2)); + Assert.assertThrows(() -> consumerService.decreaseConsumerStoreLoad(consumer1, pubSubTopicForStoreName2)); - consumerService.decreaseConsumerStoreLoad(consumer1, storeName3); + consumerService.decreaseConsumerStoreLoad(consumer1, pubSubTopicForStoreName3); Assert.assertNull(consumerToBasicLoadMap.get(consumer1)); Assert.assertNull(consumerToStoreLoadMap.get(consumer1)); Assert.assertEquals(consumerService.getConsumerStoreLoad(consumer1, storeName3), 0); - Assert.assertThrows(() -> consumerService.decreaseConsumerStoreLoad(consumer1, storeName3)); + Assert.assertThrows(() -> consumerService.decreaseConsumerStoreLoad(consumer1, pubSubTopicForStoreName3)); + + // Make sure invalid versionTopic won't throw NPE. + consumerService.decreaseConsumerStoreLoad(consumer1, null); // Validate increase consumer entry consumerService.increaseConsumerStoreLoad(consumer1, storeName1);