-
Notifications
You must be signed in to change notification settings - Fork 0
/
create_random_fewshot_file.py
executable file
·40 lines (33 loc) · 1.39 KB
/
create_random_fewshot_file.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
import pickle
import argparse
import random
from utils import *
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--domain", type=str, required=True)
parser.add_argument("--src-lang", type=str, default="de", help="Eval Source Language")
parser.add_argument("--tgt-lang", type=str, default="en", help="Eval Target Language")
parser.add_argument("--out-prompt-file", type=str, required=True)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--split", type=str, default="test")
parser.add_argument("--num-trials", type=int, default=100)
parser.add_argument("--k", type=str, default=1)
args = parser.parse_args()
train_src, train_tgt = get_data(args.domain, args.src_lang, args.tgt_lang, "train")
random.seed(args.seed)
num_itrs = args.num_trials
for j in range(num_itrs):
random.seed(j)
indices = random.sample(range(len(train_src)), args.k)
prompts = [
FewShotSample(data={
"src": train_src[ind],
"tgt": train_tgt[ind]
}, correct_candidates=[train_tgt[ind]])
for ind in indices ]
prompt_file=f"{args.out_prompt_file}.{j}.pkl"
with open(prompt_file, "wb") as f:
pickle.dump(prompts, f)
if __name__ == '__main__':
main()