-
Notifications
You must be signed in to change notification settings - Fork 7
/
gosafe.go
421 lines (386 loc) · 11.8 KB
/
gosafe.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
/*
A Go tool to safely compile and run Go programs by only allowing importing of whitelisted packages.
Use gosafe.Compiler.Allow(string) to allow given packages, then run code with gosafe.Compiler.Run(string) or gosafe.Compiler.RunFile(string).
Use child.Stdin(), child.Stdout() and child.Stderr() in github.com/zond/gosafe/child to communicate with the child processes via structured data.
Use gosafe.Compiler.Command(string), gosafe.Compiler.CommandFile(string) and gosafe.Cmd.Handle(interface{}, interface{} to create child process handlers that will stay dormant until needed (when gosafe.Cmd.Handle(...) is called), and die again after a customizable timeout without new messages.
Go to https://github.com/zond/gosafe for the source, of course.
*/
package gosafe
import (
"bytes"
"crypto/sha1"
"encoding/json"
"errors"
"fmt"
"github.com/zond/gosafe/child"
"github.com/zond/tools"
"go/ast"
"go/parser"
"go/token"
"hash"
"io"
"os"
"os/exec"
"path"
"time"
)
const HANDLER_TIMEOUT = time.Second * 10
var hasher hash.Hash
func init() {
hasher = sha1.New()
}
type visitor func(ast.Node)
func (self visitor) Visit(node ast.Node) ast.Visitor {
self(node)
return self
}
type Error string
func (self Error) Error() string {
return string(self)
}
/*
A wrapper around os/exec.Cmd that provides ready io.Readers and io.Writers for communicating with the contained process.
Also provides gosafe.Cmd.Encode(interface{}) and gosafe.Cmd.Decode(interface{}) that sends/receives structured data to the child process.
Use gosafe.Cmd.Handle() to spin up child processes on demand. If they continue living and handling messages after responding to the first call, they will keep on living and handling incoming messages until they get killed from timeout.
*/
type Cmd struct {
// Binary is the path to the executable file represented by this Cmd
Binary string
// Cmd is the os/exec.Cmd wrapped by this Cmd
Cmd *exec.Cmd
// Stdin is the Stdin of the wrapped process
Stdin io.WriteCloser
// Stdout is the Stdout of the wrapped process
Stdout io.Reader
// Stderr is the Stderr of the wrapped process
Stderr io.Writer
encoder *json.Encoder
decoder *json.Decoder
lastEvent time.Time
server child.Server
// The amount of time idle child processes are allowed to live without handling messages.
Timeout time.Duration
}
func (self *Cmd) String() string {
pid, running := self.Pid()
var s string
if running {
s = fmt.Sprint(pid)
} else {
s = "dead"
}
return fmt.Sprintf("<Cmd %v %v>", self.Binary, s)
}
// Encode sends i to the child process stdin through a json.Encoder.
func (self *Cmd) Encode(i interface{}) error {
if self.encoder == nil {
self.encoder = json.NewEncoder(self.Stdin)
}
return self.encoder.Encode(i)
}
// Decode receives i from the child process stdout through a json.Decoder.
func (self *Cmd) Decode(i interface{}) error {
if self.decoder == nil {
self.decoder = json.NewDecoder(self.Stdout)
}
return self.decoder.Decode(i)
}
// Kill will kill the child process if it is alive.
func (self *Cmd) Kill() error {
if self.Cmd == nil {
return nil
}
if self.Cmd.Process == nil {
return nil
}
return self.Cmd.Process.Kill()
}
// Pid returns the pid of the child process, and whether it was alive.
func (self *Cmd) Pid() (int, bool) {
if self.Cmd == nil {
return 0, false
}
if self.Cmd.Process == nil {
return 0, false
}
if proc, err := os.FindProcess(self.Cmd.Process.Pid); err == nil {
return proc.Pid, true
}
return 0, false
}
func (self *Cmd) reHandle(i, o interface{}) error {
self.Start()
return self.Handle(i, o)
}
func (self *Cmd) timeout() time.Duration {
if self.Timeout == 0 {
return HANDLER_TIMEOUT
}
return self.Timeout
}
// Register the given name and Service function to serve callbacks from the child process.
//
// func get(args... interface{}) interface{} {
// return MyDatabase.Get(args[0])
// }
//
// cmd.Register("get", get)
//
// Now the child processes can use this method via child.Call("get", key)
func (self *Cmd) Register(name string, service child.Service) *Cmd {
self.server[name] = service
return self
}
func createRequest(response child.Response) (rval child.Request, err error) {
if call, ok := response.Payload.(map[string]interface{}); ok {
if x, ok := call["Name"]; ok {
if name, ok := x.(string); ok {
if x, ok = call["Args"]; ok {
if args, ok := x.([]interface{}); ok {
return child.Request{Name: name, Args: args}, nil
}
}
}
}
}
return child.Request{}, errors.New(fmt.Sprintf(child.NotProperRequest, response.Payload))
}
// Call will call one function registered via child.Server#Register and return its return value.
func (self *Cmd) Call(name string, args ...interface{}) (rval interface{}, err error) {
response := child.Response{}
self.Handle(child.Request{name, args}, &response)
for {
if response.Type == child.Return {
break
} else if response.Type == child.Error {
return nil, errors.New(fmt.Sprint(response.Payload))
} else if response.Type == child.Callback {
if request, err := createRequest(response); err == nil {
response = child.Response{}
self.Handle(self.server.Handle(request), &response)
} else {
self.Encode(child.Response{child.Error, err})
return nil, err
}
} else {
return nil, errors.New(fmt.Sprintf(child.UnknownResponseType, response))
}
}
return response.Payload, nil
}
// Handle starts the child process if it is dead, sends i to the child process using Encode and receives o with the response using Decode.
// Will create a timer that kills this process after gosafe.Cmd.Timeout has passed if no new messages arrive.
func (self *Cmd) Handle(i, o interface{}) error {
if _, running := self.Pid(); !running {
return self.reHandle(i, o)
}
self.lastEvent = time.Now()
err := self.Encode(i)
if err != nil {
if err.Error() == "write |1: bad file descriptor" {
return self.reHandle(i, o)
}
return err
}
err = self.Decode(&o)
if err != nil {
if err == io.EOF {
return self.reHandle(i, o)
}
return err
}
go func() {
<-time.After(self.timeout())
if time.Now().Sub(self.lastEvent) > self.timeout() {
self.lastEvent = time.Now()
if err := self.Kill(); err != nil {
fmt.Fprintln(os.Stderr, "While trying to kill an idle process: ", err)
}
}
}()
return nil
}
// Start clears all child process-specific state of this Cmd and restart the process.
func (self *Cmd) Start() error {
self.Cmd = exec.Command(self.Binary)
self.encoder = nil
self.decoder = nil
self.lastEvent = time.Now()
var err error
if self.Stdin, err = self.Cmd.StdinPipe(); err != nil {
return err
}
if self.Stdout, err = self.Cmd.StdoutPipe(); err != nil {
return err
}
if self.Stderr == nil {
self.Cmd.Stderr = os.Stderr
} else {
self.Cmd.Stderr = self.Stderr
}
if err := self.Cmd.Start(); err != nil {
return err
}
go func() {
if err = self.Cmd.Wait(); err != nil {
fmt.Fprintln(os.Stderr, err)
}
}()
return nil
}
// A compiler of potentially unsafe code.
type Compiler struct {
allowed map[string]bool
okChecked map[string]time.Time
okCompiled map[string]time.Time
}
func NewCompiler() *Compiler {
return &Compiler{make(map[string]bool), make(map[string]time.Time), make(map[string]time.Time)}
}
// AllowRuntime will allow the runtime package for this gosafe.Compiler.
// See https://github.com/zond/gosafe/issues/1 as to why this is necessary.
func (self *Compiler) AllowRuntime() {
self.allowed[fmt.Sprint("\"runtime\"")] = true
}
// Allow will add p to the allowed list of golang packages for this gosafe.Compiler.
// It will NOT allow the runtime package - including that one requires a more conscious effort.
// See https://github.com/zond/gosafe/issues/1 as to why this is necessary.
func (self *Compiler) Allow(p string) {
if p == "runtime" {
panic(fmt.Errorf("Allowing \"runtime\" requires you to use Compiler#AllowRuntime. See https://github.com/zond/gosafe/issues/1"))
}
self.allowed[fmt.Sprint("\"", p, "\"")] = true
}
func (self *Compiler) shorten(s string) string {
hasher.Reset()
for allowed, _ := range self.allowed {
hasher.Write([]byte(allowed))
}
hasher.Write([]byte(s))
return tools.NewBigIntBytes(hasher.Sum(nil)).BaseString(tools.MAX_BASE)
}
// Check will return an error if this gosafe.Compiler doesn't allow the given file to be compiled.
func (self *Compiler) Check(file string) error {
fstat, err := os.Stat(file)
if err != nil {
// Problem stating file
return err
}
if checkTime, ok := self.okChecked[file]; ok && checkTime.After(fstat.ModTime()) {
// Was checked before, and after the file was last changed
return nil
}
var disallowed []string
tree, _ := parser.ParseFile(token.NewFileSet(), file, nil, 0)
ast.Walk(visitor(func(node ast.Node) {
if importNode, isImport := node.(*ast.ImportSpec); isImport {
if importNode.Path != nil {
if _, ok := self.allowed[importNode.Path.Value]; !ok {
// This import declaration imports a package that is not allowed
disallowed = append(disallowed, importNode.Path.Value)
}
}
}
}), tree)
if len(disallowed) > 0 {
var buffer bytes.Buffer
for index, pkg := range disallowed {
fmt.Fprint(&buffer, pkg)
if index < len(disallowed)-1 {
fmt.Fprint(&buffer, ", ")
}
}
// We tried to import non-allowed packages
return Error(fmt.Sprint("Imports of disallowed libraries: ", string((&buffer).Bytes())))
}
// We checked this file as OK now
self.okChecked[file] = time.Now()
return nil
}
// RunFile will start a gosafe.Cmd encapsulating the given file and return it.
func (self *Compiler) RunFile(file string) (cmd *Cmd, err error) {
cmd, err = self.CommandFile(file)
if err != nil {
return nil, err
}
cmd.Start()
return cmd, nil
}
// Run will start a gosafe.Cmd encapsulating the given code and return it.
func (self *Compiler) Run(s string) (cmd *Cmd, err error) {
cmd, err = self.Command(s)
if err != nil {
return nil, err
}
cmd.Start()
return cmd, nil
}
// CommandFile will return a gosafe.Cmd encapsulating the given file.
func (self *Compiler) CommandFile(file string) (cmd *Cmd, err error) {
compiled, err := self.Compile(file)
if err != nil {
return nil, err
}
cmd = &Cmd{Binary: compiled, server: make(child.Server)}
return cmd, nil
}
// Command will return a gosafe.Cmd encapsulating the given code.
func (self *Compiler) Command(s string) (cmd *Cmd, err error) {
output := path.Join(os.TempDir(), fmt.Sprintf("%s.gosafe.go", self.shorten(s)))
file, err := os.Create(output)
if err != nil {
return nil, err
}
defer func() {
os.Remove(output)
}()
file.WriteString(s)
err = file.Close()
if err != nil {
return nil, err
}
return self.CommandFile(file.Name())
}
// Compile will compile the given file to a temporary file if deemed safe, and return the path to the resulting binary.
func (self *Compiler) Compile(file string) (output string, err error) {
output = path.Join(os.TempDir(), fmt.Sprintf("%s.gosafe", self.shorten(file)))
err = self.CompileTo(file, output)
if err != nil {
return "", err
}
return output, nil
}
// CompileTo will compile the given file to a given path file if deemed safe.
func (self *Compiler) CompileTo(file, output string) error {
fstat, err := os.Stat(file)
if err != nil {
// Problem stating file
return err
}
if compileTime, ok := self.okCompiled[file]; ok && compileTime.After(fstat.ModTime()) {
// Was compiled before, and after the file was last changed
return nil
}
err = self.Check(file)
if err != nil {
return err
}
var stderr bytes.Buffer
var stdout bytes.Buffer
args := []string{"build", "-ldflags", fmt.Sprint("-o ", output), file}
cmd := exec.Command("go", args...)
cmd.Stderr = &stderr
cmd.Stdout = &stdout
err = cmd.Run()
if len((stderr).Bytes()) > 0 {
return Error(string(stderr.Bytes()))
}
if len((stdout).Bytes()) > 0 {
return Error(string(stdout.Bytes()))
}
if err != nil {
return err
}
self.okCompiled[file] = time.Now()
return nil
}