Skip to content

Commit

Permalink
Merge pull request #31 from aertje/appengine_headers
Browse files Browse the repository at this point in the history
Appengine headers
  • Loading branch information
aertje authored Apr 10, 2021
2 parents 8828b06 + 074864f commit f4f5a49
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 55 deletions.
118 changes: 75 additions & 43 deletions emulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,7 @@ func TestCreateTask(t *testing.T) {
serv, client := setUp(t)
defer tearDown(t, serv)

queue := newQueue(formattedParent, "test")
createQueueRequest := taskspb.CreateQueueRequest{
Parent: formattedParent,
Queue: queue,
}

createdQueue, err := client.CreateQueue(context.Background(), &createQueueRequest)
require.NoError(t, err)
createdQueue := createTestQueue(t, client)

createTaskRequest := taskspb.CreateTaskRequest{
Parent: createdQueue.GetName(),
Expand All @@ -114,14 +107,7 @@ func TestCreateTaskRejectsInvalidName(t *testing.T) {
serv, client := setUp(t)
defer tearDown(t, serv)

queue := newQueue(formattedParent, "test")
createQueueRequest := taskspb.CreateQueueRequest{
Parent: formattedParent,
Queue: queue,
}

createdQueue, err := client.CreateQueue(context.Background(), &createQueueRequest)
require.NoError(t, err)
createdQueue := createTestQueue(t, client)

createTaskRequest := taskspb.CreateTaskRequest{
Parent: createdQueue.GetName(),
Expand Down Expand Up @@ -156,14 +142,7 @@ func TestSuccessTaskExecution(t *testing.T) {
func(req *http.Request) {},
)

queue := newQueue(formattedParent, "test")
createQueueRequest := taskspb.CreateQueueRequest{
Parent: formattedParent,
Queue: queue,
}

createdQueue, err := client.CreateQueue(context.Background(), &createQueueRequest)
require.NoError(t, err)
createdQueue := createTestQueue(t, client)

createTaskRequest := taskspb.CreateTaskRequest{
Parent: createdQueue.GetName(),
Expand Down Expand Up @@ -204,30 +183,76 @@ func TestSuccessTaskExecution(t *testing.T) {
}

assert.Equal(t, expectHeaders, actualHeaders)
assertIsRecentTimestamp(t, receivedRequest.Header.Get("X-CloudTasks-TaskEta"))
assertIsRecentTimestamp(t, receivedRequest.Header.Get("X-CloudTasks-TaskETA"))

srv.Shutdown(context.Background())
}

func TestErrorTaskExecution(t *testing.T) {
func TestSuccessAppEngineTaskExecution(t *testing.T) {
serv, client := setUp(t)
defer tearDown(t, serv)

called := 0
defer os.Unsetenv("APP_ENGINE_EMULATOR_HOST")
os.Setenv("APP_ENGINE_EMULATOR_HOST", "http://localhost:5000")

var receivedRequest *http.Request

srv := startTestServer(
func(req *http.Request) { receivedRequest = req },
func(req *http.Request) {},
func(req *http.Request) { called++ },
)

queue := newQueue(formattedParent, "test")
defer srv.Shutdown(context.Background())

createQueueRequest := taskspb.CreateQueueRequest{
Parent: formattedParent,
Queue: queue,
createdQueue := createTestQueue(t, client)

createTaskRequest := taskspb.CreateTaskRequest{
Parent: createdQueue.GetName(),
Task: &taskspb.Task{
Name: createdQueue.GetName() + "/tasks/my-test-task",
MessageType: &taskspb.Task_AppEngineHttpRequest{
AppEngineHttpRequest: &taskspb.AppEngineHttpRequest{
RelativeUri: "/success",
},
},
},
}

createdQueue, err := client.CreateQueue(context.Background(), &createQueueRequest)
require.NoError(t, err)
createdTask, _ := client.CreateTask(context.Background(), &createTaskRequest)

// Need to give it a chance to make the actual call
time.Sleep(100 * time.Millisecond)

assert.NotNil(t, createdTask)

expectHeaders := map[string]string{
"X-AppEngine-TaskExecutionCount": "0",
"X-AppEngine-TaskRetryCount": "0",
"X-AppEngine-TaskName": "my-test-task",
"X-AppEngine-QueueName": "test",
}
actualHeaders := make(map[string]string)

for hdr := range expectHeaders {
actualHeaders[hdr] = receivedRequest.Header.Get(hdr)
}

assert.Equal(t, expectHeaders, actualHeaders)

assertIsRecentTimestamp(t, receivedRequest.Header.Get("X-AppEngine-TaskETA"))
}

func TestErrorTaskExecution(t *testing.T) {
serv, client := setUp(t)
defer tearDown(t, serv)

called := 0
srv := startTestServer(
func(req *http.Request) {},
func(req *http.Request) { called++ },
)

createdQueue := createTestQueue(t, client)

createTaskRequest := taskspb.CreateTaskRequest{
Parent: createdQueue.GetName(),
Expand Down Expand Up @@ -268,14 +293,7 @@ func TestOIDCAuthenticatedTaskExecution(t *testing.T) {
func(req *http.Request) {},
)

queue := newQueue(formattedParent, "test")
createQueueRequest := taskspb.CreateQueueRequest{
Parent: formattedParent,
Queue: queue,
}

createdQueue, err := client.CreateQueue(context.Background(), &createQueueRequest)
require.NoError(t, err)
createdQueue := createTestQueue(t, client)

createTaskRequest := taskspb.CreateTaskRequest{
Parent: createdQueue.GetName(),
Expand All @@ -292,7 +310,7 @@ func TestOIDCAuthenticatedTaskExecution(t *testing.T) {
},
},
}
_, err = client.CreateTask(context.Background(), &createTaskRequest)
_, err := client.CreateTask(context.Background(), &createTaskRequest)
require.NoError(t, err)

// Need to give it a chance to make the actual call
Expand Down Expand Up @@ -345,6 +363,20 @@ func assertIsRecentTimestamp(t *testing.T, etaString string) {
)
}

func createTestQueue(t *testing.T, client *Client) *taskspb.Queue {
queue := newQueue(formattedParent, "test")

createQueueRequest := taskspb.CreateQueueRequest{
Parent: formattedParent,
Queue: queue,
}

createdQueue, err := client.CreateQueue(context.Background(), &createQueueRequest)
require.NoError(t, err)

return createdQueue
}

func startTestServer(successCallback serverRequestCallback, notFoundCallback serverRequestCallback) *http.Server {
mux := http.NewServeMux()
mux.HandleFunc("/success", func(w http.ResponseWriter, r *http.Request) {
Expand Down
37 changes: 25 additions & 12 deletions task.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ func NewTask(queue *Queue, taskState *tasks.Task, onDone func(task *Task)) *Task
}

func setInitialTaskState(taskState *tasks.Task, queueName string) {
// TODO: more header stuff like X-Appengine-* setting

if taskState.GetName() == "" {
taskID := strconv.FormatUint(uint64(rand.Uint64()), 10)
taskState.Name = queueName + "/tasks/" + taskID
Expand Down Expand Up @@ -298,6 +296,15 @@ func dispatch(retry bool, taskState *tasks.Task) int {
httpRequest := taskState.GetHttpRequest()
appEngineHTTPRequest := taskState.GetAppEngineHttpRequest()

scheduled, _ := ptypes.Timestamp(taskState.GetScheduleTime())
nameParts := parseTaskName(taskState)

headerQueueName := nameParts.queueId
headerTaskName := nameParts.taskId
headerTaskRetryCount := fmt.Sprintf("%v", taskState.GetDispatchCount()-1)
headerTaskExecutionCount := fmt.Sprintf("%v", taskState.GetResponseCount())
headerTaskETA := fmt.Sprintf("%f", float64(scheduled.UnixNano())/1e9)

if httpRequest != nil {
method := toHTTPMethod(httpRequest.GetHttpMethod())

Expand All @@ -309,6 +316,14 @@ func dispatch(retry bool, taskState *tasks.Task) int {
tokenStr := createOIDCToken(auth.ServiceAccountEmail, httpRequest.GetUrl())
headers["Authorization"] = "Bearer " + tokenStr
}

// Headers as per https://cloud.google.com/tasks/docs/creating-http-target-tasks#handler
// TODO: optional headers
req.Header.Set("X-CloudTasks-QueueName", headerQueueName)
req.Header.Set("X-CloudTasks-TaskName", headerTaskName)
req.Header.Set("X-CloudTasks-TaskExecutionCount", headerTaskExecutionCount)
req.Header.Set("X-CloudTasks-TaskRetryCount", headerTaskRetryCount)
req.Header.Set("X-CloudTasks-TaskETA", headerTaskETA)
} else if appEngineHTTPRequest != nil {
method := toHTTPMethod(appEngineHTTPRequest.GetHttpMethod())

Expand All @@ -319,22 +334,20 @@ func dispatch(retry bool, taskState *tasks.Task) int {
req, _ = http.NewRequest(method, url, bytes.NewBuffer(appEngineHTTPRequest.GetBody()))

headers = appEngineHTTPRequest.GetHeaders()

// These headers are only set on dispatch, see https://cloud.google.com/tasks/docs/reference/rpc/google.cloud.tasks.v2#google.cloud.tasks.v2.AppEngineHttpRequest
// TODO: optional headers
headers["X-AppEngine-QueueName"] = headerQueueName
headers["X-AppEngine-TaskName"] = headerTaskName
headers["X-AppEngine-TaskRetryCount"] = headerTaskRetryCount
headers["X-AppEngine-TaskExecutionCount"] = headerTaskExecutionCount
headers["X-AppEngine-TaskETA"] = headerTaskETA
}

for k, v := range headers {
req.Header.Set(k, v)
}

nameParts := parseTaskName(taskState)

// Headers as per https://cloud.google.com/tasks/docs/creating-http-target-tasks#handler
scheduled, _ := ptypes.Timestamp(taskState.GetScheduleTime())
req.Header.Set("X-CloudTasks-QueueName", nameParts.queueId)
req.Header.Set("X-CloudTasks-TaskName", nameParts.taskId)
req.Header.Set("X-CloudTasks-TaskExecutionCount", fmt.Sprintf("%v", taskState.GetResponseCount()))
req.Header.Set("X-CloudTasks-TaskRetryCount", fmt.Sprintf("%v", taskState.GetDispatchCount()-1))
req.Header.Set("X-CloudTasks-TaskEta", fmt.Sprintf("%f", float64(scheduled.UnixNano())/1e9))

resp, err := client.Do(req)
if err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
Expand Down

0 comments on commit f4f5a49

Please sign in to comment.