Skip to content

Commit

Permalink
ENH Adds default success into config plugin (#2189)
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas J. Fan <[email protected]>
  • Loading branch information
thomasjpfan authored Feb 14, 2024
1 parent 060d480 commit d840369
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 0 deletions.
10 changes: 10 additions & 0 deletions flytekit/clients/auth/default_html.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
from contextlib import suppress


def get_default_success_html(endpoint: str) -> str:
from flytekit.configuration.plugin import get_plugin

with suppress(AttributeError):
success_html = get_plugin().get_auth_success_html(endpoint)
if success_html is not None:
return success_html

return f"""
<html>
<head>
Expand Down
9 changes: 9 additions & 0 deletions flytekit/configuration/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def secret_requires_group() -> bool:
def get_default_image() -> Optional[str]:
"""Get default image. Return None to use the images from flytekit.configuration.DefaultImages"""

@staticmethod
def get_auth_success_html(endpoint: str) -> Optional[str]:
"""Get default success html for auth. Return None to use flytekit's default success html."""


class FlytekitPlugin:
@staticmethod
Expand Down Expand Up @@ -80,6 +84,11 @@ def get_default_image() -> Optional[str]:
"""Get default image. Return None to use the images from flytekit.configuration.DefaultImages"""
return None

@staticmethod
def get_auth_success_html(endpoint: str) -> Optional[str]:
"""Get default success html. Return None to use flytekit's default success html."""
return None


def _get_plugin_from_entrypoint():
"""Get plugin from entrypoint."""
Expand Down
15 changes: 15 additions & 0 deletions tests/flytekit/unit/clients/auth/test_default_html.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from unittest.mock import Mock

import flytekit
from flytekit.clients.auth.default_html import get_default_success_html


Expand All @@ -16,3 +19,15 @@ def test_default_html():
</html>
"""
) # noqa


def test_default_html_plugin(monkeypatch):
def get_auth_success_html(endpoint):
return f"<html><head><title>Successful Auth into {endpoint}!</title></head></html>"

plugin_mock = Mock()
plugin_mock.get_auth_success_html.side_effect = get_auth_success_html
mock_global_plugin = {"plugin": plugin_mock}
monkeypatch.setattr(flytekit.configuration.plugin, "_GLOBAL_CONFIG", mock_global_plugin)

assert get_default_success_html("flyte.org") == get_auth_success_html("flyte.org")

0 comments on commit d840369

Please sign in to comment.