Skip to content

Commit

Permalink
vector_dot_product
Browse files Browse the repository at this point in the history
  • Loading branch information
daimor committed Apr 16, 2024
1 parent d766737 commit e5c9a64
Showing 1 changed file with 11 additions and 21 deletions.
32 changes: 11 additions & 21 deletions langchain_iris/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
import enum
import logging
import uuid
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Expand All @@ -20,24 +18,19 @@
Type,
)

import numpy as np
from sqlalchemy import (
Connection,
and_,
asc,
literal,
VARCHAR,
UUID,
TEXT,
Column,
String,
Table,
create_engine,
insert,
text,
delete,
Row,
func,
)
from sqlalchemy_iris import IRISListBuild
from sqlalchemy_iris import IRISVector as IRISVectorType
Expand All @@ -53,9 +46,6 @@
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore

from langchain.utils import get_from_dict_or_env
from langchain.vectorstores.utils import maximal_marginal_relevance

Base = declarative_base() # type: Any


Expand All @@ -64,7 +54,7 @@ class DistanceStrategy(str, enum.Enum):

EUCLIDEAN = "l2"
COSINE = "cosine"
MAX_INNER_PRODUCT = "inner"
DOT_PRODUCT = "dot"


DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
Expand All @@ -77,7 +67,6 @@ class DistanceStrategy(str, enum.Enum):
class IRISVector(VectorStore):
_conn = None
native_vector = False
native_vector_cosine_similarity = False

def __init__(
self,
Expand Down Expand Up @@ -201,8 +190,8 @@ def distance_strategy(self) -> str:
if self.native_vector:
if self._distance_strategy == DistanceStrategy.COSINE:
return self.table.c.embedding.cosine
elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
return self.table.c.embedding.max_inner_product
elif self._distance_strategy == DistanceStrategy.DOT_PRODUCT:
return self.table.c.embedding.DOT_product
# elif self._distance_strategy == DistanceStrategy.EUCLIDEAN:
# return "langchain_l2_distance"
else:
Expand All @@ -215,7 +204,7 @@ def distance_strategy(self) -> str:
return "langchain_l2_distance"
elif self._distance_strategy == DistanceStrategy.COSINE:
return "langchain_cosine_distance"
elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
elif self._distance_strategy == DistanceStrategy.DOT_PRODUCT:
return "langchain_inner_distance"
else:
raise ValueError(
Expand All @@ -229,9 +218,6 @@ def connect(self) -> Connection:
try:
if conn.dialect.supports_vectors:
self.native_vector = True
self.native_vector_cosine_similarity = (
conn.dialect.vector_cosine_similarity
)
except: # noqa
pass
return conn
Expand Down Expand Up @@ -307,8 +293,8 @@ def _select_relevance_score_fn(self) -> Callable[[float], float]:
return self._cosine_relevance_score_fn
elif self._distance_strategy == DistanceStrategy.EUCLIDEAN:
return self._euclidean_relevance_score_fn
elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
return self._max_inner_product_relevance_score_fn
elif self._distance_strategy == DistanceStrategy.DOT_PRODUCT:
return self._DOT_product_relevance_score_fn
else:
raise ValueError(
"No supported normalization function"
Expand Down Expand Up @@ -550,7 +536,11 @@ def similarity_search_with_score_by_vector(
page_content=result.document,
metadata=json.loads(result.metadata),
),
round(float(result.distance), 15) if self.embedding_function is not None else None,
(
round(float(result.distance), 15)
if self.embedding_function is not None
else None
),
)
for result in results
]
Expand Down

0 comments on commit e5c9a64

Please sign in to comment.