From 229a649443ffa8da1dd849ff26231e7ede389885 Mon Sep 17 00:00:00 2001 From: Francis Gallagher Date: Wed, 22 Nov 2023 23:13:56 +0000 Subject: [PATCH] fix: small fix for graceful interupts + some whitespacing fixes --- cmd/listen.go | 13 +- messages/listener/rmq.go | 274 ++++++++++++++++++++------------------- messages/listener/sqs.go | 3 +- 3 files changed, 149 insertions(+), 141 deletions(-) diff --git a/cmd/listen.go b/cmd/listen.go index 20f747b..0ace41b 100644 --- a/cmd/listen.go +++ b/cmd/listen.go @@ -39,12 +39,17 @@ var ListenCmd = &cobra.Command{ return fmt.Errorf("error creating listener: %s", err) } - // Create a channel to handle program termination or interruption signals so we can kill any connections if needed + //Create a channel to handle program termination or interruption signals so we can kill any connections if needed sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) - go listener.ListenToService(queueNames) - <-sigChan - listener.Close() + go func() { + <-sigChan + logger.Infof("Received interruption signal. Shutting down gracefully...") + listener.Close() + os.Exit(0) + }() + + listener.ListenToService(queueNames) return nil }, diff --git a/messages/listener/rmq.go b/messages/listener/rmq.go index a78ee0f..c121246 100644 --- a/messages/listener/rmq.go +++ b/messages/listener/rmq.go @@ -1,165 +1,167 @@ package listener import ( - "errors" - "fmt" - "github.com/streadway/amqp" - "magecomm/config_manager/system_limits" - "magecomm/logger" - "magecomm/messages/handler" - "magecomm/messages/queues" - "magecomm/services" - "sync" + "errors" + "fmt" + "github.com/streadway/amqp" + "magecomm/config_manager/system_limits" + "magecomm/logger" + "magecomm/messages/handler" + "magecomm/messages/queues" + "magecomm/services" + "sync" ) type RmqListener struct { - ChannelPool *services.RabbitMQChannelPool - done chan struct{} - wg sync.WaitGroup + ChannelPool *services.RabbitMQChannelPool + stopChan chan struct{} + waitGroup sync.WaitGroup } func (listener *RmqListener) shouldExecutionBeDelayed() error { - totalDeferTime := 0 - for system_limits.CheckIfOutsideOperationalLimits() { - system_limits.SystemLimitCheckSleep() - totalDeferTime += int(system_limits.WaitTimeBetweenChecks) + totalDeferTime := 0 + for system_limits.CheckIfOutsideOperationalLimits() { + system_limits.SystemLimitCheckSleep() + totalDeferTime += int(system_limits.WaitTimeBetweenChecks) - if totalDeferTime > int(system_limits.MaxDeferralTime) { - return errors.New("max deferral time exceeded") - } - } + if totalDeferTime > int(system_limits.MaxDeferralTime) { + return errors.New("max deferral time exceeded") + } + } - return nil + return nil } func (listener *RmqListener) processRmqMessage(message amqp.Delivery, channel *amqp.Channel, queueName string) { - logger.Debugf("Message received from %s", queueName) - correlationID := message.CorrelationId - if message.Headers == nil { - message.Headers = make(amqp.Table) - } - - retryCount, ok := message.Headers["RetryCount"] - if !ok { - retryCount = int32(0) - } - - err := listener.shouldExecutionBeDelayed() - if err != nil { - logger.Warnf("Message deferral time exceeded. Dropping hold on the message.") - message.Headers["RetryCount"] = retryCount.(int32) + 1 - _, err := services.PublishRmqMessage(channel, queueName, message.Body, message.Headers, correlationID) - if err != nil { - logger.Warnf("Failed to republish publish message: %v", err) - } - return - } - if err := handler.HandleReceivedMessage(string(message.Body), queueName, correlationID); err != nil { - logger.Warnf("Failed to process message: %v", err) - if retryCount.(int32) < handler.MessageRetryLimit { - message.Headers["RetryCount"] = retryCount.(int32) + 1 - _, err := services.PublishRmqMessage(channel, queueName, message.Body, message.Headers, correlationID) - if err != nil { - logger.Warnf("Failed to republish publish message: %v", err) - } - } else { - logger.Warnf("Retry count exceeded. Discarding the message.") - } - } + logger.Debugf("Message received from %s", queueName) + correlationID := message.CorrelationId + if message.Headers == nil { + message.Headers = make(amqp.Table) + } + + retryCount, ok := message.Headers["RetryCount"] + if !ok { + retryCount = int32(0) + } + + err := listener.shouldExecutionBeDelayed() + if err != nil { + logger.Warnf("Message deferral time exceeded. Dropping hold on the message.") + message.Headers["RetryCount"] = retryCount.(int32) + 1 + _, err := services.PublishRmqMessage(channel, queueName, message.Body, message.Headers, correlationID) + if err != nil { + logger.Warnf("Failed to republish publish message: %v", err) + } + return + } + if err := handler.HandleReceivedMessage(string(message.Body), queueName, correlationID); err != nil { + logger.Warnf("Failed to process message: %v", err) + if retryCount.(int32) < handler.MessageRetryLimit { + message.Headers["RetryCount"] = retryCount.(int32) + 1 + _, err := services.PublishRmqMessage(channel, queueName, message.Body, message.Headers, correlationID) + if err != nil { + logger.Warnf("Failed to republish publish message: %v", err) + } + } else { + logger.Warnf("Retry count exceeded. Discarding the message.") + } + } } func (listener *RmqListener) listenToQueue(queueName string) { - defer listener.wg.Done() - - channel, err := listener.ChannelPool.Get() - if err != nil { - logger.Warnf("Error getting channel from pool: %v", err) - return - } - defer listener.ChannelPool.Put(channel) - - queueNameWithConfigPrefix, err := services.CreateRmqQueue(channel, queueName) - if err != nil { - return - } - msgs, err := channel.Consume( - queueNameWithConfigPrefix, - "", - true, - false, - false, - false, - nil, - ) - if err != nil { - logger.Fatalf("%s: %s", "Failed to register a consumer", err) - } - - for { - select { - case message, ok := <-msgs: - if !ok { - return - } - listener.processRmqMessage(message, channel, queueName) - case <-listener.done: - return - } - } + defer listener.waitGroup.Done() + + channel, err := listener.ChannelPool.Get() + if err != nil { + logger.Warnf("Error getting channel from pool: %v", err) + return + } + defer listener.ChannelPool.Put(channel) + + queueNameWithConfigPrefix, err := services.CreateRmqQueue(channel, queueName) + if err != nil { + return + } + msgs, err := channel.Consume( + queueNameWithConfigPrefix, + "", + true, + false, + false, + false, + nil, + ) + if err != nil { + logger.Fatalf("%s: %s", "Failed to register a consumer", err) + } + + for { + select { + case message, ok := <-msgs: + if !ok { + return + } + listener.processRmqMessage(message, channel, queueName) + case <-listener.stopChan: + return + } + } } func (listener *RmqListener) ListenForOutputByCorrelationID(queueName string, correlationID string) (string, error) { - queueName = queues.MapQueueToOutputQueue(queueName) - channel, err := listener.ChannelPool.Get() - if err != nil { - logger.Warnf("Error getting channel from pool: %v", err) - return "", err - } - defer listener.ChannelPool.Put(channel) - - queueNameWithConfigPrefix, err := services.CreateRmqQueue(channel, queueName) - if err != nil { - return "", err - } - msgs, err := channel.Consume( - queueNameWithConfigPrefix, - "", - false, - false, - false, - false, - nil, - ) - if err != nil { - return "", fmt.Errorf("failed to consume messages: %s", err) - } - - for msg := range msgs { - if correlationID == msg.CorrelationId { - output := string(msg.Body) - err = msg.Ack(false) - if err != nil { - return "", fmt.Errorf("failed to acknowledge message: %s", err) - } - - return output, nil - } - } - - return "", fmt.Errorf("failed to receive message with correlation ID: %s", correlationID) + queueName = queues.MapQueueToOutputQueue(queueName) + channel, err := listener.ChannelPool.Get() + if err != nil { + logger.Warnf("Error getting channel from pool: %v", err) + return "", err + } + defer listener.ChannelPool.Put(channel) + + queueNameWithConfigPrefix, err := services.CreateRmqQueue(channel, queueName) + if err != nil { + return "", err + } + msgs, err := channel.Consume( + queueNameWithConfigPrefix, + "", + false, + false, + false, + false, + nil, + ) + if err != nil { + return "", fmt.Errorf("failed to consume messages: %s", err) + } + + for msg := range msgs { + if correlationID == msg.CorrelationId { + output := string(msg.Body) + err = msg.Ack(false) + if err != nil { + return "", fmt.Errorf("failed to acknowledge message: %s", err) + } + + return output, nil + } + } + + return "", fmt.Errorf("failed to receive message with correlation ID: %s", correlationID) } func (listener *RmqListener) ListenToService(queueNames []string) { - listener.done = make(chan struct{}) + listener.stopChan = make(chan struct{}) - for _, queueName := range queueNames { - listener.wg.Add(1) - go listener.listenToQueue(queueName) - } + for _, queueName := range queueNames { + listener.waitGroup.Add(1) + go listener.listenToQueue(queueName) + } - listener.wg.Wait() + listener.waitGroup.Wait() } func (listener *RmqListener) Close() { - close(listener.done) + close(listener.stopChan) + logger.Infof("Stopped listening to queues") + fmt.Println("Stopped listening to queues") } diff --git a/messages/listener/sqs.go b/messages/listener/sqs.go index 6c24814..2d5494d 100644 --- a/messages/listener/sqs.go +++ b/messages/listener/sqs.go @@ -200,5 +200,6 @@ func (listener *SqsListener) ListenToService(queueNames []string) { func (listener *SqsListener) Close() { close(listener.stopChan) - listener.waitGroup.Wait() + logger.Infof("Stopped listening to queues") + fmt.Println("Stopped listening to queues") }