diff --git a/goimportssort.go b/goimportssort.go index 10e3ca5..3bab486 100644 --- a/goimportssort.go +++ b/goimportssort.go @@ -32,6 +32,7 @@ var ( list = flag.Bool("l", false, "write results to stdout") write = flag.Bool("w", false, "write result to (source) file instead of stdout") localPrefix = flag.String("local", "", "put imports beginning with this string after 3rd-party packages; comma-separated list") + order = flag.String("o", "iel", "custom the order of the section of imports. e.g. ixl means inbuilt, external, and local") verbose bool // verbose logging standardPackages = make(map[string]struct{}) ) @@ -290,6 +291,25 @@ func countImports(impModels [][]impModel) int { func convertImportsToSlice(node *dst.File) ([][]impModel, error) { importCategories := make([][]impModel, 3) + inbuild := &importCategories[0] + external := &importCategories[1] + local := &importCategories[2] + if sortString(*order) == sortString("iel") { + chars := []rune(*order) + for i := 0; i < 3; i++ { + switch chars[i] { + case 'l': + local = &importCategories[i] + case 'e': + external = &importCategories[i] + case 'i': + inbuild = &importCategories[i] + default: + return importCategories, fmt.Errorf("cannot parse the order argument given: %s", *order) + } + } + } + for _, importSpec := range node.Imports { impName := importSpec.Path.Value impNameWithoutQuotes := strings.Trim(impName, "\"") @@ -302,17 +322,25 @@ func convertImportsToSlice(node *dst.File) ([][]impModel, error) { locImpModel.path = impName if *localPrefix != "" && strings.Count(impName, *localPrefix) > 0 { - importCategories[2] = append(importCategories[2], locImpModel) + *local = append(*local, locImpModel) } else if isStandardPackage(impNameWithoutQuotes) { - importCategories[0] = append(importCategories[0], locImpModel) + *inbuild = append(*inbuild, locImpModel) } else { - importCategories[1] = append(importCategories[1], locImpModel) + *external = append(*external, locImpModel) } } return importCategories, nil } +func sortString(str string) string { + charArray := []rune(str) + sort.Slice(charArray, func(i int, j int) bool { + return charArray[i] < charArray[j] + }) + return string(charArray) +} + // loadStandardPackages tries to fetch all golang std packages func loadStandardPackages() error { pkgs, err := packages.Load(nil, "std")