-
-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add BiRefNet-General and BiRefNet-Portrait models as available models (…
- Loading branch information
1 parent
ed1c295
commit d4c40e1
Showing
38 changed files
with
453 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import os | ||
|
||
import pooch | ||
|
||
from . import BiRefNetSessionGeneral | ||
|
||
|
||
class BiRefNetSessionCOD(BiRefNetSessionGeneral): | ||
""" | ||
This class represents a BiRefNet-COD session, which is a subclass of BiRefNetSessionGeneral. | ||
""" | ||
|
||
@classmethod | ||
def download_models(cls, *args, **kwargs): | ||
""" | ||
Downloads the BiRefNet-COD model file from a specific URL and saves it. | ||
Parameters: | ||
*args: Additional positional arguments. | ||
**kwargs: Additional keyword arguments. | ||
Returns: | ||
str: The path to the downloaded model file. | ||
""" | ||
fname = f"{cls.name(*args, **kwargs)}.onnx" | ||
pooch.retrieve( | ||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-COD-epoch_125.onnx", | ||
( | ||
None | ||
if cls.checksum_disabled(*args, **kwargs) | ||
else "md5:f6d0d21ca89d287f17e7afe9f5fd3b45" | ||
), | ||
fname=fname, | ||
path=cls.u2net_home(*args, **kwargs), | ||
progressbar=True, | ||
) | ||
|
||
return os.path.join(cls.u2net_home(*args, **kwargs), fname) | ||
|
||
@classmethod | ||
def name(cls, *args, **kwargs): | ||
""" | ||
Returns the name of the BiRefNet-COD session. | ||
Parameters: | ||
*args: Additional positional arguments. | ||
**kwargs: Additional keyword arguments. | ||
Returns: | ||
str: The name of the session. | ||
""" | ||
return "birefnet-cod" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import os | ||
|
||
import pooch | ||
|
||
from . import BiRefNetSessionGeneral | ||
|
||
|
||
class BiRefNetSessionDIS(BiRefNetSessionGeneral): | ||
""" | ||
This class represents a BiRefNet-DIS session, which is a subclass of BiRefNetSessionGeneral. | ||
""" | ||
|
||
@classmethod | ||
def download_models(cls, *args, **kwargs): | ||
""" | ||
Downloads the BiRefNet-DIS model file from a specific URL and saves it. | ||
Parameters: | ||
*args: Additional positional arguments. | ||
**kwargs: Additional keyword arguments. | ||
Returns: | ||
str: The path to the downloaded model file. | ||
""" | ||
fname = f"{cls.name(*args, **kwargs)}.onnx" | ||
pooch.retrieve( | ||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-DIS-epoch_590.onnx", | ||
( | ||
None | ||
if cls.checksum_disabled(*args, **kwargs) | ||
else "md5:2d4d44102b446f33a4ebb2e56c051f2b" | ||
), | ||
fname=fname, | ||
path=cls.u2net_home(*args, **kwargs), | ||
progressbar=True, | ||
) | ||
|
||
return os.path.join(cls.u2net_home(*args, **kwargs), fname) | ||
|
||
@classmethod | ||
def name(cls, *args, **kwargs): | ||
""" | ||
Returns the name of the BiRefNet-DIS session. | ||
Parameters: | ||
*args: Additional positional arguments. | ||
**kwargs: Additional keyword arguments. | ||
Returns: | ||
str: The name of the session. | ||
""" | ||
return "birefnet-dis" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import os | ||
from typing import List | ||
|
||
import numpy as np | ||
import pooch | ||
from PIL import Image | ||
from PIL.Image import Image as PILImage | ||
|
||
from .base import BaseSession | ||
|
||
|
||
class BiRefNetSessionGeneral(BaseSession): | ||
""" | ||
This class represents a BiRefNet-General session, which is a subclass of BaseSession. | ||
""" | ||
|
||
def sigmoid(self, mat): | ||
return 1 / (1 + np.exp(-mat)) | ||
|
||
def predict(self, img: PILImage, *args, **kwargs) -> List[PILImage]: | ||
""" | ||
Predicts the output masks for the input image using the inner session. | ||
Parameters: | ||
img (PILImage): The input image. | ||
*args: Additional positional arguments. | ||
**kwargs: Additional keyword arguments. | ||
Returns: | ||
List[PILImage]: The list of output masks. | ||
""" | ||
ort_outs = self.inner_session.run( | ||
None, | ||
self.normalize( | ||
img, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225), (1024, 1024) | ||
), | ||
) | ||
|
||
pred = self.sigmoid(ort_outs[0][:, 0, :, :]) | ||
|
||
ma = np.max(pred) | ||
mi = np.min(pred) | ||
|
||
pred = (pred - mi) / (ma - mi) | ||
pred = np.squeeze(pred) | ||
|
||
mask = Image.fromarray((pred * 255).astype("uint8"), mode="L") | ||
mask = mask.resize(img.size, Image.Resampling.LANCZOS) | ||
|
||
return [mask] | ||
|
||
@classmethod | ||
def download_models(cls, *args, **kwargs): | ||
""" | ||
Downloads the BiRefNet-General model file from a specific URL and saves it. | ||
Parameters: | ||
*args: Additional positional arguments. | ||
**kwargs: Additional keyword arguments. | ||
Returns: | ||
str: The path to the downloaded model file. | ||
""" | ||
fname = f"{cls.name(*args, **kwargs)}.onnx" | ||
pooch.retrieve( | ||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-epoch_244.onnx", | ||
( | ||
None | ||
if cls.checksum_disabled(*args, **kwargs) | ||
else "md5:7a35a0141cbbc80de11d9c9a28f52697" | ||
), | ||
fname=fname, | ||
path=cls.u2net_home(*args, **kwargs), | ||
progressbar=True, | ||
) | ||
|
||
return os.path.join(cls.u2net_home(*args, **kwargs), fname) | ||
|
||
@classmethod | ||
def name(cls, *args, **kwargs): | ||
""" | ||
Returns the name of the BiRefNet-General session. | ||
Parameters: | ||
*args: Additional positional arguments. | ||
**kwargs: Additional keyword arguments. | ||
Returns: | ||
str: The name of the session. | ||
""" | ||
return "birefnet-general" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import os | ||
|
||
import pooch | ||
|
||
from . import BiRefNetSessionGeneral | ||
|
||
|
||
class BiRefNetSessionGeneralLite(BiRefNetSessionGeneral): | ||
""" | ||
This class represents a BiRefNet-General-Lite session, which is a subclass of BiRefNetSessionGeneral. | ||
""" | ||
|
||
@classmethod | ||
def download_models(cls, *args, **kwargs): | ||
""" | ||
Downloads the BiRefNet-General-Lite model file from a specific URL and saves it. | ||
Parameters: | ||
*args: Additional positional arguments. | ||
**kwargs: Additional keyword arguments. | ||
Returns: | ||
str: The path to the downloaded model file. | ||
""" | ||
fname = f"{cls.name(*args, **kwargs)}.onnx" | ||
pooch.retrieve( | ||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-general-bb_swin_v1_tiny-epoch_232.onnx", | ||
( | ||
None | ||
if cls.checksum_disabled(*args, **kwargs) | ||
else "md5:4fab47adc4ff364be1713e97b7e66334" | ||
), | ||
fname=fname, | ||
path=cls.u2net_home(*args, **kwargs), | ||
progressbar=True, | ||
) | ||
|
||
return os.path.join(cls.u2net_home(*args, **kwargs), fname) | ||
|
||
@classmethod | ||
def name(cls, *args, **kwargs): | ||
""" | ||
Returns the name of the BiRefNet-General-Lite session. | ||
Parameters: | ||
*args: Additional positional arguments. | ||
**kwargs: Additional keyword arguments. | ||
Returns: | ||
str: The name of the session. | ||
""" | ||
return "birefnet-general-lite" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import os | ||
|
||
import pooch | ||
|
||
from . import BiRefNetSessionGeneral | ||
|
||
|
||
class BiRefNetSessionHRSOD(BiRefNetSessionGeneral): | ||
""" | ||
This class represents a BiRefNet-HRSOD session, which is a subclass of BiRefNetSessionGeneral. | ||
""" | ||
|
||
@classmethod | ||
def download_models(cls, *args, **kwargs): | ||
""" | ||
Downloads the BiRefNet-HRSOD model file from a specific URL and saves it. | ||
Parameters: | ||
*args: Additional positional arguments. | ||
**kwargs: Additional keyword arguments. | ||
Returns: | ||
str: The path to the downloaded model file. | ||
""" | ||
fname = f"{cls.name(*args, **kwargs)}.onnx" | ||
pooch.retrieve( | ||
"https://github.com/danielgatis/rembg/releases/download/v0.0.0/BiRefNet-HRSOD_DHU-epoch_115.onnx", | ||
( | ||
None | ||
if cls.checksum_disabled(*args, **kwargs) | ||
else "md5:c017ade5de8a50ff0fd74d790d268dda" | ||
), | ||
fname=fname, | ||
path=cls.u2net_home(*args, **kwargs), | ||
progressbar=True, | ||
) | ||
|
||
return os.path.join(cls.u2net_home(*args, **kwargs), fname) | ||
|
||
@classmethod | ||
def name(cls, *args, **kwargs): | ||
""" | ||
Returns the name of the BiRefNet-HRSOD session. | ||
Parameters: | ||
*args: Additional positional arguments. | ||
**kwargs: Additional keyword arguments. | ||
Returns: | ||
str: The name of the session. | ||
""" | ||
return "birefnet-hrsod" |
Oops, something went wrong.