Skip to content

Commit

Permalink
Direct object passing for decimator (#1591) (#1677)
Browse files Browse the repository at this point in the history
  • Loading branch information
dinhlongviolin1 authored Sep 7, 2024
1 parent af60cf7 commit e11dd0c
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 42 deletions.
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
rev: v1.11.2
hooks:
- id: mypy
additional_dependencies: [
Expand All @@ -16,7 +16,7 @@ repos:
- --exclude=(taipy/templates/|generate_pyi.py|tools)
- --follow-imports=skip
- repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.1.10
rev: v1.5.5
hooks:
- id: forbid-crlf
- id: remove-crlf
Expand All @@ -28,21 +28,21 @@ repos:
- --license-filepath
- .license-header
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
rev: v4.6.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-merge-conflict
- id: check-yaml
args: [--unsafe]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.8
rev: v0.6.4
hooks:
- id: ruff
args: [ --fix ]
- id: ruff-format
- repo: https://github.com/codespell-project/codespell
rev: v2.2.6
rev: v2.3.0
hooks:
- id: codespell
additional_dependencies: [tomli]
95 changes: 58 additions & 37 deletions taipy/gui/builder/_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class _Element(ABC):
_DEFAULT_PROPERTY = ""
__RE_INDEXED_PROPERTY = re.compile(r"^(.*?)__([\w\d]+)$")
_NEW_LAMBDA_NAME = "new_lambda"
_TAIPY_EMBEDDED_PREFIX = "_tp_embedded_"
_EMBEDED_PROPERTIES = ["decimator"]

def __new__(cls, *args, **kwargs):
obj = super(_Element, cls).__new__(cls)
Expand Down Expand Up @@ -93,47 +95,66 @@ def _parse_property(self, key: str, value: t.Any) -> t.Any:
return value
if isinstance(value, FunctionType):
if key.startswith("on_"):
if value.__name__.startswith("<"):
return value
return value.__name__

try:
source = inspect.findsource(value)
st = ast.parse("".join(source[0]))
lambda_by_name: t.Dict[str, ast.Lambda] = {}
_LambdaByName(self._ELEMENT_NAME, source[1], lambda_by_name).visit(st)
lambda_fn = lambda_by_name.get(
key,
lambda_by_name.get(_LambdaByName._DEFAULT_NAME, None) if key == self._DEFAULT_PROPERTY else None,
)
if lambda_fn is not None:
args = [arg.arg for arg in lambda_fn.args.args]
targets = [
compr.target.id # type: ignore[attr-defined]
for node in ast.walk(lambda_fn.body)
if isinstance(node, ast.ListComp)
for compr in node.generators
]
tree = _TransformVarToValue(self.__calling_frame, args + targets + _python_builtins).visit(
lambda_fn
)
ast.fix_missing_locations(tree)
if sys.version_info < (3, 9): # python 3.8 ast has no unparse
string_fd = io.StringIO()
_Unparser(tree, string_fd)
string_fd.seek(0)
lambda_text = string_fd.read()
else:
lambda_text = ast.unparse(tree)
lambda_name = f"__lambda_{uuid.uuid4().hex}"
self._lambdas[lambda_name] = lambda_text
return f'{{{lambda_name}({", ".join(args)})}}'
except Exception as e:
_warn("Error in lambda expression", e)
return value if value.__name__.startswith("<") else value.__name__
# Parse lambda function
if (lambda_name := self.__parse_lambda_property(key, value)) is not None:
return lambda_name
# Embed value in the caller frame
if not isinstance(value, str) and key in self._EMBEDED_PROPERTIES:
return self.__embed_object(value, is_expression=False)
if hasattr(value, "__name__"):
return str(getattr(value, "__name__")) # noqa: B009
return str(value)

def __parse_lambda_property(self, key: str, value: t.Any) -> t.Any:
try:
source = inspect.findsource(value)
st = ast.parse("".join(source[0]))
lambda_by_name: t.Dict[str, ast.Lambda] = {}
_LambdaByName(self._ELEMENT_NAME, source[1], lambda_by_name).visit(st)
lambda_fn = lambda_by_name.get(
key,
lambda_by_name.get(_LambdaByName._DEFAULT_NAME, None) if key == self._DEFAULT_PROPERTY else None,
)
if lambda_fn is None:
return None
args = [arg.arg for arg in lambda_fn.args.args]
targets = [
compr.target.id # type: ignore[attr-defined]
for node in ast.walk(lambda_fn.body)
if isinstance(node, ast.ListComp)
for compr in node.generators
]
tree = _TransformVarToValue(self.__calling_frame, args + targets + _python_builtins).visit(lambda_fn)
ast.fix_missing_locations(tree)
if sys.version_info < (3, 9): # python 3.8 ast has no unparse
string_fd = io.StringIO()
_Unparser(tree, string_fd)
string_fd.seek(0)
lambda_text = string_fd.read()
else:
lambda_text = ast.unparse(tree)
lambda_name = f"__lambda_{uuid.uuid4().hex}"
self._lambdas[lambda_name] = lambda_text
return f'{{{lambda_name}({", ".join(args)})}}'
except Exception as e:
_warn("Error in lambda expression", e)
return None

def __embed_object(self, obj: t.Any, is_expression=True) -> str:
"""Embed an object in the caller frame
Return the Taipy expression of the embedded object
"""
frame_locals = self.__calling_frame.f_locals
obj_var_name = self._TAIPY_EMBEDDED_PREFIX + obj.__class__.__name__
index = 0
while f"{obj_var_name}_{index}" in frame_locals:
index += 1
obj_var_name = f"{obj_var_name}_{index}"
frame_locals[obj_var_name] = obj
return f"{{{obj_var_name}}}" if is_expression else obj_var_name

@abstractmethod
def _render(self, gui: "Gui") -> str:
pass
Expand Down
23 changes: 23 additions & 0 deletions tests/gui/builder/test_embed_object.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright 2021-2024 Avaiga Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import taipy.gui.builder as tgb
from taipy.gui import Gui
from taipy.gui.data.decimator import ScatterDecimator


def test_decimator_embed_object(gui: Gui, test_client, helpers):
chart_builder = tgb.chart(decimator=ScatterDecimator()) # type: ignore[attr-defined] # noqa: B023
frame_locals = locals()
decimator_property = chart_builder._properties.get("decimator", None)
assert decimator_property is not None
assert decimator_property in frame_locals
assert isinstance(frame_locals[decimator_property], ScatterDecimator)

0 comments on commit e11dd0c

Please sign in to comment.