Skip to content

Commit

Permalink
PromQL variables support
Browse files Browse the repository at this point in the history
  • Loading branch information
alpinskiy committed Sep 11, 2023
1 parent ee12ac2 commit cb53ba5
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 25 deletions.
69 changes: 45 additions & 24 deletions internal/api/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ type (
}

DashboardVar struct {
Name string `json:"name"`
Args DashboardVarArgs `json:"args"`
Vals []string `json:"values"`
Link [][]int `json:"link"`
Expand Down Expand Up @@ -377,6 +378,7 @@ type (
rand *rand.Rand
stat *endpointStat
timeNow time.Time
vars map[string]promql.Variable
}

//easyjson:json
Expand Down Expand Up @@ -419,6 +421,7 @@ type (
renderRequest struct {
ai accessInfo
seriesRequest []seriesRequest
vars map[string]promql.Variable
renderWidth string
renderFormat string
}
Expand Down Expand Up @@ -1750,7 +1753,7 @@ func (h *Handler) HandleSeriesQuery(w http.ResponseWriter, r *http.Request) {
return
}
// Parse request
qry, err := h.parseHTTPRequest(r)
qry, _, err := h.parseHTTPRequest(r)
if err != nil {
respondJSON(w, nil, 0, 0, err, h.verbose, ai.user, sl)
return
Expand Down Expand Up @@ -1892,7 +1895,7 @@ func (h *Handler) HandleSeriesQuery(w http.ResponseWriter, r *http.Request) {

func (h *Handler) handleSeriesQueryPromQL(w http.ResponseWriter, r *http.Request, sl *endpointStat, ai accessInfo) {
// Parse request
qry, err := h.parseHTTPRequest(r)
qry, vars, err := h.parseHTTPRequest(r)
if err != nil {
respondJSON(w, nil, 0, 0, err, h.verbose, ai.user, sl)
return
Expand Down Expand Up @@ -1927,6 +1930,7 @@ func (h *Handler) handleSeriesQueryPromQL(w http.ResponseWriter, r *http.Request
res, freeRes, err = h.handlePromqlQuery(withHTTPEndpointStat(ctx, sl), ai, qry, seriesRequestOptions{
debugQueries: true,
stat: sl,
vars: vars,
metricNameCallback: func(name string) {
qry.metricWithNamespace = name
g.Go(func() error {
Expand All @@ -1943,6 +1947,7 @@ func (h *Handler) handleSeriesQueryPromQL(w http.ResponseWriter, r *http.Request
res, freeRes, err = h.handlePromqlQuery(withHTTPEndpointStat(ctx, sl), ai, qry, seriesRequestOptions{
debugQueries: true,
stat: sl,
vars: vars,
})
}
var traces []string
Expand Down Expand Up @@ -2096,6 +2101,7 @@ func (h *Handler) handlePromqlQuery(ctx context.Context, ai accessInfo, req seri
}
},
SeriesQueryCallback: seriesQueryCallback,
Vars: opt.vars,
}
)
if req.widthKind == widthAutoRes {
Expand Down Expand Up @@ -2523,7 +2529,7 @@ func (h *Handler) HandleGetPoint(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), h.querySelectTimeout)
defer cancel()

req, err := h.parseHTTPRequest(r)
req, _, err := h.parseHTTPRequest(r)
if err != nil {
respondJSON(w, nil, 0, 0, err, h.verbose, ai.user, sl)
return
Expand Down Expand Up @@ -2746,7 +2752,7 @@ func (h *Handler) HandleGetRender(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), h.querySelectTimeout)
defer cancel()

s, err := h.parseHTTPRequestS(r, 12)
s, vars, err := h.parseHTTPRequestS(r, 12)
if err != nil {
respondJSON(w, nil, 0, 0, err, h.verbose, ai.user, sl)
return
Expand All @@ -2756,6 +2762,7 @@ func (h *Handler) HandleGetRender(w http.ResponseWriter, r *http.Request) {
ctx, ai,
renderRequest{
seriesRequest: s,
vars: vars,
renderWidth: r.FormValue(paramRenderWidth),
renderFormat: r.FormValue(paramDataFormat),
})
Expand Down Expand Up @@ -2853,9 +2860,12 @@ func (h *Handler) handleGetRender(ctx context.Context, ai accessInfo, req render
err error
start = time.Now()
)
data, cancel, err = h.handlePromqlQuery(ctx, ai, r, seriesRequestOptions{metricNameCallback: func(s string) {
req.seriesRequest[i].metricWithNamespace = s
}})
data, cancel, err = h.handlePromqlQuery(ctx, ai, r, seriesRequestOptions{
vars: req.vars,
metricNameCallback: func(s string) {
req.seriesRequest[i].metricWithNamespace = s
},
})
if err != nil {
return nil, false, err
}
Expand Down Expand Up @@ -3648,18 +3658,18 @@ func getQueryRespEqual(a, b *SeriesResponse) bool {
return true
}

func (h *Handler) parseHTTPRequest(r *http.Request) (seriesRequest, error) {
res, err := h.parseHTTPRequestS(r, 1)
func (h *Handler) parseHTTPRequest(r *http.Request) (seriesRequest, map[string]promql.Variable, error) {
res, vars, err := h.parseHTTPRequestS(r, 1)
if err != nil {
return seriesRequest{}, err
return seriesRequest{}, nil, err
}
if len(res) == 0 {
return seriesRequest{}, httpErr(http.StatusBadRequest, fmt.Errorf("request is empty"))
return seriesRequest{}, nil, httpErr(http.StatusBadRequest, fmt.Errorf("request is empty"))
}
return res[0], nil
return res[0], vars, nil
}

func (h *Handler) parseHTTPRequestS(r *http.Request, maxTabs int) (res []seriesRequest, err error) {
func (h *Handler) parseHTTPRequestS(r *http.Request, maxTabs int) (res []seriesRequest, env map[string]promql.Variable, err error) {
defer func() {
var dummy httpError
if err != nil && !errors.As(err, &dummy) {
Expand Down Expand Up @@ -3763,7 +3773,13 @@ func (h *Handler) parseHTTPRequestS(r *http.Request, maxTabs int) (res []seriesR
tab.maxHost = v.MaxHost
n++
}
env = make(map[string]promql.Variable)
for _, v := range dash.Vars {
env[v.Name] = promql.Variable{
Value: v.Vals,
Group: v.Args.Group,
Negate: v.Args.Negate,
}
for _, link := range v.Link {
if len(link) != 2 {
continue
Expand Down Expand Up @@ -3894,7 +3910,7 @@ func (h *Handler) parseHTTPRequestS(r *http.Request, maxTabs int) (res []seriesR
switch s[1] {
case "g":
varByName(s[0]).group = first(v)
case "nk":
case "nv":
varByName(s[0]).negate = first(v)
}
}
Expand Down Expand Up @@ -3949,7 +3965,7 @@ func (h *Handler) parseHTTPRequestS(r *http.Request, maxTabs int) (res []seriesR
var tid string
tid, err = parseTagID(s)
if err != nil {
return nil, err
return nil, nil, err
}
t.by = append(t.by, tid)
}
Expand All @@ -3969,7 +3985,7 @@ func (h *Handler) parseHTTPRequestS(r *http.Request, maxTabs int) (res []seriesR
case Version1, Version2:
t.version = s
default:
return nil, fmt.Errorf("invalid version: %q", s)
return nil, nil, fmt.Errorf("invalid version: %q", s)
}
case ParamWidth:
t.strWidth = first(v)
Expand All @@ -3987,17 +4003,22 @@ func (h *Handler) parseHTTPRequestS(r *http.Request, maxTabs int) (res []seriesR
t.expandToLODBoundary = true
}
if err != nil {
return nil, err
return nil, nil, err
}
}
if len(tabs) == 0 {
return nil, nil
return nil, nil, nil
}
for _, v := range vars {
vv := varM[v.name]
if vv == nil {
continue
}
env[v.name] = promql.Variable{
Value: vv.val,
Group: vv.group == "1",
Negate: vv.negate == "1",
}
for _, link := range v.link {
if len(link) != 2 {
continue
Expand Down Expand Up @@ -4068,28 +4089,28 @@ func (h *Handler) parseHTTPRequestS(r *http.Request, maxTabs int) (res []seriesR
if len(tab0.strFrom) != 0 || len(tab0.strTo) != 0 {
tab0.from, tab0.to, err = parseFromTo(tab0.strFrom, tab0.strTo)
if err != nil {
return nil, err
return nil, nil, err
}
}
err = finalize(tab0)
if err != nil {
return nil, err
return nil, nil, err
}
for i := range tabs[1:] {
t := &tabs[i+1]
t.from = tab0.from
t.to = tab0.to
err = finalize(t)
if err != nil {
return nil, err
return nil, nil, err
}
}
// build resulting slice
if tabX != -1 {
if tabs[tabX].strType == "1" {
return nil, nil
return nil, nil, nil
}
return []seriesRequest{tabs[tabX].seriesRequest}, nil
return []seriesRequest{tabs[tabX].seriesRequest}, env, nil
}
res = make([]seriesRequest, 0, len(tabs))
for _, t := range tabs {
Expand All @@ -4100,7 +4121,7 @@ func (h *Handler) parseHTTPRequestS(r *http.Request, maxTabs int) (res []seriesR
res = append(res, t.seriesRequest)
}
}
return res, nil
return res, env, nil
}

func (r *DashboardTimeRange) UnmarshalJSON(bs []byte) error {
Expand Down
61 changes: 60 additions & 1 deletion internal/promql/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
const (
labelWhat = "__what__"
labelBy = "__by__"
labelBind = "__bind__"
labelOffset = "__offset__"
labelTotal = "__total__"
LabelShard = "__shard__"
Expand Down Expand Up @@ -56,6 +57,7 @@ type Options struct {
MaxHost bool
Offsets []int64
Rand *rand.Rand
Vars map[string]Variable

ExprQueriesSingleMetricCallback MetricMetaValueCallback
SeriesQueryCallback SeriesQueryCallback
Expand All @@ -66,6 +68,12 @@ type (
SeriesQueryCallback func(version string, key string, pq any, lod any, avoidCache bool)
)

type Variable struct {
Value []string
Group bool
Negate bool
}

type Engine struct {
h Handler
loc *time.Location
Expand Down Expand Up @@ -190,7 +198,9 @@ func (ng Engine) newEvaluator(ctx context.Context, qry Query) (evaluator, error)
parser.Inspect(ast, func(node parser.Node, path []parser.Node) error {
switch e := node.(type) {
case *parser.VectorSelector:
err = ng.matchMetrics(ctx, e, path, metricOffset, offsets[0])
if err = ng.bindVariables(e, qry.Options.Vars); err == nil {
err = ng.matchMetrics(ctx, e, path, metricOffset, offsets[0])
}
case *parser.MatrixSelector:
if maxRange < e.Range {
maxRange = e.Range
Expand Down Expand Up @@ -264,6 +274,55 @@ func (ng Engine) newEvaluator(ctx context.Context, qry Query) (evaluator, error)
}, nil
}

func (ng Engine) bindVariables(sel *parser.VectorSelector, vars map[string]Variable) error {
var s []*labels.Matcher
for _, matcher := range sel.LabelMatchers {
if matcher.Name == labelBind {
if matcher.Type != labels.MatchEqual {
return fmt.Errorf("%s supports only strict equality", matcher.Name)
}
s = append(s, matcher)
}
}
for _, matcher := range s {
for _, bind := range strings.Split(matcher.Value, ",") {
s := strings.Split(bind, ":")
if len(s) != 2 || len(s[0]) == 0 || len(s[1]) == 0 {
return fmt.Errorf("%s invalid value format: expected \"tag:var\", got %q", matcher.Name, bind)
}
var (
vn = s[1] // variable name
vv Variable // variable value
ok bool
)
if vv, ok = vars[vn]; !ok {
return fmt.Errorf("variable %q not specified", vn)
}
var mt labels.MatchType
if vv.Negate {
mt = labels.MatchNotEqual
} else {
mt = labels.MatchEqual
}
var (
tn = s[0] // tag name
m *labels.Matcher
err error
)
for _, v := range vv.Value {
if m, err = labels.NewMatcher(mt, tn, v); err != nil {
return err
}
sel.LabelMatchers = append(sel.LabelMatchers, m)
}
if vv.Group {
sel.GroupBy = append(sel.GroupBy, tn)
}
}
}
return nil
}

func (ng Engine) matchMetrics(ctx context.Context, sel *parser.VectorSelector, path []parser.Node, metricOffset map[*format.MetricMetaValue]int64, offset int64) error {
for _, matcher := range sel.LabelMatchers {
if len(sel.MatchingMetrics) != 0 && len(sel.What) != 0 {
Expand Down

0 comments on commit cb53ba5

Please sign in to comment.