Skip to content

Commit

Permalink
Merge pull request #78 from AVOlili/avolili
Browse files Browse the repository at this point in the history
新增支持指定返回值和method不用传入receiver
  • Loading branch information
agiledragon authored Feb 20, 2022
2 parents 7052c4a + d6b60c4 commit 0726845
Show file tree
Hide file tree
Showing 5 changed files with 283 additions and 2 deletions.
83 changes: 81 additions & 2 deletions patch.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ func ApplyMethod(target reflect.Type, methodName string, double interface{}) *Pa
return create().ApplyMethod(target, methodName, double)
}

func ApplyMethodFunc(target reflect.Type, methodName string, doubleFunc interface{}) *Patches {
return create().ApplyMethodFunc(target, methodName, doubleFunc)
}

func ApplyPrivateMethod(target reflect.Type, methodName string, double interface{}) *Patches {
return create().ApplyPrivateMethod(target, methodName, double)
}
Expand All @@ -52,6 +56,18 @@ func ApplyFuncVarSeq(target interface{}, outputs []OutputCell) *Patches {
return create().ApplyFuncVarSeq(target, outputs)
}

func ApplyFuncReturn(target interface{}, output ...interface{}) *Patches {
return create().ApplyFuncReturn(target, output...)
}

func ApplyMethodReturn(target interface{}, methodName string, output ...interface{}) *Patches {
return create().ApplyMethodReturn(target, methodName, output...)
}

func ApplyFuncVarReturn(target interface{}, output ...interface{}) *Patches {
return create().ApplyFuncVarReturn(target, output...)
}

func create() *Patches {
return &Patches{originals: make(map[uintptr][]byte), values: make(map[reflect.Value]reflect.Value), valueHolders: make(map[reflect.Value]reflect.Value)}
}
Expand All @@ -75,6 +91,15 @@ func (this *Patches) ApplyMethod(target reflect.Type, methodName string, double
return this.ApplyCore(m.Func, d)
}

func (this *Patches) ApplyMethodFunc(target reflect.Type, methodName string, doubleFunc interface{}) *Patches {
m, ok := target.MethodByName(methodName)
if !ok {
panic("retrieve method by name failed")
}
d := funcToMethod(m.Type, doubleFunc)
return this.ApplyCore(m.Func, d)
}

