-
Notifications
You must be signed in to change notification settings - Fork 28
/
run_whole.py
132 lines (111 loc) · 3.98 KB
/
run_whole.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from cmdline import args
import time
GREEDY = args.greedy
N_SAMPLES = args.n_samples
MAX_N_SAMPLES = args.max_n_samples
MINS_TIMEOUT = args.mins_timeout
from lang import can_be_solution_whole
from lang import score_func_whole as uncached_score_func
from prompts import prompt, min_lines, check_func, check_string, test_dict
if test_dict:
uncached_score_func_before_dict = uncached_score_func
uncached_score_func = lambda x: uncached_score_func_before_dict(x, test_dict)
from common_cache import create_cached_func
score_func, cache_stats, reset_cache = create_cached_func(uncached_score_func)
import llm
import wandb
if args.use_wandb:
wandb.init(
entity=args.wandb_entity,
project=args.wandb_project,
group=args.wandb_group,
config=args.dict(),
name=args.wandb_name,
)
def attempt(prompt=prompt, attempt_id=0):
attempt_stats = {"attempt_id": attempt_id}
init_n_tokens = llm.token_counter
init_time = time.time()
text = llm.generate_full(prompt)
# Args are handled inside of llm.py
# if GREEDY:
# text = llm.generate_full(prompt)
# else:
# text = llm.generate_full(
# prompt, do_sample=True, top_p=0.9, top_k=7, temperature=0.8
# )
score = score_func(text)
is_solution = (
score is not None
and score > 0
and can_be_solution_whole(text, min_lines, check_func, check_string)
)
score_sign = 0 if score is None else (1 if score > 0 else -1)
attempt_stats["time"] = time.time() - init_time
attempt_stats["text"] = text
attempt_stats["is_solution"] = 1 if is_solution else 0
attempt_stats["score_sign"] = score_sign
attempt_stats["n_tokens"] = llm.token_counter - init_n_tokens
return attempt_stats
def summary(all_stats):
n_solutions = sum(stats["is_solution"] for stats in all_stats)
n_positive = sum(stats["score_sign"] > 0 for stats in all_stats)
n_negative = sum(stats["score_sign"] < 0 for stats in all_stats)
n_zero = sum(stats["score_sign"] == 0 for stats in all_stats)
n_tokens = sum(stats["n_tokens"] for stats in all_stats)
total_time = sum(stats["time"] for stats in all_stats)
print(
{
"n_attempts": len(all_stats),
"n_solutions": n_solutions,
"n_positive": n_positive,
"n_negative": n_negative,
"n_zero": n_zero,
"n_tokens": n_tokens,
"total_time": total_time,
}
)
def main(mins_timeout=MINS_TIMEOUT, prompt=prompt):
all_stats = []
if MAX_N_SAMPLES is not None:
assert not GREEDY
solution = False
n_calls = 0
while not solution and n_calls < MAX_N_SAMPLES:
stats = attempt(prompt=prompt, attempt_id=n_calls)
all_stats.append(stats)
solution = stats["is_solution"]
n_calls += 1
if stats["is_solution"]:
print("SOLUTION FOUND")
print(stats["text"])
else:
print("SOLUTION is None")
elif mins_timeout is None:
for i in range(0, 1 if GREEDY else N_SAMPLES):
stats = attempt(prompt=prompt, attempt_id=i)
all_stats.append(stats)
if args.use_wandb:
wandb.log(stats)
for stats in all_stats:
if stats["is_solution"]:
solution = stats["text"]
print("ONE SOLUTION")
print(solution)
else:
# make mins_timeout the stronger parameter
start_time = time.time() # Save the start time when the loop begins
timeout = mins_timeout * 60 # Convert minutes to seconds
i = 0
while (time.time() - start_time) < timeout:
stats = attempt(prompt=prompt, attempt_id=i)
i += 1
all_stats.append(stats)
if args.use_wandb:
wandb.log(stats)
if stats["is_solution"]:
break
summary(all_stats)
return all_stats
if __name__ == "__main__":
main()