Skip to content

Commit

Permalink
Merge pull request #44 from OSU-NLP-Group/som
Browse files Browse the repository at this point in the history
Add SOM Grounding and Update README
  • Loading branch information
boyuanzheng010 authored Jul 10, 2024
2 parents 3e1f548 + 47f3bd9 commit 1b41279
Show file tree
Hide file tree
Showing 7 changed files with 284 additions and 32 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,10 @@ if __name__ == "__main__":

| Name | Description | Type | Default | Required |
|------|-------------|------|---------|:--------:|
| model | Prefered LLM model to run the task | str | gpt-4-turbo | no |
| model | Prefered LLM model to run the task | str | gpt-4o | no |
| default_task | Default task to run | str | Find the pdf of the paper "GPT-4V(ision) is a Generalist Web Agent, if Grounded" | no |
| default_website | Default starting website | str | https://www.google.com/ | no |
| grounding_strategy | Grounding strategy <ul><li>text_choice: use text choices</li><li>text_choice_som: use text choices with set of marks</li></ul> | str | text_choice_som | no |
| config_path | Configuration file path | str | None | no |
| save_file_dir | Folder to save output files | str | seeact_agent_files | no |
| temperature | Termperature passed to LLM | num | 0.9 | no |
Expand Down
2 changes: 1 addition & 1 deletion seeact_package/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# os.environ["GEMINI_API_KEY"] = "Your API KEY Here"

async def run_agent():
agent = SeeActAgent(model="gpt-4-turbo")
agent = SeeActAgent(model="gpt-4o")
await agent.start()
while not agent.complete_flag:
prediction_dict = await agent.predict()
Expand Down
36 changes: 29 additions & 7 deletions seeact_package/seeact/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
import json
import toml
from playwright.async_api import async_playwright,Locator
from os.path import dirname, join as joinpath
import asyncio

from .data_utils.format_prompt_utils import get_index_from_option_name, generate_new_query_prompt, \
generate_new_referring_prompt, format_options
generate_new_referring_prompt, format_options, generate_option_name
from .demo_utils.browser_helper import normal_launch_async, normal_new_context_async, \
get_interactive_elements_with_playwright, select_option, saveconfig
from .demo_utils.format_prompt import format_choices, postprocess_action_lmm
Expand All @@ -36,7 +38,7 @@ def __init__(self,
default_task='Find the pdf of the paper "GPT-4V(ision) is a Generalist Web Agent, if Grounded"',
default_website="https://www.google.com/",
input_info=["screenshot"],
grounding_strategy="text_choice",
grounding_strategy="text_choice_som",
max_auto_op=50,
max_continuous_no_op=5,
highlight=False,
Expand All @@ -57,7 +59,7 @@ def __init__(self,
"sources": True
},
rate_limit=-1,
model="gpt-4-turbo",
model="gpt-4o",
temperature=0.9

):
Expand Down Expand Up @@ -484,6 +486,11 @@ async def predict(self):
except Exception as e:
pass

if self.config["agent"]["grounding_strategy"] == "text_choice_som":
with open(os.path.join(dirname(__file__), "mark_page.js")) as f:
mark_page_script = f.read()
await self.session_control['active_page'].evaluate(mark_page_script)

elements = await get_interactive_elements_with_playwright(self.session_control['active_page'],
self.config['browser']['viewport'])

Expand All @@ -499,10 +506,22 @@ async def predict(self):
elements = sorted(elements, key=lambda el: (
el["center_point"][1], el["center_point"][0])) # Sorting by y and then x coordinate


elements = [{**x, "idx": i, "option": generate_option_name(i)} for i,x in enumerate(elements)]
page = self.session_control['active_page']


if self.config["agent"]["grounding_strategy"] == "text_choice_som":
await page.evaluate("unmarkPage()")
await page.evaluate("""elements => {
return window.som.drawBoxes(elements);
}""", elements)

# Generate choices for the prompt

# , self.config['basic']['default_task'], self.taken_actions
choices = format_choices(elements)
options = format_options(choices)

# print("\n\n",choices)
prompt = self.generate_prompt(task=self.tasks[-1], previous=self.taken_actions, choices=choices)
Expand All @@ -512,8 +531,8 @@ async def predict(self):

# Capture a screenshot for the current state of the webpage, if required by the model
screenshot_path = os.path.join(self.main_path, 'screenshots', f'screen_{self.time_step}.png')
try:
await self.session_control['active_page'].screenshot(path=screenshot_path)
try:
await page.screenshot(path=screenshot_path)
except Exception as e:
self.logger.info(f"Failed to take screenshot: {e}")

