diff --git a/setup.cfg b/setup.cfg
index 05e33f9b5a..8664c59fdd 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -95,6 +95,7 @@ all =
sentence-transformers>=2.7.0
vllm>=0.2.6 ; sys_platform=='linux'
diffusers>=0.25.0 # fix conflict with matcha-tts
+ imageio-ffmpeg # For video
controlnet_aux
orjson
auto-gptq ; sys_platform!='darwin'
@@ -158,6 +159,9 @@ rerank =
image =
diffusers>=0.25.0 # fix conflict with matcha-tts
controlnet_aux
+video =
+ diffusers
+ imageio-ffmpeg
audio =
funasr
omegaconf~=2.3.0
diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py
index 87e19d98cf..47b4848c80 100644
--- a/xinference/api/restful_api.py
+++ b/xinference/api/restful_api.py
@@ -65,6 +65,7 @@
CreateCompletion,
ImageList,
PeftModelConfig,
+ VideoList,
max_tokens_field,
)
from .oauth2.auth_service import AuthService
@@ -123,6 +124,14 @@ class TextToImageRequest(BaseModel):
user: Optional[str] = None
+class TextToVideoRequest(BaseModel):
+ model: str
+ prompt: Union[str, List[str]] = Field(description="The input to embed.")
+ n: Optional[int] = 1
+ kwargs: Optional[str] = None
+ user: Optional[str] = None
+
+
class SpeechRequest(BaseModel):
model: str
input: str
@@ -512,6 +521,17 @@ async def internal_exception_handler(request: Request, exc: Exception):
else None
),
)
+ self._router.add_api_route(
+ "/v1/video/generations",
+ self.create_videos,
+ methods=["POST"],
+ response_model=VideoList,
+ dependencies=(
+ [Security(self._auth_service, scopes=["models:read"])]
+ if self.is_authenticated()
+ else None
+ ),
+ )
self._router.add_api_route(
"/v1/chat/completions",
self.create_chat_completion,
@@ -1546,6 +1566,38 @@ async def create_flexible_infer(self, request: Request) -> Response:
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))
+ async def create_videos(self, request: Request) -> Response:
+ body = TextToVideoRequest.parse_obj(await request.json())
+ model_uid = body.model
+ try:
+ model = await (await self._get_supervisor_ref()).get_model(model_uid)
+ except ValueError as ve:
+ logger.error(str(ve), exc_info=True)
+ await self._report_error_event(model_uid, str(ve))
+ raise HTTPException(status_code=400, detail=str(ve))
+ except Exception as e:
+ logger.error(e, exc_info=True)
+ await self._report_error_event(model_uid, str(e))
+ raise HTTPException(status_code=500, detail=str(e))
+
+ try:
+ kwargs = json.loads(body.kwargs) if body.kwargs else {}
+ video_list = await model.text_to_video(
+ prompt=body.prompt,
+ n=body.n,
+ **kwargs,
+ )
+ return Response(content=video_list, media_type="application/json")
+ except RuntimeError as re:
+ logger.error(re, exc_info=True)
+ await self._report_error_event(model_uid, str(re))
+ self.handle_request_limit_error(re)
+ raise HTTPException(status_code=400, detail=str(re))
+ except Exception as e:
+ logger.error(e, exc_info=True)
+ await self._report_error_event(model_uid, str(e))
+ raise HTTPException(status_code=500, detail=str(e))
+
async def create_chat_completion(self, request: Request) -> Response:
raw_body = await request.json()
body = CreateChatCompletion.parse_obj(raw_body)
diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py
index aa0955f75d..c11c30c29f 100644
--- a/xinference/client/restful/restful_client.py
+++ b/xinference/client/restful/restful_client.py
@@ -31,6 +31,7 @@
ImageList,
LlamaCppGenerateConfig,
PytorchGenerateConfig,
+ VideoList,
)
@@ -370,6 +371,44 @@ def inpainting(
return response_data
+class RESTfulVideoModelHandle(RESTfulModelHandle):
+ def text_to_video(
+ self,
+ prompt: str,
+ n: int = 1,
+ **kwargs,
+ ) -> "VideoList":
+ """
+ Creates a video by the input text.
+
+ Parameters
+ ----------
+ prompt: `str` or `List[str]`
+ The prompt or prompts to guide video generation. If not defined, you need to pass `prompt_embeds`.
+ n: `int`, defaults to 1
+ The number of videos to generate per prompt. Must be between 1 and 10.
+ Returns
+ -------
+ VideoList
+ A list of video objects.
+ """
+ url = f"{self._base_url}/v1/video/generations"
+ request_body = {
+ "model": self._model_uid,
+ "prompt": prompt,
+ "n": n,
+ "kwargs": json.dumps(kwargs),
+ }
+ response = requests.post(url, json=request_body, headers=self.auth_headers)
+ if response.status_code != 200:
+ raise RuntimeError(
+ f"Failed to create the video, detail: {_get_error_string(response)}"
+ )
+
+ response_data = response.json()
+ return response_data
+
+
class RESTfulGenerateModelHandle(RESTfulModelHandle):
def generate(
self,
@@ -1015,6 +1054,10 @@ def get_model(self, model_uid: str) -> RESTfulModelHandle:
return RESTfulAudioModelHandle(
model_uid, self.base_url, auth_headers=self._headers
)
+ elif desc["model_type"] == "video":
+ return RESTfulVideoModelHandle(
+ model_uid, self.base_url, auth_headers=self._headers
+ )
elif desc["model_type"] == "flexible":
return RESTfulFlexibleModelHandle(
model_uid, self.base_url, auth_headers=self._headers
diff --git a/xinference/constants.py b/xinference/constants.py
index 3efad56ed3..c9ba4e5ddc 100644
--- a/xinference/constants.py
+++ b/xinference/constants.py
@@ -47,6 +47,7 @@ def get_xinference_home() -> str:
XINFERENCE_MODEL_DIR = os.path.join(XINFERENCE_HOME, "model")
XINFERENCE_LOG_DIR = os.path.join(XINFERENCE_HOME, "logs")
XINFERENCE_IMAGE_DIR = os.path.join(XINFERENCE_HOME, "image")
+XINFERENCE_VIDEO_DIR = os.path.join(XINFERENCE_HOME, "video")
XINFERENCE_AUTH_DIR = os.path.join(XINFERENCE_HOME, "auth")
XINFERENCE_CSG_ENDPOINT = str(
os.environ.get(XINFERENCE_ENV_CSG_ENDPOINT, "https://hub-stg.opencsg.com/")
diff --git a/xinference/core/model.py b/xinference/core/model.py
index 7fc41b9c53..24cfe3c6e8 100644
--- a/xinference/core/model.py
+++ b/xinference/core/model.py
@@ -774,6 +774,27 @@ async def infer(
f"Model {self._model.model_spec} is not for flexible infer."
)
+ @log_async(logger=logger)
+ @request_limit
+ async def text_to_video(
+ self,
+ prompt: str,
+ n: int = 1,
+ *args,
+ **kwargs,
+ ):
+ if hasattr(self._model, "text_to_video"):
+ return await self._call_wrapper_json(
+ self._model.text_to_video,
+ prompt,
+ n,
+ *args,
+ **kwargs,
+ )
+ raise AttributeError(
+ f"Model {self._model.model_spec} is not for creating video."
+ )
+
async def record_metrics(self, name, op, kwargs):
worker_ref = await self._get_worker_ref()
await worker_ref.record_metrics(name, op, kwargs)
diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py
index 54e4b65849..2b6f7b9fc5 100644
--- a/xinference/core/supervisor.py
+++ b/xinference/core/supervisor.py
@@ -64,6 +64,7 @@
from ..model.image import ImageModelFamilyV1
from ..model.llm import LLMFamilyV1
from ..model.rerank import RerankModelSpec
+ from ..model.video import VideoModelFamilyV1
from .worker import WorkerActor
@@ -484,6 +485,31 @@ async def _to_audio_model_reg(
res["model_instance_count"] = instance_cnt
return res
+ async def _to_video_model_reg(
+ self, model_family: "VideoModelFamilyV1", is_builtin: bool
+ ) -> Dict[str, Any]:
+ from ..model.video import get_cache_status
+
+ instance_cnt = await self.get_instance_count(model_family.model_name)
+ version_cnt = await self.get_model_version_count(model_family.model_name)
+
+ if self.is_local_deployment():
+ # TODO: does not work when the supervisor and worker are running on separate nodes.
+ cache_status = get_cache_status(model_family)
+ res = {
+ **model_family.dict(),
+ "cache_status": cache_status,
+ "is_builtin": is_builtin,
+ }
+ else:
+ res = {
+ **model_family.dict(),
+ "is_builtin": is_builtin,
+ }
+ res["model_version_count"] = version_cnt
+ res["model_instance_count"] = instance_cnt
+ return res
+
async def _to_flexible_model_reg(
self, model_spec: "FlexibleModelSpec", is_builtin: bool
) -> Dict[str, Any]:
@@ -602,6 +628,17 @@ def sort_helper(item):
{"model_name": model_spec.model_name, "is_builtin": False}
)
+ ret.sort(key=sort_helper)
+ return ret
+ elif model_type == "video":
+ from ..model.video import BUILTIN_VIDEO_MODELS
+
+ for model_name, family in BUILTIN_VIDEO_MODELS.items():
+ if detailed:
+ ret.append(await self._to_video_model_reg(family, is_builtin=True))
+ else:
+ ret.append({"model_name": model_name, "is_builtin": True})
+
ret.sort(key=sort_helper)
return ret
elif model_type == "rerank":
diff --git a/xinference/core/worker.py b/xinference/core/worker.py
index 9524bd604a..cfffd7fb17 100644
--- a/xinference/core/worker.py
+++ b/xinference/core/worker.py
@@ -735,6 +735,8 @@ async def _get_model_ability(self, model: Any, model_type: str) -> List[str]:
return ["text_to_image"]
elif model_type == "audio":
return ["audio_to_text"]
+ elif model_type == "video":
+ return ["text_to_video"]
elif model_type == "flexible":
return ["flexible"]
else:
diff --git a/xinference/deploy/docker/requirements.txt b/xinference/deploy/docker/requirements.txt
index 66f6d650af..1830a7de25 100644
--- a/xinference/deploy/docker/requirements.txt
+++ b/xinference/deploy/docker/requirements.txt
@@ -60,6 +60,7 @@ onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows' # Fo
openai-whisper # For CosyVoice
boto3>=1.28.55,<1.28.65 # For tensorizer
tensorizer~=2.9.0
+imageio-ffmpeg # For video
# sglang
outlines>=0.0.44
diff --git a/xinference/deploy/docker/requirements_cpu.txt b/xinference/deploy/docker/requirements_cpu.txt
index a117e0c549..7ae0a2544d 100644
--- a/xinference/deploy/docker/requirements_cpu.txt
+++ b/xinference/deploy/docker/requirements_cpu.txt
@@ -55,3 +55,4 @@ matcha-tts # For CosyVoice
onnxruntime-gpu==1.16.0; sys_platform == 'linux' # For CosyVoice
onnxruntime==1.16.0; sys_platform == 'darwin' or sys_platform == 'windows' # For CosyVoice
openai-whisper # For CosyVoice
+imageio-ffmpeg # For video
diff --git a/xinference/model/core.py b/xinference/model/core.py
index 09cb4104a4..4591d255b0 100644
--- a/xinference/model/core.py
+++ b/xinference/model/core.py
@@ -65,6 +65,7 @@ def create_model_instance(
from .image.core import create_image_model_instance
from .llm.core import create_llm_model_instance
from .rerank.core import create_rerank_model_instance
+ from .video.core import create_video_model_instance
if model_type == "LLM":
return create_llm_model_instance(
@@ -127,6 +128,17 @@ def create_model_instance(
model_path,
**kwargs,
)
+ elif model_type == "video":
+ kwargs.pop("trust_remote_code", None)
+ return create_video_model_instance(
+ subpool_addr,
+ devices,
+ model_uid,
+ model_name,
+ download_hub,
+ model_path,
+ **kwargs,
+ )
elif model_type == "flexible":
kwargs.pop("trust_remote_code", None)
return create_flexible_model_instance(
diff --git a/xinference/model/video/__init__.py b/xinference/model/video/__init__.py
new file mode 100644
index 0000000000..e1325b0bbb
--- /dev/null
+++ b/xinference/model/video/__init__.py
@@ -0,0 +1,62 @@
+# Copyright 2022-2023 XProbe Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import codecs
+import json
+import os
+from itertools import chain
+
+from .core import (
+ BUILTIN_VIDEO_MODELS,
+ MODEL_NAME_TO_REVISION,
+ MODELSCOPE_VIDEO_MODELS,
+ VIDEO_MODEL_DESCRIPTIONS,
+ VideoModelFamilyV1,
+ generate_video_description,
+ get_cache_status,
+ get_video_model_descriptions,
+)
+
+_model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json")
+_model_spec_modelscope_json = os.path.join(
+ os.path.dirname(__file__), "model_spec_modelscope.json"
+)
+BUILTIN_VIDEO_MODELS.update(
+ dict(
+ (spec["model_name"], VideoModelFamilyV1(**spec))
+ for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8"))
+ )
+)
+for model_name, model_spec in BUILTIN_VIDEO_MODELS.items():
+ MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
+
+MODELSCOPE_VIDEO_MODELS.update(
+ dict(
+ (spec["model_name"], VideoModelFamilyV1(**spec))
+ for spec in json.load(
+ codecs.open(_model_spec_modelscope_json, "r", encoding="utf-8")
+ )
+ )
+)
+for model_name, model_spec in MODELSCOPE_VIDEO_MODELS.items():
+ MODEL_NAME_TO_REVISION[model_name].append(model_spec.model_revision)
+
+# register model description
+for model_name, model_spec in chain(
+ MODELSCOPE_VIDEO_MODELS.items(), BUILTIN_VIDEO_MODELS.items()
+):
+ VIDEO_MODEL_DESCRIPTIONS.update(generate_video_description(model_spec))
+
+del _model_spec_json
+del _model_spec_modelscope_json
diff --git a/xinference/model/video/core.py b/xinference/model/video/core.py
new file mode 100644
index 0000000000..3b9f96ad9a
--- /dev/null
+++ b/xinference/model/video/core.py
@@ -0,0 +1,178 @@
+# Copyright 2022-2023 XProbe Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import os
+from collections import defaultdict
+from typing import Dict, List, Literal, Optional, Tuple
+
+from ...constants import XINFERENCE_CACHE_DIR
+from ..core import CacheableModelSpec, ModelDescription
+from ..utils import valid_model_revision
+from .diffusers import DiffUsersVideoModel
+
+MAX_ATTEMPTS = 3
+
+logger = logging.getLogger(__name__)
+
+MODEL_NAME_TO_REVISION: Dict[str, List[str]] = defaultdict(list)
+VIDEO_MODEL_DESCRIPTIONS: Dict[str, List[Dict]] = defaultdict(list)
+BUILTIN_VIDEO_MODELS: Dict[str, "VideoModelFamilyV1"] = {}
+MODELSCOPE_VIDEO_MODELS: Dict[str, "VideoModelFamilyV1"] = {}
+
+
+def get_video_model_descriptions():
+ import copy
+
+ return copy.deepcopy(VIDEO_MODEL_DESCRIPTIONS)
+
+
+class VideoModelFamilyV1(CacheableModelSpec):
+ model_family: str
+ model_name: str
+ model_id: str
+ model_revision: str
+ model_hub: str = "huggingface"
+ model_ability: Optional[List[str]]
+
+
+class VideoModelDescription(ModelDescription):
+ def __init__(
+ self,
+ address: Optional[str],
+ devices: Optional[List[str]],
+ model_spec: VideoModelFamilyV1,
+ model_path: Optional[str] = None,
+ ):
+ super().__init__(address, devices, model_path=model_path)
+ self._model_spec = model_spec
+
+ def to_dict(self):
+ return {
+ "model_type": "video",
+ "address": self.address,
+ "accelerators": self.devices,
+ "model_name": self._model_spec.model_name,
+ "model_family": self._model_spec.model_family,
+ "model_revision": self._model_spec.model_revision,
+ "model_ability": self._model_spec.model_ability,
+ }
+
+ def to_version_info(self):
+ if self._model_path is None:
+ is_cached = get_cache_status(self._model_spec)
+ file_location = get_cache_dir(self._model_spec)
+ else:
+ is_cached = True
+ file_location = self._model_path
+
+ return [
+ {
+ "model_version": self._model_spec.model_name,
+ "model_file_location": file_location,
+ "cache_status": is_cached,
+ }
+ ]
+
+
+def generate_video_description(
+ video_model: VideoModelFamilyV1,
+) -> Dict[str, List[Dict]]:
+ res = defaultdict(list)
+ res[video_model.model_name].extend(
+ VideoModelDescription(None, None, video_model).to_version_info()
+ )
+ return res
+
+
+def match_diffusion(
+ model_name: str,
+ download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
+) -> VideoModelFamilyV1:
+ from ..utils import download_from_modelscope
+ from . import BUILTIN_VIDEO_MODELS, MODELSCOPE_VIDEO_MODELS
+
+ if download_hub == "modelscope" and model_name in MODELSCOPE_VIDEO_MODELS:
+ logger.debug(f"Video model {model_name} found in ModelScope.")
+ return MODELSCOPE_VIDEO_MODELS[model_name]
+ elif download_hub == "huggingface" and model_name in BUILTIN_VIDEO_MODELS:
+ logger.debug(f"Video model {model_name} found in Huggingface.")
+ return BUILTIN_VIDEO_MODELS[model_name]
+ elif download_from_modelscope() and model_name in MODELSCOPE_VIDEO_MODELS:
+ logger.debug(f"Video model {model_name} found in ModelScope.")
+ return MODELSCOPE_VIDEO_MODELS[model_name]
+ elif model_name in BUILTIN_VIDEO_MODELS:
+ logger.debug(f"Video model {model_name} found in Huggingface.")
+ return BUILTIN_VIDEO_MODELS[model_name]
+ else:
+ raise ValueError(
+ f"Video model {model_name} not found, available"
+ f"model list: {BUILTIN_VIDEO_MODELS.keys()}"
+ )
+
+
+def cache(model_spec: VideoModelFamilyV1):
+ from ..utils import cache
+
+ return cache(model_spec, VideoModelDescription)
+
+
+def get_cache_dir(model_spec: VideoModelFamilyV1):
+ return os.path.realpath(os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name))
+
+
+def get_cache_status(
+ model_spec: VideoModelFamilyV1,
+) -> bool:
+ cache_dir = get_cache_dir(model_spec)
+ meta_path = os.path.join(cache_dir, "__valid_download")
+
+ model_name = model_spec.model_name
+ if model_name in BUILTIN_VIDEO_MODELS and model_name in MODELSCOPE_VIDEO_MODELS:
+ hf_spec = BUILTIN_VIDEO_MODELS[model_name]
+ ms_spec = MODELSCOPE_VIDEO_MODELS[model_name]
+
+ return any(
+ [
+ valid_model_revision(meta_path, hf_spec.model_revision),
+ valid_model_revision(meta_path, ms_spec.model_revision),
+ ]
+ )
+ else: # Usually for UT
+ return valid_model_revision(meta_path, model_spec.model_revision)
+
+
+def create_video_model_instance(
+ subpool_addr: str,
+ devices: List[str],
+ model_uid: str,
+ model_name: str,
+ download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
+ model_path: Optional[str] = None,
+ **kwargs,
+) -> Tuple[DiffUsersVideoModel, VideoModelDescription]:
+ model_spec = match_diffusion(model_name, download_hub)
+ if not model_path:
+ model_path = cache(model_spec)
+ assert model_path is not None
+
+ model = DiffUsersVideoModel(
+ model_uid,
+ model_path,
+ model_spec,
+ **kwargs,
+ )
+ model_description = VideoModelDescription(
+ subpool_addr, devices, model_spec, model_path=model_path
+ )
+ return model, model_description
diff --git a/xinference/model/video/diffusers.py b/xinference/model/video/diffusers.py
new file mode 100644
index 0000000000..b9b8569918
--- /dev/null
+++ b/xinference/model/video/diffusers.py
@@ -0,0 +1,180 @@
+# Copyright 2022-2023 XProbe Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import base64
+import logging
+import os
+import sys
+import time
+import uuid
+from concurrent.futures import ThreadPoolExecutor
+from functools import partial
+from typing import TYPE_CHECKING, List, Union
+
+import numpy as np
+import PIL.Image
+import torch
+
+from ...constants import XINFERENCE_VIDEO_DIR
+from ...device_utils import move_model_to_available_device
+from ...types import Video, VideoList
+
+if TYPE_CHECKING:
+ from .core import VideoModelFamilyV1
+
+
+logger = logging.getLogger(__name__)
+
+
+def export_to_video_imageio(
+ video_frames: Union[List[np.ndarray], List["PIL.Image.Image"]],
+ output_video_path: str,
+ fps: int = 8,
+) -> str:
+ """
+ Export the video frames to a video file using imageio lib to Avoid "green screen" issue (for example CogVideoX)
+ """
+ import imageio
+
+ if isinstance(video_frames[0], PIL.Image.Image):
+ video_frames = [np.array(frame) for frame in video_frames]
+ with imageio.get_writer(output_video_path, fps=fps) as writer:
+ for frame in video_frames:
+ writer.append_data(frame)
+ return output_video_path
+
+
+class DiffUsersVideoModel:
+ def __init__(
+ self,
+ model_uid: str,
+ model_path: str,
+ model_spec: "VideoModelFamilyV1",
+ **kwargs,
+ ):
+ self._model_uid = model_uid
+ self._model_path = model_path
+ self._model_spec = model_spec
+ self._model = None
+ self._kwargs = kwargs
+
+ @property
+ def model_spec(self):
+ return self._model_spec
+
+ def load(self):
+ import torch
+
+ torch_dtype = self._kwargs.get("torch_dtype")
+ if sys.platform != "darwin" and torch_dtype is None:
+ # The following params crashes on Mac M2
+ self._kwargs["torch_dtype"] = torch.float16
+ self._kwargs["variant"] = "fp16"
+ self._kwargs["use_safetensors"] = True
+ if isinstance(torch_dtype, str):
+ self._kwargs["torch_dtype"] = getattr(torch, torch_dtype)
+
+ if self._model_spec.model_family == "CogVideoX":
+ from diffusers import CogVideoXPipeline
+
+ self._model = CogVideoXPipeline.from_pretrained(
+ self._model_path, **self._kwargs
+ )
+ else:
+ raise Exception(
+ f"Unsupported model family: {self._model_spec.model_family}"
+ )
+
+ if self._kwargs.get("cpu_offload", False):
+ logger.debug("CPU offloading model")
+ self._model.enable_model_cpu_offload()
+ elif not self._kwargs.get("device_map"):
+ logger.debug("Loading model to available device")
+ self._model = move_model_to_available_device(self._model)
+ # Recommended if your computer has < 64 GB of RAM
+ self._model.enable_attention_slicing()
+
+ def text_to_video(
+ self,
+ prompt: str,
+ n: int = 1,
+ num_inference_steps: int = 50,
+ guidance_scale: int = 6,
+ response_format: str = "b64_json",
+ **kwargs,
+ ) -> VideoList:
+ import gc
+
+ # cv2 bug will cause the video cannot be normally displayed
+ # thus we use the imageio one
+ # from diffusers.utils import export_to_video
+ from ...device_utils import empty_cache
+
+ logger.debug(
+ "diffusers text_to_video args: %s",
+ kwargs,
+ )
+ assert self._model is not None
+ if self._kwargs.get("cpu_offload"):
+ # if enabled cpu offload,
+ # the model.device would be CPU
+ device = "cuda"
+ else:
+ device = self._model.device
+ prompt_embeds, _ = self._model.encode_prompt(
+ prompt=prompt,
+ do_classifier_free_guidance=True,
+ num_videos_per_prompt=n,
+ max_sequence_length=226,
+ device=device,
+ dtype=torch.float16,
+ )
+ assert callable(self._model)
+ output = self._model(
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ prompt_embeds=prompt_embeds,
+ **kwargs,
+ )
+
+ # clean cache
+ gc.collect()
+ empty_cache()
+
+ os.makedirs(XINFERENCE_VIDEO_DIR, exist_ok=True)
+ urls = []
+ for f in output.frames:
+ path = os.path.join(XINFERENCE_VIDEO_DIR, uuid.uuid4().hex + ".mp4")
+ p = export_to_video_imageio(f, path, fps=8)
+ urls.append(p)
+ if response_format == "url":
+ return VideoList(
+ created=int(time.time()),
+ data=[Video(url=url, b64_json=None) for url in urls],
+ )
+ elif response_format == "b64_json":
+
+ def _gen_base64_video(_video_url):
+ try:
+ with open(_video_url, "rb") as f:
+ return base64.b64encode(f.read()).decode()
+ finally:
+ os.remove(_video_url)
+
+ with ThreadPoolExecutor() as executor:
+ results = list(map(partial(executor.submit, _gen_base64_video), urls)) # type: ignore
+ video_list = [Video(url=None, b64_json=s.result()) for s in results]
+ return VideoList(created=int(time.time()), data=video_list)
+ else:
+ raise ValueError(f"Unsupported response format: {response_format}")
diff --git a/xinference/model/video/model_spec.json b/xinference/model/video/model_spec.json
new file mode 100644
index 0000000000..52b748fd6a
--- /dev/null
+++ b/xinference/model/video/model_spec.json
@@ -0,0 +1,11 @@
+[
+ {
+ "model_name": "CogVideoX-2b",
+ "model_family": "CogVideoX",
+ "model_id": "THUDM/CogVideoX-2b",
+ "model_revision": "4bbfb1de622b80bc1b77b6e9aced75f816be0e38",
+ "model_ability": [
+ "text2video"
+ ]
+ }
+]
diff --git a/xinference/model/video/model_spec_modelscope.json b/xinference/model/video/model_spec_modelscope.json
new file mode 100644
index 0000000000..e3cb604921
--- /dev/null
+++ b/xinference/model/video/model_spec_modelscope.json
@@ -0,0 +1,12 @@
+[
+ {
+ "model_name": "CogVideoX-2b",
+ "model_family": "CogVideoX",
+ "model_hub": "modelscope",
+ "model_id": "ZhipuAI/CogVideoX-2b",
+ "model_revision": "master",
+ "model_ability": [
+ "text2video"
+ ]
+ }
+]
diff --git a/xinference/model/video/tests/__init__.py b/xinference/model/video/tests/__init__.py
new file mode 100644
index 0000000000..37f6558d95
--- /dev/null
+++ b/xinference/model/video/tests/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022-2023 XProbe Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/xinference/model/video/tests/test_diffusers_video.py b/xinference/model/video/tests/test_diffusers_video.py
new file mode 100644
index 0000000000..3676612c05
--- /dev/null
+++ b/xinference/model/video/tests/test_diffusers_video.py
@@ -0,0 +1,63 @@
+# Copyright 2022-2023 XProbe Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+import pytest
+
+from .. import BUILTIN_VIDEO_MODELS
+from ..core import cache
+from ..diffusers import DiffUsersVideoModel
+
+logger = logging.getLogger(__name__)
+
+
+@pytest.mark.skip(reason="Video model requires too many GRAM.")
+def test_model():
+ test_model_spec = next(iter(BUILTIN_VIDEO_MODELS.values()))
+ model_path = cache(test_model_spec)
+ model = DiffUsersVideoModel("mock", model_path, test_model_spec)
+ # input is a string
+ input_text = "an apple"
+ model.load()
+ r = model.text_to_image(input_text)
+ assert r
+
+
+@pytest.mark.skip(reason="Video model requires too many GRAM.")
+def test_client(setup):
+ endpoint, _ = setup
+ from ....client import Client
+
+ client = Client(endpoint)
+
+ model_uid = client.launch_model(
+ model_uid="my_video_model",
+ model_name="CogVideoX-2b",
+ model_type="video",
+ )
+ model = client.get_model(model_uid)
+ assert model
+
+ r = model.text_to_video(
+ prompt="A panda, dressed in a small, red jacket and a tiny hat, "
+ "sits on a wooden stool in a serene bamboo forest. "
+ "The panda's fluffy paws strum a miniature acoustic guitar, "
+ "producing soft, melodic tunes. Nearby, a few other pandas gather, "
+ "watching curiously and some clapping in rhythm. "
+ "Sunlight filters through the tall bamboo, casting a gentle glow on the scene. "
+ "The panda's face is expressive, showing concentration and joy as it plays. "
+ "The background includes a small, flowing stream and vibrant green foliage, "
+ "enhancing the peaceful and magical atmosphere of this unique musical performance."
+ )
+ assert r
diff --git a/xinference/types.py b/xinference/types.py
index e66e90bee1..3f636d94c3 100644
--- a/xinference/types.py
+++ b/xinference/types.py
@@ -52,6 +52,16 @@ class ImageList(TypedDict):
data: List[Image]
+class Video(TypedDict):
+ url: Optional[str]
+ b64_json: Optional[str]
+
+
+class VideoList(TypedDict):
+ created: int
+ data: List[Video]
+
+
class EmbeddingUsage(TypedDict):
prompt_tokens: int
total_tokens: int
diff --git a/xinference/web/ui/src/scenes/launch_model/index.js b/xinference/web/ui/src/scenes/launch_model/index.js
index 1339e94d4f..55f05747bd 100644
--- a/xinference/web/ui/src/scenes/launch_model/index.js
+++ b/xinference/web/ui/src/scenes/launch_model/index.js
@@ -69,6 +69,7 @@ const LaunchModel = () => {
+
@@ -93,6 +94,9 @@ const LaunchModel = () => {
+
+
+
diff --git a/xinference/web/ui/src/scenes/running_models/index.js b/xinference/web/ui/src/scenes/running_models/index.js
index e91858f2fd..9f9486651a 100644
--- a/xinference/web/ui/src/scenes/running_models/index.js
+++ b/xinference/web/ui/src/scenes/running_models/index.js
@@ -21,6 +21,7 @@ const RunningModels = () => {
const [embeddingModelData, setEmbeddingModelData] = useState([])
const [imageModelData, setImageModelData] = useState([])
const [audioModelData, setAudioModelData] = useState([])
+ const [videoModelData, setVideoModelData] = useState([])
const [rerankModelData, setRerankModelData] = useState([])
const [flexibleModelData, setFlexibleModelData] = useState([])
const { isCallingApi, setIsCallingApi } = useContext(ApiContext)
@@ -53,6 +54,9 @@ const RunningModels = () => {
setAudioModelData([
{ id: 'Loading, do not refresh page...', url: 'IS_LOADING' },
])
+ setVideoModelData([
+ { id: 'Loading, do not refresh page...', url: 'IS_LOADING' },
+ ])
setImageModelData([
{ id: 'Loading, do not refresh page...', url: 'IS_LOADING' },
])
@@ -72,6 +76,7 @@ const RunningModels = () => {
const newEmbeddingModelData = []
const newImageModelData = []
const newAudioModelData = []
+ const newVideoModelData = []
const newRerankModelData = []
const newFlexibleModelData = []
response.data.forEach((model) => {
@@ -86,6 +91,8 @@ const RunningModels = () => {
newEmbeddingModelData.push(newValue)
} else if (newValue.model_type === 'audio') {
newAudioModelData.push(newValue)
+ } else if (newValue.model_type === 'video') {
+ newVideoModelData.push(newValue)
} else if (newValue.model_type === 'image') {
newImageModelData.push(newValue)
} else if (newValue.model_type === 'rerank') {
@@ -97,6 +104,7 @@ const RunningModels = () => {
setLlmData(newLlmData)
setEmbeddingModelData(newEmbeddingModelData)
setAudioModelData(newAudioModelData)
+ setVideoModelData(newVideoModelData)
setImageModelData(newImageModelData)
setRerankModelData(newRerankModelData)
setFlexibleModelData(newFlexibleModelData)
@@ -591,6 +599,7 @@ const RunningModels = () => {
},
]
const audioModelColumns = embeddingModelColumns
+ const videoModelColumns = embeddingModelColumns
const rerankModelColumns = embeddingModelColumns
const flexibleModelColumns = embeddingModelColumns
@@ -652,6 +661,7 @@ const RunningModels = () => {
+
@@ -725,6 +735,20 @@ const RunningModels = () => {
/>
+
+
+
+
+