From 7c3e60be4c026c94f071a2d62fd4363fbe96d7de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=81=B5=E3=81=81?= Date: Fri, 29 Nov 2024 13:18:35 +0900 Subject: [PATCH 1/3] fix sum providers bug MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: ふぁ --- rembg/sessions/sam.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/rembg/sessions/sam.py b/rembg/sessions/sam.py index 7a287cd0..e6d8c4bf 100644 --- a/rembg/sessions/sam.py +++ b/rembg/sessions/sam.py @@ -105,9 +105,10 @@ def __init__( valid_providers = [] available_providers = ort.get_available_providers() - for provider in providers or []: - if provider in available_providers: - valid_providers.append(provider) + if providers: + for provider in providers or []: + if provider in available_providers: + valid_providers.append(provider) else: valid_providers.extend(available_providers) From 00ad95d1c00965323e53b2cf289ec948b6f9fbbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=81=B5=E3=81=81?= Date: Fri, 29 Nov 2024 13:19:33 +0900 Subject: [PATCH 2/3] remove unused type imports in sam MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: ふぁ --- rembg/sessions/sam.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rembg/sessions/sam.py b/rembg/sessions/sam.py index e6d8c4bf..8b358a4c 100644 --- a/rembg/sessions/sam.py +++ b/rembg/sessions/sam.py @@ -1,6 +1,6 @@ import os from copy import deepcopy -from typing import Dict, List, Tuple +from typing import List import cv2 import numpy as np From 9e6c46184d645b12e09de41f62123f2e9df89c7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=81=B5=E3=81=81?= Date: Fri, 29 Nov 2024 13:21:10 +0900 Subject: [PATCH 3/3] update default sam_prompt structure to include point data MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: ふぁ --- rembg/sessions/sam.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/rembg/sessions/sam.py b/rembg/sessions/sam.py index 8b358a4c..b0c02214 100644 --- a/rembg/sessions/sam.py +++ b/rembg/sessions/sam.py @@ -143,7 +143,16 @@ def predict( Returns: List[PILImage]: A list of masks generated by the decoder. """ - prompt = kwargs.get("sam_prompt", "{}") + prompt = kwargs.get( + "sam_prompt", + [ + { + "type": "point", + "label": 1, + "data": [int(img.width / 2), int(img.height / 2)], + } + ], + ) schema = { "type": "array", "items": {