-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.go
144 lines (124 loc) · 2.91 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
package main
import (
"fmt"
"go/parser"
"go/token"
"io"
"log"
"os"
"path/filepath"
"gopkg.in/yaml.v2"
"github.com/hasitpbhatt/gonforce/models"
)
var _enforcerConfig enforcer
// Rule contains the package name and rules that need to be applied
type Rule struct {
Name string `yaml:"name"`
PackageRule models.PackageRule `yaml:"rule"`
}
type enforcer struct {
// Package is package name e.g. github.com/hasitpbhatt/gonforce
Package string `yaml:"package"`
// Default rules contain the default allowed and not allowed packages
// e.g.
// default:
// type: allowlist
// imports:
// - gopkg.in
// except:
// - gopkg.in/yaml.v2
Default models.PackageRule `yaml:"default"`
// Rules are array of rules containing package name and package
// rule
Rules []Rule `yaml:"rules"`
}
func main() {
f, err := os.Open("gonforce.yaml")
if err != nil {
log.Fatal("File not found: gonforce.yaml")
}
defer f.Close()
if err := decode(f); err != nil {
log.Fatal(err)
}
if err := processRoot(); err != nil {
log.Fatal(err)
}
}
func decode(f io.Reader) error {
d := yaml.NewDecoder(f)
d.SetStrict(true)
if err := d.Decode(&_enforcerConfig); err != nil {
return fmt.Errorf("Unable to decode gonforce.yaml: %v", err)
}
if err := _enforcerConfig.Default.Validate(); err != nil {
return fmt.Errorf("Invalid gonforce.yaml: %v", err)
}
for _, rule := range _enforcerConfig.Rules {
if err := rule.PackageRule.Validate(); err != nil {
return fmt.Errorf("Invalid gonforce.yaml: %v", err)
}
}
return nil
}
func processRoot() error {
dir, err := os.Getwd()
if err != nil {
return fmt.Errorf("Unable to get current dir: %v", err)
}
err = process(dir, _enforcerConfig.Default)
if err != nil {
return err
}
for _, rule := range _enforcerConfig.Rules {
err := processRecursively(filepath.Join(dir, rule.Name), rule.PackageRule)
if err != nil {
return err
}
}
return nil
}
func processRecursively(dir string, pr models.PackageRule) error {
if err := process(dir, pr); err != nil {
return err
}
err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
return nil
}
p, err := filepath.Abs(path)
if err != nil {
return err
}
if err := process(p, pr); err != nil {
return err
}
return nil
})
return err
}
func process(dir string, pr models.PackageRule) error {
fset := token.NewFileSet()
pkgs, err := parser.ParseDir(fset, dir, nil, parser.ImportsOnly)
if err != nil {
return fmt.Errorf("Unable to parse imports in %s: %v", dir, err)
}
errorFound := false
for _, pkg := range pkgs {
for fpath, file := range pkg.Files {
for _, imp := range file.Imports {
if err := pr.IsValidImport(fpath, imp.Path.Value); err != nil {
errorFound = true
fmt.Println(err)
}
}
}
}
if errorFound {
return fmt.Errorf("validation failed in %v", dir)
}
return nil
}