Skip to content

Commit

Permalink
Merge pull request #8 from anyproto/go-4369-research-jieba-add-weights
Browse files Browse the repository at this point in the history
Add weights
  • Loading branch information
fat-fellow authored Nov 5, 2024
2 parents c56da9d + 134cb29 commit 246130d
Show file tree
Hide file tree
Showing 7 changed files with 307 additions and 33 deletions.
1 change: 1 addition & 0 deletions bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ uint64_t context_num_docs(struct TantivyContext *context_ptr, char **error_buffe

struct SearchResult *context_search(struct TantivyContext *context_ptr,
const char **field_names_ptr,
float *field_weights_ptr,
uintptr_t field_names_len,
const char *query_ptr,
char **error_buffer,
Expand Down
9 changes: 8 additions & 1 deletion example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,14 @@ func main() {
}

// Search index
result, err := index.Search("body", 100, true, NameBody)
sCtx := tantivy_go.NewSearchContextBuilder().
SetQuery("body").
SetDocsLimit(100).
SetWithHighlights(true).
AddFieldDefaultWeight(NameBody).
Build()

result, err := index.Search(sCtx)
if err != nil {
fmt.Println("Failed to search index:", err)
return
Expand Down
51 changes: 44 additions & 7 deletions rust/src/c_util/util.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use std::{fs, panic, slice};
use std::collections::HashMap;
use std::ffi::{CStr, CString};
use std::os::raw::c_char;
use std::os::raw::{c_char, c_float};
use std::path::Path;
use std::sync::Mutex;
use lazy_static::lazy_static;
use log::debug;
use serde_json::json;
use tantivy::{Index, IndexWriter, TantivyDocument, TantivyError, Term};
use tantivy::{Index, IndexWriter, Score, TantivyDocument, TantivyError, Term};
use tantivy::directory::MmapDirectory;
use tantivy::query::{QueryParser};
use tantivy::schema::{Field, Schema};
Expand Down Expand Up @@ -133,6 +133,30 @@ where
Ok(())
}

pub fn process_slice<'a, F, T>(
ptr: *mut T,
error_buffer: *mut *mut c_char,
len: usize,
mut func: F,
) -> Result<(), ()>
where
F: FnMut(usize, T) -> Result<(), ()>,
T: Copy,
{
let slice = match assert_pointer(ptr, error_buffer) {
Some(ptr) => unsafe { slice::from_raw_parts(ptr, len) },
None => return Err(()),
};

for (i, item) in slice.iter().enumerate() {
if func(i, *item).is_err() {
return Err(());
}
}

Ok(())
}

pub fn schema_apply_for_field<'a, T, K, F: FnMut(Field, &'a str) -> Result<T, ()>>(
error_buffer: *mut *mut c_char,
schema: Schema,
Expand Down Expand Up @@ -283,10 +307,10 @@ pub fn delete_docs(

let field = match schema_apply_for_field::<Field, (), _>
(error_buffer, schema.clone(), field_name, |field, _|
match get_string_field_entry(schema.clone(), field) {
Ok(value) => Ok(value),
Err(_) => Err(())
},
match get_string_field_entry(schema.clone(), field) {
Ok(value) => Ok(value),
Err(_) => Err(())
},
) {
Ok(value) => value,
Err(_) => {
Expand Down Expand Up @@ -350,6 +374,7 @@ pub fn add_field(

pub fn search(
field_names_ptr: *mut *const c_char,
field_weights_ptr: *mut c_float,
field_names_len: usize,
query_ptr: *const c_char,
error_buffer: *mut *mut c_char,
Expand All @@ -372,12 +397,24 @@ pub fn search(
return Err(());
}

let mut weights = HashMap::with_capacity(field_names_len);

if process_slice(field_weights_ptr, error_buffer, field_names_len, |i, field_weight| {
weights.insert(fields[i], field_weight);
Ok(())
}).is_err() {
return Err(());
}

let query = match assert_string(query_ptr, error_buffer) {
Some(value) => value,
None => return Err(())
};

let query_parser = QueryParser::for_index(&context.index, fields);
let mut query_parser = QueryParser::for_index(&context.index, fields);
for (field, weight) in weights {
query_parser.set_field_boost(field, weight as Score);
}

let query = match query_parser.parse_query(query.as_str()) {
Ok(query) => query,
Expand Down
4 changes: 3 additions & 1 deletion rust/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::ffi::CString;
use std::os::raw::c_char;
use std::os::raw::{c_char, c_float};
use std::ptr;
use logcall::logcall;
use tantivy::{schema::*};
Expand Down Expand Up @@ -284,6 +284,7 @@ pub extern "C" fn context_num_docs(
pub extern "C" fn context_search(
context_ptr: *mut TantivyContext,
field_names_ptr: *mut *const c_char,
field_weights_ptr: *mut c_float,
field_names_len: usize,
query_ptr: *const c_char,
error_buffer: *mut *mut c_char,
Expand All @@ -297,6 +298,7 @@ pub extern "C" fn context_search(

match search(
field_names_ptr,
field_weights_ptr,
field_names_len,
query_ptr,
error_buffer,
Expand Down
103 changes: 103 additions & 0 deletions searchcontext.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package tantivy_go

// SearchContext defines the interface for searchContext
type SearchContext interface {
// GetQuery returns the search query string.
GetQuery() string
// GetDocsLimit returns the document limit as a uintptr.
GetDocsLimit() uintptr
// WithHighlights returns true if highlights are enabled.
WithHighlights() bool
// GetFieldAndWeights returns slices of field names and their corresponding weights.
GetFieldAndWeights() ([]string, []float32)
}

// searchContext is a structure that implements SearchContext.
type searchContext struct {
query string
docsLimit uintptr
withHighlights bool
fieldNames map[string]float32
}

// GetQuery returns the search query string.
func (sc *searchContext) GetQuery() string {
return sc.query
}

// GetDocsLimit returns the document limit.
func (sc *searchContext) GetDocsLimit() uintptr {
return sc.docsLimit
}

// WithHighlights returns the highlights flag.
func (sc *searchContext) WithHighlights() bool {
return sc.withHighlights
}

// GetFieldNames returns a map of field names and their weights.
func (sc *searchContext) GetFieldNames() map[string]float32 {
return sc.fieldNames
}

// GetFieldWeights returns slices of field names and their corresponding weights.
func (sc *searchContext) GetFieldAndWeights() ([]string, []float32) {
fields := make([]string, 0, len(sc.fieldNames))
weights := make([]float32, 0, len(sc.fieldNames))

for field, weight := range sc.fieldNames {
fields = append(fields, field)
weights = append(weights, weight)
}

return fields, weights
}

// SearchContextBuilder is a builder structure for creating searchContext.
type SearchContextBuilder struct {
context *searchContext
}

// NewSearchContextBuilder creates a new instance of SearchContextBuilder.
func NewSearchContextBuilder() *SearchContextBuilder {
return &SearchContextBuilder{
context: &searchContext{
fieldNames: make(map[string]float32),
},
}
}

// SetQuery sets the query for searchContext.
func (b *SearchContextBuilder) SetQuery(query string) *SearchContextBuilder {
b.context.query = query
return b
}

// SetDocsLimit sets the docsLimit for searchContext.
func (b *SearchContextBuilder) SetDocsLimit(limit uintptr) *SearchContextBuilder {
b.context.docsLimit = limit
return b
}

// SetWithHighlights sets the withHighlights flag for searchContext.
func (b *SearchContextBuilder) SetWithHighlights(withHighlights bool) *SearchContextBuilder {
b.context.withHighlights = withHighlights
return b
}

// AddField adds a field with the specified weight to searchContext.
func (b *SearchContextBuilder) AddField(field string, weight float32) *SearchContextBuilder {
b.context.fieldNames[field] = weight
return b
}

// AddFieldDefaultWeight adds a field with a default weight of 1.0 to searchContext.
func (b *SearchContextBuilder) AddFieldDefaultWeight(field string) *SearchContextBuilder {
b.context.fieldNames[field] = 1.0
return b
}

// Build returns the constructed searchContext as an interface.
func (b *SearchContextBuilder) Build() SearchContext {
return b.context
}
Loading

0 comments on commit 246130d

Please sign in to comment.