-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo.py
61 lines (35 loc) · 1.78 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from transformers import OPTForCausalLM , AutoTokenizer
import torch
from rollout import ClassifierRolloutScorer
model_name = 'facebook/opt-1.3b'
tokenizer_name = 'facebook/opt-1.3b'
model = OPTForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
model.to('cuda')
input_text = 'The men started swearing at me, called me'
input_len = len(input_text)
encoding =tokenizer.batch_encode_plus([input_text], return_tensors="pt").to('cuda')
input_ids = encoding['input_ids']
attention_mask = encoding['attention_mask']
# usuall beam search
beam_output = model.generate(input_ids,attention_mask=attention_mask, max_length=30, early_stopping=True,
do_sample=False,num_beams=5,num_return_sequences=1, output_scores=True,return_dict_in_generate=True)
texts = tokenizer.batch_decode(beam_output['sequences'], skip_special_tokens=True)
continuations = [tt[input_len:] for tt in texts]
# define rollout scorer using a toxicity classifier
rollout_scorer = ClassifierRolloutScorer(clf_name='s-nlp/roberta_toxicity_classifier',model_tokenizer_name=tokenizer_name,label=0,sharp=False)
scores = rollout_scorer.classifier.get_scores(texts)
for s,t in zip(scores,texts):
print('Toxicity:',1-s)
print('text:',t)
print('*'*50)
print('^'*50)
# beam search with rollouts
beam_output = model.generate(input_ids,attention_mask=attention_mask,roll_out_scorer=rollout_scorer, max_length=30,
do_sample=False,num_beams=5,num_return_sequences=1, output_scores=True,return_dict_in_generate=True,branching_factor=20,rollout_length=10)
texts = tokenizer.batch_decode(beam_output['sequences'], skip_special_tokens=True)
scores = rollout_scorer.classifier.get_scores(texts)
for s,t in zip(scores,texts):
print('Toxicity:',1-s)
print('text:',t)
print('*'*50)