Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
andyz245 committed Nov 17, 2023
1 parent 5fa1ce7 commit e08d36f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 15 deletions.
17 changes: 4 additions & 13 deletions hotpot/lats.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,9 @@ def __init__(self, state, question, parent=None):

def uct(self):
if self.visits == 0:
return float('inf')
#return self.value * 2
return self.value
return self.value / self.visits + np.sqrt(2 * np.log(self.parent.visits) / self.visits)

def uct_with_depth(self, C1=1, C2=1):
if self.visits == 0:
return float('inf')
exploitation_term = self.value / self.visits
exploration_term = np.sqrt(2 * np.log(self.parent.visits) / self.visits)
depth_term = self.depth
return exploitation_term + C1 * exploration_term + C2 * depth_term

def __str__(self):
return f"Node(depth={self.depth}, value={self.value:.2f}, visits={self.visits}, thought={self.state['thought']}, action={self.state['action']}, observation={self.state['observation']})"

Expand Down Expand Up @@ -165,7 +156,7 @@ def collect_trajectory(node):
node = node.parent
return '\n'.join(reversed(trajectory))

def lats_search(args, task, idx, iterations=50, to_print=True):
def lats_search(args, task, idx, iterations=30, to_print=True):
global gpt
global failed_trajectories
global reflection_map
Expand Down Expand Up @@ -205,7 +196,7 @@ def lats_search(args, task, idx, iterations=50, to_print=True):

value = evaluate_node(node, args, task)
# Find the child with the highest value
reward, terminal_node = rollout(max(node.children, key=lambda child: child.value), args, task, idx, max_depth=7)
reward, terminal_node = rollout(max(node.children, key=lambda child: child.value), args, task, idx, max_depth=4)

terminal_nodes.append(terminal_node)

Expand Down Expand Up @@ -277,7 +268,7 @@ def expand_node(node, args, task):
new_nodes = generate_new_states(node, args, task, args.n_generate_sample)
node.children.extend(new_nodes)

def rollout(node, args, task, idx, max_depth=7):
def rollout(node, args, task, idx, max_depth=4):
logging.info("ROLLING OUT")
depth = node.depth
n = 5
Expand Down
5 changes: 3 additions & 2 deletions programming/mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def run_mcts(
exe = executor_factory(language, is_leet=is_leetcode)
gen = generator_factory(language)
model = model_factory(model_name)
test_model = model_factory("gpt4")
print_v = make_printv(verbose)

num_items = len(dataset)
Expand All @@ -110,7 +111,7 @@ def run_mcts(
if is_leetcode:
tests_i = item['visible_tests']
else:
tests_i = gen.internal_tests(item["prompt"], model, 4)
tests_i = gen.internal_tests(item["prompt"], test_model, 6)

while cur_func_impl is None:
cur_func_impl = gen.func_impl(item["prompt"], model, "simple")
Expand Down Expand Up @@ -214,7 +215,7 @@ def run_mcts(
item["solution"] = child.solution
is_solved = True
reward_real = 1
break
break

if is_solved:
break
Expand Down

0 comments on commit e08d36f

Please sign in to comment.