From 8976df6b37ecdc661af04592858e354f1a462817 Mon Sep 17 00:00:00 2001 From: Carson Date: Thu, 21 Dec 2023 11:44:55 -0600 Subject: [PATCH] When render_widget gets invalidated, invalidate anyone would reads the value --- shinywidgets/_shinywidgets.py | 69 +++++++++++++++++++++++++++++++---- 1 file changed, 62 insertions(+), 7 deletions(-) diff --git a/shinywidgets/_shinywidgets.py b/shinywidgets/_shinywidgets.py index 9d4a6a5..fedf715 100644 --- a/shinywidgets/_shinywidgets.py +++ b/shinywidgets/_shinywidgets.py @@ -29,6 +29,7 @@ from shiny import Session, reactive, req from shiny.http_staticfiles import StaticFiles from shiny.module import resolve_id +from shiny.reactive._core import Context, get_current_context from shiny.render.transformer import ( TransformerMetadata, ValueFn, @@ -209,20 +210,31 @@ def _restore_state(): # Implement @render_widget() # -------------------------------------------------------------------------------------------- +# TODO: pass along IT/OT to get proper typing? +UserValueFn = ValueFn[object | None] @output_transformer(default_ui=output_widget) async def WidgetTransformer( _meta: TransformerMetadata, - _fn: ValueFn[object | None], + _fn: UserValueFn, ) -> dict[str, Any] | None: value = await resolve_value_fn(_fn) + + # Attach value/widget attributes to user func so they can be accessed (in other reactive contexts) _fn.value = value # type: ignore _fn.widget = None # type: ignore + + # Invalidate any reactive contexts that have read these attributes + invalidate_contexts(_fn) + if value is None: return None + + # Ensure we have a widget & smart layout defaults widget = as_widget(value) widget, fill = set_layout_defaults(widget) _fn.widget = widget # type: ignore + return {"model_id": widget.model_id, "fill": fill} # type: ignore @@ -243,11 +255,15 @@ def render_widget( # Make the `res._value_fn.widget` attribute that we set in WidgetTransformer # accessible via `res.widget` def get_widget(*_: object) -> Optional[Widget]: - w = res._value_fn.widget # type: ignore - if w is None: + vfn = res._value_fn # pyright: ignore[reportFunctionMemberAccess] + vfn = register_current_context(vfn) + w = vfn.widget # type: ignore + if w is not None: + return w + # If widget is None, we're reading in a reactive context, other than the render context, throw a silent exception + if has_current_context(): req(False) - return None - return w + return None def set_widget(*_: object): raise RuntimeError("The widget attribute of a @render_widget function is read only.") @@ -255,15 +271,46 @@ def set_widget(*_: object): setattr(res.__class__, "widget", property(get_widget, set_widget)) def get_value(*_: object) -> object | None: - return res._value_fn.value # type: ignore + vfn = res._value_fn # pyright: ignore[reportFunctionMemberAccess] + vfn = register_current_context(vfn) + v = vfn.value # type: ignore + if v is not None: + return v + if has_current_context(): + req(False) + return None def set_value(*_: object): raise RuntimeError("The value attribute of a @render_widget function is read only.") setattr(res.__class__, "value", property(get_value, set_value)) + # Define these attributes directly on the user function so they're defined, even + # if that function hasn't been called yet. (we don't want to raise an exception in that case) + fn.widget = None # type: ignore + fn.value = None # type: ignore + return res + +def invalidate_contexts(fn: UserValueFn): + ctxs = getattr(fn, "_shinywidgets_contexts", set[Context]()) + for ctx in ctxs: + # TODO: at what point should we be removing contexts? + ctx.invalidate() + + +# If the widget/value is read in a reactive context, then we'll need to invalidate +# that context when the widget's value changes +def register_current_context(fn: UserValueFn): + if not has_current_context(): + return fn + ctxs = getattr(fn, "_shinywidgets_contexts", set[Context]()) + ctxs.add(get_current_context()) + fn._shinywidgets_contexts = ctxs # type: ignore + return fn + + def reactive_read(widget: Widget, names: Union[str, Sequence[str]]) -> Any: reactive_depend(widget, names) if isinstance(names, str): @@ -282,7 +329,7 @@ def reactive_depend( """ try: - ctx = reactive.get_current_context() # pyright: ignore[reportPrivateImportUsage] + ctx = get_current_context() except RuntimeError: raise RuntimeError("reactive_read() must be called within a reactive context") @@ -381,6 +428,14 @@ def set_layout_defaults(widget: Widget) -> Tuple[Widget, bool]: return (widget, fill) + +def has_current_context() -> bool: + try: + get_current_context() + return True + except RuntimeError: + return False + # similar to base::system.file() def package_dir(package: str) -> str: with tempfile.TemporaryDirectory():