Skip to content

Commit

Permalink
FEAT: added MLX support for Flux.1 (#2459)
Browse files Browse the repository at this point in the history
  • Loading branch information
qinxuye authored Oct 25, 2024
1 parent f7f873f commit d4cd7b1
Show file tree
Hide file tree
Showing 26 changed files with 2,428 additions and 19 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ install_requires =
tabulate
requests
pydantic
fastapi==0.110.3
fastapi>=0.110.3
uvicorn
huggingface-hub>=0.19.4
typing_extensions
Expand Down
6 changes: 5 additions & 1 deletion xinference/core/chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ def build(self) -> "gr.Blocks":
# Gradio initiates the queue during a startup event, but since the app has already been
# started, that event will not run, so manually invoke the startup events.
# See: https://github.com/gradio-app/gradio/issues/5228
interface.startup_events()
try:
interface.run_startup_events()
except AttributeError:
# compatibility
interface.startup_events()
favicon_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
os.path.pardir,
Expand Down
6 changes: 5 additions & 1 deletion xinference/core/image_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ def build(self) -> gr.Blocks:
# Gradio initiates the queue during a startup event, but since the app has already been
# started, that event will not run, so manually invoke the startup events.
# See: https://github.com/gradio-app/gradio/issues/5228
interface.startup_events()
try:
interface.run_startup_events()
except AttributeError:
# compatibility
interface.startup_events()
favicon_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
os.path.pardir,
Expand Down
2 changes: 1 addition & 1 deletion xinference/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
# For tool call
self.tools = None
# Currently, for storing tool call streaming results.
self.outputs: List[str] = []
self.outputs: List[str] = [] # type: ignore
# inference results,
# it is a list type because when stream=True,
# self.completion contains all the results in a decode round.
Expand Down
2 changes: 1 addition & 1 deletion xinference/deploy/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ tqdm>=4.27
tabulate
requests
pydantic
fastapi==0.110.3
fastapi>=0.110.3
uvicorn
huggingface-hub>=0.19.4
typing_extensions
Expand Down
2 changes: 1 addition & 1 deletion xinference/deploy/docker/requirements_cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ tqdm>=4.27
tabulate
requests
pydantic
fastapi==0.110.3
fastapi>=0.110.3
uvicorn
huggingface-hub>=0.19.4
typing_extensions
Expand Down
28 changes: 25 additions & 3 deletions xinference/model/image/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections.abc
import logging
import os
import platform
from collections import defaultdict
from typing import Dict, List, Literal, Optional, Tuple, Union

Expand All @@ -23,6 +25,7 @@
from ..utils import valid_model_revision
from .ocr.got_ocr2 import GotOCR2Model
from .stable_diffusion.core import DiffusionModel
from .stable_diffusion.mlx import MLXDiffusionModel

logger = logging.getLogger(__name__)

Expand All @@ -46,6 +49,7 @@ class ImageModelFamilyV1(CacheableModelSpec):
model_hub: str = "huggingface"
model_ability: Optional[List[str]]
controlnet: Optional[List["ImageModelFamilyV1"]]
default_model_config: Optional[dict] = {}
default_generate_config: Optional[dict] = {}


Expand Down Expand Up @@ -212,7 +216,9 @@ def create_image_model_instance(
download_hub: Optional[Literal["huggingface", "modelscope", "csghub"]] = None,
model_path: Optional[str] = None,
**kwargs,
) -> Tuple[Union[DiffusionModel, GotOCR2Model], ImageModelDescription]:
) -> Tuple[
Union[DiffusionModel, MLXDiffusionModel, GotOCR2Model], ImageModelDescription
]:
model_spec = match_diffusion(model_name, download_hub)
if model_spec.model_ability and "ocr" in model_spec.model_ability:
return create_ocr_model_instance(
Expand All @@ -224,6 +230,12 @@ def create_image_model_instance(
model_path=model_path,
**kwargs,
)

# use default model config
model_default_config = (model_spec.default_model_config or {}).copy()
model_default_config.update(kwargs)
kwargs = model_default_config

controlnet = kwargs.get("controlnet")
# Handle controlnet
if controlnet is not None:
Expand Down Expand Up @@ -265,10 +277,20 @@ def create_image_model_instance(
lora_load_kwargs = None
lora_fuse_kwargs = None

model = DiffusionModel(
if (
platform.system() == "Darwin"
and "arm" in platform.machine().lower()
and model_name in MLXDiffusionModel.supported_models
):
# Mac with M series silicon chips
model_cls = MLXDiffusionModel
else:
model_cls = DiffusionModel # type: ignore

model = model_cls(
model_uid,
model_path,
lora_model_paths=lora_model,
lora_model=lora_model,
lora_load_kwargs=lora_load_kwargs,
lora_fuse_kwargs=lora_fuse_kwargs,
model_spec=model_spec,
Expand Down
18 changes: 15 additions & 3 deletions xinference/model/image/model_spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
"text2image",
"image2image",
"inpainting"
]
],
"default_model_config": {
"quantize": true,
"quantize_text_encoder": "text_encoder_2"
}
},
{
"model_name": "FLUX.1-dev",
Expand All @@ -19,7 +23,11 @@
"text2image",
"image2image",
"inpainting"
]
],
"default_model_config": {
"quantize": true,
"quantize_text_encoder": "text_encoder_2"
}
},
{
"model_name": "sd3-medium",
Expand All @@ -30,7 +38,11 @@
"text2image",
"image2image",
"inpainting"
]
],
"default_model_config": {
"quantize": true,
"quantize_text_encoder": "text_encoder_3"
}
},
{
"model_name": "sd-turbo",
Expand Down
18 changes: 15 additions & 3 deletions xinference/model/image/model_spec_modelscope.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
"text2image",
"image2image",
"inpainting"
]
],
"default_model_config": {
"quantize": true,
"quantize_text_encoder": "text_encoder_2"
}
},
{
"model_name": "FLUX.1-dev",
Expand All @@ -21,7 +25,11 @@
"text2image",
"image2image",
"inpainting"
]
],
"default_model_config": {
"quantize": true,
"quantize_text_encoder": "text_encoder_2"
}
},
{
"model_name": "sd3-medium",
Expand All @@ -33,7 +41,11 @@
"text2image",
"image2image",
"inpainting"
]
],
"default_model_config": {
"quantize": true,
"quantize_text_encoder": "text_encoder_3"
}
},
{
"model_name": "sd-turbo",
Expand Down
2 changes: 1 addition & 1 deletion xinference/model/image/scheduler/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(self):
self._running_queue: deque[Text2ImageRequest] = deque() # type: ignore
self._model = None
self._available_device = get_available_device()
self._id_to_req: Dict[str, Text2ImageRequest] = {}
self._id_to_req: Dict[str, Text2ImageRequest] = {} # type: ignore

def set_model(self, model):
"""
Expand Down
5 changes: 2 additions & 3 deletions xinference/model/image/stable_diffusion/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,8 @@ def _load_to_device(self, model):
model.enable_sequential_cpu_offload()
elif not self._kwargs.get("device_map"):
logger.debug("Loading model to available device")
model = move_model_to_available_device(self._model)
# Recommended if your computer has < 64 GB of RAM
if self._kwargs.get("attention_slicing", True):
model = move_model_to_available_device(model)
if self._kwargs.get("attention_slicing", False):
model.enable_attention_slicing()
if self._kwargs.get("vae_tiling", False):
model.enable_vae_tiling()
Expand Down
Loading

0 comments on commit d4cd7b1

Please sign in to comment.