Skip to content

Commit

Permalink
Merge pull request #55 from saileshd1402/update-hf-hub
Browse files Browse the repository at this point in the history
Update download logic
  • Loading branch information
johnugeorge authored Jun 17, 2024
2 parents f36edf2 + 4518948 commit 79570fe
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 116 deletions.
144 changes: 45 additions & 99 deletions llm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import argparse
import json
import sys
import re
from collections import Counter
from typing import List
from huggingface_hub import snapshot_download
import utils.marsgen as mg
Expand All @@ -20,7 +18,6 @@
create_folder_if_not_exists,
delete_directory,
copy_file,
get_all_files_in_directory,
check_if_folder_empty,
)

Expand All @@ -29,63 +26,40 @@
MODEL_STORE_DIR = "model-store"
HANDLER = "handler.py"
MODEL_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "model_config.json")
FILE_EXTENSIONS_TO_IGNORE = [
".safetensors",
".safetensors.index.json",
".h5",
".ot",
".tflite",
".msgpack",
".onnx",
PREFERRED_MODEL_FORMATS = [".safetensors", ".bin"] # In order of Preference
OTHER_MODEL_FORMATS = [
"*.pt",
"*.h5",
"*.gguf",
"*.msgpack",
"*.tflite",
"*.ot",
"*.onnx",
]


def get_ignore_pattern_list(extension_list: List[str]) -> List[str]:
def get_ignore_pattern_list(gen_model: GenerateDataModel) -> List[str]:
"""
This function takes a list of file extensions and returns a list of patterns
that can be used to filter out files with these extensions.
This method creates a list of file extensions to ignore from a priority list based on files
present in the Hugging Face Repo. It filters out extensions not found in the repository and
returns them as ignore patterns prefixed with '*' which is expected by Hugging Face client.
Args:
extension_list (list): A list of file extensions.
gen_model (GenerateDataModel): An instance of the GenerateDataModel class
Returns:
list: A list of patterns with '*' prepended to each extension, suitable for filtering files.
list(str): A list of patterns with '*' prepended to each extension,
suitable for filtering files.
"""
return ["*" + pattern for pattern in extension_list]


def compare_lists(list1: List[str], list2: List[str]) -> bool:
"""
This function checks if two lists are equal by
comparing their contents, regardless of the order.
Args:
list1 (list): The first list to compare.
list2 (list): The second list to compare.
Returns:
bool: True if the lists have the same elements, False otherwise.
"""
return Counter(list1) == Counter(list2)


def filter_files_by_extension(
filenames: List[str], extensions_to_remove: List[str]
) -> List[str]:
"""
This function takes a list of filenames and a list
of extensions to remove. It returns a new list of filenames
after filtering out those with specified extensions.
Args:
filenames (list): A list of filenames to be filtered.
extensions_to_remove (list): A list of file extensions to remove.
Returns:
list: A list of filenames after filtering.
"""
pattern = "|".join([re.escape(suffix) + "$" for suffix in extensions_to_remove])
# for the extensions in FILE_EXTENSIONS_TO_IGNORE
# pattern will be '\.safetensors$|\.safetensors\.index\.json$'
filtered_filenames = [
filename for filename in filenames if not re.search(pattern, filename)
]
return filtered_filenames
repo_file_extensions = hf.get_repo_file_extensions(gen_model)
for desired_extension in PREFERRED_MODEL_FORMATS:
if desired_extension in repo_file_extensions:
ignore_list = [
"*" + ignore_extension
for ignore_extension in PREFERRED_MODEL_FORMATS
if ignore_extension != desired_extension
]
ignore_list.extend(OTHER_MODEL_FORMATS)
return ignore_list
return []


def set_config(gen_model: GenerateDataModel) -> None:
Expand Down Expand Up @@ -141,24 +115,6 @@ class with relevant information.
config_file.writelines(config_info)


def check_if_model_files_exist(gen_model: GenerateDataModel) -> bool:
"""
This function compares the list of files in the downloaded model
directory with the list of files in the HuggingFace repository.
It takes into account any files to ignore based on predefined extensions.
Args:
gen_model (GenerateDataModel): An instance of the GenerateDataModel
class with relevant information.
Returns:
bool: True if the downloaded model files match the expected
repository files, False otherwise.
"""
extra_files_list = get_all_files_in_directory(gen_model.mar_utils.model_path)
repo_files = hf.get_repo_files_list(gen_model)
repo_files = filter_files_by_extension(repo_files, FILE_EXTENSIONS_TO_IGNORE)
return compare_lists(extra_files_list, repo_files)


def check_if_mar_file_exist(gen_model: GenerateDataModel) -> bool:
"""
This function checks if the Model Archive (MAR) file for the
Expand Down Expand Up @@ -266,27 +222,24 @@ class with relevant information.
Returns:
GenerateDataModel: An instance of the GenerateDataModel class.
"""
if os.path.exists(gen_model.mar_utils.model_path) and check_if_model_files_exist(
gen_model
):
print(
(
"## Skipping downloading as model files of the needed"
" repo version are already present\n"
)
)
return gen_model
print("## Starting model files download\n")
delete_directory(gen_model.mar_utils.model_path)
create_folder_if_not_exists(gen_model.mar_utils.model_path)

tmp_hf_cache = os.path.join(gen_model.mar_utils.model_path, "tmp_hf_cache")
create_folder_if_not_exists(tmp_hf_cache)

snapshot_download(
repo_id=gen_model.repo_info.repo_id,
revision=gen_model.repo_info.repo_version,
local_dir=gen_model.mar_utils.model_path,
local_dir_use_symlinks=False,
token=gen_model.repo_info.hf_token,
ignore_patterns=get_ignore_pattern_list(FILE_EXTENSIONS_TO_IGNORE),
local_dir_use_symlinks=False,
cache_dir=tmp_hf_cache,
force_download=True,
ignore_patterns=get_ignore_pattern_list(gen_model),
)
delete_directory(tmp_hf_cache)
print("## Successfully downloaded model_files\n")
return gen_model

Expand All @@ -305,23 +258,16 @@ class with relevant information.
print("## Skipping generation of model archive file as it is present\n")
else:
check_if_path_exists(gen_model.mar_utils.model_path, "model_path", is_dir=True)
if not gen_model.is_custom:
if not check_if_model_files_exist(gen_model):
# checking if local model files are same the repository files
print("## Model files do not match HuggingFace repository Files")
sys.exit(1)
else:
if check_if_folder_empty(gen_model.mar_utils.model_path):
print(
f"\n##Error: {gen_model.model_name} model files for the custom"
f" model not found in the provided path: {gen_model.mar_utils.model_path}"
)
sys.exit(1)
else:
print(
f"\n## Generating MAR file for custom model files: {gen_model.model_name} \n"
)
if check_if_folder_empty(gen_model.mar_utils.model_path):
print(
f"\n##Error: {gen_model.model_name} model files for the custom"
f" model not found in the provided path: {gen_model.mar_utils.model_path}"
)
sys.exit(1)

print(
f"\n## Generating MAR file for custom model files: {gen_model.model_name} \n"
)
create_folder_if_not_exists(gen_model.mar_utils.mar_output)

mg.generate_mars(
Expand Down
2 changes: 1 addition & 1 deletion llm/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torch-model-archiver==0.8.1
kubernetes==28.1.0
kserve==0.11.1
huggingface-hub==0.20.1
huggingface-hub==0.22.2
32 changes: 16 additions & 16 deletions llm/utils/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
"""

import sys
from typing import List
import os
from huggingface_hub import HfApi
from huggingface_hub.utils import (
GatedRepoError,
RepositoryNotFoundError,
RevisionNotFoundError,
HfHubHTTPError,
Expand All @@ -14,20 +15,19 @@
from utils.generate_data_model import GenerateDataModel


def get_repo_files_list(gen_model: GenerateDataModel) -> List[str]:
def get_repo_file_extensions(gen_model: GenerateDataModel) -> set:
"""
This function returns a list of all files in the HuggingFace repo of
This function returns set of all file extensions in the Hugging Face repo of
the model.
Args:
gen_model (GenerateDataModel): An instance of the GenerateDataModel
class with relevant information.
gen_model (GenerateDataModel): An instance of the GenerateDataModel class
Returns:
repo_files (list): all files in the HuggingFace repo of
the model
repo_file_extension (set): The set of all file extensions in the
Hugging Face repo of the model
Raises:
sys.exit(1): If repo_id, repo_version or huggingface token
is not valid, the function will terminate
the program with an exit code of 1.
is not valid, the function will terminate
the program with an exit code of 1.
"""
try:
hf_api = HfApi()
Expand All @@ -36,19 +36,19 @@ class with relevant information.
revision=gen_model.repo_info.repo_version,
token=gen_model.repo_info.hf_token,
)
return repo_files
return {os.path.splitext(file_name)[1] for file_name in repo_files}
except (
HfHubHTTPError,
HFValidationError,
GatedRepoError,
RepositoryNotFoundError,
RevisionNotFoundError,
HfHubHTTPError,
HFValidationError,
ValueError,
KeyError,
):
print(
(
"\n## Error: Please check either repo_id, repo_version "
"or huggingface token is not correct\n"
)
"## Error: Please check either repo_id, repo_version"
" or HuggingFace ID is not correct\n"
)
sys.exit(1)

Expand Down

0 comments on commit 79570fe

Please sign in to comment.