Skip to content

Commit

Permalink
Sync with internal development.
Browse files Browse the repository at this point in the history
The biggest changes:

- Add boolean constant parsing.
- Add context.Context support to rpc code.

Minor changes encountered while parsing a large thrift corpus:

- Harden code generator against '_test'-suffixed thrift schemas.
- Harden code generator against name collisions ('res').
- Rename internal RPC wrapper structs.

Better error handling:

- Log runtime errors encountered by rpc codec serialization.
- Clear buffers more aggressively.
  • Loading branch information
Matt Jones committed Aug 22, 2016
1 parent b029ac8 commit 4f4c1a1
Show file tree
Hide file tree
Showing 9 changed files with 736 additions and 579 deletions.
148 changes: 122 additions & 26 deletions generator/go.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"fmt"
"go/format"
"io"
"log"
"os"
"path/filepath"
"runtime"
Expand All @@ -31,10 +32,11 @@ var (
flagGoImportPrefix = flag.String("go.importprefix", "", "Prefix for thrift-generated go package imports")
flagGoGenerateMethods = flag.Bool("go.generate", false, "Add testing/quick compatible Generate methods to enum types")
flagGoSignedBytes = flag.Bool("go.signedbytes", false, "Interpret Thrift byte as Go signed int8 type")
flagGoRPCContext = flag.Bool("go.rpccontext", false, "Add context.Context objects to rpc wrappers")
)

var (
goNamespaceOrder = []string{"go", "perl", "py", "cpp", "rb", "java"}
goNamespaceOrder = []string{"go", "perl", "py", "cpp", "rb", "java", "cpp2"}
)

type ErrUnknownType string
Expand Down Expand Up @@ -63,6 +65,9 @@ type GoGenerator struct {
Format bool
Pointers bool
SignedBytes bool

// package names imported
packageNames map[string]bool
}