Expand All @@ -537,8 +556,7 @@ async def predict(self):
terminal_width = 10
self.logger.info("-" * (terminal_width))

choice_text = f"Action Grounding ➡️" + "\n" + format_options(
choices)
choice_text = f"Action Grounding ➡️" + "\n" + options
choice_text = choice_text.replace("\n\n", "")

for line in choice_text.split('\n'):
Expand Down Expand Up @@ -581,6 +599,10 @@ async def execute(self, prediction_dict):
Execute the predicted action on the webpage.
"""

# Clear the marks before action
if self.config["agent"]["grounding_strategy"] == "text_choice_som":
await self.session_control['active_page'].evaluate("unmarkPage()")

pred_element = prediction_dict["element"]
pred_action = prediction_dict["action"]
pred_value = prediction_dict["value"]
Expand Down
11 changes: 10 additions & 1 deletion seeact_package/seeact/demo_utils/browser_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ async def get_element_data(element, tag_name,viewport_size,seen_elements=[]):
if rect['x']<0 or rect['y']<0 or rect['width']<=4 or rect['height']<=4 or rect['y']+rect['height']>viewport_size["height"] or rect['x']+ rect['width']>viewport_size["width"]:
return None

box_raw = [rect['x'], rect['y'], rect['width'], rect['height']]
box_model = [rect['x'], rect['y'], rect['x'] + rect['width'], rect['y'] + rect['height']]
center_point = (round((box_model[0] + box_model[2]) / 2 / viewport_size["width"], 3),
round((box_model[1] + box_model[3]) / 2 / viewport_size["height"], 3))
Expand Down Expand Up @@ -294,7 +295,15 @@ async def get_element_data(element, tag_name,viewport_size,seen_elements=[]):
5. tag
'''
selector = element
return {"center_point":center_point,"description":description,"tag_with_role":tag_head,"box":box_model,"selector":selector,"tag":real_tag_name}
return {
"center_point":center_point,
"description":description,
"tag_with_role":tag_head,
"box_raw":box_raw,
"box":box_model,
"selector":selector,
"tag":real_tag_name
}
# return [center_point, description, tag_head, box_model, selector, real_tag_name]
except Exception as e:
# print(e)
Expand Down
51 changes: 30 additions & 21 deletions seeact_package/seeact/demo_utils/format_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re

import shlex

def format_choices(elements):

converted_elements = [
f'{element["center_point"]} <{element["tag_with_role"]}">'
+ (
element["description"]
if len(element["description"].split()) < 30
else " ".join(element["description"].split()[:30]) + "..."
)
+ f"</{element['tag']}>"

if element['tag']!="select" else f'{element["center_point"]} <{element["tag_with_role"]}>'
+ (
element["description"]
)
+ f"</{element['tag']}>"
for i, element in enumerate(elements)
]


converted_elements
converted_elements = []
elements_w_descriptions = []
for element in elements:
if "description" in element and "=" in element["description"]:
description_dict = []
for sub in shlex.split(element["description"]):
if '=' in sub:
description_dict.append(map(str.strip, sub.split('=', 1)))
element.update(dict(description_dict))
elements_w_descriptions.append(element)

converted_elements = []
for i, element in enumerate(elements_w_descriptions):
converted = ""
if element['tag']!="select":
converted += f'{element["center_point"]} <{element["tag_with_role"]}">'
converted += (
element["description"]
if len(element["description"].split()) < 30
else " ".join(element["description"].split()[:30]) + "..."
)
converted += f"</{element['tag']}>"
else:
converted += f'{element["center_point"]} <{element["tag_with_role"]}>'
converted += (
element["description"]
)
converted += f"</{element['tag']}>"
converted_elements.append(converted)

return converted_elements

Expand Down
2 changes: 1 addition & 1 deletion seeact_package/seeact/demo_utils/inference_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def engine_factory(api_key=None, model=None, **kwargs):
model="llava"
return OllamaEngine(model=model, **kwargs)
raise Exception(f"Unsupported model: {model}, currently supported models: \
gpt-4-vision-preview, gpt-4-turbo, gemini-1.5-pro-latest, llava")
gpt-4-vision-preview, gpt-4-turbo, gpt-4o, gemini-1.5-pro-latest, llava")

class Engine:
def __init__(
Expand Down
Loading

0 comments on commit 1b41279

Please sign in to comment.