-
-
Notifications
You must be signed in to change notification settings - Fork 10
/
transaction.go
86 lines (65 loc) · 1.84 KB
/
transaction.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
package otelsql
import (
"context"
"database/sql/driver"
"go.opentelemetry.io/otel/trace"
)
const (
metricMethodCommit = "go.sql.commit"
traceMethodCommit = "commit"
metricMethodRollback = "go.sql.rollback"
traceMethodRollback = "rollback"
)
var _ driver.Tx = (*tx)(nil)
type txFuncMiddleware = middleware[txFunc]
type txFunc func() error
type tx struct {
commit txFunc
rollback txFunc
}
func (t tx) Commit() error {
return t.commit()
}
func (t tx) Rollback() error {
return t.rollback()
}
func wrapTx(ctx context.Context, parent driver.Tx, r methodRecorder, t methodTracer) driver.Tx {
ctx = trace.ContextWithSpanContext(context.Background(), trace.SpanContextFromContext(ctx))
return &tx{
commit: chainMiddlewares(makeTxFuncMiddlewares(ctx, r, t, metricMethodCommit, traceMethodCommit), parent.Commit),
rollback: chainMiddlewares(makeTxFuncMiddlewares(ctx, r, t, metricMethodRollback, traceMethodRollback), parent.Rollback),
}
}
func nopTxFunc() error {
return nil
}
func txStats(ctx context.Context, r methodRecorder, method string) txFuncMiddleware {
return func(next txFunc) txFunc {
return func() (err error) {
end := r.Record(ctx, method)
defer func() {
end(err)
}()
return next()
}
}
}
func txTrace(ctx context.Context, t methodTracer, method string) txFuncMiddleware {
return func(next txFunc) txFunc {
return func() (err error) {
_, end := t.MustTrace(ctx, method)
defer func() {
end(err)
}()
return next()
}
}
}
func makeTxFuncMiddlewares(ctx context.Context, r methodRecorder, t methodTracer, metricMethod string, traceMethod string) []txFuncMiddleware {
middlewares := make([]txFuncMiddleware, 0, 2)
middlewares = append(middlewares, txStats(ctx, r, metricMethod))
if t != nil {
middlewares = append(middlewares, txTrace(ctx, t, traceMethod))
}
return middlewares
}