From aff73b6101145b9f7edac28207e65d4dc27d6018 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Wed, 11 Dec 2024 21:10:26 +0800 Subject: [PATCH] Allow tokenCh of batch controller be nil Signed-off-by: JmPotato --- client/pkg/batch/batch_controller.go | 29 ++++++----- client/pkg/batch/batch_controller_test.go | 61 ++++++++++++++++++++++- 2 files changed, 76 insertions(+), 14 deletions(-) diff --git a/client/pkg/batch/batch_controller.go b/client/pkg/batch/batch_controller.go index 32f0aaba1ae..afc8c93fea7 100644 --- a/client/pkg/batch/batch_controller.go +++ b/client/pkg/batch/batch_controller.go @@ -66,7 +66,7 @@ func (bc *Controller[T]) FetchPendingRequests(ctx context.Context, requestCh <-c if errRet != nil { // Something went wrong when collecting a batch of requests. Release the token and cancel collected requests // if any. - if tokenAcquired { + if tokenAcquired && tokenCh != nil { tokenCh <- struct{}{} } bc.FinishCollectedRequests(bc.finisher, errRet) @@ -80,6 +80,9 @@ func (bc *Controller[T]) FetchPendingRequests(ctx context.Context, requestCh <-c // If the batch size reaches the maxBatchSize limit but the token haven't arrived yet, don't receive more // requests, and return when token is ready. if bc.collectedRequestCount >= bc.maxBatchSize && !tokenAcquired { + if tokenCh == nil { + return nil + } select { case <-ctx.Done(): return ctx.Err() @@ -88,17 +91,19 @@ func (bc *Controller[T]) FetchPendingRequests(ctx context.Context, requestCh <-c } } - select { - case <-ctx.Done(): - return ctx.Err() - case req := <-requestCh: - // Start to batch when the first request arrives. - bc.pushRequest(req) - // A request arrives but the token is not ready yet. Continue waiting, and also allowing collecting the next - // request if it arrives. - continue - case <-tokenCh: - tokenAcquired = true + if tokenCh != nil { + select { + case <-ctx.Done(): + return ctx.Err() + case req := <-requestCh: + // Start to batch when the first request arrives. + bc.pushRequest(req) + // A request arrives but the token is not ready yet. Continue waiting, and also allowing collecting the next + // request if it arrives. + continue + case <-tokenCh: + tokenAcquired = true + } } // The token is ready. If the first request didn't arrive, wait for it. diff --git a/client/pkg/batch/batch_controller_test.go b/client/pkg/batch/batch_controller_test.go index 7c9ffa6944f..92aef14bd35 100644 --- a/client/pkg/batch/batch_controller_test.go +++ b/client/pkg/batch/batch_controller_test.go @@ -21,9 +21,11 @@ import ( "github.com/stretchr/testify/require" ) +const testMaxBatchSize = 20 + func TestAdjustBestBatchSize(t *testing.T) { re := require.New(t) - bc := NewController[int](20, nil, nil) + bc := NewController[int](testMaxBatchSize, nil, nil) re.Equal(defaultBestBatchSize, bc.bestBatchSize) bc.AdjustBestBatchSize() re.Equal(defaultBestBatchSize-1, bc.bestBatchSize) @@ -52,7 +54,7 @@ type testRequest struct { func TestFinishCollectedRequests(t *testing.T) { re := require.New(t) - bc := NewController[*testRequest](20, nil, nil) + bc := NewController[*testRequest](testMaxBatchSize, nil, nil) // Finish with zero request count. re.Zero(bc.collectedRequestCount) bc.FinishCollectedRequests(nil, nil) @@ -81,3 +83,58 @@ func TestFinishCollectedRequests(t *testing.T) { re.Equal(context.Canceled, requests[i].err) } } + +func TestFetchPendingRequests(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + re := require.New(t) + bc := NewController[int](testMaxBatchSize, nil, nil) + requestCh := make(chan int, testMaxBatchSize+1) + // Fetch a nil `tokenCh`. + requestCh <- 1 + re.NoError(bc.FetchPendingRequests(ctx, requestCh, nil, 0)) + re.Empty(requestCh) + re.Equal(1, bc.collectedRequestCount) + // Fetch a nil `tokenCh` with max batch size. + for i := range testMaxBatchSize { + requestCh <- i + } + re.NoError(bc.FetchPendingRequests(ctx, requestCh, nil, 0)) + re.Empty(requestCh) + re.Equal(testMaxBatchSize, bc.collectedRequestCount) + // Fetch a nil `tokenCh` with max batch size + 1. + for i := range testMaxBatchSize + 1 { + requestCh <- i + } + re.NoError(bc.FetchPendingRequests(ctx, requestCh, nil, 0)) + re.Len(requestCh, 1) + re.Equal(testMaxBatchSize, bc.collectedRequestCount) + // Drain the requestCh. + <-requestCh + // Fetch a non-nil `tokenCh`. + tokenCh := make(chan struct{}, 1) + requestCh <- 1 + tokenCh <- struct{}{} + re.NoError(bc.FetchPendingRequests(ctx, requestCh, tokenCh, 0)) + re.Empty(requestCh) + re.Equal(1, bc.collectedRequestCount) + // Fetch a non-nil `tokenCh` with max batch size. + for i := range testMaxBatchSize { + requestCh <- i + } + tokenCh <- struct{}{} + re.NoError(bc.FetchPendingRequests(ctx, requestCh, tokenCh, 0)) + re.Empty(requestCh) + re.Equal(testMaxBatchSize, bc.collectedRequestCount) + // Fetch a non-nil `tokenCh` with max batch size + 1. + for i := range testMaxBatchSize + 1 { + requestCh <- i + } + tokenCh <- struct{}{} + re.NoError(bc.FetchPendingRequests(ctx, requestCh, tokenCh, 0)) + re.Len(requestCh, 1) + re.Equal(testMaxBatchSize, bc.collectedRequestCount) + // Drain the requestCh. + <-requestCh +}