Skip to content

Commit

Permalink
adding Image APIs and deploying moving to prod fly account
Browse files Browse the repository at this point in the history
  • Loading branch information
jay-dhanwant-yral committed Nov 7, 2024
1 parent f235de0 commit ec659e9
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 44 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/deploy-on-merge-to-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ jobs:
run: |
flyctl secrets set --stage --app "yral-nsfw-classification" "SERVICE_CRED=$SERVICE_CRED"
env:
FLY_API_TOKEN: ${{ secrets.FLY_NSFW_TOKEN }}
FLY_API_TOKEN: ${{ secrets.FLY_IO_DEPLOY_TOKENS }}
SERVICE_CRED: ${{ secrets.SERVICE_CRED }}
- name: Deploy a docker container to fly.io
run: flyctl deploy --remote-only
env:
FLY_API_TOKEN: ${{ secrets.FLY_NSFW_TOKEN }}
FLY_API_TOKEN: ${{ secrets.FLY_IO_DEPLOY_TOKENS }}
6 changes: 3 additions & 3 deletions fly.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# See https://fly.io/docs/reference/configuration/ for information about how to use this file.
#

app = 'yral-nsfw-classification'
app = 'prod-yral-nsfw-classification'
primary_region = 'ewr'
kill_signal = 'SIGINT'
kill_timeout = '5s'
Expand All @@ -22,7 +22,7 @@ swap_size_mb = 32768
port = 443
handlers = ['tls']

[services.ports.tls_options]
[services.ports.tls_optionse
alpn = ['h2']

[services.concurrency]
Expand All @@ -32,4 +32,4 @@ swap_size_mb = 32768
[[vm]]
memory = '8gb'
cpu_kind = 'performance'
cpus = 4
cpus = 1
29 changes: 21 additions & 8 deletions localping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,41 @@
import nsfw_detector_pb2_grpc
import os
import jwt
import base64

token_fpath = "/Users/jaydhanwant/Documents/SS/nsfw_jwt_token.txt"
with open(token_fpath, 'r') as f:
_JWT_TOKEN = f.read()

server_url = 'localhost:50051'
# Load the private key from a path specified in an environment variable

def run():
# NOTE: Replace with the actual server address if different
channel = grpc.insecure_channel(server_url)
stub = nsfw_detector_pb2_grpc.NSFWDetectorStub(channel)

# Test video ID (replace with an actual video ID for testing)
video_id = "21727f4f6cbe496d9a410202537faff8"
# Test video ID
video_id = "21727f4f6cbe496d9a410202537faff8"
# Test image URL
image_url = "https://img-cdn.pixlr.com/image-generator/history/65bb506dcb310754719cf81f/ede935de-1138-4f66-8ed7-44bd16efc709/medium.webp"
# Test image path for byte64
image_path = "/Users/jaydhanwant/Downloads/WhatsApp Image 2024-08-29 at 13.18.09.jpeg"

try:
print("Sending request to server")
print("Sending video ID request to server")
metadata = [('authorization', f'Bearer {_JWT_TOKEN}')]
response = stub.DetectNSFW(nsfw_detector_pb2.NSFWDetectorRequest(video_id=video_id), metadata=metadata)
# response = stub.DetectNSFW(nsfw_detector_pb2.NSFWDetectorRequest(video_id=video_id))
print(response)
# video_response = stub.DetectNSFWVideoId(nsfw_detector_pb2.NSFWDetectorRequestVideoId(video_id=video_id), metadata=metadata)
# print(video_response)

print("Sending image URL request to server")
url_response = stub.DetectNSFWURL(nsfw_detector_pb2.NSFWDetectorRequestURL(url=image_url), metadata=metadata)
print(url_response)

print("Sending image to byte64 request to server")
with open(image_path, "rb") as image_file:
image_byte64 = base64.b64encode(image_file.read()).decode('utf-8')
byte64_response = stub.DetectNSFWImg(nsfw_detector_pb2.NSFWDetectorRequestImg(image=image_byte64), metadata=metadata)
print(byte64_response)

except grpc.RpcError as e:
print(f"RPC failed: {e}")

Expand Down
22 changes: 16 additions & 6 deletions nsfw_detect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, pipe3c, pipe5c):
credentials = service_account.Credentials.from_service_account_info(service_acc_creds)
self.gclient = vision.ImageAnnotatorClient(credentials=credentials)

