-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
25 lines (19 loc) · 902 Bytes
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
from transformers import pipeline
import argparse
def get_args():
parser = argparse.ArgumentParser(description='Startup Parameters')
parser.add_argument("--tf32", action='store_true', help="Use tf32")
parser.add_argument("--device", type=str, default=0, help="Cuda")
parser.add_argument("--image", type=str, default="test/test.jpg", help="Image")
return parser.parse_args()
args = get_args()
if args.tf32:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
pipe = pipeline("image-classification", model="shadowlilac/aesthetic-shadow-v2", device=f"cuda:{args.device}")
# Input image file
single_image_file = args.image
result = pipe(images=[single_image_file])
prediction_single = result[0]
print("High Quality: " + str(round([p for p in prediction_single if p['label'] == 'hq'][0]['score'], 2)) )