From 8fb9678009079d788015ae703433f472f5781ec0 Mon Sep 17 00:00:00 2001 From: Aidan Coyle Date: Mon, 5 Dec 2016 11:21:33 -0600 Subject: [PATCH] Initial Open Source Commit This existed as an internal project before, but the history contained access keys and the like. Currently at 1.3 --- .travis.yml | 20 +++++ Changelog.md | 17 +++++ LICENSE.txt | 8 ++ README.md | 96 ++++++++++++++++++++++++ config.go | 53 +++++++++++++ config_test.go | 87 +++++++++++++++++++++ main.go | 119 +++++++++++++++++++++++++++++ main_test.go | 126 +++++++++++++++++++++++++++++++ mocks_test.go | 50 +++++++++++++ queue.go | 120 +++++++++++++++++++++++++++++ queue_test.go | 170 ++++++++++++++++++++++++++++++++++++++++++ sqs_client.go | 86 +++++++++++++++++++++ sqs_client_test.go | 136 +++++++++++++++++++++++++++++++++ worker_client.go | 52 +++++++++++++ worker_client_test.go | 113 ++++++++++++++++++++++++++++ 15 files changed, 1253 insertions(+) create mode 100644 .travis.yml create mode 100644 Changelog.md create mode 100644 LICENSE.txt create mode 100644 README.md create mode 100644 config.go create mode 100644 config_test.go create mode 100644 main.go create mode 100644 main_test.go create mode 100644 mocks_test.go create mode 100644 queue.go create mode 100644 queue_test.go create mode 100644 sqs_client.go create mode 100644 sqs_client_test.go create mode 100644 worker_client.go create mode 100644 worker_client_test.go diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..08f5929 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,20 @@ +language: go + +go: + - 1.6 + - 1.7 + - tip + +before_script: + - go get github.com/GeertJohan/fgt + - go get github.com/golang/lint + +script: + - fgt gomft -l . + - fgt golint ./.. + - go vet ./... + - go test -v ./... + +matrix: + allow_failures: + - go: tip diff --git a/Changelog.md b/Changelog.md new file mode 100644 index 0000000..d608796 --- /dev/null +++ b/Changelog.md @@ -0,0 +1,17 @@ +# Scout Changes + +1.3 +---------- +- Rewrite the SQS integration to the use the AWS SDK instead of goamz + +1.2 +---------- +- Save jobs in Redis with the Sidekiq `retry` flag set to `true` + +1.1 +---------- + +- Remove the `--quiet` flag in favor of `--log-level` which defaults to `INFO` +- Move some of the more verbose logging to `DEBUG` level logs +- Log full message body after parsing it + diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..d6104b7 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,8 @@ +Copyright (c) 2016 Enova International + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + diff --git a/README.md b/README.md new file mode 100644 index 0000000..ee8043c --- /dev/null +++ b/README.md @@ -0,0 +1,96 @@ +# Scout + +Scout is a daemon for listening to a set of SNS topics and enqueuing anything it +finds into sidekiq jobs. It's meant to extract processing of SQS from the rails +apps that increasingly need to do so. + +## Usage + +``` +NAME: + scout - SQS Listener +Poll SQS queues specified in a config and enqueue Sidekiq jobs with the queue items. +It gracefully stops when sent SIGTERM. + +USAGE: + scout [global options] command [command options] [arguments...] + +VERSION: + 1.3 + +COMMANDS: + help, h Shows a list of commands or help for one command + +GLOBAL OPTIONS: + --config FILE, -c FILE Load config from FILE, required + --freq N, -f N Poll SQS every N milliseconds (default: 100) + --log-level value, -l value Sets log level. Accepts one of: debug, info, warn, error + --json, -j Log in json format + --help, -h show help + --version, -v print the version +``` + +## Configuration + +The configuration requires 3 distinct sets of information. It needs information +about how to connect to redis to enqueue jobs, credentials to talk to AWS and +read SQS, and a mapping from SNS topics to sidekiq worker classes in the +application. The structure looks like this. + +```yaml +redis: + host: "localhost:9000" + namespace: "test" + queue: "background" +aws: + access_key: "super" + secret_key: "secret" + region: "us-best" +queue: + name: "myapp_queue" + topics: + foo-topic: "FooWorker" + bar-topic: "BazWorker" +``` + +None of this information is actually an example of anything other than the +strucure of the file, so if you copy paste it you'll probably be disappointed. + +## Versioning + +Scout uses tagged commits to be compatible with gopkg.in. To pin to version 1, +you can import it as `gopkg.in/enova/scout.v1`. The "first" version is version +1.3, since all other versions were before this project was made open source. +Version 2 is possible at some point and may contain breaking changes, so pinning +to version 1 is recommended unless you want to work with the bleeding edge. + +## Development + +To get set up make sure to run `go get -t -u ./...` to get all the dependencies. + +### Testing + +The normal test suite can be run as expected with go test. There are also two +tagged files with expensive integration tests that require external services. +They can be run as follows + +``` + [FG-386] scout > go test -run=TestSQS -v -tags=sqsint +=== RUN TestSQS_Init +--- PASS: TestSQS_Init (3.84s) +=== RUN TestSQS_FetchDelete +--- PASS: TestSQS_FetchDelete (3.58s) + PASS +ok github.com/enova/scout 7.422s + [FG-386] scout > go test -run=TestWorker -v -tags=redisint +=== RUN TestWorker_Init +--- PASS: TestWorker_Init (0.00s) +=== RUN TestWorker_Push +--- PASS: TestWorker_Push (0.00s) +PASS +ok github.com/enova/scout 0.013s +``` + +The tests themselves (found in `sqs_client_test.go` and `worker_client_test.go`) +explain what is required to run them. In particular, the SQS integration tests +require that you provide AWS credentials to run them. diff --git a/config.go b/config.go new file mode 100644 index 0000000..3c3b9bf --- /dev/null +++ b/config.go @@ -0,0 +1,53 @@ +package main + +import ( + "io/ioutil" + + "gopkg.in/yaml.v2" +) + +// Config is the internal representation of the yaml that determines what +// the app listens to an enqueues +type Config struct { + Redis RedisConfig `yaml:"redis"` + AWS AWSConfig `yaml:"aws"` + Queue QueueConfig `yaml:"queue"` +} + +// RedisConfig is a nested config that contains the necessary parameters to +// connect to a redis instance and enqueue workers. +type RedisConfig struct { + Host string `yaml:"host"` + Namespace string `yaml:"namespace"` + Queue string `yaml:"queue"` +} + +// AWSConfig is a nested config that contains the necessary parameters to +// connect to AWS and read from SQS +type AWSConfig struct { + AccessKey string `yaml:"access_key"` + SecretKey string `yaml:"secret_key"` + Region string `yaml:"region"` +} + +// QueueConfig is a nested config that gives the SQS queue to listen on +// and a mapping of topics to workeers +type QueueConfig struct { + Name string `yaml:"name"` + Topics map[string]string `yaml:"topics"` +} + +// ReadConfig reads from a file with the given name and returns a config or +// an error if the file was unable to be parsed. It does no error checking +// as far as required fields. +func ReadConfig(file string) (*Config, error) { + data, err := ioutil.ReadFile(file) + if err != nil { + return nil, err + } + + config := new(Config) + + err = yaml.Unmarshal(data, config) + return config, err +} diff --git a/config_test.go b/config_test.go new file mode 100644 index 0000000..58bab61 --- /dev/null +++ b/config_test.go @@ -0,0 +1,87 @@ +package main + +import ( + "io/ioutil" + "os" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +func TestConfig(t *testing.T) { + suite.Run(t, new(ConfigTestSuite)) +} + +type ConfigTestSuite struct { + suite.Suite + tempfile *os.File + assert *require.Assertions +} + +func (c *ConfigTestSuite) SetupTest() { + c.assert = require.New(c.T()) + + var err error + c.tempfile, err = ioutil.TempFile("", "config") + c.assert.NoError(err) +} + +func (c *ConfigTestSuite) TearDownTest() { + os.Remove(c.tempfile.Name()) +} + +func (c *ConfigTestSuite) WriteTemp(content string) { + _, err := c.tempfile.Write([]byte(content)) + c.assert.NoError(err) + ReadConfig(c.tempfile.Name()) + err = c.tempfile.Close() + c.assert.NoError(err) +} + +var validConfig = ` +redis: + host: "localhost:9000" + namespace: "test" + queue: "background" +aws: + access_key: "super" + secret_key: "secret" + region: "us_best" +queue: + name: "myapp_queue" + topics: + foo_topic: "FooWorker" + bar_topic: "BazWorker"` + +func (c *ConfigTestSuite) TestConfig_Valid() { + c.WriteTemp(validConfig) + config, err := ReadConfig(c.tempfile.Name()) + c.assert.NoError(err) + + // More to convince myself that the yaml package works than anything + c.assert.Equal(config.Redis.Host, "localhost:9000") + c.assert.Equal(config.Redis.Queue, "background") + c.assert.Equal(config.AWS.Region, "us_best") + c.assert.Equal(config.Queue.Name, "myapp_queue") + c.assert.Equal(config.Queue.Topics["foo_topic"], "FooWorker") +} + +var sparseConfig = ` +redis: + host: "localhost:9000" +aws: + access_key: "super" + secret_key: "secret" + region: "us_best"` + +// It's ok for stuff to be missing, we'll check that elsewhere +func (c *ConfigTestSuite) TestConfig_Sparse() { + c.WriteTemp(sparseConfig) + config, err := ReadConfig(c.tempfile.Name()) + c.assert.NoError(err) + + c.assert.Equal(config.Redis.Namespace, "") + c.assert.Equal(config.AWS.Region, "us_best") + c.assert.Equal(len(config.Queue.Topics), 0) +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..d2ab066 --- /dev/null +++ b/main.go @@ -0,0 +1,119 @@ +package main + +import ( + "fmt" + "os" + "os/signal" + "syscall" + "time" + + log "github.com/Sirupsen/logrus" + "gopkg.in/urfave/cli.v1" +) + +var ( + app *cli.App + signals chan os.Signal +) + +func init() { + app = cli.NewApp() + + app.Name = "scout" + app.Usage = `SQS Listener +Poll SQS queues specified in a config and enqueue Sidekiq jobs with the queue items. +It gracefully stops when sent SIGTERM.` + + app.Version = "1.3" + + app.Flags = []cli.Flag{ + cli.StringFlag{ + Name: "config, c", + Usage: "Load config from `FILE`, required", + }, + cli.Int64Flag{ + Name: "freq, f", + Value: 100, + Usage: "Poll SQS every `N` milliseconds", + }, + cli.StringFlag{ + Name: "log-level, l", + Usage: "Sets log level. Accepts one of: debug, info, warn, error", + }, + cli.BoolFlag{ + Name: "json, j", + Usage: "Log in json format", + }, + } + + app.Action = runApp + + signals = make(chan os.Signal, 1) + signal.Notify(signals, syscall.SIGTERM) +} + +func main() { + app.Run(os.Args) +} + +func runApp(ctx *cli.Context) error { + configFile := ctx.String("config") + frequency := ctx.Int64("freq") + + if ctx.Bool("json") { + log.SetFormatter(&log.JSONFormatter{}) + } + + logLevel := ctx.String("log-level") + if logLevel == "" { + logLevel = "info" + } + + level, err := log.ParseLevel(logLevel) + if err != nil { + return cli.NewExitError("Could not parse log level", 1) + } + + log.SetLevel(level) + + if configFile == "" { + return cli.NewExitError("Missing required flag --config. Run `scout --help` for more information", 1) + } + + log.Infof("Reading config from %s", configFile) + log.Infof("Polling every %d milliseconds", frequency) + + config, err := ReadConfig(configFile) + if err != nil { + return cli.NewExitError("Failed to parse config file", 1) + } + + queue, err := NewQueue(config) + if err != nil { + return cli.NewExitError(fmt.Sprintf("Initialization error: %s", err.Error()), 1) + } + + log.Info("Now listening on queue: ", config.Queue.Name) + for topic, worker := range config.Queue.Topics { + log.Infof("%s -> %s", topic, worker) + } + + Listen(queue, time.Tick(time.Duration(frequency)*time.Millisecond)) + return nil +} + +// Listen does the work. It only returns if we get a signal +func Listen(queue Queue, freq <-chan time.Time) { + for { + select { + case <-signals: + log.Info("Got TERM") + queue.Semaphore().Wait() + return + case tick := <-freq: + log.Debug("Polling at: ", tick) + queue.Semaphore().Add(1) + go queue.Poll() + } + } +} diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..0974242 --- /dev/null +++ b/main_test.go @@ -0,0 +1,126 @@ +package main + +import ( + "flag" + "os" + "sync" + "testing" + "time" + + log "github.com/Sirupsen/logrus" + "github.com/stretchr/testify/require" + "gopkg.in/urfave/cli.v1" +) + +func TestFlags(t *testing.T) { + testFlags := flag.NewFlagSet("testflags", flag.PanicOnError) + testFlags.String("config", "", "where to find the config") + testFlags.Int64("freq", 100, "how often to poll") + testFlags.String("log-level", "info", "log level") + + // Ensure we require config + testFlags.Set("config", "") + testFlags.Set("freq", "100") + testFlags.Set("log-level", "info") + + testCtx := cli.NewContext(app, testFlags, nil) + err := runApp(testCtx) + require.Error(t, err) +} + +func TestLogLevel(t *testing.T) { + testFlags := flag.NewFlagSet("testflags", flag.PanicOnError) + testFlags.String("log-level", "", "log level") + + // Check we can set it to warn + testFlags.Set("log-level", "warn") + + testCtx := cli.NewContext(app, testFlags, nil) + runApp(testCtx) + require.Equal(t, log.GetLevel(), log.WarnLevel) + + // Setting it to the wrong thing breaks + testFlags.Set("log-level", "notalevel") + + testCtx = cli.NewContext(app, testFlags, nil) + err := runApp(testCtx) + require.Equal(t, err.Error(), "Could not parse log level") + + // Leaving it unset defaults to info + testFlags.Set("log-level", "") + + testCtx = cli.NewContext(app, testFlags, nil) + runApp(testCtx) + require.Equal(t, log.GetLevel(), log.InfoLevel) +} + +// WaitQueue implements Queue +type WaitQueue struct { + sem *sync.WaitGroup + seq chan int + wait chan int +} + +func (w *WaitQueue) Semaphore() *sync.WaitGroup { + return w.sem +} + +// Poll sends `1` on the sequence channel, then waits on `wait`, then sends `2` +// then downs the semaphore. +func (w *WaitQueue) Poll() { + defer w.sem.Done() + w.seq <- 1 + <-w.wait + w.seq <- 2 +} + +// waitListen just calls listen and then sends `3 on the sequence channel when +// the call exits +func waitListen(q Queue, freq <-chan time.Time, seq chan int) { + Listen(q, freq) + seq <- 3 +} + +// This test is pretty complicated because I'm basically using it to ensure that +// the loop in Listen will only exit once all of the in flight work is done. The +// basic structure of theis test is to have a mock whose call to Poll() blocks +// until I send it a signal. We then send something on the ticker channel to +// kick off a job (which blocks). Then we send on the signal channel to start +// the graceful exit. Then we tell the queue's Poll() to exit, and then the +// Listen should exit. We use a separate sequence channel to ensure the order +// of everything +func TestSignals(t *testing.T) { + seq := make(chan int) + wait := make(chan int) + freq := make(chan time.Time) + + queue := &WaitQueue{ + sem: new(sync.WaitGroup), + seq: seq, + wait: wait, + } + + // begin listening + go waitListen(queue, freq, seq) + + // send a tick, this should start a call to Poll() + freq <- time.Now() + + // when that call starts, we should get `1` on the sequence channel + val := <-seq + require.Equal(t, val, 1) + + // send a signal, this should start the graceful exit + signals <- os.Interrupt + + // tell Poll() that it can exit + wait <- 1 + + // first Poll() should exit + val = <-seq + require.Equal(t, val, 2) + + // then Listen() should exit + val = <-seq + require.Equal(t, val, 3) +} diff --git a/mocks_test.go b/mocks_test.go new file mode 100644 index 0000000..7fc65b5 --- /dev/null +++ b/mocks_test.go @@ -0,0 +1,50 @@ +package main + +import ( + "encoding/json" +) + +func MockMessage(body, topic string) Message { + msg := map[string]string{ + "Message": body, + "TopicArn": topic, + } + + data, err := json.Marshal(msg) + if err != nil { + panic(err) + } + + return Message{Body: string(data)} +} + +type MockSQSClient struct { + Fetchable []Message + FetchError error + Deleted []Message + DeleteError error +} + +func (m *MockSQSClient) Fetch() ([]Message, error) { + if m.FetchError != nil { + return nil, m.FetchError + } + + return m.Fetchable, nil +} + +func (m *MockSQSClient) Delete(message Message) error { + m.Deleted = append(m.Deleted, message) + return m.DeleteError +} + +type MockWorkerClient struct { + Enqueued [][]string + EnqueuedJID string + EnqueueError error +} + +func (m *MockWorkerClient) Push(class, args string) (string, error) { + m.Enqueued = append(m.Enqueued, []string{class, args}) + return m.EnqueuedJID, m.EnqueueError +} diff --git a/queue.go b/queue.go new file mode 100644 index 0000000..943db0b --- /dev/null +++ b/queue.go @@ -0,0 +1,120 @@ +package main + +import ( + "encoding/json" + "errors" + "strings" + "sync" + + log "github.com/Sirupsen/logrus" +) + +// Queue is an encasulation for processing an SQS queue and enqueueing the +// results in sidekiq +type Queue interface { + // Poll gets the next batch of messages from SQS and processes them. + // When it's finished, it downs the sempahore + Poll() + + // Semaphore returns the lock used to ensure that all the work is + // done before terminating the queue + Semaphore() *sync.WaitGroup +} + +// queue is the actual implementation +type queue struct { + WorkerClient WorkerClient + SQSClient SQSClient + Topics map[string]string + Sem *sync.WaitGroup +} + +// NewQueue creates a new Queue from the given Config. Returns an error if +// something about the config is invalid +func NewQueue(config *Config) (Queue, error) { + queue := new(queue) + var err error + + queue.SQSClient, err = NewAWSSQSClient(config.AWS, config.Queue.Name) + if err != nil { + return nil, err + } + + queue.WorkerClient, err = NewRedisWorkerClient(config.Redis) + if err != nil { + return nil, err + } + + queue.Topics = config.Queue.Topics + if len(queue.Topics) == 0 { + return nil, errors.New("No topics defined") + } + + queue.Sem = new(sync.WaitGroup) + + return queue, nil +} + +func (q *queue) Semaphore() *sync.WaitGroup { + return q.Sem +} + +func (q *queue) Poll() { + if q.Sem != nil { + defer q.Sem.Done() + } + + messages, err := q.SQSClient.Fetch() + if err != nil { + log.Error("Error fetching messages: ", err.Error()) + } + + for _, msg := range messages { + log.Info("Processing message: ", msg.MessageID) + deletable := q.EnqueueMessage(msg) + if deletable { + q.DeleteMessage(msg) + } + } +} + +// DeleteMessage deletes a single message from SQS +func (q *queue) DeleteMessage(msg Message) { + err := q.SQSClient.Delete(msg) + if err != nil { + log.Error("Couldn't delete message: ", msg.MessageID) + } else { + log.Info("Deleted message: ", msg.MessageID) + } +} + +// EnqueueMessage pushes a single message from SQS into redis +func (q *queue) EnqueueMessage(msg Message) bool { + ctx := log.WithField("MessageID", msg.MessageID) + body := make(map[string]string) + err := json.Unmarshal([]byte(msg.Body), &body) + if err != nil { + ctx.Warn("Message body could not be parsed: ", err.Error()) + return true + } + + workerClass, ok := q.Topics[topicName(body["TopicArn"])] + if !ok { + ctx.Warn("No worker for topic: ", topicName(body["TopicArn"])) + return true + } + + jid, err := q.WorkerClient.Push(workerClass, body["Message"]) + if err != nil { + ctx.Error("Couldn't enqueue worker: ", workerClass) + return false + } + + ctx.WithField("Args", body["Message"]).Info("Enqueued job: ", jid) + return true +} + +func topicName(topicARN string) string { + toks := strings.Split(topicARN, ":") + return toks[len(toks)-1] +} diff --git a/queue_test.go b/queue_test.go new file mode 100644 index 0000000..4f4ba41 --- /dev/null +++ b/queue_test.go @@ -0,0 +1,170 @@ +package main + +import ( + "errors" + "sync" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +func TestQueue(t *testing.T) { + suite.Run(t, new(QueueTestSuite)) +} + +type QueueTestSuite struct { + suite.Suite + queue *queue + sqsClient *MockSQSClient + workerClient *MockWorkerClient + assert *require.Assertions +} + +func (q *QueueTestSuite) SetupTest() { + q.assert = require.New(q.T()) + + q.queue = new(queue) + + q.sqsClient = &MockSQSClient{ + Fetchable: make([]Message, 0), + FetchError: nil, + Deleted: make([]Message, 0), + DeleteError: nil, + } + + q.queue.SQSClient = q.sqsClient + + q.workerClient = &MockWorkerClient{ + Enqueued: make([][]string, 0), + EnqueuedJID: "jid", + EnqueueError: nil, + } + + q.queue.WorkerClient = q.workerClient + + q.queue.Topics = make(map[string]string) +} + +func (q *QueueTestSuite) TestQueue_Success() { + // make some messages + message1 := MockMessage(`{"foo":"bar"}`, "topicA") + message2 := MockMessage(`{"bar":"baz"}`, "topicA") + message3 := MockMessage(`{"key":"val"}`, "topicB") + + // set the mock to return those + q.sqsClient.Fetchable = []Message{message1, message3, message2} + + // make some topics + q.queue.Topics["topicA"] = "WorkerA" + q.queue.Topics["topicB"] = "WorkerB" + + // do the work + q.queue.Poll() + + // The workers should be enqueued + q.assert.Contains(q.workerClient.Enqueued, []string{"WorkerA", `{"foo":"bar"}`}) + q.assert.Contains(q.workerClient.Enqueued, []string{"WorkerA", `{"bar":"baz"}`}) + q.assert.Contains(q.workerClient.Enqueued, []string{"WorkerB", `{"key":"val"}`}) + + // The messages should be deleted + q.assert.Contains(q.sqsClient.Deleted, message1) + q.assert.Contains(q.sqsClient.Deleted, message2) + q.assert.Contains(q.sqsClient.Deleted, message3) +} + +func (q *QueueTestSuite) TestQueue_NoTopic() { + // make some messages + message1 := MockMessage(`{"foo":"bar"}`, "topicA") + message2 := MockMessage(`{"bar":"baz"}`, "topicA") + message3 := MockMessage(`{"key":"val"}`, "topicB") + + // set the mock to return those + q.sqsClient.Fetchable = []Message{message1, message3, message2} + + // make some topics + // note: there is no topicB + q.queue.Topics["topicA"] = "WorkerA" + + // do the work + q.queue.Poll() + + // The workers should be enqueued + // note: the topic B message is not enqueued + q.assert.Contains(q.workerClient.Enqueued, []string{"WorkerA", `{"foo":"bar"}`}) + q.assert.Contains(q.workerClient.Enqueued, []string{"WorkerA", `{"bar":"baz"}`}) + + // The messages should be deleted + // note: message 3 is still deleted + q.assert.Contains(q.sqsClient.Deleted, message1) + q.assert.Contains(q.sqsClient.Deleted, message2) + q.assert.Contains(q.sqsClient.Deleted, message3) +} + +func (q *QueueTestSuite) TestQueue_UnparseableBody() { + // make some messages + message1 := MockMessage(`{"foo":"bar"}`, "topicA") + message2 := MockMessage(`{"bar":"baz"}`, "topicB") + + // this message has an unparseable body + badMessage := Message{ + Body: `thisain'tjson`, + } + + // set the mock to return those + q.sqsClient.Fetchable = []Message{message1, badMessage, message2} + + // make some topics + q.queue.Topics["topicA"] = "WorkerA" + q.queue.Topics["topicB"] = "WorkerB" + + // do the work + q.queue.Poll() + + // The workers should be enqueued + // note: the unparseable worker is not enqueued + q.assert.Contains(q.workerClient.Enqueued, []string{"WorkerA", `{"foo":"bar"}`}) + q.assert.Contains(q.workerClient.Enqueued, []string{"WorkerB", `{"bar":"baz"}`}) + + // The messages should be deleted + // note: the badMessage is deleted + q.assert.Contains(q.sqsClient.Deleted, message1) + q.assert.Contains(q.sqsClient.Deleted, message2) + q.assert.Contains(q.sqsClient.Deleted, badMessage) +} + +func (q *QueueTestSuite) TestQueue_EnqueueError() { + // make a messages + message1 := MockMessage(`{"foo":"bar"}`, "topicA") + + // set the mock to return that + q.sqsClient.Fetchable = []Message{message1} + + // make a topic + q.queue.Topics["topicA"] = "WorkerA" + + // set the worker client to error out + q.workerClient.EnqueueError = errors.New("oops") + + // do the work + q.queue.Poll() + + // nothing should be deleted + q.assert.Empty(q.sqsClient.Deleted) +} + +func (q *QueueTestSuite) TestQueue_Semaphore() { + q.queue.Sem = new(sync.WaitGroup) + q.queue.Semaphore().Add(1) + q.queue.Poll() + + // Calling Done() on a waitgroup that's at 0 will segfault + q.assert.Panics(func() { + q.queue.Semaphore().Done() + }) +} + +func TestTopicName(t *testing.T) { + // from http://docs.aws.amazon.com/sns/latest/dg/SendMessageToSQS.html + require.Equal(t, topicName("arn:aws:sns:us-west-2:123456789012:MyTopic"), "MyTopic") +} diff --git a/sqs_client.go b/sqs_client.go new file mode 100644 index 0000000..a8b3b37 --- /dev/null +++ b/sqs_client.go @@ -0,0 +1,86 @@ +package main + +import ( + "strings" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/sqs" +) + +// SQSClient is an interface for SQS +type SQSClient interface { + // Fetch returns the next batch of SQS messages + Fetch() ([]Message, error) + + // Delete deletes a single message from SQS + Delete(Message) error +} + +// Message is the internal representation of an SQS message +type Message struct { + MessageID string + Body string + ReceiptHandle string +} + +type sdkClient struct { + service *sqs.SQS + url string +} + +// NewAWSSQSClient creates an SQS client that talks to AWS on the given queue +func NewAWSSQSClient(conf AWSConfig, queueName string) (SQSClient, error) { + creds := credentials.NewStaticCredentials(conf.AccessKey, conf.SecretKey, "") + sess := session.New(&aws.Config{Region: formatRegion(conf.Region), Credentials: creds}) + + client := new(sdkClient) + client.service = sqs.New(sess) + + resp, err := client.service.GetQueueUrl(&sqs.GetQueueUrlInput{ + QueueName: &queueName, + }) + + if err != nil { + return nil, err + } + + client.url = *resp.QueueUrl + return client, nil +} + +func (s *sdkClient) Fetch() ([]Message, error) { + res, err := s.service.ReceiveMessage(&sqs.ReceiveMessageInput{ + QueueUrl: &s.url, + MaxNumberOfMessages: aws.Int64(10), + }) + if err != nil { + return nil, err + } + + msgs := make([]Message, len(res.Messages)) + + for i, m := range res.Messages { + msgs[i] = Message{ + MessageID: *m.MessageId, + Body: *m.Body, + ReceiptHandle: *m.ReceiptHandle, + } + } + + return msgs, nil +} + +func (s *sdkClient) Delete(message Message) error { + _, err := s.service.DeleteMessage(&sqs.DeleteMessageInput{ + QueueUrl: &s.url, + ReceiptHandle: &message.ReceiptHandle, + }) + return err +} + +func formatRegion(region string) *string { + newRegion := strings.NewReplacer(".", "-", "_", "-").Replace(region) + return &newRegion +} diff --git a/sqs_client_test.go b/sqs_client_test.go new file mode 100644 index 0000000..10559c2 --- /dev/null +++ b/sqs_client_test.go @@ -0,0 +1,136 @@ +// +build sqsint + +package main + +import ( + "testing" + + "github.com/goamz/goamz/sqs" + "github.com/stretchr/testify/require" +) + +// These tests are hardcoded to point to actual sqs queues and sns topics. +// You should think of them as "integration" tests, and in general you don't +// want to be running these all the time. You will need to provide your own +// credentials below to run these. + +var config = AWSConfig{ + AccessKey: "", + SecretKey: "", + Region: "", +} + +var queueName = "" + +func queueHandle() *sqs.Queue { + sqsHandle, err := sqs.NewFrom(config.AccessKey, config.SecretKey, config.Region) + if err != nil { + panic(err) + } + + queue, err := sqsHandle.GetQueue(queueName) + if err != nil { + panic(err) + } + + return queue +} + +func TestSQS_Init(t *testing.T) { + _, err := NewAWSSQSClient(config, queueName) + require.NoError(t, err) + + // wrong region + _, err = NewAWSSQSClient( + AWSConfig{ + AccessKey: "AKIAJNRPNGF7HIWQ5C6Q", + SecretKey: "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", + Region: "us.best", + }, + "test-queue-integration", + ) + require.Error(t, err) + + // wrong creds + _, err = NewAWSSQSClient( + AWSConfig{ + AccessKey: "super", + SecretKey: "secret", + Region: "us.west.2", + }, + "test-queue-integration", + ) + require.Error(t, err) + + // wrong queue + _, err = NewAWSSQSClient( + AWSConfig{ + AccessKey: "AKIAJNRPNGF7HIWQ5C6Q", + SecretKey: "myTX5YypzjqjgtZJ2ABvwqotGazqxtj37yQwyZpa", + Region: "us.west.2", + }, + "fake-queue", // hopefully nobody ever makes this + ) + require.Error(t, err) +} + +func TestSQS_FetchDelete(t *testing.T) { + recd := map[string]int{ + "foo": 0, + "bar": 0, + "baz": 0, + } + queue := queueHandle() + _, err := queue.SendMessage("foo") + require.NoError(t, err) + _, err = queue.SendMessage("bar") + require.NoError(t, err) + _, err = queue.SendMessage("baz") + require.NoError(t, err) + + client, err := NewAWSSQSClient(config, queueName) + require.NoError(t, err) + + // Loop over and read from the queue unitl there are no messages left. + // Doing it this way because even though we set max messages to 10, it + // seems that aws almost always gives us back only one anyway + for { + messages, err := client.Fetch() + require.NoError(t, err) + if len(messages) == 0 { + break + } + + for _, msg := range messages { + recd[msg.Body] += 1 + err := client.Delete(msg) + require.NoError(t, err) + } + } + + require.Equal(t, recd["foo"], 1) + require.Equal(t, recd["bar"], 1) + require.Equal(t, recd["baz"], 1) +} + +func TestSQS_FetchMany(t *testing.T) { + queue := queueHandle() + + // We're filling up the queue to ensure that a call to Fetch will + // actually return 10 messages. More sanity than anything else, don't + // be too concerned if this fails + for i := 0; i < 100; i++ { + _, err := queue.SendMessage("foo") + require.NoError(t, err) + } + + client, err := NewAWSSQSClient(config, queueName) + require.NoError(t, err) + + messages, err := client.Fetch() + require.NoError(t, err) + require.Equal(t, len(messages), 10) + + _, err = queue.Purge() + require.NoError(t, err) +} diff --git a/worker_client.go b/worker_client.go new file mode 100644 index 0000000..10fe74b --- /dev/null +++ b/worker_client.go @@ -0,0 +1,52 @@ +package main + +import ( + "encoding/json" + "errors" + + "github.com/jrallison/go-workers" +) + +// WorkerClient is an interface for enqueueing workers +type WorkerClient interface { + // Push pushes a worker onto the queue + Push(class, args string) (string, error) +} + +type redisWorkerClient struct { + queue string +} + +// NewRedisWorkerClient creates a worker client that pushes the worker to redis +func NewRedisWorkerClient(redis RedisConfig) (WorkerClient, error) { + if redis.Host == "" { + return nil, errors.New("Redis host required") + } + + if redis.Queue == "" { + return nil, errors.New("Sidekiq queue required") + } + + workers.Configure(map[string]string{ + "server": redis.Host, + "database": "0", + "pool": "20", + "process": "1", + "namespace": redis.Namespace, + }) + + return &redisWorkerClient{queue: redis.Queue}, nil +} + +func (r *redisWorkerClient) Push(class, args string) (string, error) { + // This will hopefully deserialize on the ruby end as a hash + jsonArgs := json.RawMessage([]byte(args)) + return workers.EnqueueWithOptions( + r.queue, + class, + []*json.RawMessage{&jsonArgs}, + workers.EnqueueOptions{ + Retry: true, + }, + ) +} diff --git a/worker_client_test.go b/worker_client_test.go new file mode 100644 index 0000000..d916780 --- /dev/null +++ b/worker_client_test.go @@ -0,0 +1,113 @@ +// +build redisint + +package main + +import ( + "encoding/json" + "testing" + + "github.com/jrallison/go-workers" + "github.com/stretchr/testify/require" + "gopkg.in/redis.v5" +) + +// These tests require an existing redis database that matches the config below. +// If one doesn't exist they'll probably fail + +var config = RedisConfig{ + Host: "localhost:6379", + Namespace: "integration", + Queue: "testq", +} + +func TestWorker_Init(t *testing.T) { + _, err := NewRedisWorkerClient(config) + require.NoError(t, err) + + // no host + _, err = NewRedisWorkerClient( + RedisConfig{ + Host: "", + Namespace: "integration", + Queue: "testq", + }, + ) + require.Error(t, err) + + // no queue + _, err = NewRedisWorkerClient( + RedisConfig{ + Host: "localhost:6379", + Namespace: "integration", + Queue: "", + }, + ) + require.Error(t, err) + + // no namespace, this doesn't error + _, err = NewRedisWorkerClient( + RedisConfig{ + Host: "localhost:6379", + Namespace: "", + Queue: "testq", + }, + ) + require.NoError(t, err) +} + +func TestWorker_Push(t *testing.T) { + redisHandle := redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + Password: "", + DB: 0, + }) + + err := redisHandle.Del("integration:queue:testq").Err() + require.NoError(t, err) + + client, err := NewRedisWorkerClient(config) + require.NoError(t, err) + + fooMessage := `{"msg":"foo"}` + barMessage := `{"msg":"bar"}` + + fooJID, err := client.Push("FooWorker", fooMessage) + require.NoError(t, err) + barJID, err := client.Push("BarWorker", barMessage) + require.NoError(t, err) + + require.NotEqual(t, fooJID, barJID) + + // first thing we enqueued + data, err := redisHandle.LPop("integration:queue:testq").Bytes() + require.NoError(t, err) + + fooEnqueued := &workers.EnqueueData{} + err = json.Unmarshal(data, fooEnqueued) + require.NoError(t, err) + + require.Equal(t, fooEnqueued.Jid, fooJID) + require.Equal(t, fooEnqueued.Class, "FooWorker") + require.Equal(t, fooEnqueued.Args, []interface{}{map[string]interface{}{"msg": "foo"}}) + require.Equal(t, fooEnqueued.EnqueueOptions.Retry, true) + + // second thing we enqueued + data, err = redisHandle.LPop("integration:queue:testq").Bytes() + require.NoError(t, err) + + barEnqueued := &workers.EnqueueData{} + err = json.Unmarshal(data, barEnqueued) + require.NoError(t, err) + + require.Equal(t, barEnqueued.Jid, barJID) + require.Equal(t, barEnqueued.Class, "BarWorker") + require.Equal(t, barEnqueued.Args, []interface{}{map[string]interface{}{"msg": "bar"}}) + require.Equal(t, fooEnqueued.EnqueueOptions.Retry, true) + + // verify retry is flat + barFlat := make(map[string]interface{}) + err = json.Unmarshal(data, &barFlat) + require.NoError(t, err) + + require.Equal(t, barFlat["retry"], true) +}