Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support IVF_FLAT and hamming in pylance #3301

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1870,24 +1870,28 @@ def create_index(
f" ({num_sub_vectors})"
)

if not pa.types.is_floating(field.type.value_type):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we check floating and int8/uint8 here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

raise TypeError(
f"Vector column {c} must have floating value type, "
f"got {field.type.value_type}"
)
if not (
pa.types.is_floating(field.type.value_type)
or pa.types.is_uint8(field.type.value_type)
):
raise TypeError(
f"Vector column {c} must have floating or binary (uint8) value type, "
f"got {field.type.value_type}"
)

if not isinstance(metric, str) or metric.lower() not in [
"l2",
"cosine",
"euclidean",
"dot",
"hamming",
]:
raise ValueError(f"Metric {metric} not supported.")

kwargs["metric_type"] = metric

index_type = index_type.upper()
valid_index_types = ["IVF_PQ", "IVF_HNSW_PQ", "IVF_HNSW_SQ"]
valid_index_types = ["IVF_FLAT", "IVF_PQ", "IVF_HNSW_PQ", "IVF_HNSW_SQ"]
if index_type not in valid_index_types:
raise NotImplementedError(
f"Only {valid_index_types} index types supported. " f"Got {index_type}"
Expand Down
23 changes: 23 additions & 0 deletions python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,29 @@ def test_create_4bit_ivf_pq_index(dataset, tmp_path):
assert index["indices"][0]["sub_index"]["nbits"] == 4


def test_ivf_flat_over_binary_vector(tmp_path):
dim = 128
nvec = 1000
data = np.random.randint(0, 256, (nvec, dim // 8)).tolist()
array = pa.array(data, type=pa.list_(pa.uint8(), dim // 8))
tbl = pa.Table.from_pydict({"vector": array})
ds = lance.write_dataset(tbl, tmp_path)
ds.create_index("vector", index_type="IVF_FLAT", num_partitions=4, metric="hamming")
stats = ds.stats.index_stats("vector_idx")
assert stats["indices"][0]["metric_type"] == "hamming"
assert stats["index_type"] == "IVF_FLAT"

query = np.random.randint(0, 256, dim // 8).astype(np.uint8)
ds.to_table(
nearest={
"column": "vector",
"q": query,
"k": 10,
"metric": "hamming",
}
)


def test_create_ivf_hnsw_pq_index(dataset, tmp_path):
assert not dataset.has_index
ann_ds = lance.write_dataset(dataset.to_table(), tmp_path / "indexed.lance")
Expand Down
23 changes: 21 additions & 2 deletions python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ use std::collections::HashMap;
use std::str;
use std::sync::Arc;

use arrow::array::AsArray;
use arrow::datatypes::UInt8Type;
use arrow::ffi_stream::ArrowArrayStreamReader;
use arrow::pyarrow::*;
use arrow_array::{Float32Array, RecordBatch, RecordBatchReader};
Expand Down Expand Up @@ -43,6 +45,7 @@ use lance::dataset::{
BatchInfo, BatchUDF, CommitBuilder, NewColumnTransform, UDFCheckpointStore, WriteDestination,
};
use lance::dataset::{ColumnAlteration, ProjectionRequest};
use lance::index::vector::utils::get_vector_element_type;
use lance::index::{vector::VectorIndexParams, DatasetIndexInternalExt};
use lance_arrow::as_fixed_size_list_array;
use lance_index::scalar::InvertedIndexParams;
Expand Down Expand Up @@ -685,8 +688,19 @@ impl Dataset {
None
};

let element_type = get_vector_element_type(&self_.ds, &column)
.map_err(|e| PyValueError::new_err(e.to_string()))?;
let scanner = match element_type {
DataType::UInt8 => {
let q = arrow::compute::cast(&q, &DataType::UInt8).map_err(|e| {
PyValueError::new_err(format!("Failed to cast q to binary vector: {}", e))
})?;
let q = q.as_primitive::<UInt8Type>();
scanner.nearest(&column, q, k)
}
_ => scanner.nearest(&column, &q, k),
};
scanner
.nearest(column.as_str(), &q, k)
.map(|s| {
let mut s = s.nprobs(nprobes);
if let Some(factor) = refine_factor {
Expand Down Expand Up @@ -1135,7 +1149,7 @@ impl Dataset {
"BITMAP" => IndexType::Bitmap,
"LABEL_LIST" => IndexType::LabelList,
"INVERTED" | "FTS" => IndexType::Inverted,
"IVF_PQ" | "IVF_HNSW_PQ" | "IVF_HNSW_SQ" => IndexType::Vector,
"IVF_FLAT" | "IVF_PQ" | "IVF_HNSW_PQ" | "IVF_HNSW_SQ" => IndexType::Vector,
_ => {
return Err(PyValueError::new_err(format!(
"Index type '{index_type}' is not supported."
Expand Down Expand Up @@ -1749,6 +1763,11 @@ fn prepare_vector_index_params(
}

match index_type {
"IVF_FLAT" => Ok(Box::new(VectorIndexParams::ivf_flat(
ivf_params.num_partitions,
m_type,
))),

"IVF_PQ" => Ok(Box::new(VectorIndexParams::with_ivf_pq_params(
m_type, ivf_params, pq_params,
))),
Expand Down
2 changes: 1 addition & 1 deletion rust/lance/src/index/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::{any::Any, collections::HashMap};
pub mod builder;
pub mod ivf;
pub mod pq;
mod utils;
pub mod utils;

#[cfg(test)]
mod fixture_test;
Expand Down
Loading