Skip to content

Commit

Permalink
Fix dup file error (#241)
Browse files Browse the repository at this point in the history
* Fix dup file error

* Fix max_score bug

* Fix lint

* Remove empty flag
  • Loading branch information
moria97 authored Oct 10, 2024
1 parent f3a620b commit 7edee20
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 130 deletions.
5 changes: 3 additions & 2 deletions src/pai_rag/app/api/service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from fastapi import APIRouter, FastAPI
from pai_rag.core.rag_configuration import RagConfiguration
from pai_rag.core.rag_service import rag_service
from pai_rag.app.api import query
from pai_rag.app.api.middleware import init_middleware
Expand All @@ -11,8 +12,8 @@ def init_router(app: FastAPI):
app.include_router(api_router, prefix="/service")


def configure_app(app: FastAPI, config_file: str):
rag_service.initialize(config_file)
def configure_app(app: FastAPI, rag_configuration: RagConfiguration):
rag_service.initialize(rag_configuration)
init_middleware(app)
init_router(app)
config_app_errors(app)
9 changes: 2 additions & 7 deletions src/pai_rag/core/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from typing import Any, List
import logging

from pai_rag.core.rag_trace import init_trace

TASK_STATUS_FILE = "__upload_task_status.tmp"
logger = logging.getLogger(__name__)

Expand All @@ -43,16 +41,13 @@ async def _a_trace_correlation_id(*args, **kwargs):


class RagService:
def initialize(self, config_file: str):
self.config_file = config_file
self.rag_configuration = RagConfiguration.from_file(config_file)
def initialize(self, rag_configuration: RagConfiguration):
self.rag_configuration = rag_configuration
self.config_dict_value = self.rag_configuration.get_value().to_dict()
self.config_modified_time = self.rag_configuration.get_config_mtime()

self.rag_configuration.persist()

init_trace(self.rag_configuration.get_value().get("RAG.trace"))

self.rag = RagApplication()
self.rag.initialize(self.rag_configuration.get_value())

Expand Down
7 changes: 4 additions & 3 deletions src/pai_rag/integrations/index/pai/local/local_bm25_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,12 @@ def query(
)
)

if normalize:
if normalize and len(results) > 0:
bm25_scores = [node.score for node in results]
max_score = max(bm25_scores)
for node_with_score in results:
node_with_score.score = node_with_score.score / max_score
if max_score > 0:
for node_with_score in results:
node_with_score.score = node_with_score.score / max_score

return results