def nsfw_detect(self, imgs):
def explicit_detect(self, imgs):
marks = {'QUESTIONABLE provocative': 'provocative',
'QUESTIONABLE porn': 'explicit',
'QUESTIONABLE neutral': 'neutral',
Expand All @@ -38,13 +38,23 @@ def nsfw_detect(self, imgs):
res3cs = self.pipe3c(imgs)
res5cs = self.pipe5c(imgs)

mark3cs = [max(res3c, key=lambda x: x['score']) for res3c in res3cs]
mark5cs = [max(res5c, key=lambda x: x['score']) for res5c in res5cs]
"""
This is how the res3cs and res5cs look like
res3cs: [[{'label': 'QUESTIONABLE', 'score': 0.9604137539863586}, {'label': 'UNSAFE', 'score': 0.6502232551574707}, {'label': 'SAFE', 'score': 0.03905165567994118}], [{'label': 'SAFE', 'score': 0.9722312092781067}, {'label': 'UNSAFE', 'score': 0.31041547656059265}, {'label': 'QUESTIONABLE', 'score': 0.185553178191185}], [{'label': 'UNSAFE', 'score': 0.7398239970207214}, {'label': 'SAFE', 'score': 0.7222445607185364}, {'label': 'QUESTIONABLE', 'score': 0.2648926079273224}]]
res5cs: [[{'label': 'provocative', 'score': 0.9864626526832581}, {'label': 'neutral', 'score': 0.7108702063560486}, {'label': 'drawings', 'score': 0.22955366969108582}, {'label': 'hentai', 'score': 0.14005741477012634}, {'label': 'porn', 'score': 0.1315544694662094}], [{'label': 'neutral', 'score': 0.99712735414505}, {'label': 'hentai', 'score': 0.25641435384750366}, {'label': 'provocative', 'score': 0.23076564073562622}, {'label': 'drawings', 'score': 0.2289031445980072}, {'label': 'porn', 'score': 0.10852369666099548}], [{'label': 'neutral', 'score': 0.9969491362571716}, {'label': 'drawings', 'score': 0.37695300579071045}, {'label': 'hentai', 'score': 0.15976859629154205}, {'label': 'provocative', 'score': 0.14641618728637695}, {'label': 'porn', 'score': 0.14610938727855682}]]
Output corresponding to each image is a list of [{label, score}...]
and the res is a list of such lists
"""

mark3cs = [max(res3c, key=lambda x: x['score']) for res3c in res3cs] # maintaining the max score with the label
mark5cs = [max(res5c, key=lambda x: x['score']) for res5c in res5cs] # maintaining the max score with the label

marks = [(marks[mark3c['label'] + ' ' + mark5c['label']], mark3c['score'], mark5c['score']) for mark3c, mark5c in zip(mark3cs, mark5cs)]
marks = [(marks[mark3c['label'] + ' ' + mark5c['label']], mark3c['score'], mark5c['score']) for mark3c, mark5c in zip(mark3cs, mark5cs)] # zipped label
return marks

def detect_nsfw_gore(self, pil_image):
def gore_detect(self, pil_image):
try:
"""Detects NSFW content in a PIL image and returns the safe search annotation."""

Expand Down Expand Up @@ -112,6 +122,6 @@ def load_model_artifacts(artifact_path):
img2 = Image.open("/Users/jaydhanwant/Downloads/3.jpg")
img3 = Image.open("/Users/jaydhanwant/Documents/questionable.png")
imgs = [img, img2, img3]
result = detector.nsfw_detect(imgs)
result = detector.explicit_detect(imgs)
print(result)

15 changes: 13 additions & 2 deletions nsfw_detector.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,26 @@ syntax = "proto3";

package nsfw_detector;

message NSFWDetectorRequest {
message NSFWDetectorRequestVideoId {
string video_id = 1;
}

message NSFWDetectorRequestURL {
string url = 1;
}

message NSFWDetectorRequestImg {
string image = 1;
}

message NSFWDetectorResponse {
string nsfw_ec = 1;
string nsfw_gore = 2;
bool csam_detected = 3;
}

service NSFWDetector {
rpc DetectNSFW(NSFWDetectorRequest) returns (NSFWDetectorResponse) {};
rpc DetectNSFWVideoId(NSFWDetectorRequestVideoId) returns (NSFWDetectorResponse) {};
rpc DetectNSFWURL(NSFWDetectorRequestURL) returns (NSFWDetectorResponse) {};
rpc DetectNSFWImg(NSFWDetectorRequestImg) returns (NSFWDetectorResponse) {};
}
18 changes: 11 additions & 7 deletions nsfw_detector_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

106 changes: 96 additions & 10 deletions nsfw_detector_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,39 @@ def __init__(self, channel):
Args:
channel: A grpc.Channel.
"""
self.DetectNSFW = channel.unary_unary(
'/nsfw_detector.NSFWDetector/DetectNSFW',
request_serializer=nsfw__detector__pb2.NSFWDetectorRequest.SerializeToString,
self.DetectNSFWVideoId = channel.unary_unary(
'/nsfw_detector.NSFWDetector/DetectNSFWVideoId',
request_serializer=nsfw__detector__pb2.NSFWDetectorRequestVideoId.SerializeToString,
response_deserializer=nsfw__detector__pb2.NSFWDetectorResponse.FromString,
_registered_method=True)
self.DetectNSFWURL = channel.unary_unary(
'/nsfw_detector.NSFWDetector/DetectNSFWURL',
request_serializer=nsfw__detector__pb2.NSFWDetectorRequestURL.SerializeToString,
response_deserializer=nsfw__detector__pb2.NSFWDetectorResponse.FromString,
_registered_method=True)
self.DetectNSFWImg = channel.unary_unary(
'/nsfw_detector.NSFWDetector/DetectNSFWImg',
request_serializer=nsfw__detector__pb2.NSFWDetectorRequestImg.SerializeToString,
response_deserializer=nsfw__detector__pb2.NSFWDetectorResponse.FromString,
_registered_method=True)


class NSFWDetectorServicer(object):
"""Missing associated documentation comment in .proto file."""

def DetectNSFW(self, request, context):
def DetectNSFWVideoId(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def DetectNSFWURL(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def DetectNSFWImg(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
Expand All @@ -58,9 +80,19 @@ def DetectNSFW(self, request, context):

def add_NSFWDetectorServicer_to_server(servicer, server):
rpc_method_handlers = {
'DetectNSFW': grpc.unary_unary_rpc_method_handler(
servicer.DetectNSFW,
request_deserializer=nsfw__detector__pb2.NSFWDetectorRequest.FromString,
'DetectNSFWVideoId': grpc.unary_unary_rpc_method_handler(
servicer.DetectNSFWVideoId,
request_deserializer=nsfw__detector__pb2.NSFWDetectorRequestVideoId.FromString,
response_serializer=nsfw__detector__pb2.NSFWDetectorResponse.SerializeToString,
),
'DetectNSFWURL': grpc.unary_unary_rpc_method_handler(
servicer.DetectNSFWURL,
request_deserializer=nsfw__detector__pb2.NSFWDetectorRequestURL.FromString,
response_serializer=nsfw__detector__pb2.NSFWDetectorResponse.SerializeToString,
),
'DetectNSFWImg': grpc.unary_unary_rpc_method_handler(
servicer.DetectNSFWImg,
request_deserializer=nsfw__detector__pb2.NSFWDetectorRequestImg.FromString,
response_serializer=nsfw__detector__pb2.NSFWDetectorResponse.SerializeToString,
),
}
Expand All @@ -75,7 +107,61 @@ class NSFWDetector(object):
"""Missing associated documentation comment in .proto file."""

@staticmethod
def DetectNSFW(request,
def DetectNSFWVideoId(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/nsfw_detector.NSFWDetector/DetectNSFWVideoId',
nsfw__detector__pb2.NSFWDetectorRequestVideoId.SerializeToString,
nsfw__detector__pb2.NSFWDetectorResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

@staticmethod
def DetectNSFWURL(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(
request,
target,
'/nsfw_detector.NSFWDetector/DetectNSFWURL',
nsfw__detector__pb2.NSFWDetectorRequestURL.SerializeToString,
nsfw__detector__pb2.NSFWDetectorResponse.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True)

@staticmethod
def DetectNSFWImg(request,
target,
options=(),
channel_credentials=None,
Expand All @@ -88,8 +174,8 @@ def DetectNSFW(request,
return grpc.experimental.unary_unary(
request,
target,
'/nsfw_detector.NSFWDetector/DetectNSFW',
nsfw__detector__pb2.NSFWDetectorRequest.SerializeToString,
'/nsfw_detector.NSFWDetector/DetectNSFWImg',
nsfw__detector__pb2.NSFWDetectorRequestImg.SerializeToString,
nsfw__detector__pb2.NSFWDetectorResponse.FromString,
options,
channel_credentials,
Expand Down
Loading

0 comments on commit ec659e9

Please sign in to comment.