Skip to content

Commit

Permalink
Add BiRefNet-General and BiRefNet-Portrait models as available models (
Browse files Browse the repository at this point in the history
  • Loading branch information
dimitribarbot authored Aug 26, 2024
1 parent ed1c295 commit d4c40e1
Show file tree
Hide file tree
Showing 38 changed files with 453 additions and 1 deletion.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,13 @@ The available models are:
- isnet-general-use ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-general-use.onnx), [source](https://github.com/xuebinqin/DIS)): A new pre-trained model for general use cases.
- isnet-anime ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/isnet-anime.onnx), [source](https://github.com/SkyTNT/anime-segmentation)): A high-accuracy segmentation for anime character.
- sam ([download encoder](https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-encoder-quant.onnx), [download decoder](https://github.com/danielgatis/rembg/releases/download/v0.0.0/vit_b-decoder-quant.onnx), [source](https://github.com/facebookresearch/segment-anything)): A pre-trained model for any use cases.
- birefnet-general ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-epoch_244.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for general use cases.
- birefnet-general-lite ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A light pre-trained model for general use cases.
- birefnet-portrait ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-portrait-epoch_150.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for human portraits.
- birefnet-dis ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-DIS-epoch_590.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for dichotomous image segmentation (DIS).
- birefnet-hrsod ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-HRSOD_DHU-epoch_115.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for high-resolution salient object detection (HRSOD).
- birefnet-cod ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-COD-epoch_125.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model for concealed object detection (COD).
- birefnet-massive ([download](https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-massive-TR_DIS5K_TR_TEs-epoch_420.onnx), [source](https://github.com/ZhengPeng7/BiRefNet)): A pre-trained model with massive dataset.

### How to train your own model

Expand Down
35 changes: 35 additions & 0 deletions rembg/sessions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,41 @@
sessions_class: List[type[BaseSession]] = []
sessions_names: List[str] = []

from .birefnet_general import BiRefNetSessionGeneral

sessions_class.append(BiRefNetSessionGeneral)
sessions_names.append(BiRefNetSessionGeneral.name())

from .birefnet_general_lite import BiRefNetSessionGeneralLite

sessions_class.append(BiRefNetSessionGeneralLite)
sessions_names.append(BiRefNetSessionGeneralLite.name())

from .birefnet_portrait import BiRefNetSessionPortrait

sessions_class.append(BiRefNetSessionPortrait)
sessions_names.append(BiRefNetSessionPortrait.name())

from .birefnet_dis import BiRefNetSessionDIS

sessions_class.append(BiRefNetSessionDIS)
sessions_names.append(BiRefNetSessionDIS.name())

from .birefnet_hrsod import BiRefNetSessionHRSOD

sessions_class.append(BiRefNetSessionHRSOD)
sessions_names.append(BiRefNetSessionHRSOD.name())

from .birefnet_cod import BiRefNetSessionCOD

sessions_class.append(BiRefNetSessionCOD)
sessions_names.append(BiRefNetSessionCOD.name())

from .birefnet_massive import BiRefNetSessionMassive

sessions_class.append(BiRefNetSessionMassive)
sessions_names.append(BiRefNetSessionMassive.name())

from .dis_anime import DisSession

sessions_class.append(DisSession)
Expand Down
52 changes: 52 additions & 0 deletions rembg/sessions/birefnet_cod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os

import pooch

from . import BiRefNetSessionGeneral


class BiRefNetSessionCOD(BiRefNetSessionGeneral):
"""
This class represents a BiRefNet-COD session, which is a subclass of BiRefNetSessionGeneral.
"""

@classmethod
def download_models(cls, *args, **kwargs):
"""
Downloads the BiRefNet-COD model file from a specific URL and saves it.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The path to the downloaded model file.
"""
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-COD-epoch_125.onnx",
(
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:f6d0d21ca89d287f17e7afe9f5fd3b45"
),
fname=fname,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)

return os.path.join(cls.u2net_home(*args, **kwargs), fname)

@classmethod
def name(cls, *args, **kwargs):
"""
Returns the name of the BiRefNet-COD session.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The name of the session.
"""
return "birefnet-cod"
52 changes: 52 additions & 0 deletions rembg/sessions/birefnet_dis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os

import pooch

from . import BiRefNetSessionGeneral


class BiRefNetSessionDIS(BiRefNetSessionGeneral):
"""
This class represents a BiRefNet-DIS session, which is a subclass of BiRefNetSessionGeneral.
"""

@classmethod
def download_models(cls, *args, **kwargs):
"""
Downloads the BiRefNet-DIS model file from a specific URL and saves it.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The path to the downloaded model file.
"""
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-DIS-epoch_590.onnx",
(
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:2d4d44102b446f33a4ebb2e56c051f2b"
),
fname=fname,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)

return os.path.join(cls.u2net_home(*args, **kwargs), fname)

@classmethod
def name(cls, *args, **kwargs):
"""
Returns the name of the BiRefNet-DIS session.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The name of the session.
"""
return "birefnet-dis"
91 changes: 91 additions & 0 deletions rembg/sessions/birefnet_general.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import os
from typing import List

import numpy as np
import pooch
from PIL import Image
from PIL.Image import Image as PILImage

from .base import BaseSession


class BiRefNetSessionGeneral(BaseSession):
"""
This class represents a BiRefNet-General session, which is a subclass of BaseSession.
"""

def sigmoid(self, mat):
return 1 / (1 + np.exp(-mat))

def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]:
"""
Predicts the output masks for the input image using the inner session.
Parameters:
img (PILImage): The input image.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
List[PILImage]: The list of output masks.
"""
ort_outs = self.inner_session.run(
None,
self.normalize(
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (1024, 1024)
),
)

pred = self.sigmoid(ort_outs[0][:, 0, :, :])

ma = np.max(pred)
mi = np.min(pred)

pred = (pred - mi) / (ma - mi)
pred = np.squeeze(pred)

mask = Image.fromarray((pred * 255).astype("uint8"), mode="L")
mask = mask.resize(img.size, Image.Resampling.LANCZOS)

return [mask]

@classmethod
def download_models(cls, *args, **kwargs):
"""
Downloads the BiRefNet-General model file from a specific URL and saves it.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The path to the downloaded model file.
"""
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-epoch_244.onnx",
(
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:7a35a0141cbbc80de11d9c9a28f52697"
),
fname=fname,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)

return os.path.join(cls.u2net_home(*args, **kwargs), fname)

@classmethod
def name(cls, *args, **kwargs):
"""
Returns the name of the BiRefNet-General session.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The name of the session.
"""
return "birefnet-general"
52 changes: 52 additions & 0 deletions rembg/sessions/birefnet_general_lite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os

import pooch

from . import BiRefNetSessionGeneral


class BiRefNetSessionGeneralLite(BiRefNetSessionGeneral):
"""
This class represents a BiRefNet-General-Lite session, which is a subclass of BiRefNetSessionGeneral.
"""

@classmethod
def download_models(cls, *args, **kwargs):
"""
Downloads the BiRefNet-General-Lite model file from a specific URL and saves it.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The path to the downloaded model file.
"""
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx",
(
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:4fab47adc4ff364be1713e97b7e66334"
),
fname=fname,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)

