From 246d73deb85be25b82d35ab300e5bdb81104b038 Mon Sep 17 00:00:00 2001 From: Sumit Aich Date: Wed, 9 Oct 2024 15:53:10 +0530 Subject: [PATCH] feat: add HttpSinkWriter --- .../dagger/core/sink/http/HttpSinkWriter.java | 268 ++++++++++++ .../core/sink/http/HttpSinkWriterTest.java | 413 ++++++++++++++++++ 2 files changed, 681 insertions(+) create mode 100644 dagger-core/src/main/java/com/gotocompany/dagger/core/sink/http/HttpSinkWriter.java create mode 100644 dagger-core/src/test/java/com/gotocompany/dagger/core/sink/http/HttpSinkWriterTest.java diff --git a/dagger-core/src/main/java/com/gotocompany/dagger/core/sink/http/HttpSinkWriter.java b/dagger-core/src/main/java/com/gotocompany/dagger/core/sink/http/HttpSinkWriter.java new file mode 100644 index 000000000..8ddecf9e9 --- /dev/null +++ b/dagger-core/src/main/java/com/gotocompany/dagger/core/sink/http/HttpSinkWriter.java @@ -0,0 +1,268 @@ +package com.gotocompany.dagger.core.sink.http; + +import com.gotocompany.dagger.common.serde.proto.serialization.ProtoSerializer; +import com.gotocompany.dagger.core.exception.HttpSinkWriterException; +import com.gotocompany.dagger.core.metrics.reporters.ErrorReporter; +import com.gotocompany.depot.Sink; +import com.gotocompany.depot.SinkResponse; +import com.gotocompany.depot.error.ErrorInfo; +import com.gotocompany.depot.error.ErrorType; +import com.gotocompany.depot.exception.SinkException; +import com.gotocompany.depot.message.Message; +import lombok.extern.slf4j.Slf4j; +import org.apache.flink.api.connector.sink.SinkWriter; +import org.apache.flink.types.Row; + +import java.io.IOException; +import java.util.*; +import java.util.concurrent.*; +import java.util.function.Function; +import java.util.stream.Collectors; + +@Slf4j +public class HttpSinkWriter implements SinkWriter { + + private static final int DEFAULT_QUEUE_CAPACITY = 10000; + private static final int DEFAULT_THREAD_POOL_SIZE = 5; + private static final long DEFAULT_FLUSH_INTERVAL_MS = 1000; + + private final ProtoSerializer protoSerializer; + private final Sink httpSink; + private final int batchSize; + private final ErrorReporter errorReporter; + private final Set errorTypesForFailing; + private final BlockingQueue messageQueue; + private final ExecutorService executorService; + private final ScheduledExecutorService scheduledExecutorService; + private final AtomicInteger currentBatchSize; + private final Map> customFieldExtractors; + private final HttpSinkWriterMetrics metrics; + private final HttpSinkWriterState state; + + public HttpSinkWriter(ProtoSerializer protoSerializer, Sink httpSink, int batchSize, + ErrorReporter errorReporter, Set errorTypesForFailing) { + this.protoSerializer = protoSerializer; + this.httpSink = httpSink; + this.batchSize = batchSize; + this.errorReporter = errorReporter; + this.errorTypesForFailing = errorTypesForFailing; + this.messageQueue = new LinkedBlockingQueue<>(DEFAULT_QUEUE_CAPACITY); + this.executorService = Executors.newFixedThreadPool(DEFAULT_THREAD_POOL_SIZE); + this.scheduledExecutorService = Executors.newSingleThreadScheduledExecutor(); + this.currentBatchSize = new AtomicInteger(0); + this.customFieldExtractors = initializeCustomFieldExtractors(); + this.metrics = new HttpSinkWriterMetrics(); + this.state = new HttpSinkWriterState(); + + initializePeriodicFlush(); + } + + @Override + public void write(Row element, Context context) throws IOException, InterruptedException { + metrics.incrementTotalRowsReceived(); + byte[] key = protoSerializer.serializeKey(element); + byte[] value = enrichAndSerializeValue(element); + Message message = new Message(key, value); + + if (!messageQueue.offer(message, 1, TimeUnit.SECONDS)) { + metrics.incrementDroppedMessages(); + log.warn("Message queue is full. Dropping message: {}", message); + return; + } + + if (currentBatchSize.incrementAndGet() >= batchSize) { + flushQueueAsync(); + } + } + + @Override + public List prepareCommit(boolean flush) throws IOException, InterruptedException { + if (flush) { + flushQueue(); + } + return Collections.emptyList(); + } + + @Override + public List snapshotState(long checkpointId) { + state.setLastCheckpointId(checkpointId); + state.setLastCheckpointTimestamp(System.currentTimeMillis()); + return Collections.emptyList(); + } + + @Override + public void close() throws Exception { + flushQueue(); + executorService.shutdown(); + scheduledExecutorService.shutdown(); + if (!executorService.awaitTermination(30, TimeUnit.SECONDS)) { + executorService.shutdownNow(); + } + if (!scheduledExecutorService.awaitTermination(30, TimeUnit.SECONDS)) { + scheduledExecutorService.shutdownNow(); + } + httpSink.close(); + } + + private Map> initializeCustomFieldExtractors() { + Map> extractors = new HashMap<>(); + extractors.put("timestamp", row -> System.currentTimeMillis()); + extractors.put("rowHash", row -> Objects.hash(row)); + return extractors; + } + + private byte[] enrichAndSerializeValue(Row element) { + Map enrichedData = new HashMap<>(); + for (Map.Entry> entry : customFieldExtractors.entrySet()) { + enrichedData.put(entry.getKey(), entry.getValue().apply(element)); + } + + return protoSerializer.serializeValue(element); + } + + private void initializePeriodicFlush() { + scheduledExecutorService.scheduleAtFixedRate( + this::flushQueueAsync, + DEFAULT_FLUSH_INTERVAL_MS, + DEFAULT_FLUSH_INTERVAL_MS, + TimeUnit.MILLISECONDS + ); + } + + + private void flushQueueAsync() { + executorService.submit(() -> { + try { + flushQueue(); + } catch (Exception e) { + log.error("Error during async queue flush", e); + metrics.incrementAsyncFlushErrors(); + } + }); + } + + private void flushQueue() throws IOException, InterruptedException { + List batch = new ArrayList<>(batchSize); + messageQueue.drainTo(batch, batchSize); + if (!batch.isEmpty()) { + pushToHttpSink(batch); + } + currentBatchSize.set(0); + } + + private void pushToHttpSink(List batch) throws SinkException, HttpSinkWriterException { + metrics.startBatchProcessing(batch.size()); + SinkResponse sinkResponse; + try { + sinkResponse = httpSink.pushToSink(batch); + } catch (Exception e) { + metrics.incrementTotalErrors(); + errorReporter.reportFatalException(e); + throw e; + } + if (sinkResponse.hasErrors()) { + handleErrors(sinkResponse, batch); + } + metrics.endBatchProcessing(); + } + + private void handleErrors(SinkResponse sinkResponse, List batch) throws HttpSinkWriterException { + logErrors(sinkResponse, batch); + Map> partitionedErrors = partitionErrorsByFailureType(sinkResponse); + + partitionedErrors.get(Boolean.FALSE).forEach(errorInfo -> { + errorReporter.reportNonFatalException(errorInfo.getException()); + metrics.incrementNonFatalErrors(); + }); + + partitionedErrors.get(Boolean.TRUE).forEach(errorInfo -> { + errorReporter.reportFatalException(errorInfo.getException()); + metrics.incrementFatalErrors(); + }); + + if (!partitionedErrors.get(Boolean.TRUE).isEmpty()) { + throw new HttpSinkWriterException("Critical error(s) occurred during HTTP sink write operation"); + } + } + + private void logErrors(SinkResponse sinkResponse, List batch) { + log.error("Failed to push {} records to HttpSink", sinkResponse.getErrors().size()); + sinkResponse.getErrors().forEach((index, errorInfo) -> { + Message message = batch.get(index); + log.error("Failed to push message with metadata {}. Exception: {}. ErrorType: {}", + message.getMetadataString(), + errorInfo.getException().getMessage(), + errorInfo.getErrorType().name()); + }); + } + + private Map> partitionErrorsByFailureType(SinkResponse sinkResponse) { + return sinkResponse.getErrors().values().stream() + .collect(Collectors.partitioningBy(errorInfo -> errorTypesForFailing.contains(errorInfo.getErrorType()))); + } + + private static class HttpSinkWriterMetrics { + private final AtomicLong totalRowsReceived = new AtomicLong(0); + private final AtomicLong totalBatchesProcessed = new AtomicLong(0); + private final AtomicLong totalRecordsSent = new AtomicLong(0); + private final AtomicLong totalErrors = new AtomicLong(0); + private final AtomicLong nonFatalErrors = new AtomicLong(0); + private final AtomicLong fatalErrors = new AtomicLong(0); + private final AtomicLong asyncFlushErrors = new AtomicLong(0); + private final AtomicLong droppedMessages = new AtomicLong(0); + private final AtomicLong totalProcessingTimeMs = new AtomicLong(0); + private final ThreadLocal batchStartTime = new ThreadLocal<>(); + + void incrementTotalRowsReceived() { + totalRowsReceived.incrementAndGet(); + } + + void incrementTotalErrors() { + totalErrors.incrementAndGet(); + } + + void incrementNonFatalErrors() { + nonFatalErrors.incrementAndGet(); + } + + void incrementFatalErrors() { + fatalErrors.incrementAndGet(); + } + + void incrementAsyncFlushErrors() { + asyncFlushErrors.incrementAndGet(); + } + + void incrementDroppedMessages() { + droppedMessages.incrementAndGet(); + } + + void startBatchProcessing(int batchSize) { + totalBatchesProcessed.incrementAndGet(); + totalRecordsSent.addAndGet(batchSize); + batchStartTime.set(System.currentTimeMillis()); + } + + void endBatchProcessing() { + long processingTime = System.currentTimeMillis() - batchStartTime.get(); + totalProcessingTimeMs.addAndGet(processingTime); + batchStartTime.remove(); + } + + } + + + private static class HttpSinkWriterState { + private volatile long lastCheckpointId; + private volatile long lastCheckpointTimestamp; + + void setLastCheckpointId(long checkpointId) { + this.lastCheckpointId = checkpointId; + } + + void setLastCheckpointTimestamp(long timestamp) { + this.lastCheckpointTimestamp = timestamp; + } + + } +} diff --git a/dagger-core/src/test/java/com/gotocompany/dagger/core/sink/http/HttpSinkWriterTest.java b/dagger-core/src/test/java/com/gotocompany/dagger/core/sink/http/HttpSinkWriterTest.java new file mode 100644 index 000000000..06bee26c8 --- /dev/null +++ b/dagger-core/src/test/java/com/gotocompany/dagger/core/sink/http/HttpSinkWriterTest.java @@ -0,0 +1,413 @@ +import com.gotocompany.dagger.common.serde.proto.serialization.ProtoSerializer; +import com.gotocompany.dagger.core.exception.HttpSinkWriterException; +import com.gotocompany.dagger.core.metrics.reporters.ErrorReporter; +import com.gotocompany.depot.Sink; +import com.gotocompany.depot.SinkResponse; +import com.gotocompany.depot.error.ErrorInfo; +import com.gotocompany.depot.error.ErrorType; +import com.gotocompany.depot.message.Message; +import org.apache.flink.types.Row; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.IOException; +import java.util.*; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +public class HttpSinkWriterTest { + + @Mock + private ProtoSerializer protoSerializer; + @Mock + private Sink httpSink; + @Mock + private ErrorReporter errorReporter; + @Mock + private SinkWriter.Context context; + + private HttpSinkWriter httpSinkWriter; + private Set errorTypesForFailing; + + @Before + public void setUp() { + MockitoAnnotations.initMocks(this); + errorTypesForFailing = new HashSet<>(Arrays.asList(ErrorType.SINK_4XX_ERROR, ErrorType.SINK_5XX_ERROR)); + httpSinkWriter = new HttpSinkWriter(protoSerializer, httpSink, 10, errorReporter, errorTypesForFailing); + } + + @Test + public void shouldWriteSuccessfully() throws Exception { + Row row = mock(Row.class); + when(protoSerializer.serializeKey(row)).thenReturn(new byte[]{1}); + when(protoSerializer.serializeValue(row)).thenReturn(new byte[]{2}); + + httpSinkWriter.write(row, context); + + verify(protoSerializer).serializeKey(row); + verify(protoSerializer).serializeValue(row); + } + + @Test + public void shouldHandleQueueFullScenario() throws Exception { + Row row = mock(Row.class); + when(protoSerializer.serializeKey(row)).thenReturn(new byte[]{1}); + when(protoSerializer.serializeValue(row)).thenReturn(new byte[]{2}); + + for (int i = 0; i < 10001; i++) { + httpSinkWriter.write(row, context); + } + + verify(protoSerializer, times(10001)).serializeKey(row); + verify(protoSerializer, times(10001)).serializeValue(row); + } + + @Test + public void shouldPrepareCommitWithFlush() throws Exception { + httpSinkWriter.prepareCommit(true); + verify(httpSink).pushToSink(anyList()); + } + + @Test + public void shouldPrepareCommitWithoutFlush() throws Exception { + httpSinkWriter.prepareCommit(false); + verify(httpSink, never()).pushToSink(anyList()); + } + + @Test + public void shouldSnapshotState() { + List state = httpSinkWriter.snapshotState(123L); + assertTrue(state.isEmpty()); + } + + @Test + public void shouldClose() throws Exception { + httpSinkWriter.close(); + verify(httpSink).close(); + } + + @Test + public void shouldWriteWhenBatchSizeReached() throws Exception { + Row row = mock(Row.class); + when(protoSerializer.serializeKey(row)).thenReturn(new byte[]{1}); + when(protoSerializer.serializeValue(row)).thenReturn(new byte[]{2}); + + for (int i = 0; i < 10; i++) { + httpSinkWriter.write(row, context); + } + + verify(httpSink).pushToSink(anyList()); + } + + @Test + public void shouldHandleSerializationError() throws Exception { + Row row = mock(Row.class); + when(protoSerializer.serializeKey(row)).thenThrow(new RuntimeException("Serialization error")); + + assertThrows(RuntimeException.class, () -> httpSinkWriter.write(row, context)); + } + + @Test + public void shouldPushToHttpSinkSuccessfully() throws Exception { + List batch = Collections.singletonList(new Message(new byte[]{1}, new byte[]{2})); + when(httpSink.pushToSink(batch)).thenReturn(new SinkResponse()); + + httpSinkWriter.getClass().getDeclaredMethod("pushToHttpSink", List.class) + .invoke(httpSinkWriter, batch); + + verify(httpSink).pushToSink(batch); + } + + @Test + public void shouldHandlePushToHttpSinkWithErrors() throws Exception { + List batch = Arrays.asList( + new Message(new byte[]{1}, new byte[]{2}), + new Message(new byte[]{3}, new byte[]{4}) + ); + SinkResponse response = new SinkResponse(); + response.addErrors(0, new ErrorInfo(new Exception("Error 1"), ErrorType.SINK_4XX_ERROR)); + when(httpSink.pushToSink(batch)).thenReturn(response); + + assertThrows(HttpSinkWriterException.class, () -> { + httpSinkWriter.getClass().getDeclaredMethod("pushToHttpSink", List.class) + .invoke(httpSinkWriter, batch); + }); + } + + @Test + public void shouldHandleNonFatalErrors() throws Exception { + SinkResponse response = new SinkResponse(); + response.addErrors(0, new ErrorInfo(new Exception("Non-fatal error"), ErrorType.DESERIALIZATION_ERROR)); + List batch = Collections.singletonList(new Message(new byte[]{1}, new byte[]{2})); + + httpSinkWriter.getClass().getDeclaredMethod("handleErrors", SinkResponse.class, List.class) + .invoke(httpSinkWriter, response, batch); + + verify(errorReporter).reportNonFatalException(any()); + } + + @Test + public void shouldHandleFatalErrors() throws Exception { + SinkResponse response = new SinkResponse(); + response.addErrors(0, new ErrorInfo(new Exception("Fatal error"), ErrorType.SINK_5XX_ERROR)); + List batch = Collections.singletonList(new Message(new byte[]{1}, new byte[]{2})); + + assertThrows(HttpSinkWriterException.class, () -> { + httpSinkWriter.getClass().getDeclaredMethod("handleErrors", SinkResponse.class, List.class) + .invoke(httpSinkWriter, response, batch); + }); + } + + @Test + public void shouldLogErrors() throws Exception { + SinkResponse response = new SinkResponse(); + response.addErrors(0, new ErrorInfo(new Exception("Test error"), ErrorType.SINK_4XX_ERROR)); + List batch = Collections.singletonList(new Message(new byte[]{1}, new byte[]{2})); + + httpSinkWriter.getClass().getDeclaredMethod("logErrors", SinkResponse.class, List.class) + .invoke(httpSinkWriter, response, batch); + } + + @Test + public void shouldPartitionErrorsByFailureType() throws Exception { + SinkResponse response = new SinkResponse(); + response.addErrors(0, new ErrorInfo(new Exception("Fatal error"), ErrorType.SINK_5XX_ERROR)); + response.addErrors(1, new ErrorInfo(new Exception("Non-fatal error"), ErrorType.DESERIALIZATION_ERROR)); + + Map> result = (Map>) httpSinkWriter.getClass() + .getDeclaredMethod("partitionErrorsByFailureType", SinkResponse.class) + .invoke(httpSinkWriter, response); + + assertEquals(1, result.get(Boolean.TRUE).size()); + assertEquals(1, result.get(Boolean.FALSE).size()); + } + + @Test + public void shouldIncrementTotalRowsReceived() throws Exception { + HttpSinkWriter.HttpSinkWriterMetrics metrics = (HttpSinkWriter.HttpSinkWriterMetrics) httpSinkWriter + .getClass().getDeclaredField("metrics").get(httpSinkWriter); + + metrics.incrementTotalRowsReceived(); + assertEquals(1, metrics.totalRowsReceived.get()); + } + + @Test + public void shouldIncrementTotalErrors() throws Exception { + HttpSinkWriter.HttpSinkWriterMetrics metrics = (HttpSinkWriter.HttpSinkWriterMetrics) httpSinkWriter + .getClass().getDeclaredField("metrics").get(httpSinkWriter); + + metrics.incrementTotalErrors(); + assertEquals(1, metrics.totalErrors.get()); + } + + @Test + public void shouldIncrementNonFatalErrors() throws Exception { + HttpSinkWriter.HttpSinkWriterMetrics metrics = (HttpSinkWriter.HttpSinkWriterMetrics) httpSinkWriter + .getClass().getDeclaredField("metrics").get(httpSinkWriter); + + metrics.incrementNonFatalErrors(); + assertEquals(1, metrics.nonFatalErrors.get()); + } + + @Test + public void shouldIncrementFatalErrors() throws Exception { + HttpSinkWriter.HttpSinkWriterMetrics metrics = (HttpSinkWriter.HttpSinkWriterMetrics) httpSinkWriter + .getClass().getDeclaredField("metrics").get(httpSinkWriter); + + metrics.incrementFatalErrors(); + assertEquals(1, metrics.fatalErrors.get()); + } + + @Test + public void shouldIncrementAsyncFlushErrors() throws Exception { + HttpSinkWriter.HttpSinkWriterMetrics metrics = (HttpSinkWriter.HttpSinkWriterMetrics) httpSinkWriter + .getClass().getDeclaredField("metrics").get(httpSinkWriter); + + metrics.incrementAsyncFlushErrors(); + assertEquals(1, metrics.asyncFlushErrors.get()); + } + + @Test + public void shouldIncrementDroppedMessages() throws Exception { + HttpSinkWriter.HttpSinkWriterMetrics metrics = (HttpSinkWriter.HttpSinkWriterMetrics) httpSinkWriter + .getClass().getDeclaredField("metrics").get(httpSinkWriter); + + metrics.incrementDroppedMessages(); + assertEquals(1, metrics.droppedMessages.get()); + } + + @Test + public void shouldStartBatchProcessing() throws Exception { + HttpSinkWriter.HttpSinkWriterMetrics metrics = (HttpSinkWriter.HttpSinkWriterMetrics) httpSinkWriter + .getClass().getDeclaredField("metrics").get(httpSinkWriter); + + metrics.startBatchProcessing(5); + assertEquals(1, metrics.totalBatchesProcessed.get()); + assertEquals(5, metrics.totalRecordsSent.get()); + } + + @Test + public void shouldEndBatchProcessing() throws Exception { + HttpSinkWriter.HttpSinkWriterMetrics metrics = (HttpSinkWriter.HttpSinkWriterMetrics) httpSinkWriter + .getClass().getDeclaredField("metrics").get(httpSinkWriter); + + metrics.startBatchProcessing(5); + Thread.sleep(10); + metrics.endBatchProcessing(); + assertTrue(metrics.totalProcessingTimeMs.get() > 0); + } + + @Test + public void shouldSetLastCheckpointId() throws Exception { + HttpSinkWriter.HttpSinkWriterState state = (HttpSinkWriter.HttpSinkWriterState) httpSinkWriter + .getClass().getDeclaredField("state").get(httpSinkWriter); + + state.setLastCheckpointId(123L); + assertEquals(123L, state.lastCheckpointId); + } + + @Test + public void shouldSetLastCheckpointTimestamp() throws Exception { + HttpSinkWriter.HttpSinkWriterState state = (HttpSinkWriter.HttpSinkWriterState) httpSinkWriter + .getClass().getDeclaredField("state").get(httpSinkWriter); + + long timestamp = System.currentTimeMillis(); + state.setLastCheckpointTimestamp(timestamp); + assertEquals(timestamp, state.lastCheckpointTimestamp); + } + + @Test + public void shouldInitializeCustomFieldExtractors() throws Exception { + Map> extractors = (Map>) httpSinkWriter + .getClass().getDeclaredMethod("initializeCustomFieldExtractors").invoke(httpSinkWriter); + + assertTrue(extractors.containsKey("timestamp")); + assertTrue(extractors.containsKey("rowHash")); + } + + @Test + public void shouldEnrichAndSerializeValue() throws Exception { + Row row = mock(Row.class); + when(protoSerializer.serializeValue(row)).thenReturn(new byte[]{1, 2, 3}); + + byte[] result = (byte[]) httpSinkWriter.getClass().getDeclaredMethod("enrichAndSerializeValue", Row.class) + .invoke(httpSinkWriter, row); + + assertArrayEquals(new byte[]{1, 2, 3}, result); + } + + @Test + public void shouldInitializePeriodicFlush() throws Exception { + httpSinkWriter.getClass().getDeclaredMethod("initializePeriodicFlush").invoke(httpSinkWriter); + + Thread.sleep(1100); + verify(httpSink, atLeastOnce()).pushToSink(anyList()); + } + + @Test + public void shouldFlushQueueAsync() throws Exception { + httpSinkWriter.getClass().getDeclaredMethod("flushQueueAsync").invoke(httpSinkWriter); + + Thread.sleep(100); + verify(httpSink, atLeastOnce()).pushToSink(anyList()); + } + + @Test + public void shouldNotFlushEmptyQueue() throws Exception { + httpSinkWriter.getClass().getDeclaredMethod("flushQueue").invoke(httpSinkWriter); + + verify(httpSink, never()).pushToSink(anyList()); + } + + @Test + public void shouldFlushNonEmptyQueue() throws Exception { + Row row = mock(Row.class); + when(protoSerializer.serializeKey(row)).thenReturn(new byte[]{1}); + when(protoSerializer.serializeValue(row)).thenReturn(new byte[]{2}); + + httpSinkWriter.write(row, context); + + httpSinkWriter.getClass().getDeclaredMethod("flushQueue").invoke(httpSinkWriter); + + verify(httpSink).pushToSink(anyList()); + } + + @Test + public void shouldHandleInterruptedException() throws Exception { + Row row = mock(Row.class); + when(protoSerializer.serializeKey(row)).thenReturn(new byte[]{1}); + when(protoSerializer.serializeValue(row)).thenReturn(new byte[]{2}); + + Thread.currentThread().interrupt(); + + assertThrows(InterruptedException.class, () -> httpSinkWriter.write(row, context)); + + Thread.interrupted(); + } + + @Test + public void shouldShutdownExecutorServiceOnClose() throws Exception { + httpSinkWriter.close(); + + ExecutorService executorService = (ExecutorService) httpSinkWriter.getClass() + .getDeclaredField("executorService").get(httpSinkWriter); + ScheduledExecutorService scheduledExecutorService = (ScheduledExecutorService) httpSinkWriter.getClass() + .getDeclaredField("scheduledExecutorService").get(httpSinkWriter); + + assertTrue(executorService.isShutdown()); + assertTrue(scheduledExecutorService.isShutdown()); + } + + @Test + public void shouldTerminateExecutorServiceOnClose() throws Exception { + ExecutorService executorService = mock(ExecutorService.class); + ScheduledExecutorService scheduledExecutorService = mock(ScheduledExecutorService.class); + + when(executorService.awaitTermination(30, TimeUnit.SECONDS)).thenReturn(false); + when(scheduledExecutorService.awaitTermination(30, TimeUnit.SECONDS)).thenReturn(false); + + httpSinkWriter.getClass().getDeclaredField("executorService").set(httpSinkWriter, executorService); + httpSinkWriter.getClass().getDeclaredField("scheduledExecutorService").set(httpSinkWriter, scheduledExecutorService); + + httpSinkWriter.close(); + + verify(executorService).shutdownNow(); + verify(scheduledExecutorService).shutdownNow(); + } + + @Test + public void shouldHandlePushToHttpSinkException() throws Exception { + List batch = Collections.singletonList(new Message(new byte[]{1}, new byte[]{2})); + when(httpSink.pushToSink(batch)).thenThrow(new RuntimeException("Test exception")); + + assertThrows(RuntimeException.class, () -> { + httpSinkWriter.getClass().getDeclaredMethod("pushToHttpSink", List.class) + .invoke(httpSinkWriter, batch); + }); + + verify(errorReporter).reportFatalException(any(RuntimeException.class)); + } + + @Test + public void shouldHandleMixedErrorTypes() throws Exception { + SinkResponse response = new SinkResponse(); + response.addErrors(0, new ErrorInfo(new Exception("Fatal error"), ErrorType.SINK_5XX_ERROR)); + response.addErrors(1, new ErrorInfo(new Exception("Non-fatal error"), ErrorType.DESERIALIZATION_ERROR)); + List batch = Arrays.asList( + new Message(new byte[]{1}, new byte[]{2}), + new Message(new byte[]{3}, new byte[]{4}) + ); + + assertThrows(HttpSinkWriterException.class, () -> { + httpSinkWriter.getClass().getDeclaredMethod("handleErrors", SinkResponse.class, List.class) + .invoke(httpSinkWriter, response, batch); + }); + + verify(errorReporter).reportFatalException(any()); + verify(errorReporter).reportNonFatalException(any()); + } +}