Skip to content

Commit

Permalink
Update cnclip
Browse files Browse the repository at this point in the history
  • Loading branch information
moria97 committed Oct 21, 2024
1 parent 86e5461 commit 7f4ff96
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/pai_rag/app/web/index_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def components_to_index(
milvus_collection_name,
**kwargs,
) -> RagIndexEntry:
if vector_index.lower() == "new":
if vector_index is None or vector_index.lower() == "new":
index_name = new_index_name
else:
index_name = vector_index
Expand Down
4 changes: 3 additions & 1 deletion src/pai_rag/integrations/embeddings/clip/cnclip_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from llama_index.core.constants import DEFAULT_EMBED_BATCH_SIZE
from pai_rag.utils.constants import DEFAULT_MODEL_DIR

DEFAULT_CNCLIP_MODEL_DIR = os.path.join(DEFAULT_MODEL_DIR, "cn_clip")
DEFAULT_CNCLIP_MODEL_DIR = os.path.join(
DEFAULT_MODEL_DIR, "chinese-clip-vit-large-patch14"
)
DEFAULT_CNCLIP_MODEL = "ViT-L-14"


Expand Down
12 changes: 6 additions & 6 deletions src/pai_rag/utils/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@


class ModelScopeDownloader:
def __init__(self):
def __init__(self, fetch_config: bool = False):
self.download_directory_path = Path(DEFAULT_MODEL_DIR)
if not os.path.exists(self.download_directory_path):
os.makedirs(self.download_directory_path)
if fetch_config or not os.path.exists(self.download_directory_path):
os.makedirs(self.download_directory_path, exist_ok=True)
logger.info(
f"Create model directory: {self.download_directory_path} and get model info from oss {OSS_URL}."
)
Expand Down Expand Up @@ -100,7 +100,7 @@ def load_models(self, model):
help="model name. Default: download all models provided",
default=None,
)
def load_models(model):
download_models = ModelScopeDownloader()
download_models.load_models(model)
def load_models(model_name):
download_models = ModelScopeDownloader(fetch_config=True)
download_models.load_models(model=model_name)
download_models.load_mineru_config()

0 comments on commit 7f4ff96

Please sign in to comment.