diff --git a/tracer.go b/tracer.go index 9a64320..cbfa075 100644 --- a/tracer.go +++ b/tracer.go @@ -75,12 +75,12 @@ const sqlOperationUnknkown = "UNKNOWN" // sqlOperationName attempts to get the first 'word' from a given SQL query, which usually // is the operation name (e.g. 'SELECT'). -func sqlOperationName(stmt string, fn func(string) string) string { +func (t *Tracer) sqlOperationName(stmt string) string { // If a custom function is provided, use that. Otherwise, fall back to the // default implementation. This allows users to override the default // behavior without having to reimplement it. - if fn != nil { - return fn(stmt) + if t.spanNameFunc != nil { + return t.spanNameFunc(stmt) } parts := strings.Fields(stmt) @@ -127,7 +127,7 @@ func (t *Tracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.T spanName := "query " + data.SQL if t.trimQuerySpanName { - spanName = "query " + sqlOperationName(data.SQL, t.spanNameFunc) + spanName = "query " + t.sqlOperationName(data.SQL) } ctx, _ = t.tracer.Start(ctx, spanName, opts...) @@ -227,7 +227,7 @@ func (t *Tracer) TraceBatchQuery(ctx context.Context, conn *pgx.Conn, data pgx.T spanName := "batch query " + data.SQL if t.trimQuerySpanName { - spanName = "query " + sqlOperationName(data.SQL, t.spanNameFunc) + spanName = "query " + t.sqlOperationName(data.SQL) } _, span := t.tracer.Start(ctx, spanName, opts...) @@ -297,7 +297,7 @@ func (t *Tracer) TracePrepareStart(ctx context.Context, conn *pgx.Conn, data pgx spanName := "prepare " + data.SQL if t.trimQuerySpanName { - spanName = "prepare " + sqlOperationName(data.SQL, t.spanNameFunc) + spanName = "prepare " + t.sqlOperationName(data.SQL) } ctx, _ = t.tracer.Start(ctx, spanName, opts...) diff --git a/tracer_test.go b/tracer_test.go index e565a61..59240fc 100644 --- a/tracer_test.go +++ b/tracer_test.go @@ -5,134 +5,131 @@ import ( "testing" ) -func TestSqlOperationName(t *testing.T) { +func TestTracer_sqlOperationName(t *testing.T) { tests := []struct { - name string - query string - spanNameFunc func(string) string - expName string + name string + tracer *Tracer + query string + expName string }{ { - name: "Spaces only", - query: "SELECT * FROM users", - spanNameFunc: nil, - expName: "SELECT", + name: "Spaces only", + query: "SELECT * FROM users", + tracer: NewTracer(), + expName: "SELECT", }, { - name: "Newline and tab", - query: "UPDATE\n\tfoo", - spanNameFunc: nil, - expName: "UPDATE", + name: "Newline and tab", + query: "UPDATE\n\tfoo", + tracer: NewTracer(), + expName: "UPDATE", }, { - name: "Additional whitespace", - query: " \n SELECT\n\t * FROM users ", - spanNameFunc: nil, - expName: "SELECT", + name: "Additional whitespace", + query: " \n SELECT\n\t * FROM users ", + tracer: NewTracer(), + expName: "SELECT", }, { - name: "Whitespace-only query", - query: " \n\t", - spanNameFunc: nil, - expName: sqlOperationUnknkown, + name: "Whitespace-only query", + query: " \n\t", + tracer: NewTracer(), + expName: sqlOperationUnknkown, }, { - name: "Empty query", - query: "", - spanNameFunc: nil, - expName: sqlOperationUnknkown, + name: "Empty query", + query: "", + tracer: NewTracer(), + expName: sqlOperationUnknkown, }, { - name: "Functional span name (-- comment style)", - query: "-- name: GetUsers :many\nSELECT * FROM users", - spanNameFunc: defaultSpanNameFunc(), - expName: "GetUsers :many", + name: "Functional span name (-- comment style)", + query: "-- name: GetUsers :many\nSELECT * FROM users", + tracer: NewTracer(WithSpanNameFunc(defaultSpanNameFunc)), + expName: "GetUsers :many", }, { - name: "Functional span name (/**/ comment style)", - query: "/* name: GetBooks :many */\nSELECT * FROM books", - spanNameFunc: defaultSpanNameFunc(), - expName: "GetBooks :many", + name: "Functional span name (/**/ comment style)", + query: "/* name: GetBooks :many */\nSELECT * FROM books", + tracer: NewTracer(WithSpanNameFunc(defaultSpanNameFunc)), + expName: "GetBooks :many", }, { - name: "Functional span name (# comment style)", - query: "# name: GetRecords :many\nSELECT * FROM records", - spanNameFunc: defaultSpanNameFunc(), - expName: "GetRecords :many", + name: "Functional span name (# comment style)", + query: "# name: GetRecords :many\nSELECT * FROM records", + tracer: NewTracer(WithSpanNameFunc(defaultSpanNameFunc)), + expName: "GetRecords :many", }, { - name: "Functional span name (no annotation)", - query: "--\nSELECT * FROM user", - spanNameFunc: defaultSpanNameFunc(), - expName: sqlOperationUnknkown, + name: "Functional span name (no annotation)", + query: "--\nSELECT * FROM user", + tracer: NewTracer(WithSpanNameFunc(defaultSpanNameFunc)), + expName: sqlOperationUnknkown, }, { - name: "Custom SQL name query (normal comment)", - query: "-- foo \nSELECT * FROM users", - spanNameFunc: defaultSpanNameFunc(), - expName: sqlOperationUnknkown, + name: "Custom SQL name query (normal comment)", + query: "-- foo \nSELECT * FROM users", + tracer: NewTracer(WithSpanNameFunc(defaultSpanNameFunc)), + expName: sqlOperationUnknkown, }, { - name: "Custom SQL name query (invalid formatting)", - query: "foo \nSELECT * FROM users", - spanNameFunc: defaultSpanNameFunc(), - expName: sqlOperationUnknkown, + name: "Custom SQL name query (invalid formatting)", + query: "foo \nSELECT * FROM users", + tracer: NewTracer(WithSpanNameFunc(defaultSpanNameFunc)), + expName: sqlOperationUnknkown, }, } - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - name := sqlOperationName(tt.query, tt.spanNameFunc) - if name != tt.expName { - t.Errorf("Got name %q, expected %q", name, tt.expName) + tr := tt.tracer + if got := tr.sqlOperationName(tt.query); got != tt.expName { + t.Errorf("Tracer.sqlOperationName() = %v, want %v", got, tt.expName) } }) } } -// defaultSpanNameFunc is an utility fucntion for testing that attempts to get +// defaultSpanNameFunc is an utility function for testing that attempts to get // the first name of the query from a given SQL statement. -func defaultSpanNameFunc() SpanNameFunc { - return func(query string) string { - for _, line := range strings.Split(query, "\n") { - var prefix string - switch { - case strings.HasPrefix(line, "--"): - prefix = "--" +var defaultSpanNameFunc SpanNameFunc = func(query string) string { + for _, line := range strings.Split(query, "\n") { + var prefix string + switch { + case strings.HasPrefix(line, "--"): + prefix = "--" - case strings.HasPrefix(line, "/*"): - prefix = "/*" + case strings.HasPrefix(line, "/*"): + prefix = "/*" - case strings.HasPrefix(line, "#"): - prefix = "#" - default: - continue - } + case strings.HasPrefix(line, "#"): + prefix = "#" + default: + continue + } - rest := line[len(prefix):] - if !strings.HasPrefix(strings.TrimSpace(rest), "name") { - continue - } - if !strings.Contains(rest, ":") { - continue - } - if !strings.HasPrefix(rest, " name: ") { - return sqlOperationUnknkown - } + rest := line[len(prefix):] + if !strings.HasPrefix(strings.TrimSpace(rest), "name") { + continue + } + if !strings.Contains(rest, ":") { + continue + } + if !strings.HasPrefix(rest, " name: ") { + return sqlOperationUnknkown + } - part := strings.Split(strings.TrimSpace(line), " ") - if prefix == "/*" { - part = part[:len(part)-1] // removes the trailing "*/" element - } - if len(part) == 2 { - return sqlOperationUnknkown - } + part := strings.Split(strings.TrimSpace(line), " ") + if prefix == "/*" { + part = part[:len(part)-1] // removes the trailing "*/" element + } + if len(part) == 2 { + return sqlOperationUnknkown + } - queryName := part[2] - queryType := strings.TrimSpace(part[3]) + queryName := part[2] + queryType := strings.TrimSpace(part[3]) - return queryName + " " + queryType - } - return sqlOperationUnknkown + return queryName + " " + queryType } + return sqlOperationUnknkown }