Skip to content

Commit

Permalink
ENH: Skip download if model exists (#495)
Browse files Browse the repository at this point in the history
  • Loading branch information
aresnow1 authored Sep 26, 2023
1 parent 5f2078b commit aac134d
Showing 1 changed file with 27 additions and 2 deletions.
29 changes: 27 additions & 2 deletions xinference/model/llm/llm_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import logging
import os
import platform
Expand Down Expand Up @@ -450,15 +451,27 @@ def cache_from_modelscope(

cache_dir = _get_cache_dir(llm_family, llm_spec)
if llm_spec.model_format == "pytorch":
download_dir = snapshot_download(
llm_spec.model_id, revision=llm_spec.model_revision
meta_path = os.path.join(cache_dir, "__valid_download")
if os.path.exists(meta_path):
return cache_dir
download_dir = retry_download(
snapshot_download,
llm_family,
llm_spec,
llm_spec.model_id,
revision=llm_spec.model_revision,
)
for subdir, dirs, files in os.walk(download_dir):
for file in files:
relpath = os.path.relpath(os.path.join(subdir, file), download_dir)
symlink_local_file(os.path.join(subdir, file), cache_dir, relpath)
with open(meta_path, "w") as f:
f.write(str(datetime.datetime.now()))

elif llm_spec.model_format in ["ggmlv3", "ggufv2"]:
meta_path = os.path.join(cache_dir, f"__valid_download_{quantization}")
if os.path.exists(meta_path):
return cache_dir
filename = llm_spec.model_file_name_template.format(quantization=quantization)
download_path = retry_download(
model_file_download,
Expand All @@ -469,6 +482,8 @@ def cache_from_modelscope(
revision=llm_spec.model_revision,
)
symlink_local_file(download_path, cache_dir, filename)
with open(meta_path, "w") as f:
f.write(str(datetime.datetime.now()))
else:
raise ValueError(f"Unsupported format: {llm_spec.model_format}")
return cache_dir
Expand All @@ -487,6 +502,9 @@ def cache_from_huggingface(
cache_dir = _get_cache_dir(llm_family, llm_spec)
if llm_spec.model_format == "pytorch":
assert isinstance(llm_spec, PytorchLLMSpecV1)
meta_path = os.path.join(cache_dir, "__valid_download")
if os.path.exists(meta_path):
return cache_dir

retry_download(
huggingface_hub.snapshot_download,
Expand All @@ -497,9 +515,14 @@ def cache_from_huggingface(
local_dir=cache_dir,
local_dir_use_symlinks=True,
)
with open(meta_path, "w") as f:
f.write(str(datetime.datetime.now()))

elif llm_spec.model_format in ["ggmlv3", "ggufv2"]:
assert isinstance(llm_spec, GgmlLLMSpecV1)
meta_path = os.path.join(cache_dir, f"__valid_download_{quantization}")
if os.path.exists(meta_path):
return cache_dir
file_name = llm_spec.model_file_name_template.format(quantization=quantization)
retry_download(
huggingface_hub.hf_hub_download,
Expand All @@ -511,6 +534,8 @@ def cache_from_huggingface(
local_dir=cache_dir,
local_dir_use_symlinks=True,
)
with open(meta_path, "w") as f:
f.write(str(datetime.datetime.now()))
else:
raise ValueError(f"Unsupported model format: {llm_spec.model_format}")

Expand Down

0 comments on commit aac134d

Please sign in to comment.