var goKeywords = map[string]bool{
Expand Down Expand Up @@ -92,8 +97,12 @@ var goKeywords = map[string]bool{
"return": true,
"var": true,

// request arguments are hardcoded to 'req'; blacklist it to prevent accidental name collisions
// request arguments are hardcoded to 'req' and the response to 'res'
"req": true,
"res": true,
// ctx is passed as the first argument, SetContext methods are generated iff flagGoRPCContext is set
"ctx": true,
"SetContext": true,
}

var basicTypes = map[string]bool{
Expand Down Expand Up @@ -200,7 +209,7 @@ func (g *GoGenerator) formatType(pkg string, thrift *parser.Thrift, typ *parser.
}

if t := thrift.Typedefs[typ.Name]; t != nil {
name := typ.Name
name := camelCase(typ.Name)
if pkg != g.pkg {
name = pkg + "." + name
}
Expand Down Expand Up @@ -356,6 +365,11 @@ func (g *GoGenerator) formatValue(v interface{}, t *parser.Type) (string, error)
return strconv.Quote(v2), nil
case int:
return strconv.Itoa(v2), nil
case bool:
if v2 {
return "true", nil
}
return "false", nil
case int64:
if t.Name == "bool" {
if v2 == 0 {
Expand Down Expand Up @@ -386,10 +400,11 @@ func (g *GoGenerator) formatValue(v interface{}, t *parser.Type) (string, error)
return buf.String(), nil
case []parser.KeyValue:
buf := &bytes.Buffer{}
buf.WriteString(g.formatType(g.pkg, g.thrift, t, 0))
buf.WriteString(g.formatType(g.pkg, g.thrift, t, toNoPointer))
buf.WriteString("{\n")
for _, kv := range v2 {
buf.WriteString("\t\t")

s, err := g.formatValue(kv.Key, t.KeyType)
if err != nil {
return "", err
Expand All @@ -400,19 +415,30 @@ func (g *GoGenerator) formatValue(v interface{}, t *parser.Type) (string, error)
if err != nil {
return "", err
}

// struct values are pointers
if t.ValueType == nil && *flagGoPointers {
s += ".Ptr()"
}

buf.WriteString(s)
buf.WriteString(",\n")
}
buf.WriteString("\t}")
return buf.String(), nil
case parser.Identifier:
parts := strings.SplitN(string(v2), ".", 2)
if len(parts) == 1 {
return camelCase(parts[0]), nil
ident := string(v2)
idx := strings.LastIndex(ident, ".")
if idx == -1 {
return camelCase(ident), nil
}

scope := ident[:idx]
if g.packageNames[scope] {
scope += "."
}

resolved := parts[0] + camelCase(parts[1])
return resolved, nil
return scope + camelCase(ident[idx+1:]), nil
}
return "", fmt.Errorf("unsupported value type %T", v)
}
Expand Down Expand Up @@ -516,10 +542,14 @@ func (e *%s) Generate(rand *rand.Rand, size int) reflect.Value {
return nil
}

func (g *GoGenerator) writeStruct(out io.Writer, st *parser.Struct) error {
func (g *GoGenerator) writeStruct(out io.Writer, st *parser.Struct, includeContext bool) error {
structName := camelCase(st.Name)

g.write(out, "\ntype %s struct {\n", structName)
if includeContext {
g.write(out, "\tctx context.Context\n")
}

for _, field := range st.Fields {
g.write(out, "\t%s\n", g.formatField(field))
}
Expand All @@ -532,11 +562,19 @@ func (g *GoGenerator) writeStruct(out io.Writer, st *parser.Struct) error {
g.write(out, "%s\n", g.formatFieldGetter(receiver, structName, field))
}

if includeContext {
g.write(out, `
func (%s *%s) SetContext(ctx context.Context) {
%s.ctx = ctx
}
`, receiver, structName, receiver)
}

return g.write(out, "\n")
}

func (g *GoGenerator) writeException(out io.Writer, ex *parser.Struct) error {
if err := g.writeStruct(out, ex); err != nil {
if err := g.writeStruct(out, ex, false); err != nil {
return err
}

Expand All @@ -550,7 +588,7 @@ func (g *GoGenerator) writeException(out io.Writer, ex *parser.Struct) error {
fieldVars := make([]string, len(ex.Fields))
for i, field := range ex.Fields {
fieldNames[i] = camelCase(field.Name) + ": %+v"
fieldVars[i] = "e." + camelCase(field.Name)
fieldVars[i] = "e.Get" + camelCase(field.Name) + "()"
}
g.write(out, "\treturn fmt.Sprintf(\"%s{%s}\", %s)\n",
exName, strings.Join(fieldNames, ", "), strings.Join(fieldVars, ", "))
Expand All @@ -570,9 +608,14 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {
methodNames := sortedKeys(svc.Methods)
for _, k := range methodNames {
method := svc.Methods[k]
args := g.formatArguments(method.Arguments)
if *flagGoRPCContext {
args = "ctx context.Context, " + args
}

g.write(out,
"\t%s(%s) %s\n",
camelCase(method.Name), g.formatArguments(method.Arguments),
camelCase(method.Name), args,
g.formatReturnType(method.ReturnType, false))
}
g.write(out, "}\n")
Expand All @@ -590,12 +633,21 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {
for _, k := range methodNames {
method := svc.Methods[k]
mName := camelCase(method.Name)

requestStructName := "InternalRPC" + svcName + camelCase(method.Name) + "Request"
responseStructName := "InternalRPC" + svcName + camelCase(method.Name) + "Response"

resArg := ""
if !method.Oneway {
resArg = fmt.Sprintf(", res *%s%sResponse", svcName, mName)
resArg = fmt.Sprintf(", res *%s", responseStructName)
}
g.write(out, "\nfunc (s *%sServer) %s(req *%s%sRequest%s) error {\n", svcName, mName, svcName, mName, resArg)
g.write(out, "\nfunc (s *%sServer) %s(req *%s%s) error {\n", svcName, mName, requestStructName, resArg)
var args []string

if *flagGoRPCContext {
args = append(args, "req.ctx")
}

for _, arg := range method.Arguments {
aName := camelCase(arg.Name)
args = append(args, "req."+aName)
Expand Down Expand Up @@ -626,13 +678,16 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {
for _, k := range methodNames {
// Request struct
method := svc.Methods[k]
reqStructName := svcName + camelCase(method.Name) + "Request"
if err := g.writeStruct(out, &parser.Struct{Name: reqStructName, Fields: method.Arguments}); err != nil {

requestStructName := "InternalRPC" + svcName + camelCase(method.Name) + "Request"
responseStructName := "InternalRPC" + svcName + camelCase(method.Name) + "Response"

if err := g.writeStruct(out, &parser.Struct{Name: requestStructName, Fields: method.Arguments}, *flagGoRPCContext); err != nil {
return err
}

if method.Oneway {
g.write(out, "\nfunc (r *%s) Oneway() bool {\n\treturn true\n}\n", reqStructName)
g.write(out, "\nfunc (r *%s) Oneway() bool {\n\treturn true\n}\n", requestStructName)
} else {
// Response struct
args := make([]*parser.Field, 0, len(method.Exceptions))
Expand All @@ -642,8 +697,8 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {
for _, ex := range method.Exceptions {
args = append(args, ex)
}
res := &parser.Struct{Name: svcName + camelCase(method.Name) + "Response", Fields: args}
if err := g.writeStruct(out, res); err != nil {
res := &parser.Struct{Name: responseStructName, Fields: args}
if err := g.writeStruct(out, res, false); err != nil {
return err
}
}
Expand All @@ -657,8 +712,12 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {

for _, k := range methodNames {
method := svc.Methods[k]

requestStructName := "InternalRPC" + svcName + camelCase(method.Name) + "Request"
responseStructName := "InternalRPC" + svcName + camelCase(method.Name) + "Response"

methodName := camelCase(method.Name)
returnType := "err error"
returnType := "(err error)"
if !method.Oneway {
returnType = g.formatReturnType(method.ReturnType, true)
}
Expand All @@ -668,7 +727,7 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {
returnType)

// Request
g.write(out, "\treq := &%s%sRequest{\n", svcName, methodName)
g.write(out, "\treq := &%s{\n", requestStructName)
for _, arg := range method.Arguments {
g.write(out, "\t\t%s: %s,\n", camelCase(arg.Name), validGoIdent(lowerCamelCase(arg.Name)))
}
Expand All @@ -679,7 +738,7 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {
// g.write(out, "\tvar res *%s%sResponse = nil\n", svcName, methodName)
g.write(out, "\tvar res interface{} = nil\n")
} else {
g.write(out, "\tres := &%s%sResponse{}\n", svcName, methodName)
g.write(out, "\tres := &%s{}\n", responseStructName)
}

// Call
Expand Down Expand Up @@ -709,6 +768,26 @@ func (g *GoGenerator) writeService(out io.Writer, svc *parser.Service) error {
return nil
}

var validMapKeys = map[string]bool{
"string": true,
"i32": true,
"i64": true,
"bool": true,
"double": true,
}

func (g *GoGenerator) isValidGoType(typ *parser.Type) bool {
if typ.KeyType == nil {
return true
}

if _, ok := g.thrift.Enums[g.resolveType(typ.KeyType)]; ok {
return true
}

return validMapKeys[g.resolveType(typ.KeyType)]
}

func (g *GoGenerator) generateSingle(out io.Writer, thriftPath string, thrift *parser.Thrift) {
packageName := g.Packages[thriftPath].Name
g.thrift = thrift
Expand All @@ -726,6 +805,11 @@ func (g *GoGenerator) generateSingle(out io.Writer, thriftPath string, thrift *p
imports = append(imports, "math/rand", "reflect")
}
}

if len(thrift.Services) > 0 && *flagGoRPCContext {
imports = append(imports, "golang.org/x/net/context")
}

if len(thrift.Includes) > 0 {
for _, path := range thrift.Includes {
pkg := g.Packages[path].Name
Expand Down Expand Up @@ -760,6 +844,12 @@ func (g *GoGenerator) generateSingle(out io.Writer, thriftPath string, thrift *p
if len(thrift.Constants) > 0 {
for _, k := range sortedKeys(thrift.Constants) {
c := thrift.Constants[k]

if !g.isValidGoType(c.Type) {
log.Printf("Skipping generation for constant %s - type is not a valid go type (%s)\n", c.Name, g.resolveType(c.Type.KeyType))
continue
}

v, err := g.formatValue(c.Value, c.Type)
if err != nil {
g.error(err)
Expand All @@ -783,7 +873,7 @@ func (g *GoGenerator) generateSingle(out io.Writer, thriftPath string, thrift *p

for _, k := range sortedKeys(thrift.Structs) {
st := thrift.Structs[k]
if err := g.writeStruct(out, st); err != nil {
if err := g.writeStruct(out, st, false); err != nil {
g.error(err)
}
}
Expand All @@ -797,7 +887,7 @@ func (g *GoGenerator) generateSingle(out io.Writer, thriftPath string, thrift *p

for _, k := range sortedKeys(thrift.Unions) {
un := thrift.Unions[k]
if err := g.writeStruct(out, un); err != nil {
if err := g.writeStruct(out, un, false); err != nil {
g.error(err)
}
}
Expand All @@ -822,7 +912,8 @@ func (g *GoGenerator) Generate(outPath string) (err error) {

// Generate package namespace mapping if necessary
if g.Packages == nil {
g.Packages = make(map[string]GoPackage)
g.Packages = map[string]GoPackage{}
g.packageNames = map[string]bool{}
}
for path, th := range g.ThriftFiles {
if pkg, ok := g.Packages[path]; !ok || pkg.Name == "" {
Expand All @@ -843,6 +934,7 @@ func (g *GoGenerator) Generate(outPath string) (err error) {
}
pkg.Name = validIdentifier(strings.ToLower(pkg.Name), "_")
g.Packages[path] = pkg
g.packageNames[pkg.Name] = true
}
}

Expand All @@ -856,6 +948,10 @@ func (g *GoGenerator) Generate(outPath string) (err error) {
filename = filename[:i]
}
}
if strings.HasSuffix(filename, "_test") {
filename = filename[:len(filename)-len("_test")]
}

filename += ".go"
pkgpath := filepath.Join(outPath, pkg.Path, pkg.Name)
outfile := filepath.Join(pkgpath, filename)
Expand Down
Loading

0 comments on commit 4f4c1a1

Please sign in to comment.