Skip to content

Commit

Permalink
ENH: caching from self-hosted storage (#419)
Browse files Browse the repository at this point in the history
  • Loading branch information
UranusSeven authored Sep 8, 2023
1 parent 1a292a5 commit d4d18da
Show file tree
Hide file tree
Showing 9 changed files with 393 additions and 78 deletions.
1 change: 1 addition & 0 deletions .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ jobs:
pip install bitsandbytes
pip install ctransformers
pip install sentence-transformers
pip install s3fs
pip install -e ".[dev]"
working-directory: .

Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ $ xinference registrations
For in-depth details on the built-in models, please refer to [built-in models](https://inference.readthedocs.io/en/latest/models/builtin/index.html).

**NOTE**:
- Xinference will download models automatically for you, and by default the models will be saved under `${USER}/.xinference/cache`.

- Xinference will download models automatically for you, and by default the models will be saved under `${USER}/.xinference/cache`.
- If you have trouble downloading models from the Hugging Face, run `export XINFERENCE_MODEL_SRC=xorbits` to download models from our mirror site.

## Custom models
Please refer to [custom models](https://inference.readthedocs.io/en/latest/models/custom.html).
5 changes: 1 addition & 4 deletions README_ja_JP.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,7 @@ $ xinference registrations

****:
- Xinference は自動的にモデルをダウンロードし、デフォルトでは `${USER}/.xinference/cache` の下に保存されます。
- Foundation モデルは `generate` インターフェースのみを提供する。
- RLHF と SFT のモデルは `generate``chat` の両方を提供する。
- Apple Metal GPU をアクセラレーションに使用する場合は、q4_0 と q4_1 の量子化方法を選択してください。
- `llama-2-chat` 70B ggmlv3 モデルは現在 q4_0 量子化しかサポートしていない。
- Hugging Face からモデルをダウンロードする際に問題が発生した場合は、 `export XINFERENCE_MODEL_SRC=xorbits` を実行して、ミラーサイトからモデルをダウンロードしてください。

## カスタムモデル
[カスタムモデル](https://inference.readthedocs.io/en/latest/models/custom.html)を参照してください。
Expand Down
3 changes: 2 additions & 1 deletion README_zh_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ $ xinference registrations

**注意**:
- Xinference 会自动为你下载模型,默认的模型存放路径为 `${USER}/.xinference/cache`

- 如果您在Hugging Face下载模型时遇到问题,请运行 `export XINFERENCE_MODEL_SRC=xorbits`,从我们的镜像站点下载模型。
-
## 自定义模型
请参考 [自定义模型](https://inference.readthedocs.io/en/latest/models/custom.html)
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ install_requires =
huggingface-hub>=0.14.1,<1.0
typing_extensions
fsspec
s3fs

[options.packages.find]
exclude =
Expand Down
2 changes: 2 additions & 0 deletions xinference/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@
XINFERENCE_DEFAULT_LOCAL_HOST = "127.0.0.1"
XINFERENCE_DEFAULT_DISTRIBUTED_HOST = "0.0.0.0"
XINFERENCE_DEFAULT_ENDPOINT_PORT = 9997

XINFERENCE_ENV_ENDPOINT = "XINFERENCE_ENDPOINT"
XINFERENCE_ENV_MODEL_SRC = "XINFERENCE_MODEL_SRC"
4 changes: 2 additions & 2 deletions xinference/model/llm/llm_family.json
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@
"none"
],
"model_id": "THUDM/chatglm-6b",
"model_revision": "b1502f4f75c71499a3d566b14463edd62620ce9f"
"model_revision": "8b7d33596d18c5e83e2da052d05ca4db02e60620"
}
],
"prompt_style": {
Expand Down Expand Up @@ -1265,7 +1265,7 @@
"none"
],
"model_id": "WizardLM/WizardMath-70B-V1.0",
"model_revision": " 8823afe1d77b1ebdd6ac0c14e6e8977037d1830e"
"model_revision": "e089c3f9d2ad9d1acb62425aec3f4126f498f4c5"
}
],
"prompt_style": {
Expand Down
220 changes: 182 additions & 38 deletions xinference/model/llm/llm_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
import os
import platform
import shutil
from threading import Lock
from typing import List, Optional, Tuple, Type, Union

Expand Down Expand Up @@ -87,6 +88,25 @@ class LLMFamilyV1(BaseModel):
UD_LLM_FAMILIES_LOCK = Lock()


def is_locale_chinese_simplified() -> bool:
import locale

try:
lang, _ = locale.getdefaultlocale()
return lang == "zh_CN"
except:
return False


def download_from_self_hosted_storage() -> bool:
from ...constants import XINFERENCE_ENV_MODEL_SRC

return (
is_locale_chinese_simplified()
or os.environ.get(XINFERENCE_ENV_MODEL_SRC) == "xorbits"
)


def get_legacy_cache_path(
model_name: str,
model_format: str,
Expand All @@ -109,14 +129,17 @@ def cache(
quantization,
)
if os.path.exists(legacy_cache_path):
logger.debug("Legacy cache path exists: %s", legacy_cache_path)
logger.info("Legacy cache path exists: %s", legacy_cache_path)
return os.path.dirname(legacy_cache_path)
elif download_from_self_hosted_storage() and is_self_hosted(llm_family, llm_spec):
logger.info(f"Caching from self-hosted storage")
return cache_from_self_hosted_storage(llm_family, llm_spec, quantization)
else:
if llm_spec.model_uri is not None:
logger.debug(f"Caching from URI: {llm_spec.model_uri}")
return cache_from_uri(llm_family, llm_spec)
logger.info(f"Caching from URI: {llm_spec.model_uri}")
return cache_from_uri(llm_family, llm_spec, quantization)
else:
logger.debug(f"Caching from Hugging Face: {llm_spec.model_id}")
logger.info(f"Caching from Hugging Face: {llm_spec.model_id}")
return cache_from_huggingface(llm_family, llm_spec, quantization)


Expand All @@ -138,74 +161,191 @@ def parse_uri(uri: str) -> Tuple[str, str]:
SUPPORTED_SCHEMES = ["s3"]


class AWSRegion:
def __init__(self, region: str):
self.region = region
self.original_aws_default_region = None

def __enter__(self):
if "AWS_DEFAULT_REGION" in os.environ:
self.original_aws_default_region = os.environ["AWS_DEFAULT_REGION"]
os.environ["AWS_DEFAULT_REGION"] = self.region

def __exit__(self, exc_type, exc_value, traceback):
if self.original_aws_default_region:
os.environ["AWS_DEFAULT_REGION"] = self.original_aws_default_region
else:
del os.environ["AWS_DEFAULT_REGION"]


def is_self_hosted(
llm_family: LLMFamilyV1,
llm_spec: "LLMSpecV1",
):
from fsspec import AbstractFileSystem, filesystem

with AWSRegion("cn-northwest-1"):
src_fs: AbstractFileSystem = filesystem("s3", anon=True)
model_dir = (
f"/xinference-models/llm/"
f"{llm_family.model_name}-{llm_spec.model_format}-{llm_spec.model_size_in_billions}b"
)
return src_fs.exists(model_dir)


def cache_from_self_hosted_storage(
llm_family: LLMFamilyV1,
llm_spec: "LLMSpecV1",
quantization: Optional[str] = None,
) -> str:
with AWSRegion("cn-northwest-1"):
llm_spec = llm_spec.copy()
llm_spec.model_uri = (
f"s3://xinference-models/llm/"
f"{llm_family.model_name}-{llm_spec.model_format}-{llm_spec.model_size_in_billions}b"
)

return cache_from_uri(
llm_family, llm_spec, quantization, self_hosted_storage=True
)


def cache_from_uri(
llm_family: LLMFamilyV1,
llm_spec: "LLMSpecV1",
quantization: Optional[str] = None,
self_hosted_storage: bool = False,
) -> str:
from fsspec import AbstractFileSystem, filesystem

def copy(
_src_fs: "AbstractFileSystem",
src_path: str,
_src_path: str,
dst_fs: "AbstractFileSystem",
dst_path: str,
max_attempt: int = 3,
):
logger.error((src_path, dst_path))
with _src_fs.open(src_path, "rb") as src_file:
with dst_fs.open(dst_path, "wb") as dst_file:
dst_file.write(src_file.read())
from tqdm import tqdm

for attempt in range(max_attempt):
logger.info(f"Copy from {_src_path} to {dst_path}, attempt: {attempt}")
try:
with _src_fs.open(_src_path, "rb") as src_file:
file_size = _src_fs.info(src_path)["size"]

dst_fs.makedirs(os.path.dirname(dst_path), exist_ok=True)
with dst_fs.open(dst_path, "wb") as dst_file:
chunk_size = 1024 * 1024 # 1 MB

with tqdm(
total=file_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
desc=_src_path,
) as pbar:
while True:
chunk = src_file.read(chunk_size)
if not chunk:
break
dst_file.write(chunk)
pbar.update(len(chunk))
logger.info(
f"Copy from {_src_path} to {dst_path} finished, attempt: {attempt}"
)
break
except:
logger.error(
f"Failed to copy from {_src_path} to {dst_path} on attempt {attempt + 1}",
exc_info=True,
)
if attempt + 1 == max_attempt:
raise

cache_dir_name = (
f"{llm_family.model_name}-{llm_spec.model_format}"
f"-{llm_spec.model_size_in_billions}b"
)
cache_dir = os.path.realpath(os.path.join(XINFERENCE_CACHE_DIR, cache_dir_name))
if os.path.exists(cache_dir):
return cache_dir

assert llm_spec.model_uri is not None
src_scheme, src_root = parse_uri(llm_spec.model_uri)
if src_root.endswith("/"):
# remove trailing path separator
# remove trailing path separator.
src_root = src_root[:-1]

if src_scheme == "file":
if not os.path.isabs(src_root):
raise ValueError(
f"Model URI cannot be a relative path: {llm_spec.model_uri}"
)
if not os.path.exists(XINFERENCE_CACHE_DIR):
os.makedirs(XINFERENCE_CACHE_DIR, exist_ok=True)
os.symlink(src_root, cache_dir, target_is_directory=True)
os.makedirs(XINFERENCE_CACHE_DIR, exist_ok=True)
if os.path.exists(cache_dir):
logger.info(f"Cache {cache_dir} exists")
return cache_dir
else:
os.symlink(src_root, cache_dir, target_is_directory=True)
return cache_dir
elif src_scheme in SUPPORTED_SCHEMES:
if not os.path.exists(cache_dir):
os.makedirs(cache_dir, exist_ok=True)

src_fs = filesystem(src_scheme)
# use anonymous connection for self-hosted storage.
src_fs: AbstractFileSystem = filesystem(src_scheme, anon=self_hosted_storage)
local_fs: AbstractFileSystem = filesystem("file")

files_to_download = []
for path, _, files in src_fs.walk(llm_spec.model_uri):
for file in files:
src_path = f"{path}/{file}"
local_path = src_path.replace(src_root, cache_dir)
files_to_download.append((src_path, local_path))

from concurrent.futures import ThreadPoolExecutor
if llm_spec.model_format == "pytorch":
if os.path.exists(cache_dir):
logger.info(f"Cache {cache_dir} exists")
return cache_dir
else:
os.makedirs(cache_dir, exist_ok=True)

for path, _, files in src_fs.walk(llm_spec.model_uri):
for file in files:
src_path = f"{path}/{file}"
local_path = src_path.replace(src_root, cache_dir)
files_to_download.append((src_path, local_path))
elif llm_spec.model_format == "ggmlv3":
file = llm_spec.model_file_name_template.format(quantization=quantization)
if os.path.exists(os.path.join(cache_dir, file)):
logger.info(f"Cache {os.path.join(cache_dir, file)} exists")
return cache_dir
else:
os.makedirs(cache_dir, exist_ok=True)

from tqdm import tqdm
src_path = f"{src_root}/{file}"
local_path = f"{cache_dir}/{file}"
files_to_download.append((src_path, local_path))
else:
raise ValueError(f"Unsupported model format: {llm_spec.model_format}")

with tqdm(total=len(files_to_download), desc="Downloading files") as pbar:
with ThreadPoolExecutor(max_workers=4) as executor:
futures = [
executor.submit(copy, src_fs, src_path, local_fs, local_path)
for src_path, local_path in files_to_download
]
for future in futures:
future.result()
pbar.update(1)
from concurrent.futures import ThreadPoolExecutor

failed = False
with ThreadPoolExecutor(max_workers=min(len(files_to_download), 4)) as executor:
futures = [
(
src_path,
executor.submit(copy, src_fs, src_path, local_fs, local_path),
)
for src_path, local_path in files_to_download
]
for src_path, future in futures:
if failed:
future.cancel()
else:
try:
future.result()
except:
logger.error(f"Download {src_path} failed", exc_info=True)
failed = True

if failed:
logger.warning(f"Removing cache directory: {cache_dir}")
shutil.rmtree(cache_dir, ignore_errors=True)
raise RuntimeError(
f"Failed to download model '{llm_family.model_name}' "
f"(size: {llm_spec.model_size_in_billions}, format: {llm_spec.model_format})"
)
return cache_dir
else:
raise ValueError(f"Unsupported URL scheme: {src_scheme}")
Expand Down Expand Up @@ -249,7 +389,9 @@ def cache_from_huggingface(

else:
raise RuntimeError(
f"Failed to download model '{llm_spec.model_name}' (size: {llm_spec.model_size}, format: {llm_spec.model_format}) after multiple retries"
f"Failed to download model '{llm_family.model_name}' "
f"(size: {llm_spec.model_size_in_billions}, format: {llm_spec.model_format}) "
f"after multiple retries"
)

elif llm_spec.model_format == "ggmlv3":
Expand All @@ -274,7 +416,9 @@ def cache_from_huggingface(

else:
raise RuntimeError(
f"Failed to download model '{llm_spec.model_name}' (size: {llm_spec.model_size}, format: {llm_spec.model_format}) after multiple retries"
f"Failed to download model '{llm_family.model_name}' "
f"(size: {llm_spec.model_size_in_billions}, format: {llm_spec.model_format}) "
f"after multiple retries"
)

return cache_dir
Expand Down
Loading

0 comments on commit d4d18da

Please sign in to comment.