diff --git a/.github/workflows/deploy-on-merge-to-main.yml b/.github/workflows/deploy-on-merge-to-main.yml index 6f8adfb..7747496 100644 --- a/.github/workflows/deploy-on-merge-to-main.yml +++ b/.github/workflows/deploy-on-merge-to-main.yml @@ -21,9 +21,9 @@ jobs: run: | flyctl secrets set --stage --app "yral-nsfw-classification" "SERVICE_CRED=$SERVICE_CRED" env: - FLY_API_TOKEN: ${{ secrets.FLY_NSFW_TOKEN }} + FLY_API_TOKEN: ${{ secrets.FLY_IO_DEPLOY_TOKENS }} SERVICE_CRED: ${{ secrets.SERVICE_CRED }} - name: Deploy a docker container to fly.io run: flyctl deploy --remote-only env: - FLY_API_TOKEN: ${{ secrets.FLY_NSFW_TOKEN }} + FLY_API_TOKEN: ${{ secrets.FLY_IO_DEPLOY_TOKENS }} diff --git a/fly.toml b/fly.toml index 10d7229..91e4180 100644 --- a/fly.toml +++ b/fly.toml @@ -3,7 +3,7 @@ # See https://fly.io/docs/reference/configuration/ for information about how to use this file. # -app = 'yral-nsfw-classification' +app = 'prod-yral-nsfw-classification' primary_region = 'ewr' kill_signal = 'SIGINT' kill_timeout = '5s' @@ -22,7 +22,7 @@ swap_size_mb = 32768 port = 443 handlers = ['tls'] - [services.ports.tls_options] + [services.ports.tls_optionse alpn = ['h2'] [services.concurrency] @@ -32,4 +32,4 @@ swap_size_mb = 32768 [[vm]] memory = '8gb' cpu_kind = 'performance' - cpus = 4 + cpus = 1 diff --git a/localping.py b/localping.py index b949f5f..879984d 100644 --- a/localping.py +++ b/localping.py @@ -3,28 +3,41 @@ import nsfw_detector_pb2_grpc import os import jwt +import base64 token_fpath = "/Users/jaydhanwant/Documents/SS/nsfw_jwt_token.txt" with open(token_fpath, 'r') as f: _JWT_TOKEN = f.read() server_url = 'localhost:50051' -# Load the private key from a path specified in an environment variable def run(): - # NOTE: Replace with the actual server address if different channel = grpc.insecure_channel(server_url) stub = nsfw_detector_pb2_grpc.NSFWDetectorStub(channel) - # Test video ID (replace with an actual video ID for testing) - video_id = "21727f4f6cbe496d9a410202537faff8" + # Test video ID + video_id = "21727f4f6cbe496d9a410202537faff8" + # Test image URL + image_url = "https://img-cdn.pixlr.com/image-generator/history/65bb506dcb310754719cf81f/ede935de-1138-4f66-8ed7-44bd16efc709/medium.webp" + # Test image path for byte64 + image_path = "/Users/jaydhanwant/Downloads/WhatsApp Image 2024-08-29 at 13.18.09.jpeg" try: - print("Sending request to server") + print("Sending video ID request to server") metadata = [('authorization', f'Bearer {_JWT_TOKEN}')] - response = stub.DetectNSFW(nsfw_detector_pb2.NSFWDetectorRequest(video_id=video_id), metadata=metadata) - # response = stub.DetectNSFW(nsfw_detector_pb2.NSFWDetectorRequest(video_id=video_id)) - print(response) + # video_response = stub.DetectNSFWVideoId(nsfw_detector_pb2.NSFWDetectorRequestVideoId(video_id=video_id), metadata=metadata) + # print(video_response) + + print("Sending image URL request to server") + url_response = stub.DetectNSFWURL(nsfw_detector_pb2.NSFWDetectorRequestURL(url=image_url), metadata=metadata) + print(url_response) + + print("Sending image to byte64 request to server") + with open(image_path, "rb") as image_file: + image_byte64 = base64.b64encode(image_file.read()).decode('utf-8') + byte64_response = stub.DetectNSFWImg(nsfw_detector_pb2.NSFWDetectorRequestImg(image=image_byte64), metadata=metadata) + print(byte64_response) + except grpc.RpcError as e: print(f"RPC failed: {e}") diff --git a/nsfw_detect_utils.py b/nsfw_detect_utils.py index d326ea1..4e4ae18 100644 --- a/nsfw_detect_utils.py +++ b/nsfw_detect_utils.py @@ -16,7 +16,7 @@ def __init__(self, pipe3c, pipe5c): credentials = service_account.Credentials.from_service_account_info(service_acc_creds) self.gclient = vision.ImageAnnotatorClient(credentials=credentials) - def nsfw_detect(self, imgs): + def explicit_detect(self, imgs): marks = {'QUESTIONABLE provocative': 'provocative', 'QUESTIONABLE porn': 'explicit', 'QUESTIONABLE neutral': 'neutral', @@ -38,13 +38,23 @@ def nsfw_detect(self, imgs): res3cs = self.pipe3c(imgs) res5cs = self.pipe5c(imgs) - mark3cs = [max(res3c, key=lambda x: x['score']) for res3c in res3cs] - mark5cs = [max(res5c, key=lambda x: x['score']) for res5c in res5cs] + """ + This is how the res3cs and res5cs look like + res3cs: [[{'label': 'QUESTIONABLE', 'score': 0.9604137539863586}, {'label': 'UNSAFE', 'score': 0.6502232551574707}, {'label': 'SAFE', 'score': 0.03905165567994118}], [{'label': 'SAFE', 'score': 0.9722312092781067}, {'label': 'UNSAFE', 'score': 0.31041547656059265}, {'label': 'QUESTIONABLE', 'score': 0.185553178191185}], [{'label': 'UNSAFE', 'score': 0.7398239970207214}, {'label': 'SAFE', 'score': 0.7222445607185364}, {'label': 'QUESTIONABLE', 'score': 0.2648926079273224}]] + res5cs: [[{'label': 'provocative', 'score': 0.9864626526832581}, {'label': 'neutral', 'score': 0.7108702063560486}, {'label': 'drawings', 'score': 0.22955366969108582}, {'label': 'hentai', 'score': 0.14005741477012634}, {'label': 'porn', 'score': 0.1315544694662094}], [{'label': 'neutral', 'score': 0.99712735414505}, {'label': 'hentai', 'score': 0.25641435384750366}, {'label': 'provocative', 'score': 0.23076564073562622}, {'label': 'drawings', 'score': 0.2289031445980072}, {'label': 'porn', 'score': 0.10852369666099548}], [{'label': 'neutral', 'score': 0.9969491362571716}, {'label': 'drawings', 'score': 0.37695300579071045}, {'label': 'hentai', 'score': 0.15976859629154205}, {'label': 'provocative', 'score': 0.14641618728637695}, {'label': 'porn', 'score': 0.14610938727855682}]] + + + Output corresponding to each image is a list of [{label, score}...] + and the res is a list of such lists + """ + + mark3cs = [max(res3c, key=lambda x: x['score']) for res3c in res3cs] # maintaining the max score with the label + mark5cs = [max(res5c, key=lambda x: x['score']) for res5c in res5cs] # maintaining the max score with the label - marks = [(marks[mark3c['label'] + ' ' + mark5c['label']], mark3c['score'], mark5c['score']) for mark3c, mark5c in zip(mark3cs, mark5cs)] + marks = [(marks[mark3c['label'] + ' ' + mark5c['label']], mark3c['score'], mark5c['score']) for mark3c, mark5c in zip(mark3cs, mark5cs)] # zipped label return marks - def detect_nsfw_gore(self, pil_image): + def gore_detect(self, pil_image): try: """Detects NSFW content in a PIL image and returns the safe search annotation.""" @@ -112,6 +122,6 @@ def load_model_artifacts(artifact_path): img2 = Image.open("/Users/jaydhanwant/Downloads/3.jpg") img3 = Image.open("/Users/jaydhanwant/Documents/questionable.png") imgs = [img, img2, img3] - result = detector.nsfw_detect(imgs) + result = detector.explicit_detect(imgs) print(result) \ No newline at end of file diff --git a/nsfw_detector.proto b/nsfw_detector.proto index 99c606d..f510e8b 100644 --- a/nsfw_detector.proto +++ b/nsfw_detector.proto @@ -2,15 +2,26 @@ syntax = "proto3"; package nsfw_detector; -message NSFWDetectorRequest { +message NSFWDetectorRequestVideoId { string video_id = 1; } +message NSFWDetectorRequestURL { + string url = 1; +} + +message NSFWDetectorRequestImg { + string image = 1; +} + message NSFWDetectorResponse { string nsfw_ec = 1; string nsfw_gore = 2; + bool csam_detected = 3; } service NSFWDetector { - rpc DetectNSFW(NSFWDetectorRequest) returns (NSFWDetectorResponse) {}; + rpc DetectNSFWVideoId(NSFWDetectorRequestVideoId) returns (NSFWDetectorResponse) {}; + rpc DetectNSFWURL(NSFWDetectorRequestURL) returns (NSFWDetectorResponse) {}; + rpc DetectNSFWImg(NSFWDetectorRequestImg) returns (NSFWDetectorResponse) {}; } \ No newline at end of file diff --git a/nsfw_detector_pb2.py b/nsfw_detector_pb2.py index 9ef3d8a..443dbe7 100644 --- a/nsfw_detector_pb2.py +++ b/nsfw_detector_pb2.py @@ -14,17 +14,21 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13nsfw_detector.proto\x12\rnsfw_detector\"\'\n\x13NSFWDetectorRequest\x12\x10\n\x08video_id\x18\x01 \x01(\t\":\n\x14NSFWDetectorResponse\x12\x0f\n\x07nsfw_ec\x18\x01 \x01(\t\x12\x11\n\tnsfw_gore\x18\x02 \x01(\t2g\n\x0cNSFWDetector\x12W\n\nDetectNSFW\x12\".nsfw_detector.NSFWDetectorRequest\x1a#.nsfw_detector.NSFWDetectorResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13nsfw_detector.proto\x12\rnsfw_detector\".\n\x1aNSFWDetectorRequestVideoId\x12\x10\n\x08video_id\x18\x01 \x01(\t\"%\n\x16NSFWDetectorRequestURL\x12\x0b\n\x03url\x18\x01 \x01(\t\"\'\n\x16NSFWDetectorRequestImg\x12\r\n\x05image\x18\x01 \x01(\t\"Q\n\x14NSFWDetectorResponse\x12\x0f\n\x07nsfw_ec\x18\x01 \x01(\t\x12\x11\n\tnsfw_gore\x18\x02 \x01(\t\x12\x15\n\rcsam_detected\x18\x03 \x01(\x08\x32\xb3\x02\n\x0cNSFWDetector\x12\x65\n\x11\x44\x65tectNSFWVideoId\x12).nsfw_detector.NSFWDetectorRequestVideoId\x1a#.nsfw_detector.NSFWDetectorResponse\"\x00\x12]\n\rDetectNSFWURL\x12%.nsfw_detector.NSFWDetectorRequestURL\x1a#.nsfw_detector.NSFWDetectorResponse\"\x00\x12]\n\rDetectNSFWImg\x12%.nsfw_detector.NSFWDetectorRequestImg\x1a#.nsfw_detector.NSFWDetectorResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'nsfw_detector_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals['_NSFWDETECTORREQUEST']._serialized_start=38 - _globals['_NSFWDETECTORREQUEST']._serialized_end=77 - _globals['_NSFWDETECTORRESPONSE']._serialized_start=79 - _globals['_NSFWDETECTORRESPONSE']._serialized_end=137 - _globals['_NSFWDETECTOR']._serialized_start=139 - _globals['_NSFWDETECTOR']._serialized_end=242 + _globals['_NSFWDETECTORREQUESTVIDEOID']._serialized_start=38 + _globals['_NSFWDETECTORREQUESTVIDEOID']._serialized_end=84 + _globals['_NSFWDETECTORREQUESTURL']._serialized_start=86 + _globals['_NSFWDETECTORREQUESTURL']._serialized_end=123 + _globals['_NSFWDETECTORREQUESTIMG']._serialized_start=125 + _globals['_NSFWDETECTORREQUESTIMG']._serialized_end=164 + _globals['_NSFWDETECTORRESPONSE']._serialized_start=166 + _globals['_NSFWDETECTORRESPONSE']._serialized_end=247 + _globals['_NSFWDETECTOR']._serialized_start=250 + _globals['_NSFWDETECTOR']._serialized_end=557 # @@protoc_insertion_point(module_scope) diff --git a/nsfw_detector_pb2_grpc.py b/nsfw_detector_pb2_grpc.py index daebcbc..3b2d7cc 100644 --- a/nsfw_detector_pb2_grpc.py +++ b/nsfw_detector_pb2_grpc.py @@ -39,9 +39,19 @@ def __init__(self, channel): Args: channel: A grpc.Channel. """ - self.DetectNSFW = channel.unary_unary( - '/nsfw_detector.NSFWDetector/DetectNSFW', - request_serializer=nsfw__detector__pb2.NSFWDetectorRequest.SerializeToString, + self.DetectNSFWVideoId = channel.unary_unary( + '/nsfw_detector.NSFWDetector/DetectNSFWVideoId', + request_serializer=nsfw__detector__pb2.NSFWDetectorRequestVideoId.SerializeToString, + response_deserializer=nsfw__detector__pb2.NSFWDetectorResponse.FromString, + _registered_method=True) + self.DetectNSFWURL = channel.unary_unary( + '/nsfw_detector.NSFWDetector/DetectNSFWURL', + request_serializer=nsfw__detector__pb2.NSFWDetectorRequestURL.SerializeToString, + response_deserializer=nsfw__detector__pb2.NSFWDetectorResponse.FromString, + _registered_method=True) + self.DetectNSFWImg = channel.unary_unary( + '/nsfw_detector.NSFWDetector/DetectNSFWImg', + request_serializer=nsfw__detector__pb2.NSFWDetectorRequestImg.SerializeToString, response_deserializer=nsfw__detector__pb2.NSFWDetectorResponse.FromString, _registered_method=True) @@ -49,7 +59,19 @@ def __init__(self, channel): class NSFWDetectorServicer(object): """Missing associated documentation comment in .proto file.""" - def DetectNSFW(self, request, context): + def DetectNSFWVideoId(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def DetectNSFWURL(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def DetectNSFWImg(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') @@ -58,9 +80,19 @@ def DetectNSFW(self, request, context): def add_NSFWDetectorServicer_to_server(servicer, server): rpc_method_handlers = { - 'DetectNSFW': grpc.unary_unary_rpc_method_handler( - servicer.DetectNSFW, - request_deserializer=nsfw__detector__pb2.NSFWDetectorRequest.FromString, + 'DetectNSFWVideoId': grpc.unary_unary_rpc_method_handler( + servicer.DetectNSFWVideoId, + request_deserializer=nsfw__detector__pb2.NSFWDetectorRequestVideoId.FromString, + response_serializer=nsfw__detector__pb2.NSFWDetectorResponse.SerializeToString, + ), + 'DetectNSFWURL': grpc.unary_unary_rpc_method_handler( + servicer.DetectNSFWURL, + request_deserializer=nsfw__detector__pb2.NSFWDetectorRequestURL.FromString, + response_serializer=nsfw__detector__pb2.NSFWDetectorResponse.SerializeToString, + ), + 'DetectNSFWImg': grpc.unary_unary_rpc_method_handler( + servicer.DetectNSFWImg, + request_deserializer=nsfw__detector__pb2.NSFWDetectorRequestImg.FromString, response_serializer=nsfw__detector__pb2.NSFWDetectorResponse.SerializeToString, ), } @@ -75,7 +107,61 @@ class NSFWDetector(object): """Missing associated documentation comment in .proto file.""" @staticmethod - def DetectNSFW(request, + def DetectNSFWVideoId(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/nsfw_detector.NSFWDetector/DetectNSFWVideoId', + nsfw__detector__pb2.NSFWDetectorRequestVideoId.SerializeToString, + nsfw__detector__pb2.NSFWDetectorResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def DetectNSFWURL(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/nsfw_detector.NSFWDetector/DetectNSFWURL', + nsfw__detector__pb2.NSFWDetectorRequestURL.SerializeToString, + nsfw__detector__pb2.NSFWDetectorResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def DetectNSFWImg(request, target, options=(), channel_credentials=None, @@ -88,8 +174,8 @@ def DetectNSFW(request, return grpc.experimental.unary_unary( request, target, - '/nsfw_detector.NSFWDetector/DetectNSFW', - nsfw__detector__pb2.NSFWDetectorRequest.SerializeToString, + '/nsfw_detector.NSFWDetector/DetectNSFWImg', + nsfw__detector__pb2.NSFWDetectorRequestImg.SerializeToString, nsfw__detector__pb2.NSFWDetectorResponse.FromString, options, channel_credentials, diff --git a/server.py b/server.py index fbfd476..77fb6e8 100644 --- a/server.py +++ b/server.py @@ -19,13 +19,16 @@ from nsfw_detect_utils import NSFWDetect from save_model_artifacts import download_artifacts # hardocded the gcr paths here TODO: Move that to a config import torch - +from PIL import Image +import io +import base64 +import requests _LOGGER = logging.getLogger(__name__) _ONE_DAY = datetime.timedelta(days=1) _PROCESS_COUNT = multiprocessing.cpu_count() -# _PROCESS_COUNT = 1 +# _PROCESS_COUNT = 1 # TODO: change this back after testing _THREAD_CONCURRENCY = 10 # heuristic _BIND_ADDRESS = "[::]:50051" @@ -94,19 +97,33 @@ def __init__(self): _LOGGER.info(self.pipe5c.model.config) _LOGGER.info('='*100) - def DetectNSFW(self, request, context): + def DetectNSFWVideoId(self, request, context): _LOGGER.info("Request received") video_id = request.video_id nsfw_tag, gore_tag = self.process_frames(video_id) - response = nsfw_detector_pb2.NSFWDetectorResponse(nsfw_ec=nsfw_tag, nsfw_gore=gore_tag) + response = nsfw_detector_pb2.NSFWDetectorResponse(nsfw_ec=nsfw_tag, nsfw_gore=gore_tag, csam_detected=False) + return response + + def DetectNSFWURL(self, request, context): + _LOGGER.info("Request received") + image_url = request.url + nsfw_tag, gore_tag = self.process_image_url(image_url) + response = nsfw_detector_pb2.NSFWDetectorResponse(nsfw_ec=nsfw_tag, nsfw_gore=gore_tag, csam_detected=False) + return response + + def DetectNSFWImg(self, request, context): + _LOGGER.info("Request received") + image_byte64 = request.image + nsfw_tag, gore_tag = self.process_image_byte64(image_byte64) + response = nsfw_detector_pb2.NSFWDetectorResponse(nsfw_ec=nsfw_tag, nsfw_gore=gore_tag, csam_detected=False) return response def process_frames(self, video_id): frames = get_images_from_gcs("yral-video-frames", video_id) - nsfw_tags = self.nsfw_detector.nsfw_detect([frame['image'] for frame in frames]) + nsfw_tags = self.nsfw_detector.explicit_detect([frame['image'] for frame in frames]) gore_tags = [] for frame in frames: - gore_tags.append(self.nsfw_detector.detect_nsfw_gore(frame['image'])) + gore_tags.append(self.nsfw_detector.gore_detect(frame['image'])) tag_priority = "explicit nudity provocative neutral".split() gore_priority = ["UNKNOWN", "VERY_UNLIKELY", "UNLIKELY", "POSSIBLE", "LIKELY", "VERY_LIKELY"][::-1] # Sort nsfw_tags based on the priority defined in tag_priority @@ -124,6 +141,26 @@ def process_frames(self, video_id): return [nsfw_tag, gore_tag] + def process_image_byte64(self, image_byte64): + image = Image.open(io.BytesIO(base64.b64decode(image_byte64))) + nsfw_res = self.nsfw_detector.explicit_detect([image])[0] + mark3c_score = nsfw_res[1] + mark5c_score = nsfw_res[2] + if mark3c_score > 0.82 and mark5c_score > 0.9: + nsfw_tag = nsfw_res[0] + else: + nsfw_tag = None + + gore_tag = self.nsfw_detector.gore_detect(image) + return [nsfw_tag, gore_tag] + + def process_image_url(self, image_url): + image = Image.open(requests.get(image_url, stream=True).raw) + buffered = io.BytesIO() + image.save(buffered, format="JPEG") + image_byte64 = base64.b64encode(buffered.getvalue()).decode('utf-8') + nsfw_res = self.process_image_byte64(image_byte64) + return nsfw_res