Skip to content

Commit

Permalink
change image order (#310)
Browse files Browse the repository at this point in the history
* change image order

* change image order
  • Loading branch information
Ceceliachenen authored Dec 20, 2024
1 parent fd31306 commit 93c7f2a
Showing 1 changed file with 52 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -186,33 +186,41 @@ def _retrieve(
if self._enable_multimodal and self._image_vector_store is not None:
image_nodes = self._text_to_image_retrieve(query_bundle)

seen_images = set([node.node.image_url for node in image_nodes])
# 从文本中召回图片
if self._search_image and len(image_nodes) < self._image_similarity_top_k:
if not text_nodes:
text_nodes = []
if not image_nodes:
image_nodes = []

# 优先从文本中召回图片
integrated_image_nodes = []
seen_image_urls = []
if self._search_image:
for node in text_nodes:
if len(integrated_image_nodes) >= self._image_similarity_top_k:
break
image_url_infos = node.node.metadata.get("image_info_list")
if not image_url_infos:
continue
for image_url_info in image_url_infos:
if image_url_info.get("image_url", None) not in seen_images:
image_nodes.extend(
if len(integrated_image_nodes) >= self._image_similarity_top_k:
break
image_url = image_url_info.get("image_url", None)
if image_url and image_url not in seen_image_urls:
integrated_image_nodes.append(
NodeWithScore(
ImageNode(
image_url=image_url_info.get("image_url", None)
),
score=node.score
* 0.5, # discount the score from text nodes
node=ImageNode(image_url=image_url),
score=node.score,
)
)
seen_images.add(image_url_info.get("image_url", None))
if len(image_nodes) >= self._image_similarity_top_k:
break
seen_image_urls.append(image_url)

if not text_nodes:
text_nodes = []
if not image_nodes:
image_nodes = []
results = text_nodes + image_nodes
for node in image_nodes:
if len(integrated_image_nodes) >= self._image_similarity_top_k:
break
if node.node.image_url not in seen_image_urls:
integrated_image_nodes.append(node)

results = text_nodes + integrated_image_nodes
return results

def retrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
Expand Down Expand Up @@ -491,33 +499,41 @@ async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
logger.debug(f"Retrieved text nodes: {text_nodes}")
logger.debug(f"Retrieved image nodes: {image_nodes}")

seen_images = set([node.node.image_url for node in image_nodes])
# 从文本中召回图片
if self._search_image and len(image_nodes) < self._image_similarity_top_k:
if not text_nodes:
text_nodes = []
if not image_nodes:
image_nodes = []

# 优先从文本中召回图片
integrated_image_nodes = []
seen_image_urls = []
if self._search_image:
for node in text_nodes:
if len(integrated_image_nodes) >= self._image_similarity_top_k:
break
image_url_infos = node.node.metadata.get("image_info_list")
if not image_url_infos:
continue
for image_url_info in image_url_infos:
if image_url_info.get("image_url", None) not in seen_images:
image_nodes.extend(
if len(integrated_image_nodes) >= self._image_similarity_top_k:
break
image_url = image_url_info.get("image_url", None)
if image_url and image_url not in seen_image_urls:
integrated_image_nodes.append(
NodeWithScore(
ImageNode(
image_url=image_url_info.get("image_url", None)
),
score=node.score
* 0.5, # discount the score from text nodes
node=ImageNode(image_url=image_url),
score=node.score,
)
)
seen_images.add(image_url_info.get("image_url", None))
if len(image_nodes) >= self._image_similarity_top_k:
break
seen_image_urls.append(image_url)

if not text_nodes:
text_nodes = []
if not image_nodes:
image_nodes = []
results = text_nodes + image_nodes
for node in image_nodes:
if len(integrated_image_nodes) >= self._image_similarity_top_k:
break
if node.node.image_url not in seen_image_urls:
integrated_image_nodes.append(node)

results = text_nodes + integrated_image_nodes
return results

async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:
Expand Down

0 comments on commit 93c7f2a

Please sign in to comment.