diff --git a/cmd/hz/protobuf/plugin.go b/cmd/hz/protobuf/plugin.go index abb7616c4..5649c6127 100644 --- a/cmd/hz/protobuf/plugin.go +++ b/cmd/hz/protobuf/plugin.go @@ -290,9 +290,31 @@ func (plugin *Plugin) Handle(req *pluginpb.CodeGeneratorRequest, args *config.Ar return nil } +func gopkgIncluded(opt string, gopkg string) bool { + if strings.HasPrefix(opt, gopkg) { + return true + } else if strings.HasPrefix(opt, "/"+gopkg) { + return true + } else { + return false + } +} + +func (plugin *Plugin) processOpt(opt string) string { + gopkg := plugin.Package + if !gopkgIncluded(opt, gopkg) { + if strings.HasPrefix(opt, "/") { + opt = gopkg + opt + } else { + opt = gopkg + "/" + opt + } + } + impt, _ := plugin.fixModelPathAndPackage(opt) + return impt +} + // fixGoPackage will update go_package to store all the model files in ${model_dir} func (plugin *Plugin) fixGoPackage(req *pluginpb.CodeGeneratorRequest, pkgMap map[string]string, trimGoPackage string) { - gopkg := plugin.Package for _, f := range req.ProtoFile { if strings.HasPrefix(f.GetPackage(), "google.protobuf") { continue @@ -300,17 +322,8 @@ func (plugin *Plugin) fixGoPackage(req *pluginpb.CodeGeneratorRequest, pkgMap ma if len(trimGoPackage) != 0 && strings.HasPrefix(f.GetOptions().GetGoPackage(), trimGoPackage) { *f.Options.GoPackage = strings.TrimPrefix(*f.Options.GoPackage, trimGoPackage) } - opt := getGoPackage(f, pkgMap) - if !strings.Contains(opt, gopkg) { - if strings.HasPrefix(opt, "/") { - opt = gopkg + opt - } else { - opt = gopkg + "/" + opt - } - } - impt, _ := plugin.fixModelPathAndPackage(opt) - *f.Options.GoPackage = impt + *f.Options.GoPackage = plugin.processOpt(opt) } } diff --git a/cmd/hz/protobuf/plugin_test.go b/cmd/hz/protobuf/plugin_test.go index b34d8b868..cbb86efa2 100644 --- a/cmd/hz/protobuf/plugin_test.go +++ b/cmd/hz/protobuf/plugin_test.go @@ -47,6 +47,24 @@ func TestPlugin_Handle(t *testing.T) { plu.recvWarningLogger() } +func TestProcessOpt(t *testing.T) { + plu := &Plugin{} + plu.Package = "hello" + plu.ModelDir = meta.ModelDir + tests := [][]string{ + {"a/b/c", "hello/biz/model/a/b/c"}, + {"a/hello/c", "hello/biz/model/a/hello/c"}, + {"biz/model/a/b/c", "hello/biz/model/a/b/c"}, + {"hello/a/b/c", "hello/biz/model/a/b/c"}, + {"hello/biz/model/a/hello/c", "hello/biz/model/a/hello/c"}, + } + for _, test := range tests { + if result := plu.processOpt(test[0]); result != test[1] { + t.Fatalf("want go package: %s, but get: %s", test[1], result) + } + } +} + func TestFixModelPathAndPackage(t *testing.T) { plu := &Plugin{} plu.Package = "cloudwego/hertz"