From 93c7f2aaa76e02c992e02b34039a6d33aa6ded13 Mon Sep 17 00:00:00 2001 From: Ceceliachenen Date: Fri, 20 Dec 2024 13:54:38 +0800 Subject: [PATCH] change image order (#310) * change image order * change image order --- .../pai/multimodal/multimodal_retriever.py | 88 +++++++++++-------- 1 file changed, 52 insertions(+), 36 deletions(-) diff --git a/src/pai_rag/integrations/index/pai/multimodal/multimodal_retriever.py b/src/pai_rag/integrations/index/pai/multimodal/multimodal_retriever.py index 05da7dac..a0944509 100644 --- a/src/pai_rag/integrations/index/pai/multimodal/multimodal_retriever.py +++ b/src/pai_rag/integrations/index/pai/multimodal/multimodal_retriever.py @@ -186,33 +186,41 @@ def _retrieve( if self._enable_multimodal and self._image_vector_store is not None: image_nodes = self._text_to_image_retrieve(query_bundle) - seen_images = set([node.node.image_url for node in image_nodes]) - # 从文本中召回图片 - if self._search_image and len(image_nodes) < self._image_similarity_top_k: + if not text_nodes: + text_nodes = [] + if not image_nodes: + image_nodes = [] + + # 优先从文本中召回图片 + integrated_image_nodes = [] + seen_image_urls = [] + if self._search_image: for node in text_nodes: + if len(integrated_image_nodes) >= self._image_similarity_top_k: + break image_url_infos = node.node.metadata.get("image_info_list") if not image_url_infos: continue for image_url_info in image_url_infos: - if image_url_info.get("image_url", None) not in seen_images: - image_nodes.extend( + if len(integrated_image_nodes) >= self._image_similarity_top_k: + break + image_url = image_url_info.get("image_url", None) + if image_url and image_url not in seen_image_urls: + integrated_image_nodes.append( NodeWithScore( - ImageNode( - image_url=image_url_info.get("image_url", None) - ), - score=node.score - * 0.5, # discount the score from text nodes + node=ImageNode(image_url=image_url), + score=node.score, ) ) - seen_images.add(image_url_info.get("image_url", None)) - if len(image_nodes) >= self._image_similarity_top_k: - break + seen_image_urls.append(image_url) - if not text_nodes: - text_nodes = [] - if not image_nodes: - image_nodes = [] - results = text_nodes + image_nodes + for node in image_nodes: + if len(integrated_image_nodes) >= self._image_similarity_top_k: + break + if node.node.image_url not in seen_image_urls: + integrated_image_nodes.append(node) + + results = text_nodes + integrated_image_nodes return results def retrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]: @@ -491,33 +499,41 @@ async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: logger.debug(f"Retrieved text nodes: {text_nodes}") logger.debug(f"Retrieved image nodes: {image_nodes}") - seen_images = set([node.node.image_url for node in image_nodes]) - # 从文本中召回图片 - if self._search_image and len(image_nodes) < self._image_similarity_top_k: + if not text_nodes: + text_nodes = [] + if not image_nodes: + image_nodes = [] + + # 优先从文本中召回图片 + integrated_image_nodes = [] + seen_image_urls = [] + if self._search_image: for node in text_nodes: + if len(integrated_image_nodes) >= self._image_similarity_top_k: + break image_url_infos = node.node.metadata.get("image_info_list") if not image_url_infos: continue for image_url_info in image_url_infos: - if image_url_info.get("image_url", None) not in seen_images: - image_nodes.extend( + if len(integrated_image_nodes) >= self._image_similarity_top_k: + break + image_url = image_url_info.get("image_url", None) + if image_url and image_url not in seen_image_urls: + integrated_image_nodes.append( NodeWithScore( - ImageNode( - image_url=image_url_info.get("image_url", None) - ), - score=node.score - * 0.5, # discount the score from text nodes + node=ImageNode(image_url=image_url), + score=node.score, ) ) - seen_images.add(image_url_info.get("image_url", None)) - if len(image_nodes) >= self._image_similarity_top_k: - break + seen_image_urls.append(image_url) - if not text_nodes: - text_nodes = [] - if not image_nodes: - image_nodes = [] - results = text_nodes + image_nodes + for node in image_nodes: + if len(integrated_image_nodes) >= self._image_similarity_top_k: + break + if node.node.image_url not in seen_image_urls: + integrated_image_nodes.append(node) + + results = text_nodes + integrated_image_nodes return results async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]: