Skip to content

Commit

Permalink
FEAT: Support CogVideoX video model (#2049)
Browse files Browse the repository at this point in the history
Co-authored-by: qinxuye <[email protected]>
  • Loading branch information
codingl2k1 and qinxuye authored Aug 9, 2024
1 parent cbdc811 commit 3ebe1f3
Show file tree
Hide file tree
Showing 20 changed files with 731 additions and 0 deletions.
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ all =
sentence-transformers>=2.7.0
vllm>=0.2.6 ; sys_platform=='linux'
diffusers>=0.25.0 # fix conflict with matcha-tts
imageio-ffmpeg # For video
controlnet_aux
orjson
auto-gptq ; sys_platform!='darwin'
Expand Down Expand Up @@ -158,6 +159,9 @@ rerank =
image =
diffusers>=0.25.0 # fix conflict with matcha-tts
controlnet_aux
video =
diffusers
imageio-ffmpeg
audio =
funasr
omegaconf~=2.3.0
Expand Down
52 changes: 52 additions & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
CreateCompletion,
ImageList,
PeftModelConfig,
VideoList,
max_tokens_field,
)
from .oauth2.auth_service import AuthService
Expand Down Expand Up @@ -123,6 +124,14 @@ class TextToImageRequest(BaseModel):
user: Optional[str] = None


class TextToVideoRequest(BaseModel):
model: str
prompt: Union[str, List[str]] = Field(description="The input to embed.")
n: Optional[int] = 1
kwargs: Optional[str] = None
user: Optional[str] = None


