Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix database repost #36

Merged
merged 3 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions database/query/query_fuzz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ type structElement struct {
Name []string
// Full index of the field.
Addr []int
// Number of elements to generate for this slice, if applicable.
NumElem int
}

func (f *structElement) Clone() *structElement {
Expand All @@ -30,8 +32,9 @@ func (f *structElement) Clone() *structElement {
}

return &structElement{
Name: append([]string{}, f.Name...),
Addr: append([]int{}, f.Addr...),
Name: append([]string{}, f.Name...),
Addr: append([]int{}, f.Addr...),
NumElem: f.NumElem,
}
}

Expand Down Expand Up @@ -60,7 +63,13 @@ func helperParseFilterStruct(t *testing.T, typ reflect.Type, parent *structEleme
el := parent.Clone()
el.Name = append(el.Name, field.Name)
el.Addr = append(el.Addr, field.Index...)
el.NumElem = 1
fields = append(fields, el)
if field.Type.Kind() == reflect.Slice {
next := el.Clone()
next.NumElem = int(rand.Int31n(6)) + 1
fields = append(fields, next)
}

case reflect.String:
for _, v := range []string{"Images", "Quotes", "References", "Videos", "Expiration"} {
Expand Down Expand Up @@ -93,23 +102,25 @@ func helperNewFilterFromElements(t *testing.T, fields []*structElement) model.Fi
value := reflect.ValueOf(&f).Elem().FieldByIndex(field.GetAddress())
switch field.GetName() {
case "Authors", "IDs":
n := rand.Int31n(4)
vals := make([]string, n)
for i := range n {
vals := make([]string, field.NumElem)
for i := range field.NumElem {
vals[i] = generateHexString()
}
value.Set(reflect.ValueOf(vals))

case "Kinds":
k := []int{generateKind()}
value.Set(reflect.ValueOf(k))
vals := make([]int, field.NumElem)
for i := range field.NumElem {
vals[i] = generateKind()
}
value.Set(reflect.ValueOf(vals))

case "Tags":
values := []string{}
for range rand.Intn(3) {
values = append(values, generateHexString())
vals := make([]string, field.NumElem)
for i := range field.NumElem {
vals[i] = generateHexString()
}
m := model.TagMap{}.SetLiterals("e", values...)
m := model.TagMap{}.SetLiterals("e", vals...)

value.Set(reflect.ValueOf(m))

Expand Down Expand Up @@ -203,15 +214,22 @@ func TestQueryFuzzNoUseTempBTREEOrScan(t *testing.T) {

rows, err := stmt.QueryContext(context.Background(), params)
require.NoError(t, err)
var hasPK bool
for rows.Next() {
var s1, s2, s3, s4 string
err := rows.Scan(&s1, &s2, &s3, &s4)
require.NoError(t, err)
op[s4]++
if strings.Contains(s4, "SEARCH e USING PRIMARY KEY") {
hasPK = true
}
if s4 == "USE TEMP B-TREE FOR ORDER BY" || (strings.HasPrefix(s4, "SCAN ") && !strings.Contains(s4, "INDEX")) {
if strings.Contains(filter.Search, "Expiration:true") {
// It uses SCAN over CTE, which is expected.
continue
} else if (hasPK || len(filter.Authors) > 0) && s4 == "USE TEMP B-TREE FOR ORDER BY" {
// Allow B-TREE for ORDER BY if there are multiple authors or PK is used.
continue
}
t.Logf("filter: %#v", filter)
t.Logf("set #%d: %s (%+v)", i+1, sql, params)
Expand Down
155 changes: 50 additions & 105 deletions database/query/query_where_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@ package query
import (
"cmp"
"log"
"slices"
"strconv"
"strings"
"sync"

"github.com/cockroachdb/errors"
"github.com/nbd-wtf/go-nostr"
Expand Down Expand Up @@ -70,35 +68,8 @@ type (
Tag string
Marker string
}
filterBuilder struct {
Name string
EventIds []string
EventIdsString string
sync.Once
}
)

func (f *filterBuilder) HasEvents() bool {
return len(f.EventIds) > 0
}

func (f *filterBuilder) BuildEvents(w *whereBuilder) string {
f.Do(func() {
f.EventIdsString = buildFromSlice(
&whereBuilder{
Params: w.Params,
},
sqlOpCodeAND,
f.Name,
f.EventIds,
"event_id",
"",
).String()
})

return f.EventIdsString
}

func parseEventAsFilterForDelete(e *model.Event) (*databaseFilterDelete, error) {
filter := databaseFilterDelete{
Author: e.PubKey,
Expand Down Expand Up @@ -174,12 +145,9 @@ func buildFromSlice[T comparable](builder *whereBuilder, op int, filterID string
}

maybeOpCode(builder, op)
if len(s) > 1 && (name == "id" || name == "pubkey") {
builder.WriteRune('+')
}
builder.WriteString(name)
s = model.DeduplicateSlice(s, func(elem T) T { return elem })
if len(s) == 1 && name != "kind" {
if len(s) == 1 {
// X = :X_name.
builder.WriteString(" = :")
builder.WriteString(builder.addParam(filterID, paramName, s[0]))
Expand Down Expand Up @@ -227,22 +195,22 @@ func (w *whereBuilder) maybeOR() {
w.WriteString(" OR ")
}

func (w *whereBuilder) applyFilterTagMarkers(filter *filterBuilder, markers []databaseFilterMarker) {
func (w *whereBuilder) applyFilterTagMarkers(name string, markers []databaseFilterMarker) {
if len(markers) == 0 {
return
}

for id, marker := range markers {
w.maybeAND()
w.WriteString("EXISTS (select true from event_tags where event_id = e.id AND event_tag_key = :")
w.WriteString(w.addParam(filter.Name, "mtag"+strconv.Itoa(id), marker.Tag))
w.WriteString(w.addParam(name, "mtag"+strconv.Itoa(id), marker.Tag))
w.WriteString(" AND event_tag_value3 = :")
w.WriteString(w.addParam(filter.Name, "mtagvalue"+strconv.Itoa(id), marker.Marker))
w.WriteString(w.addParam(name, "mtagvalue"+strconv.Itoa(id), marker.Marker))
w.WriteRune(')')
}
}

func (w *whereBuilder) applyFilterTags(filter *filterBuilder, tags model.TagMap) {
func (w *whereBuilder) applyFilterTags(name string, tags model.TagMap) {
const valuesMax = 21

if len(tags) == 0 {
Expand All @@ -254,7 +222,7 @@ func (w *whereBuilder) applyFilterTags(filter *filterBuilder, tags model.TagMap)
tagID++

w.maybeAND()
tagParam := w.addParam(filter.Name, "tag"+strconv.Itoa(tagID), tagName)
tagParam := w.addParam(name, "tag"+strconv.Itoa(tagID), tagName)

// Only the tag name is specified, no values.
if !tags.HasValues(tagName) {
Expand Down Expand Up @@ -288,7 +256,7 @@ func (w *whereBuilder) applyFilterTags(filter *filterBuilder, tags model.TagMap)
w.WriteString("event_tag_value")
w.WriteString(strconv.Itoa(j + 1))
w.WriteString(" = :")
w.WriteString(w.addParam(filter.Name, "tagvalue"+strconv.Itoa(tagID<<8|(j+1)*(i+1)), *values[j]))
w.WriteString(w.addParam(name, "tagvalue"+strconv.Itoa(tagID<<8|(j+1)*(i+1)), *values[j]))
}
w.WriteRune(')')
}
Expand All @@ -311,12 +279,12 @@ func isFilterEmpty(filter *databaseFilterSearch) bool {
filter.Images == nil
}

func (w *whereBuilder) applyTimeRange(filter *filterBuilder, since, until *model.Timestamp) error {
func (w *whereBuilder) applyTimeRange(name string, since, until *model.Timestamp) error {
if since != nil && until != nil {
if *since == *until {
w.maybeAND()
w.WriteString("created_at = :")
w.WriteString(w.addParam(filter.Name, "timestamp", *since))
w.WriteString(w.addParam(name, "timestamp", *since))

return nil
} else if *since > *until {
Expand All @@ -328,14 +296,14 @@ func (w *whereBuilder) applyTimeRange(filter *filterBuilder, since, until *model
if since != nil && *since > 0 {
w.maybeAND()
w.WriteString("created_at >= :")
w.WriteString(w.addParam(filter.Name, "since", *since))
w.WriteString(w.addParam(name, "since", *since))
}

// The `until` property is similar except that `created_at` must be less than or equal to `until`.
if until != nil && *until > 0 {
w.maybeAND()
w.WriteString("created_at <= :")
w.WriteString(w.addParam(filter.Name, "until", *until))
w.WriteString(w.addParam(name, "until", *until))
}

return nil
Expand Down Expand Up @@ -368,22 +336,24 @@ func filterHasExtensions(filter *databaseFilterSearch) (positive, negative int)
return
}

func (w *whereBuilder) applyFilterForExtensions(filter *databaseFilterSearch, builder *filterBuilder, include bool) {
func (w *whereBuilder) applyFilterForExtensions(filter *databaseFilterSearch, include bool) {
separator := w.maybeOR
w.WriteString("select event_id from event_tags where ")
if include && builder.HasEvents() {
w.WriteString(builder.BuildEvents(w))
w.maybeAND()
if !include {
w.WriteString("NOT ")
}
w.WriteString("exists (select true from event_tags where event_id in (e.id, e.reference_id) AND (")

w.WriteRune('(')
if filter.Quotes != nil && *filter.Quotes == include {
separator()
w.WriteString("(event_tag_key = 'q')")
}
if filter.References != nil && *filter.References == include {
separator()
w.WriteString("(event_tag_key = 'e')")
result := "true"
if !include {
result = "false"
}
w.WriteString("(case when e.reference_id is not null then " + result + " else event_tag_key = 'e' end)")
}
if filter.Images != nil && *filter.Images == include {
separator()
Expand All @@ -409,85 +379,60 @@ func (w *whereBuilder) applyFilterForExtensions(filter *databaseFilterSearch, bu
w.WriteString(" as integer) > unixepoch())")
}
}
w.WriteRune(')')
w.WriteString("))")
}

func (w *whereBuilder) applyRepostFilter(filter *databaseFilterSearch, builder *filterBuilder, positiveExtensions, negativeExtensions *int) (applied bool) {
if (*positiveExtensions + *negativeExtensions) == 0 {
// No extensions in the filter.
return
func filterMainIndexField(filter *databaseFilterSearch) string {
if len(filter.Authors) > 0 {
return "master_pubkey"
}

if !slices.ContainsFunc(filter.Kinds, func(k int) bool {
return k == nostr.KindRepost || k == nostr.KindGenericRepost
}) {
// No reposts in the filter.
return
if len(filter.Kinds) > 0 {
return "kind"
}

// Not allowed.
filter.References = nil
*positiveExtensions &= ^extensionReferences
*negativeExtensions &= ^extensionReferences

if *positiveExtensions > 0 {
w.maybeAND()
w.WriteString("(+id IN (select e.id from events subev where subev.id = e.reference_id and subev.kind = 1 and exists (")
w.applyFilterForExtensions(filter, builder, true)
w.WriteString(")))")
}
return ""
}

if *negativeExtensions > 0 {
w.maybeAND()
w.WriteString("(+id NOT IN (select e.id from events subev where subev.id = e.reference_id and subev.kind = 1 and exists (")
w.applyFilterForExtensions(filter, builder, false)
w.WriteString(")))")
func filterMaybeForceIndex(filter *databaseFilterSearch, field string) string {
main := filterMainIndexField(filter)
if main == field {
field = "+" + field
}

return (*positiveExtensions + *negativeExtensions) > 0
return field
}

func (w *whereBuilder) applyFilter(idx int, filter *databaseFilterSearch) error {
if isFilterEmpty(filter) {
return nil
}

builder := &filterBuilder{
Name: "filter" + strconv.Itoa(idx) + "_",
EventIds: filter.IDs,
}
name := "filter" + strconv.Itoa(idx) + "_"
positiveExtensions, negativeExtensions := filterHasExtensions(filter)
w.WriteRune('(') // Begin the filter section.
if w.applyRepostFilter(filter, builder, &positiveExtensions, &negativeExtensions) {
buildFromSlice(w, sqlOpCodeAND, builder.Name, filter.IDs, "id", "")
} else {
if positiveExtensions > 0 {
w.WriteString("+id IN (")
w.applyFilterForExtensions(filter, builder, true)
w.WriteRune(')')
} else {
buildFromSlice(w, sqlOpCodeAND, builder.Name, filter.IDs, "id", "")
}
if negativeExtensions > 0 {
w.maybeAND()
w.WriteString("(+id NOT IN (")
w.applyFilterForExtensions(filter, builder, false)
w.WriteString("))")
}
buildFromSlice(w, sqlOpCodeNONE, name, filter.IDs, "id", "")
buildFromSlice(w, sqlOpCodeAND, name, filter.Kinds, filterMaybeForceIndex(filter, "kind"), "kind")
if positiveExtensions > 0 {
w.maybeAND()
w.applyFilterForExtensions(filter, true)
}
if negativeExtensions > 0 {
w.maybeAND()
w.applyFilterForExtensions(filter, false)
}
buildFromSlice(w, sqlOpCodeAND, builder.Name, filter.Kinds, "kind", "")
if len(filter.Authors) > 0 {
w.maybeAND()
w.WriteRune('(')
buildFromSlice(w, sqlOpCodeNONE, builder.Name, filter.Authors, "pubkey", "")
buildFromSlice(w, sqlOpCodeOR, builder.Name, filter.Authors, "master_pubkey", "pubkey")
w.WriteRune(')')
buildFromSlice(w, sqlOpCodeNONE, name, filter.Authors, "pubkey", "")
w.WriteString(" and hidden=0 OR ")
buildFromSlice(w, sqlOpCodeNONE, name, filter.Authors, "master_pubkey", "pubkey")
w.WriteString(" and hidden=0)")
}
if err := w.applyTimeRange(builder, filter.Since, filter.Until); err != nil {
if err := w.applyTimeRange(name, filter.Since, filter.Until); err != nil {
return err
}
w.applyFilterTags(builder, filter.Tags)
w.applyFilterTagMarkers(builder, filter.TagMarkers)
w.applyFilterTags(name, filter.Tags)
w.applyFilterTagMarkers(name, filter.TagMarkers)

w.WriteRune(')') // End the filter section.

Expand Down
Loading
Loading