Skip to content

Commit

Permalink
Test browser close event callback
Browse files Browse the repository at this point in the history
  • Loading branch information
satansdeer committed Dec 16, 2024
1 parent 2ecbd8c commit 7364311
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 11 deletions.
8 changes: 7 additions & 1 deletion skyvern/webeye/persistent_sessions_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,13 @@ async def create_session(
organization_id=organization_id,
)

browser_context.on("close", lambda: self.sessions[organization_id].pop(session_id))
def on_context_close():
if organization_id in self.sessions:
self.sessions[organization_id].pop(session_id, None)
if not self.sessions[organization_id]:
self.sessions.pop(organization_id, None)

browser_context.on("close", on_context_close)

browser_state = BrowserState(
pw=pw,
Expand Down
57 changes: 47 additions & 10 deletions skyvern/webeye/test_persistent_sessions_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,21 @@ class MockBrowserContext(AsyncMock):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.new_page = AsyncMock(return_value=MockPage())
self.close = AsyncMock()
self.tracing = MagicMock(stop=AsyncMock())
self._close_callbacks = []

async def close_with_event():
for callback in self._close_callbacks:
try:
callback()
except Exception as e:
print(f"Error in close callback: {e}")
self.close = AsyncMock(side_effect=close_with_event)

def on(self, event_name, callback):
if event_name == "close":
print(f"Registering close callback: {callback}")
self._close_callbacks.append(callback)


class MockPlaywright(AsyncMock):
Expand All @@ -38,10 +51,11 @@ def __init__(self, *args, **kwargs):
def mock_browser_factory():
"""Mock the entire BrowserContextFactory class to prevent any real browser operations"""
with patch("skyvern.webeye.persistent_sessions_manager.BrowserContextFactory") as mock:
# Mock the static method
browser_context = MockBrowserContext()

mock.create_browser_context = AsyncMock(
return_value=(
MockBrowserContext(),
browser_context,
MagicMock(
video_artifacts=[],
har_path=None,
Expand All @@ -62,14 +76,18 @@ def mock_playwright():


@pytest.fixture
def mock_browser_state_class():
def mock_browser_state_class(mock_browser_factory):
"""Mock the BrowserState class itself"""
with patch("skyvern.webeye.persistent_sessions_manager.BrowserState") as mock:
mock.return_value = MagicMock(spec=BrowserState)
mock.return_value.get_or_create_page = AsyncMock(return_value=MockPage())
mock.return_value.get_working_page = AsyncMock(return_value=MockPage())
mock.return_value.close = AsyncMock()
mock.return_value.page = MockPage()
mock_instance = MagicMock(spec=BrowserState)
mock_instance.get_or_create_page = AsyncMock(return_value=MockPage())
mock_instance.get_working_page = AsyncMock(return_value=MockPage())
mock_instance.close = AsyncMock()
mock_instance.page = MockPage()
mock_instance.browser_context = mock_browser_factory.create_browser_context.return_value[0]

# Make the mock class return our configured instance
mock.return_value = mock_instance
yield mock


Expand Down Expand Up @@ -165,4 +183,23 @@ async def test_multiple_organizations(sessions_manager):
assert session_id1 in sessions_manager.get_active_session_ids(org_id1)
assert session_id2 in sessions_manager.get_active_session_ids(org_id2)
assert session_id1 not in sessions_manager.get_active_session_ids(org_id2)
assert session_id2 not in sessions_manager.get_active_session_ids(org_id1)
assert session_id2 not in sessions_manager.get_active_session_ids(org_id1)


async def test_browser_context_close_removes_session(sessions_manager):
org_id = "test_org"
session_id, browser_state = await sessions_manager.create_session(
organization_id=org_id,
url="https://example.com"
)

# Verify session exists
assert session_id in sessions_manager.get_active_session_ids(org_id)

# Simulate browser context close event
await browser_state.browser_context.close()

# Verify the session was removed
assert session_id not in sessions_manager.get_active_session_ids(org_id)
assert org_id not in sessions_manager.sessions

0 comments on commit 7364311

Please sign in to comment.