Skip to content

Commit

Permalink
Add weights
Browse files Browse the repository at this point in the history
  • Loading branch information
fat-fellow committed Nov 5, 2024
1 parent e97c5b6 commit 134cb29
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 62 deletions.
2 changes: 1 addition & 1 deletion example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func main() {
SetQuery("body").
SetDocsLimit(100).
SetWithHighlights(true).
AddFieldWithoutWeight(NameBody).
AddFieldDefaultWeight(NameBody).
Build()

result, err := index.Search(sCtx)
Expand Down
20 changes: 9 additions & 11 deletions rust/src/c_util/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,16 @@ pub fn process_slice<'a, F, T>(
mut func: F,
) -> Result<(), ()>
where
F: FnMut(T) -> Result<(), ()>,
F: FnMut(usize, T) -> Result<(), ()>,
T: Copy,
{
let slice = match assert_pointer(ptr, error_buffer) {
Some(ptr) => unsafe { slice::from_raw_parts(ptr, len) },
None => return Err(()),
};

for item in slice {
if func(*item).is_err() {
for (i, item) in slice.iter().enumerate() {
if func(i, *item).is_err() {
return Err(());
}
}
Expand Down Expand Up @@ -307,10 +307,10 @@ pub fn delete_docs(

let field = match schema_apply_for_field::<Field, (), _>
(error_buffer, schema.clone(), field_name, |field, _|
match get_string_field_entry(schema.clone(), field) {
Ok(value) => Ok(value),
Err(_) => Err(())
},
match get_string_field_entry(schema.clone(), field) {
Ok(value) => Ok(value),
Err(_) => Err(())
},
) {
Ok(value) => value,
Err(_) => {
Expand Down Expand Up @@ -399,10 +399,8 @@ pub fn search(

let mut weights = HashMap::with_capacity(field_names_len);

let iter = 0;
if process_slice(field_weights_ptr, error_buffer, field_names_len, |field_weight| {
weights.insert(fields[iter], field_weight);
debug!("weights azaza: {:?}", weights);
if process_slice(field_weights_ptr, error_buffer, field_names_len, |i, field_weight| {
weights.insert(fields[i], field_weight);
Ok(())
}).is_err() {
return Err(());
Expand Down
34 changes: 5 additions & 29 deletions searchcontext.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package tantivy_go

import "fmt"

// SearchContext defines the interface for searchContext
type SearchContext interface {
// GetQuery returns the search query string.
Expand All @@ -10,8 +8,8 @@ type SearchContext interface {
GetDocsLimit() uintptr
// WithHighlights returns true if highlights are enabled.
WithHighlights() bool
// GetFieldWeights returns slices of field names and their corresponding weights.
GetFieldWeights() ([]string, []float32)
// GetFieldAndWeights returns slices of field names and their corresponding weights.
GetFieldAndWeights() ([]string, []float32)
}

// searchContext is a structure that implements SearchContext.
Expand Down Expand Up @@ -43,7 +41,7 @@ func (sc *searchContext) GetFieldNames() map[string]float32 {
}

// GetFieldWeights returns slices of field names and their corresponding weights.
func (sc *searchContext) GetFieldWeights() ([]string, []float32) {
func (sc *searchContext) GetFieldAndWeights() ([]string, []float32) {
fields := make([]string, 0, len(sc.fieldNames))
weights := make([]float32, 0, len(sc.fieldNames))

Expand Down Expand Up @@ -93,8 +91,8 @@ func (b *SearchContextBuilder) AddField(field string, weight float32) *SearchCon
return b
}

// AddFieldWithoutWeight adds a field with a default weight of 1.0 to searchContext.
func (b *SearchContextBuilder) AddFieldWithoutWeight(field string) *SearchContextBuilder {
// AddFieldDefaultWeight adds a field with a default weight of 1.0 to searchContext.
func (b *SearchContextBuilder) AddFieldDefaultWeight(field string) *SearchContextBuilder {
b.context.fieldNames[field] = 1.0
return b
}
Expand All @@ -103,25 +101,3 @@ func (b *SearchContextBuilder) AddFieldWithoutWeight(field string) *SearchContex
func (b *SearchContextBuilder) Build() SearchContext {
return b.context
}

// Example usage of the searchContext and SearchContextBuilder.
func main() {
builder := NewSearchContextBuilder()
searchContext := builder.
SetQuery("example search query").
SetDocsLimit(10).
SetWithHighlights(true).
AddField("title", 1.5).
AddFieldWithoutWeight("description").
Build()

// Retrieve fields and weights
fields, weights := searchContext.GetFieldWeights()
fmt.Printf("Fields: %v\n", fields)
fmt.Printf("Weights: %v\n", weights)

// Additional information
fmt.Printf("searchContext Query: %s\n", searchContext.GetQuery())
fmt.Printf("Docs Limit: %d\n", searchContext.GetDocsLimit())
fmt.Printf("With Highlights: %t\n", searchContext.WithHighlights())
}
151 changes: 135 additions & 16 deletions tantivy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import (
const NameBody = "body"
const NameId = "id"
const NameTitle = "title"
const NameBodyCh = "bodyCh"
const NameTitleCh = "titleCh"
const NameBodyZh = "bodyCh"
const NameTitleZh = "titleCh"

const limit = 40
const minGram = 2
Expand Down Expand Up @@ -51,7 +51,14 @@ func Test(t *testing.T) {
err = tc.AddAndConsumeDocuments(doc)
require.NoError(t, err)

result, err := tc.Search("body", 100, true, NameBody)
sCtx := tantivy_go.NewSearchContextBuilder().
SetQuery("body").
SetDocsLimit(100).
SetWithHighlights(true).
AddFieldDefaultWeight(NameBody).
Build()

result, err := tc.Search(sCtx)
require.NoError(t, err)

size, err := result.GetSize()
Expand Down Expand Up @@ -177,14 +184,26 @@ func Test(t *testing.T) {
require.NoError(t, err)
require.Equal(t, uint64(1), docs)

result, err := tc.Search("ย", 100, true, NameTitle)
sCtx := tantivy_go.NewSearchContextBuilder().
SetQuery("ย").
SetDocsLimit(100).
SetWithHighlights(true).
AddFieldDefaultWeight(NameTitle).
Build()
result, err := tc.Search(sCtx)
require.NoError(t, err)

size, err := result.GetSize()
defer result.Free()
require.Equal(t, 0, int(size))

result2, err := tc.Search("ย่", 100, true, NameTitle)
sCtx2 := tantivy_go.NewSearchContextBuilder().
SetQuery("ย่").
SetDocsLimit(100).
SetWithHighlights(true).
AddFieldDefaultWeight(NameTitle).
Build()
result2, err := tc.Search(sCtx2)
require.NoError(t, err)

size2, err := result2.GetSize()
Expand Down Expand Up @@ -217,28 +236,52 @@ func Test(t *testing.T) {
require.NoError(t, err)
require.Equal(t, uint64(1), docs)

result, err := tc.Search("Idées fête", 100, true, NameTitle)
sCtx := tantivy_go.NewSearchContextBuilder().
SetQuery("Idées fête").
SetDocsLimit(100).
SetWithHighlights(true).
AddFieldDefaultWeight(NameTitle).
Build()
result, err := tc.Search(sCtx)
require.NoError(t, err)

size, err := result.GetSize()
defer result.Free()
require.Equal(t, 1, int(size))

result2, err := tc.Search("idees fete", 100, true, NameTitle)
sCtx2 := tantivy_go.NewSearchContextBuilder().
SetQuery("idees fete").
SetDocsLimit(100).
SetWithHighlights(true).
AddFieldDefaultWeight(NameTitle).
Build()
result2, err := tc.Search(sCtx2)
require.NoError(t, err)

size2, err := result2.GetSize()
defer result2.Free()
require.Equal(t, 1, int(size2))

result3, err := tc.Search("straße", 100, true, NameBody)
sCtx3 := tantivy_go.NewSearchContextBuilder().
SetQuery("straße").
SetDocsLimit(100).
SetWithHighlights(true).
AddFieldDefaultWeight(NameBody).
Build()
result3, err := tc.Search(sCtx3)
require.NoError(t, err)

size3, err := result3.GetSize()
defer result3.Free()
require.Equal(t, 1, int(size3))

result4, err := tc.Search("strasse", 100, true, NameBody)
sCtx4 := tantivy_go.NewSearchContextBuilder().
SetQuery("strasse").
SetDocsLimit(100).
SetWithHighlights(true).
AddFieldDefaultWeight(NameBody).
Build()
result4, err := tc.Search(sCtx4)
require.NoError(t, err)

size4, err := result4.GetSize()
Expand All @@ -261,7 +304,13 @@ func Test(t *testing.T) {
require.NoError(t, err)
require.Equal(t, uint64(1), docs)

result, err := tc.Search("1", 100, true, NameId)
sCtx := tantivy_go.NewSearchContextBuilder().
SetQuery("1").
SetDocsLimit(100).
SetWithHighlights(true).
AddFieldDefaultWeight(NameId).
Build()
result, err := tc.Search(sCtx)
require.NoError(t, err)

size, err := result.GetSize()
Expand All @@ -285,7 +334,13 @@ func Test(t *testing.T) {
err = tc.AddAndConsumeDocuments(doc)
require.NoError(t, err)

result, err := tc.Search("create", 100, true, NameTitle)
sCtx := tantivy_go.NewSearchContextBuilder().
SetQuery("create").
SetDocsLimit(100).
SetWithHighlights(true).
AddFieldDefaultWeight(NameTitle).
Build()
result, err := tc.Search(sCtx)
require.NoError(t, err)

size, err := result.GetSize()
Expand Down Expand Up @@ -346,13 +401,77 @@ func Test(t *testing.T) {
require.NoError(t, err)
require.Equal(t, uint64(2), docs)

result, err := tc.Search("售货员", 100, true, NameBodyCh, NameTitleCh)
sCtx := tantivy_go.NewSearchContextBuilder().
SetQuery("售货员").
SetDocsLimit(100).
SetWithHighlights(true).
AddFieldDefaultWeight(NameBodyZh).
AddFieldDefaultWeight(NameTitleZh).
Build()
result, err := tc.Search(sCtx)
require.NoError(t, err)

size, err := result.GetSize()
defer result.Free()
require.Equal(t, 2, int(size))
})

t.Run("docs search - when weights apply", func(t *testing.T) {
schema, tc := fx(t, limit, 1, false)

defer tc.Free()

doc, err := addDoc(t, "an apple", "", "id1", tc)
require.NoError(t, err)

doc2, err := addDoc(t, "", "an apple", "id2", tc)
require.NoError(t, err)

err = tc.AddAndConsumeDocuments(doc, doc2)
require.NoError(t, err)

docs, err := tc.NumDocs()
require.NoError(t, err)
require.Equal(t, uint64(2), docs)

sCtx := tantivy_go.NewSearchContextBuilder().
SetQuery("apple").
SetDocsLimit(100).
SetWithHighlights(false).
AddField(NameTitle, 1.0).
AddField(NameBody, 1.0).
Build()
result, err := tc.Search(sCtx)
require.NoError(t, err)

size, err := result.GetSize()
defer result.Free()
require.Equal(t, 2, int(size))
resDoc, err := result.Get(0)
require.NoError(t, err)
jsonStr, err := resDoc.ToJson(schema, NameId)
require.NoError(t, err)
require.JSONEq(t, `{"highlights":[],"id":"id1","score":1.9676434993743896}`, jsonStr)

sCtx2 := tantivy_go.NewSearchContextBuilder().
SetQuery("apple").
SetDocsLimit(100).
SetWithHighlights(false).
AddField(NameTitle, 1.0).
AddField(NameBody, 10.0).
Build()
result2, err := tc.Search(sCtx2)
require.NoError(t, err)

size2, err := result2.GetSize()
defer result2.Free()
require.Equal(t, 2, int(size2))
resDoc2, err := result2.Get(0)
require.NoError(t, err)
jsonStr2, err := resDoc2.ToJson(schema, NameId)
require.NoError(t, err)
require.JSONEq(t, `{"highlights":[],"id":"id2","score":4.919108867645264}`, jsonStr2)
})
}

func addDoc(
Expand All @@ -367,7 +486,7 @@ func addDoc(
err := doc.AddField(NameTitle, title, tc)
require.NoError(t, err)

err = doc.AddField(NameTitleCh, title, tc)
err = doc.AddField(NameTitleZh, title, tc)
require.NoError(t, err)

err = doc.AddField(NameId, id, tc)
Expand All @@ -376,7 +495,7 @@ func addDoc(
err = doc.AddField(NameBody, body, tc)
require.NoError(t, err)

err = doc.AddField(NameBodyCh, body, tc)
err = doc.AddField(NameBodyZh, body, tc)
return doc, err
}

Expand All @@ -402,7 +521,7 @@ func fx(
require.NoError(t, err)

err = builder.AddTextField(
NameTitleCh,
NameTitleZh,
true,
true,
false,
Expand Down Expand Up @@ -432,7 +551,7 @@ func fx(
require.NoError(t, err)

err = builder.AddTextField(
NameBodyCh,
NameBodyZh,
true,
true,
false,
Expand Down
8 changes: 3 additions & 5 deletions tantivycontext.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,14 @@ func (tc *TantivyContext) NumDocs() (uint64, error) {
// Search performs a search query on the index and returns the search results.
//
// Parameters:
// - query (string): The search query string.
// - docsLimit (uintptr): The maximum number of documents to return.
// - withHighlights (bool): Whether to include highlights in the results.
// - fieldNames (...string): The names of the fields to be included in the search.
// - sCtx (SearchContext): The context for the search, containing query string,
// document limit, highlight option, and field weights.
//
// Returns:
// - *SearchResult: A pointer to the SearchResult containing the search results.
// - error: An error if the search fails.
func (tc *TantivyContext) Search(sCtx SearchContext) (*SearchResult, error) {
fieldNames, weights := sCtx.GetFieldWeights()
fieldNames, weights := sCtx.GetFieldAndWeights()
if len(fieldNames) == 0 {
return nil, fmt.Errorf("fieldNames must not be empty")
}
Expand Down

0 comments on commit 134cb29

Please sign in to comment.