Skip to content

Commit

Permalink
parallelising, adding thresholds
Browse files Browse the repository at this point in the history
  • Loading branch information
jay-dhanwant-yral committed Oct 25, 2024
1 parent 65db7f5 commit 8585ad4
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
2 changes: 1 addition & 1 deletion localping.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def run():
stub = nsfw_detector_pb2_grpc.NSFWDetectorStub(channel)

# Test video ID (replace with an actual video ID for testing)
video_id = "00034f1c9c9148388bf6873776222535"
video_id = "21727f4f6cbe496d9a410202537faff8"

try:
print("Sending request to server")
Expand Down
22 changes: 13 additions & 9 deletions nsfw_detect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging

_LOGGER = logging.getLogger(__name__)

class NSFWDetect:
def __init__(self, pipe3c, pipe5c):
self.pipe3c = pipe3c
Expand All @@ -15,7 +16,7 @@ def __init__(self, pipe3c, pipe5c):
credentials = service_account.Credentials.from_service_account_info(service_acc_creds)
self.gclient = vision.ImageAnnotatorClient(credentials=credentials)

def nsfw_detect(self, img):
def nsfw_detect(self, imgs):
marks = {'QUESTIONABLE provocative': 'provocative',
'QUESTIONABLE porn': 'explicit',
'QUESTIONABLE neutral': 'neutral',
Expand All @@ -34,14 +35,14 @@ def nsfw_detect(self, img):

from concurrent.futures import ThreadPoolExecutor

res3c = self.pipe3c(img)
res5c = self.pipe5c(img)
res3cs = self.pipe3c(imgs)
res5cs = self.pipe5c(imgs)

mark3c = max(res3c, key=lambda x: x['score'])['label']
mark5c = max(res5c, key=lambda x: x['score'])['label']
mark3cs = [max(res3c, key=lambda x: x['score']) for res3c in res3cs]
mark5cs = [max(res5c, key=lambda x: x['score']) for res5c in res5cs]

mark = marks[mark3c + ' ' + mark5c]
return {'res':mark, 'metadata': {'mark3c': mark3c, 'mark5c': mark5c}}
marks = [(marks[mark3c['label'] + ' ' + mark5c['label']], mark3c['score'], mark5c['score']) for mark3c, mark5c in zip(mark3cs, mark5cs)]
return marks

def detect_nsfw_gore(self, pil_image):
try:
Expand Down Expand Up @@ -108,6 +109,9 @@ def load_model_artifacts(artifact_path):
pipe3c, pipe5c = load_model_artifacts("model_artifacts")
detector = NSFWDetect(pipe3c, pipe5c)
img = Image.open("/Users/jaydhanwant/Downloads/WhatsApp Image 2024-08-29 at 13.18.09.jpeg")
result = detector.detect_nsfw_gore(img)
print(result['violence'])
img2 = Image.open("/Users/jaydhanwant/Downloads/3.jpg")
img3 = Image.open("/Users/jaydhanwant/Documents/questionable.png")
imgs = [img, img2, img3]
result = detector.nsfw_detect(imgs)
print(result)

8 changes: 4 additions & 4 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

_ONE_DAY = datetime.timedelta(days=1)
_PROCESS_COUNT = multiprocessing.cpu_count()
_LOGGER.info(f"Process count: {_PROCESS_COUNT}")
# _PROCESS_COUNT = 1
_THREAD_CONCURRENCY = 10 # heuristic
_BIND_ADDRESS = "[::]:50051"
Expand Down Expand Up @@ -104,15 +103,16 @@ def DetectNSFW(self, request, context):

def process_frames(self, video_id):
frames = get_images_from_gcs("yral-video-frames", video_id)
nsfw_tags = []
nsfw_tags = self.nsfw_detector.nsfw_detect([frame['image'] for frame in frames])
gore_tags = []
for frame in frames:
nsfw_tags.append(self.nsfw_detector.nsfw_detect(frame['image'])['res'])
gore_tags.append(self.nsfw_detector.detect_nsfw_gore(frame['image']))
tag_priority = "explicit nudity provocative neutral".split()
gore_priority = ["UNKNOWN", "VERY_UNLIKELY", "UNLIKELY", "POSSIBLE", "LIKELY", "VERY_LIKELY"][::-1]
# Sort nsfw_tags based on the priority defined in tag_priority
nsfw_tags = [i[0] for i in nsfw_tags if i[1] > 0.82 and i[2]>0.9]
nsfw_tags.sort(key=lambda tag: tag_priority.index(tag))

gore_tags.sort(key=lambda tag: gore_priority.index(tag))

nsfw_tag = None
Expand Down Expand Up @@ -158,7 +158,7 @@ def _run_server():

def main():
multiprocessing.set_start_method("spawn", force=True)
_LOGGER.info(f"Binding to '{_BIND_ADDRESS}'")
_LOGGER.info(f"Binding to '{_BIND_ADDRESS}' with Process Count: {_PROCESS_COUNT}")
sys.stdout.flush()
workers = []
for _ in range(_PROCESS_COUNT):
Expand Down

0 comments on commit 8585ad4

Please sign in to comment.