diff --git a/transport/http/client.go b/transport/http/client.go index 0d8eb2a3e..737147a5c 100644 --- a/transport/http/client.go +++ b/transport/http/client.go @@ -1,6 +1,9 @@ package http import ( + "bytes" + "encoding/json" + "io/ioutil" "net/http" "net/url" @@ -113,3 +116,19 @@ func (c Client) Endpoint() endpoint.Endpoint { return response, nil } } + +// EncodeJSONRequest is an EncodeRequestFunc that serializes the request as a +// JSON object to the Request body. Many JSON-over-HTTP services can use it as +// a sensible default. If the request implements Headerer, the provided headers +// will be applied to the request. +func EncodeJSONRequest(c context.Context, r *http.Request, request interface{}) error { + r.Header.Set("Content-Type", "application/json; charset=utf-8") + if headerer, ok := request.(Headerer); ok { + for k := range headerer.Headers() { + r.Header.Set(k, headerer.Headers().Get(k)) + } + } + var b bytes.Buffer + r.Body = ioutil.NopCloser(&b) + return json.NewEncoder(&b).Encode(request) +} diff --git a/transport/http/client_test.go b/transport/http/client_test.go index 3ec224e01..048399d48 100644 --- a/transport/http/client_test.go +++ b/transport/http/client_test.go @@ -2,6 +2,7 @@ package http_test import ( "io" + "io/ioutil" "net/http" "net/http/httptest" "net/url" @@ -140,6 +141,68 @@ func TestHTTPClientBufferedStream(t *testing.T) { } } +func TestEncodeJSONRequest(t *testing.T) { + var header http.Header + var body string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, err := ioutil.ReadAll(r.Body) + if err != nil && err != io.EOF { + t.Fatal(err) + } + header = r.Header + body = string(b) + })) + + defer server.Close() + + serverURL, err := url.Parse(server.URL) + + if err != nil { + t.Fatal(err) + } + + client := httptransport.NewClient( + "POST", + serverURL, + httptransport.EncodeJSONRequest, + func(context.Context, *http.Response) (interface{}, error) { return nil, nil }, + ).Endpoint() + + for _, test := range []struct { + value interface{} + body string + }{ + {nil, "null\n"}, + {12, "12\n"}, + {1.2, "1.2\n"}, + {true, "true\n"}, + {"test", "\"test\"\n"}, + {enhancedRequest{Foo: "foo"}, "{\"foo\":\"foo\"}\n"}, + } { + if _, err := client(context.Background(), test.value); err != nil { + t.Error(err) + continue + } + + if body != test.body { + t.Errorf("%v: actual %#v, expected %#v", test.value, body, test.body) + } + } + + if _, err := client(context.Background(), enhancedRequest{Foo: "foo"}); err != nil { + t.Fatal(err) + } + + if _, ok := header["X-Edward"]; !ok { + t.Fatalf("X-Edward value: actual %v, expected %v", nil, []string{"Snowden"}) + } + + if v := header.Get("X-Edward"); v != "Snowden" { + t.Errorf("X-Edward string: actual %v, expected %v", v, "Snowden") + } +} + func mustParse(s string) *url.URL { u, err := url.Parse(s) if err != nil { @@ -147,3 +210,9 @@ func mustParse(s string) *url.URL { } return u } + +type enhancedRequest struct { + Foo string `json:"foo"` +} + +func (e enhancedRequest) Headers() http.Header { return http.Header{"X-Edward": []string{"Snowden"}} }