Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve GT.save() usability #499

Merged
merged 11 commits into from
Nov 22, 2024
139 changes: 41 additions & 98 deletions great_tables/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,24 +177,6 @@ def as_raw_html(
DebugDumpOptions: TypeAlias = Literal["zoom", "width_resize", "final_resize"]


class _NoOpDriverCtx:
"""Context manager that no-ops entering a webdriver(options=...) instance."""

def __init__(self, driver: webdriver.Remote):
self.driver = driver

def __call__(self, options):
# no-op what is otherwise instantiating webdriver with options,
# since a webdriver instance was already passed on init
return self

def __enter__(self):
return self.driver

def __exit__(self, *args):
pass


def save(
self: GT,
file: Path | str,
Expand All @@ -206,7 +188,8 @@ def save(
debug_port: None | int = None,
encoding: str = "utf-8",
_debug_dump: DebugDumpOptions | None = None,
) -> None:
**params,
) -> GTSelf:
"""
Produce a high-resolution image file or PDF of the table.

Expand Down Expand Up @@ -239,17 +222,25 @@ def save(
debug_port
Port number to use for debugging. By default no debugging port is opened.
encoding
The encoding used when writing temporary files.
The character encoding used for the HTML content.
_debug_dump
Whether the saved image should be a big browser window, with key elements outlined. This is
helpful for debugging this function's resizing, cropping heuristics. This is an internal
parameter and subject to change.
**params
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it okay if we remove the **params piece of this PR? I'm onboard with everything else. Because **params only applies to when we go from a .png -> something else, it exposes PIL as part of the user API.

If we leave it out, we can swap out PIL down the road. As an alternative could we tell users that we're using Pillow? We could even tell them about PIL open and save if needed (so they can go from .png -> anything PIL supports on their own if more customization is needed).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. I agree that we should remove **params, making Pillow an internal dependency rather than a public one.

Regarding hints, I believe that diligent readers likely already recognize we use Pillow under the hood and would consult its documentation for further customization if needed.

This aligns with our goal of providing a user-friendly way to save generated tables rather than focusing on creating highly customized table figures in various formats🤔.

Additional parameters supported by
[Image.save()](https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.save)
in Pillow. For instance, when saving the table as a
[PNG](https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html#png), you can
adjust the `compress_level` to balance between speed and compression. The `compress_level`
ranges from 0 to 9, where 1 offers the best speed, 9 offers maximum compression, and 0
applies no compression. The default value is 6.

Returns
-------
None
This function does not return anything; it simply saves the image to the specified file
path.
GT
The GT object is returned. This is the same object that the method is called on so that we
can facilitate method chaining.

Details
-------
Expand All @@ -271,95 +262,54 @@ def save(
```

"""
import base64

# Import the required packages
_try_import(name="selenium", pip_install_line="pip install selenium")

from selenium import webdriver
from ._utils_selenium import _get_web_driver

if selector != "table":
raise NotImplementedError("Currently, only selector='table' is supported.")

if isinstance(file, Path):
file = str(file)

# If there is no file extension, add the .png extension
if not Path(file).suffix:
file += ".png"
file = str(Path(file).with_suffix(".png"))

# Get the HTML content from the displayed output
html_content = as_raw_html(self)

# Set the webdriver and options based on the chosen browser (`web_driver=` argument)
if isinstance(web_driver, webdriver.Remote):
wdriver = _NoOpDriverCtx(web_driver)
wd_options = None

elif web_driver == "chrome":
wdriver = webdriver.Chrome
wd_options = webdriver.ChromeOptions()
elif web_driver == "safari":
wdriver = webdriver.Safari
wd_options = webdriver.SafariOptions()
elif web_driver == "firefox":
wdriver = webdriver.Firefox
wd_options = webdriver.FirefoxOptions()
elif web_driver == "edge":
wdriver = webdriver.Edge
wd_options = webdriver.EdgeOptions()
else:
raise ValueError(f"Unsupported web driver: {web_driver}")

# specify headless flag ----
if web_driver in {"firefox", "edge"}:
wd_options.add_argument("--headless")
elif web_driver == "chrome":
# Operate all webdrivers in headless mode
wd_options.add_argument("--headless=new")
else:
# note that safari currently doesn't support headless browsing
pass

if debug_port:
if web_driver == "chrome":
wd_options.add_argument(f"--remote-debugging-port={debug_port}")
elif web_driver == "firefox":
# TODO: not sure how to connect to this session on firefox?
wd_options.add_argument(f"--start-debugger-server {debug_port}")
else:
warnings.warn("debug_port argument only supported on chrome and firefox")
debug_port = None
wdriver = _get_web_driver(web_driver)

# run browser ----
with (
tempfile.TemporaryDirectory() as tmp_dir,
wdriver(options=wd_options) as headless_browser,
):
with wdriver(debug_port=debug_port) as headless_browser:
headless_browser.set_window_size(*window_size)
encoded = base64.b64encode(html_content.encode(encoding=encoding)).decode(encoding=encoding)
headless_browser.get(f"data:text/html;base64,{encoded}")

# Write the HTML content to the temp file
with open(f"{tmp_dir}/table.html", "w", encoding=encoding) as temp_file:
temp_file.write(html_content)
_save_screenshot(headless_browser, scale, file, debug=_debug_dump, **params)

# Open the HTML file in the headless browser
headless_browser.set_window_size(*window_size)
headless_browser.get("file://" + temp_file.name)
if debug_port and web_driver not in {"chrome", "firefox"}:
warnings.warn("debug_port argument only supported on chrome and firefox")
debug_port = None

_save_screenshot(headless_browser, scale, file, debug=_debug_dump)
if debug_port:
input(
f"Currently debugging on port {debug_port}.\n\n"
"If you are using Chrome, enter chrome://inspect to preview the headless browser."
"Other browsers may have different ways to preview headless browser sessions.\n\n"
"Press enter to continue."
)

if debug_port:
input(
f"Currently debugging on port {debug_port}.\n\n"
"If you are using Chrome, enter chrome://inspect to preview the headless browser."
"Other browsers may have different ways to preview headless browser sessions.\n\n"
"Press enter to continue."
)
return self


def _save_screenshot(
driver: webdriver.Chrome, scale, path: str, debug: DebugDumpOptions | None
driver: webdriver.Chrome, scale: float, path: str, debug: DebugDumpOptions | None, **params
) -> None:
from io import BytesIO
from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.ui import WebDriverWait

# Based on: https://stackoverflow.com/a/52572919/
# In some headless browsers, element position and width do not always reflect
Expand All @@ -372,7 +322,6 @@ def _save_screenshot(
#
# I can't say for sure whether the final sleep is needed. Only that it seems like
# on CI with firefox sometimes the final screencapture is wider than necessary.

original_size = driver.get_window_size()

# set table zoom ----
Expand Down Expand Up @@ -423,19 +372,13 @@ def _save_screenshot(
if debug == "final_resize":
return _dump_debug_screenshot(driver, path)

el = driver.find_element(by=By.TAG_NAME, value="body")
el = WebDriverWait(driver, 1).until(EC.visibility_of_element_located((By.TAG_NAME, "body")))

time.sleep(0.05)

if path.endswith(".png"):
el.screenshot(path)
else:
_try_import(name="PIL", pip_install_line="pip install pillow")
_try_import(name="PIL", pip_install_line="pip install pillow")

from PIL import Image
from PIL import Image

# convert to other formats (e.g. pdf, bmp) using PIL
Image.open(fp=BytesIO(el.screenshot_as_png)).save(fp=path)
Image.open(fp=BytesIO(el.screenshot_as_png)).save(fp=path, **params)


def _dump_debug_screenshot(driver, path):
Expand Down
91 changes: 91 additions & 0 deletions great_tables/_utils_selenium.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from __future__ import annotations

from types import TracebackType
from typing import Literal
from typing_extensions import TypeAlias
from selenium import webdriver

# Create a list of all selenium webdrivers
WebDrivers: TypeAlias = Literal[
"chrome",
"firefox",
"safari",
"edge",
]


class _BaseWebDriver:

def __init__(self, debug_port: int | None = None):
self.debug_port = debug_port
self.wd_options = self.cls_wd_options()
self.add_arguments()
self.driver = self.cls_driver(self.wd_options)

def add_arguments(self): ...

def __enter__(self) -> WebDrivers | webdriver.Remote:
return self.driver

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> bool | None:
self.driver.quit()


class _ChromeWebDriver(_BaseWebDriver):
cls_driver = webdriver.Chrome
cls_wd_options = webdriver.ChromeOptions

def add_arguments(self):
self.wd_options.add_argument("--headless=new")
if self.debug_port is not None:
self.wd_options.add_argument(f"--remote-debugging-port={self.debug_port}")

Check warning on line 46 in great_tables/_utils_selenium.py

View check run for this annotation

Codecov / codecov/patch

great_tables/_utils_selenium.py#L46

Added line #L46 was not covered by tests


class _SafariWebDriver(_BaseWebDriver):
cls_driver = webdriver.Safari
cls_wd_options = webdriver.SafariOptions


class _FirefoxWebDriver(_BaseWebDriver):
cls_driver = webdriver.Firefox
cls_wd_options = webdriver.FirefoxOptions

def add_arguments(self):
self.wd_options.add_argument("--headless")
if self.debug_port is not None:
self.wd_options.add_argument(f"--start-debugger-server {self.debug_port}")

Check warning on line 61 in great_tables/_utils_selenium.py

View check run for this annotation

Codecov / codecov/patch

great_tables/_utils_selenium.py#L59-L61

Added lines #L59 - L61 were not covered by tests


class _EdgeWebDriver(_BaseWebDriver):
cls_driver = webdriver.Edge
cls_wd_options = webdriver.EdgeOptions

def add_arguments(self):
self.wd_options.add_argument("--headless")

Check warning on line 69 in great_tables/_utils_selenium.py

View check run for this annotation

Codecov / codecov/patch

great_tables/_utils_selenium.py#L69

Added line #L69 was not covered by tests


def no_op_callable(web_driver: webdriver.Remote):
def wrapper(*args, **kwargs):
return web_driver

return wrapper


def _get_web_driver(web_driver: WebDrivers | webdriver.Remote):
if isinstance(web_driver, webdriver.Remote):
return no_op_callable(web_driver)
elif web_driver == "chrome":
return _ChromeWebDriver
elif web_driver == "safari":
return _SafariWebDriver
elif web_driver == "firefox":
return _FirefoxWebDriver
elif web_driver == "edge":
return _EdgeWebDriver
else:
raise ValueError(f"Unsupported web driver: {web_driver}")
39 changes: 39 additions & 0 deletions tests/test__utils_selenium.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest

from great_tables._utils_selenium import (
_get_web_driver,
no_op_callable,
_ChromeWebDriver,
_SafariWebDriver,
_FirefoxWebDriver,
_EdgeWebDriver,
)


def test_no_op_callable():
"""
The test should cover the scenario of obtaining a remote driver in `_get_web_driver`.
"""
fake_input = object()
f = no_op_callable(fake_input)
assert f(1, x="x") is fake_input


@pytest.mark.parametrize(
"web_driver,Driver",
[
("chrome", _ChromeWebDriver),
("safari", _SafariWebDriver),
("firefox", _FirefoxWebDriver),
("edge", _EdgeWebDriver),
],
)
def test_get_web_driver(web_driver, Driver):
assert _get_web_driver(web_driver) is Driver


def test_get_web_driver_raise():
fake_web_driver = "fake_web_driver"
with pytest.raises(ValueError) as exc_info:
_get_web_driver(fake_web_driver)
assert exc_info.value.args[0] == f"Unsupported web driver: {fake_web_driver}"