-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.go
231 lines (212 loc) · 6.04 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
package main
import (
"errors"
"flag"
"fmt"
"io"
"log"
"net"
"os"
"path/filepath"
"strings"
"syscall"
)
type server struct {
Socket string
SecretDir string
epfd int
inotifyRequests chan inotifyRequest
connectionClosed chan int
}
func inheritSocket() *net.UnixListener {
socks := systemdSockets(true)
stat := &syscall.Stat_t{}
for _, s := range socks {
fd := s.Fd()
err := syscall.Fstat(int(fd), stat)
if err != nil {
log.Printf("Received invalid file descriptor from systemd for fd%d: %v", fd, err)
continue
}
listener, err := net.FileListener(s)
if err != nil {
log.Printf("Received file descriptor %d from systemd that is not a valid socket: %v", fd, err)
continue
}
unixListener, ok := listener.(*net.UnixListener)
if !ok {
log.Printf("Ignore file descriptor %d from systemd, which is not a unix socket", fd)
continue
}
log.Printf("Use unix socket received from systemd")
return unixListener
}
return nil
}
func listenSocket(path string) (*net.UnixListener, error) {
s := inheritSocket()
if s != nil {
return s, nil
}
if err := syscall.Unlink(path); err != nil && !os.IsNotExist(err) {
return nil, fmt.Errorf("Cannot remove old socket: %v", err)
}
abs, err := filepath.Abs(path)
if err != nil {
return nil, fmt.Errorf("'%s' is not a valid socket path: %v", path, err)
}
addr, err := net.ResolveUnixAddr("unix", abs)
if err != nil {
return nil, fmt.Errorf("Failed to resolv '%s' as a unix address: %v", abs, err)
}
listener, err := net.ListenUnix("unix", addr)
if err != nil {
return nil, fmt.Errorf("Failed to open socket at %s: %v", addr.Name, err)
}
return listener, nil
}
func parseCredentialsAddr(addr string) (*string, *string, error) {
// Systemd stores metadata in its local unix address
fields := strings.Split(addr, "/")
if len(fields) != 4 || fields[1] != "unit" {
return nil, nil, fmt.Errorf("Address needs to match this format: @<random_hex>/unit/<service_name>/<secret_id>, got '%s'", addr)
}
return &fields[2], &fields[3], nil
}
func (s *server) queueInotifyRequest(conn *net.UnixConn, filename string, key string) error {
log.Printf("Block start until %s appears", filename)
fd, err := connFd(conn)
if err != nil {
// connection was closed while we trying to wait
return err
}
if err := s.epollWatch(fd); err != nil {
log.Printf("Cannot get setup epoll for unix socket: %s", err)
return err
}
s.inotifyRequests <- inotifyRequest{filename: filename, key: key, conn: conn}
return nil
}
func (s *server) serveServiceEnvironment(conn *net.UnixConn, unit string, secret string) {
shouldClose := true
defer func() {
if shouldClose {
conn.Close()
}
}()
log.Printf("Systemd requested environment file for %s from %s", secret, unit)
secretPath := filepath.Join(s.SecretDir, secret)
f, err := os.Open(secretPath)
if errors.Is(err, os.ErrNotExist) {
if s.queueInotifyRequest(conn, secret, secret) == nil {
shouldClose = false
}
return
} else if err != nil {
log.Printf("Cannot open environment file %s/%s: %v", unit, secret, err)
return
}
defer f.Close()
if _, err = io.Copy(conn, f); err != nil {
log.Printf("Failed to send environment file: %v", err)
}
}
func (s *server) serveServiceSecrets(conn *net.UnixConn, unit string, secret string) {
shouldClose := true
defer func() {
if shouldClose {
conn.Close()
}
}()
log.Printf("Systemd requested secret for %s/%s", unit, secret)
secretName := unit + ".json"
secretPath := filepath.Join(s.SecretDir, secretName)
secretMap, err := parseServiceSecrets(secretPath)
if errors.Is(err, os.ErrNotExist) {
if s.queueInotifyRequest(conn, secretName, secret) == nil {
shouldClose = false
}
return
} else if err != nil {
log.Printf("Cannot process secret %s/%s: %v", unit, secret, err)
return
}
val, ok := secretMap[secret]
if ok {
if _, err = io.WriteString(conn, fmt.Sprint(val)); err != nil {
log.Printf("Failed to send secret: %v", err)
}
} else {
log.Printf("Secret map at %s has no value for key %s", secretPath, secret)
}
}
func (s *server) serveConnection(conn *net.UnixConn) {
addr := conn.RemoteAddr().String()
unit, secret, err := parseCredentialsAddr(addr)
if err != nil {
conn.Close()
log.Printf("Received connection but remote unix address seems to be not from systemd: %v", err)
return
}
if isEnvironmentFile(*secret) {
s.serveServiceEnvironment(conn, *unit, *secret)
} else {
s.serveServiceSecrets(conn, *unit, *secret)
}
}
func serveSecrets(s *server) error {
l, err := listenSocket(s.Socket)
if err != nil {
return fmt.Errorf("Failed to setup listening socket: %v", err)
}
defer l.Close()
log.Printf("Listening on %s", s.Socket)
go s.handleEpoll()
for {
conn, err := l.AcceptUnix()
if err != nil {
return fmt.Errorf("Error accepting unix connection: %v", err)
}
go s.serveConnection(conn)
}
}
var secretDir, socketDir string
func init() {
defaultDir := os.Getenv("SYSTEMD_VAULT_SECRETS")
if defaultDir == "" {
defaultDir = "/run/systemd-vaultd/secrets"
}
flag.StringVar(&secretDir, "secrets", defaultDir, "directory where secrets are looked up")
defaultSock := os.Getenv("SYSTEMD_VAULT_SOCK")
if defaultSock == "" {
defaultSock = "/run/systemd-vaultd/sock"
}
flag.StringVar(&socketDir, "sock", defaultSock, "unix socket to listen to for systemd requests")
flag.Parse()
}
func createServer(secretDir string, socketDir string) (*server, error) {
epfd, err := syscall.EpollCreate1(syscall.EPOLL_CLOEXEC)
if epfd == -1 {
return nil, fmt.Errorf("failed to create epoll fd: %v", err)
}
s := &server{
Socket: socketDir,
SecretDir: secretDir,
epfd: epfd,
inotifyRequests: make(chan inotifyRequest),
connectionClosed: make(chan int),
}
if err := s.setupWatcher(secretDir); err != nil {
return nil, fmt.Errorf("Failed to setup file system watcher: %v", err)
}
return s, nil
}
func main() {
s, err := createServer(secretDir, socketDir)
if err != nil {
log.Fatalf("Failed to create server: %v", err)
}
if err := serveSecrets(s); err != nil {
log.Fatalf("Failed serve secrets: %v", err)
}
}