Skip to content

Commit

Permalink
fixes to code gen, unsafe access of unexported template data (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
jhump authored Jul 15, 2019
1 parent 17282ff commit 389d8d8
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 18 deletions.
6 changes: 3 additions & 3 deletions code.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,12 @@ func (l line) qualify(imports *Imports) {
for i := range l.args {
switch a := l.args[i].(type) {
case Package:
p := imports.RegisterImportForPackage(a)
p := imports.registerPackage(a)
if p != a.Name {
l.args[i] = Package{Name: p, ImportPath: a.ImportPath}
}
case *Package:
p := imports.RegisterImportForPackage(*a)
p := imports.registerPackage(*a)
if p != a.Name {
l.args[i] = Package{Name: p, ImportPath: a.ImportPath}
}
Expand Down Expand Up @@ -209,7 +209,7 @@ func (l line) qualify(imports *Imports) {
}
case *types.Package:
p := PackageForGoType(a)
l.args[i] = imports.RegisterImportForPackage(p)
l.args[i] = imports.registerPackage(p)
}
}
}
Expand Down
11 changes: 11 additions & 0 deletions imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ func (i *Imports) RegisterImportForPackage(pkg Package) string {
return i.RegisterImport(pkg.ImportPath, pkg.Name)
}

// registerPackage is like RegisterImportForPackage, but instead of returning
// the prefix (which includes a trailing dot if non-empty), it just returns the
// package alias.
func (i *Imports) registerPackage(pkg Package) string {
p := i.RegisterImportForPackage(pkg)
if len(p) > 0 && p[len(p)-1] == '.' {
p = p[:len(p)-1]
}
return p
}

// RegisterImport "imports" the specified package and returns the package prefix
// to use for symbols in the imported package. It is safe to import the same
// package repeatedly -- the same prefix will be returned every time. If an
Expand Down
45 changes: 30 additions & 15 deletions templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ func qualifyTemplateData(imports *Imports, data reflect.Value) (reflect.Value, b
switch data.Kind() {
case reflect.Interface:
if data.Elem().IsValid() {
qualified := true
var newRv reflect.Value
switch d := data.Interface().(type) {
case TypeName:
Expand Down Expand Up @@ -44,20 +45,25 @@ func qualifyTemplateData(imports *Imports, data reflect.Value) (reflect.Value, b
}
case *types.Package:
origP := PackageForGoType(d)
newP := imports.RegisterImportForPackage(origP)
newP := imports.registerPackage(origP)
if newP != origP.Name {
newRv = reflect.ValueOf(Package{Name: newP, ImportPath: origP.ImportPath})
}
default:
qualified = false
}

if newRv.IsValid() && newRv.Type().Implements(data.Type()) {
// For TypeName, this should always be true; but for cases
// where we've changed the type of the value, if we try to
// return an incompatible type, the result will be a panic
// with a location and message that is not awesome for
// users of this package. So we'll ignore the new value if
// it's not the right type.
return newRv, true
if qualified {
if newRv.IsValid() && newRv.Type().Implements(data.Type()) {
// For TypeName, this should always be true; but for cases
// where we've changed the type of the value, if we try to
// return an incompatible type, the result will be a panic
// with a location and message that is not awesome for
// users of this package. So we'll ignore the new value if
// it's not the right type.
return newRv, true
}
return data, false
}

return qualifyTemplateData(imports, data.Elem())
Expand All @@ -66,7 +72,7 @@ func qualifyTemplateData(imports *Imports, data reflect.Value) (reflect.Value, b
case reflect.Struct:
switch t := data.Interface().(type) {
case Package:
p := imports.RegisterImportForPackage(t)
p := imports.registerPackage(t)
if p != t.Name {
return reflect.ValueOf(&Package{Name: p, ImportPath: t.ImportPath}).Elem(), true
}
Expand All @@ -89,7 +95,7 @@ func qualifyTemplateData(imports *Imports, data reflect.Value) (reflect.Value, b
case ConstSpec:
if t.parent != nil {
oldPkg := t.parent.PackageName
newPkg := imports.RegisterImportForPackage(t.parent.Package())
newPkg := imports.registerPackage(t.parent.Package())
if newPkg != oldPkg {
newCs := t
newCs.parent = &GoFile{PackageName: newPkg}
Expand All @@ -99,7 +105,7 @@ func qualifyTemplateData(imports *Imports, data reflect.Value) (reflect.Value, b
case VarSpec:
if t.parent != nil {
oldPkg := t.parent.PackageName
newPkg := imports.RegisterImportForPackage(t.parent.Package())
newPkg := imports.registerPackage(t.parent.Package())
if newPkg != oldPkg {
newVs := t
newVs.parent = &GoFile{PackageName: newPkg}
Expand All @@ -109,7 +115,7 @@ func qualifyTemplateData(imports *Imports, data reflect.Value) (reflect.Value, b
case TypeSpec:
if t.parent != nil {
oldPkg := t.parent.PackageName
newPkg := imports.RegisterImportForPackage(t.parent.Package())
newPkg := imports.registerPackage(t.parent.Package())
if newPkg != oldPkg {
newTs := t
newTs.parent = &GoFile{PackageName: newPkg}
Expand All @@ -119,7 +125,7 @@ func qualifyTemplateData(imports *Imports, data reflect.Value) (reflect.Value, b
case FuncSpec:
if t.parent != nil {
oldPkg := t.parent.PackageName
newPkg := imports.RegisterImportForPackage(t.parent.Package())
newPkg := imports.registerPackage(t.parent.Package())
if newPkg != oldPkg {
newFs := t
newFs.parent = &GoFile{PackageName: newPkg}
Expand All @@ -141,7 +147,16 @@ func qualifyTemplateData(imports *Imports, data reflect.Value) (reflect.Value, b
default:
var newStruct reflect.Value
for i := 0; i < data.NumField(); i++ {
newV, changedV := qualifyTemplateData(imports, data.Field(i))
var newV reflect.Value
var changedV bool
fld, ok := getField(data, i)
if !ok {
// do not recurse
newV = data.Field(i)
changedV = false
} else {
newV, changedV = qualifyTemplateData(imports, fld)
}
if newStruct.IsValid() {
newStruct.Field(i).Set(newV)
} else if changedV {
Expand Down
15 changes: 15 additions & 0 deletions templates_no_unsafe.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
//+build appengine gopherjs purego
// NB: other environments where unsafe is unappropriate should use "purego" build tag
// https://github.com/golang/go/issues/23172

package gopoet

import (
"reflect"
)

func getField(v reflect.Value, index int) (reflect.Value, bool) {
fld := v.Field(index)
// We can't use unsafe, so return false for unexported fields :(
return fld, !fld.IsValid() || fld.CanInterface()
}
25 changes: 25 additions & 0 deletions templates_unsafe.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//+build !appengine,!gopherjs,!purego
// NB: other environments where unsafe is unappropriate should use "purego" build tag
// https://github.com/golang/go/issues/23172

package gopoet

import (
"reflect"
"unsafe"
)

func getField(v reflect.Value, index int) (reflect.Value, bool) {
fld := v.Field(index)
if !fld.IsValid() || fld.CanInterface() {
return fld, true
}

// NB: We are being super-sneaky. Go reflection will not let us call
// fld.Interface() if fld was obtained via unexported fields (which it
// was!). So we use unsafe to create an alternate reflect.Value instance
// that represents the same value (same type and address). We can then
// call Interface() on *that*.
val := reflect.NewAt(fld.Type(), unsafe.Pointer(fld.UnsafeAddr())).Elem()
return val, true
}

0 comments on commit 389d8d8

Please sign in to comment.