From 7f4ff96c4f307f569e2e8fbb0250a3bf42912984 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=86=E9=80=8A?= Date: Mon, 21 Oct 2024 14:36:51 +0800 Subject: [PATCH] Update cnclip --- src/pai_rag/app/web/index_utils.py | 2 +- .../integrations/embeddings/clip/cnclip_embedding.py | 4 +++- src/pai_rag/utils/download_models.py | 12 ++++++------ 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/pai_rag/app/web/index_utils.py b/src/pai_rag/app/web/index_utils.py index 07ff57ab..8e98d30d 100644 --- a/src/pai_rag/app/web/index_utils.py +++ b/src/pai_rag/app/web/index_utils.py @@ -354,7 +354,7 @@ def components_to_index( milvus_collection_name, **kwargs, ) -> RagIndexEntry: - if vector_index.lower() == "new": + if vector_index is None or vector_index.lower() == "new": index_name = new_index_name else: index_name = vector_index diff --git a/src/pai_rag/integrations/embeddings/clip/cnclip_embedding.py b/src/pai_rag/integrations/embeddings/clip/cnclip_embedding.py index d41d9d2c..4d212073 100644 --- a/src/pai_rag/integrations/embeddings/clip/cnclip_embedding.py +++ b/src/pai_rag/integrations/embeddings/clip/cnclip_embedding.py @@ -11,7 +11,9 @@ from llama_index.core.constants import DEFAULT_EMBED_BATCH_SIZE from pai_rag.utils.constants import DEFAULT_MODEL_DIR -DEFAULT_CNCLIP_MODEL_DIR = os.path.join(DEFAULT_MODEL_DIR, "cn_clip") +DEFAULT_CNCLIP_MODEL_DIR = os.path.join( + DEFAULT_MODEL_DIR, "chinese-clip-vit-large-patch14" +) DEFAULT_CNCLIP_MODEL = "ViT-L-14" diff --git a/src/pai_rag/utils/download_models.py b/src/pai_rag/utils/download_models.py index 56bb7ec4..cf75859b 100644 --- a/src/pai_rag/utils/download_models.py +++ b/src/pai_rag/utils/download_models.py @@ -14,10 +14,10 @@ class ModelScopeDownloader: - def __init__(self): + def __init__(self, fetch_config: bool = False): self.download_directory_path = Path(DEFAULT_MODEL_DIR) - if not os.path.exists(self.download_directory_path): - os.makedirs(self.download_directory_path) + if fetch_config or not os.path.exists(self.download_directory_path): + os.makedirs(self.download_directory_path, exist_ok=True) logger.info( f"Create model directory: {self.download_directory_path} and get model info from oss {OSS_URL}." ) @@ -100,7 +100,7 @@ def load_models(self, model): help="model name. Default: download all models provided", default=None, ) -def load_models(model): - download_models = ModelScopeDownloader() - download_models.load_models(model) +def load_models(model_name): + download_models = ModelScopeDownloader(fetch_config=True) + download_models.load_models(model=model_name) download_models.load_mineru_config()