Expand Down
20 changes: 0 additions & 20 deletions src/pai_rag/integrations/index/pai/multimodal/multimodal_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,6 @@ def __init__(
# keep image_vector_store here for backward compatibility
image_vector_store: Optional[BasePydanticVectorStore] = None,
image_embed_model: EmbedType = "clip:ViT-B/32",
# is_image_vector_store_empty is used to indicate whether image_vector_store is empty
# those flags are used for cases when only one vector store is used
is_image_vector_store_empty: bool = False,
is_text_vector_store_empty: bool = False,
# deprecated
service_context: Optional[ServiceContext] = None,
**kwargs: Any,
Expand Down Expand Up @@ -105,8 +101,6 @@ def __init__(
self.image_namespace
]

self._is_image_vector_store_empty = is_image_vector_store_empty
self._is_text_vector_store_empty = is_text_vector_store_empty
storage_context = storage_context or StorageContext.from_defaults()

super().__init__(
Expand All @@ -129,14 +123,6 @@ def image_vector_store(self) -> BasePydanticVectorStore:
def image_embed_model(self) -> MultiModalEmbedding:
return self._image_embed_model

@property
def is_image_vector_store_empty(self) -> bool:
return self._is_image_vector_store_empty

@property
def is_text_vector_store_empty(self) -> bool:
return self._is_text_vector_store_empty

def as_retriever(self, **kwargs: Any) -> PaiMultiModalVectorIndexRetriever:
return PaiMultiModalVectorIndexRetriever(
self,
Expand Down Expand Up @@ -301,8 +287,6 @@ async def _async_add_nodes_to_index(
new_text_ids = await self.storage_context.vector_stores[
DEFAULT_VECTOR_STORE
].async_add(text_nodes, **insert_kwargs)
else:
self._is_text_vector_store_empty = True

if len(image_nodes) > 0:
# embed image nodes as images directly
Expand All @@ -317,8 +301,6 @@ async def _async_add_nodes_to_index(

# TODO: Fix for FAISS
new_img_ids = [f"{self.image_namespace}_{i}" for i in new_img_ids]
else:
self._is_image_vector_store_empty = True

# if the vector store doesn't store text, we need to add the nodes to the
# index struct and document store
Expand Down Expand Up @@ -378,7 +360,6 @@ def _add_nodes_to_index(
)

else:
self._is_text_vector_store_empty = True
logger.info("No text nodes to insert.")

if len(image_nodes) > 0:
Expand All @@ -402,7 +383,6 @@ def _add_nodes_to_index(
)

else:
self._is_image_vector_store_empty = True
logger.info("No image nodes to insert.")

# if the vector store doesn't store text, we need to add the nodes to the
Expand Down
171 changes: 75 additions & 96 deletions src/pai_rag/integrations/index/pai/multimodal/multimodal_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,34 +279,31 @@ def _text_retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
if not self._index.is_text_vector_store_empty:
if (
self._supports_hybrid_search
or self._vector_store_query_mode == VectorStoreQueryMode.DEFAULT
):
return self._text_retrieve_from_vector_store(query_bundle)
elif (
self._vector_store_query_mode == VectorStoreQueryMode.TEXT_SEARCH
or self._vector_store_query_mode == VectorStoreQueryMode.SPARSE
):
return self._local_bm25_index.query(
query_str=query_bundle.query_str,
top_n=self.similarity_top_k,
normalize=True,
)
else:
vector_nodes = self._text_retrieve_from_vector_store(query_bundle)
keyword_nodes = self._local_bm25_index.query(
query_str=query_bundle.query_str,
top_n=self.similarity_top_k,
normalize=True,
)

return self._fusion_nodes(
vector_nodes, keyword_nodes, self._similarity_top_k
)
if (
self._supports_hybrid_search
or self._vector_store_query_mode == VectorStoreQueryMode.DEFAULT
):
return self._text_retrieve_from_vector_store(query_bundle)
elif (
self._vector_store_query_mode == VectorStoreQueryMode.TEXT_SEARCH
or self._vector_store_query_mode == VectorStoreQueryMode.SPARSE
):
return self._local_bm25_index.query(
query_str=query_bundle.query_str,
top_n=self.similarity_top_k,
normalize=True,
)
else:
return []
vector_nodes = self._text_retrieve_from_vector_store(query_bundle)
keyword_nodes = self._local_bm25_index.query(
query_str=query_bundle.query_str,
top_n=self.similarity_top_k,
normalize=True,
)

return self._fusion_nodes(
vector_nodes, keyword_nodes, self._similarity_top_k
)

def text_retrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
if isinstance(str_or_query_bundle, str):
Expand All @@ -317,20 +314,17 @@ def _text_to_image_retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
if self._search_image and not self._index.is_image_vector_store_empty:
if self._image_vector_store.is_embedding_query:
# change the embedding for query bundle to Multi Modal Text encoder
query_bundle.embedding = (
self._image_embed_model.get_agg_embedding_from_queries(
query_bundle.embedding_strs
)
if self._image_vector_store.is_embedding_query:
# change the embedding for query bundle to Multi Modal Text encoder
query_bundle.embedding = (
self._image_embed_model.get_agg_embedding_from_queries(
query_bundle.embedding_strs
)

return self._get_nodes_with_embeddings(
query_bundle, self._image_similarity_top_k, self._image_vector_store
)
else:
return []

return self._get_nodes_with_embeddings(
query_bundle, self._image_similarity_top_k, self._image_vector_store
)

def text_to_image_retrieve(
self, str_or_query_bundle: QueryType
Expand All @@ -343,18 +337,15 @@ def _image_to_image_retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
if not self._index.is_image_vector_store_empty:
if self._image_vector_store.is_embedding_query:
# change the embedding for query bundle to Multi Modal Image encoder for image input
assert isinstance(self._index.image_embed_model, MultiModalEmbedding)
query_bundle.embedding = self._image_embed_model.get_image_embedding(
query_bundle.embedding_image[0]
)
return self._get_nodes_with_embeddings(
query_bundle, self._image_similarity_top_k, self._image_vector_store
if self._image_vector_store.is_embedding_query:
# change the embedding for query bundle to Multi Modal Image encoder for image input
assert isinstance(self._index.image_embed_model, MultiModalEmbedding)
query_bundle.embedding = self._image_embed_model.get_image_embedding(
query_bundle.embedding_image[0]
)
else:
return []
return self._get_nodes_with_embeddings(
query_bundle, self._image_similarity_top_k, self._image_vector_store
)

def image_to_image_retrieve(
self, str_or_query_bundle: QueryType
Expand Down Expand Up @@ -490,34 +481,31 @@ async def _atext_retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
if not self._index.is_text_vector_store_empty:
if (
self._supports_hybrid_search
or self._vector_store_query_mode == VectorStoreQueryMode.DEFAULT
):
return await self._atext_retrieve_from_vector_store(query_bundle)
elif (
self._vector_store_query_mode == VectorStoreQueryMode.TEXT_SEARCH
or self._vector_store_query_mode == VectorStoreQueryMode.SPARSE
):
return self._local_bm25_index.query(
query_str=query_bundle.query_str,
top_n=self.similarity_top_k,
normalize=True,
)
else:
vector_nodes = await self._atext_retrieve_from_vector_store(
query_bundle
)
keyword_nodes = self._local_bm25_index.query(
query_str=query_bundle.query_str,
top_n=self.similarity_top_k,
normalize=True,
)
if (
self._supports_hybrid_search
or self._vector_store_query_mode == VectorStoreQueryMode.DEFAULT
):
return await self._atext_retrieve_from_vector_store(query_bundle)
elif (
self._vector_store_query_mode == VectorStoreQueryMode.TEXT_SEARCH
or self._vector_store_query_mode == VectorStoreQueryMode.SPARSE
):
return self._local_bm25_index.query(
query_str=query_bundle.query_str,
top_n=self.similarity_top_k,
normalize=True,
)
else:
vector_nodes = await self._atext_retrieve_from_vector_store(query_bundle)
keyword_nodes = self._local_bm25_index.query(
query_str=query_bundle.query_str,
top_n=self.similarity_top_k,
normalize=True,
)

return self._fusion_nodes(
vector_nodes, keyword_nodes, self.similarity_top_k
)
return self._fusion_nodes(
vector_nodes, keyword_nodes, self.similarity_top_k
)

async def atext_retrieve(
self, str_or_query_bundle: QueryType
Expand All @@ -530,11 +518,7 @@ async def _atext_to_image_retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
if (
self._enable_multimodal
and self._search_image
and not self._index.is_image_vector_store_empty
):
if self._enable_multimodal and self._search_image:
if self._image_vector_store.is_embedding_query:
# change the embedding for query bundle to Multi Modal Text encoder
query_bundle.embedding = (
Expand Down Expand Up @@ -576,21 +560,16 @@ async def _aimage_to_image_retrieve(
self,
query_bundle: QueryBundle,
) -> List[NodeWithScore]:
if not self._index.is_image_vector_store_empty:
if self._image_vector_store.is_embedding_query:
# change the embedding for query bundle to Multi Modal Image encoder for image input
assert isinstance(self._index.image_embed_model, MultiModalEmbedding)
# Using the first imaage in the list for image retrieval
query_bundle.embedding = (
await self._image_embed_model.aget_image_embedding(
query_bundle.embedding_image[0]
)
)
return await self._aget_nodes_with_embeddings(
query_bundle, self._image_similarity_top_k, self._image_vector_store
if self._image_vector_store.is_embedding_query:
# change the embedding for query bundle to Multi Modal Image encoder for image input
assert isinstance(self._index.image_embed_model, MultiModalEmbedding)
# Using the first imaage in the list for image retrieval
query_bundle.embedding = await self._image_embed_model.aget_image_embedding(
query_bundle.embedding_image[0]
)
else:
return []
return await self._aget_nodes_with_embeddings(
query_bundle, self._image_similarity_top_k, self._image_vector_store
)

async def aimage_to_image_retrieve(
self, str_or_query_bundle: QueryType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,8 @@ async def aquery(

if isinstance(retrieval_strategy, AsyncBM25Strategy) and len(top_k_nodes) > 0:
max_score = max(top_k_scores)
top_k_scores = [score / max_score for score in top_k_scores]
if max_score > 0:
top_k_scores = [score / max_score for score in top_k_scores]

if (
isinstance(retrieval_strategy, AsyncDenseVectorStrategy)
Expand Down
8 changes: 7 additions & 1 deletion src/pai_rag/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from fastapi import FastAPI
from pai_rag.app.api.service import configure_app
from pai_rag.app.web.webui import configure_webapp
from pai_rag.core.rag_configuration import RagConfiguration
from pai_rag.core.rag_trace import init_trace
from pai_rag.utils.download_models import ModelScopeDownloader
from pai_rag.utils.constants import DEFAULT_MODEL_DIR, EAS_DEFAULT_MODEL_DIR
from logging.config import dictConfig
Expand Down Expand Up @@ -181,6 +183,9 @@ def ui(host, port, rag_url):
default=False,
)
def serve(host, port, config_file, workers, enable_example, skip_download_models):
rag_configuration = RagConfiguration.from_file(config_file)
init_trace(rag_configuration.get_value().get("RAG.trace"))

if not skip_download_models and DEFAULT_MODEL_DIR != EAS_DEFAULT_MODEL_DIR:
logger.info("Start to download models.")
ModelScopeDownloader().load_basic_models()
Expand All @@ -190,7 +195,8 @@ def serve(host, port, config_file, workers, enable_example, skip_download_models
logger.info("Start to loading minerU config file.")
ModelScopeDownloader().load_mineru_config()
logger.info("Finished loading minerU config file.")

os.environ["PAI_RAG_MODEL_DIR"] = DEFAULT_MODEL_DIR
app = FastAPI(lifespan=lifespan)
configure_app(app, config_file=config_file)
configure_app(app, rag_configuration)
uvicorn.run(app=app, host=host, port=port, loop="asyncio", workers=workers)

0 comments on commit 7edee20

Please sign in to comment.