diff --git a/localping.py b/localping.py index 57b873d..b949f5f 100644 --- a/localping.py +++ b/localping.py @@ -17,7 +17,7 @@ def run(): stub = nsfw_detector_pb2_grpc.NSFWDetectorStub(channel) # Test video ID (replace with an actual video ID for testing) - video_id = "00034f1c9c9148388bf6873776222535" + video_id = "21727f4f6cbe496d9a410202537faff8" try: print("Sending request to server") diff --git a/nsfw_detect_utils.py b/nsfw_detect_utils.py index 79b8dea..d326ea1 100644 --- a/nsfw_detect_utils.py +++ b/nsfw_detect_utils.py @@ -6,6 +6,7 @@ import logging _LOGGER = logging.getLogger(__name__) + class NSFWDetect: def __init__(self, pipe3c, pipe5c): self.pipe3c = pipe3c @@ -15,7 +16,7 @@ def __init__(self, pipe3c, pipe5c): credentials = service_account.Credentials.from_service_account_info(service_acc_creds) self.gclient = vision.ImageAnnotatorClient(credentials=credentials) - def nsfw_detect(self, img): + def nsfw_detect(self, imgs): marks = {'QUESTIONABLE provocative': 'provocative', 'QUESTIONABLE porn': 'explicit', 'QUESTIONABLE neutral': 'neutral', @@ -34,14 +35,14 @@ def nsfw_detect(self, img): from concurrent.futures import ThreadPoolExecutor - res3c = self.pipe3c(img) - res5c = self.pipe5c(img) + res3cs = self.pipe3c(imgs) + res5cs = self.pipe5c(imgs) - mark3c = max(res3c, key=lambda x: x['score'])['label'] - mark5c = max(res5c, key=lambda x: x['score'])['label'] + mark3cs = [max(res3c, key=lambda x: x['score']) for res3c in res3cs] + mark5cs = [max(res5c, key=lambda x: x['score']) for res5c in res5cs] - mark = marks[mark3c + ' ' + mark5c] - return {'res':mark, 'metadata': {'mark3c': mark3c, 'mark5c': mark5c}} + marks = [(marks[mark3c['label'] + ' ' + mark5c['label']], mark3c['score'], mark5c['score']) for mark3c, mark5c in zip(mark3cs, mark5cs)] + return marks def detect_nsfw_gore(self, pil_image): try: @@ -108,6 +109,9 @@ def load_model_artifacts(artifact_path): pipe3c, pipe5c = load_model_artifacts("model_artifacts") detector = NSFWDetect(pipe3c, pipe5c) img = Image.open("/Users/jaydhanwant/Downloads/WhatsApp Image 2024-08-29 at 13.18.09.jpeg") - result = detector.detect_nsfw_gore(img) - print(result['violence']) + img2 = Image.open("/Users/jaydhanwant/Downloads/3.jpg") + img3 = Image.open("/Users/jaydhanwant/Documents/questionable.png") + imgs = [img, img2, img3] + result = detector.nsfw_detect(imgs) + print(result) \ No newline at end of file diff --git a/server.py b/server.py index d743bef..fbfd476 100644 --- a/server.py +++ b/server.py @@ -25,7 +25,6 @@ _ONE_DAY = datetime.timedelta(days=1) _PROCESS_COUNT = multiprocessing.cpu_count() -_LOGGER.info(f"Process count: {_PROCESS_COUNT}") # _PROCESS_COUNT = 1 _THREAD_CONCURRENCY = 10 # heuristic _BIND_ADDRESS = "[::]:50051" @@ -104,15 +103,16 @@ def DetectNSFW(self, request, context): def process_frames(self, video_id): frames = get_images_from_gcs("yral-video-frames", video_id) - nsfw_tags = [] + nsfw_tags = self.nsfw_detector.nsfw_detect([frame['image'] for frame in frames]) gore_tags = [] for frame in frames: - nsfw_tags.append(self.nsfw_detector.nsfw_detect(frame['image'])['res']) gore_tags.append(self.nsfw_detector.detect_nsfw_gore(frame['image'])) tag_priority = "explicit nudity provocative neutral".split() gore_priority = ["UNKNOWN", "VERY_UNLIKELY", "UNLIKELY", "POSSIBLE", "LIKELY", "VERY_LIKELY"][::-1] # Sort nsfw_tags based on the priority defined in tag_priority + nsfw_tags = [i[0] for i in nsfw_tags if i[1] > 0.82 and i[2]>0.9] nsfw_tags.sort(key=lambda tag: tag_priority.index(tag)) + gore_tags.sort(key=lambda tag: gore_priority.index(tag)) nsfw_tag = None @@ -158,7 +158,7 @@ def _run_server(): def main(): multiprocessing.set_start_method("spawn", force=True) - _LOGGER.info(f"Binding to '{_BIND_ADDRESS}'") + _LOGGER.info(f"Binding to '{_BIND_ADDRESS}' with Process Count: {_PROCESS_COUNT}") sys.stdout.flush() workers = [] for _ in range(_PROCESS_COUNT):