class SpeechRequest(BaseModel):
model: str
input: str
Expand Down Expand Up @@ -512,6 +521,17 @@ async def internal_exception_handler(request: Request, exc: Exception):
else None
),
)
self._router.add_api_route(
"/v1/video/generations",
self.create_videos,
methods=["POST"],
response_model=VideoList,
dependencies=(
[Security(self._auth_service, scopes=["models:read"])]
if self.is_authenticated()
else None
),
)
self._router.add_api_route(
"/v1/chat/completions",
self.create_chat_completion,
Expand Down Expand Up @@ -1546,6 +1566,38 @@ async def create_flexible_infer(self, request: Request) -> Response:
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

async def create_videos(self, request: Request) -> Response:
body = TextToVideoRequest.parse_obj(await request.json())
model_uid = body.model
try:
model = await (await self._get_supervisor_ref()).get_model(model_uid)
except ValueError as ve:
logger.error(str(ve), exc_info=True)
await self._report_error_event(model_uid, str(ve))
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

try:
kwargs = json.loads(body.kwargs) if body.kwargs else {}
video_list = await model.text_to_video(
prompt=body.prompt,
n=body.n,
**kwargs,
)
return Response(content=video_list, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
self.handle_request_limit_error(re)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

async def create_chat_completion(self, request: Request) -> Response:
raw_body = await request.json()
body = CreateChatCompletion.parse_obj(raw_body)
Expand Down
43 changes: 43 additions & 0 deletions xinference/client/restful/restful_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ImageList,
LlamaCppGenerateConfig,
PytorchGenerateConfig,
VideoList,
)


Expand Down Expand Up @@ -370,6 +371,44 @@ def inpainting(
return response_data


class RESTfulVideoModelHandle(RESTfulModelHandle):
def text_to_video(
self,
prompt: str,
n: int = 1,
**kwargs,
) -> "VideoList":
"""
Creates a video by the input text.
Parameters
----------
prompt: `str` or `List[str]`
The prompt or prompts to guide video generation. If not defined, you need to pass `prompt_embeds`.
n: `int`, defaults to 1
The number of videos to generate per prompt. Must be between 1 and 10.
Returns
-------
VideoList
A list of video objects.
"""
url = f"{self._base_url}/v1/video/generations"
request_body = {
"model": self._model_uid,
"prompt": prompt,
"n": n,
"kwargs": json.dumps(kwargs),
}
response = requests.post(url, json=request_body, headers=self.auth_headers)
if response.status_code != 200:
raise RuntimeError(
f"Failed to create the video, detail: {_get_error_string(response)}"
)

response_data = response.json()
return response_data


class RESTfulGenerateModelHandle(RESTfulModelHandle):
def generate(
self,
Expand Down Expand Up @@ -1015,6 +1054,10 @@ def get_model(self, model_uid: str) -> RESTfulModelHandle:
return RESTfulAudioModelHandle(
model_uid, self.base_url, auth_headers=self._headers
)
elif desc["model_type"] == "video":
return RESTfulVideoModelHandle(
model_uid, self.base_url, auth_headers=self._headers
)
elif desc["model_type"] == "flexible":
return RESTfulFlexibleModelHandle(
model_uid, self.base_url, auth_headers=self._headers
Expand Down
1 change: 1 addition & 0 deletions xinference/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def get_xinference_home() -> str:
XINFERENCE_MODEL_DIR = os.path.join(XINFERENCE_HOME, "model")
XINFERENCE_LOG_DIR = os.path.join(XINFERENCE_HOME, "logs")
XINFERENCE_IMAGE_DIR = os.path.join(XINFERENCE_HOME, "image")
XINFERENCE_VIDEO_DIR = os.path.join(XINFERENCE_HOME, "video")
XINFERENCE_AUTH_DIR = os.path.join(XINFERENCE_HOME, "auth")
XINFERENCE_CSG_ENDPOINT = str(
os.environ.get(XINFERENCE_ENV_CSG_ENDPOINT, "https://hub-stg.opencsg.com/")
Expand Down
21 changes: 21 additions & 0 deletions xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,27 @@ async def infer(
f"Model {self._model.model_spec} is not for flexible infer."
)

@log_async(logger=logger)
@request_limit
async def text_to_video(
self,
prompt: str,
n: int = 1,
*args,
**kwargs,
):
if hasattr(self._model, "text_to_video"):
return await self._call_wrapper_json(
self._model.text_to_video,
prompt,
n,
*args,
**kwargs,
)
raise AttributeError(
f"Model {self._model.model_spec} is not for creating video."
)

async def record_metrics(self, name, op, kwargs):
worker_ref = await self._get_worker_ref()
await worker_ref.record_metrics(name, op, kwargs)
37 changes: 37 additions & 0 deletions xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from ..model.image import ImageModelFamilyV1
from ..model.llm import LLMFamilyV1
from ..model.rerank import RerankModelSpec
from ..model.video import VideoModelFamilyV1
from .worker import WorkerActor


Expand Down Expand Up @@ -484,6 +485,31 @@ async def _to_audio_model_reg(
res["model_instance_count"] = instance_cnt
return res

async def _to_video_model_reg(
self, model_family: "VideoModelFamilyV1", is_builtin: bool
) -> Dict[str, Any]:
from ..model.video import get_cache_status

instance_cnt = await self.get_instance_count(model_family.model_name)
version_cnt = await self.get_model_version_count(model_family.model_name)

if self.is_local_deployment():
# TODO: does not work when the supervisor and worker are running on separate nodes.
cache_status = get_cache_status(model_family)
res = {
**model_family.dict(),
"cache_status": cache_status,
"is_builtin": is_builtin,
}
else:
res = {
**model_family.dict(),
"is_builtin": is_builtin,
}
res["model_version_count"] = version_cnt
res["model_instance_count"] = instance_cnt
return res

async def _to_flexible_model_reg(
self, model_spec: "FlexibleModelSpec", is_builtin: bool
) -> Dict[str, Any]:
Expand Down Expand Up @@ -602,6 +628,17 @@ def sort_helper(item):
{"model_name": model_spec.model_name, "is_builtin": False}
)

ret.sort(key=sort_helper)
return ret
elif model_type == "video":
from ..model.video import BUILTIN_VIDEO_MODELS

for model_name, family in BUILTIN_VIDEO_MODELS.items():
if detailed:
ret.append(await self._to_video_model_reg(family, is_builtin=True))
else:
ret.append({"model_name": model_name, "is_builtin": True})

ret.sort(key=sort_helper)
return ret
elif model_type == "rerank":
Expand Down
2 changes: 2 additions & 0 deletions xinference/core/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,8 @@ async def _get_model_ability(self, model: Any, model_type: str) -> List[str]:
return ["text_to_image"]
elif model_type == "audio":
return ["audio_to_text"]
elif model_type == "video":
return ["text_to_video"]
elif model_type == "flexible":
return ["flexible"]
else:
Expand Down
1 change: 1 addition & 0 deletions xinference/deploy/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows' # Fo
openai-whisper # For CosyVoice
boto3>=1.28.55,<1.28.65 # For tensorizer
tensorizer~=2.9.0
imageio-ffmpeg # For video

# sglang
outlines>=0.0.44
Expand Down
1 change: 1 addition & 0 deletions xinference/deploy/docker/requirements_cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,4 @@ matcha-tts # For CosyVoice
onnxruntime-gpu==1.16.0; sys_platform == 'linux' # For CosyVoice
onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows' # For CosyVoice
openai-whisper # For CosyVoice
imageio-ffmpeg # For video
12 changes: 12 additions & 0 deletions xinference/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def create_model_instance(
from .image.core import create_image_model_instance
from .llm.core import create_llm_model_instance
from .rerank.core import create_rerank_model_instance
from .video.core import create_video_model_instance

if model_type == "LLM":
return create_llm_model_instance(
Expand Down Expand Up @@ -127,6 +128,17 @@ def create_model_instance(
model_path,
**kwargs,
)
elif model_type == "video":
kwargs.pop("trust_remote_code", None)
return create_video_model_instance(
subpool_addr,
devices,
model_uid,
model_name,
download_hub,
model_path,
**kwargs,
)
elif model_type == "flexible":
kwargs.pop("trust_remote_code", None)
return create_flexible_model_instance(
Expand Down
62 changes: 62 additions & 0 deletions xinference/model/video/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import codecs
import json
import os
from itertools import chain

from .core import (
BUILTIN_VIDEO_MODELS,
MODEL_NAME_TO_REVISION,
MODELSCOPE_VIDEO_MODELS,
VIDEO_MODEL_DESCRIPTIONS,
VideoModelFamilyV1,
generate_video_description,
get_cache_status,
get_video_model_descriptions,
)

_model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
_model_spec_modelscope_json = os.path.join(
os.path.dirname(__file__), "model_spec_modelscope.json"
)
BUILTIN_VIDEO_MODELS.update(
dict(
(spec["model_name"], VideoModelFamilyV1(**spec))
for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8"))
)
)
for model_name, model_spec in BUILTIN_VIDEO_MODELS.items():
MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)

MODELSCOPE_VIDEO_MODELS.update(
dict(
(spec["model_name"], VideoModelFamilyV1(**spec))
for spec in json.load(
codecs.open(_model_spec_modelscope_json, "r", encoding="utf-8")
)
)
)
for model_name, model_spec in MODELSCOPE_VIDEO_MODELS.items():
MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)

# register model description
for model_name, model_spec in chain(
MODELSCOPE_VIDEO_MODELS.items(), BUILTIN_VIDEO_MODELS.items()
):
VIDEO_MODEL_DESCRIPTIONS.update(generate_video_description(model_spec))

del _model_spec_json
del _model_spec_modelscope_json
Loading

0 comments on commit 3ebe1f3

Please sign in to comment.