diff --git a/skyvern/webeye/persistent_sessions_manager.py b/skyvern/webeye/persistent_sessions_manager.py index bc262d828..e87bff3aa 100644 --- a/skyvern/webeye/persistent_sessions_manager.py +++ b/skyvern/webeye/persistent_sessions_manager.py @@ -53,7 +53,13 @@ async def create_session( organization_id=organization_id, ) - browser_context.on("close", lambda: self.sessions[organization_id].pop(session_id)) + def on_context_close(): + if organization_id in self.sessions: + self.sessions[organization_id].pop(session_id, None) + if not self.sessions[organization_id]: + self.sessions.pop(organization_id, None) + + browser_context.on("close", on_context_close) browser_state = BrowserState( pw=pw, diff --git a/skyvern/webeye/test_persistent_sessions_manager.py b/skyvern/webeye/test_persistent_sessions_manager.py index dbb060954..96900519d 100644 --- a/skyvern/webeye/test_persistent_sessions_manager.py +++ b/skyvern/webeye/test_persistent_sessions_manager.py @@ -18,8 +18,21 @@ class MockBrowserContext(AsyncMock): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.new_page = AsyncMock(return_value=MockPage()) - self.close = AsyncMock() self.tracing = MagicMock(stop=AsyncMock()) + self._close_callbacks = [] + + async def close_with_event(): + for callback in self._close_callbacks: + try: + callback() + except Exception as e: + print(f"Error in close callback: {e}") + self.close = AsyncMock(side_effect=close_with_event) + + def on(self, event_name, callback): + if event_name == "close": + print(f"Registering close callback: {callback}") + self._close_callbacks.append(callback) class MockPlaywright(AsyncMock): @@ -38,10 +51,11 @@ def __init__(self, *args, **kwargs): def mock_browser_factory(): """Mock the entire BrowserContextFactory class to prevent any real browser operations""" with patch("skyvern.webeye.persistent_sessions_manager.BrowserContextFactory") as mock: - # Mock the static method + browser_context = MockBrowserContext() + mock.create_browser_context = AsyncMock( return_value=( - MockBrowserContext(), + browser_context, MagicMock( video_artifacts=[], har_path=None, @@ -62,14 +76,18 @@ def mock_playwright(): @pytest.fixture -def mock_browser_state_class(): +def mock_browser_state_class(mock_browser_factory): """Mock the BrowserState class itself""" with patch("skyvern.webeye.persistent_sessions_manager.BrowserState") as mock: - mock.return_value = MagicMock(spec=BrowserState) - mock.return_value.get_or_create_page = AsyncMock(return_value=MockPage()) - mock.return_value.get_working_page = AsyncMock(return_value=MockPage()) - mock.return_value.close = AsyncMock() - mock.return_value.page = MockPage() + mock_instance = MagicMock(spec=BrowserState) + mock_instance.get_or_create_page = AsyncMock(return_value=MockPage()) + mock_instance.get_working_page = AsyncMock(return_value=MockPage()) + mock_instance.close = AsyncMock() + mock_instance.page = MockPage() + mock_instance.browser_context = mock_browser_factory.create_browser_context.return_value[0] + + # Make the mock class return our configured instance + mock.return_value = mock_instance yield mock @@ -165,4 +183,23 @@ async def test_multiple_organizations(sessions_manager): assert session_id1 in sessions_manager.get_active_session_ids(org_id1) assert session_id2 in sessions_manager.get_active_session_ids(org_id2) assert session_id1 not in sessions_manager.get_active_session_ids(org_id2) - assert session_id2 not in sessions_manager.get_active_session_ids(org_id1) \ No newline at end of file + assert session_id2 not in sessions_manager.get_active_session_ids(org_id1) + + +async def test_browser_context_close_removes_session(sessions_manager): + org_id = "test_org" + session_id, browser_state = await sessions_manager.create_session( + organization_id=org_id, + url="https://example.com" + ) + + # Verify session exists + assert session_id in sessions_manager.get_active_session_ids(org_id) + + # Simulate browser context close event + await browser_state.browser_context.close() + + # Verify the session was removed + assert session_id not in sessions_manager.get_active_session_ids(org_id) + assert org_id not in sessions_manager.sessions +