func (this *Patches) ApplyPrivateMethod(target reflect.Type, methodName string, double interface{}) *Patches {
m, ok := creflect.MethodByName(target, methodName)
if !ok {
Expand Down Expand Up @@ -136,6 +161,40 @@ func (this *Patches) ApplyFuncVarSeq(target interface{}, outputs []OutputCell) *
return this.ApplyGlobalVar(target, double)
}

func (this *Patches) ApplyFuncReturn(target interface{}, returns ...interface{}) *Patches {
funcType := reflect.TypeOf(target)
t := reflect.ValueOf(target)
outputs := []OutputCell{{Values: returns, Times: -1}}
d := getDoubleFunc(funcType, outputs)
return this.ApplyCore(t, d)
}

func (this *Patches) ApplyMethodReturn(target interface{}, methodName string, returns ...interface{}) *Patches {
m, ok := reflect.TypeOf(target).MethodByName(methodName)
if !ok {
panic("retrieve method by name failed")
}

outputs := []OutputCell{{Values: returns, Times: -1}}
d := getDoubleFunc(m.Type, outputs)
return this.ApplyCore(m.Func, d)
}

func (this *Patches) ApplyFuncVarReturn(target interface{}, returns ...interface{}) *Patches {
t := reflect.ValueOf(target)
if t.Type().Kind() != reflect.Ptr {
panic("target is not a pointer")
}
if t.Elem().Kind() != reflect.Func {
panic("target is not a func")
}

funcType := reflect.TypeOf(target).Elem()
outputs := []OutputCell{{Values: returns, Times: -1}}
double := getDoubleFunc(funcType, outputs).Interface()
return this.ApplyGlobalVar(target, double)
}

func (this *Patches) Reset() {
for target, bytes := range this.originals {
modifyBinary(target, bytes)
Expand Down Expand Up @@ -203,8 +262,14 @@ func getDoubleFunc(funcType reflect.Type, outputs []OutputCell) reflect.Value {
funcType.NumOut(), len(outputs[0].Values)))
}

needReturn := false
slice := make([]Params, 0)
for _, output := range outputs {
if output.Times == -1 {
needReturn = true
slice = []Params{output.Values}
break
}
t := 0
if output.Times <= 1 {
t = 1
Expand All @@ -217,9 +282,12 @@ func getDoubleFunc(funcType reflect.Type, outputs []OutputCell) reflect.Value {
}

i := 0
len := len(slice)
lenOutputs := len(slice)
return reflect.MakeFunc(funcType, func(_ []reflect.Value) []reflect.Value {
if i < len {
if needReturn {
return GetResultValues(funcType, slice[0]...)
}
if i < lenOutputs {
i++
return GetResultValues(funcType, slice[i-1]...)
}
Expand Down Expand Up @@ -259,3 +327,14 @@ func entryAddress(p uintptr, l int) []byte {
func pageStart(ptr uintptr) uintptr {
return ptr & ^(uintptr(syscall.Getpagesize() - 1))
}

func funcToMethod(funcType reflect.Type, doubleFunc interface{}) reflect.Value {
rf := reflect.TypeOf(doubleFunc)
if rf.Kind() != reflect.Func {
panic("doubleFunc is not a func")
}
vf := reflect.ValueOf(doubleFunc)
return reflect.MakeFunc(funcType, func(in []reflect.Value) []reflect.Value {
return vf.Call(in[1:])
})
}
39 changes: 39 additions & 0 deletions test/apply_func_return_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package test

import (
"testing"

. "github.com/agiledragon/gomonkey/v2"
"github.com/agiledragon/gomonkey/v2/test/fake"
. "github.com/smartystreets/goconvey/convey"
)

/*
compare with apply_func_seq_test.go
*/
func TestApplyFuncReturn(t *testing.T) {
Convey("TestApplyFuncReturn", t, func() {

Convey("declares the values to be returned", func() {
info1 := "hello cpp"

patches := ApplyFuncReturn(fake.ReadLeaf, info1, nil)
defer patches.Reset()

for i := 0; i < 10; i++ {
output, err := fake.ReadLeaf("")
So(err, ShouldEqual, nil)
So(output, ShouldEqual, info1)
}

patches.Reset() // if not reset will occur:patch has been existed
info2 := "hello golang"
patches.ApplyFuncReturn(fake.ReadLeaf, info2, nil)
for i := 0; i < 10; i++ {
output, err := fake.ReadLeaf("")
So(err, ShouldEqual, nil)
So(output, ShouldEqual, info2)
}
})
})
}
38 changes: 38 additions & 0 deletions test/apply_func_var_return_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package test

import (
"testing"

. "github.com/agiledragon/gomonkey/v2"
"github.com/agiledragon/gomonkey/v2/test/fake"
. "github.com/smartystreets/goconvey/convey"
)

/*
compare with apply_func_var_seq_test.go
*/
func TestApplyFuncVarReturn(t *testing.T) {
Convey("TestApplyFuncVarReturn", t, func() {

Convey("declares the values to be returned", func() {
info1 := "hello cpp"

patches := ApplyFuncVarReturn(&fake.Marshal, []byte(info1), nil)
defer patches.Reset()
for i := 0; i < 10; i++ {
bytes, err := fake.Marshal("")
So(err, ShouldEqual, nil)
So(string(bytes), ShouldEqual, info1)
}

info2 := "hello golang"
patches.ApplyFuncVarReturn(&fake.Marshal, []byte(info2), nil)
for i := 0; i < 10; i++ {
bytes, err := fake.Marshal("")
So(err, ShouldEqual, nil)
So(string(bytes), ShouldEqual, info2)
}
})

})
}
87 changes: 87 additions & 0 deletions test/apply_method_func_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package test

import (
"reflect"
"testing"

. "github.com/agiledragon/gomonkey/v2"
"github.com/agiledragon/gomonkey/v2/test/fake"
. "github.com/smartystreets/goconvey/convey"
)

/*
compare with apply_method_test.go, no need pass receiver
*/

func TestApplyMethodFunc(t *testing.T) {
slice := fake.NewSlice()
var s *fake.Slice
Convey("TestApplyMethodFunc", t, func() {
Convey("for succ", func() {
err := slice.Add(1)
So(err, ShouldEqual, nil)
patches := ApplyMethodFunc(reflect.TypeOf(s), "Add", func(_ int) error {
return nil
})
defer patches.Reset()
err = slice.Add(1)
So(err, ShouldEqual, nil)
err = slice.Remove(1)
So(err, ShouldEqual, nil)
So(len(slice), ShouldEqual, 0)
})

Convey("for already exist", func() {
err := slice.Add(2)
So(err, ShouldEqual, nil)
patches := ApplyMethodFunc(reflect.TypeOf(s), "Add", func(_ int) error {
return fake.ErrElemExsit
})
defer patches.Reset()
err = slice.Add(1)
So(err, ShouldEqual, fake.ErrElemExsit)
err = slice.Remove(2)
So(err, ShouldEqual, nil)
So(len(slice), ShouldEqual, 0)
})

Convey("two methods", func() {
err := slice.Add(3)
So(err, ShouldEqual, nil)
defer slice.Remove(3)
patches := ApplyMethodFunc(reflect.TypeOf(s), "Add", func(_ int) error {
return fake.ErrElemExsit
})
defer patches.Reset()
patches.ApplyMethodFunc(reflect.TypeOf(s), "Remove", func(_ int) error {
return fake.ErrElemNotExsit
})
err = slice.Add(2)
So(err, ShouldEqual, fake.ErrElemExsit)
err = slice.Remove(1)
So(err, ShouldEqual, fake.ErrElemNotExsit)
So(len(slice), ShouldEqual, 1)
So(slice[0], ShouldEqual, 3)
})

Convey("one func and one method", func() {
err := slice.Add(4)
So(err, ShouldEqual, nil)
defer slice.Remove(4)
patches := ApplyFunc(fake.Exec, func(_ string, _ ...string) (string, error) {
return outputExpect, nil
})
defer patches.Reset()
patches.ApplyMethodFunc(reflect.TypeOf(s), "Remove", func(_ int) error {
return fake.ErrElemNotExsit
})
output, err := fake.Exec("", "")
So(err, ShouldEqual, nil)
So(output, ShouldEqual, outputExpect)
err = slice.Remove(1)
So(err, ShouldEqual, fake.ErrElemNotExsit)
So(len(slice), ShouldEqual, 1)
So(slice[0], ShouldEqual, 4)
})
})
}
38 changes: 38 additions & 0 deletions test/apply_method_return_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package test

import (
"testing"

. "github.com/agiledragon/gomonkey/v2"
"github.com/agiledragon/gomonkey/v2/test/fake"
. "github.com/smartystreets/goconvey/convey"
)

/*
compare with apply_method_seq_test.go
*/

func TestApplyMethodReturn(t *testing.T) {
e := &fake.Etcd{}
Convey("TestApplyMethodReturn", t, func() {
Convey("declares the values to be returned", func() {
info1 := "hello cpp"
patches := ApplyMethodReturn(e, "Retrieve", info1, nil)
defer patches.Reset()
for i := 0; i < 10; i++ {
output1, err1 := e.Retrieve("")
So(err1, ShouldEqual, nil)
So(output1, ShouldEqual, info1)
}

patches.Reset() // if not reset will occur:patch has been existed
info2 := "hello golang"
patches.ApplyMethodReturn(e, "Retrieve", info2, nil)
for i := 0; i < 10; i++ {
output2, err2 := e.Retrieve("")
So(err2, ShouldEqual, nil)
So(output2, ShouldEqual, info2)
}
})
})
}

0 comments on commit 0726845

Please sign in to comment.