return os.path.join(cls.u2net_home(*args, **kwargs), fname)

@classmethod
def name(cls, *args, **kwargs):
"""
Returns the name of the BiRefNet-General-Lite session.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The name of the session.
"""
return "birefnet-general-lite"
52 changes: 52 additions & 0 deletions rembg/sessions/birefnet_hrsod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os

import pooch

from . import BiRefNetSessionGeneral


class BiRefNetSessionHRSOD(BiRefNetSessionGeneral):
"""
This class represents a BiRefNet-HRSOD session, which is a subclass of BiRefNetSessionGeneral.
"""

@classmethod
def download_models(cls, *args, **kwargs):
"""
Downloads the BiRefNet-HRSOD model file from a specific URL and saves it.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The path to the downloaded model file.
"""
fname = f"{cls.name(*args, **kwargs)}.onnx"
pooch.retrieve(
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-HRSOD_DHU-epoch_115.onnx",
(
None
if cls.checksum_disabled(*args, **kwargs)
else "md5:c017ade5de8a50ff0fd74d790d268dda"
),
fname=fname,
path=cls.u2net_home(*args, **kwargs),
progressbar=True,
)

return os.path.join(cls.u2net_home(*args, **kwargs), fname)

@classmethod
def name(cls, *args, **kwargs):
"""
Returns the name of the BiRefNet-HRSOD session.
Parameters:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
str: The name of the session.
"""
return "birefnet-hrsod"
Loading

0 comments on commit d4c40e1

Please sign in to comment.