diff --git a/qtext/config.py b/qtext/config.py index 004bd0b..f454ecc 100644 --- a/qtext/config.py +++ b/qtext/config.py @@ -2,7 +2,7 @@ import logging from pathlib import Path -from typing import Annotated, Type +from typing import Annotated, Literal, Type import msgspec @@ -24,6 +24,7 @@ class VectorStoreConfig(msgspec.Struct, kw_only=True, frozen=True): class EmbeddingConfig(msgspec.Struct, kw_only=True, frozen=True): + client: Literal["openai", "cohere"] = "cohere" model_name: str = "thenlper/gte-base" dim: Annotated[int, msgspec.Meta(ge=1, le=65535)] = 768 api_key: str = "fake" diff --git a/qtext/emb_client.py b/qtext/emb_client.py index d3f04bf..f315b92 100644 --- a/qtext/emb_client.py +++ b/qtext/emb_client.py @@ -1,5 +1,6 @@ from __future__ import annotations +import cohere import httpx import msgspec import openai @@ -31,6 +32,18 @@ def embedding(self, text: str | list[str]) -> list[float]: return response.data[0].embedding +class CohereEmbeddingClient: + def __init__(self, model_name: str, api_key: str): + self.client = cohere.Client(api_key=api_key) + self.model_name = model_name + + @time_it + @embedding_histogram.time() + def embedding(self, text: str) -> list[float]: + response = self.client.embed([text], model=self.model_name) + return response.embeddings[0] + + class SparseEmbeddingClient: def __init__(self, endpoint: str, dim: int, timeout: int) -> None: self.dim = dim diff --git a/qtext/engine.py b/qtext/engine.py index bc5e137..f15bc79 100644 --- a/qtext/engine.py +++ b/qtext/engine.py @@ -3,7 +3,11 @@ from time import perf_counter from qtext.config import Config -from qtext.emb_client import EmbeddingClient, SparseEmbeddingClient +from qtext.emb_client import ( + CohereEmbeddingClient, + EmbeddingClient, + SparseEmbeddingClient, +) from qtext.highlight_client import ENGLISH_STOPWORDS, HighlightClient from qtext.metrics import rerank_histogram from qtext.pg_client import PgVectorsClient @@ -25,12 +29,18 @@ def __init__(self, config: Config) -> None: self.resp_cls = self.querier.table_type self.pg_client = PgVectorsClient(config.vector_store.url, querier=self.querier) self.highlight_client = HighlightClient(config.highlight.addr) - self.emb_client = EmbeddingClient( - model_name=config.embedding.model_name, - api_key=config.embedding.api_key, - endpoint=config.embedding.api_endpoint, - timeout=config.embedding.timeout, - ) + if config.embedding.client == "openai": + self.emb_client = EmbeddingClient( + model_name=config.embedding.model_name, + api_key=config.embedding.api_key, + endpoint=config.embedding.api_endpoint, + timeout=config.embedding.timeout, + ) + else: + self.emb_client = CohereEmbeddingClient( + model_name=config.embedding.model_name, + api_key=config.embedding.api_key, + ) self.sparse_client = SparseEmbeddingClient( endpoint=config.sparse.addr, dim=config.sparse.dim, diff --git a/tui.py b/tui.py index 46517e6..b56508a 100644 --- a/tui.py +++ b/tui.py @@ -40,7 +40,7 @@ def compose(self) -> ComposeResult: with HorizontalScroll(id="namespace"): yield Label("Namespace:") yield Input( - "sparse_test", + "cohere_wiki", placeholder="Type the namespace you want to query", max_length=128, id="namespace-input",