From e5c9a64f2e47e3a73359907a0f0429246a74156b Mon Sep 17 00:00:00 2001 From: Dmitry Maslennikov Date: Tue, 16 Apr 2024 20:19:44 +1000 Subject: [PATCH] vector_dot_product --- langchain_iris/vectorstores.py | 32 +++++++++++--------------------- 1 file changed, 11 insertions(+), 21 deletions(-) diff --git a/langchain_iris/vectorstores.py b/langchain_iris/vectorstores.py index 173cb7e..ce7e4ed 100644 --- a/langchain_iris/vectorstores.py +++ b/langchain_iris/vectorstores.py @@ -5,9 +5,7 @@ import enum import logging import uuid -from functools import partial from typing import ( - TYPE_CHECKING, Any, Callable, Dict, @@ -20,24 +18,19 @@ Type, ) -import numpy as np from sqlalchemy import ( Connection, and_, asc, - literal, VARCHAR, - UUID, TEXT, Column, - String, Table, create_engine, insert, text, delete, Row, - func, ) from sqlalchemy_iris import IRISListBuild from sqlalchemy_iris import IRISVector as IRISVectorType @@ -53,9 +46,6 @@ from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore -from langchain.utils import get_from_dict_or_env -from langchain.vectorstores.utils import maximal_marginal_relevance - Base = declarative_base() # type: Any @@ -64,7 +54,7 @@ class DistanceStrategy(str, enum.Enum): EUCLIDEAN = "l2" COSINE = "cosine" - MAX_INNER_PRODUCT = "inner" + DOT_PRODUCT = "dot" DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE @@ -77,7 +67,6 @@ class DistanceStrategy(str, enum.Enum): class IRISVector(VectorStore): _conn = None native_vector = False - native_vector_cosine_similarity = False def __init__( self, @@ -201,8 +190,8 @@ def distance_strategy(self) -> str: if self.native_vector: if self._distance_strategy == DistanceStrategy.COSINE: return self.table.c.embedding.cosine - elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: - return self.table.c.embedding.max_inner_product + elif self._distance_strategy == DistanceStrategy.DOT_PRODUCT: + return self.table.c.embedding.DOT_product # elif self._distance_strategy == DistanceStrategy.EUCLIDEAN: # return "langchain_l2_distance" else: @@ -215,7 +204,7 @@ def distance_strategy(self) -> str: return "langchain_l2_distance" elif self._distance_strategy == DistanceStrategy.COSINE: return "langchain_cosine_distance" - elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: + elif self._distance_strategy == DistanceStrategy.DOT_PRODUCT: return "langchain_inner_distance" else: raise ValueError( @@ -229,9 +218,6 @@ def connect(self) -> Connection: try: if conn.dialect.supports_vectors: self.native_vector = True - self.native_vector_cosine_similarity = ( - conn.dialect.vector_cosine_similarity - ) except: # noqa pass return conn @@ -307,8 +293,8 @@ def _select_relevance_score_fn(self) -> Callable[[float], float]: return self._cosine_relevance_score_fn elif self._distance_strategy == DistanceStrategy.EUCLIDEAN: return self._euclidean_relevance_score_fn - elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT: - return self._max_inner_product_relevance_score_fn + elif self._distance_strategy == DistanceStrategy.DOT_PRODUCT: + return self._DOT_product_relevance_score_fn else: raise ValueError( "No supported normalization function" @@ -550,7 +536,11 @@ def similarity_search_with_score_by_vector( page_content=result.document, metadata=json.loads(result.metadata), ), - round(float(result.distance), 15) if self.embedding_function is not None else None, + ( + round(float(result.distance), 15) + if self.embedding_function is not None + else None + ), ) for result in results ]