Skip to content

Commit

Permalink
Feature: Search endpoint improvements (#128)
Browse files Browse the repository at this point in the history
  • Loading branch information
miagilepner authored Jun 25, 2021
1 parent c158bb5 commit ed4e215
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 21 deletions.
54 changes: 39 additions & 15 deletions query.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package stream_chat // nolint: golint

import (
"encoding/json"
"errors"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -117,24 +118,53 @@ func (c *Client) QueryChannels(q *QueryOption, sort ...*SortOption) ([]*Channel,

type SearchRequest struct {
// Required
Query string `json:"query"`
Filters map[string]interface{} `json:"filter_conditions"`
Query string `json:"query"`
Filters map[string]interface{} `json:"filter_conditions"`
MessageFilters map[string]interface{} `json:"message_filter_conditions"`

// Pagination, optional
Limit int `json:"limit,omitempty"`
Offset int `json:"offset,omitempty"`
Limit int `json:"limit,omitempty"`
Offset int `json:"offset,omitempty"`
Next string `json:"next,omitempty"`

// Sort, optional
Sort []SortOption `json:"sort,omitempty"`
}

type searchResponse struct {
Results []searchMessageResponse `json:"results"`
type SearchResponse struct {
Results []SearchMessageResponse `json:"results"`
Next string `json:"next,omitempty"`
Previous string `json:"previous,omitempty"`
}

type searchMessageResponse struct {
type SearchMessageResponse struct {
Message *Message `json:"message"`
}

// Search returns channels matching for given keyword.
func (c *Client) Search(request SearchRequest) ([]*Message, error) {
result, err := c.SearchWithFullResponse(request)
if err != nil {
return nil, err
}
messages := make([]*Message, 0, len(result.Results))
for _, res := range result.Results {
messages = append(messages, res.Message)
}

return messages, nil
}

// SearchWithFullResponse performs a search and returns the full results.
func (c *Client) SearchWithFullResponse(request SearchRequest) (*SearchResponse, error) {
if request.Offset != 0 {
if len(request.Sort) > 0 || request.Next != "" {
return nil, errors.New("cannot use Offset with Next or Sort parameters")
}
}
if request.Query != "" && len(request.MessageFilters) != 0 {
return nil, errors.New("can only specify Query or MessageFilters, not both")
}
var buf strings.Builder

if err := json.NewEncoder(&buf).Encode(request); err != nil {
Expand All @@ -144,17 +174,11 @@ func (c *Client) Search(request SearchRequest) ([]*Message, error) {
values := url.Values{}
values.Set("payload", buf.String())

var result searchResponse
var result SearchResponse
if err := c.makeRequest(http.MethodGet, "search", values, nil, &result); err != nil {
return nil, err
}

messages := make([]*Message, 0, len(result.Results))
for _, res := range result.Results {
messages = append(messages, res.Message)
}

return messages, nil
return &result, nil
}

type queryMessageFlagsResponse struct {
Expand Down
142 changes: 136 additions & 6 deletions query_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package stream_chat // nolint: golint

import (
"fmt"
"testing"
"time"

Expand Down Expand Up @@ -104,15 +105,144 @@ func TestClient_Search(t *testing.T) {
_, err = ch.SendMessage(&Message{Text: text + " " + randomString(25)}, user2.ID)
require.NoError(t, err)

got, err := c.Search(SearchRequest{Query: text, Filters: map[string]interface{}{
"members": map[string][]string{
"$in": {user1.ID, user2.ID},
t.Run("Query", func(tt *testing.T) {
got, err := c.Search(SearchRequest{Query: text, Filters: map[string]interface{}{
"members": map[string][]string{
"$in": {user1.ID, user2.ID},
},
}})

require.NoError(tt, err)

assert.Len(tt, got, 2)
})
t.Run("Message filters", func(tt *testing.T) {
got, err := c.Search(SearchRequest{
Filters: map[string]interface{}{
"members": map[string][]string{
"$in": {user1.ID, user2.ID},
},
},
MessageFilters: map[string]interface{}{
"text": map[string]interface{}{
"$q": text,
},
},
})
require.NoError(tt, err)

assert.Len(tt, got, 2)
})
t.Run("Query and message filters error", func(tt *testing.T) {
_, err := c.Search(SearchRequest{
Filters: map[string]interface{}{
"members": map[string][]string{
"$in": {user1.ID, user2.ID},
},
},
MessageFilters: map[string]interface{}{
"text": map[string]interface{}{
"$q": text,
},
},
Query: text,
})
require.Error(tt, err)
})
t.Run("Offset and sort error", func(tt *testing.T) {
_, err := c.Search(SearchRequest{
Filters: map[string]interface{}{
"members": map[string][]string{
"$in": {user1.ID, user2.ID},
},
},
Offset: 1,
Query: text,
Sort: []SortOption{{
Field: "created_at",
Direction: -1,
}},
})
require.Error(tt, err)
})
t.Run("Offset and next error", func(tt *testing.T) {
_, err := c.Search(SearchRequest{
Filters: map[string]interface{}{
"members": map[string][]string{
"$in": {user1.ID, user2.ID},
},
},
Offset: 1,
Query: text,
Next: randomString(5),
})
require.Error(tt, err)
})
}

func TestClient_SearchWithFullResponse(t *testing.T) {
t.Skip()
c := initClient(t)
ch := initChannel(t, c)

user1, user2 := randomUser(), randomUser()

text := randomString(10)

messageIDs := make([]string, 6)
for i := 0; i < 6; i++ {
userID := user1.ID
if i%2 == 0 {
userID = user2.ID
}
messageID := fmt.Sprintf("%d-%s", i, text)
_, err := ch.SendMessage(&Message{
ID: messageID,
Text: text + " " + randomString(25),
}, userID)
require.NoError(t, err)

messageIDs[6-i] = messageID
}

got, err := c.SearchWithFullResponse(SearchRequest{
Query: text,
Filters: map[string]interface{}{
"members": map[string][]string{
"$in": {user1.ID, user2.ID},
},
},
Sort: []SortOption{
{Field: "created_at", Direction: -1},
},
}})
Limit: 3,
})

gotMessageIDs := make([]string, 0, 6)
require.NoError(t, err)

assert.Len(t, got, 2)
assert.NotEmpty(t, got.Next)
assert.Len(t, got.Results, 3)
for _, result := range got.Results {
gotMessageIDs = append(gotMessageIDs, result.Message.ID)
}
got, err = c.SearchWithFullResponse(SearchRequest{
Query: text,
Filters: map[string]interface{}{
"members": map[string][]string{
"$in": {user1.ID, user2.ID},
},
},
Next: got.Next,
Limit: 3,
})
require.NoError(t, err)
assert.NotEmpty(t, got.Previous)
assert.Empty(t, got.Next)
assert.Len(t, got.Results, 3)
for _, result := range got.Results {
gotMessageIDs = append(gotMessageIDs, result.Message.ID)
}
assert.Equal(t, messageIDs, gotMessageIDs)
}

func TestClient_QueryMessageFlags(t *testing.T) {
Expand Down

0 comments on commit ed4e215

Please sign in to comment.