From e899d523a9908e5c5f2c36e0bdca3c4fe1fedfaa Mon Sep 17 00:00:00 2001 From: Bryan White Date: Sat, 16 Nov 2024 16:12:14 +0100 Subject: [PATCH 01/18] wip: add protocheck status-err subcommand --- tools/scripts/protocheck/cmd/status_errors.go | 286 ++++++++++++++++++ 1 file changed, 286 insertions(+) create mode 100644 tools/scripts/protocheck/cmd/status_errors.go diff --git a/tools/scripts/protocheck/cmd/status_errors.go b/tools/scripts/protocheck/cmd/status_errors.go new file mode 100644 index 000000000..9598a8556 --- /dev/null +++ b/tools/scripts/protocheck/cmd/status_errors.go @@ -0,0 +1,286 @@ +package main + +import ( + "context" + "fmt" + "go/ast" + "go/parser" + "go/token" + "log" + "path/filepath" + + "github.com/spf13/cobra" + "golang.org/x/tools/go/packages" +) + +var ( + flagModule = "module" + flagModuleShorthand = "m" + flagModuleValue = "*" + flagModuleUsage = "If present, only check message handlers of the given module." + + statusErrorsCheckCmd = &cobra.Command{ + Use: "status-errors [flags]", + Short: "Checks that all message handler function errors are wrapped in gRPC status errors.", + RunE: runStatusErrorsCheck, + } + + poktrollModules = map[string]struct{}{ + "application": {}, + //"gateway": {}, + //"service": {}, + //"session": {}, + //"shared": {}, + //"supplier": {}, + //"proof": {}, + //"tokenomics": {}, + } +) + +func init() { + statusErrorsCheckCmd.Flags().StringVarP(&flagModule, flagModuleShorthand, "m", flagModuleValue, flagModuleUsage) + rootCmd.AddCommand(statusErrorsCheckCmd) +} + +// TODO_IN_THIS_COMMIT: pre-run: drop patch version in go.mod; post-run: restore. +func runStatusErrorsCheck(cmd *cobra.Command, _ []string) error { + ctx := cmd.Context() + + if flagModule != "*" { + if _, ok := poktrollModules[flagModule]; !ok { + return fmt.Errorf("unknown module %q", flagModule) + } + + if err := checkModule(ctx, flagModule); err != nil { + return err + } + } + + for module := range poktrollModules { + if err := checkModule(ctx, module); err != nil { + return err + } + } + + return nil +} + +func checkModule(_ context.Context, moduleName string) error { + + // 0. Get the package info for the given module's keeper package. + // 1. Find the message server struct for the given module. + // 2. Recursively traverse `msg_server_*.go` files to find all of its methods. + // 3. Recursively traverse the method body to find all of its error returns. + // 4. Lookup error assignments to ensure that they are wrapped in gRPC status errors. + + // TODO: import polyzero for side effects. + //logger := polylog.Ctx(ctx) + + moduleDir := filepath.Join(".", "x", moduleName) + keeperDir := filepath.Join(moduleDir, "keeper") + + // TODO_IN_THIS_COMMIT: extract --- BEGIN + // Set up the package configuration + cfg := &packages.Config{ + Mode: packages.NeedName | packages.NeedFiles | packages.NeedImports | packages.NeedTypes | packages.NeedTypesInfo, + Tests: false, // Set to true if you also want to load test files + } + + // Load the package containing the target file or directory + poktrollPkgPathRoot := "github.com/pokt-network/poktroll" + moduleKeeperPkgPath := filepath.Join(poktrollPkgPathRoot, keeperDir) + pkgs, err := packages.Load(cfg, moduleKeeperPkgPath) + if err != nil { + log.Fatalf("Failed to load package: %v", err) + } + + // Iterate over the packages + for _, pkg := range pkgs { + if len(pkg.Errors) > 0 { + for _, pkgErr := range pkg.Errors { + log.Printf("Package error: %v", pkgErr) + } + continue + } + + // Print the package name and path + fmt.Printf("Package: %s (Path: %s)\n", pkg.Name, pkg.PkgPath) + + // Access type information + info := pkg.TypesInfo + if info == nil { + log.Println("No type information available") + continue + } + + // Inspect the type information + //for ident, obj := range info.Defs { + // if obj != nil { + // fmt.Printf("Identifier: %s, Type: %s\n", ident.Name, obj.Type()) + // } + //} + } + // TODO_IN_THIS_COMMIT: assert only 1 pkg: module's keeper... + typeInfo := pkgs[0].TypesInfo + // --- END + + msgServerGlob := filepath.Join(keeperDir, "msg_server_*.go") + + matches, err := filepath.Glob(msgServerGlob) + if err != nil { + return err + } + + // TODO_IN_THIS_COMMIT: extract --- BEGIN + for _, matchFilePath := range matches { + fset := token.NewFileSet() + + astFile, err := parser.ParseFile(fset, matchFilePath, nil, parser.AllErrors) + if err != nil { + return err + } + + //fmt.Println("BEFORE...") + //typeInfo, err := getTypeInfo(fset, matchFilePath, astFile) + //if err != nil { + // return err + //} + ////typeInfo := types.Info{} + //fmt.Println("AFTER...") + + //ast.Walk + ast.Inspect(astFile, func(n ast.Node) bool { + fnNode, ok := n.(*ast.FuncDecl) + if !ok { + return true + } + + // Skip functions which are not methods. + if fnNode.Recv == nil { + return false + } + + fnType := fnNode.Recv.List[0].Type + typeIdentNode, ok := fnType.(*ast.Ident) + if !ok { + return false + } + + if typeIdentNode.Name != "msgServer" { + return false + } + + fmt.Printf("Found msgServer method %q in %s\n", fnNode.Name.Name, matchFilePath) + + // Recursively traverse the function body, looking for non-nil error returns. + var errorReturns []*ast.IfStmt + // TODO_IN_THIS_COMMIT: extract --- BEGIN + ast.Inspect(fnNode.Body, func(n ast.Node) bool { + // Search for a return statement. + rtrnStmt, ok := n.(*ast.ReturnStmt) + if !ok { + return true + } + + ifStmt, ok := n.(*ast.IfStmt) + if !ok { + // Skip AST branches which are not logically conditional branches. + //fmt.Println("non if") + return true + } + //fmt.Println("yes if") + + // Match on `if err != nil` statements. + // TODO_IN_THIS_COMMIT: extract --- BEGIN + if ifStmt.Cond == nil { + return false + } + + errorReturn, ok := ifStmt.Cond.(*ast.BinaryExpr) + if !ok { + return false + } + + if errorReturn.Op != token.NEQ { + return false + } + + // Check that the left operand is an error type. + // TODO_IN_THIS_COMMIT: extract --- BEGIN + errIdentNode, ok := errorReturn.X.(*ast.Ident) + if !ok { + return false + } + + //errIdentNode.Obj.Kind.String() + obj := typeInfo.Uses[errIdentNode] + fmt.Sprintf("obj: %+v", obj) + // --- END + // --- END + + errorReturns = append(errorReturns, ifStmt) + + return false + }) + // --- END + + // TODO_IN_THIS_COMMIT: extract --- BEGIN + for _, errorReturn := range errorReturns { + // Check if the error return is wrapped in a gRPC status error. + //ifStmt, ok := errorReturn.If.(*ast.IfStmt) + //if !ok { + // return false + //} + ifStmt := errorReturn //.If.(*ast.IfStmt) + + switch node := ifStmt.Cond.(type) { + case *ast.BinaryExpr: + if node.Op != token.NEQ { + return false + } + + //statusErrorIdentNode, ok := ifStmtCond.X.(*ast.Ident) + //if !ok { + // continue + //} + + //fmt.Printf("Found error return %q in %s\n", statusErrorIdentNode.Name, matchFilePath) + } + } + // --- END + + return false + }) + } + // --- END + + return nil +} + +//// TODO_IN_THIS_COMMIT: move & refactor... +//var _ ast.Visitor = (*Visitor)(nil) +// +//type Visitor struct{} +// +//// TODO_IN_THIS_COMMIT: move & refactor... +//func (v *Visitor) Visit(node ast.Node) ast.Visitor { +// +//} + +// TODO_IN_THIS_COMMIT: move & godoc... +//func getTypeInfo(fset *token.FileSet, filePath string, fileNode *ast.File) (*types.Info, error) { +// //conf := types.Config{ +// // Importer: importer.For("source", nil), +// //} +// //info := &types.Info{ +// // Types: make(map[ast.Expr]types.TypeAndValue), +// // Defs: make(map[*ast.Ident]types.Object), +// // Uses: make(map[*ast.Ident]types.Object), +// //} +// //if _, err := conf.Check(fileNode.Name.Name, fset, []*ast.File{fileNode}, info); err != nil { +// // return nil, err +// //} +// // +// //return info, nil +// return &types.Info{}, nil +//} From 62c0eaef7afc646fd58342c88710a69350b3f513 Mon Sep 17 00:00:00 2001 From: Bryan White Date: Sun, 17 Nov 2024 14:31:47 +0100 Subject: [PATCH 02/18] wip: ... --- tools/scripts/protocheck/cmd/status_errors.go | 446 ++++++++++++++---- .../keeper/msg_server_delegate_to_gateway.go | 1 + 2 files changed, 343 insertions(+), 104 deletions(-) diff --git a/tools/scripts/protocheck/cmd/status_errors.go b/tools/scripts/protocheck/cmd/status_errors.go index 9598a8556..eca76758c 100644 --- a/tools/scripts/protocheck/cmd/status_errors.go +++ b/tools/scripts/protocheck/cmd/status_errors.go @@ -4,10 +4,11 @@ import ( "context" "fmt" "go/ast" - "go/parser" "go/token" + "go/types" "log" "path/filepath" + "strings" "github.com/spf13/cobra" "golang.org/x/tools/go/packages" @@ -82,7 +83,8 @@ func checkModule(_ context.Context, moduleName string) error { // TODO_IN_THIS_COMMIT: extract --- BEGIN // Set up the package configuration cfg := &packages.Config{ - Mode: packages.NeedName | packages.NeedFiles | packages.NeedImports | packages.NeedTypes | packages.NeedTypesInfo, + Mode: packages.NeedName | packages.NeedFiles | packages.NeedImports | packages.NeedTypes | packages.NeedTypesInfo | packages.LoadSyntax, + //Mode: packages.LoadAllSyntax, Tests: false, // Set to true if you also want to load test files } @@ -119,138 +121,295 @@ func checkModule(_ context.Context, moduleName string) error { // fmt.Printf("Identifier: %s, Type: %s\n", ident.Name, obj.Type()) // } //} - } - // TODO_IN_THIS_COMMIT: assert only 1 pkg: module's keeper... - typeInfo := pkgs[0].TypesInfo - // --- END - - msgServerGlob := filepath.Join(keeperDir, "msg_server_*.go") - - matches, err := filepath.Glob(msgServerGlob) - if err != nil { - return err - } - - // TODO_IN_THIS_COMMIT: extract --- BEGIN - for _, matchFilePath := range matches { - fset := token.NewFileSet() - - astFile, err := parser.ParseFile(fset, matchFilePath, nil, parser.AllErrors) - if err != nil { - return err - } + // TODO_IN_THIS_COMMIT: assert only 1 pkg: module's keeper... + typeInfo := pkgs[0].TypesInfo + // --- END - //fmt.Println("BEFORE...") - //typeInfo, err := getTypeInfo(fset, matchFilePath, astFile) + //msgServerGlob := filepath.Join(keeperDir, "msg_server_*.go") + // + //matches, err := filepath.Glob(msgServerGlob) //if err != nil { // return err //} - ////typeInfo := types.Info{} - //fmt.Println("AFTER...") - - //ast.Walk - ast.Inspect(astFile, func(n ast.Node) bool { - fnNode, ok := n.(*ast.FuncDecl) - if !ok { - return true - } - - // Skip functions which are not methods. - if fnNode.Recv == nil { - return false - } - - fnType := fnNode.Recv.List[0].Type - typeIdentNode, ok := fnType.(*ast.Ident) - if !ok { - return false - } - - if typeIdentNode.Name != "msgServer" { - return false - } - - fmt.Printf("Found msgServer method %q in %s\n", fnNode.Name.Name, matchFilePath) - // Recursively traverse the function body, looking for non-nil error returns. - var errorReturns []*ast.IfStmt - // TODO_IN_THIS_COMMIT: extract --- BEGIN - ast.Inspect(fnNode.Body, func(n ast.Node) bool { - // Search for a return statement. - rtrnStmt, ok := n.(*ast.ReturnStmt) + offendingPkgErrLines := make([]string, 0) + + // TODO_IN_THIS_COMMIT: extract --- BEGIN + //for _, matchFilePath := range matches[:1] { + //for _, astFile := range pkgs[0].Syntax { + for _, astFile := range pkg.Syntax { + //fset := token.NewFileSet() + // + //astFile, err := parser.ParseFile(fset, matchFilePath, nil, parser.AllErrors) + //if err != nil { + // return err + //} + + //fmt.Println("BEFORE...") + //typeInfo, err := getTypeInfo(fset, matchFilePath, astFile) + //if err != nil { + // return err + //} + ////typeInfo := types.Info{} + //fmt.Println("AFTER...") + + //// Skip files which don't match the msg_server_*.go pattern. + //if !strings.HasPrefix(astFile.Name.Name, "msg_server_") { + // continue + //} + + //ast.Walk + ast.Inspect(astFile, func(n ast.Node) bool { + fnNode, ok := n.(*ast.FuncDecl) if !ok { return true } - ifStmt, ok := n.(*ast.IfStmt) - if !ok { - // Skip AST branches which are not logically conditional branches. - //fmt.Println("non if") - return true - } - //fmt.Println("yes if") - - // Match on `if err != nil` statements. - // TODO_IN_THIS_COMMIT: extract --- BEGIN - if ifStmt.Cond == nil { + // Skip functions which are not methods. + if fnNode.Recv == nil { return false } - errorReturn, ok := ifStmt.Cond.(*ast.BinaryExpr) + fnType := fnNode.Recv.List[0].Type + typeIdentNode, ok := fnType.(*ast.Ident) if !ok { return false } - if errorReturn.Op != token.NEQ { + if typeIdentNode.Name != "msgServer" { return false } - // Check that the left operand is an error type. - // TODO_IN_THIS_COMMIT: extract --- BEGIN - errIdentNode, ok := errorReturn.X.(*ast.Ident) - if !ok { - return false - } + //fmt.Printf("Found msgServer method %q in %s\n", fnNode.Name.Name, matchFilePath) + fmt.Printf("in %q in %s\n", fnNode.Name.Name, astFile.Name.Name) - //errIdentNode.Obj.Kind.String() - obj := typeInfo.Uses[errIdentNode] - fmt.Sprintf("obj: %+v", obj) - // --- END - // --- END + condition := func(sel *ast.Ident, typeObj types.Object) bool { + isStatusError := sel.Name == "Error" && typeObj.Pkg().Path() == "google.golang.org/grpc/status" + pos := pkg.Fset.Position(sel.Pos()) + if !isStatusError { + fmt.Printf("fnNode: %+v", fnNode) + fmt.Printf("typeIdentNode: %+v", typeIdentNode) + offendingPkgErrLines = append(offendingPkgErrLines, fmt.Sprintf("%s:%d:%d", pos.Filename, pos.Line, pos.Column)) + } - errorReturns = append(errorReturns, ifStmt) + return isStatusError + //return true + //return false + } + + // Recursively traverse the function body, looking for non-nil error returns. + var errorReturns []*ast.IfStmt + // TODO_IN_THIS_COMMIT: extract --- BEGIN + ast.Inspect(fnNode.Body, func(n ast.Node) bool { + switch n := n.(type) { + case *ast.BlockStmt: + return true + // Search for a return statement. + case *ast.ReturnStmt: + lastReturnArg := n.Results[len(n.Results)-1] + + switch lastReturnArgNode := lastReturnArg.(type) { + // `return nil, err` <-- last arg is an *ast.Ident. + case *ast.Ident: + fmt.Printf("ast.Ident: %T: %+v\n", lastReturnArg, lastReturnArgNode) + //return true + + defs := typeInfo.Defs[lastReturnArgNode] + fmt.Printf("type defs: %+v\n", defs) + + use := typeInfo.Uses[lastReturnArgNode] + fmt.Printf("type use: %+v\n", use) + + // TODO_IN_THIS_COMMIT: No need to check that the last return + // arg is an error type if we checked that the function returns + // an error as the last arg. + if lastReturnArgNode.Name == "err" { + def := typeInfo.Defs[lastReturnArgNode] + fmt.Printf("def: %+v\n", def) + + if lastReturnArgNode.Obj == nil { + return true + } + + // TODO_IN_THIS_COMMIT: factor out and call in a case in the switch above where we handle *ast.AssignStmt + switch node := lastReturnArgNode.Obj.Decl.(type) { + case *ast.AssignStmt: + // TODO_IN_THIS_COMMIT: extract --- BEGIN + //errAssignStmt, ok := node.(*ast.AssignStmt) + //if !ok { + // panic(fmt.Sprintf("not an ast.AssignStmt: %T: %+v", node, node)) + //} + //errAssignStmt := node + + //use := typeInfo.Uses[errAssignStmt.Rhs[0]] + //def := typeInfo.Defs[errAssignStmt.Rhs[0]] + //_type := typeInfo.Types[errAssignStmt.Rhs[0]] + //impl := typeInfo.Implicits[errAssignStmt.Rhs[0]] + //inst := typeInfo.Instances[errAssignStmt.Rhs[0]] + + fmt.Printf("errAssignStmt found: %+v\n", node) + //fmt.Printf("use: %+v\n", use) + //fmt.Printf("def: %+v\n", def) + //fmt.Printf("_type: %+v\n", _type) + //fmt.Printf("impl: %+v\n", impl) + //fmt.Printf("inst: %+v\n", inst) + // --- END + + selection := typeInfo.Selections[node.Rhs[0].(*ast.CallExpr).Fun.(*ast.SelectorExpr)] + fmt.Printf("type selection: %+v\n", selection) + + // TODO_IN_THIS_COMMIT: account for other cases... + //posNode := GetNodeAtPos(astFile, pkg.Fset, node.Rhs[0].(*ast.CallExpr).Fun.Pos()) + //fmt.Printf("posNode: %+v\n", posNode) + + traverseFunctionBody(selection.Obj().(*types.Func), pkg, 0, condition) + + return false + //default: + //return true + } + //errAssignStmt, ok := lastReturnIdent.Obj.Decl.(*ast.AssignStmt) + //if !ok { + // panic(fmt.Sprintf("not an ast.AssignStmt: %T: %+v", lastReturnIdent.Obj.Decl, lastReturnIdent.Obj.Decl)) + //} + // + ////use := typeInfo.Uses[errAssignStmt.Rhs[0]] + //def := typeInfo.Defs[lastReturnArgNode] + //_type := typeInfo.Types[errAssignStmt.Rhs[0]] + //impl := typeInfo.Implicits[errAssignStmt.Rhs[0]] + ////inst := typeInfo.Instances[errAssignStmt.Rhs[0]] + // + //fmt.Printf("return found: %+v\n", n) + ////fmt.Printf("use: %+v\n", use) + //fmt.Printf("def: %+v\n", def) + //fmt.Printf("_type: %+v\n", _type) + //fmt.Printf("impl: %+v\n", impl) + ////fmt.Printf("inst: %+v\n", inst) + // + ////errAssignStmt.Rhs + + return false + //return true + } + // `return nil, types.ErrXXX.Wrapf(...)` <-- last arg is a *ast.CallExpr. + case *ast.CallExpr: + fmt.Printf("ast.CallExpr: %T: %+v\n", lastReturnArg, lastReturnArgNode) + + TraverseCallStack(lastReturnArgNode, pkg, 0, condition) + + //// TODO_IN_THIS_COMMIT: handle other types of CallExprs + //switch sel := lastReturnArgNode.Fun.(type) { + //case *ast.SelectorExpr: + // _type := typeInfo.Types[sel] + // fmt.Printf("sel types: %T: %+v\n", _type, _type) + // + // selections := typeInfo.Selections[sel] + // fmt.Printf("sel selections: %+v\n", selections) + //default: + // panic(fmt.Sprintf("unknown AST node type: %T: %+v", lastReturnArg, lastReturnArg)) + //} + // + //return true + return false + //return true + default: + //panic(fmt.Sprintf("unknown AST node type: %T: %+v", lastReturnArg, lastReturnArg)) + fmt.Printf("unknown AST node type: %T: %+v\n", lastReturnArg, lastReturnArg) + } + + //use := typeInfo.Uses[lastReturnIdent] + //def := typeInfo.Defs[lastReturnIdent] + //_type := typeInfo.Types[lastReturnIdent] + //impl := typeInfo.Implicits[lastReturnIdent] + //inst := typeInfo.Instances[lastReturnIdent] + // + ////fmt.Printf("return found: %+v\n", n) + //fmt.Printf("use: %+v\n", use) + //fmt.Printf("def: %+v\n", def) + //fmt.Printf("_type: %+v\n", _type) + //fmt.Printf("impl: %+v\n", impl) + //fmt.Printf("inst: %+v\n", inst) - return false - }) - // --- END - - // TODO_IN_THIS_COMMIT: extract --- BEGIN - for _, errorReturn := range errorReturns { - // Check if the error return is wrapped in a gRPC status error. - //ifStmt, ok := errorReturn.If.(*ast.IfStmt) - //if !ok { - // return false - //} - ifStmt := errorReturn //.If.(*ast.IfStmt) - - switch node := ifStmt.Cond.(type) { - case *ast.BinaryExpr: - if node.Op != token.NEQ { return false + //return true } - //statusErrorIdentNode, ok := ifStmtCond.X.(*ast.Ident) + return true + + //ifStmt, ok := n.(*ast.IfStmt) + //if !ok { + // // Skip AST branches which are not logically conditional branches. + // //fmt.Println("non if") + // return true + //} + ////fmt.Println("yes if") + // + //// Match on `if err != nil` statements. + //// TODO_IN_THIS_COMMIT: extract --- BEGIN + //if ifStmt.Cond == nil { + // return false + //} + // + //errorReturn, ok := ifStmt.Cond.(*ast.BinaryExpr) + //if !ok { + // return false + //} + // + //if errorReturn.Op != token.NEQ { + // return false + //} + // + //// Check that the left operand is an error type. + //// TODO_IN_THIS_COMMIT: extract --- BEGIN + //errIdentNode, ok := errorReturn.X.(*ast.Ident) + //if !ok { + // return false + //} + // + ////errIdentNode.Obj.Kind.String() + //obj := typeInfo.Uses[errIdentNode] + //fmt.Sprintf("obj: %+v", obj) + //// --- END + //// --- END + // + //errorReturns = append(errorReturns, ifStmt) + // + //return false + }) + // --- END + + // TODO_IN_THIS_COMMIT: extract --- BEGIN + for _, errorReturn := range errorReturns { + // Check if the error return is wrapped in a gRPC status error. + //ifStmt, ok := errorReturn.If.(*ast.IfStmt) //if !ok { - // continue + // return false //} + ifStmt := errorReturn //.If.(*ast.IfStmt) + + switch node := ifStmt.Cond.(type) { + case *ast.BinaryExpr: + if node.Op != token.NEQ { + return false + } + + //statusErrorIdentNode, ok := ifStmtCond.X.(*ast.Ident) + //if !ok { + // continue + //} - //fmt.Printf("Found error return %q in %s\n", statusErrorIdentNode.Name, matchFilePath) + //fmt.Printf("Found error return %q in %s\n", statusErrorIdentNode.Name, matchFilePath) + } } - } - // --- END + // --- END - return false - }) + return false + //return true + }) + } + + // Print offending lines in package + fmt.Printf("offending lines in %s:\n%s\n", pkg.PkgPath, strings.Join(offendingPkgErrLines, "\n")) } // --- END @@ -284,3 +443,82 @@ func checkModule(_ context.Context, moduleName string) error { // //return info, nil // return &types.Info{}, nil //} + +// TraverseCallStack recursively traverses the call stack starting from a *ast.CallExpr. +func TraverseCallStack(call *ast.CallExpr, pkg *packages.Package, indent int, condition func(*ast.Ident, types.Object) bool) { + fun := call.Fun + switch fn := fun.(type) { + case *ast.Ident: + // Local or top-level function + obj := pkg.TypesInfo.Uses[fn] + if obj != nil { + fmt.Printf("%sFunction: %s\n", indentSpaces(indent), obj.Name()) + if fnDecl, ok := obj.(*types.Func); ok { + traverseFunctionBody(fnDecl, pkg, indent+2, condition) + } + } + case *ast.SelectorExpr: + // Method call like obj.Method() + sel := fn.Sel + obj := pkg.TypesInfo.Selections[fn] + if obj != nil { + // Instance method + fmt.Printf("%sMethod: %s on %s\n", indentSpaces(indent), sel.Name, obj.Recv()) + if method, ok := obj.Obj().(*types.Func); ok { + traverseFunctionBody(method, pkg, indent+2, condition) + } + } else { + // Static or package-level call + typeObj := pkg.TypesInfo.Uses[sel] + if typeObj != nil { + fmt.Printf("%sFunction: %s (package-level: %s)\n", indentSpaces(indent), sel.Name, typeObj.Pkg().Path()) + if condition(sel, typeObj) { + fmt.Println(">>> STATUS ERROR FOUND!") + return + } + + if fnDecl, ok := typeObj.(*types.Func); ok { + traverseFunctionBody(fnDecl, pkg, indent+2, condition) + } + } + } + default: + fmt.Printf("%sUnknown function type: %T\n", indentSpaces(indent), fun) + } + + // Recursively inspect arguments for nested calls + for _, arg := range call.Args { + if nestedCall, ok := arg.(*ast.CallExpr); ok { + TraverseCallStack(nestedCall, pkg, indent+2, condition) + } + } +} + +// traverseFunctionBody analyzes the body of a function or method to find further calls. +func traverseFunctionBody(fn *types.Func, pkg *packages.Package, indent int, condition func(*ast.Ident, types.Object) bool) { + // Find the declaration of the function in the AST + for _, file := range pkg.Syntax { + ast.Inspect(file, func(node ast.Node) bool { + funcDecl, ok := node.(*ast.FuncDecl) + if !ok || pkg.Fset.Position(funcDecl.Pos()).Filename != pkg.Fset.Position(fn.Pos()).Filename { + return true // Not the target function, continue + } + if funcDecl.Name.Name == fn.Name() { + // Found the function, inspect its body for calls + ast.Inspect(funcDecl.Body, func(n ast.Node) bool { + if call, ok := n.(*ast.CallExpr); ok { + TraverseCallStack(call, pkg, indent, condition) + } + return true + }) + return false // Stop after finding the target function + } + return true + }) + } +} + +// Helper function to generate indentation +func indentSpaces(indent int) string { + return strings.Repeat(" ", indent) +} diff --git a/x/application/keeper/msg_server_delegate_to_gateway.go b/x/application/keeper/msg_server_delegate_to_gateway.go index f109d0ce3..502692bf6 100644 --- a/x/application/keeper/msg_server_delegate_to_gateway.go +++ b/x/application/keeper/msg_server_delegate_to_gateway.go @@ -25,6 +25,7 @@ func (k msgServer) DelegateToGateway(ctx context.Context, msg *apptypes.MsgDeleg if err := msg.ValidateBasic(); err != nil { logger.Error(fmt.Sprintf("Delegation Message failed basic validation: %v", err)) return nil, err + //return nil, status.Error(codes.InvalidArgument, err.Error()) } // Retrieve the application from the store From a25688bc51d9c963068baa59db6a1395346521e1 Mon Sep 17 00:00:00 2001 From: Bryan White Date: Sun, 17 Nov 2024 23:16:56 +0100 Subject: [PATCH 03/18] wip: PoC working --- go.mod | 2 +- tools/scripts/protocheck/cmd/status_errors.go | 226 +++++++++++------- 2 files changed, 146 insertions(+), 82 deletions(-) diff --git a/go.mod b/go.mod index 024364e1b..79bd630b7 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/pokt-network/poktroll -go 1.23.0 +go 1.23 // replace ( // DEVELOPER_TIP: Uncomment to use a local copy of shannon-sdk for development purposes. diff --git a/tools/scripts/protocheck/cmd/status_errors.go b/tools/scripts/protocheck/cmd/status_errors.go index eca76758c..040897d05 100644 --- a/tools/scripts/protocheck/cmd/status_errors.go +++ b/tools/scripts/protocheck/cmd/status_errors.go @@ -7,7 +7,6 @@ import ( "go/token" "go/types" "log" - "path/filepath" "strings" "github.com/spf13/cobra" @@ -47,26 +46,28 @@ func init() { func runStatusErrorsCheck(cmd *cobra.Command, _ []string) error { ctx := cmd.Context() - if flagModule != "*" { - if _, ok := poktrollModules[flagModule]; !ok { - return fmt.Errorf("unknown module %q", flagModule) - } - - if err := checkModule(ctx, flagModule); err != nil { - return err - } - } - - for module := range poktrollModules { - if err := checkModule(ctx, module); err != nil { - return err - } + // TODO_IN_THIS_COMMIT: to support this, need to load all modules but only inspect target module. + //if flagModule != "*" { + // if _, ok := poktrollModules[flagModule]; !ok { + // return fmt.Errorf("unknown module %q", flagModule) + // } + // + // if err := checkModule(ctx, flagModule); err != nil { + // return err + // } + //} + + //for module := range poktrollModules { + // if err := checkModule(ctx, module); err != nil { + if err := checkModule(ctx); err != nil { + return err } + //} return nil } -func checkModule(_ context.Context, moduleName string) error { +func checkModule(_ context.Context) error { // 0. Get the package info for the given module's keeper package. // 1. Find the message server struct for the given module. @@ -77,8 +78,11 @@ func checkModule(_ context.Context, moduleName string) error { // TODO: import polyzero for side effects. //logger := polylog.Ctx(ctx) - moduleDir := filepath.Join(".", "x", moduleName) - keeperDir := filepath.Join(moduleDir, "keeper") + //xDir := filepath.Join(".", "x") + //xDir := filepath.Join(".", "x", "application") + //moduleDir := filepath.Join(".", "x", "application") + //moduleDir := filepath.Join(".", "x", moduleName) + //keeperDir := filepath.Join(moduleDir, "keeper") // TODO_IN_THIS_COMMIT: extract --- BEGIN // Set up the package configuration @@ -89,15 +93,30 @@ func checkModule(_ context.Context, moduleName string) error { } // Load the package containing the target file or directory - poktrollPkgPathRoot := "github.com/pokt-network/poktroll" - moduleKeeperPkgPath := filepath.Join(poktrollPkgPathRoot, keeperDir) - pkgs, err := packages.Load(cfg, moduleKeeperPkgPath) + poktrollPkgPathPattern := "github.com/pokt-network/poktroll/x/..." + //moduleKeeperPkgPath := filepath.Join(poktrollPkgPathPattern, keeperDir) + //xPkgPath := filepath.Join(poktrollPkgPathPattern, xDir) + //fmt.Printf(">>> pkg path: %s\n", moduleKeeperPkgPath) + //pkgs, err := packages.Load(cfg, moduleKeeperPkgPath) + pkgs, err := packages.Load(cfg, poktrollPkgPathPattern) + //pkgs, err := packages.Load(cfg, "github.com/pokt-network/poktroll/x/application") if err != nil { log.Fatalf("Failed to load package: %v", err) } - // Iterate over the packages + offendingPkgErrLines := make([]string, 0) + + // Iterate over the keeper packages + // E.g.: + // - github.com/pokt-network/poktroll/x/application/keeper + // - github.com/pokt-network/poktroll/x/gateway/keeper + // - ... for _, pkg := range pkgs { + fmt.Printf("pkg: %+v\n", pkg) + if pkg.Name != "keeper" { + continue + } + if len(pkg.Errors) > 0 { for _, pkgErr := range pkg.Errors { log.Printf("Package error: %v", pkgErr) @@ -122,7 +141,8 @@ func checkModule(_ context.Context, moduleName string) error { // } //} // TODO_IN_THIS_COMMIT: assert only 1 pkg: module's keeper... - typeInfo := pkgs[0].TypesInfo + //typeInfo := pkgs[0].TypesInfo + typeInfo := pkg.TypesInfo // --- END //msgServerGlob := filepath.Join(keeperDir, "msg_server_*.go") @@ -132,8 +152,6 @@ func checkModule(_ context.Context, moduleName string) error { // return err //} - offendingPkgErrLines := make([]string, 0) - // TODO_IN_THIS_COMMIT: extract --- BEGIN //for _, matchFilePath := range matches[:1] { //for _, astFile := range pkgs[0].Syntax { @@ -153,10 +171,11 @@ func checkModule(_ context.Context, moduleName string) error { ////typeInfo := types.Info{} //fmt.Println("AFTER...") - //// Skip files which don't match the msg_server_*.go pattern. - //if !strings.HasPrefix(astFile.Name.Name, "msg_server_") { + // Skip files which don't match the msg_server_*.go pattern. + //if !strings.HasPrefix(pkg., "msg_server_") { // continue //} + //fmt.Printf(">>> %s\n", pkg.PkgPath) //ast.Walk ast.Inspect(astFile, func(n ast.Node) bool { @@ -183,18 +202,20 @@ func checkModule(_ context.Context, moduleName string) error { //fmt.Printf("Found msgServer method %q in %s\n", fnNode.Name.Name, matchFilePath) fmt.Printf("in %q in %s\n", fnNode.Name.Name, astFile.Name.Name) - condition := func(sel *ast.Ident, typeObj types.Object) bool { - isStatusError := sel.Name == "Error" && typeObj.Pkg().Path() == "google.golang.org/grpc/status" - pos := pkg.Fset.Position(sel.Pos()) - if !isStatusError { - fmt.Printf("fnNode: %+v", fnNode) - fmt.Printf("typeIdentNode: %+v", typeIdentNode) - offendingPkgErrLines = append(offendingPkgErrLines, fmt.Sprintf("%s:%d:%d", pos.Filename, pos.Line, pos.Column)) - } + condition := func(returnErrNode ast.Node) func(*ast.Ident, types.Object) bool { + return func(sel *ast.Ident, typeObj types.Object) bool { + isStatusError := sel.Name == "Error" && typeObj.Pkg().Path() == "google.golang.org/grpc/status" + pos := pkg.Fset.Position(returnErrNode.Pos()) + if !isStatusError { + fmt.Printf("fnNode: %+v", fnNode) + fmt.Printf("typeIdentNode: %+v", typeIdentNode) + offendingPkgErrLines = append(offendingPkgErrLines, fmt.Sprintf("%s:%d:%d", pos.Filename, pos.Line, pos.Column)) + } - return isStatusError - //return true - //return false + return isStatusError + //return true + //return false + } } // Recursively traverse the function body, looking for non-nil error returns. @@ -224,9 +245,6 @@ func checkModule(_ context.Context, moduleName string) error { // arg is an error type if we checked that the function returns // an error as the last arg. if lastReturnArgNode.Name == "err" { - def := typeInfo.Defs[lastReturnArgNode] - fmt.Printf("def: %+v\n", def) - if lastReturnArgNode.Obj == nil { return true } @@ -262,7 +280,13 @@ func checkModule(_ context.Context, moduleName string) error { //posNode := GetNodeAtPos(astFile, pkg.Fset, node.Rhs[0].(*ast.CallExpr).Fun.Pos()) //fmt.Printf("posNode: %+v\n", posNode) - traverseFunctionBody(selection.Obj().(*types.Func), pkg, 0, condition) + if selection == nil { + fmt.Printf("ERROR: selection is nil\n") + //return true + return false + } + + traverseFunctionBody(selection.Obj().(*types.Func), pkgs, 0, condition(lastReturnArgNode)) return false //default: @@ -295,7 +319,7 @@ func checkModule(_ context.Context, moduleName string) error { case *ast.CallExpr: fmt.Printf("ast.CallExpr: %T: %+v\n", lastReturnArg, lastReturnArgNode) - TraverseCallStack(lastReturnArgNode, pkg, 0, condition) + TraverseCallStack(lastReturnArgNode, pkgs, 0, condition(lastReturnArgNode)) //// TODO_IN_THIS_COMMIT: handle other types of CallExprs //switch sel := lastReturnArgNode.Fun.(type) { @@ -445,40 +469,59 @@ func checkModule(_ context.Context, moduleName string) error { //} // TraverseCallStack recursively traverses the call stack starting from a *ast.CallExpr. -func TraverseCallStack(call *ast.CallExpr, pkg *packages.Package, indent int, condition func(*ast.Ident, types.Object) bool) { +func TraverseCallStack(call *ast.CallExpr, pkgs []*packages.Package, indent int, condition func(*ast.Ident, types.Object) bool) { fun := call.Fun switch fn := fun.(type) { case *ast.Ident: // Local or top-level function - obj := pkg.TypesInfo.Uses[fn] - if obj != nil { - fmt.Printf("%sFunction: %s\n", indentSpaces(indent), obj.Name()) - if fnDecl, ok := obj.(*types.Func); ok { - traverseFunctionBody(fnDecl, pkg, indent+2, condition) + + var useObj types.Object + for _, pkg := range pkgs { + useObj = pkg.TypesInfo.Uses[fn] + if useObj != nil { + break + } + } + if useObj != nil { + fmt.Printf("%sFunction: %s\n", indentSpaces(indent), useObj.Name()) + if fnDecl, ok := useObj.(*types.Func); ok { + traverseFunctionBody(fnDecl, pkgs, indent+2, condition) } } case *ast.SelectorExpr: // Method call like obj.Method() sel := fn.Sel - obj := pkg.TypesInfo.Selections[fn] - if obj != nil { + var selection *types.Selection + for _, pkg := range pkgs { + selection = pkg.TypesInfo.Selections[fn] + if selection != nil { + break + } + } + if selection != nil { // Instance method - fmt.Printf("%sMethod: %s on %s\n", indentSpaces(indent), sel.Name, obj.Recv()) - if method, ok := obj.Obj().(*types.Func); ok { - traverseFunctionBody(method, pkg, indent+2, condition) + fmt.Printf("%sMethod: %s on %s\n", indentSpaces(indent), sel.Name, selection.Recv()) + if method, ok := selection.Obj().(*types.Func); ok { + traverseFunctionBody(method, pkgs, indent+2, condition) } } else { // Static or package-level call - typeObj := pkg.TypesInfo.Uses[sel] - if typeObj != nil { - fmt.Printf("%sFunction: %s (package-level: %s)\n", indentSpaces(indent), sel.Name, typeObj.Pkg().Path()) - if condition(sel, typeObj) { + var useObj types.Object + for _, pkg := range pkgs { + useObj = pkg.TypesInfo.Uses[sel] + if useObj != nil { + break + } + } + if useObj != nil { + fmt.Printf("%sFunction: %s (package-level: %s)\n", indentSpaces(indent), sel.Name, useObj.Pkg().Path()) + if condition(sel, useObj) { fmt.Println(">>> STATUS ERROR FOUND!") return } - if fnDecl, ok := typeObj.(*types.Func); ok { - traverseFunctionBody(fnDecl, pkg, indent+2, condition) + if fnDecl, ok := useObj.(*types.Func); ok { + traverseFunctionBody(fnDecl, pkgs, indent+2, condition) } } } @@ -489,32 +532,53 @@ func TraverseCallStack(call *ast.CallExpr, pkg *packages.Package, indent int, co // Recursively inspect arguments for nested calls for _, arg := range call.Args { if nestedCall, ok := arg.(*ast.CallExpr); ok { - TraverseCallStack(nestedCall, pkg, indent+2, condition) + TraverseCallStack(nestedCall, pkgs, indent+2, condition) } } } // traverseFunctionBody analyzes the body of a function or method to find further calls. -func traverseFunctionBody(fn *types.Func, pkg *packages.Package, indent int, condition func(*ast.Ident, types.Object) bool) { - // Find the declaration of the function in the AST - for _, file := range pkg.Syntax { - ast.Inspect(file, func(node ast.Node) bool { - funcDecl, ok := node.(*ast.FuncDecl) - if !ok || pkg.Fset.Position(funcDecl.Pos()).Filename != pkg.Fset.Position(fn.Pos()).Filename { - return true // Not the target function, continue - } - if funcDecl.Name.Name == fn.Name() { - // Found the function, inspect its body for calls - ast.Inspect(funcDecl.Body, func(n ast.Node) bool { - if call, ok := n.(*ast.CallExpr); ok { - TraverseCallStack(call, pkg, indent, condition) - } - return true - }) - return false // Stop after finding the target function - } - return true - }) +func traverseFunctionBody(fn *types.Func, pkgs []*packages.Package, indent int, condition func(*ast.Ident, types.Object) bool) { + fmt.Printf("fn package path: %s\n", fn.Pkg().Path()) + fmt.Printf("path has prefix: %v\n", strings.HasPrefix(fn.Pkg().Path(), "github.com/pokt-network/poktroll")) + // Don't traverse beyond poktroll module root (i.e. assume deps won't return status errors). + if !strings.HasPrefix(fn.Pkg().Path(), "github.com/pokt-network/poktroll") { + return + } + + // TODO_IN_THIS_COMMIT: Implement & log when this happens. + // DEV_NOTE: If targetFileName is not present in any package, + // we assume that a status error will not be returned by the + // function; so we MUST mark it as offending. + + for _, pkg := range pkgs { + // Find the declaration of the function in the AST + for _, file := range pkg.Syntax { + ast.Inspect(file, func(node ast.Node) bool { + funcDecl, ok := node.(*ast.FuncDecl) + if !ok { + return true // Not the target function, continue + } + targetFileName := pkg.Fset.Position(fn.Pos()).Filename + nodeFileName := pkg.Fset.Position(funcDecl.Pos()).Filename + //fmt.Printf("nodeFileName: %s\n", nodeFileName) + if nodeFileName != targetFileName { + return true // Not the target function, continue + } + + if funcDecl.Name.Name == fn.Name() { + // Found the function, inspect its body for calls + ast.Inspect(funcDecl.Body, func(n ast.Node) bool { + if call, ok := n.(*ast.CallExpr); ok { + TraverseCallStack(call, pkgs, indent, condition) + } + return true + }) + return false // Stop after finding the target function + } + return true + }) + } } } From c31b3ef87302da0d7687fecee026b883dd118999 Mon Sep 17 00:00:00 2001 From: Bryan White Date: Sun, 17 Nov 2024 23:32:37 +0100 Subject: [PATCH 04/18] wip: cleanup --- tools/scripts/protocheck/cmd/status_errors.go | 274 ++++-------------- 1 file changed, 55 insertions(+), 219 deletions(-) diff --git a/tools/scripts/protocheck/cmd/status_errors.go b/tools/scripts/protocheck/cmd/status_errors.go index 040897d05..58a1025cf 100644 --- a/tools/scripts/protocheck/cmd/status_errors.go +++ b/tools/scripts/protocheck/cmd/status_errors.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "go/ast" - "go/token" "go/types" "log" "strings" @@ -125,7 +124,7 @@ func checkModule(_ context.Context) error { } // Print the package name and path - fmt.Printf("Package: %s (Path: %s)\n", pkg.Name, pkg.PkgPath) + //fmt.Printf("Package: %s (Path: %s)\n", pkg.Name, pkg.PkgPath) // Access type information info := pkg.TypesInfo @@ -134,50 +133,11 @@ func checkModule(_ context.Context) error { continue } - // Inspect the type information - //for ident, obj := range info.Defs { - // if obj != nil { - // fmt.Printf("Identifier: %s, Type: %s\n", ident.Name, obj.Type()) - // } - //} - // TODO_IN_THIS_COMMIT: assert only 1 pkg: module's keeper... - //typeInfo := pkgs[0].TypesInfo typeInfo := pkg.TypesInfo // --- END - //msgServerGlob := filepath.Join(keeperDir, "msg_server_*.go") - // - //matches, err := filepath.Glob(msgServerGlob) - //if err != nil { - // return err - //} - // TODO_IN_THIS_COMMIT: extract --- BEGIN - //for _, matchFilePath := range matches[:1] { - //for _, astFile := range pkgs[0].Syntax { for _, astFile := range pkg.Syntax { - //fset := token.NewFileSet() - // - //astFile, err := parser.ParseFile(fset, matchFilePath, nil, parser.AllErrors) - //if err != nil { - // return err - //} - - //fmt.Println("BEFORE...") - //typeInfo, err := getTypeInfo(fset, matchFilePath, astFile) - //if err != nil { - // return err - //} - ////typeInfo := types.Info{} - //fmt.Println("AFTER...") - - // Skip files which don't match the msg_server_*.go pattern. - //if !strings.HasPrefix(pkg., "msg_server_") { - // continue - //} - //fmt.Printf(">>> %s\n", pkg.PkgPath) - - //ast.Walk ast.Inspect(astFile, func(n ast.Node) bool { fnNode, ok := n.(*ast.FuncDecl) if !ok { @@ -200,15 +160,15 @@ func checkModule(_ context.Context) error { } //fmt.Printf("Found msgServer method %q in %s\n", fnNode.Name.Name, matchFilePath) - fmt.Printf("in %q in %s\n", fnNode.Name.Name, astFile.Name.Name) + //fmt.Printf("in %q in %s\n", fnNode.Name.Name, astFile.Name.Name) condition := func(returnErrNode ast.Node) func(*ast.Ident, types.Object) bool { return func(sel *ast.Ident, typeObj types.Object) bool { isStatusError := sel.Name == "Error" && typeObj.Pkg().Path() == "google.golang.org/grpc/status" pos := pkg.Fset.Position(returnErrNode.Pos()) if !isStatusError { - fmt.Printf("fnNode: %+v", fnNode) - fmt.Printf("typeIdentNode: %+v", typeIdentNode) + //fmt.Printf("fnNode: %+v", fnNode) + //fmt.Printf("typeIdentNode: %+v", typeIdentNode) offendingPkgErrLines = append(offendingPkgErrLines, fmt.Sprintf("%s:%d:%d", pos.Filename, pos.Line, pos.Column)) } @@ -219,7 +179,7 @@ func checkModule(_ context.Context) error { } // Recursively traverse the function body, looking for non-nil error returns. - var errorReturns []*ast.IfStmt + //var errorReturns []*ast.IfStmt // TODO_IN_THIS_COMMIT: extract --- BEGIN ast.Inspect(fnNode.Body, func(n ast.Node) bool { switch n := n.(type) { @@ -232,14 +192,14 @@ func checkModule(_ context.Context) error { switch lastReturnArgNode := lastReturnArg.(type) { // `return nil, err` <-- last arg is an *ast.Ident. case *ast.Ident: - fmt.Printf("ast.Ident: %T: %+v\n", lastReturnArg, lastReturnArgNode) + //fmt.Printf("ast.Ident: %T: %+v\n", lastReturnArg, lastReturnArgNode) //return true - defs := typeInfo.Defs[lastReturnArgNode] - fmt.Printf("type defs: %+v\n", defs) + //defs := typeInfo.Defs[lastReturnArgNode] + //fmt.Printf("type defs: %+v\n", defs) - use := typeInfo.Uses[lastReturnArgNode] - fmt.Printf("type use: %+v\n", use) + //use := typeInfo.Uses[lastReturnArgNode] + //fmt.Printf("type use: %+v\n", use) // TODO_IN_THIS_COMMIT: No need to check that the last return // arg is an error type if we checked that the function returns @@ -252,33 +212,12 @@ func checkModule(_ context.Context) error { // TODO_IN_THIS_COMMIT: factor out and call in a case in the switch above where we handle *ast.AssignStmt switch node := lastReturnArgNode.Obj.Decl.(type) { case *ast.AssignStmt: - // TODO_IN_THIS_COMMIT: extract --- BEGIN - //errAssignStmt, ok := node.(*ast.AssignStmt) - //if !ok { - // panic(fmt.Sprintf("not an ast.AssignStmt: %T: %+v", node, node)) - //} - //errAssignStmt := node - - //use := typeInfo.Uses[errAssignStmt.Rhs[0]] - //def := typeInfo.Defs[errAssignStmt.Rhs[0]] - //_type := typeInfo.Types[errAssignStmt.Rhs[0]] - //impl := typeInfo.Implicits[errAssignStmt.Rhs[0]] - //inst := typeInfo.Instances[errAssignStmt.Rhs[0]] - - fmt.Printf("errAssignStmt found: %+v\n", node) - //fmt.Printf("use: %+v\n", use) - //fmt.Printf("def: %+v\n", def) - //fmt.Printf("_type: %+v\n", _type) - //fmt.Printf("impl: %+v\n", impl) - //fmt.Printf("inst: %+v\n", inst) - // --- END + //fmt.Printf("errAssignStmt found: %+v\n", node) selection := typeInfo.Selections[node.Rhs[0].(*ast.CallExpr).Fun.(*ast.SelectorExpr)] - fmt.Printf("type selection: %+v\n", selection) + //fmt.Printf("type selection: %+v\n", selection) // TODO_IN_THIS_COMMIT: account for other cases... - //posNode := GetNodeAtPos(astFile, pkg.Fset, node.Rhs[0].(*ast.CallExpr).Fun.Pos()) - //fmt.Printf("posNode: %+v\n", posNode) if selection == nil { fmt.Printf("ERROR: selection is nil\n") @@ -292,48 +231,16 @@ func checkModule(_ context.Context) error { //default: //return true } - //errAssignStmt, ok := lastReturnIdent.Obj.Decl.(*ast.AssignStmt) - //if !ok { - // panic(fmt.Sprintf("not an ast.AssignStmt: %T: %+v", lastReturnIdent.Obj.Decl, lastReturnIdent.Obj.Decl)) - //} - // - ////use := typeInfo.Uses[errAssignStmt.Rhs[0]] - //def := typeInfo.Defs[lastReturnArgNode] - //_type := typeInfo.Types[errAssignStmt.Rhs[0]] - //impl := typeInfo.Implicits[errAssignStmt.Rhs[0]] - ////inst := typeInfo.Instances[errAssignStmt.Rhs[0]] - // - //fmt.Printf("return found: %+v\n", n) - ////fmt.Printf("use: %+v\n", use) - //fmt.Printf("def: %+v\n", def) - //fmt.Printf("_type: %+v\n", _type) - //fmt.Printf("impl: %+v\n", impl) - ////fmt.Printf("inst: %+v\n", inst) - // - ////errAssignStmt.Rhs return false //return true } // `return nil, types.ErrXXX.Wrapf(...)` <-- last arg is a *ast.CallExpr. case *ast.CallExpr: - fmt.Printf("ast.CallExpr: %T: %+v\n", lastReturnArg, lastReturnArgNode) + //fmt.Printf("ast.CallExpr: %T: %+v\n", lastReturnArg, lastReturnArgNode) TraverseCallStack(lastReturnArgNode, pkgs, 0, condition(lastReturnArgNode)) - //// TODO_IN_THIS_COMMIT: handle other types of CallExprs - //switch sel := lastReturnArgNode.Fun.(type) { - //case *ast.SelectorExpr: - // _type := typeInfo.Types[sel] - // fmt.Printf("sel types: %T: %+v\n", _type, _type) - // - // selections := typeInfo.Selections[sel] - // fmt.Printf("sel selections: %+v\n", selections) - //default: - // panic(fmt.Sprintf("unknown AST node type: %T: %+v", lastReturnArg, lastReturnArg)) - //} - // - //return true return false //return true default: @@ -341,133 +248,62 @@ func checkModule(_ context.Context) error { fmt.Printf("unknown AST node type: %T: %+v\n", lastReturnArg, lastReturnArg) } - //use := typeInfo.Uses[lastReturnIdent] - //def := typeInfo.Defs[lastReturnIdent] - //_type := typeInfo.Types[lastReturnIdent] - //impl := typeInfo.Implicits[lastReturnIdent] - //inst := typeInfo.Instances[lastReturnIdent] - // - ////fmt.Printf("return found: %+v\n", n) - //fmt.Printf("use: %+v\n", use) - //fmt.Printf("def: %+v\n", def) - //fmt.Printf("_type: %+v\n", _type) - //fmt.Printf("impl: %+v\n", impl) - //fmt.Printf("inst: %+v\n", inst) - return false //return true } return true - - //ifStmt, ok := n.(*ast.IfStmt) - //if !ok { - // // Skip AST branches which are not logically conditional branches. - // //fmt.Println("non if") - // return true - //} - ////fmt.Println("yes if") - // - //// Match on `if err != nil` statements. - //// TODO_IN_THIS_COMMIT: extract --- BEGIN - //if ifStmt.Cond == nil { - // return false - //} - // - //errorReturn, ok := ifStmt.Cond.(*ast.BinaryExpr) - //if !ok { - // return false - //} - // - //if errorReturn.Op != token.NEQ { - // return false - //} - // - //// Check that the left operand is an error type. - //// TODO_IN_THIS_COMMIT: extract --- BEGIN - //errIdentNode, ok := errorReturn.X.(*ast.Ident) - //if !ok { - // return false - //} - // - ////errIdentNode.Obj.Kind.String() - //obj := typeInfo.Uses[errIdentNode] - //fmt.Sprintf("obj: %+v", obj) - //// --- END - //// --- END - // - //errorReturns = append(errorReturns, ifStmt) - // - //return false }) // --- END - // TODO_IN_THIS_COMMIT: extract --- BEGIN - for _, errorReturn := range errorReturns { - // Check if the error return is wrapped in a gRPC status error. - //ifStmt, ok := errorReturn.If.(*ast.IfStmt) - //if !ok { - // return false - //} - ifStmt := errorReturn //.If.(*ast.IfStmt) - - switch node := ifStmt.Cond.(type) { - case *ast.BinaryExpr: - if node.Op != token.NEQ { - return false - } - - //statusErrorIdentNode, ok := ifStmtCond.X.(*ast.Ident) - //if !ok { - // continue - //} - - //fmt.Printf("Found error return %q in %s\n", statusErrorIdentNode.Name, matchFilePath) - } - } - // --- END + //// TODO_IN_THIS_COMMIT: extract --- BEGIN + //for _, errorReturn := range errorReturns { + // // Check if the error return is wrapped in a gRPC status error. + // //ifStmt, ok := errorReturn.If.(*ast.IfStmt) + // //if !ok { + // // return false + // //} + // ifStmt := errorReturn //.If.(*ast.IfStmt) + // + // switch node := ifStmt.Cond.(type) { + // case *ast.BinaryExpr: + // if node.Op != token.NEQ { + // return false + // } + // } + //} + //// --- END return false //return true }) } - // Print offending lines in package - fmt.Printf("offending lines in %s:\n%s\n", pkg.PkgPath, strings.Join(offendingPkgErrLines, "\n")) } // --- END + // Print offending lines in package + // TODO_IN_THIS_COMMIT: refactor to const. + pkgsPattern := "github.com/pokt-network/poktroll/x/..." + numOffendingLines := len(offendingPkgErrLines) + if numOffendingLines == 0 { + fmt.Printf("No offending lines in %s\n", pkgsPattern) + } else { + msg := fmt.Sprintf( + "\nFound %d offending lines in %s:", + numOffendingLines, pkgsPattern, + ) + fmt.Printf( + "\n%s\n%s\n%s\n", + msg, + strings.Join(offendingPkgErrLines, "\n"), + msg, + ) + } + return nil } -//// TODO_IN_THIS_COMMIT: move & refactor... -//var _ ast.Visitor = (*Visitor)(nil) -// -//type Visitor struct{} -// -//// TODO_IN_THIS_COMMIT: move & refactor... -//func (v *Visitor) Visit(node ast.Node) ast.Visitor { -// -//} - -// TODO_IN_THIS_COMMIT: move & godoc... -//func getTypeInfo(fset *token.FileSet, filePath string, fileNode *ast.File) (*types.Info, error) { -// //conf := types.Config{ -// // Importer: importer.For("source", nil), -// //} -// //info := &types.Info{ -// // Types: make(map[ast.Expr]types.TypeAndValue), -// // Defs: make(map[*ast.Ident]types.Object), -// // Uses: make(map[*ast.Ident]types.Object), -// //} -// //if _, err := conf.Check(fileNode.Name.Name, fset, []*ast.File{fileNode}, info); err != nil { -// // return nil, err -// //} -// // -// //return info, nil -// return &types.Info{}, nil -//} - // TraverseCallStack recursively traverses the call stack starting from a *ast.CallExpr. func TraverseCallStack(call *ast.CallExpr, pkgs []*packages.Package, indent int, condition func(*ast.Ident, types.Object) bool) { fun := call.Fun @@ -483,7 +319,7 @@ func TraverseCallStack(call *ast.CallExpr, pkgs []*packages.Package, indent int, } } if useObj != nil { - fmt.Printf("%sFunction: %s\n", indentSpaces(indent), useObj.Name()) + //fmt.Printf("%sFunction: %s\n", indentSpaces(indent), useObj.Name()) if fnDecl, ok := useObj.(*types.Func); ok { traverseFunctionBody(fnDecl, pkgs, indent+2, condition) } @@ -500,7 +336,7 @@ func TraverseCallStack(call *ast.CallExpr, pkgs []*packages.Package, indent int, } if selection != nil { // Instance method - fmt.Printf("%sMethod: %s on %s\n", indentSpaces(indent), sel.Name, selection.Recv()) + //fmt.Printf("%sMethod: %s on %s\n", indentSpaces(indent), sel.Name, selection.Recv()) if method, ok := selection.Obj().(*types.Func); ok { traverseFunctionBody(method, pkgs, indent+2, condition) } @@ -514,9 +350,9 @@ func TraverseCallStack(call *ast.CallExpr, pkgs []*packages.Package, indent int, } } if useObj != nil { - fmt.Printf("%sFunction: %s (package-level: %s)\n", indentSpaces(indent), sel.Name, useObj.Pkg().Path()) + //fmt.Printf("%sFunction: %s (package-level: %s)\n", indentSpaces(indent), sel.Name, useObj.Pkg().Path()) if condition(sel, useObj) { - fmt.Println(">>> STATUS ERROR FOUND!") + //fmt.Println(">>> STATUS ERROR FOUND!") return } @@ -539,8 +375,8 @@ func TraverseCallStack(call *ast.CallExpr, pkgs []*packages.Package, indent int, // traverseFunctionBody analyzes the body of a function or method to find further calls. func traverseFunctionBody(fn *types.Func, pkgs []*packages.Package, indent int, condition func(*ast.Ident, types.Object) bool) { - fmt.Printf("fn package path: %s\n", fn.Pkg().Path()) - fmt.Printf("path has prefix: %v\n", strings.HasPrefix(fn.Pkg().Path(), "github.com/pokt-network/poktroll")) + //fmt.Printf("fn package path: %s\n", fn.Pkg().Path()) + //fmt.Printf("path has prefix: %v\n", strings.HasPrefix(fn.Pkg().Path(), "github.com/pokt-network/poktroll")) // Don't traverse beyond poktroll module root (i.e. assume deps won't return status errors). if !strings.HasPrefix(fn.Pkg().Path(), "github.com/pokt-network/poktroll") { return From e24e241c7818850b693011ea6372f70eee795439 Mon Sep 17 00:00:00 2001 From: Bryan White Date: Mon, 18 Nov 2024 01:21:25 +0100 Subject: [PATCH 05/18] wip: cleanup --- tools/scripts/protocheck/cmd/status_errors.go | 268 ++++++++++-------- 1 file changed, 155 insertions(+), 113 deletions(-) diff --git a/tools/scripts/protocheck/cmd/status_errors.go b/tools/scripts/protocheck/cmd/status_errors.go index 58a1025cf..25f871cc0 100644 --- a/tools/scripts/protocheck/cmd/status_errors.go +++ b/tools/scripts/protocheck/cmd/status_errors.go @@ -6,6 +6,7 @@ import ( "go/ast" "go/types" "log" + "sort" "strings" "github.com/spf13/cobra" @@ -23,17 +24,6 @@ var ( Short: "Checks that all message handler function errors are wrapped in gRPC status errors.", RunE: runStatusErrorsCheck, } - - poktrollModules = map[string]struct{}{ - "application": {}, - //"gateway": {}, - //"service": {}, - //"session": {}, - //"shared": {}, - //"supplier": {}, - //"proof": {}, - //"tokenomics": {}, - } ) func init() { @@ -47,13 +37,7 @@ func runStatusErrorsCheck(cmd *cobra.Command, _ []string) error { // TODO_IN_THIS_COMMIT: to support this, need to load all modules but only inspect target module. //if flagModule != "*" { - // if _, ok := poktrollModules[flagModule]; !ok { - // return fmt.Errorf("unknown module %q", flagModule) - // } - // - // if err := checkModule(ctx, flagModule); err != nil { - // return err - // } + // ... //} //for module := range poktrollModules { @@ -77,33 +61,22 @@ func checkModule(_ context.Context) error { // TODO: import polyzero for side effects. //logger := polylog.Ctx(ctx) - //xDir := filepath.Join(".", "x") - //xDir := filepath.Join(".", "x", "application") - //moduleDir := filepath.Join(".", "x", "application") - //moduleDir := filepath.Join(".", "x", moduleName) - //keeperDir := filepath.Join(moduleDir, "keeper") - // TODO_IN_THIS_COMMIT: extract --- BEGIN // Set up the package configuration cfg := &packages.Config{ - Mode: packages.NeedName | packages.NeedFiles | packages.NeedImports | packages.NeedTypes | packages.NeedTypesInfo | packages.LoadSyntax, - //Mode: packages.LoadAllSyntax, - Tests: false, // Set to true if you also want to load test files + //Mode: packages.NeedName | packages.NeedFiles | packages.NeedImports | packages.NeedTypes | packages.NeedTypesInfo | packages.LoadSyntax, + Mode: packages.LoadSyntax, + Tests: false, } // Load the package containing the target file or directory poktrollPkgPathPattern := "github.com/pokt-network/poktroll/x/..." - //moduleKeeperPkgPath := filepath.Join(poktrollPkgPathPattern, keeperDir) - //xPkgPath := filepath.Join(poktrollPkgPathPattern, xDir) - //fmt.Printf(">>> pkg path: %s\n", moduleKeeperPkgPath) - //pkgs, err := packages.Load(cfg, moduleKeeperPkgPath) pkgs, err := packages.Load(cfg, poktrollPkgPathPattern) - //pkgs, err := packages.Load(cfg, "github.com/pokt-network/poktroll/x/application") if err != nil { log.Fatalf("Failed to load package: %v", err) } - offendingPkgErrLines := make([]string, 0) + offendingPkgErrLineSet := make(map[string]struct{}) // Iterate over the keeper packages // E.g.: @@ -111,7 +84,6 @@ func checkModule(_ context.Context) error { // - github.com/pokt-network/poktroll/x/gateway/keeper // - ... for _, pkg := range pkgs { - fmt.Printf("pkg: %+v\n", pkg) if pkg.Name != "keeper" { continue } @@ -123,9 +95,6 @@ func checkModule(_ context.Context) error { continue } - // Print the package name and path - //fmt.Printf("Package: %s (Path: %s)\n", pkg.Name, pkg.PkgPath) - // Access type information info := pkg.TypesInfo if info == nil { @@ -149,6 +118,31 @@ func checkModule(_ context.Context) error { return false } + fnNodeTypeObj, ok := typeInfo.Defs[fnNode.Name] //.Type.Results.List[0].Type + if !ok { + fmt.Printf("ERROR: unable to find fnNode type def: %s\n", fnNode.Name.Name) + return true + } + + // Skip methods which are not exported. + if !fnNodeTypeObj.Exported() { + return false + } + + // Skip methods which have no return arguments. + if fnNode.Type.Results == nil { + return false + } + + // TODO_IN_THIS_COMMIT: check the signature of the method to ensure it returns an error type. + fnResultsList := fnNode.Type.Results.List + fnLastResultType := fnResultsList[len(fnResultsList)-1].Type + if fnLastResultIdent, ok := fnLastResultType.(*ast.Ident); ok { + if fnLastResultIdent.Name != "error" { + return false + } + } + fnType := fnNode.Recv.List[0].Type typeIdentNode, ok := fnType.(*ast.Ident) if !ok { @@ -159,27 +153,18 @@ func checkModule(_ context.Context) error { return false } - //fmt.Printf("Found msgServer method %q in %s\n", fnNode.Name.Name, matchFilePath) - //fmt.Printf("in %q in %s\n", fnNode.Name.Name, astFile.Name.Name) - condition := func(returnErrNode ast.Node) func(*ast.Ident, types.Object) bool { return func(sel *ast.Ident, typeObj types.Object) bool { isStatusError := sel.Name == "Error" && typeObj.Pkg().Path() == "google.golang.org/grpc/status" - pos := pkg.Fset.Position(returnErrNode.Pos()) if !isStatusError { - //fmt.Printf("fnNode: %+v", fnNode) - //fmt.Printf("typeIdentNode: %+v", typeIdentNode) - offendingPkgErrLines = append(offendingPkgErrLines, fmt.Sprintf("%s:%d:%d", pos.Filename, pos.Line, pos.Column)) + offendingPkgErrLineSet[pkg.Fset.Position(returnErrNode.Pos()).String()] = struct{}{} } return isStatusError - //return true - //return false } } // Recursively traverse the function body, looking for non-nil error returns. - //var errorReturns []*ast.IfStmt // TODO_IN_THIS_COMMIT: extract --- BEGIN ast.Inspect(fnNode.Body, func(n ast.Node) bool { switch n := n.(type) { @@ -187,119 +172,162 @@ func checkModule(_ context.Context) error { return true // Search for a return statement. case *ast.ReturnStmt: - lastReturnArg := n.Results[len(n.Results)-1] + lastResult := n.Results[len(n.Results)-1] - switch lastReturnArgNode := lastReturnArg.(type) { + switch lastReturnArgNode := lastResult.(type) { // `return nil, err` <-- last arg is an *ast.Ident. case *ast.Ident: - //fmt.Printf("ast.Ident: %T: %+v\n", lastReturnArg, lastReturnArgNode) - //return true - - //defs := typeInfo.Defs[lastReturnArgNode] - //fmt.Printf("type defs: %+v\n", defs) - - //use := typeInfo.Uses[lastReturnArgNode] - //fmt.Printf("type use: %+v\n", use) - // TODO_IN_THIS_COMMIT: No need to check that the last return // arg is an error type if we checked that the function returns // an error as the last arg. - if lastReturnArgNode.Name == "err" { - if lastReturnArgNode.Obj == nil { - return true - } - - // TODO_IN_THIS_COMMIT: factor out and call in a case in the switch above where we handle *ast.AssignStmt - switch node := lastReturnArgNode.Obj.Decl.(type) { - case *ast.AssignStmt: - //fmt.Printf("errAssignStmt found: %+v\n", node) - - selection := typeInfo.Selections[node.Rhs[0].(*ast.CallExpr).Fun.(*ast.SelectorExpr)] - //fmt.Printf("type selection: %+v\n", selection) + //if lastReturnArgNode.Name == "err" { + if lastReturnArgNode.Obj == nil { + return true + } - // TODO_IN_THIS_COMMIT: account for other cases... + // TODO_IN_THIS_COMMIT: factor out and call in a case in the switch above where we handle *ast.AssignStmt + switch lastReturnArgDecl := lastReturnArgNode.Obj.Decl.(type) { + case *ast.AssignStmt: + switch lastReturnArg := lastReturnArgDecl.Rhs[0].(type) { + case *ast.Ident: + // TODO_IN_THIS_COMMIT: recurse into the outer ident case. + // TODO_IN_THIS_COMMIT: recurse into the outer ident case. + // TODO_IN_THIS_COMMIT: recurse into the outer ident case. + // TODO_IN_THIS_COMMIT: recurse into the outer ident case. + return true + case *ast.SelectorExpr: + // TODO_IN_THIS_COMMIT: recurse into the outer ident case. + // TODO_IN_THIS_COMMIT: recurse into the outer ident case. + // TODO_IN_THIS_COMMIT: recurse into the outer ident case. + // TODO_IN_THIS_COMMIT: recurse into the outer ident case. + selection := typeInfo.Selections[lastReturnArg] + traverseFunctionBody(selection.Obj().(*types.Func), pkgs, 0, condition(lastReturnArgNode)) + return false + case *ast.CallExpr: + switch lastReturnArgFun := lastReturnArg.Fun.(type) { + case *ast.SelectorExpr: + var selection *types.Selection + for _, srcPkg := range pkgs { + selection = srcPkg.TypesInfo.Selections[lastReturnArgFun] + if selection != nil { + break + } + } + //fmt.Printf("type selection: %+v\n", selection) + + // TODO_IN_THIS_COMMIT: account for other cases... + + if selection == nil { + //fmt.Printf("ERROR: selection is nil\n") + printNodeSource("lastReturnArgFun selection is nil", pkg, lastReturnArgFun) + return false + } + + traverseFunctionBody(selection.Obj().(*types.Func), pkgs, 0, condition(lastReturnArgNode)) - if selection == nil { - fmt.Printf("ERROR: selection is nil\n") - //return true return false - } - - traverseFunctionBody(selection.Obj().(*types.Func), pkgs, 0, condition(lastReturnArgNode)) + default: + printNodeSource( + "lastReturnArgFun", + pkg, lastReturnArgFun, + ) + return true + } + default: + printNodeSource( + "lastReturnArg", + pkg, lastReturnArg, + ) return false - //default: - //return true } - + case *ast.ValueSpec: + // Ignore return false - //return true + case *ast.Field: + // Ignore + return false + default: + printNodeSource( + fmt.Sprintf("unknown return arg decl node type: %T: %+v", lastReturnArgNode.Obj.Decl, lastReturnArgNode.Obj.Decl), + pkg, lastReturnArgNode.Obj.Decl, + ) + return true } + // `return nil, types.ErrXXX.Wrapf(...)` <-- last arg is a *ast.CallExpr. case *ast.CallExpr: - //fmt.Printf("ast.CallExpr: %T: %+v\n", lastReturnArg, lastReturnArgNode) - TraverseCallStack(lastReturnArgNode, pkgs, 0, condition(lastReturnArgNode)) - return false - //return true + + case *ast.SelectorExpr: + var selection *types.Selection + for _, srcPkg := range pkgs { + selection = srcPkg.TypesInfo.Selections[lastReturnArgNode] + if selection != nil { + break + } + } + + if selection == nil { + printNodeSource( + fmt.Sprintf("lastReturnArgNode selection is nil: %T: %+v", lastReturnArgNode, lastReturnArgNode), + pkg, lastReturnArgNode, + ) + + return true + } + + traverseFunctionBody(selection.Obj().(*types.Func), pkgs, 0, condition(lastReturnArgNode)) + default: - //panic(fmt.Sprintf("unknown AST node type: %T: %+v", lastReturnArg, lastReturnArg)) - fmt.Printf("unknown AST node type: %T: %+v\n", lastReturnArg, lastReturnArg) + printNodeSource( + fmt.Sprintf("unknown return arg node type: %T: %+v", lastResult, lastResult), + pkg, lastResult, + ) + return false } - - return false - //return true } return true }) // --- END - //// TODO_IN_THIS_COMMIT: extract --- BEGIN - //for _, errorReturn := range errorReturns { - // // Check if the error return is wrapped in a gRPC status error. - // //ifStmt, ok := errorReturn.If.(*ast.IfStmt) - // //if !ok { - // // return false - // //} - // ifStmt := errorReturn //.If.(*ast.IfStmt) - // - // switch node := ifStmt.Cond.(type) { - // case *ast.BinaryExpr: - // if node.Op != token.NEQ { - // return false - // } - // } - //} - //// --- END - return false - //return true }) } } // --- END + // TODO_IN_THIS_COMMIT: extract --- BEGIN + // TODO_IN_THIS_COMMIT: figure out why there are duplicate offending lines. // Print offending lines in package // TODO_IN_THIS_COMMIT: refactor to const. pkgsPattern := "github.com/pokt-network/poktroll/x/..." - numOffendingLines := len(offendingPkgErrLines) + numOffendingLines := len(offendingPkgErrLineSet) if numOffendingLines == 0 { - fmt.Printf("No offending lines in %s\n", pkgsPattern) + fmt.Printf("No offending lines in %q\n", pkgsPattern) } else { + offendingPkgErrLines := make([]string, 0, len(offendingPkgErrLineSet)) + for offendingPkgErrLine := range offendingPkgErrLineSet { + offendingPkgErrLines = append(offendingPkgErrLines, offendingPkgErrLine) + } + + sort.Strings(offendingPkgErrLines) + msg := fmt.Sprintf( - "\nFound %d offending lines in %s:", + "\nFound %d offending lines in %q", numOffendingLines, pkgsPattern, ) fmt.Printf( - "\n%s\n%s\n%s\n", + "%s:\n%s%s\n", msg, strings.Join(offendingPkgErrLines, "\n"), msg, ) } + // --- END return nil } @@ -422,3 +450,17 @@ func traverseFunctionBody(fn *types.Func, pkgs []*packages.Package, indent int, func indentSpaces(indent int) string { return strings.Repeat(" ", indent) } + +// TODO_IN_THIS_COMMIT: remove or move to a testutil package. +func printNodeSource(msg string, pkg *packages.Package, queryNode any) { + node, ok := queryNode.(ast.Node) + if !ok { + fmt.Printf("ERROR: queryNode is not an ast.Node: %T: %+v\n", queryNode, queryNode) + return + } + + fmt.Printf( + "not traversing %+v\n\t%s\n", + pkg.Fset.Position(node.Pos()), msg, + ) +} From 6378150e50c59d1aeceeca844394841bce739542 Mon Sep 17 00:00:00 2001 From: Bryan White Date: Mon, 18 Nov 2024 01:22:32 +0100 Subject: [PATCH 06/18] wip: todo --- tools/scripts/protocheck/cmd/status_errors.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tools/scripts/protocheck/cmd/status_errors.go b/tools/scripts/protocheck/cmd/status_errors.go index 25f871cc0..c1ec95862 100644 --- a/tools/scripts/protocheck/cmd/status_errors.go +++ b/tools/scripts/protocheck/cmd/status_errors.go @@ -35,6 +35,10 @@ func init() { func runStatusErrorsCheck(cmd *cobra.Command, _ []string) error { ctx := cmd.Context() + // TODO_IN_THIS_COMMIT: add hack/work-around to temporarily strip patch version from go.mod. + // TODO_IN_THIS_COMMIT: add hack/work-around to temporarily strip patch version from go.mod. + // TODO_IN_THIS_COMMIT: add hack/work-around to temporarily strip patch version from go.mod. + // TODO_IN_THIS_COMMIT: to support this, need to load all modules but only inspect target module. //if flagModule != "*" { // ... From 1e823d6886977496d06db637413db5e737b0d094 Mon Sep 17 00:00:00 2001 From: Bryan White Date: Mon, 25 Nov 2024 13:17:35 +0100 Subject: [PATCH 07/18] wip:... --- go.mod | 2 +- tools/scripts/protocheck/cmd/ast_tracing.go | 215 ++++++++++ tools/scripts/protocheck/cmd/status_errors.go | 393 +++++++----------- 3 files changed, 369 insertions(+), 241 deletions(-) create mode 100644 tools/scripts/protocheck/cmd/ast_tracing.go diff --git a/go.mod b/go.mod index 79bd630b7..024364e1b 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/pokt-network/poktroll -go 1.23 +go 1.23.0 // replace ( // DEVELOPER_TIP: Uncomment to use a local copy of shannon-sdk for development purposes. diff --git a/tools/scripts/protocheck/cmd/ast_tracing.go b/tools/scripts/protocheck/cmd/ast_tracing.go new file mode 100644 index 000000000..0634e34ad --- /dev/null +++ b/tools/scripts/protocheck/cmd/ast_tracing.go @@ -0,0 +1,215 @@ +package main + +import ( + "fmt" + "go/ast" + "go/token" + "strings" + + "golang.org/x/tools/go/packages" +) + +const grpcStatusImportPath = "google.golang.org/grpc/status" + +// Helper function to trace selector expressions +func traceSelectorExpr( + expr *ast.SelectorExpr, + pkgs []*packages.Package, + candidateNode ast.Node, + offendingPositions map[string]struct{}, +) bool { + // Resolve the base expression + switch x := expr.X.(type) { + case *ast.Ident: // e.g., `pkg.Func` + for _, pkg := range pkgs { + if obj := pkg.TypesInfo.Uses[x]; obj != nil { + pkgParts := strings.Split(obj.String(), " ") + + var pkgStr string + switch { + // e.g., package (error) google.golang.org/grpc/status + case strings.HasPrefix(obj.String(), "package ("): + pkgStr = pkgParts[2] + // e.g. package fmt + default: + pkgStr = pkgParts[1] + } + //fmt.Printf(">>> pkgStr: %s\n", pkgStr) + + logger := logger.With( + "node_type", fmt.Sprintf("%T", x), + "position", fmt.Sprintf(" %s ", pkg.Fset.Position(x.Pos()).String()), + "package", strings.Trim(pkgStr, "\"()"), + ) + logger.Debug().Msg("tracing selector expression") + + //if obj := pkg.TypesInfo.Uses[x]; obj A!= nil { + //fmt.Printf("Base identifier %s resolved to: %s\n", x.Name, obj) + //fmt.Printf(">>> obj.String(): %s\n", obj.String()) + //fmt.Printf(">>> strings.Contains(obj.String(), grpcStatusImportPath): %v\n", strings.Contains(obj.String(), grpcStatusImportPath)) + isMatch := strings.Contains(obj.String(), grpcStatusImportPath) && + expr.Sel.Name == "Error" + //if isMatch { + if isMatch { + candidateNodePosition := pkg.Fset.Position(candidateNode.Pos()).String() + //inspectionPosition := pkg.Fset.Position(x.Pos()).String() + //fmt.Printf("Found target function %s in selector at %s\n", expr.Sel.Name, pkg.Fset.Position(expr.Pos())) + //fmt.Printf("Found offending candidate %s in selector at %s\n", expr.Sel.Name, pkg.Fset.Position(candidateNode.Pos()).String()) + //fmt.Printf("!!! - offendingPositions: %+v\n", offendingPositions) + //fmt.Printf("!!! - candPosition: %s", currentPosition) + //logger.Debug(). + // Str("candidate_pos", fmt.Sprintf(" %s ", candidateNodePosition)). + // //Str("offending_positions", fmt.Sprintf("%+v", offendingPositions)). + // Send() + if _, ok := offendingPositions[candidateNodePosition]; ok { + logger.Debug().Msgf("exhonerating %s", candidateNodePosition) + delete(offendingPositions, candidateNodePosition) + } + return false + } + //} + //case *ast.SelectorExpr: // e.g., `obj.Method.Func` + // if traceSelectorExpr(x, info, fset) { + // return true + // } + } else if obj = pkg.TypesInfo.Defs[x]; obj != nil { + logger.Warn().Msgf("no use but def: %+v", obj) + } + } + case *ast.SelectorExpr: // e.g., `obj.Method.Func` + return traceSelectorExpr(x, pkgs, candidateNode, offendingPositions) + case *ast.CallExpr: + logger.Debug().Msgf("tracing call expression: %+v", expr) + switch callExpr := x.Fun.(type) { + case *ast.SelectorExpr: // e.g., `obj.Method.Func` + return traceSelectorExpr(callExpr, pkgs, candidateNode, offendingPositions) + default: + logger.Warn().Msgf("skipping sub-selector call expression X type: %T", x) + } + default: + logger.Warn().Msgf("skipping selector expression X type: %T", x) + } + return true +} + +// Trace any expression recursively, including selector expressions +func traceExpressionStack( + exprToTrace ast.Expr, + pkgs []*packages.Package, + candidateNode ast.Node, + offendingPositions map[string]struct{}, +) bool { + if exprToTrace == nil { + return false + } + + logger := logger.With( + "node_type", fmt.Sprintf("%T", exprToTrace), + //"position", pkg.Fset.Position(x.Pos()).String(), + //"package", strings.Trim(strings.Split(obj.String(), " ")[2], "\"()"), + ) + logger.Debug().Msg("tracing expression stack") + + switch expr := exprToTrace.(type) { + case nil: + return false + case *ast.CallExpr: + if sel, ok := expr.Fun.(*ast.SelectorExpr); ok { + //logger.Debug().Msg("tracing selector expression") + return traceSelectorExpr(sel, pkgs, candidateNode, offendingPositions) + } + //logger.Debug().Msgf("tracing expression args: %+v", expr) + //for _, arg := range expr.Args { + // // TODO_IN_THIS_COMMIT: return traceExpressionStack... ? + // traceExpressionStack(arg, pkgs, candidateNode, offendingPositions) + // //return true + //} + return true + case *ast.BinaryExpr: + logger.Debug().Msg("tracing binary expression") + // TODO_IN_THIS_COMMIT: return traceExpressionStack... ? + if traceExpressionStack(expr.X, pkgs, candidateNode, offendingPositions) { + traceExpressionStack(expr.Y, pkgs, candidateNode, offendingPositions) + //return true + } + return true + case *ast.ParenExpr: + logger.Debug().Msg("tracing paren expression") + return traceExpressionStack(expr.X, pkgs, candidateNode, offendingPositions) + case *ast.SelectorExpr: + logger.Debug().Msg("tracing selector expression") + return traceSelectorExpr(expr, pkgs, candidateNode, offendingPositions) + case *ast.Ident: + logger.Debug().Msg("tracing ident") + //fmt.Printf(">>> exprToTrace: %+v\n", expr) + //var srcPkg *packages.Package + for _, pkg := range pkgs { + //if obj := pkg.TypesInfo.Defs[expr]; obj != nil { + // srcPkg = pkg + for _, fileNode := range pkg.Syntax { + declOrAssign, _ := findDeclOrAssign(expr, fileNode, pkg.Fset) + if declOrAssign == nil { + continue + } + logger.Debug(). + Str("pkg_path", pkg.PkgPath). + Str("file_path", pkg.Fset.File(fileNode.Pos()).Name()). + Str("decl_or_assign_pos", pkg.Fset.Position(declOrAssign.Pos()).String()). + Send() + //Msgf("found decl or assign: %+v", declOrAssign) + traceExpressionStack(declOrAssign, pkgs, candidateNode, offendingPositions) + } + //} + } + //if srcPkg == nil { + // logger.Warn().Msgf("no pkg found for expr: %+v", expr) + //} + return true + //case *ast.SliceExpr: + // logger.Debug().Msgf("tracing slice expression: %+v", expr) + // return true + default: + logger.Warn().Msgf("unknown node type 2: %T", exprToTrace) + return true + } +} + +// Find the declaration or assignment of an identifier +func findDeclOrAssign(ident *ast.Ident, node ast.Node, fset *token.FileSet) (ast.Expr, *token.Position) { + //logger.Debug().Msg("finding decl or assign") + + //fmt.Println("!!!! findDeclOrAssign begin") + var declOrAssign ast.Expr + var foundPos token.Position + + //fmt.Println("!!!! findDeclOrAssign inspect") + ast.Inspect(node, func(n ast.Node) bool { + switch stmt := n.(type) { + case *ast.AssignStmt: // Look for assignments + //fmt.Println("!!!! findDeclOrAssign case assign") + for i, lhs := range stmt.Lhs { + if lhsIdent, ok := lhs.(*ast.Ident); ok && lhsIdent.Name == ident.Name { + //fmt.Printf("len(rhs): %d len(lhs): %d\n", len(stmt.Rhs), len(stmt.Lhs)) + if len(stmt.Lhs) != len(stmt.Rhs) { + declOrAssign = stmt.Rhs[0] + } else { + declOrAssign = stmt.Rhs[i] + } + foundPos = fset.Position(stmt.Pos()) + } + } + case *ast.ValueSpec: // Look for declarations with initialization + //fmt.Println("!!!! findDeclOrAssign case value") + for i, name := range stmt.Names { + if name.Name == ident.Name && i < len(stmt.Values) { + declOrAssign = stmt.Values[i] + foundPos = fset.Position(stmt.Pos()) + return false + } + } + } + return true + }) + + return declOrAssign, &foundPos +} diff --git a/tools/scripts/protocheck/cmd/status_errors.go b/tools/scripts/protocheck/cmd/status_errors.go index c1ec95862..74d7576cd 100644 --- a/tools/scripts/protocheck/cmd/status_errors.go +++ b/tools/scripts/protocheck/cmd/status_errors.go @@ -4,33 +4,55 @@ import ( "context" "fmt" "go/ast" - "go/types" - "log" + "os" + "path/filepath" "sort" "strings" + "github.com/rs/zerolog" "github.com/spf13/cobra" "golang.org/x/tools/go/packages" + + "github.com/pokt-network/poktroll/pkg/polylog" + "github.com/pokt-network/poktroll/pkg/polylog/polyzero" ) var ( flagModule = "module" flagModuleShorthand = "m" - flagModuleValue = "*" - flagModuleUsage = "If present, only check message handlers of the given module." + // TODO_IN_THIS_COMMIT: support this flag. + flagModuleValue = "*" + flagModuleUsage = "If present, only check message handlers of the given module." + + flagLogLevel = "log-level" + flagLogLevelShorthand = "l" + flagLogLevelValue = "info" + flagLogLevelUsage = "The logging level (debug|info|warn|error)" statusErrorsCheckCmd = &cobra.Command{ - Use: "status-errors [flags]", - Short: "Checks that all message handler function errors are wrapped in gRPC status errors.", - RunE: runStatusErrorsCheck, + Use: "status-errors [flags]", + Short: "Checks that all message handler function errors are wrapped in gRPC status errors.", + PreRun: setupLogger, + RunE: runStatusErrorsCheck, } + + logger polylog.Logger + offendingPkgErrLineSet = make(map[string]struct{}) ) func init() { - statusErrorsCheckCmd.Flags().StringVarP(&flagModule, flagModuleShorthand, "m", flagModuleValue, flagModuleUsage) + statusErrorsCheckCmd.Flags().StringVarP(&flagModuleValue, flagModule, flagModuleShorthand, flagModuleValue, flagModuleUsage) + statusErrorsCheckCmd.Flags().StringVarP(&flagLogLevelValue, flagLogLevel, flagLogLevelShorthand, flagLogLevelValue, flagLogLevelUsage) rootCmd.AddCommand(statusErrorsCheckCmd) } +func setupLogger(_ *cobra.Command, _ []string) { + logger = polyzero.NewLogger( + polyzero.WithWriter(zerolog.ConsoleWriter{Out: os.Stderr}), + polyzero.WithLevel(polyzero.ParseLevel(flagLogLevelValue)), + ) +} + // TODO_IN_THIS_COMMIT: pre-run: drop patch version in go.mod; post-run: restore. func runStatusErrorsCheck(cmd *cobra.Command, _ []string) error { ctx := cmd.Context() @@ -54,6 +76,13 @@ func runStatusErrorsCheck(cmd *cobra.Command, _ []string) error { return nil } +// TODO_IN_THIS_COMMIT: 2-step check +// 1. Collect all return statements from `msgServer` methods and `Keeper` methods in `query_*.go` files. +// 2. For each return statement, check the type: +// *ast.Ident: search this package ... +// *ast.SelectorExpr: search the package of its declaration... +// ...for an *ast.AssignStmt with the given *ast.Ident as the left-hand side. + func checkModule(_ context.Context) error { // 0. Get the package info for the given module's keeper package. @@ -75,13 +104,13 @@ func checkModule(_ context.Context) error { // Load the package containing the target file or directory poktrollPkgPathPattern := "github.com/pokt-network/poktroll/x/..." + logger.Info().Msgf("Loading package(s) in %s", poktrollPkgPathPattern) + pkgs, err := packages.Load(cfg, poktrollPkgPathPattern) if err != nil { - log.Fatalf("Failed to load package: %v", err) + return fmt.Errorf("failed to load package: %w", err) } - offendingPkgErrLineSet := make(map[string]struct{}) - // Iterate over the keeper packages // E.g.: // - github.com/pokt-network/poktroll/x/application/keeper @@ -94,7 +123,7 @@ func checkModule(_ context.Context) error { if len(pkg.Errors) > 0 { for _, pkgErr := range pkg.Errors { - log.Printf("Package error: %v", pkgErr) + logger.Error().Msgf("Package error: %v", pkgErr) } continue } @@ -102,15 +131,37 @@ func checkModule(_ context.Context) error { // Access type information info := pkg.TypesInfo if info == nil { - log.Println("No type information available") + logger.Warn().Msgf("No type information available, skipping package %q", pkg.PkgPath) continue } - typeInfo := pkg.TypesInfo // --- END + //filenames := make([]string, 0) + //for _, astFile := range pkg.Syntax { + // filenames = append(filenames, filepath.Base(pkg.Fset.Position(astFile.Pos()).Filename)) + //} + //fmt.Printf(">>> filenames:\n%s\n", strings.Join(filenames, "\n")) + // TODO_IN_THIS_COMMIT: extract --- BEGIN + // TODO_IN_THIS_COMMIT: check the filename and only inspect each once! for _, astFile := range pkg.Syntax { + filename := pkg.Fset.Position(astFile.Pos()).Filename + + // Ignore protobuf generated files. + if strings.HasSuffix(filepath.Base(filename), ".pb.go") { + continue + } + if strings.HasSuffix(filepath.Base(filename), ".pb.gw.go") { + continue + } + + // TODO_IN_THIS_COMMIT: remove! + //fmt.Printf(">>> filename: %s\n", filename) + //if filename != "/home/bwhite/Projects/pokt/poktroll/x/application/keeper/msg_server_delegate_to_gateway.go" { + // continue + //} + ast.Inspect(astFile, func(n ast.Node) bool { fnNode, ok := n.(*ast.FuncDecl) if !ok { @@ -122,7 +173,7 @@ func checkModule(_ context.Context) error { return false } - fnNodeTypeObj, ok := typeInfo.Defs[fnNode.Name] //.Type.Results.List[0].Type + fnNodeTypeObj, ok := info.Defs[fnNode.Name] //.Type.Results.List[0].Type if !ok { fmt.Printf("ERROR: unable to find fnNode type def: %s\n", fnNode.Name.Name) return true @@ -138,6 +189,11 @@ func checkModule(_ context.Context) error { return false } + //fmt.Printf(">>> fNode.Name.Name: %s\n", fnNode.Name.Name) + //if fnNode.Name.Name != "AllApplications" { + // return false + //} + // TODO_IN_THIS_COMMIT: check the signature of the method to ensure it returns an error type. fnResultsList := fnNode.Type.Results.List fnLastResultType := fnResultsList[len(fnResultsList)-1].Type @@ -153,34 +209,63 @@ func checkModule(_ context.Context) error { return false } - if typeIdentNode.Name != "msgServer" { + fnPos := pkg.Fset.Position(fnNode.Pos()) + //fmt.Printf(">>> fnNode.Pos(): %s\n", fnPos.String()) + fnFilename := filepath.Base(fnPos.Filename) + fnSourceHasQueryHandlerPrefix := strings.HasPrefix(fnFilename, "query_") + //fnSourceHasQueryHandlerPrefix := false + + if typeIdentNode.Name != "msgServer" && !fnSourceHasQueryHandlerPrefix { return false } - condition := func(returnErrNode ast.Node) func(*ast.Ident, types.Object) bool { - return func(sel *ast.Ident, typeObj types.Object) bool { - isStatusError := sel.Name == "Error" && typeObj.Pkg().Path() == "google.golang.org/grpc/status" - if !isStatusError { - offendingPkgErrLineSet[pkg.Fset.Position(returnErrNode.Pos()).String()] = struct{}{} - } - - return isStatusError + // TODO_IN_THIS_COMMIT: figure out why this file hangs the command. + isExcludedFile := false + for _, excludedFile := range []string{"query_get_session.go"} { + if fnFilename == excludedFile { + isExcludedFile = true } } + if isExcludedFile { + return false + } // Recursively traverse the function body, looking for non-nil error returns. // TODO_IN_THIS_COMMIT: extract --- BEGIN + //fmt.Printf(">>> walking func from file: %s\n", pkg.Fset.Position(astFile.Pos()).Filename) ast.Inspect(fnNode.Body, func(n ast.Node) bool { + if n == nil { + return false + } + //inspectPos := pkg.Fset.Position(n.Pos()) + //fmt.Printf(">>> inspecting %T at %s\n", n, inspectPos) + + //inspectLocation := inspectPos.String() + //if inspectLocation == "/home/bwhite/Projects/pokt/poktroll/x/proof/keeper/msg_server_create_claim.go:44:3" { + // fmt.Printf(">>> found it!\n") + //} + switch n := n.(type) { case *ast.BlockStmt: return true // Search for a return statement. case *ast.ReturnStmt: lastResult := n.Results[len(n.Results)-1] + inspectPosition := pkg.Fset.Position(lastResult.Pos()).String() + + logger := logger.With( + "node_type", fmt.Sprintf("%T", lastResult), + "inspectPosition", fmt.Sprintf(" %s ", inspectPosition), + ) switch lastReturnArgNode := lastResult.(type) { // `return nil, err` <-- last arg is an *ast.Ident. case *ast.Ident: + //logger.Debug().Fields(map[string]any{ + // "node_type": fmt.Sprintf("%T", lastReturnArgNode), + // "inspectPosition": pkg.Fset.Position(lastReturnArgNode.Pos()).String(), + //}).Msg("traversing ast node") + // TODO_IN_THIS_COMMIT: No need to check that the last return // arg is an error type if we checked that the function returns // an error as the last arg. @@ -189,108 +274,49 @@ func checkModule(_ context.Context) error { return true } - // TODO_IN_THIS_COMMIT: factor out and call in a case in the switch above where we handle *ast.AssignStmt - switch lastReturnArgDecl := lastReturnArgNode.Obj.Decl.(type) { - case *ast.AssignStmt: - switch lastReturnArg := lastReturnArgDecl.Rhs[0].(type) { - case *ast.Ident: - // TODO_IN_THIS_COMMIT: recurse into the outer ident case. - // TODO_IN_THIS_COMMIT: recurse into the outer ident case. - // TODO_IN_THIS_COMMIT: recurse into the outer ident case. - // TODO_IN_THIS_COMMIT: recurse into the outer ident case. - return true - case *ast.SelectorExpr: - // TODO_IN_THIS_COMMIT: recurse into the outer ident case. - // TODO_IN_THIS_COMMIT: recurse into the outer ident case. - // TODO_IN_THIS_COMMIT: recurse into the outer ident case. - // TODO_IN_THIS_COMMIT: recurse into the outer ident case. - selection := typeInfo.Selections[lastReturnArg] - traverseFunctionBody(selection.Obj().(*types.Func), pkgs, 0, condition(lastReturnArgNode)) - return false - case *ast.CallExpr: - switch lastReturnArgFun := lastReturnArg.Fun.(type) { - case *ast.SelectorExpr: - var selection *types.Selection - for _, srcPkg := range pkgs { - selection = srcPkg.TypesInfo.Selections[lastReturnArgFun] - if selection != nil { - break - } - } - //fmt.Printf("type selection: %+v\n", selection) - - // TODO_IN_THIS_COMMIT: account for other cases... - - if selection == nil { - //fmt.Printf("ERROR: selection is nil\n") - printNodeSource("lastReturnArgFun selection is nil", pkg, lastReturnArgFun) - return false - } - - traverseFunctionBody(selection.Obj().(*types.Func), pkgs, 0, condition(lastReturnArgNode)) - - return false - default: - printNodeSource( - "lastReturnArgFun", - pkg, lastReturnArgFun, - ) - - return true - } - default: - printNodeSource( - "lastReturnArg", - pkg, lastReturnArg, - ) - return false - } - case *ast.ValueSpec: - // Ignore - return false - case *ast.Field: - // Ignore - return false - default: - printNodeSource( - fmt.Sprintf("unknown return arg decl node type: %T: %+v", lastReturnArgNode.Obj.Decl, lastReturnArgNode.Obj.Decl), - pkg, lastReturnArgNode.Obj.Decl, - ) + def := pkg.TypesInfo.Uses[lastReturnArgNode] + if def == nil { + logger.Warn().Msg("def is nil") return true } - // `return nil, types.ErrXXX.Wrapf(...)` <-- last arg is a *ast.CallExpr. - case *ast.CallExpr: - TraverseCallStack(lastReturnArgNode, pkgs, 0, condition(lastReturnArgNode)) - return false - - case *ast.SelectorExpr: - var selection *types.Selection - for _, srcPkg := range pkgs { - selection = srcPkg.TypesInfo.Selections[lastReturnArgNode] - if selection != nil { - break - } + if def.Type().String() != "error" { + //logger.Warn().Msg("def is not error") + //inspectPosition := pkg.Fset.Position(lastReturnArgNode.Pos()).String() + //break + return false } - if selection == nil { - printNodeSource( - fmt.Sprintf("lastReturnArgNode selection is nil: %T: %+v", lastReturnArgNode, lastReturnArgNode), - pkg, lastReturnArgNode, - ) + logger.Debug().Msg("appending potential offending line") + appendOffendingLine(inspectPosition) + traceExpressionStack(lastReturnArgNode, pkgs, lastReturnArgNode, offendingPkgErrLineSet) + return true - return true - } + // `return nil, types.ErrXXX.Wrapf(...)` <-- last arg is a *ast.CallExpr. + case *ast.CallExpr: + //logger.Debug().Msg("inspecting ast node") + logger.Debug().Msg("appending potential offending line") + appendOffendingLine(inspectPosition) + traceExpressionStack(lastReturnArgNode, pkgs, lastReturnArgNode, offendingPkgErrLineSet) + //TraverseCallStack(lastReturnArgNode, pkgs, 0, condition(lastReturnArgNode)) + //return false + return true - traverseFunctionBody(selection.Obj().(*types.Func), pkgs, 0, condition(lastReturnArgNode)) + case *ast.SelectorExpr: + logger.Debug().Msg("appending potential offending line") + appendOffendingLine(inspectPosition) + traceSelectorExpr(lastReturnArgNode, pkgs, lastReturnArgNode, offendingPkgErrLineSet) + return true default: - printNodeSource( - fmt.Sprintf("unknown return arg node type: %T: %+v", lastResult, lastResult), - pkg, lastResult, - ) - return false + logger.Warn().Msg("NOT traversing ast node") + return true } + + //logger.Debug().Msg("appending potential offending line") + //appendOffendingLine(inspectPosition) + // + //return true } return true @@ -302,6 +328,7 @@ func checkModule(_ context.Context) error { } } + // --- END // TODO_IN_THIS_COMMIT: extract --- BEGIN @@ -321,11 +348,11 @@ func checkModule(_ context.Context) error { sort.Strings(offendingPkgErrLines) msg := fmt.Sprintf( - "\nFound %d offending lines in %q", + "Found %d offending lines in %q", numOffendingLines, pkgsPattern, ) - fmt.Printf( - "%s:\n%s%s\n", + logger.Info().Msgf( + "%s:\n%s\n%s", msg, strings.Join(offendingPkgErrLines, "\n"), msg, @@ -336,125 +363,6 @@ func checkModule(_ context.Context) error { return nil } -// TraverseCallStack recursively traverses the call stack starting from a *ast.CallExpr. -func TraverseCallStack(call *ast.CallExpr, pkgs []*packages.Package, indent int, condition func(*ast.Ident, types.Object) bool) { - fun := call.Fun - switch fn := fun.(type) { - case *ast.Ident: - // Local or top-level function - - var useObj types.Object - for _, pkg := range pkgs { - useObj = pkg.TypesInfo.Uses[fn] - if useObj != nil { - break - } - } - if useObj != nil { - //fmt.Printf("%sFunction: %s\n", indentSpaces(indent), useObj.Name()) - if fnDecl, ok := useObj.(*types.Func); ok { - traverseFunctionBody(fnDecl, pkgs, indent+2, condition) - } - } - case *ast.SelectorExpr: - // Method call like obj.Method() - sel := fn.Sel - var selection *types.Selection - for _, pkg := range pkgs { - selection = pkg.TypesInfo.Selections[fn] - if selection != nil { - break - } - } - if selection != nil { - // Instance method - //fmt.Printf("%sMethod: %s on %s\n", indentSpaces(indent), sel.Name, selection.Recv()) - if method, ok := selection.Obj().(*types.Func); ok { - traverseFunctionBody(method, pkgs, indent+2, condition) - } - } else { - // Static or package-level call - var useObj types.Object - for _, pkg := range pkgs { - useObj = pkg.TypesInfo.Uses[sel] - if useObj != nil { - break - } - } - if useObj != nil { - //fmt.Printf("%sFunction: %s (package-level: %s)\n", indentSpaces(indent), sel.Name, useObj.Pkg().Path()) - if condition(sel, useObj) { - //fmt.Println(">>> STATUS ERROR FOUND!") - return - } - - if fnDecl, ok := useObj.(*types.Func); ok { - traverseFunctionBody(fnDecl, pkgs, indent+2, condition) - } - } - } - default: - fmt.Printf("%sUnknown function type: %T\n", indentSpaces(indent), fun) - } - - // Recursively inspect arguments for nested calls - for _, arg := range call.Args { - if nestedCall, ok := arg.(*ast.CallExpr); ok { - TraverseCallStack(nestedCall, pkgs, indent+2, condition) - } - } -} - -// traverseFunctionBody analyzes the body of a function or method to find further calls. -func traverseFunctionBody(fn *types.Func, pkgs []*packages.Package, indent int, condition func(*ast.Ident, types.Object) bool) { - //fmt.Printf("fn package path: %s\n", fn.Pkg().Path()) - //fmt.Printf("path has prefix: %v\n", strings.HasPrefix(fn.Pkg().Path(), "github.com/pokt-network/poktroll")) - // Don't traverse beyond poktroll module root (i.e. assume deps won't return status errors). - if !strings.HasPrefix(fn.Pkg().Path(), "github.com/pokt-network/poktroll") { - return - } - - // TODO_IN_THIS_COMMIT: Implement & log when this happens. - // DEV_NOTE: If targetFileName is not present in any package, - // we assume that a status error will not be returned by the - // function; so we MUST mark it as offending. - - for _, pkg := range pkgs { - // Find the declaration of the function in the AST - for _, file := range pkg.Syntax { - ast.Inspect(file, func(node ast.Node) bool { - funcDecl, ok := node.(*ast.FuncDecl) - if !ok { - return true // Not the target function, continue - } - targetFileName := pkg.Fset.Position(fn.Pos()).Filename - nodeFileName := pkg.Fset.Position(funcDecl.Pos()).Filename - //fmt.Printf("nodeFileName: %s\n", nodeFileName) - if nodeFileName != targetFileName { - return true // Not the target function, continue - } - - if funcDecl.Name.Name == fn.Name() { - // Found the function, inspect its body for calls - ast.Inspect(funcDecl.Body, func(n ast.Node) bool { - if call, ok := n.(*ast.CallExpr); ok { - TraverseCallStack(call, pkgs, indent, condition) - } - return true - }) - return false // Stop after finding the target function - } - return true - }) - } - } -} - -// Helper function to generate indentation -func indentSpaces(indent int) string { - return strings.Repeat(" ", indent) -} - // TODO_IN_THIS_COMMIT: remove or move to a testutil package. func printNodeSource(msg string, pkg *packages.Package, queryNode any) { node, ok := queryNode.(ast.Node) @@ -463,8 +371,13 @@ func printNodeSource(msg string, pkg *packages.Package, queryNode any) { return } - fmt.Printf( - "not traversing %+v\n\t%s\n", + logger.Warn().Msgf( + "not traversing %+v\n\t%s", pkg.Fset.Position(node.Pos()), msg, ) } + +// TODO_IN_THIS_COMMIT: move & godoc... +func appendOffendingLine(sourceLine string) { + offendingPkgErrLineSet[sourceLine] = struct{}{} +} From a9c2e2ec2d30a3dee930f047ff9886c25e9d9ff8 Mon Sep 17 00:00:00 2001 From: Bryan White Date: Mon, 25 Nov 2024 13:18:20 +0100 Subject: [PATCH 08/18] wip: chore: support custom polyzero writer --- pkg/polylog/polyzero/options.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pkg/polylog/polyzero/options.go b/pkg/polylog/polyzero/options.go index 5f474938d..93f9ef7f8 100644 --- a/pkg/polylog/polyzero/options.go +++ b/pkg/polylog/polyzero/options.go @@ -50,3 +50,11 @@ func WithSetupFn(fn func(logger *zerolog.Logger)) polylog.LoggerOption { fn(&logger.(*zerologLogger).Logger) } } + +// TODO_IN_THIS_COMMIT: godoc & test... +func WithWriter(writer io.Writer) polylog.LoggerOption { + return func(logger polylog.Logger) { + zl := logger.(*zerologLogger).Logger + logger.(*zerologLogger).Logger = zl.Output(writer) + } +} From 16ba81256e7e01e16c640bc2611040e0555eee54 Mon Sep 17 00:00:00 2001 From: Bryan White Date: Mon, 25 Nov 2024 15:12:30 +0100 Subject: [PATCH 09/18] wip: ... --- tools/scripts/protocheck/cmd/status_errors.go | 56 ++++++++++++------- 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/tools/scripts/protocheck/cmd/status_errors.go b/tools/scripts/protocheck/cmd/status_errors.go index 74d7576cd..e06136c66 100644 --- a/tools/scripts/protocheck/cmd/status_errors.go +++ b/tools/scripts/protocheck/cmd/status_errors.go @@ -48,7 +48,7 @@ func init() { func setupLogger(_ *cobra.Command, _ []string) { logger = polyzero.NewLogger( - polyzero.WithWriter(zerolog.ConsoleWriter{Out: os.Stderr}), + polyzero.WithWriter(zerolog.ConsoleWriter{Out: os.Stdout}), polyzero.WithLevel(polyzero.ParseLevel(flagLogLevelValue)), ) } @@ -57,6 +57,22 @@ func setupLogger(_ *cobra.Command, _ []string) { func runStatusErrorsCheck(cmd *cobra.Command, _ []string) error { ctx := cmd.Context() + // TODO_IN_THIS_COMMIT: extract to validation function. + if flagModuleValue != "*" { + switch flagModuleValue { + case "application": + case "gateway": + case "proof": + case "service": + case "session": + case "shared": + case "supplier": + case "tokenomics": + default: + return fmt.Errorf("ERROR: invalid module name: %s", flagModuleValue) + } + } + // TODO_IN_THIS_COMMIT: add hack/work-around to temporarily strip patch version from go.mod. // TODO_IN_THIS_COMMIT: add hack/work-around to temporarily strip patch version from go.mod. // TODO_IN_THIS_COMMIT: add hack/work-around to temporarily strip patch version from go.mod. @@ -104,7 +120,7 @@ func checkModule(_ context.Context) error { // Load the package containing the target file or directory poktrollPkgPathPattern := "github.com/pokt-network/poktroll/x/..." - logger.Info().Msgf("Loading package(s) in %s", poktrollPkgPathPattern) + //logger.Info().Msgf("Loading package(s) in %s", poktrollPkgPathPattern) pkgs, err := packages.Load(cfg, poktrollPkgPathPattern) if err != nil { @@ -117,6 +133,13 @@ func checkModule(_ context.Context) error { // - github.com/pokt-network/poktroll/x/gateway/keeper // - ... for _, pkg := range pkgs { + if flagModuleValue != "*" { + moduleRootPath := fmt.Sprintf("github.com/pokt-network/poktroll/x/%s", flagModuleValue) + if !strings.HasPrefix(pkg.PkgPath, moduleRootPath) { + continue + } + } + if pkg.Name != "keeper" { continue } @@ -336,9 +359,13 @@ func checkModule(_ context.Context) error { // Print offending lines in package // TODO_IN_THIS_COMMIT: refactor to const. pkgsPattern := "github.com/pokt-network/poktroll/x/..." + if flagModuleValue != "*" { + pkgsPattern = fmt.Sprintf("github.com/pokt-network/poktroll/x/%s/...", flagModuleValue) + } + numOffendingLines := len(offendingPkgErrLineSet) if numOffendingLines == 0 { - fmt.Printf("No offending lines in %q\n", pkgsPattern) + logger.Info().Msgf("🎉 No offending lines in %q 🎉", pkgsPattern) } else { offendingPkgErrLines := make([]string, 0, len(offendingPkgErrLineSet)) for offendingPkgErrLine := range offendingPkgErrLineSet { @@ -348,35 +375,24 @@ func checkModule(_ context.Context) error { sort.Strings(offendingPkgErrLines) msg := fmt.Sprintf( - "Found %d offending lines in %q", + "🚨 Found %d offending lines in %q 🚨", numOffendingLines, pkgsPattern, ) logger.Info().Msgf( - "%s:\n%s\n%s", + "%s:\n%s", msg, strings.Join(offendingPkgErrLines, "\n"), - msg, ) + + if numOffendingLines > 5 { + logger.Info().Msg(msg) + } } // --- END return nil } -// TODO_IN_THIS_COMMIT: remove or move to a testutil package. -func printNodeSource(msg string, pkg *packages.Package, queryNode any) { - node, ok := queryNode.(ast.Node) - if !ok { - fmt.Printf("ERROR: queryNode is not an ast.Node: %T: %+v\n", queryNode, queryNode) - return - } - - logger.Warn().Msgf( - "not traversing %+v\n\t%s", - pkg.Fset.Position(node.Pos()), msg, - ) -} - // TODO_IN_THIS_COMMIT: move & godoc... func appendOffendingLine(sourceLine string) { offendingPkgErrLineSet[sourceLine] = struct{}{} From c5c659fd92eb9905432a69687d533ec5de27bfb6 Mon Sep 17 00:00:00 2001 From: Bryan White Date: Mon, 25 Nov 2024 16:04:05 +0100 Subject: [PATCH 10/18] chore: add check_grpc_status_errors make target --- makefiles/checks.mk | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/makefiles/checks.mk b/makefiles/checks.mk index d50e5320a..782ddb31e 100644 --- a/makefiles/checks.mk +++ b/makefiles/checks.mk @@ -153,3 +153,10 @@ check_proto_unstable_marshalers: ## Check that all protobuf files have the `stab fix_proto_unstable_marshalers: ## Ensure the `stable_marshaler_all` option is present on all protobuf files. go run ./tools/scripts/protocheck/cmd unstable --fix ${MAKE} proto_regen + +MODULE ?= "*" +LEVEL ?= "info" + +.PHONY: check_grpc_status_errors +check_grpc_status_errors: ## Check that all gRPC service handlers return gRPC status errors. + go run ./tools/scripts/protocheck/cmd status-errors -m ${MODULE} -l ${LEVEL} \ No newline at end of file From 68d76e041f48f2fdf69d32955d04e470a50ba604 Mon Sep 17 00:00:00 2001 From: Bryan White Date: Tue, 26 Nov 2024 18:52:08 +0100 Subject: [PATCH 11/18] wip: more better --- tools/scripts/protocheck/cmd/ast_tracing.go | 449 ++++++++++++++++-- tools/scripts/protocheck/cmd/status_errors.go | 89 +--- .../keeper/msg_server_stake_supplier.go | 4 +- 3 files changed, 416 insertions(+), 126 deletions(-) diff --git a/tools/scripts/protocheck/cmd/ast_tracing.go b/tools/scripts/protocheck/cmd/ast_tracing.go index 0634e34ad..e52ff1f70 100644 --- a/tools/scripts/protocheck/cmd/ast_tracing.go +++ b/tools/scripts/protocheck/cmd/ast_tracing.go @@ -11,13 +11,193 @@ import ( const grpcStatusImportPath = "google.golang.org/grpc/status" +var TRACE bool + +// TODO_IN_THIS_COMMIT: move & godoc... +func walkFuncBody( + pkg *packages.Package, + pkgs []*packages.Package, + shouldAppend, + shouldExhonerate bool, +) func(ast.Node) bool { + return func(n ast.Node) bool { + if n == nil { + return false + } + + logger.Debug(). + Str("position", fmt.Sprintf(" %s ", pkg.Fset.Position(n.Pos()).String())). + Str("node_type", fmt.Sprintf("%T", n)). + Bool("shouldAppend", shouldAppend). + Msg("walking function body") + + //position := pkg.Fset.Position(n.Pos()) + //logger.Warn().Msgf("position: %s", position.String()) + // + //inspectPos := pkg.Fset.Position(n.Pos()) + //fmt.Printf(">>> inspecting %T at %s\n", n, inspectPos) + // + //inspectLocation := inspectPos.String() + //if inspectLocation == "/home/bwhite/Projects/pokt/poktroll/x/proof/keeper/msg_server_create_claim.go:44:3" { + // fmt.Printf(">>> found it!\n") + //} + + switch n := n.(type) { + //case *ast.BlockStmt: + // return true + //// Search for a return statement. + //case *ast.IfStmt: + // return true + case *ast.ReturnStmt: + lastResult := n.Results[len(n.Results)-1] + inspectPosition := pkg.Fset.Position(lastResult.Pos()).String() + + logger := logger.With( + "node_type", fmt.Sprintf("%T", lastResult), + "inspectPosition", fmt.Sprintf(" %s ", inspectPosition), + ) + + logger.Debug().Msgf("lastResult: %+v", lastResult) + + switch lastReturnArgNode := lastResult.(type) { + // `return nil, err` <-- last arg is an *ast.Ident. + case *ast.Ident: + //logger.Debug().Fields(map[string]any{ + // "node_type": fmt.Sprintf("%T", lastReturnArgNode), + // "inspectPosition": pkg.Fset.Position(lastReturnArgNode.Pos()).String(), + //}).Msg("traversing ast node") + + // TODO_IN_THIS_COMMIT: No need to check that the last return + // arg is an error type if we checked that the function returns + // an error as the last arg. + //if lastReturnArgNode.Name == "err" { + if lastReturnArgNode.Obj == nil { + logger.Warn().Msg("lastReturnArgNode.Obj is nil") + return true + } + + def := pkg.TypesInfo.Uses[lastReturnArgNode] + if def == nil { + logger.Warn().Msg("def is nil") + return true + } + + if def.Type().String() != "error" { + logger.Warn().Msg("def is not error") + //inspectPosition := pkg.Fset.Position(lastReturnArgNode.Pos()).String() + //break + return false + } + + if shouldAppend { + logger.Debug().Msg("appending potential offending line") + appendOffendingLine(inspectPosition) + } + traceExpressionStack(lastReturnArgNode, pkgs, nil, pkg, lastReturnArgNode, offendingPkgErrLineSet) + return true + + // `return nil, types.ErrXXX.Wrapf(...)` <-- last arg is a *ast.CallExpr. + case *ast.CallExpr: + //logger.Debug().Msg("inspecting ast node") + if shouldAppend { + logger.Debug().Msg("appending potential offending line") + appendOffendingLine(inspectPosition) + } + traceExpressionStack(lastReturnArgNode, pkgs, nil, pkg, lastReturnArgNode, offendingPkgErrLineSet) + //TraverseCallStack(lastReturnArgNode, pkgs, 0, condition(lastReturnArgNode)) + //return false + return true + + case *ast.SelectorExpr: + if shouldAppend { + logger.Debug().Msg("appending potential offending line") + appendOffendingLine(inspectPosition) + } + traceSelectorExpr(lastReturnArgNode, pkg, pkgs, lastReturnArgNode, offendingPkgErrLineSet) + return true + } + + default: + //logger.Warn().Str("node_types", fmt.Sprintf("%T", n)).Msg("NOT traversing ast node") + return true + } + + //logger.Debug().Msg("appending potential offending line") + //appendOffendingLine(inspectPosition) + // + //return true + + return true + } +} + // Helper function to trace selector expressions func traceSelectorExpr( expr *ast.SelectorExpr, + //scopeNode ast.Node, + candidatePkg *packages.Package, pkgs []*packages.Package, candidateNode ast.Node, offendingPositions map[string]struct{}, ) bool { + logger.Debug().Msg("tracing selector expression") + //fmt.Println(">>>>>>>>> TRACE SELECTOR EXPR") + for _, pkg := range pkgs { + if selection := pkg.TypesInfo.Selections[expr]; selection != nil { + //logger.Warn().Msgf("<<<<<<< has selection: %s", selection.String()) + for _, pkg2 := range pkgs { + position := pkg2.Fset.Position(selection.Obj().Pos()) + + var foundNode ast.Node + for _, fileNode := range pkg2.Syntax { + foundNode = findNodeByPosition(pkg2.Fset, fileNode, position) + if foundNode != nil { + //traceExpressionStack(foundNode, pkgs, expr, pkg, foundNode, offendingPositions) + logger.Warn(). + Str("node_type", fmt.Sprintf("%T", foundNode)). + Str("selection_position", fmt.Sprintf(" %s ", position)). + Str("expr_position", fmt.Sprintf(" %s ", pkg.Fset.Position(expr.Pos()).String())). + Str("found_node_position", fmt.Sprintf(" %s ", pkg2.Fset.Position(foundNode.Pos()).String())). + Msg("found node") + //fmt.Printf(">>>>>>>>>> found node %T %+v\n", foundNode, foundNode) + //traceExpressionStack(foundNode.(*ast.Ident), pkgs, expr, pkg2, foundNode, offendingPositions) + var declNode *ast.FuncDecl + ast.Inspect(fileNode, func(n ast.Node) bool { + if declNode != nil { + return false + } + + if decl, ok := n.(*ast.FuncDecl); ok { + if decl.Name.Name == foundNode.(*ast.Ident).Name && + decl.Pos() < foundNode.Pos() && + foundNode.Pos() <= decl.End() { + declNode = decl + return false + } + } + return true + }) + + if declNode != nil { + logger.Warn().Str("decl_position", pkg2.Fset.Position(declNode.Pos()).String()).Msg("tracing decl node") + logger.Warn().Str("decl_body", pkg2.Fset.Position(declNode.Body.Pos()).String()).Msg("tracing decl node body") + ast.Inspect(declNode.Body, walkFuncBody(pkg, pkgs, false, false)) + //walkFuncBody(pkg, pkgs)(declNode.Body) + } else { + logger.Warn().Msg("could not find decl node") + } + + //return false + return true + } + } + } + return true + } + } + + // TODO_IN_THIS_COMMIT: refactor; below happens when the selector is not found within any package. + // Resolve the base expression switch x := expr.X.(type) { case *ast.Ident: // e.g., `pkg.Func` @@ -64,9 +244,14 @@ func traceSelectorExpr( if _, ok := offendingPositions[candidateNodePosition]; ok { logger.Debug().Msgf("exhonerating %s", candidateNodePosition) delete(offendingPositions, candidateNodePosition) + } else { + logger.Warn().Msgf("can't exhonerating %s", candidateNodePosition) } return false } + + //traceSelectorExpr(expr.Sel, candidatePkg, pkgs, candidateNode, offendingPositions) + //} //case *ast.SelectorExpr: // e.g., `obj.Method.Func` // if traceSelectorExpr(x, info, fset) { @@ -74,18 +259,28 @@ func traceSelectorExpr( // } } else if obj = pkg.TypesInfo.Defs[x]; obj != nil { logger.Warn().Msgf("no use but def: %+v", obj) + } else if obj = pkg.TypesInfo.Defs[expr.Sel]; obj != nil { + logger.Warn(). + Str("pkg_path", pkg.PkgPath). + Str("name", expr.Sel.Name). + Msgf("sel def") + traceExpressionStack(expr.Sel, pkgs, expr, candidatePkg, candidateNode, offendingPositions) + //} else { + // logger.Warn().Msgf("no use or def: %+v, sel: %+v", x, expr.Sel) } } case *ast.SelectorExpr: // e.g., `obj.Method.Func` - return traceSelectorExpr(x, pkgs, candidateNode, offendingPositions) + logger.Debug().Msgf("tracing recursive selector expression: %+v", expr) + return traceSelectorExpr(x, candidatePkg, pkgs, candidateNode, offendingPositions) case *ast.CallExpr: logger.Debug().Msgf("tracing call expression: %+v", expr) - switch callExpr := x.Fun.(type) { - case *ast.SelectorExpr: // e.g., `obj.Method.Func` - return traceSelectorExpr(callExpr, pkgs, candidateNode, offendingPositions) - default: - logger.Warn().Msgf("skipping sub-selector call expression X type: %T", x) - } + //switch callExpr := x.Fun.(type) { + //case *ast.SelectorExpr: // e.g., `obj.Method.Func` + // return traceSelectorExpr(callExpr, pkgs, candidateNode, offendingPositions) + //default: + // logger.Warn().Msgf("skipping sub-selector call expression X type: %T", x) + //} + traceExpressionStack(x.Fun, pkgs, expr, candidatePkg, candidateNode, offendingPositions) default: logger.Warn().Msgf("skipping selector expression X type: %T", x) } @@ -96,6 +291,8 @@ func traceSelectorExpr( func traceExpressionStack( exprToTrace ast.Expr, pkgs []*packages.Package, + _ ast.Node, + candidatePkg *packages.Package, candidateNode ast.Node, offendingPositions map[string]struct{}, ) bool { @@ -105,7 +302,7 @@ func traceExpressionStack( logger := logger.With( "node_type", fmt.Sprintf("%T", exprToTrace), - //"position", pkg.Fset.Position(x.Pos()).String(), + //"position", candidatePkg.Fset.Position(exprToTrace.Pos()).String(), //"package", strings.Trim(strings.Split(obj.String(), " ")[2], "\"()"), ) logger.Debug().Msg("tracing expression stack") @@ -116,7 +313,7 @@ func traceExpressionStack( case *ast.CallExpr: if sel, ok := expr.Fun.(*ast.SelectorExpr); ok { //logger.Debug().Msg("tracing selector expression") - return traceSelectorExpr(sel, pkgs, candidateNode, offendingPositions) + return traceSelectorExpr(sel, candidatePkg, pkgs, candidateNode, offendingPositions) } //logger.Debug().Msgf("tracing expression args: %+v", expr) //for _, arg := range expr.Args { @@ -128,41 +325,83 @@ func traceExpressionStack( case *ast.BinaryExpr: logger.Debug().Msg("tracing binary expression") // TODO_IN_THIS_COMMIT: return traceExpressionStack... ? - if traceExpressionStack(expr.X, pkgs, candidateNode, offendingPositions) { - traceExpressionStack(expr.Y, pkgs, candidateNode, offendingPositions) + if traceExpressionStack(expr.X, pkgs, expr, candidatePkg, candidateNode, offendingPositions) { + traceExpressionStack(expr.Y, pkgs, expr, candidatePkg, candidateNode, offendingPositions) //return true } return true case *ast.ParenExpr: logger.Debug().Msg("tracing paren expression") - return traceExpressionStack(expr.X, pkgs, candidateNode, offendingPositions) + return traceExpressionStack(expr.X, pkgs, expr, candidatePkg, candidateNode, offendingPositions) case *ast.SelectorExpr: logger.Debug().Msg("tracing selector expression") - return traceSelectorExpr(expr, pkgs, candidateNode, offendingPositions) + return traceSelectorExpr(expr, candidatePkg, pkgs, candidateNode, offendingPositions) case *ast.Ident: - logger.Debug().Msg("tracing ident") - //fmt.Printf(">>> exprToTrace: %+v\n", expr) - //var srcPkg *packages.Package - for _, pkg := range pkgs { - //if obj := pkg.TypesInfo.Defs[expr]; obj != nil { - // srcPkg = pkg - for _, fileNode := range pkg.Syntax { - declOrAssign, _ := findDeclOrAssign(expr, fileNode, pkg.Fset) - if declOrAssign == nil { - continue + logger.Debug().Str("name", expr.Name).Msg("tracing ident") + //def := candidatePkg.TypesInfo.Defs[expr] + // TODO_IN_THIS_COMMIT: handle no def... + //x := def.Parent().Lookup(expr.Name) + // TODO_IN_THIS_COMMIT: handle no lookup... + //x.Pos() + + //var candidateFileNode ast.Node + //for _, fileNode := range candidatePkg.Syntax { + // ast.Inspect(fileNode, func(n ast.Node) bool { + // if n == candidateNode { + // candidateFileNode = fileNode + // return false + // } + // return true + // }) + //} + //for _, pkg := range pkgs { + // for _, fileNode := range pkg.Syntax { + //declOrAssign, _ := findDeclOrAssign(expr, candidateFileNode, candidatePkg) + //declOrAssign, _ := findDeclOrAssign(expr, expr, candidatePkg) + //declOrAssign, declOrAssignPos := findDeclOrAssign(expr, scopeNode, candidatePkg) + + // TODO_IN_THIS_COMMIT: return a slice of all decls and assignments + // and their respective files/pkgs. + declOrAssign, _ := newFindDeclOrAssign(expr, candidatePkg) + if declOrAssign == nil { + logger.Warn().Msgf("no declaration or assignment found for ident %q", expr.String()) + return false + } + //logger.Debug(). + // Str("pkg_path", candidatePkg.PkgPath). + // //Str("file_path", fmt.Sprintf(" %s", candidatePkg.Fset.File(candidateFileNode.Pos()).Name())). + // Str("decl_or_assign_pos", fmt.Sprintf(" %s ", declOrAssignPos)). + // Send() + //Msgf("found decl or assign: %+v", declOrAssign) + switch doa := declOrAssign.(type) { + case ast.Expr: + traceExpressionStack(doa, pkgs, expr, candidatePkg, candidateNode, offendingPositions) + case *ast.AssignStmt: + logger.Warn().Msgf(">>>>>>> assign stmt: %+v", doa) + // TODO_IN_THIS_COMMIT: what about len(Rhs) > 1? + traceExpressionStack(doa.Rhs[0], pkgs, expr, candidatePkg, candidateNode, offendingPositions) + case *ast.ValueSpec: + // TODO_RESUME_HERE!!!! + // TODO_RESUME_HERE!!!! + // TODO_RESUME_HERE!!!! + // TODO_RESUME_HERE!!!! + // + // find "closest" previous assignment... + // + logger.Warn(). + //Str("position", fmt.Sprintf(" %s ", pkg.Fset.Position(doa.Pos()).String())). + Int("len(values)", len(doa.Values)). + Msgf(">>>>>>> value spec: %+v", doa) + + if doa.Values != nil { + for _, value := range doa.Values { + traceExpressionStack(value, pkgs, expr, candidatePkg, candidateNode, offendingPositions) } - logger.Debug(). - Str("pkg_path", pkg.PkgPath). - Str("file_path", pkg.Fset.File(fileNode.Pos()).Name()). - Str("decl_or_assign_pos", pkg.Fset.Position(declOrAssign.Pos()).String()). - Send() - //Msgf("found decl or assign: %+v", declOrAssign) - traceExpressionStack(declOrAssign, pkgs, candidateNode, offendingPositions) } - //} + default: + logger.Warn().Msgf("unknown node type 3: %T", doa) } - //if srcPkg == nil { - // logger.Warn().Msgf("no pkg found for expr: %+v", expr) + //} //} return true //case *ast.SliceExpr: @@ -174,8 +413,117 @@ func traceExpressionStack( } } +//func newNewFindDeclOrAssign( +// targetIdent *ast.Ident, +// scopeNode ast.Node, +// pkg *packages.Package, +//) (declNode ast.Node, declPos token.Position) { +// var nodes []ast.Node +// ast.Inspect(scopeNode, func(n ast.Node) bool { +// if n != nil { +// nodes = append(nodes, n) +// } +// return true +// }) +// +// for i := len(nodes) - 1; i >= 0; i-- { +// +// } +//} + +// TODO_IN_THIS_COMMIT: move & godoc... +func newFindDeclOrAssign( + targetIdent *ast.Ident, + //pkgs []*packages.Package, + pkg *packages.Package, +) (declNode ast.Node, declPos token.Position) { + //var closestDeclNode ast.Node + + for _, fileNode := range pkg.Syntax { + if declNode != nil { + return declNode, declPos + } + + //fmt.Println(">>>>>>>>> NEW FILE NODE") + //ast.Inspect(fileNode, func(n ast.Node) bool { + // //if declNode != nil { + // // //fmt.Println(">>>>>>>>> EXITING EARLY") + // // return false + // //} + // + // if ident, ok := n.(*ast.Ident); ok && + // ident.Name == targetIdent.Name { + //if obj := pkg.TypesInfo.Defs[targetIdent]; obj != nil { + // declPos = pkg.Fset.Position(obj.Pos()) + // //logger.Debug().Fields(map[string]any{ + // // //"pkg_path": pkg.PkgPath, + // // "file_path": fmt.Sprintf(" %s ", pkg.Fset.File(fileNode.Pos()).Name()), + // // "target_pos": fmt.Sprintf(" %s ", pkg.Fset.Position(targetIdent.Pos()).String()), + // // "decl_pos": fmt.Sprintf(" %s ", declPos.String()), + // //}).Msg("defs") + // declNode = findNodeByPosition(pkg.Fset, fileNode, declPos) + // return false + //} else if obj := pkg.TypesInfo.Uses[targetIdent]; obj != nil { + if obj := pkg.TypesInfo.Uses[targetIdent]; obj != nil { + // TODO_IN_THIS_COMMIT: figure out why this is called so frequently. + logger.Debug().Fields(map[string]any{ + //"pkg_path": pkg.PkgPath, + "file_path": fmt.Sprintf(" %s ", pkg.Fset.File(fileNode.Pos()).Name()), + "target_pos": fmt.Sprintf(" %s ", pkg.Fset.Position(targetIdent.Pos()).String()), + "decl_pos": fmt.Sprintf(" %s ", pkg.Fset.Position(obj.Pos()).String()), + }).Msg("uses") + declPos = pkg.Fset.Position(obj.Pos()) + declNode = findNodeByPosition(pkg.Fset, fileNode, declPos) + logger.Warn(). + Str("decl_node", fmt.Sprintf("%+v", declNode)). + Str("decl_pos", fmt.Sprintf(" %s ", declPos)). + Msg("found decl node") + //return false + } + // } + // return true + //}) + } + //fmt.Println(">>>>>>>>> DONE") + + return declNode, declPos +} + +// TODO_IN_THIS_COMMIT: move & godoc... +// search for targetIdent by position +func findNodeByPosition( + fset *token.FileSet, + fileNode *ast.File, + position token.Position, +) (targetNode ast.Node) { + //fmt.Println(">>>>>>>>> FIND NODE BY POSITION") + + ast.Inspect(fileNode, func(n ast.Node) bool { + if targetNode != nil { + return false + } + + if n == nil { + return true + } + + if n != nil && fset.Position(n.Pos()) == position { + targetNode = n + return false + } + + if targetNode != nil { + return false + } + + return true + }) + return targetNode +} + // Find the declaration or assignment of an identifier -func findDeclOrAssign(ident *ast.Ident, node ast.Node, fset *token.FileSet) (ast.Expr, *token.Position) { +// func findDeclOrAssign(ident *ast.Ident, fileNode ast.Node, pkg *packages.Package) (ast.Expr, *token.Position) { +func findDeclOrAssign(ident *ast.Ident, scopeNode ast.Node, pkg *packages.Package) (ast.Expr, *token.Position) { //logger.Debug().Msg("finding decl or assign") //fmt.Println("!!!! findDeclOrAssign begin") @@ -183,7 +531,9 @@ func findDeclOrAssign(ident *ast.Ident, node ast.Node, fset *token.FileSet) (ast var foundPos token.Position //fmt.Println("!!!! findDeclOrAssign inspect") - ast.Inspect(node, func(n ast.Node) bool { + //for _, fileNode := range pkg.Syntax { + //ast.Inspect(fileNode, func(n ast.Node) bool { + ast.Inspect(scopeNode, func(n ast.Node) bool { switch stmt := n.(type) { case *ast.AssignStmt: // Look for assignments //fmt.Println("!!!! findDeclOrAssign case assign") @@ -195,7 +545,7 @@ func findDeclOrAssign(ident *ast.Ident, node ast.Node, fset *token.FileSet) (ast } else { declOrAssign = stmt.Rhs[i] } - foundPos = fset.Position(stmt.Pos()) + foundPos = pkg.Fset.Position(stmt.Pos()) } } case *ast.ValueSpec: // Look for declarations with initialization @@ -203,13 +553,38 @@ func findDeclOrAssign(ident *ast.Ident, node ast.Node, fset *token.FileSet) (ast for i, name := range stmt.Names { if name.Name == ident.Name && i < len(stmt.Values) { declOrAssign = stmt.Values[i] - foundPos = fset.Position(stmt.Pos()) + foundPos = pkg.Fset.Position(stmt.Pos()) return false } } } return true }) + //} return declOrAssign, &foundPos } + +// TODO_IN_THIS_COMMIT: move & godoc... +//func getNodeFromPosition(fset *token.FileSet, position token.Position) ast.Node { +// file := fset.File(position) +// if file == nil { +// return nil +// } +// +// var node ast.Node +// ast.Inspect(file, func(n ast.Node) bool { +// if n == nil { +// return false +// } +// +// if fset.Position(n.Pos()).String() == position.String() { +// node = n +// return false +// } +// +// return true +// }) +// +// return node +//} diff --git a/tools/scripts/protocheck/cmd/status_errors.go b/tools/scripts/protocheck/cmd/status_errors.go index e06136c66..f4f9a20ca 100644 --- a/tools/scripts/protocheck/cmd/status_errors.go +++ b/tools/scripts/protocheck/cmd/status_errors.go @@ -256,94 +256,7 @@ func checkModule(_ context.Context) error { // Recursively traverse the function body, looking for non-nil error returns. // TODO_IN_THIS_COMMIT: extract --- BEGIN //fmt.Printf(">>> walking func from file: %s\n", pkg.Fset.Position(astFile.Pos()).Filename) - ast.Inspect(fnNode.Body, func(n ast.Node) bool { - if n == nil { - return false - } - //inspectPos := pkg.Fset.Position(n.Pos()) - //fmt.Printf(">>> inspecting %T at %s\n", n, inspectPos) - - //inspectLocation := inspectPos.String() - //if inspectLocation == "/home/bwhite/Projects/pokt/poktroll/x/proof/keeper/msg_server_create_claim.go:44:3" { - // fmt.Printf(">>> found it!\n") - //} - - switch n := n.(type) { - case *ast.BlockStmt: - return true - // Search for a return statement. - case *ast.ReturnStmt: - lastResult := n.Results[len(n.Results)-1] - inspectPosition := pkg.Fset.Position(lastResult.Pos()).String() - - logger := logger.With( - "node_type", fmt.Sprintf("%T", lastResult), - "inspectPosition", fmt.Sprintf(" %s ", inspectPosition), - ) - - switch lastReturnArgNode := lastResult.(type) { - // `return nil, err` <-- last arg is an *ast.Ident. - case *ast.Ident: - //logger.Debug().Fields(map[string]any{ - // "node_type": fmt.Sprintf("%T", lastReturnArgNode), - // "inspectPosition": pkg.Fset.Position(lastReturnArgNode.Pos()).String(), - //}).Msg("traversing ast node") - - // TODO_IN_THIS_COMMIT: No need to check that the last return - // arg is an error type if we checked that the function returns - // an error as the last arg. - //if lastReturnArgNode.Name == "err" { - if lastReturnArgNode.Obj == nil { - return true - } - - def := pkg.TypesInfo.Uses[lastReturnArgNode] - if def == nil { - logger.Warn().Msg("def is nil") - return true - } - - if def.Type().String() != "error" { - //logger.Warn().Msg("def is not error") - //inspectPosition := pkg.Fset.Position(lastReturnArgNode.Pos()).String() - //break - return false - } - - logger.Debug().Msg("appending potential offending line") - appendOffendingLine(inspectPosition) - traceExpressionStack(lastReturnArgNode, pkgs, lastReturnArgNode, offendingPkgErrLineSet) - return true - - // `return nil, types.ErrXXX.Wrapf(...)` <-- last arg is a *ast.CallExpr. - case *ast.CallExpr: - //logger.Debug().Msg("inspecting ast node") - logger.Debug().Msg("appending potential offending line") - appendOffendingLine(inspectPosition) - traceExpressionStack(lastReturnArgNode, pkgs, lastReturnArgNode, offendingPkgErrLineSet) - //TraverseCallStack(lastReturnArgNode, pkgs, 0, condition(lastReturnArgNode)) - //return false - return true - - case *ast.SelectorExpr: - logger.Debug().Msg("appending potential offending line") - appendOffendingLine(inspectPosition) - traceSelectorExpr(lastReturnArgNode, pkgs, lastReturnArgNode, offendingPkgErrLineSet) - return true - - default: - logger.Warn().Msg("NOT traversing ast node") - return true - } - - //logger.Debug().Msg("appending potential offending line") - //appendOffendingLine(inspectPosition) - // - //return true - } - - return true - }) + ast.Inspect(fnNode.Body, walkFuncBody(pkg, pkgs, true, true)) // --- END return false diff --git a/x/supplier/keeper/msg_server_stake_supplier.go b/x/supplier/keeper/msg_server_stake_supplier.go index d92a41201..4a7fb9992 100644 --- a/x/supplier/keeper/msg_server_stake_supplier.go +++ b/x/supplier/keeper/msg_server_stake_supplier.go @@ -119,7 +119,9 @@ func (k msgServer) StakeSupplier(ctx context.Context, msg *types.MsgStakeSupplie msg.Signer, msg.GetStake(), supplier.Stake, ) logger.Info(fmt.Sprintf("WARN: %s", err)) - return nil, status.Error(codes.InvalidArgument, err.Error()) + //return nil, status.Error(codes.InvalidArgument, err.Error()) + err = status.Error(codes.InvalidArgument, err.Error()) + return nil, err } // MUST ALWAYS have at least minimum stake. From cc2af0ced2ed8b7efa70949f9500d3a8732d017d Mon Sep 17 00:00:00 2001 From: Bryan White Date: Wed, 27 Nov 2024 10:34:40 +0100 Subject: [PATCH 12/18] =?UTF-8?q?wip:=20error=20assignment=20case=20wrangl?= =?UTF-8?q?ed=20=F0=9F=99=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/scripts/protocheck/cmd/ast_tracing.go | 253 ++++++++++++++---- tools/scripts/protocheck/cmd/codes.go | 1 + tools/scripts/protocheck/cmd/status_errors.go | 2 + tools/scripts/protocheck/cmd/unstable.go | 4 +- 4 files changed, 211 insertions(+), 49 deletions(-) diff --git a/tools/scripts/protocheck/cmd/ast_tracing.go b/tools/scripts/protocheck/cmd/ast_tracing.go index e52ff1f70..04b2fa783 100644 --- a/tools/scripts/protocheck/cmd/ast_tracing.go +++ b/tools/scripts/protocheck/cmd/ast_tracing.go @@ -4,6 +4,7 @@ import ( "fmt" "go/ast" "go/token" + "slices" "strings" "golang.org/x/tools/go/packages" @@ -32,7 +33,7 @@ func walkFuncBody( Msg("walking function body") //position := pkg.Fset.Position(n.Pos()) - //logger.Warn().Msgf("position: %s", position.String()) + //logger.Debug().Msgf("position: %s", position.String()) // //inspectPos := pkg.Fset.Position(n.Pos()) //fmt.Printf(">>> inspecting %T at %s\n", n, inspectPos) @@ -72,18 +73,18 @@ func walkFuncBody( // an error as the last arg. //if lastReturnArgNode.Name == "err" { if lastReturnArgNode.Obj == nil { - logger.Warn().Msg("lastReturnArgNode.Obj is nil") + logger.Debug().Msg("lastReturnArgNode.Obj is nil") return true } def := pkg.TypesInfo.Uses[lastReturnArgNode] if def == nil { - logger.Warn().Msg("def is nil") + logger.Debug().Msg("def is nil") return true } if def.Type().String() != "error" { - logger.Warn().Msg("def is not error") + logger.Debug().Msg("def is not error") //inspectPosition := pkg.Fset.Position(lastReturnArgNode.Pos()).String() //break return false @@ -118,7 +119,7 @@ func walkFuncBody( } default: - //logger.Warn().Str("node_types", fmt.Sprintf("%T", n)).Msg("NOT traversing ast node") + //logger.Debug().Str("node_types", fmt.Sprintf("%T", n)).Msg("NOT traversing ast node") return true } @@ -144,7 +145,7 @@ func traceSelectorExpr( //fmt.Println(">>>>>>>>> TRACE SELECTOR EXPR") for _, pkg := range pkgs { if selection := pkg.TypesInfo.Selections[expr]; selection != nil { - //logger.Warn().Msgf("<<<<<<< has selection: %s", selection.String()) + //logger.Debug().Msgf("<<<<<<< has selection: %s", selection.String()) for _, pkg2 := range pkgs { position := pkg2.Fset.Position(selection.Obj().Pos()) @@ -153,7 +154,7 @@ func traceSelectorExpr( foundNode = findNodeByPosition(pkg2.Fset, fileNode, position) if foundNode != nil { //traceExpressionStack(foundNode, pkgs, expr, pkg, foundNode, offendingPositions) - logger.Warn(). + logger.Debug(). Str("node_type", fmt.Sprintf("%T", foundNode)). Str("selection_position", fmt.Sprintf(" %s ", position)). Str("expr_position", fmt.Sprintf(" %s ", pkg.Fset.Position(expr.Pos()).String())). @@ -179,12 +180,12 @@ func traceSelectorExpr( }) if declNode != nil { - logger.Warn().Str("decl_position", pkg2.Fset.Position(declNode.Pos()).String()).Msg("tracing decl node") - logger.Warn().Str("decl_body", pkg2.Fset.Position(declNode.Body.Pos()).String()).Msg("tracing decl node body") + logger.Debug().Str("decl_position", pkg2.Fset.Position(declNode.Pos()).String()).Msg("tracing decl node") + logger.Debug().Str("decl_body", pkg2.Fset.Position(declNode.Body.Pos()).String()).Msg("tracing decl node body") ast.Inspect(declNode.Body, walkFuncBody(pkg, pkgs, false, false)) //walkFuncBody(pkg, pkgs)(declNode.Body) } else { - logger.Warn().Msg("could not find decl node") + logger.Debug().Msg("could not find decl node") } //return false @@ -245,7 +246,7 @@ func traceSelectorExpr( logger.Debug().Msgf("exhonerating %s", candidateNodePosition) delete(offendingPositions, candidateNodePosition) } else { - logger.Warn().Msgf("can't exhonerating %s", candidateNodePosition) + logger.Warn().Msgf("can't exhonerate %s", candidateNodePosition) } return false } @@ -258,15 +259,15 @@ func traceSelectorExpr( // return true // } } else if obj = pkg.TypesInfo.Defs[x]; obj != nil { - logger.Warn().Msgf("no use but def: %+v", obj) + logger.Debug().Msgf("no use but def: %+v", obj) } else if obj = pkg.TypesInfo.Defs[expr.Sel]; obj != nil { - logger.Warn(). + logger.Debug(). Str("pkg_path", pkg.PkgPath). Str("name", expr.Sel.Name). Msgf("sel def") traceExpressionStack(expr.Sel, pkgs, expr, candidatePkg, candidateNode, offendingPositions) //} else { - // logger.Warn().Msgf("no use or def: %+v, sel: %+v", x, expr.Sel) + // logger.Debug().Msgf("no use or def: %+v, sel: %+v", x, expr.Sel) } } case *ast.SelectorExpr: // e.g., `obj.Method.Func` @@ -278,7 +279,7 @@ func traceSelectorExpr( //case *ast.SelectorExpr: // e.g., `obj.Method.Func` // return traceSelectorExpr(callExpr, pkgs, candidateNode, offendingPositions) //default: - // logger.Warn().Msgf("skipping sub-selector call expression X type: %T", x) + // logger.Debug().Msgf("skipping sub-selector call expression X type: %T", x) //} traceExpressionStack(x.Fun, pkgs, expr, candidatePkg, candidateNode, offendingPositions) default: @@ -364,9 +365,10 @@ func traceExpressionStack( // and their respective files/pkgs. declOrAssign, _ := newFindDeclOrAssign(expr, candidatePkg) if declOrAssign == nil { - logger.Warn().Msgf("no declaration or assignment found for ident %q", expr.String()) return false } + + //traceExpressionStack(declOrAssign.(ast.Expr), pkgs, expr, candidatePkg, candidateNode, offendingPositions) //logger.Debug(). // Str("pkg_path", candidatePkg.PkgPath). // //Str("file_path", fmt.Sprintf(" %s", candidatePkg.Fset.File(candidateFileNode.Pos()).Name())). @@ -377,27 +379,27 @@ func traceExpressionStack( case ast.Expr: traceExpressionStack(doa, pkgs, expr, candidatePkg, candidateNode, offendingPositions) case *ast.AssignStmt: - logger.Warn().Msgf(">>>>>>> assign stmt: %+v", doa) + //logger.Debug().Msgf(">>>>>>> assign stmt: %+v", doa) // TODO_IN_THIS_COMMIT: what about len(Rhs) > 1? traceExpressionStack(doa.Rhs[0], pkgs, expr, candidatePkg, candidateNode, offendingPositions) - case *ast.ValueSpec: - // TODO_RESUME_HERE!!!! - // TODO_RESUME_HERE!!!! - // TODO_RESUME_HERE!!!! - // TODO_RESUME_HERE!!!! - // - // find "closest" previous assignment... - // - logger.Warn(). - //Str("position", fmt.Sprintf(" %s ", pkg.Fset.Position(doa.Pos()).String())). - Int("len(values)", len(doa.Values)). - Msgf(">>>>>>> value spec: %+v", doa) - - if doa.Values != nil { - for _, value := range doa.Values { - traceExpressionStack(value, pkgs, expr, candidatePkg, candidateNode, offendingPositions) - } - } + //case *ast.ValueSpec: + // // TODO_RESUME_HERE!!!! + // // TODO_RESUME_HERE!!!! + // // TODO_RESUME_HERE!!!! + // // TODO_RESUME_HERE!!!! + // // + // // find "closest" previous assignment... + // // + // logger.Debug(). + // //Str("position", fmt.Sprintf(" %s ", pkg.Fset.Position(doa.Pos()).String())). + // Int("len(values)", len(doa.Values)). + // Msgf(">>>>>>> value spec: %+v", doa) + // + // if doa.Values != nil { + // for _, value := range doa.Values { + // traceExpressionStack(value, pkgs, expr, candidatePkg, candidateNode, offendingPositions) + // } + // } default: logger.Warn().Msgf("unknown node type 3: %T", doa) } @@ -441,7 +443,8 @@ func newFindDeclOrAssign( for _, fileNode := range pkg.Syntax { if declNode != nil { - return declNode, declPos + //return declNode, declPos + break } //fmt.Println(">>>>>>>>> NEW FILE NODE") @@ -455,14 +458,14 @@ func newFindDeclOrAssign( // ident.Name == targetIdent.Name { //if obj := pkg.TypesInfo.Defs[targetIdent]; obj != nil { // declPos = pkg.Fset.Position(obj.Pos()) - // //logger.Debug().Fields(map[string]any{ - // // //"pkg_path": pkg.PkgPath, - // // "file_path": fmt.Sprintf(" %s ", pkg.Fset.File(fileNode.Pos()).Name()), - // // "target_pos": fmt.Sprintf(" %s ", pkg.Fset.Position(targetIdent.Pos()).String()), - // // "decl_pos": fmt.Sprintf(" %s ", declPos.String()), - // //}).Msg("defs") + // logger.Debug().Fields(map[string]any{ + // //"pkg_path": pkg.PkgPath, + // "file_path": fmt.Sprintf(" %s ", pkg.Fset.File(fileNode.Pos()).Name()), + // "target_pos": fmt.Sprintf(" %s ", pkg.Fset.Position(targetIdent.Pos()).String()), + // "decl_pos": fmt.Sprintf(" %s ", declPos.String()), + // }).Msg("defs") // declNode = findNodeByPosition(pkg.Fset, fileNode, declPos) - // return false + // //return false //} else if obj := pkg.TypesInfo.Uses[targetIdent]; obj != nil { if obj := pkg.TypesInfo.Uses[targetIdent]; obj != nil { // TODO_IN_THIS_COMMIT: figure out why this is called so frequently. @@ -474,11 +477,27 @@ func newFindDeclOrAssign( }).Msg("uses") declPos = pkg.Fset.Position(obj.Pos()) declNode = findNodeByPosition(pkg.Fset, fileNode, declPos) - logger.Warn(). - Str("decl_node", fmt.Sprintf("%+v", declNode)). - Str("decl_pos", fmt.Sprintf(" %s ", declPos)). - Msg("found decl node") + if declNode != nil { + logger.Debug(). + Str("decl_node", fmt.Sprintf("%+v", declNode)). + Str("decl_pos", fmt.Sprintf(" %s ", declPos)). + Msg("found decl node") + } //return false + //} else if obj := pkg.TypesInfo.Defs[targetIdent]; obj != nil { + // declPos = pkg.Fset.Position(obj.Pos()) + // logger.Debug().Fields(map[string]any{ + // //"pkg_path": pkg.PkgPath, + // "file_path": fmt.Sprintf(" %s ", pkg.Fset.File(fileNode.Pos()).Name()), + // "target_pos": fmt.Sprintf(" %s ", pkg.Fset.Position(targetIdent.Pos()).String()), + // "decl_pos": fmt.Sprintf(" %s ", pkg.Fset.Position(obj.Pos()).String()), + // }).Msg("defs") + // declPos = pkg.Fset.Position(obj.Pos()) + // declNode = findNodeByPosition(pkg.Fset, fileNode, declPos) + // logger.Debug(). + // Str("decl_node", fmt.Sprintf("%+v", declNode)). + // Str("decl_pos", fmt.Sprintf(" %s ", declPos)). + // Msg("found decl node (def)") } // } // return true @@ -486,6 +505,146 @@ func newFindDeclOrAssign( } //fmt.Println(">>>>>>>>> DONE") + // TODO_IN_THIS_COMMIT: improve comment... + // Look through decl node to see if it contains a valudspec with values. + // If it does, return the value(s). + if declNode != nil { + ast.Inspect(declNode, func(n ast.Node) bool { + //if declNode != nil { + // return true + //} + + switch doa := n.(type) { + //case ast.Expr: + // traceExpressionStack(doa, pkgs, expr, candidatePkg, candidateNode, offendingPositions) + //case *ast.AssignStmt: + // logger.Debug().Msgf(">>>>>>> assign stmt: %+v", doa) + // // TODO_IN_THIS_COMMIT: what about len(Rhs) > 1? + // traceExpressionStack(doa.Rhs[0], pkgs, expr, candidatePkg, candidateNode, offendingPositions) + case *ast.ValueSpec: + // TODO_RESUME_HERE!!!! + // TODO_RESUME_HERE!!!! + // TODO_RESUME_HERE!!!! + // TODO_RESUME_HERE!!!! + // + // find "closest" previous assignment... + // + logger.Debug(). + //Str("position", fmt.Sprintf(" %s ", pkg.Fset.Position(doa.Pos()).String())). + Int("len(values)", len(doa.Values)). + Msgf(">>>>>>> value spec: %+v", doa) + + if doa.Values != nil { + logger.Debug().Msg("dao.Values != nil") + for _, value := range doa.Values { + //traceExpressionStack(value, pkgs, expr, candidatePkg, candidateNode, offendingPositions) + declPos = pkg.Fset.Position(value.Pos()) + declNode = value + } + } else { + logger.Debug().Msg("dao.Values == nil") + declNode = nil + } + } + + return true + }) + } else { + logger.Debug().Msgf("no declaration or assignment found for ident %q", targetIdent.String()) + } + + // TODO_IN_THIS_COMMIT: improve comment... + // If it does not, search the package for + // the ident and return the closest assignment. + if declNode == nil { + var assignsRhs []ast.Expr + for _, fileNode := range pkg.Syntax { + ast.Inspect(fileNode, func(n ast.Node) bool { + if assign, ok := n.(*ast.AssignStmt); ok { + for lhsIdx, lhs := range assign.Lhs { + // TODO_TECHDEBT: Ignoring assignments via selectors for now. + // E.g., `a.b = c` will not be considered. + lhsIdent, lhsIsIdent := lhs.(*ast.Ident) + if !lhsIsIdent { + continue + } + + if lhsIdent.Name != targetIdent.Name { + continue + } + + rhsIdx := 0 + if len(assign.Lhs) == len(assign.Rhs) { + rhsIdx = lhsIdx + } + + rhs := assign.Rhs[rhsIdx] + assignsRhs = append(assignsRhs, rhs) + } + } + return true + }) + } + + if len(assignsRhs) > 0 { + // TODO_IN_THIS_COMMIT: comment explaining what's going on here... + slices.SortFunc[[]ast.Expr, ast.Expr](assignsRhs, func(a, b ast.Expr) int { + aPos := pkg.Fset.Position(a.Pos()) + bPos := pkg.Fset.Position(b.Pos()) + + if aPos.Filename == bPos.Filename { + switch { + case aPos.Line < bPos.Line: + return -1 + case aPos.Line > bPos.Line: + return 1 + default: + return 0 + } + } else { + return 1 + } + + }) + + // DeclNode is the closest assignment whose position is less than or equal to the declPos. + var ( + closestAssignPos token.Position + closestAssignNode ast.Expr + targetIdentPos = pkg.Fset.Position(targetIdent.Pos()) + ) + for _, rhs := range assignsRhs { + if rhs == nil { + continue + } + + // DEV_NOTE: using pkg here assumes that rhs is in the same file as targetIdent. + // This SHOULD ALWAYS be the case for error type non-initialization declarations + // (e.g. var err error). I.e. we SHOULD NEVER be assigning an error value directly + // from aa pkg-level error variable. + rhsPos := pkg.Fset.Position(rhs.Pos()) + switch { + case rhsPos.Filename != targetIdentPos.Filename: + // TODO_TECHDEBT: handle case where rhs ident is defined in a different file. + logger.Debug(). + Str("assignment_position", rhsPos.String()). + Msg("ignoring assignment from different file") + continue + case rhsPos.Line < targetIdentPos.Line: + closestAssignPos = rhsPos + closestAssignNode = rhs + case rhsPos.Line == targetIdentPos.Line: + if rhsPos.Column <= targetIdentPos.Column { + closestAssignPos = rhsPos + closestAssignNode = rhs + } + } + } + declPos = closestAssignPos + declNode = closestAssignNode + } + } + return declNode, declPos } diff --git a/tools/scripts/protocheck/cmd/codes.go b/tools/scripts/protocheck/cmd/codes.go index a119ce4fa..2f8da9cab 100644 --- a/tools/scripts/protocheck/cmd/codes.go +++ b/tools/scripts/protocheck/cmd/codes.go @@ -6,4 +6,5 @@ const ( CodeRootCmdErr = ExitCode(iota + 1) CodePathWalkErr CodeUnstableProtosFound + CodeNonStatusGRPCErrorsFound ) diff --git a/tools/scripts/protocheck/cmd/status_errors.go b/tools/scripts/protocheck/cmd/status_errors.go index f4f9a20ca..bb9e85841 100644 --- a/tools/scripts/protocheck/cmd/status_errors.go +++ b/tools/scripts/protocheck/cmd/status_errors.go @@ -300,6 +300,8 @@ func checkModule(_ context.Context) error { if numOffendingLines > 5 { logger.Info().Msg(msg) } + + os.Exit(CodeNonStatusGRPCErrorsFound) } // --- END diff --git a/tools/scripts/protocheck/cmd/unstable.go b/tools/scripts/protocheck/cmd/unstable.go index f0e39d414..2e3f454a1 100644 --- a/tools/scripts/protocheck/cmd/unstable.go +++ b/tools/scripts/protocheck/cmd/unstable.go @@ -214,7 +214,7 @@ func excludeFileIfStableVisitFn( optName, optNameOk := getOptNodeName(optNode) if !optNameOk { - logger.Warn().Msgf( + logger.Debug().Msgf( "unable to extract option name from option node at %s:%d:%d", protoFilePath, optSrc.Line, optSrc.Col, ) @@ -238,7 +238,7 @@ func excludeFileIfStableVisitFn( if optValue != "true" { // Not the value we're looking for, continue traversing... - logger.Warn().Msgf( + logger.Debug().Msgf( "discovered an unstable_marshaler_all option with unexpected value %q at %s:%d:%d", optValue, protoFilePath, optSrc.Line, optSrc.Col, ) From 669765b91c732eaadb70d002ae3e1160cb3b2e3d Mon Sep 17 00:00:00 2001 From: Bryan White Date: Wed, 27 Nov 2024 10:54:46 +0100 Subject: [PATCH 13/18] chore: comment cleanup --- tools/scripts/protocheck/cmd/ast_tracing.go | 273 +----------------- tools/scripts/protocheck/cmd/status_errors.go | 49 +--- 2 files changed, 12 insertions(+), 310 deletions(-) diff --git a/tools/scripts/protocheck/cmd/ast_tracing.go b/tools/scripts/protocheck/cmd/ast_tracing.go index 04b2fa783..85a5df8c8 100644 --- a/tools/scripts/protocheck/cmd/ast_tracing.go +++ b/tools/scripts/protocheck/cmd/ast_tracing.go @@ -12,14 +12,11 @@ import ( const grpcStatusImportPath = "google.golang.org/grpc/status" -var TRACE bool - // TODO_IN_THIS_COMMIT: move & godoc... func walkFuncBody( pkg *packages.Package, pkgs []*packages.Package, - shouldAppend, - shouldExhonerate bool, + shouldAppend bool, ) func(ast.Node) bool { return func(n ast.Node) bool { if n == nil { @@ -32,23 +29,7 @@ func walkFuncBody( Bool("shouldAppend", shouldAppend). Msg("walking function body") - //position := pkg.Fset.Position(n.Pos()) - //logger.Debug().Msgf("position: %s", position.String()) - // - //inspectPos := pkg.Fset.Position(n.Pos()) - //fmt.Printf(">>> inspecting %T at %s\n", n, inspectPos) - // - //inspectLocation := inspectPos.String() - //if inspectLocation == "/home/bwhite/Projects/pokt/poktroll/x/proof/keeper/msg_server_create_claim.go:44:3" { - // fmt.Printf(">>> found it!\n") - //} - switch n := n.(type) { - //case *ast.BlockStmt: - // return true - //// Search for a return statement. - //case *ast.IfStmt: - // return true case *ast.ReturnStmt: lastResult := n.Results[len(n.Results)-1] inspectPosition := pkg.Fset.Position(lastResult.Pos()).String() @@ -61,17 +42,10 @@ func walkFuncBody( logger.Debug().Msgf("lastResult: %+v", lastResult) switch lastReturnArgNode := lastResult.(type) { - // `return nil, err` <-- last arg is an *ast.Ident. + // E.g. `return nil, err` <-- last arg is an *ast.Ident. case *ast.Ident: - //logger.Debug().Fields(map[string]any{ - // "node_type": fmt.Sprintf("%T", lastReturnArgNode), - // "inspectPosition": pkg.Fset.Position(lastReturnArgNode.Pos()).String(), - //}).Msg("traversing ast node") - - // TODO_IN_THIS_COMMIT: No need to check that the last return - // arg is an error type if we checked that the function returns - // an error as the last arg. - //if lastReturnArgNode.Name == "err" { + // DEV_NOTE: No need to check that the last return arg is an error type + // if we checked that the function returns an error as the last arg. if lastReturnArgNode.Obj == nil { logger.Debug().Msg("lastReturnArgNode.Obj is nil") return true @@ -85,8 +59,6 @@ func walkFuncBody( if def.Type().String() != "error" { logger.Debug().Msg("def is not error") - //inspectPosition := pkg.Fset.Position(lastReturnArgNode.Pos()).String() - //break return false } @@ -97,16 +69,13 @@ func walkFuncBody( traceExpressionStack(lastReturnArgNode, pkgs, nil, pkg, lastReturnArgNode, offendingPkgErrLineSet) return true - // `return nil, types.ErrXXX.Wrapf(...)` <-- last arg is a *ast.CallExpr. + // E.g. `return nil, types.ErrXXX.Wrapf(...)` <-- last arg is a *ast.CallExpr. case *ast.CallExpr: - //logger.Debug().Msg("inspecting ast node") if shouldAppend { logger.Debug().Msg("appending potential offending line") appendOffendingLine(inspectPosition) } traceExpressionStack(lastReturnArgNode, pkgs, nil, pkg, lastReturnArgNode, offendingPkgErrLineSet) - //TraverseCallStack(lastReturnArgNode, pkgs, 0, condition(lastReturnArgNode)) - //return false return true case *ast.SelectorExpr: @@ -119,15 +88,9 @@ func walkFuncBody( } default: - //logger.Debug().Str("node_types", fmt.Sprintf("%T", n)).Msg("NOT traversing ast node") return true } - //logger.Debug().Msg("appending potential offending line") - //appendOffendingLine(inspectPosition) - // - //return true - return true } } @@ -142,10 +105,8 @@ func traceSelectorExpr( offendingPositions map[string]struct{}, ) bool { logger.Debug().Msg("tracing selector expression") - //fmt.Println(">>>>>>>>> TRACE SELECTOR EXPR") for _, pkg := range pkgs { if selection := pkg.TypesInfo.Selections[expr]; selection != nil { - //logger.Debug().Msgf("<<<<<<< has selection: %s", selection.String()) for _, pkg2 := range pkgs { position := pkg2.Fset.Position(selection.Obj().Pos()) @@ -153,15 +114,13 @@ func traceSelectorExpr( for _, fileNode := range pkg2.Syntax { foundNode = findNodeByPosition(pkg2.Fset, fileNode, position) if foundNode != nil { - //traceExpressionStack(foundNode, pkgs, expr, pkg, foundNode, offendingPositions) logger.Debug(). Str("node_type", fmt.Sprintf("%T", foundNode)). Str("selection_position", fmt.Sprintf(" %s ", position)). Str("expr_position", fmt.Sprintf(" %s ", pkg.Fset.Position(expr.Pos()).String())). Str("found_node_position", fmt.Sprintf(" %s ", pkg2.Fset.Position(foundNode.Pos()).String())). Msg("found node") - //fmt.Printf(">>>>>>>>>> found node %T %+v\n", foundNode, foundNode) - //traceExpressionStack(foundNode.(*ast.Ident), pkgs, expr, pkg2, foundNode, offendingPositions) + var declNode *ast.FuncDecl ast.Inspect(fileNode, func(n ast.Node) bool { if declNode != nil { @@ -182,13 +141,11 @@ func traceSelectorExpr( if declNode != nil { logger.Debug().Str("decl_position", pkg2.Fset.Position(declNode.Pos()).String()).Msg("tracing decl node") logger.Debug().Str("decl_body", pkg2.Fset.Position(declNode.Body.Pos()).String()).Msg("tracing decl node body") - ast.Inspect(declNode.Body, walkFuncBody(pkg, pkgs, false, false)) - //walkFuncBody(pkg, pkgs)(declNode.Body) + ast.Inspect(declNode.Body, walkFuncBody(pkg, pkgs, false)) } else { logger.Debug().Msg("could not find decl node") } - //return false return true } } @@ -215,7 +172,6 @@ func traceSelectorExpr( default: pkgStr = pkgParts[1] } - //fmt.Printf(">>> pkgStr: %s\n", pkgStr) logger := logger.With( "node_type", fmt.Sprintf("%T", x), @@ -224,24 +180,10 @@ func traceSelectorExpr( ) logger.Debug().Msg("tracing selector expression") - //if obj := pkg.TypesInfo.Uses[x]; obj A!= nil { - //fmt.Printf("Base identifier %s resolved to: %s\n", x.Name, obj) - //fmt.Printf(">>> obj.String(): %s\n", obj.String()) - //fmt.Printf(">>> strings.Contains(obj.String(), grpcStatusImportPath): %v\n", strings.Contains(obj.String(), grpcStatusImportPath)) isMatch := strings.Contains(obj.String(), grpcStatusImportPath) && expr.Sel.Name == "Error" - //if isMatch { if isMatch { candidateNodePosition := pkg.Fset.Position(candidateNode.Pos()).String() - //inspectionPosition := pkg.Fset.Position(x.Pos()).String() - //fmt.Printf("Found target function %s in selector at %s\n", expr.Sel.Name, pkg.Fset.Position(expr.Pos())) - //fmt.Printf("Found offending candidate %s in selector at %s\n", expr.Sel.Name, pkg.Fset.Position(candidateNode.Pos()).String()) - //fmt.Printf("!!! - offendingPositions: %+v\n", offendingPositions) - //fmt.Printf("!!! - candPosition: %s", currentPosition) - //logger.Debug(). - // Str("candidate_pos", fmt.Sprintf(" %s ", candidateNodePosition)). - // //Str("offending_positions", fmt.Sprintf("%+v", offendingPositions)). - // Send() if _, ok := offendingPositions[candidateNodePosition]; ok { logger.Debug().Msgf("exhonerating %s", candidateNodePosition) delete(offendingPositions, candidateNodePosition) @@ -250,14 +192,6 @@ func traceSelectorExpr( } return false } - - //traceSelectorExpr(expr.Sel, candidatePkg, pkgs, candidateNode, offendingPositions) - - //} - //case *ast.SelectorExpr: // e.g., `obj.Method.Func` - // if traceSelectorExpr(x, info, fset) { - // return true - // } } else if obj = pkg.TypesInfo.Defs[x]; obj != nil { logger.Debug().Msgf("no use but def: %+v", obj) } else if obj = pkg.TypesInfo.Defs[expr.Sel]; obj != nil { @@ -266,8 +200,6 @@ func traceSelectorExpr( Str("name", expr.Sel.Name). Msgf("sel def") traceExpressionStack(expr.Sel, pkgs, expr, candidatePkg, candidateNode, offendingPositions) - //} else { - // logger.Debug().Msgf("no use or def: %+v, sel: %+v", x, expr.Sel) } } case *ast.SelectorExpr: // e.g., `obj.Method.Func` @@ -275,12 +207,6 @@ func traceSelectorExpr( return traceSelectorExpr(x, candidatePkg, pkgs, candidateNode, offendingPositions) case *ast.CallExpr: logger.Debug().Msgf("tracing call expression: %+v", expr) - //switch callExpr := x.Fun.(type) { - //case *ast.SelectorExpr: // e.g., `obj.Method.Func` - // return traceSelectorExpr(callExpr, pkgs, candidateNode, offendingPositions) - //default: - // logger.Debug().Msgf("skipping sub-selector call expression X type: %T", x) - //} traceExpressionStack(x.Fun, pkgs, expr, candidatePkg, candidateNode, offendingPositions) default: logger.Warn().Msgf("skipping selector expression X type: %T", x) @@ -301,11 +227,7 @@ func traceExpressionStack( return false } - logger := logger.With( - "node_type", fmt.Sprintf("%T", exprToTrace), - //"position", candidatePkg.Fset.Position(exprToTrace.Pos()).String(), - //"package", strings.Trim(strings.Split(obj.String(), " ")[2], "\"()"), - ) + logger := logger.With("node_type", fmt.Sprintf("%T", exprToTrace)) logger.Debug().Msg("tracing expression stack") switch expr := exprToTrace.(type) { @@ -313,22 +235,14 @@ func traceExpressionStack( return false case *ast.CallExpr: if sel, ok := expr.Fun.(*ast.SelectorExpr); ok { - //logger.Debug().Msg("tracing selector expression") return traceSelectorExpr(sel, candidatePkg, pkgs, candidateNode, offendingPositions) } - //logger.Debug().Msgf("tracing expression args: %+v", expr) - //for _, arg := range expr.Args { - // // TODO_IN_THIS_COMMIT: return traceExpressionStack... ? - // traceExpressionStack(arg, pkgs, candidateNode, offendingPositions) - // //return true - //} return true case *ast.BinaryExpr: logger.Debug().Msg("tracing binary expression") // TODO_IN_THIS_COMMIT: return traceExpressionStack... ? if traceExpressionStack(expr.X, pkgs, expr, candidatePkg, candidateNode, offendingPositions) { traceExpressionStack(expr.Y, pkgs, expr, candidatePkg, candidateNode, offendingPositions) - //return true } return true case *ast.ParenExpr: @@ -339,27 +253,6 @@ func traceExpressionStack( return traceSelectorExpr(expr, candidatePkg, pkgs, candidateNode, offendingPositions) case *ast.Ident: logger.Debug().Str("name", expr.Name).Msg("tracing ident") - //def := candidatePkg.TypesInfo.Defs[expr] - // TODO_IN_THIS_COMMIT: handle no def... - //x := def.Parent().Lookup(expr.Name) - // TODO_IN_THIS_COMMIT: handle no lookup... - //x.Pos() - - //var candidateFileNode ast.Node - //for _, fileNode := range candidatePkg.Syntax { - // ast.Inspect(fileNode, func(n ast.Node) bool { - // if n == candidateNode { - // candidateFileNode = fileNode - // return false - // } - // return true - // }) - //} - //for _, pkg := range pkgs { - // for _, fileNode := range pkg.Syntax { - //declOrAssign, _ := findDeclOrAssign(expr, candidateFileNode, candidatePkg) - //declOrAssign, _ := findDeclOrAssign(expr, expr, candidatePkg) - //declOrAssign, declOrAssignPos := findDeclOrAssign(expr, scopeNode, candidatePkg) // TODO_IN_THIS_COMMIT: return a slice of all decls and assignments // and their respective files/pkgs. @@ -368,109 +261,33 @@ func traceExpressionStack( return false } - //traceExpressionStack(declOrAssign.(ast.Expr), pkgs, expr, candidatePkg, candidateNode, offendingPositions) - //logger.Debug(). - // Str("pkg_path", candidatePkg.PkgPath). - // //Str("file_path", fmt.Sprintf(" %s", candidatePkg.Fset.File(candidateFileNode.Pos()).Name())). - // Str("decl_or_assign_pos", fmt.Sprintf(" %s ", declOrAssignPos)). - // Send() - //Msgf("found decl or assign: %+v", declOrAssign) switch doa := declOrAssign.(type) { case ast.Expr: traceExpressionStack(doa, pkgs, expr, candidatePkg, candidateNode, offendingPositions) case *ast.AssignStmt: - //logger.Debug().Msgf(">>>>>>> assign stmt: %+v", doa) - // TODO_IN_THIS_COMMIT: what about len(Rhs) > 1? traceExpressionStack(doa.Rhs[0], pkgs, expr, candidatePkg, candidateNode, offendingPositions) - //case *ast.ValueSpec: - // // TODO_RESUME_HERE!!!! - // // TODO_RESUME_HERE!!!! - // // TODO_RESUME_HERE!!!! - // // TODO_RESUME_HERE!!!! - // // - // // find "closest" previous assignment... - // // - // logger.Debug(). - // //Str("position", fmt.Sprintf(" %s ", pkg.Fset.Position(doa.Pos()).String())). - // Int("len(values)", len(doa.Values)). - // Msgf(">>>>>>> value spec: %+v", doa) - // - // if doa.Values != nil { - // for _, value := range doa.Values { - // traceExpressionStack(value, pkgs, expr, candidatePkg, candidateNode, offendingPositions) - // } - // } default: logger.Warn().Msgf("unknown node type 3: %T", doa) } - //} - //} return true - //case *ast.SliceExpr: - // logger.Debug().Msgf("tracing slice expression: %+v", expr) - // return true default: logger.Warn().Msgf("unknown node type 2: %T", exprToTrace) return true } } -//func newNewFindDeclOrAssign( -// targetIdent *ast.Ident, -// scopeNode ast.Node, -// pkg *packages.Package, -//) (declNode ast.Node, declPos token.Position) { -// var nodes []ast.Node -// ast.Inspect(scopeNode, func(n ast.Node) bool { -// if n != nil { -// nodes = append(nodes, n) -// } -// return true -// }) -// -// for i := len(nodes) - 1; i >= 0; i-- { -// -// } -//} - // TODO_IN_THIS_COMMIT: move & godoc... func newFindDeclOrAssign( targetIdent *ast.Ident, - //pkgs []*packages.Package, pkg *packages.Package, ) (declNode ast.Node, declPos token.Position) { - //var closestDeclNode ast.Node - for _, fileNode := range pkg.Syntax { if declNode != nil { - //return declNode, declPos break } - //fmt.Println(">>>>>>>>> NEW FILE NODE") - //ast.Inspect(fileNode, func(n ast.Node) bool { - // //if declNode != nil { - // // //fmt.Println(">>>>>>>>> EXITING EARLY") - // // return false - // //} - // - // if ident, ok := n.(*ast.Ident); ok && - // ident.Name == targetIdent.Name { - //if obj := pkg.TypesInfo.Defs[targetIdent]; obj != nil { - // declPos = pkg.Fset.Position(obj.Pos()) - // logger.Debug().Fields(map[string]any{ - // //"pkg_path": pkg.PkgPath, - // "file_path": fmt.Sprintf(" %s ", pkg.Fset.File(fileNode.Pos()).Name()), - // "target_pos": fmt.Sprintf(" %s ", pkg.Fset.Position(targetIdent.Pos()).String()), - // "decl_pos": fmt.Sprintf(" %s ", declPos.String()), - // }).Msg("defs") - // declNode = findNodeByPosition(pkg.Fset, fileNode, declPos) - // //return false - //} else if obj := pkg.TypesInfo.Uses[targetIdent]; obj != nil { if obj := pkg.TypesInfo.Uses[targetIdent]; obj != nil { - // TODO_IN_THIS_COMMIT: figure out why this is called so frequently. logger.Debug().Fields(map[string]any{ - //"pkg_path": pkg.PkgPath, "file_path": fmt.Sprintf(" %s ", pkg.Fset.File(fileNode.Pos()).Name()), "target_pos": fmt.Sprintf(" %s ", pkg.Fset.Position(targetIdent.Pos()).String()), "decl_pos": fmt.Sprintf(" %s ", pkg.Fset.Position(obj.Pos()).String()), @@ -483,61 +300,23 @@ func newFindDeclOrAssign( Str("decl_pos", fmt.Sprintf(" %s ", declPos)). Msg("found decl node") } - //return false - //} else if obj := pkg.TypesInfo.Defs[targetIdent]; obj != nil { - // declPos = pkg.Fset.Position(obj.Pos()) - // logger.Debug().Fields(map[string]any{ - // //"pkg_path": pkg.PkgPath, - // "file_path": fmt.Sprintf(" %s ", pkg.Fset.File(fileNode.Pos()).Name()), - // "target_pos": fmt.Sprintf(" %s ", pkg.Fset.Position(targetIdent.Pos()).String()), - // "decl_pos": fmt.Sprintf(" %s ", pkg.Fset.Position(obj.Pos()).String()), - // }).Msg("defs") - // declPos = pkg.Fset.Position(obj.Pos()) - // declNode = findNodeByPosition(pkg.Fset, fileNode, declPos) - // logger.Debug(). - // Str("decl_node", fmt.Sprintf("%+v", declNode)). - // Str("decl_pos", fmt.Sprintf(" %s ", declPos)). - // Msg("found decl node (def)") } - // } - // return true - //}) } - //fmt.Println(">>>>>>>>> DONE") // TODO_IN_THIS_COMMIT: improve comment... // Look through decl node to see if it contains a valudspec with values. // If it does, return the value(s). if declNode != nil { ast.Inspect(declNode, func(n ast.Node) bool { - //if declNode != nil { - // return true - //} - switch doa := n.(type) { - //case ast.Expr: - // traceExpressionStack(doa, pkgs, expr, candidatePkg, candidateNode, offendingPositions) - //case *ast.AssignStmt: - // logger.Debug().Msgf(">>>>>>> assign stmt: %+v", doa) - // // TODO_IN_THIS_COMMIT: what about len(Rhs) > 1? - // traceExpressionStack(doa.Rhs[0], pkgs, expr, candidatePkg, candidateNode, offendingPositions) case *ast.ValueSpec: - // TODO_RESUME_HERE!!!! - // TODO_RESUME_HERE!!!! - // TODO_RESUME_HERE!!!! - // TODO_RESUME_HERE!!!! - // - // find "closest" previous assignment... - // logger.Debug(). - //Str("position", fmt.Sprintf(" %s ", pkg.Fset.Position(doa.Pos()).String())). Int("len(values)", len(doa.Values)). Msgf(">>>>>>> value spec: %+v", doa) if doa.Values != nil { logger.Debug().Msg("dao.Values != nil") for _, value := range doa.Values { - //traceExpressionStack(value, pkgs, expr, candidatePkg, candidateNode, offendingPositions) declPos = pkg.Fset.Position(value.Pos()) declNode = value } @@ -655,8 +434,6 @@ func findNodeByPosition( fileNode *ast.File, position token.Position, ) (targetNode ast.Node) { - //fmt.Println(">>>>>>>>> FIND NODE BY POSITION") - ast.Inspect(fileNode, func(n ast.Node) bool { if targetNode != nil { return false @@ -683,22 +460,14 @@ func findNodeByPosition( // Find the declaration or assignment of an identifier // func findDeclOrAssign(ident *ast.Ident, fileNode ast.Node, pkg *packages.Package) (ast.Expr, *token.Position) { func findDeclOrAssign(ident *ast.Ident, scopeNode ast.Node, pkg *packages.Package) (ast.Expr, *token.Position) { - //logger.Debug().Msg("finding decl or assign") - - //fmt.Println("!!!! findDeclOrAssign begin") var declOrAssign ast.Expr var foundPos token.Position - //fmt.Println("!!!! findDeclOrAssign inspect") - //for _, fileNode := range pkg.Syntax { - //ast.Inspect(fileNode, func(n ast.Node) bool { ast.Inspect(scopeNode, func(n ast.Node) bool { switch stmt := n.(type) { case *ast.AssignStmt: // Look for assignments - //fmt.Println("!!!! findDeclOrAssign case assign") for i, lhs := range stmt.Lhs { if lhsIdent, ok := lhs.(*ast.Ident); ok && lhsIdent.Name == ident.Name { - //fmt.Printf("len(rhs): %d len(lhs): %d\n", len(stmt.Rhs), len(stmt.Lhs)) if len(stmt.Lhs) != len(stmt.Rhs) { declOrAssign = stmt.Rhs[0] } else { @@ -708,7 +477,6 @@ func findDeclOrAssign(ident *ast.Ident, scopeNode ast.Node, pkg *packages.Packag } } case *ast.ValueSpec: // Look for declarations with initialization - //fmt.Println("!!!! findDeclOrAssign case value") for i, name := range stmt.Names { if name.Name == ident.Name && i < len(stmt.Values) { declOrAssign = stmt.Values[i] @@ -719,31 +487,6 @@ func findDeclOrAssign(ident *ast.Ident, scopeNode ast.Node, pkg *packages.Packag } return true }) - //} return declOrAssign, &foundPos } - -// TODO_IN_THIS_COMMIT: move & godoc... -//func getNodeFromPosition(fset *token.FileSet, position token.Position) ast.Node { -// file := fset.File(position) -// if file == nil { -// return nil -// } -// -// var node ast.Node -// ast.Inspect(file, func(n ast.Node) bool { -// if n == nil { -// return false -// } -// -// if fset.Position(n.Pos()).String() == position.String() { -// node = n -// return false -// } -// -// return true -// }) -// -// return node -//} diff --git a/tools/scripts/protocheck/cmd/status_errors.go b/tools/scripts/protocheck/cmd/status_errors.go index bb9e85841..5b163e18b 100644 --- a/tools/scripts/protocheck/cmd/status_errors.go +++ b/tools/scripts/protocheck/cmd/status_errors.go @@ -20,9 +20,8 @@ import ( var ( flagModule = "module" flagModuleShorthand = "m" - // TODO_IN_THIS_COMMIT: support this flag. - flagModuleValue = "*" - flagModuleUsage = "If present, only check message handlers of the given module." + flagModuleValue = "*" + flagModuleUsage = "If present, only check message handlers of the given module." flagLogLevel = "log-level" flagLogLevelShorthand = "l" @@ -53,7 +52,6 @@ func setupLogger(_ *cobra.Command, _ []string) { ) } -// TODO_IN_THIS_COMMIT: pre-run: drop patch version in go.mod; post-run: restore. func runStatusErrorsCheck(cmd *cobra.Command, _ []string) error { ctx := cmd.Context() @@ -73,21 +71,10 @@ func runStatusErrorsCheck(cmd *cobra.Command, _ []string) error { } } - // TODO_IN_THIS_COMMIT: add hack/work-around to temporarily strip patch version from go.mod. - // TODO_IN_THIS_COMMIT: add hack/work-around to temporarily strip patch version from go.mod. - // TODO_IN_THIS_COMMIT: add hack/work-around to temporarily strip patch version from go.mod. - - // TODO_IN_THIS_COMMIT: to support this, need to load all modules but only inspect target module. - //if flagModule != "*" { - // ... - //} - - //for module := range poktrollModules { - // if err := checkModule(ctx, module); err != nil { + // TODO_IN_THIS_COMMIT: refactor... if err := checkModule(ctx); err != nil { return err } - //} return nil } @@ -107,20 +94,15 @@ func checkModule(_ context.Context) error { // 3. Recursively traverse the method body to find all of its error returns. // 4. Lookup error assignments to ensure that they are wrapped in gRPC status errors. - // TODO: import polyzero for side effects. - //logger := polylog.Ctx(ctx) - // TODO_IN_THIS_COMMIT: extract --- BEGIN // Set up the package configuration cfg := &packages.Config{ - //Mode: packages.NeedName | packages.NeedFiles | packages.NeedImports | packages.NeedTypes | packages.NeedTypesInfo | packages.LoadSyntax, Mode: packages.LoadSyntax, Tests: false, } // Load the package containing the target file or directory poktrollPkgPathPattern := "github.com/pokt-network/poktroll/x/..." - //logger.Info().Msgf("Loading package(s) in %s", poktrollPkgPathPattern) pkgs, err := packages.Load(cfg, poktrollPkgPathPattern) if err != nil { @@ -160,12 +142,6 @@ func checkModule(_ context.Context) error { // --- END - //filenames := make([]string, 0) - //for _, astFile := range pkg.Syntax { - // filenames = append(filenames, filepath.Base(pkg.Fset.Position(astFile.Pos()).Filename)) - //} - //fmt.Printf(">>> filenames:\n%s\n", strings.Join(filenames, "\n")) - // TODO_IN_THIS_COMMIT: extract --- BEGIN // TODO_IN_THIS_COMMIT: check the filename and only inspect each once! for _, astFile := range pkg.Syntax { @@ -179,12 +155,6 @@ func checkModule(_ context.Context) error { continue } - // TODO_IN_THIS_COMMIT: remove! - //fmt.Printf(">>> filename: %s\n", filename) - //if filename != "/home/bwhite/Projects/pokt/poktroll/x/application/keeper/msg_server_delegate_to_gateway.go" { - // continue - //} - ast.Inspect(astFile, func(n ast.Node) bool { fnNode, ok := n.(*ast.FuncDecl) if !ok { @@ -212,11 +182,6 @@ func checkModule(_ context.Context) error { return false } - //fmt.Printf(">>> fNode.Name.Name: %s\n", fnNode.Name.Name) - //if fnNode.Name.Name != "AllApplications" { - // return false - //} - // TODO_IN_THIS_COMMIT: check the signature of the method to ensure it returns an error type. fnResultsList := fnNode.Type.Results.List fnLastResultType := fnResultsList[len(fnResultsList)-1].Type @@ -233,10 +198,8 @@ func checkModule(_ context.Context) error { } fnPos := pkg.Fset.Position(fnNode.Pos()) - //fmt.Printf(">>> fnNode.Pos(): %s\n", fnPos.String()) fnFilename := filepath.Base(fnPos.Filename) fnSourceHasQueryHandlerPrefix := strings.HasPrefix(fnFilename, "query_") - //fnSourceHasQueryHandlerPrefix := false if typeIdentNode.Name != "msgServer" && !fnSourceHasQueryHandlerPrefix { return false @@ -254,10 +217,7 @@ func checkModule(_ context.Context) error { } // Recursively traverse the function body, looking for non-nil error returns. - // TODO_IN_THIS_COMMIT: extract --- BEGIN - //fmt.Printf(">>> walking func from file: %s\n", pkg.Fset.Position(astFile.Pos()).Filename) - ast.Inspect(fnNode.Body, walkFuncBody(pkg, pkgs, true, true)) - // --- END + ast.Inspect(fnNode.Body, walkFuncBody(pkg, pkgs, true)) return false }) @@ -268,7 +228,6 @@ func checkModule(_ context.Context) error { // --- END // TODO_IN_THIS_COMMIT: extract --- BEGIN - // TODO_IN_THIS_COMMIT: figure out why there are duplicate offending lines. // Print offending lines in package // TODO_IN_THIS_COMMIT: refactor to const. pkgsPattern := "github.com/pokt-network/poktroll/x/..." From 4169627180613c42d7f6372a16fdfa30e30ed486 Mon Sep 17 00:00:00 2001 From: Bryan White Date: Wed, 27 Nov 2024 13:49:54 +0100 Subject: [PATCH 14/18] wip: refactoring and cleaning up --- tools/scripts/protocheck/cmd/ast_tracing.go | 492 ------------------ tools/scripts/protocheck/cmd/status_errors.go | 293 ++++++----- tools/scripts/protocheck/cmd/unstable.go | 4 +- tools/scripts/protocheck/goast/find.go | 239 +++++++++ tools/scripts/protocheck/goast/inspect.go | 98 ++++ tools/scripts/protocheck/goast/trace.go | 206 ++++++++ .../keeper/msg_server_delegate_to_gateway.go | 1 - .../keeper/msg_server_stake_supplier.go | 4 +- 8 files changed, 698 insertions(+), 639 deletions(-) delete mode 100644 tools/scripts/protocheck/cmd/ast_tracing.go create mode 100644 tools/scripts/protocheck/goast/find.go create mode 100644 tools/scripts/protocheck/goast/inspect.go create mode 100644 tools/scripts/protocheck/goast/trace.go diff --git a/tools/scripts/protocheck/cmd/ast_tracing.go b/tools/scripts/protocheck/cmd/ast_tracing.go deleted file mode 100644 index 85a5df8c8..000000000 --- a/tools/scripts/protocheck/cmd/ast_tracing.go +++ /dev/null @@ -1,492 +0,0 @@ -package main - -import ( - "fmt" - "go/ast" - "go/token" - "slices" - "strings" - - "golang.org/x/tools/go/packages" -) - -const grpcStatusImportPath = "google.golang.org/grpc/status" - -// TODO_IN_THIS_COMMIT: move & godoc... -func walkFuncBody( - pkg *packages.Package, - pkgs []*packages.Package, - shouldAppend bool, -) func(ast.Node) bool { - return func(n ast.Node) bool { - if n == nil { - return false - } - - logger.Debug(). - Str("position", fmt.Sprintf(" %s ", pkg.Fset.Position(n.Pos()).String())). - Str("node_type", fmt.Sprintf("%T", n)). - Bool("shouldAppend", shouldAppend). - Msg("walking function body") - - switch n := n.(type) { - case *ast.ReturnStmt: - lastResult := n.Results[len(n.Results)-1] - inspectPosition := pkg.Fset.Position(lastResult.Pos()).String() - - logger := logger.With( - "node_type", fmt.Sprintf("%T", lastResult), - "inspectPosition", fmt.Sprintf(" %s ", inspectPosition), - ) - - logger.Debug().Msgf("lastResult: %+v", lastResult) - - switch lastReturnArgNode := lastResult.(type) { - // E.g. `return nil, err` <-- last arg is an *ast.Ident. - case *ast.Ident: - // DEV_NOTE: No need to check that the last return arg is an error type - // if we checked that the function returns an error as the last arg. - if lastReturnArgNode.Obj == nil { - logger.Debug().Msg("lastReturnArgNode.Obj is nil") - return true - } - - def := pkg.TypesInfo.Uses[lastReturnArgNode] - if def == nil { - logger.Debug().Msg("def is nil") - return true - } - - if def.Type().String() != "error" { - logger.Debug().Msg("def is not error") - return false - } - - if shouldAppend { - logger.Debug().Msg("appending potential offending line") - appendOffendingLine(inspectPosition) - } - traceExpressionStack(lastReturnArgNode, pkgs, nil, pkg, lastReturnArgNode, offendingPkgErrLineSet) - return true - - // E.g. `return nil, types.ErrXXX.Wrapf(...)` <-- last arg is a *ast.CallExpr. - case *ast.CallExpr: - if shouldAppend { - logger.Debug().Msg("appending potential offending line") - appendOffendingLine(inspectPosition) - } - traceExpressionStack(lastReturnArgNode, pkgs, nil, pkg, lastReturnArgNode, offendingPkgErrLineSet) - return true - - case *ast.SelectorExpr: - if shouldAppend { - logger.Debug().Msg("appending potential offending line") - appendOffendingLine(inspectPosition) - } - traceSelectorExpr(lastReturnArgNode, pkg, pkgs, lastReturnArgNode, offendingPkgErrLineSet) - return true - } - - default: - return true - } - - return true - } -} - -// Helper function to trace selector expressions -func traceSelectorExpr( - expr *ast.SelectorExpr, - //scopeNode ast.Node, - candidatePkg *packages.Package, - pkgs []*packages.Package, - candidateNode ast.Node, - offendingPositions map[string]struct{}, -) bool { - logger.Debug().Msg("tracing selector expression") - for _, pkg := range pkgs { - if selection := pkg.TypesInfo.Selections[expr]; selection != nil { - for _, pkg2 := range pkgs { - position := pkg2.Fset.Position(selection.Obj().Pos()) - - var foundNode ast.Node - for _, fileNode := range pkg2.Syntax { - foundNode = findNodeByPosition(pkg2.Fset, fileNode, position) - if foundNode != nil { - logger.Debug(). - Str("node_type", fmt.Sprintf("%T", foundNode)). - Str("selection_position", fmt.Sprintf(" %s ", position)). - Str("expr_position", fmt.Sprintf(" %s ", pkg.Fset.Position(expr.Pos()).String())). - Str("found_node_position", fmt.Sprintf(" %s ", pkg2.Fset.Position(foundNode.Pos()).String())). - Msg("found node") - - var declNode *ast.FuncDecl - ast.Inspect(fileNode, func(n ast.Node) bool { - if declNode != nil { - return false - } - - if decl, ok := n.(*ast.FuncDecl); ok { - if decl.Name.Name == foundNode.(*ast.Ident).Name && - decl.Pos() < foundNode.Pos() && - foundNode.Pos() <= decl.End() { - declNode = decl - return false - } - } - return true - }) - - if declNode != nil { - logger.Debug().Str("decl_position", pkg2.Fset.Position(declNode.Pos()).String()).Msg("tracing decl node") - logger.Debug().Str("decl_body", pkg2.Fset.Position(declNode.Body.Pos()).String()).Msg("tracing decl node body") - ast.Inspect(declNode.Body, walkFuncBody(pkg, pkgs, false)) - } else { - logger.Debug().Msg("could not find decl node") - } - - return true - } - } - } - return true - } - } - - // TODO_IN_THIS_COMMIT: refactor; below happens when the selector is not found within any package. - - // Resolve the base expression - switch x := expr.X.(type) { - case *ast.Ident: // e.g., `pkg.Func` - for _, pkg := range pkgs { - if obj := pkg.TypesInfo.Uses[x]; obj != nil { - pkgParts := strings.Split(obj.String(), " ") - - var pkgStr string - switch { - // e.g., package (error) google.golang.org/grpc/status - case strings.HasPrefix(obj.String(), "package ("): - pkgStr = pkgParts[2] - // e.g. package fmt - default: - pkgStr = pkgParts[1] - } - - logger := logger.With( - "node_type", fmt.Sprintf("%T", x), - "position", fmt.Sprintf(" %s ", pkg.Fset.Position(x.Pos()).String()), - "package", strings.Trim(pkgStr, "\"()"), - ) - logger.Debug().Msg("tracing selector expression") - - isMatch := strings.Contains(obj.String(), grpcStatusImportPath) && - expr.Sel.Name == "Error" - if isMatch { - candidateNodePosition := pkg.Fset.Position(candidateNode.Pos()).String() - if _, ok := offendingPositions[candidateNodePosition]; ok { - logger.Debug().Msgf("exhonerating %s", candidateNodePosition) - delete(offendingPositions, candidateNodePosition) - } else { - logger.Warn().Msgf("can't exhonerate %s", candidateNodePosition) - } - return false - } - } else if obj = pkg.TypesInfo.Defs[x]; obj != nil { - logger.Debug().Msgf("no use but def: %+v", obj) - } else if obj = pkg.TypesInfo.Defs[expr.Sel]; obj != nil { - logger.Debug(). - Str("pkg_path", pkg.PkgPath). - Str("name", expr.Sel.Name). - Msgf("sel def") - traceExpressionStack(expr.Sel, pkgs, expr, candidatePkg, candidateNode, offendingPositions) - } - } - case *ast.SelectorExpr: // e.g., `obj.Method.Func` - logger.Debug().Msgf("tracing recursive selector expression: %+v", expr) - return traceSelectorExpr(x, candidatePkg, pkgs, candidateNode, offendingPositions) - case *ast.CallExpr: - logger.Debug().Msgf("tracing call expression: %+v", expr) - traceExpressionStack(x.Fun, pkgs, expr, candidatePkg, candidateNode, offendingPositions) - default: - logger.Warn().Msgf("skipping selector expression X type: %T", x) - } - return true -} - -// Trace any expression recursively, including selector expressions -func traceExpressionStack( - exprToTrace ast.Expr, - pkgs []*packages.Package, - _ ast.Node, - candidatePkg *packages.Package, - candidateNode ast.Node, - offendingPositions map[string]struct{}, -) bool { - if exprToTrace == nil { - return false - } - - logger := logger.With("node_type", fmt.Sprintf("%T", exprToTrace)) - logger.Debug().Msg("tracing expression stack") - - switch expr := exprToTrace.(type) { - case nil: - return false - case *ast.CallExpr: - if sel, ok := expr.Fun.(*ast.SelectorExpr); ok { - return traceSelectorExpr(sel, candidatePkg, pkgs, candidateNode, offendingPositions) - } - return true - case *ast.BinaryExpr: - logger.Debug().Msg("tracing binary expression") - // TODO_IN_THIS_COMMIT: return traceExpressionStack... ? - if traceExpressionStack(expr.X, pkgs, expr, candidatePkg, candidateNode, offendingPositions) { - traceExpressionStack(expr.Y, pkgs, expr, candidatePkg, candidateNode, offendingPositions) - } - return true - case *ast.ParenExpr: - logger.Debug().Msg("tracing paren expression") - return traceExpressionStack(expr.X, pkgs, expr, candidatePkg, candidateNode, offendingPositions) - case *ast.SelectorExpr: - logger.Debug().Msg("tracing selector expression") - return traceSelectorExpr(expr, candidatePkg, pkgs, candidateNode, offendingPositions) - case *ast.Ident: - logger.Debug().Str("name", expr.Name).Msg("tracing ident") - - // TODO_IN_THIS_COMMIT: return a slice of all decls and assignments - // and their respective files/pkgs. - declOrAssign, _ := newFindDeclOrAssign(expr, candidatePkg) - if declOrAssign == nil { - return false - } - - switch doa := declOrAssign.(type) { - case ast.Expr: - traceExpressionStack(doa, pkgs, expr, candidatePkg, candidateNode, offendingPositions) - case *ast.AssignStmt: - traceExpressionStack(doa.Rhs[0], pkgs, expr, candidatePkg, candidateNode, offendingPositions) - default: - logger.Warn().Msgf("unknown node type 3: %T", doa) - } - return true - default: - logger.Warn().Msgf("unknown node type 2: %T", exprToTrace) - return true - } -} - -// TODO_IN_THIS_COMMIT: move & godoc... -func newFindDeclOrAssign( - targetIdent *ast.Ident, - pkg *packages.Package, -) (declNode ast.Node, declPos token.Position) { - for _, fileNode := range pkg.Syntax { - if declNode != nil { - break - } - - if obj := pkg.TypesInfo.Uses[targetIdent]; obj != nil { - logger.Debug().Fields(map[string]any{ - "file_path": fmt.Sprintf(" %s ", pkg.Fset.File(fileNode.Pos()).Name()), - "target_pos": fmt.Sprintf(" %s ", pkg.Fset.Position(targetIdent.Pos()).String()), - "decl_pos": fmt.Sprintf(" %s ", pkg.Fset.Position(obj.Pos()).String()), - }).Msg("uses") - declPos = pkg.Fset.Position(obj.Pos()) - declNode = findNodeByPosition(pkg.Fset, fileNode, declPos) - if declNode != nil { - logger.Debug(). - Str("decl_node", fmt.Sprintf("%+v", declNode)). - Str("decl_pos", fmt.Sprintf(" %s ", declPos)). - Msg("found decl node") - } - } - } - - // TODO_IN_THIS_COMMIT: improve comment... - // Look through decl node to see if it contains a valudspec with values. - // If it does, return the value(s). - if declNode != nil { - ast.Inspect(declNode, func(n ast.Node) bool { - switch doa := n.(type) { - case *ast.ValueSpec: - logger.Debug(). - Int("len(values)", len(doa.Values)). - Msgf(">>>>>>> value spec: %+v", doa) - - if doa.Values != nil { - logger.Debug().Msg("dao.Values != nil") - for _, value := range doa.Values { - declPos = pkg.Fset.Position(value.Pos()) - declNode = value - } - } else { - logger.Debug().Msg("dao.Values == nil") - declNode = nil - } - } - - return true - }) - } else { - logger.Debug().Msgf("no declaration or assignment found for ident %q", targetIdent.String()) - } - - // TODO_IN_THIS_COMMIT: improve comment... - // If it does not, search the package for - // the ident and return the closest assignment. - if declNode == nil { - var assignsRhs []ast.Expr - for _, fileNode := range pkg.Syntax { - ast.Inspect(fileNode, func(n ast.Node) bool { - if assign, ok := n.(*ast.AssignStmt); ok { - for lhsIdx, lhs := range assign.Lhs { - // TODO_TECHDEBT: Ignoring assignments via selectors for now. - // E.g., `a.b = c` will not be considered. - lhsIdent, lhsIsIdent := lhs.(*ast.Ident) - if !lhsIsIdent { - continue - } - - if lhsIdent.Name != targetIdent.Name { - continue - } - - rhsIdx := 0 - if len(assign.Lhs) == len(assign.Rhs) { - rhsIdx = lhsIdx - } - - rhs := assign.Rhs[rhsIdx] - assignsRhs = append(assignsRhs, rhs) - } - } - return true - }) - } - - if len(assignsRhs) > 0 { - // TODO_IN_THIS_COMMIT: comment explaining what's going on here... - slices.SortFunc[[]ast.Expr, ast.Expr](assignsRhs, func(a, b ast.Expr) int { - aPos := pkg.Fset.Position(a.Pos()) - bPos := pkg.Fset.Position(b.Pos()) - - if aPos.Filename == bPos.Filename { - switch { - case aPos.Line < bPos.Line: - return -1 - case aPos.Line > bPos.Line: - return 1 - default: - return 0 - } - } else { - return 1 - } - - }) - - // DeclNode is the closest assignment whose position is less than or equal to the declPos. - var ( - closestAssignPos token.Position - closestAssignNode ast.Expr - targetIdentPos = pkg.Fset.Position(targetIdent.Pos()) - ) - for _, rhs := range assignsRhs { - if rhs == nil { - continue - } - - // DEV_NOTE: using pkg here assumes that rhs is in the same file as targetIdent. - // This SHOULD ALWAYS be the case for error type non-initialization declarations - // (e.g. var err error). I.e. we SHOULD NEVER be assigning an error value directly - // from aa pkg-level error variable. - rhsPos := pkg.Fset.Position(rhs.Pos()) - switch { - case rhsPos.Filename != targetIdentPos.Filename: - // TODO_TECHDEBT: handle case where rhs ident is defined in a different file. - logger.Debug(). - Str("assignment_position", rhsPos.String()). - Msg("ignoring assignment from different file") - continue - case rhsPos.Line < targetIdentPos.Line: - closestAssignPos = rhsPos - closestAssignNode = rhs - case rhsPos.Line == targetIdentPos.Line: - if rhsPos.Column <= targetIdentPos.Column { - closestAssignPos = rhsPos - closestAssignNode = rhs - } - } - } - declPos = closestAssignPos - declNode = closestAssignNode - } - } - - return declNode, declPos -} - -// TODO_IN_THIS_COMMIT: move & godoc... -// search for targetIdent by position -func findNodeByPosition( - fset *token.FileSet, - fileNode *ast.File, - position token.Position, -) (targetNode ast.Node) { - ast.Inspect(fileNode, func(n ast.Node) bool { - if targetNode != nil { - return false - } - - if n == nil { - return true - } - - if n != nil && fset.Position(n.Pos()) == position { - targetNode = n - return false - } - - if targetNode != nil { - return false - } - - return true - }) - return targetNode -} - -// Find the declaration or assignment of an identifier -// func findDeclOrAssign(ident *ast.Ident, fileNode ast.Node, pkg *packages.Package) (ast.Expr, *token.Position) { -func findDeclOrAssign(ident *ast.Ident, scopeNode ast.Node, pkg *packages.Package) (ast.Expr, *token.Position) { - var declOrAssign ast.Expr - var foundPos token.Position - - ast.Inspect(scopeNode, func(n ast.Node) bool { - switch stmt := n.(type) { - case *ast.AssignStmt: // Look for assignments - for i, lhs := range stmt.Lhs { - if lhsIdent, ok := lhs.(*ast.Ident); ok && lhsIdent.Name == ident.Name { - if len(stmt.Lhs) != len(stmt.Rhs) { - declOrAssign = stmt.Rhs[0] - } else { - declOrAssign = stmt.Rhs[i] - } - foundPos = pkg.Fset.Position(stmt.Pos()) - } - } - case *ast.ValueSpec: // Look for declarations with initialization - for i, name := range stmt.Names { - if name.Name == ident.Name && i < len(stmt.Values) { - declOrAssign = stmt.Values[i] - foundPos = pkg.Fset.Position(stmt.Pos()) - return false - } - } - } - return true - }) - - return declOrAssign, &foundPos -} diff --git a/tools/scripts/protocheck/cmd/status_errors.go b/tools/scripts/protocheck/cmd/status_errors.go index 5b163e18b..5f2990a76 100644 --- a/tools/scripts/protocheck/cmd/status_errors.go +++ b/tools/scripts/protocheck/cmd/status_errors.go @@ -15,9 +15,16 @@ import ( "github.com/pokt-network/poktroll/pkg/polylog" "github.com/pokt-network/poktroll/pkg/polylog/polyzero" + "github.com/pokt-network/poktroll/tools/scripts/protocheck/goast" +) + +const ( + poktrollMoudlePkgsPattern = "github.com/pokt-network/poktroll/x/..." ) var ( + poktrollModulesRootPkgPath = filepath.Dir(poktrollMoudlePkgsPattern) + flagModule = "module" flagModuleShorthand = "m" flagModuleValue = "*" @@ -53,58 +60,25 @@ func setupLogger(_ *cobra.Command, _ []string) { } func runStatusErrorsCheck(cmd *cobra.Command, _ []string) error { - ctx := cmd.Context() - - // TODO_IN_THIS_COMMIT: extract to validation function. - if flagModuleValue != "*" { - switch flagModuleValue { - case "application": - case "gateway": - case "proof": - case "service": - case "session": - case "shared": - case "supplier": - case "tokenomics": - default: - return fmt.Errorf("ERROR: invalid module name: %s", flagModuleValue) - } - } - - // TODO_IN_THIS_COMMIT: refactor... - if err := checkModule(ctx); err != nil { + if err := validateModuleFlag(); err != nil { return err } - return nil -} - -// TODO_IN_THIS_COMMIT: 2-step check -// 1. Collect all return statements from `msgServer` methods and `Keeper` methods in `query_*.go` files. -// 2. For each return statement, check the type: -// *ast.Ident: search this package ... -// *ast.SelectorExpr: search the package of its declaration... -// ...for an *ast.AssignStmt with the given *ast.Ident as the left-hand side. - -func checkModule(_ context.Context) error { - - // 0. Get the package info for the given module's keeper package. + // 0. Get the package info for ALL module packages. // 1. Find the message server struct for the given module. // 2. Recursively traverse `msg_server_*.go` files to find all of its methods. // 3. Recursively traverse the method body to find all of its error returns. // 4. Lookup error assignments to ensure that they are wrapped in gRPC status errors. - // TODO_IN_THIS_COMMIT: extract --- BEGIN // Set up the package configuration cfg := &packages.Config{ Mode: packages.LoadSyntax, Tests: false, } + // TODO_IN_THIS_COMMIT: update comment... // Load the package containing the target file or directory - poktrollPkgPathPattern := "github.com/pokt-network/poktroll/x/..." - - pkgs, err := packages.Load(cfg, poktrollPkgPathPattern) + pkgs, err := packages.Load(cfg, poktrollMoudlePkgsPattern) if err != nil { return fmt.Errorf("failed to load package: %w", err) } @@ -112,125 +86,170 @@ func checkModule(_ context.Context) error { // Iterate over the keeper packages // E.g.: // - github.com/pokt-network/poktroll/x/application/keeper + // - github.com/pokt-network/poktroll/x/application/types // - github.com/pokt-network/poktroll/x/gateway/keeper // - ... for _, pkg := range pkgs { - if flagModuleValue != "*" { - moduleRootPath := fmt.Sprintf("github.com/pokt-network/poktroll/x/%s", flagModuleValue) - if !strings.HasPrefix(pkg.PkgPath, moduleRootPath) { - continue - } + if shouldSkipPackage(pkg) { + continue } - if pkg.Name != "keeper" { - continue + loggerCtx := logger.WithContext(cmd.Context()) + if err := checkPackage(loggerCtx, pkg, pkgs); err != nil { + return err } + } - if len(pkg.Errors) > 0 { - for _, pkgErr := range pkg.Errors { - logger.Error().Msgf("Package error: %v", pkgErr) - } - continue + printResults() + + return nil +} + +// TODO_IN_THIS_COMMIT: move & godoc... +func validateModuleFlag() error { + if flagModuleValue != "*" { + switch flagModuleValue { + case "application": + case "gateway": + case "proof": + case "service": + case "session": + case "shared": + case "supplier": + case "tokenomics": + default: + return fmt.Errorf("ERROR: invalid module name: %s", flagModuleValue) } + } + return nil +} - // Access type information - info := pkg.TypesInfo - if info == nil { - logger.Warn().Msgf("No type information available, skipping package %q", pkg.PkgPath) +// TODO_IN_THIS_COMMIT: move & godoc... +func shouldSkipPackage(pkg *packages.Package) bool { + if flagModuleValue != "*" { + moduleRootPath := fmt.Sprintf("%s/%s", poktrollModulesRootPkgPath, flagModuleValue) + if !strings.HasPrefix(pkg.PkgPath, moduleRootPath) { + return true + } + } + + if pkg.Name != "keeper" { + return true + } + + if len(pkg.Errors) > 0 { + for _, pkgErr := range pkg.Errors { + logger.Error().Msgf("⚠️ Skipping package %q due to error: %v", pkg.PkgPath, pkgErr) + } + return true + } + + // Access type information + if pkg.TypesInfo == nil { + logger.Warn().Msgf("⚠️ No type information available, skipping package %q", pkg.PkgPath) + return true + } + + return false +} + +// TODO_IN_THIS_COMMIT: move & godoc... +func checkPackage(ctx context.Context, pkg *packages.Package, pkgs []*packages.Package) error { + for _, astFile := range pkg.Syntax { + filename := pkg.Fset.Position(astFile.Pos()).Filename + + // Ignore protobuf generated files. + if strings.HasSuffix(filepath.Base(filename), ".pb.go") { + continue + } + if strings.HasSuffix(filepath.Base(filename), ".pb.gw.go") { continue } - // --- END + ast.Inspect(astFile, newInspectFileFn(ctx, pkg, pkgs)) + } - // TODO_IN_THIS_COMMIT: extract --- BEGIN - // TODO_IN_THIS_COMMIT: check the filename and only inspect each once! - for _, astFile := range pkg.Syntax { - filename := pkg.Fset.Position(astFile.Pos()).Filename + return nil +} - // Ignore protobuf generated files. - if strings.HasSuffix(filepath.Base(filename), ".pb.go") { - continue - } - if strings.HasSuffix(filepath.Base(filename), ".pb.gw.go") { - continue - } +// TODO_IN_THIS_COMMIT: move & godoc... +func newInspectFileFn(ctx context.Context, pkg *packages.Package, pkgs []*packages.Package) func(ast.Node) bool { + return func(n ast.Node) bool { + fnNode, ok := n.(*ast.FuncDecl) + if !ok { + return true + } + + // Skip functions which are not methods. + if fnNode.Recv == nil { + return false + } - ast.Inspect(astFile, func(n ast.Node) bool { - fnNode, ok := n.(*ast.FuncDecl) - if !ok { - return true - } - - // Skip functions which are not methods. - if fnNode.Recv == nil { - return false - } - - fnNodeTypeObj, ok := info.Defs[fnNode.Name] //.Type.Results.List[0].Type - if !ok { - fmt.Printf("ERROR: unable to find fnNode type def: %s\n", fnNode.Name.Name) - return true - } - - // Skip methods which are not exported. - if !fnNodeTypeObj.Exported() { - return false - } - - // Skip methods which have no return arguments. - if fnNode.Type.Results == nil { - return false - } - - // TODO_IN_THIS_COMMIT: check the signature of the method to ensure it returns an error type. - fnResultsList := fnNode.Type.Results.List - fnLastResultType := fnResultsList[len(fnResultsList)-1].Type - if fnLastResultIdent, ok := fnLastResultType.(*ast.Ident); ok { - if fnLastResultIdent.Name != "error" { - return false - } - } - - fnType := fnNode.Recv.List[0].Type - typeIdentNode, ok := fnType.(*ast.Ident) - if !ok { - return false - } - - fnPos := pkg.Fset.Position(fnNode.Pos()) - fnFilename := filepath.Base(fnPos.Filename) - fnSourceHasQueryHandlerPrefix := strings.HasPrefix(fnFilename, "query_") - - if typeIdentNode.Name != "msgServer" && !fnSourceHasQueryHandlerPrefix { - return false - } - - // TODO_IN_THIS_COMMIT: figure out why this file hangs the command. - isExcludedFile := false - for _, excludedFile := range []string{"query_get_session.go"} { - if fnFilename == excludedFile { - isExcludedFile = true - } - } - if isExcludedFile { - return false - } - - // Recursively traverse the function body, looking for non-nil error returns. - ast.Inspect(fnNode.Body, walkFuncBody(pkg, pkgs, true)) + fnNodeTypeObj, ok := pkg.TypesInfo.Defs[fnNode.Name] //.Type.Results.List[0].Type + if !ok { + fmt.Printf("ERROR: unable to find fnNode type def: %s\n", fnNode.Name.Name) + return true + } + + // Skip methods which are not exported. + if !fnNodeTypeObj.Exported() { + return false + } + // Skip methods which have no return arguments. + if fnNode.Type.Results == nil { + return false + } + + // Ensure the last return argument type is error. + fnResultTypes := fnNode.Type.Results.List + lastResultType := fnResultTypes[len(fnResultTypes)-1].Type + if lastResultTypeIdent, ok := lastResultType.(*ast.Ident); ok { + if lastResultTypeIdent.Name != "error" { return false - }) + } } + fnType := fnNode.Recv.List[0].Type + typeIdentNode, ok := fnType.(*ast.Ident) + if !ok { + return false + } + + fnPos := pkg.Fset.Position(fnNode.Pos()) + fnFilename := filepath.Base(fnPos.Filename) + fnSourceHasQueryHandlerPrefix := strings.HasPrefix(fnFilename, "query_") + + if typeIdentNode.Name != "msgServer" && !fnSourceHasQueryHandlerPrefix { + return false + } + + // Recursively traverse the function body, looking for non-nil error returns. + ast.Inspect(fnNode.Body, goast.NewInspectLastReturnArgFn(ctx, pkg, pkgs, appendOffendingLine, exonerateOffendingLine)) + + return false } +} + +// TODO_IN_THIS_COMMIT: move & godoc... +func appendOffendingLine(sourceLine string) { + offendingPkgErrLineSet[sourceLine] = struct{}{} +} - // --- END +// TODO_IN_THIS_COMMIT: move & godoc... +func exonerateOffendingLine(sourceLine string) { + if _, ok := offendingPkgErrLineSet[sourceLine]; ok { + logger.Debug().Msgf("exhonerating %s", sourceLine) + delete(offendingPkgErrLineSet, sourceLine) + } else { + logger.Warn().Msgf("can't exonerate %s", sourceLine) + } +} - // TODO_IN_THIS_COMMIT: extract --- BEGIN +// TODO_IN_THIS_COMMIT: move & godoc... exits with code CodeNonStatusGRPCErrorsFound if offending lines found if offending lines found. +func printResults() { // Print offending lines in package - // TODO_IN_THIS_COMMIT: refactor to const. - pkgsPattern := "github.com/pokt-network/poktroll/x/..." + pkgsPattern := poktrollMoudlePkgsPattern if flagModuleValue != "*" { pkgsPattern = fmt.Sprintf("github.com/pokt-network/poktroll/x/%s/...", flagModuleValue) } @@ -262,12 +281,4 @@ func checkModule(_ context.Context) error { os.Exit(CodeNonStatusGRPCErrorsFound) } - // --- END - - return nil -} - -// TODO_IN_THIS_COMMIT: move & godoc... -func appendOffendingLine(sourceLine string) { - offendingPkgErrLineSet[sourceLine] = struct{}{} } diff --git a/tools/scripts/protocheck/cmd/unstable.go b/tools/scripts/protocheck/cmd/unstable.go index 2e3f454a1..f0e39d414 100644 --- a/tools/scripts/protocheck/cmd/unstable.go +++ b/tools/scripts/protocheck/cmd/unstable.go @@ -214,7 +214,7 @@ func excludeFileIfStableVisitFn( optName, optNameOk := getOptNodeName(optNode) if !optNameOk { - logger.Debug().Msgf( + logger.Warn().Msgf( "unable to extract option name from option node at %s:%d:%d", protoFilePath, optSrc.Line, optSrc.Col, ) @@ -238,7 +238,7 @@ func excludeFileIfStableVisitFn( if optValue != "true" { // Not the value we're looking for, continue traversing... - logger.Debug().Msgf( + logger.Warn().Msgf( "discovered an unstable_marshaler_all option with unexpected value %q at %s:%d:%d", optValue, protoFilePath, optSrc.Line, optSrc.Col, ) diff --git a/tools/scripts/protocheck/goast/find.go b/tools/scripts/protocheck/goast/find.go new file mode 100644 index 000000000..9abcc7de4 --- /dev/null +++ b/tools/scripts/protocheck/goast/find.go @@ -0,0 +1,239 @@ +package goast + +import ( + "context" + "fmt" + "go/ast" + "go/token" + "slices" + + "golang.org/x/tools/go/packages" + + "github.com/pokt-network/poktroll/pkg/polylog" +) + +// TODO_IN_THIS_COMMIT: move & godoc... +func FindDeclOrAssignValueExpr( + ctx context.Context, + targetIdent *ast.Ident, + pkg *packages.Package, +) (valueNode ast.Node, valueNodePos token.Position) { + valueNode, valueNodePos = FindValueExprByDeclaration(ctx, targetIdent, pkg) + + // TODO_IN_THIS_COMMIT: improve comment... + // If it does not, search the package for + // the ident and return the closest assignment. + if valueNode == nil { + valueNode, valueNodePos = FindValueExprByAssignment(ctx, targetIdent, pkg) + } + + return valueNode, valueNodePos +} + +// TODO_IN_THIS_COMMIT: move & godoc... +func FindValueExprByDeclaration( + ctx context.Context, + targetIdent *ast.Ident, + pkg *packages.Package, +) (valueNode ast.Node, valueNodePos token.Position) { + logger := polylog.Ctx(ctx).With("func", "FindValueExprByDeclaration") + + // TODO_IN_THIS_COMMIT: comment... + for _, fileNode := range pkg.Syntax { + if valueNode != nil { + break + } + + if obj := pkg.TypesInfo.Uses[targetIdent]; obj != nil { + logger.Debug().Fields(map[string]any{ + "file_path": fmt.Sprintf(" %s ", pkg.Fset.File(fileNode.Pos()).Name()), + "target_pos": fmt.Sprintf(" %s ", pkg.Fset.Position(targetIdent.Pos()).String()), + "decl_pos": fmt.Sprintf(" %s ", pkg.Fset.Position(obj.Pos()).String()), + }).Msg("uses") + valueNodePos = pkg.Fset.Position(obj.Pos()) + valueNode = FindNodeByPosition(pkg.Fset, fileNode, valueNodePos) + if valueNode != nil { + logger.Debug(). + Str("decl_node", fmt.Sprintf("%+v", valueNode)). + Str("decl_pos", fmt.Sprintf(" %s ", valueNodePos)). + Msg("found decl node") + } + } + } + + // TODO_IN_THIS_COMMIT: improve comment... + // Look through decl node to see if it contains a valudspec with values. + // If it does, return the value(s). + if valueNode != nil { + ast.Inspect(valueNode, func(n ast.Node) bool { + switch doa := n.(type) { + case *ast.ValueSpec: + logger.Debug(). + Int("len(values)", len(doa.Values)). + Msgf("value spec: %+v", doa) + + if doa.Values != nil { + logger.Debug().Msg("doa.Values != nil") + for _, value := range doa.Values { + valueNodePos = pkg.Fset.Position(value.Pos()) + valueNode = value + } + } else { + logger.Debug().Msg("dao.Values == nil") + valueNode = nil + } + } + + return true + }) + } else { + logger.Debug().Msgf("no declaration or assignment found for ident %q", targetIdent.String()) + } + + return valueNode, valueNodePos +} + +// TODO_IN_THIS_COMMIT: move & godoc... +// search for targetIdent by position +func FindNodeByPosition( + fset *token.FileSet, + fileNode *ast.File, + position token.Position, +) (targetNode ast.Node) { + ast.Inspect(fileNode, func(n ast.Node) bool { + if targetNode != nil { + return false + } + + if n == nil { + return true + } + + if n != nil && fset.Position(n.Pos()) == position { + targetNode = n + return false + } + + if targetNode != nil { + return false + } + + return true + }) + return targetNode +} + +// TODO_IN_THIS_COMMIT: move & godoc... +func FindValueExprByAssignment( + ctx context.Context, + targetIdent *ast.Ident, + pkg *packages.Package, +) (valueNode ast.Node, valueNodePos token.Position) { + assignsRhs := collectAssignments(targetIdent, pkg) + + if len(assignsRhs) < 1 { + return valueNode, valueNodePos + } + + // TODO_IN_THIS_COMMIT: comment explaining what's going on here... + slices.SortFunc[[]ast.Expr, ast.Expr](assignsRhs, func(a, b ast.Expr) int { + aPos := pkg.Fset.Position(a.Pos()) + bPos := pkg.Fset.Position(b.Pos()) + + if aPos.Filename == bPos.Filename { + switch { + case aPos.Line < bPos.Line: + return -1 + case aPos.Line > bPos.Line: + return 1 + default: + return 0 + } + } else { + return 1 + } + }) + + // TODO_IN_THIS_COMMIT: improve comment... + // ValueNode is the closest assignment whose position is less than or equal to the valueNodePos. + return FindClosestAssignment(ctx, assignsRhs, targetIdent, pkg) +} + +// TODO_IN_THIS_COMMIT: move & godoc... +func FindClosestAssignment( + ctx context.Context, + assignsRhs []ast.Expr, + targetIdent *ast.Ident, + pkg *packages.Package, +) (valueNode ast.Expr, valueNodePos token.Position) { + logger := polylog.Ctx(ctx).With("func", "FindClosestAssignment") + + var ( + targetIdentPos = pkg.Fset.Position(targetIdent.Pos()) + ) + for _, rhs := range assignsRhs { + if rhs == nil { + continue + } + + // DEV_NOTE: using pkg here assumes that rhs is in the same file as targetIdent. + // This SHOULD ALWAYS be the case for error type non-initialization declarations + // (e.g. var err error). I.e. we SHOULD NEVER be assigning an error value directly + // from aa pkg-level error variable. + rhsPos := pkg.Fset.Position(rhs.Pos()) + switch { + case rhsPos.Filename != targetIdentPos.Filename: + // TODO_TECHDEBT: handle case where rhs ident is defined in a different file. + logger.Debug(). + Str("assignment_position", rhsPos.String()). + Msg("ignoring assignment from different file") + continue + case rhsPos.Line < targetIdentPos.Line: + valueNodePos = rhsPos + valueNode = rhs + case rhsPos.Line == targetIdentPos.Line: + if rhsPos.Column <= targetIdentPos.Column { + valueNodePos = rhsPos + valueNode = rhs + } + } + } + + return valueNode, valueNodePos +} + +// TODO_IN_THIS_COMMIT: move & godoc... +func collectAssignments( + targetIdent *ast.Ident, + pkg *packages.Package, +) (assignsRhs []ast.Expr) { + for _, fileNode := range pkg.Syntax { + ast.Inspect(fileNode, func(n ast.Node) bool { + if assign, ok := n.(*ast.AssignStmt); ok { + for lhsIdx, lhs := range assign.Lhs { + // TODO_TECHDEBT: Ignoring assignments via selectors for now. + // E.g., `a.b = c` will not be considered. + lhsIdent, lhsIsIdent := lhs.(*ast.Ident) + if !lhsIsIdent { + continue + } + + if lhsIdent.Name != targetIdent.Name { + continue + } + + rhsIdx := 0 + if len(assign.Lhs) == len(assign.Rhs) { + rhsIdx = lhsIdx + } + + rhs := assign.Rhs[rhsIdx] + assignsRhs = append(assignsRhs, rhs) + } + } + return true + }) + } + + return assignsRhs +} diff --git a/tools/scripts/protocheck/goast/inspect.go b/tools/scripts/protocheck/goast/inspect.go new file mode 100644 index 000000000..f85747c10 --- /dev/null +++ b/tools/scripts/protocheck/goast/inspect.go @@ -0,0 +1,98 @@ +package goast + +import ( + "context" + "fmt" + "go/ast" + + "golang.org/x/tools/go/packages" + + "github.com/pokt-network/poktroll/pkg/polylog" +) + +// TODO_IN_THIS_COMMIT: move & godoc... +func NewInspectLastReturnArgFn( + ctx context.Context, + pkg *packages.Package, + modulePkgs []*packages.Package, + flag func(string), + exonerate func(string), +) func(ast.Node) bool { + logger := polylog.Ctx(ctx) + + return func(n ast.Node) bool { + if n == nil { + return false + } + + logger.Debug(). + Str("position", fmt.Sprintf(" %s ", pkg.Fset.Position(n.Pos()).String())). + Str("node_type", fmt.Sprintf("%T", n)). + Bool("flagging", flag != nil). + Msg("walking function body") + + switch n := n.(type) { + case *ast.ReturnStmt: + lastResult := n.Results[len(n.Results)-1] + inspectPosition := pkg.Fset.Position(lastResult.Pos()).String() + + logger = logger.With( + "node_type", fmt.Sprintf("%T", lastResult), + "inspectPosition", fmt.Sprintf(" %s ", inspectPosition), + ) + + logger.Debug().Msgf("lastResult: %+v", lastResult) + + switch lastReturnArgNode := lastResult.(type) { + // E.g. `return nil, err` <-- last arg is an *ast.Ident. + case *ast.Ident: + // DEV_NOTE: No need to check that the last return arg is an error type + // if we checked that the function returns an error as the last arg. + if lastReturnArgNode.Obj == nil { + logger.Debug().Msg("lastReturnArgNode.Obj is nil") + return true + } + + def := pkg.TypesInfo.Uses[lastReturnArgNode] + if def == nil { + logger.Debug().Msg("def is nil") + return true + } + + if def.Type().String() != "error" { + logger.Debug().Msg("def is not error") + return false + } + + if flag != nil { + logger.Debug().Msg("appending potential offending line") + flag(inspectPosition) + } + TraceExpressionStack(ctx, lastReturnArgNode, modulePkgs, pkg, lastReturnArgNode, exonerate) + return true + + // E.g. `return nil, types.ErrXXX.Wrapf(...)` <-- last arg is a *ast.CallExpr. + case *ast.CallExpr: + if flag != nil { + logger.Debug().Msg("appending potential offending line") + flag(inspectPosition) + } + TraceExpressionStack(ctx, lastReturnArgNode, modulePkgs, pkg, lastReturnArgNode, exonerate) + return true + + case *ast.SelectorExpr: + if flag != nil { + logger.Debug().Msg("appending potential offending line") + flag(inspectPosition) + } + TraceSelectorExpr(ctx, lastReturnArgNode, pkg, modulePkgs, lastReturnArgNode, exonerate) + return true + } + + default: + return true + } + + return true + } +} diff --git a/tools/scripts/protocheck/goast/trace.go b/tools/scripts/protocheck/goast/trace.go new file mode 100644 index 000000000..7e15121c4 --- /dev/null +++ b/tools/scripts/protocheck/goast/trace.go @@ -0,0 +1,206 @@ +package goast + +import ( + "context" + "fmt" + "go/ast" + "strings" + + "golang.org/x/tools/go/packages" + + "github.com/pokt-network/poktroll/pkg/polylog" +) + +const grpcStatusImportPath = "google.golang.org/grpc/status" + +// TODO_IN_THIS_COMMIT: move & godoc... +// TODO_IN_THIS_COMMIT: detemine whether this actually needs to return anything. +func TraceSelectorExpr( + ctx context.Context, + expr *ast.SelectorExpr, + candidatePkg *packages.Package, + modulePkgs []*packages.Package, + candidateNode ast.Node, + exonerate func(string), +) bool { + logger := polylog.Ctx(ctx).With("func", "TraceSelectorExpr") + logger.Debug().Send() + + // Search for the selector expression in all module packages. + // TODO_IN_THIS_COMMIT: is it expected and/or guaranteed that it will only be found in one pkg? + var declNode *ast.FuncDecl + for _, pkg := range modulePkgs { + //if declNode != nil { + // return true + //} + + selection := pkg.TypesInfo.Selections[expr] + if selection == nil { + continue + } + + for _, pkg2 := range modulePkgs { + selectionPos := pkg2.Fset.Position(selection.Obj().Pos()) + + for _, fileNode := range pkg2.Syntax { + selectionNode := FindNodeByPosition(pkg2.Fset, fileNode, selectionPos) + if selectionNode == nil { + continue + } + + logger.Debug(). + Str("node_type", fmt.Sprintf("%T", selectionNode)). + Str("selection_position", fmt.Sprintf(" %s ", selectionPos)). + Str("expr_position", fmt.Sprintf(" %s ", pkg.Fset.Position(expr.Pos()).String())). + Str("found_node_position", fmt.Sprintf(" %s ", pkg2.Fset.Position(selectionNode.Pos()).String())). + Msg("found node") + + ast.Inspect(fileNode, func(n ast.Node) bool { + if declNode != nil { + return false + } + + if decl, ok := n.(*ast.FuncDecl); ok { + if decl.Name.Name == selectionNode.(*ast.Ident).Name && + decl.Pos() < selectionNode.Pos() && + selectionNode.Pos() <= decl.End() { + declNode = decl + return false + } + } + return true + }) + + if declNode != nil { + logger.Debug().Str("decl_position", pkg2.Fset.Position(declNode.Pos()).String()).Msg("tracing decl node") + logger.Debug().Str("decl_body", pkg2.Fset.Position(declNode.Body.Pos()).String()).Msg("tracing decl node body") + ast.Inspect(declNode.Body, NewInspectLastReturnArgFn(ctx, pkg, modulePkgs, nil, nil)) + } else { + logger.Debug().Msg("could not find decl node") + } + + } + } + + // TODO_IN_THIS_COMMIT: note early return... + return true + } + + // TODO_IN_THIS_COMMIT: refactor; below happens when the selector is not found within any package. + + // Resolve the base expression + switch x := expr.X.(type) { + case *ast.Ident: // e.g., `pkg.Func` + for _, pkg := range modulePkgs { + if obj := pkg.TypesInfo.Uses[x]; obj != nil { + pkgParts := strings.Split(obj.String(), " ") + + var pkgStr string + switch { + // e.g., package (error) google.golang.org/grpc/status + case strings.HasPrefix(obj.String(), "package ("): + pkgStr = pkgParts[2] + // e.g. package fmt + default: + pkgStr = pkgParts[1] + } + + logger := logger.With( + "node_type", fmt.Sprintf("%T", x), + "position", fmt.Sprintf(" %s ", pkg.Fset.Position(x.Pos()).String()), + "package", strings.Trim(pkgStr, "\"()"), + ) + logger.Debug().Msg("tracing selector expression") + + isMatch := strings.Contains(obj.String(), grpcStatusImportPath) && + expr.Sel.Name == "Error" + if isMatch { + exonerate(pkg.Fset.Position(candidateNode.Pos()).String()) + return false + } + } else if obj = pkg.TypesInfo.Defs[x]; obj != nil { + logger.Debug().Msgf("no use but def: %+v", obj) + } else if obj = pkg.TypesInfo.Defs[expr.Sel]; obj != nil { + logger.Debug(). + Str("pkg_path", pkg.PkgPath). + Str("name", expr.Sel.Name). + Msgf("sel def") + TraceExpressionStack(ctx, expr.Sel, modulePkgs, candidatePkg, candidateNode, exonerate) + } + } + case *ast.SelectorExpr: // e.g., `obj.Method.Func` + logger.Debug().Msgf("tracing recursive selector expression: %+v", expr) + return TraceSelectorExpr(ctx, x, candidatePkg, modulePkgs, candidateNode, exonerate) + case *ast.CallExpr: + logger.Debug().Msgf("tracing call expression: %+v", expr) + TraceExpressionStack(ctx, x.Fun, modulePkgs, candidatePkg, candidateNode, exonerate) + default: + logger.Warn().Msgf("skipping selector expression X type: %T", x) + } + return true +} + +// Trace any expression recursively, including selector expressions +func TraceExpressionStack( + ctx context.Context, + exprToTrace ast.Expr, + modulePkgs []*packages.Package, + candidatePkg *packages.Package, + candidateNode ast.Node, + //flag func(string), + exonerate func(string), +) bool { + logger := polylog.Ctx(ctx).With( + "func", "TraceExpressionStack", + "node_type", fmt.Sprintf("%T", exprToTrace), + ) + logger.Debug().Send() + + if exprToTrace == nil { + return false + } + + switch expr := exprToTrace.(type) { + case nil: + return false + case *ast.CallExpr: + if sel, ok := expr.Fun.(*ast.SelectorExpr); ok { + return TraceSelectorExpr(ctx, sel, candidatePkg, modulePkgs, candidateNode, exonerate) + } + return true + case *ast.BinaryExpr: + logger.Debug().Msg("tracing binary expression") + // TODO_IN_THIS_COMMIT: return traceExpressionStack... ? + if TraceExpressionStack(ctx, expr.X, modulePkgs, candidatePkg, candidateNode, exonerate) { + TraceExpressionStack(ctx, expr.Y, modulePkgs, candidatePkg, candidateNode, exonerate) + } + return true + case *ast.ParenExpr: + logger.Debug().Msg("tracing paren expression") + return TraceExpressionStack(ctx, expr.X, modulePkgs, candidatePkg, candidateNode, exonerate) + case *ast.SelectorExpr: + logger.Debug().Msg("tracing selector expression") + return TraceSelectorExpr(ctx, expr, candidatePkg, modulePkgs, candidateNode, exonerate) + case *ast.Ident: + logger.Debug().Str("name", expr.Name).Msg("tracing ident") + + valueExpr, _ := FindDeclOrAssignValueExpr(ctx, expr, candidatePkg) + if valueExpr == nil { + return false + } + + switch valueNode := valueExpr.(type) { + case ast.Expr: + TraceExpressionStack(ctx, valueNode, modulePkgs, candidatePkg, candidateNode, exonerate) + case *ast.AssignStmt: + TraceExpressionStack(ctx, valueNode.Rhs[0], modulePkgs, candidatePkg, candidateNode, exonerate) + default: + logger.Warn().Msgf("unknown value node type: %T", valueNode) + } + + return true + default: + logger.Warn().Msgf("unknown expression node type: %T", exprToTrace) + return true + } +} diff --git a/x/application/keeper/msg_server_delegate_to_gateway.go b/x/application/keeper/msg_server_delegate_to_gateway.go index 502692bf6..f109d0ce3 100644 --- a/x/application/keeper/msg_server_delegate_to_gateway.go +++ b/x/application/keeper/msg_server_delegate_to_gateway.go @@ -25,7 +25,6 @@ func (k msgServer) DelegateToGateway(ctx context.Context, msg *apptypes.MsgDeleg if err := msg.ValidateBasic(); err != nil { logger.Error(fmt.Sprintf("Delegation Message failed basic validation: %v", err)) return nil, err - //return nil, status.Error(codes.InvalidArgument, err.Error()) } // Retrieve the application from the store diff --git a/x/supplier/keeper/msg_server_stake_supplier.go b/x/supplier/keeper/msg_server_stake_supplier.go index 4a7fb9992..d92a41201 100644 --- a/x/supplier/keeper/msg_server_stake_supplier.go +++ b/x/supplier/keeper/msg_server_stake_supplier.go @@ -119,9 +119,7 @@ func (k msgServer) StakeSupplier(ctx context.Context, msg *types.MsgStakeSupplie msg.Signer, msg.GetStake(), supplier.Stake, ) logger.Info(fmt.Sprintf("WARN: %s", err)) - //return nil, status.Error(codes.InvalidArgument, err.Error()) - err = status.Error(codes.InvalidArgument, err.Error()) - return nil, err + return nil, status.Error(codes.InvalidArgument, err.Error()) } // MUST ALWAYS have at least minimum stake. From 38eaa571f0191bccf29de8a815ebad710c8999a2 Mon Sep 17 00:00:00 2001 From: Bryan White Date: Wed, 27 Nov 2024 14:33:21 +0100 Subject: [PATCH 15/18] chore: rename protocheck cmd files --- .../cmd/{status_errors.go => check_status_errors.go} | 8 ++++---- .../protocheck/cmd/{unstable.go => check_unstable.go} | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) rename tools/scripts/protocheck/cmd/{status_errors.go => check_status_errors.go} (97%) rename tools/scripts/protocheck/cmd/{unstable.go => check_unstable.go} (98%) diff --git a/tools/scripts/protocheck/cmd/status_errors.go b/tools/scripts/protocheck/cmd/check_status_errors.go similarity index 97% rename from tools/scripts/protocheck/cmd/status_errors.go rename to tools/scripts/protocheck/cmd/check_status_errors.go index 5f2990a76..0ef7e506c 100644 --- a/tools/scripts/protocheck/cmd/status_errors.go +++ b/tools/scripts/protocheck/cmd/check_status_errors.go @@ -35,7 +35,7 @@ var ( flagLogLevelValue = "info" flagLogLevelUsage = "The logging level (debug|info|warn|error)" - statusErrorsCheckCmd = &cobra.Command{ + checkCtatusErrorsCmd = &cobra.Command{ Use: "status-errors [flags]", Short: "Checks that all message handler function errors are wrapped in gRPC status errors.", PreRun: setupLogger, @@ -47,9 +47,9 @@ var ( ) func init() { - statusErrorsCheckCmd.Flags().StringVarP(&flagModuleValue, flagModule, flagModuleShorthand, flagModuleValue, flagModuleUsage) - statusErrorsCheckCmd.Flags().StringVarP(&flagLogLevelValue, flagLogLevel, flagLogLevelShorthand, flagLogLevelValue, flagLogLevelUsage) - rootCmd.AddCommand(statusErrorsCheckCmd) + checkCtatusErrorsCmd.Flags().StringVarP(&flagModuleValue, flagModule, flagModuleShorthand, flagModuleValue, flagModuleUsage) + checkCtatusErrorsCmd.Flags().StringVarP(&flagLogLevelValue, flagLogLevel, flagLogLevelShorthand, flagLogLevelValue, flagLogLevelUsage) + rootCmd.AddCommand(checkCtatusErrorsCmd) } func setupLogger(_ *cobra.Command, _ []string) { diff --git a/tools/scripts/protocheck/cmd/unstable.go b/tools/scripts/protocheck/cmd/check_unstable.go similarity index 98% rename from tools/scripts/protocheck/cmd/unstable.go rename to tools/scripts/protocheck/cmd/check_unstable.go index f0e39d414..0dd8735cf 100644 --- a/tools/scripts/protocheck/cmd/unstable.go +++ b/tools/scripts/protocheck/cmd/check_unstable.go @@ -35,7 +35,7 @@ var ( flagFixValue = false flagFixUsage = "If present, protocheck will add the 'gogoproto.stable_marshaler_all' option to files which were discovered to be unstable." - unstableCmd = &cobra.Command{ + checkUnstableCmd = &cobra.Command{ Use: "unstable [flags]", Short: "Recursively list or fix all protobuf files which omit the 'stable_marshaler_all' option.", RunE: runUnstable, @@ -50,8 +50,8 @@ var ( ) func init() { - unstableCmd.Flags().BoolVarP(&flagFixValue, flagFixName, flagFixShorthand, flagFixValue, flagFixUsage) - rootCmd.AddCommand(unstableCmd) + checkUnstableCmd.Flags().BoolVarP(&flagFixValue, flagFixName, flagFixShorthand, flagFixValue, flagFixUsage) + rootCmd.AddCommand(checkUnstableCmd) } func runUnstable(cmd *cobra.Command, args []string) error { From 5d3244185cb43bb3daec09c571713439f8662218 Mon Sep 17 00:00:00 2001 From: Bryan White Date: Wed, 27 Nov 2024 16:54:09 +0100 Subject: [PATCH 16/18] squash! HEAD - refactor: extract flags & encapsulate logger --- .../protocheck/cmd/check_status_errors.go | 59 ++- .../scripts/protocheck/cmd/check_unstable.go | 339 +----------------- tools/scripts/protocheck/cmd/flags.go | 18 + tools/scripts/protocheck/cmd/logging.go | 24 ++ tools/scripts/protocheck/goast/inspect.go | 10 +- tools/scripts/protocheck/goast/trace.go | 7 +- 6 files changed, 85 insertions(+), 372 deletions(-) create mode 100644 tools/scripts/protocheck/cmd/flags.go create mode 100644 tools/scripts/protocheck/cmd/logging.go diff --git a/tools/scripts/protocheck/cmd/check_status_errors.go b/tools/scripts/protocheck/cmd/check_status_errors.go index 0ef7e506c..4400cc024 100644 --- a/tools/scripts/protocheck/cmd/check_status_errors.go +++ b/tools/scripts/protocheck/cmd/check_status_errors.go @@ -9,12 +9,10 @@ import ( "sort" "strings" - "github.com/rs/zerolog" "github.com/spf13/cobra" "golang.org/x/tools/go/packages" "github.com/pokt-network/poktroll/pkg/polylog" - "github.com/pokt-network/poktroll/pkg/polylog/polyzero" "github.com/pokt-network/poktroll/tools/scripts/protocheck/goast" ) @@ -25,41 +23,27 @@ const ( var ( poktrollModulesRootPkgPath = filepath.Dir(poktrollMoudlePkgsPattern) - flagModule = "module" - flagModuleShorthand = "m" - flagModuleValue = "*" - flagModuleUsage = "If present, only check message handlers of the given module." - - flagLogLevel = "log-level" - flagLogLevelShorthand = "l" - flagLogLevelValue = "info" - flagLogLevelUsage = "The logging level (debug|info|warn|error)" - - checkCtatusErrorsCmd = &cobra.Command{ + checkStatusErrorsCmd = &cobra.Command{ Use: "status-errors [flags]", Short: "Checks that all message handler function errors are wrapped in gRPC status errors.", - PreRun: setupLogger, + PreRun: setupPrettyLogger, RunE: runStatusErrorsCheck, } - logger polylog.Logger + // TODO_IN_THIS_COMMIT: refactor to avoid global logger var. + //logger polylog.Logger offendingPkgErrLineSet = make(map[string]struct{}) ) func init() { - checkCtatusErrorsCmd.Flags().StringVarP(&flagModuleValue, flagModule, flagModuleShorthand, flagModuleValue, flagModuleUsage) - checkCtatusErrorsCmd.Flags().StringVarP(&flagLogLevelValue, flagLogLevel, flagLogLevelShorthand, flagLogLevelValue, flagLogLevelUsage) - rootCmd.AddCommand(checkCtatusErrorsCmd) -} - -func setupLogger(_ *cobra.Command, _ []string) { - logger = polyzero.NewLogger( - polyzero.WithWriter(zerolog.ConsoleWriter{Out: os.Stdout}), - polyzero.WithLevel(polyzero.ParseLevel(flagLogLevelValue)), - ) + checkStatusErrorsCmd.Flags().StringVarP(&flagModuleValue, flagModule, flagModuleShorthand, flagModuleValue, flagModuleUsage) + checkStatusErrorsCmd.Flags().StringVarP(&flagLogLevelValue, flagLogLevel, flagLogLevelShorthand, flagLogLevelValue, flagLogLevelUsage) + rootCmd.AddCommand(checkStatusErrorsCmd) } func runStatusErrorsCheck(cmd *cobra.Command, _ []string) error { + ctx := cmd.Context() + if err := validateModuleFlag(); err != nil { return err } @@ -90,17 +74,16 @@ func runStatusErrorsCheck(cmd *cobra.Command, _ []string) error { // - github.com/pokt-network/poktroll/x/gateway/keeper // - ... for _, pkg := range pkgs { - if shouldSkipPackage(pkg) { + if shouldSkipPackage(ctx, pkg) { continue } - loggerCtx := logger.WithContext(cmd.Context()) - if err := checkPackage(loggerCtx, pkg, pkgs); err != nil { + if err = checkPackage(ctx, pkg, pkgs); err != nil { return err } } - printResults() + printResults(ctx) return nil } @@ -125,7 +108,9 @@ func validateModuleFlag() error { } // TODO_IN_THIS_COMMIT: move & godoc... -func shouldSkipPackage(pkg *packages.Package) bool { +func shouldSkipPackage(ctx context.Context, pkg *packages.Package) bool { + logger := polylog.Ctx(ctx) + if flagModuleValue != "*" { moduleRootPath := fmt.Sprintf("%s/%s", poktrollModulesRootPkgPath, flagModuleValue) if !strings.HasPrefix(pkg.PkgPath, moduleRootPath) { @@ -174,6 +159,8 @@ func checkPackage(ctx context.Context, pkg *packages.Package, pkgs []*packages.P // TODO_IN_THIS_COMMIT: move & godoc... func newInspectFileFn(ctx context.Context, pkg *packages.Package, pkgs []*packages.Package) func(ast.Node) bool { + logger := polylog.Ctx(ctx) + return func(n ast.Node) bool { fnNode, ok := n.(*ast.FuncDecl) if !ok { @@ -187,7 +174,7 @@ func newInspectFileFn(ctx context.Context, pkg *packages.Package, pkgs []*packag fnNodeTypeObj, ok := pkg.TypesInfo.Defs[fnNode.Name] //.Type.Results.List[0].Type if !ok { - fmt.Printf("ERROR: unable to find fnNode type def: %s\n", fnNode.Name.Name) + logger.Warn().Msgf("ERROR: unable to find fnNode type def: %s\n", fnNode.Name.Name) return true } @@ -232,12 +219,14 @@ func newInspectFileFn(ctx context.Context, pkg *packages.Package, pkgs []*packag } // TODO_IN_THIS_COMMIT: move & godoc... -func appendOffendingLine(sourceLine string) { +func appendOffendingLine(_ context.Context, sourceLine string) { offendingPkgErrLineSet[sourceLine] = struct{}{} } // TODO_IN_THIS_COMMIT: move & godoc... -func exonerateOffendingLine(sourceLine string) { +func exonerateOffendingLine(ctx context.Context, sourceLine string) { + logger := polylog.Ctx(ctx) + if _, ok := offendingPkgErrLineSet[sourceLine]; ok { logger.Debug().Msgf("exhonerating %s", sourceLine) delete(offendingPkgErrLineSet, sourceLine) @@ -247,7 +236,9 @@ func exonerateOffendingLine(sourceLine string) { } // TODO_IN_THIS_COMMIT: move & godoc... exits with code CodeNonStatusGRPCErrorsFound if offending lines found if offending lines found. -func printResults() { +func printResults(ctx context.Context) { + logger := polylog.Ctx(ctx) + // Print offending lines in package pkgsPattern := poktrollMoudlePkgsPattern if flagModuleValue != "*" { diff --git a/tools/scripts/protocheck/cmd/check_unstable.go b/tools/scripts/protocheck/cmd/check_unstable.go index 0dd8735cf..eb6712060 100644 --- a/tools/scripts/protocheck/cmd/check_unstable.go +++ b/tools/scripts/protocheck/cmd/check_unstable.go @@ -1,52 +1,23 @@ package main import ( - "bufio" "context" - "fmt" - "io/fs" "os" "path/filepath" - "strings" - "github.com/jhump/protoreflect/desc/protoparse" - "github.com/jhump/protoreflect/desc/protoparse/ast" "github.com/spf13/cobra" "github.com/pokt-network/poktroll/pkg/polylog" + "github.com/pokt-network/poktroll/tools/scripts/protocheck/protoast" ) -const ( - stableMarshalerAllOptName = "(gogoproto.stable_marshaler_all)" - expectedStableMarshalerAllOptionValue = "true" - gogoImportName = "gogoproto/gogo.proto" -) - -type protoFileStat struct { - pkgSource *ast.SourcePos - lastOptSource *ast.SourcePos - lastImportSource *ast.SourcePos - hasGogoImport bool -} - var ( - flagFixName = "fix" - flagFixShorthand = "f" - flagFixValue = false - flagFixUsage = "If present, protocheck will add the 'gogoproto.stable_marshaler_all' option to files which were discovered to be unstable." - checkUnstableCmd = &cobra.Command{ - Use: "unstable [flags]", - Short: "Recursively list or fix all protobuf files which omit the 'stable_marshaler_all' option.", - RunE: runUnstable, - } - - protoParser = protoparse.Parser{ - IncludeSourceCodeInfo: true, + Use: "unstable [flags]", + Short: "Recursively list or fix all protobuf files which omit the 'stable_marshaler_all' option.", + PreRun: setupPrettyLogger, + RunE: runUnstable, } - - stableMarshalerAllOptionSource = fmt.Sprintf(`option %s = %s;`, stableMarshalerAllOptName, expectedStableMarshalerAllOptionValue) - gogoImportSource = fmt.Sprintf(`import "%s";`, gogoImportName) ) func init() { @@ -58,7 +29,7 @@ func runUnstable(cmd *cobra.Command, args []string) error { ctx := cmd.Context() logger := polylog.Ctx(ctx) - unstableProtoFilesByPath := make(map[string]*protoFileStat) + unstableProtoFilesByPath := make(map[string]*protoast.ProtoFileStat) logger.Info().Msgf("Recursively checking for files matching %q in %q", flagFileIncludePatternValue, flagRootValue) @@ -69,9 +40,9 @@ func runUnstable(cmd *cobra.Command, args []string) error { // 2c. Exclude files which contain the stable_marshaler_all option. if pathWalkErr := filepath.Walk( flagRootValue, - forEachMatchingFileWalkFn( + protoast.ForEachMatchingFileWalkFn( flagFileIncludePatternValue, - findUnstableProtosInFileFn(ctx, unstableProtoFilesByPath), + protoast.FindUnstableProtosInFileFn(ctx, unstableProtoFilesByPath), ), ); pathWalkErr != nil { logger.Error().Err(pathWalkErr) @@ -102,14 +73,14 @@ func runUnstable(cmd *cobra.Command, args []string) error { return nil } -func runFixUnstable(ctx context.Context, unstableProtoFilesByPath map[string]*protoFileStat) { +func runFixUnstable(ctx context.Context, unstableProtoFilesByPath map[string]*protoast.ProtoFileStat) { logger := polylog.Ctx(ctx) logger.Info().Msg("Fixing unstable marshaler proto files...") var fixedProtoFilePaths []string for unstableProtoFile, protoStat := range unstableProtoFilesByPath { if protoStat != nil { - if insertErr := insertStableMarshalerAllOption(unstableProtoFile, protoStat); insertErr != nil { + if insertErr := protoast.InsertStableMarshalerAllOption(unstableProtoFile, protoStat); insertErr != nil { logger.Error().Err(insertErr).Msgf("unable to fix unstable marshaler proto file: %q", unstableProtoFile) continue } @@ -125,293 +96,3 @@ func runFixUnstable(ctx context.Context, unstableProtoFilesByPath map[string]*pr logger.Info().Msgf("\t%s", protoFilePath) } } - -// forEachMatchingFileWalkFn returns a filepath.WalkFunc which does the following: -// 1. Iterates over files matching fileNamePattern against each file name. -// 2. For matching files, it calls fileMatchedFn with the respective path. -func forEachMatchingFileWalkFn( - fileNamePattern string, - fileMatchedFn func(path string), -) filepath.WalkFunc { - return func(path string, info fs.FileInfo, err error) error { - if err != nil { - return err - } - - // Ignore directories - if info.IsDir() { - return nil - } - - matched, matchErr := filepath.Match(fileNamePattern, info.Name()) - if matchErr != nil { - return matchErr - } - - if matched { - fileMatchedFn(path) - } - return nil - } -} - -// findUnstableProtosInFileFn returns a function which is expected to be called for -// each proto file found. The function walks that file's AST and excludes it from -// the list of unstable files if it contains the stable_marshaler_all option. -func findUnstableProtosInFileFn( - ctx context.Context, - unstableProtoFilesByPath map[string]*protoFileStat, -) func(path string) { - logger := polylog.Ctx(ctx) - - return func(protoFilePath string) { - // Parse the .proto file into file nodes. - // NB: MUST use #ParseToAST instead of #ParseFiles to get source positions. - protoAST, parseErr := protoParser.ParseToAST(protoFilePath) - if parseErr != nil { - logger.Error().Err(parseErr).Msgf("Unable to parse proto file: %q", protoFilePath) - } - - // Iterate through the file nodes and build a protoFileStat for each file. - // NB: There should only be one file node per file. - for _, fileNode := range protoAST { - protoStat := newProtoFileStat(fileNode) - - // Add all proto files to unstableProtoFilePaths by default. If a file - // has a stable_marshaler_all option, that file protoFilePath will be - // removed from the map when the option is traversed (found). - unstableProtoFilesByPath[protoFilePath] = protoStat - - ast.Walk( - fileNode, - excludeFileIfStableVisitFn( - ctx, - protoFilePath, - unstableProtoFilesByPath, - ), - ) - } - } -} - -// excludeStableMarshalersVisitFn returns an ast.VisitFunc which removes proto files -// from unstableProtoFilesByPath if they contain the stable_marshaler_all option, and -// it is set to true. -func excludeFileIfStableVisitFn( - ctx context.Context, - protoFilePath string, - unstableProtoFilesByPath map[string]*protoFileStat, -) ast.VisitFunc { - logger := polylog.Ctx(ctx) - - return func(n ast.Node) (bool, ast.VisitFunc) { - optNode, optNodeOk := n.(*ast.OptionNode) - if !optNodeOk { - return true, nil - } - - optSrc := optNode.Start() - - optName, optNameOk := getOptNodeName(optNode) - if !optNameOk { - logger.Warn().Msgf( - "unable to extract option name from option node at %s:%d:%d", - protoFilePath, optSrc.Line, optSrc.Col, - ) - return true, nil - } - - if optName != stableMarshalerAllOptName { - // Not the option we're looking for, continue traversing... - return true, nil - } - - optValueNode := optNode.GetValue().Value() - optValue, ok := optValueNode.(ast.Identifier) - if !ok { - logger.Error().Msgf( - "unable to cast option value to ast.Identifier for option %q, got: %T at %s:%d:%d", - optName, optValueNode, protoFilePath, optSrc.Line, optSrc.Col, - ) - return true, nil - } - - if optValue != "true" { - // Not the value we're looking for, continue traversing... - logger.Warn().Msgf( - "discovered an unstable_marshaler_all option with unexpected value %q at %s:%d:%d", - optValue, protoFilePath, optSrc.Line, optSrc.Col, - ) - return true, nil - } - - // Remove stable proto file from unstableProtoFilesByPath. - delete(unstableProtoFilesByPath, protoFilePath) - - // Stop traversing the AST after finding the stable_marshaler_all option. - // We only expect one stable_marshaler_all option per file. - return false, nil - } -} - -// getOptNodeName returns the name of the option node as a string and a boolean -// indicating whether the name was successfully extracted. -func getOptNodeName(optNode *ast.OptionNode) (optName string, ok bool) { - optNameNode, optNameNodeOk := optNode.GetName().(*ast.OptionNameNode) - if !optNameNodeOk { - return "", false - } - - if len(optNameNode.Parts) < 1 { - return "", false - } - - for i, optNamePart := range optNameNode.Parts { - // Only insert delimiters if there is more than one part. - if i > 0 { - optName += "." - } - optName += optNamePart.Value() - } - - return optName, true -} - -func newProtoFileStat(fileNode *ast.FileNode) *protoFileStat { - var ( - pkgNode *ast.PackageNode - lastOptionNode *ast.OptionNode - lastImportNode *ast.ImportNode - foundGogoImport bool - ) - - for _, n := range fileNode.Children() { - switch node := n.(type) { - case *ast.PackageNode: - pkgNode = node - case *ast.ImportNode: - lastImportNode = node - - if node.Name.AsString() == gogoImportName { - foundGogoImport = true - } - case *ast.OptionNode: - lastOptionNode = node - } - } - - protoStat := &protoFileStat{ - pkgSource: pkgNode.Start(), - hasGogoImport: foundGogoImport, - } - if lastOptionNode != nil { - protoStat.lastOptSource = lastOptionNode.Start() - } - if lastImportNode != nil { - protoStat.lastImportSource = lastImportNode.Start() - } - - return protoStat -} - -func insertStableMarshalerAllOption(protoFilePath string, protoFile *protoFileStat) (err error) { - var ( - importInsertLine, - importInsertCol, - optionInsertLine, - optionInsertCol, - numInsertedLines int - - optionLine = stableMarshalerAllOptionSource - importLine = gogoImportSource - ) - - if protoFile.lastOptSource == nil { - optionInsertLine = protoFile.pkgSource.Line + 1 - optionInsertCol = protoFile.pkgSource.Col - optionLine += "\n" - numInsertedLines += 2 - } else { - optionInsertLine = protoFile.lastOptSource.Line - optionInsertCol = protoFile.lastOptSource.Col - numInsertedLines++ - } - - if err = insertLine( - protoFilePath, - optionInsertLine, - optionInsertCol, - optionLine, - ); err != nil { - return err - } - - if protoFile.hasGogoImport { - return nil - } - - if protoFile.lastImportSource == nil { - importInsertLine = optionInsertLine + 1 + numInsertedLines - importInsertCol = optionInsertCol - importLine += "\n" - } else { - importInsertLine = protoFile.lastImportSource.Line + numInsertedLines - importInsertCol = protoFile.lastImportSource.Col - } - - return insertLine( - protoFilePath, - importInsertLine, - importInsertCol, - importLine, - ) -} - -func insertLine(filePath string, lineNumber int, columnNumber int, textToInsert string) error { - // Open the file for reading - file, err := os.Open(filePath) - if err != nil { - return err - } - defer file.Close() - - // Read the file into a slice of strings (each string is a line) - var lines []string - scanner := bufio.NewScanner(file) - for scanner.Scan() { - lines = append(lines, scanner.Text()) - } - if err = scanner.Err(); err != nil { - return err - } - - // Check if the line number is within the range of lines in the file - if lineNumber < 1 || lineNumber > len(lines) { - return fmt.Errorf("line number %d is out of range", lineNumber) - } - - // Create the new line with the specified amount of leading whitespace - whitespace := strings.Repeat(" ", columnNumber-1) - newLine := whitespace + textToInsert - - // Insert the new line after the specified line - lineIndex := lineNumber - 1 - lines = append(lines[:lineIndex+1], append([]string{newLine}, lines[lineIndex+1:]...)...) - - // Open the file for writing - file, err = os.Create(filePath) - if err != nil { - return err - } - defer file.Close() - - // Write the modified lines back to the file - writer := bufio.NewWriter(file) - for _, line := range lines { - _, err := writer.WriteString(line + "\n") - if err != nil { - return err - } - } - return writer.Flush() -} diff --git a/tools/scripts/protocheck/cmd/flags.go b/tools/scripts/protocheck/cmd/flags.go new file mode 100644 index 000000000..adeed68cc --- /dev/null +++ b/tools/scripts/protocheck/cmd/flags.go @@ -0,0 +1,18 @@ +package main + +var ( + flagModule = "module" + flagModuleShorthand = "m" + flagModuleValue = "*" + flagModuleUsage = "If present, only check message handlers of the given module." + + flagLogLevel = "log-level" + flagLogLevelShorthand = "l" + flagLogLevelValue = "info" + flagLogLevelUsage = "The logging level (debug|info|warn|error)" + + flagFixName = "fix" + flagFixShorthand = "f" + flagFixValue = false + flagFixUsage = "If present, protocheck will add the 'gogoproto.stable_marshaler_all' option to files which were discovered to be unstable." +) diff --git a/tools/scripts/protocheck/cmd/logging.go b/tools/scripts/protocheck/cmd/logging.go new file mode 100644 index 000000000..f8ee0906f --- /dev/null +++ b/tools/scripts/protocheck/cmd/logging.go @@ -0,0 +1,24 @@ +package main + +import ( + "os" + + "github.com/rs/zerolog" + "github.com/spf13/cobra" + + "github.com/pokt-network/poktroll/pkg/polylog/polyzero" +) + +func setupPrettyLogger(cmd *cobra.Command, _ []string) { + logger := polyzero.NewLogger( + polyzero.WithWriter(zerolog.ConsoleWriter{ + Out: os.Stdout, + PartsExclude: []string{ + zerolog.TimestampFieldName, + }, + }), + polyzero.WithLevel(polyzero.ParseLevel(flagLogLevelValue)), + ) + + cmd.SetContext(logger.WithContext(cmd.Context())) +} diff --git a/tools/scripts/protocheck/goast/inspect.go b/tools/scripts/protocheck/goast/inspect.go index f85747c10..d60c5ede5 100644 --- a/tools/scripts/protocheck/goast/inspect.go +++ b/tools/scripts/protocheck/goast/inspect.go @@ -15,8 +15,8 @@ func NewInspectLastReturnArgFn( ctx context.Context, pkg *packages.Package, modulePkgs []*packages.Package, - flag func(string), - exonerate func(string), + flag func(context.Context, string), + exonerate func(context.Context, string), ) func(ast.Node) bool { logger := polylog.Ctx(ctx) @@ -66,7 +66,7 @@ func NewInspectLastReturnArgFn( if flag != nil { logger.Debug().Msg("appending potential offending line") - flag(inspectPosition) + flag(ctx, inspectPosition) } TraceExpressionStack(ctx, lastReturnArgNode, modulePkgs, pkg, lastReturnArgNode, exonerate) return true @@ -75,7 +75,7 @@ func NewInspectLastReturnArgFn( case *ast.CallExpr: if flag != nil { logger.Debug().Msg("appending potential offending line") - flag(inspectPosition) + flag(ctx, inspectPosition) } TraceExpressionStack(ctx, lastReturnArgNode, modulePkgs, pkg, lastReturnArgNode, exonerate) return true @@ -83,7 +83,7 @@ func NewInspectLastReturnArgFn( case *ast.SelectorExpr: if flag != nil { logger.Debug().Msg("appending potential offending line") - flag(inspectPosition) + flag(ctx, inspectPosition) } TraceSelectorExpr(ctx, lastReturnArgNode, pkg, modulePkgs, lastReturnArgNode, exonerate) return true diff --git a/tools/scripts/protocheck/goast/trace.go b/tools/scripts/protocheck/goast/trace.go index 7e15121c4..b2bd976c3 100644 --- a/tools/scripts/protocheck/goast/trace.go +++ b/tools/scripts/protocheck/goast/trace.go @@ -21,7 +21,7 @@ func TraceSelectorExpr( candidatePkg *packages.Package, modulePkgs []*packages.Package, candidateNode ast.Node, - exonerate func(string), + exonerate func(context.Context, string), ) bool { logger := polylog.Ctx(ctx).With("func", "TraceSelectorExpr") logger.Debug().Send() @@ -115,7 +115,7 @@ func TraceSelectorExpr( isMatch := strings.Contains(obj.String(), grpcStatusImportPath) && expr.Sel.Name == "Error" if isMatch { - exonerate(pkg.Fset.Position(candidateNode.Pos()).String()) + exonerate(ctx, pkg.Fset.Position(candidateNode.Pos()).String()) return false } } else if obj = pkg.TypesInfo.Defs[x]; obj != nil { @@ -147,8 +147,7 @@ func TraceExpressionStack( modulePkgs []*packages.Package, candidatePkg *packages.Package, candidateNode ast.Node, - //flag func(string), - exonerate func(string), + exonerate func(context.Context, string), ) bool { logger := polylog.Ctx(ctx).With( "func", "TraceExpressionStack", From b9f9a2c775432cdfd34214239d2c3b67f60fa254 Mon Sep 17 00:00:00 2001 From: Bryan White Date: Wed, 27 Nov 2024 17:55:12 +0100 Subject: [PATCH 17/18] refactor: find and replace "exonerate" with "exclude" --- .../protocheck/cmd/check_status_errors.go | 6 ++--- tools/scripts/protocheck/goast/inspect.go | 8 +++--- tools/scripts/protocheck/goast/trace.go | 26 +++++++++---------- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/tools/scripts/protocheck/cmd/check_status_errors.go b/tools/scripts/protocheck/cmd/check_status_errors.go index 4400cc024..7e15ef29d 100644 --- a/tools/scripts/protocheck/cmd/check_status_errors.go +++ b/tools/scripts/protocheck/cmd/check_status_errors.go @@ -212,7 +212,7 @@ func newInspectFileFn(ctx context.Context, pkg *packages.Package, pkgs []*packag } // Recursively traverse the function body, looking for non-nil error returns. - ast.Inspect(fnNode.Body, goast.NewInspectLastReturnArgFn(ctx, pkg, pkgs, appendOffendingLine, exonerateOffendingLine)) + ast.Inspect(fnNode.Body, goast.NewInspectLastReturnArgFn(ctx, pkg, pkgs, appendOffendingLine, excludeOffendingLine)) return false } @@ -224,14 +224,14 @@ func appendOffendingLine(_ context.Context, sourceLine string) { } // TODO_IN_THIS_COMMIT: move & godoc... -func exonerateOffendingLine(ctx context.Context, sourceLine string) { +func excludeOffendingLine(ctx context.Context, sourceLine string) { logger := polylog.Ctx(ctx) if _, ok := offendingPkgErrLineSet[sourceLine]; ok { logger.Debug().Msgf("exhonerating %s", sourceLine) delete(offendingPkgErrLineSet, sourceLine) } else { - logger.Warn().Msgf("can't exonerate %s", sourceLine) + logger.Warn().Msgf("can't exclude %s", sourceLine) } } diff --git a/tools/scripts/protocheck/goast/inspect.go b/tools/scripts/protocheck/goast/inspect.go index d60c5ede5..96904d8d5 100644 --- a/tools/scripts/protocheck/goast/inspect.go +++ b/tools/scripts/protocheck/goast/inspect.go @@ -16,7 +16,7 @@ func NewInspectLastReturnArgFn( pkg *packages.Package, modulePkgs []*packages.Package, flag func(context.Context, string), - exonerate func(context.Context, string), + exclude func(context.Context, string), ) func(ast.Node) bool { logger := polylog.Ctx(ctx) @@ -68,7 +68,7 @@ func NewInspectLastReturnArgFn( logger.Debug().Msg("appending potential offending line") flag(ctx, inspectPosition) } - TraceExpressionStack(ctx, lastReturnArgNode, modulePkgs, pkg, lastReturnArgNode, exonerate) + TraceExpressionStack(ctx, lastReturnArgNode, modulePkgs, pkg, lastReturnArgNode, exclude) return true // E.g. `return nil, types.ErrXXX.Wrapf(...)` <-- last arg is a *ast.CallExpr. @@ -77,7 +77,7 @@ func NewInspectLastReturnArgFn( logger.Debug().Msg("appending potential offending line") flag(ctx, inspectPosition) } - TraceExpressionStack(ctx, lastReturnArgNode, modulePkgs, pkg, lastReturnArgNode, exonerate) + TraceExpressionStack(ctx, lastReturnArgNode, modulePkgs, pkg, lastReturnArgNode, exclude) return true case *ast.SelectorExpr: @@ -85,7 +85,7 @@ func NewInspectLastReturnArgFn( logger.Debug().Msg("appending potential offending line") flag(ctx, inspectPosition) } - TraceSelectorExpr(ctx, lastReturnArgNode, pkg, modulePkgs, lastReturnArgNode, exonerate) + TraceSelectorExpr(ctx, lastReturnArgNode, pkg, modulePkgs, lastReturnArgNode, exclude) return true } diff --git a/tools/scripts/protocheck/goast/trace.go b/tools/scripts/protocheck/goast/trace.go index b2bd976c3..709c13a71 100644 --- a/tools/scripts/protocheck/goast/trace.go +++ b/tools/scripts/protocheck/goast/trace.go @@ -21,7 +21,7 @@ func TraceSelectorExpr( candidatePkg *packages.Package, modulePkgs []*packages.Package, candidateNode ast.Node, - exonerate func(context.Context, string), + exclude func(context.Context, string), ) bool { logger := polylog.Ctx(ctx).With("func", "TraceSelectorExpr") logger.Debug().Send() @@ -115,7 +115,7 @@ func TraceSelectorExpr( isMatch := strings.Contains(obj.String(), grpcStatusImportPath) && expr.Sel.Name == "Error" if isMatch { - exonerate(ctx, pkg.Fset.Position(candidateNode.Pos()).String()) + exclude(ctx, pkg.Fset.Position(candidateNode.Pos()).String()) return false } } else if obj = pkg.TypesInfo.Defs[x]; obj != nil { @@ -125,15 +125,15 @@ func TraceSelectorExpr( Str("pkg_path", pkg.PkgPath). Str("name", expr.Sel.Name). Msgf("sel def") - TraceExpressionStack(ctx, expr.Sel, modulePkgs, candidatePkg, candidateNode, exonerate) + TraceExpressionStack(ctx, expr.Sel, modulePkgs, candidatePkg, candidateNode, exclude) } } case *ast.SelectorExpr: // e.g., `obj.Method.Func` logger.Debug().Msgf("tracing recursive selector expression: %+v", expr) - return TraceSelectorExpr(ctx, x, candidatePkg, modulePkgs, candidateNode, exonerate) + return TraceSelectorExpr(ctx, x, candidatePkg, modulePkgs, candidateNode, exclude) case *ast.CallExpr: logger.Debug().Msgf("tracing call expression: %+v", expr) - TraceExpressionStack(ctx, x.Fun, modulePkgs, candidatePkg, candidateNode, exonerate) + TraceExpressionStack(ctx, x.Fun, modulePkgs, candidatePkg, candidateNode, exclude) default: logger.Warn().Msgf("skipping selector expression X type: %T", x) } @@ -147,7 +147,7 @@ func TraceExpressionStack( modulePkgs []*packages.Package, candidatePkg *packages.Package, candidateNode ast.Node, - exonerate func(context.Context, string), + exclude func(context.Context, string), ) bool { logger := polylog.Ctx(ctx).With( "func", "TraceExpressionStack", @@ -164,22 +164,22 @@ func TraceExpressionStack( return false case *ast.CallExpr: if sel, ok := expr.Fun.(*ast.SelectorExpr); ok { - return TraceSelectorExpr(ctx, sel, candidatePkg, modulePkgs, candidateNode, exonerate) + return TraceSelectorExpr(ctx, sel, candidatePkg, modulePkgs, candidateNode, exclude) } return true case *ast.BinaryExpr: logger.Debug().Msg("tracing binary expression") // TODO_IN_THIS_COMMIT: return traceExpressionStack... ? - if TraceExpressionStack(ctx, expr.X, modulePkgs, candidatePkg, candidateNode, exonerate) { - TraceExpressionStack(ctx, expr.Y, modulePkgs, candidatePkg, candidateNode, exonerate) + if TraceExpressionStack(ctx, expr.X, modulePkgs, candidatePkg, candidateNode, exclude) { + TraceExpressionStack(ctx, expr.Y, modulePkgs, candidatePkg, candidateNode, exclude) } return true case *ast.ParenExpr: logger.Debug().Msg("tracing paren expression") - return TraceExpressionStack(ctx, expr.X, modulePkgs, candidatePkg, candidateNode, exonerate) + return TraceExpressionStack(ctx, expr.X, modulePkgs, candidatePkg, candidateNode, exclude) case *ast.SelectorExpr: logger.Debug().Msg("tracing selector expression") - return TraceSelectorExpr(ctx, expr, candidatePkg, modulePkgs, candidateNode, exonerate) + return TraceSelectorExpr(ctx, expr, candidatePkg, modulePkgs, candidateNode, exclude) case *ast.Ident: logger.Debug().Str("name", expr.Name).Msg("tracing ident") @@ -190,9 +190,9 @@ func TraceExpressionStack( switch valueNode := valueExpr.(type) { case ast.Expr: - TraceExpressionStack(ctx, valueNode, modulePkgs, candidatePkg, candidateNode, exonerate) + TraceExpressionStack(ctx, valueNode, modulePkgs, candidatePkg, candidateNode, exclude) case *ast.AssignStmt: - TraceExpressionStack(ctx, valueNode.Rhs[0], modulePkgs, candidatePkg, candidateNode, exonerate) + TraceExpressionStack(ctx, valueNode.Rhs[0], modulePkgs, candidatePkg, candidateNode, exclude) default: logger.Warn().Msgf("unknown value node type: %T", valueNode) } From 392aa7e7a081ab9e8104efa1e1f10d95e6dbbecf Mon Sep 17 00:00:00 2001 From: Bryan White Date: Wed, 27 Nov 2024 18:25:59 +0100 Subject: [PATCH 18/18] fixup: flags --- tools/scripts/protocheck/cmd/check_unstable.go | 2 ++ tools/scripts/protocheck/cmd/flags.go | 10 ++++++++++ tools/scripts/protocheck/cmd/root.go | 18 ------------------ 3 files changed, 12 insertions(+), 18 deletions(-) diff --git a/tools/scripts/protocheck/cmd/check_unstable.go b/tools/scripts/protocheck/cmd/check_unstable.go index eb6712060..16b2cda0c 100644 --- a/tools/scripts/protocheck/cmd/check_unstable.go +++ b/tools/scripts/protocheck/cmd/check_unstable.go @@ -21,6 +21,8 @@ var ( ) func init() { + checkUnstableCmd.Flags().StringVarP(&flagRootValue, flagRootName, flagRootShorthand, flagRootValue, flagRootUsage) + checkUnstableCmd.Flags().StringVarP(&flagFileIncludePatternValue, flagFileIncludePatternName, flagFileIncludePatternShorthand, flagFileIncludePatternValue, flagFileIncludePatternUsage) checkUnstableCmd.Flags().BoolVarP(&flagFixValue, flagFixName, flagFixShorthand, flagFixValue, flagFixUsage) rootCmd.AddCommand(checkUnstableCmd) } diff --git a/tools/scripts/protocheck/cmd/flags.go b/tools/scripts/protocheck/cmd/flags.go index adeed68cc..618ff607e 100644 --- a/tools/scripts/protocheck/cmd/flags.go +++ b/tools/scripts/protocheck/cmd/flags.go @@ -1,6 +1,16 @@ package main var ( + flagRootName = "root" + flagRootShorthand = "r" + flagRootValue = "./proto" + flagRootUsage = "Set the path of the directory from which to start walking the filesystem tree in search of files matching --file-pattern." + + flagFileIncludePatternName = "file-pattern" + flagFileIncludePatternShorthand = "p" + flagFileIncludePatternValue = "*.proto" + flagFileIncludePatternUsage = "Set the pattern passed to filepath.Match(), used to include file names which match." + flagModule = "module" flagModuleShorthand = "m" flagModuleValue = "*" diff --git a/tools/scripts/protocheck/cmd/root.go b/tools/scripts/protocheck/cmd/root.go index 76861d4ba..dac771af9 100644 --- a/tools/scripts/protocheck/cmd/root.go +++ b/tools/scripts/protocheck/cmd/root.go @@ -2,7 +2,6 @@ package main import ( "context" - "flag" "os" "github.com/rs/zerolog" @@ -13,30 +12,13 @@ import ( ) var ( - flagRootName = "root" - flagRootShorthand = "r" - flagRootValue = "./proto" - flagRootUsage = "Set the path of the directory from which to start walking the filesystem tree in search of files matching --file-pattern." - - flagFileIncludePatternName = "file-pattern" - flagFileIncludePatternShorthand = "p" - flagFileIncludePatternValue = "*.proto" - flagFileIncludePatternUsage = "Set the pattern passed to filepath.Match(), used to include file names which match." - rootCmd = &cobra.Command{ Use: "protocheck [subcommand] [flags]", Short: "A tool for heuristically identifying and fixing issues in protobuf files and usage.", } ) -func init() { - rootCmd.PersistentFlags().StringVarP(&flagRootValue, flagRootName, flagRootShorthand, flagRootValue, flagRootUsage) - rootCmd.PersistentFlags().StringVarP(&flagFileIncludePatternValue, flagFileIncludePatternName, flagFileIncludePatternShorthand, flagFileIncludePatternValue, flagFileIncludePatternUsage) -} - func main() { - flag.Parse() - zlConsoleWriter := zerolog.ConsoleWriter{ Out: os.Stderr, // Remove the timestamp from the output