diff --git a/xinference/model/llm/llm_family.py b/xinference/model/llm/llm_family.py index 20668ef75a..6da9928a65 100644 --- a/xinference/model/llm/llm_family.py +++ b/xinference/model/llm/llm_family.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime import logging import os import platform @@ -450,15 +451,27 @@ def cache_from_modelscope( cache_dir = _get_cache_dir(llm_family, llm_spec) if llm_spec.model_format == "pytorch": - download_dir = snapshot_download( - llm_spec.model_id, revision=llm_spec.model_revision + meta_path = os.path.join(cache_dir, "__valid_download") + if os.path.exists(meta_path): + return cache_dir + download_dir = retry_download( + snapshot_download, + llm_family, + llm_spec, + llm_spec.model_id, + revision=llm_spec.model_revision, ) for subdir, dirs, files in os.walk(download_dir): for file in files: relpath = os.path.relpath(os.path.join(subdir, file), download_dir) symlink_local_file(os.path.join(subdir, file), cache_dir, relpath) + with open(meta_path, "w") as f: + f.write(str(datetime.datetime.now())) elif llm_spec.model_format in ["ggmlv3", "ggufv2"]: + meta_path = os.path.join(cache_dir, f"__valid_download_{quantization}") + if os.path.exists(meta_path): + return cache_dir filename = llm_spec.model_file_name_template.format(quantization=quantization) download_path = retry_download( model_file_download, @@ -469,6 +482,8 @@ def cache_from_modelscope( revision=llm_spec.model_revision, ) symlink_local_file(download_path, cache_dir, filename) + with open(meta_path, "w") as f: + f.write(str(datetime.datetime.now())) else: raise ValueError(f"Unsupported format: {llm_spec.model_format}") return cache_dir @@ -487,6 +502,9 @@ def cache_from_huggingface( cache_dir = _get_cache_dir(llm_family, llm_spec) if llm_spec.model_format == "pytorch": assert isinstance(llm_spec, PytorchLLMSpecV1) + meta_path = os.path.join(cache_dir, "__valid_download") + if os.path.exists(meta_path): + return cache_dir retry_download( huggingface_hub.snapshot_download, @@ -497,9 +515,14 @@ def cache_from_huggingface( local_dir=cache_dir, local_dir_use_symlinks=True, ) + with open(meta_path, "w") as f: + f.write(str(datetime.datetime.now())) elif llm_spec.model_format in ["ggmlv3", "ggufv2"]: assert isinstance(llm_spec, GgmlLLMSpecV1) + meta_path = os.path.join(cache_dir, f"__valid_download_{quantization}") + if os.path.exists(meta_path): + return cache_dir file_name = llm_spec.model_file_name_template.format(quantization=quantization) retry_download( huggingface_hub.hf_hub_download, @@ -511,6 +534,8 @@ def cache_from_huggingface( local_dir=cache_dir, local_dir_use_symlinks=True, ) + with open(meta_path, "w") as f: + f.write(str(datetime.datetime.now())) else: raise ValueError(f"Unsupported model format: {llm_spec.model_format}")