Skip to content

Commit

Permalink
When render_widget gets invalidated, invalidate anyone would reads th…
Browse files Browse the repository at this point in the history
…e value
  • Loading branch information
cpsievert committed Dec 29, 2023
1 parent f5e5327 commit 8976df6
Showing 1 changed file with 62 additions and 7 deletions.
69 changes: 62 additions & 7 deletions shinywidgets/_shinywidgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from shiny import Session, reactive, req
from shiny.http_staticfiles import StaticFiles
from shiny.module import resolve_id
from shiny.reactive._core import Context, get_current_context
from shiny.render.transformer import (
TransformerMetadata,
ValueFn,
Expand Down Expand Up @@ -209,20 +210,31 @@ def _restore_state():
# Implement @render_widget()
# --------------------------------------------------------------------------------------------

# TODO: pass along IT/OT to get proper typing?
UserValueFn = ValueFn[object | None]

@output_transformer(default_ui=output_widget)
async def WidgetTransformer(
_meta: TransformerMetadata,
_fn: ValueFn[object | None],
_fn: UserValueFn,
) -> dict[str, Any] | None:
value = await resolve_value_fn(_fn)

# Attach value/widget attributes to user func so they can be accessed (in other reactive contexts)
_fn.value = value # type: ignore
_fn.widget = None # type: ignore

# Invalidate any reactive contexts that have read these attributes
invalidate_contexts(_fn)

if value is None:
return None

# Ensure we have a widget & smart layout defaults
widget = as_widget(value)
widget, fill = set_layout_defaults(widget)
_fn.widget = widget # type: ignore

return {"model_id": widget.model_id, "fill": fill} # type: ignore


Expand All @@ -243,27 +255,62 @@ def render_widget(
# Make the `res._value_fn.widget` attribute that we set in WidgetTransformer
# accessible via `res.widget`
def get_widget(*_: object) -> Optional[Widget]:
w = res._value_fn.widget # type: ignore
if w is None:
vfn = res._value_fn # pyright: ignore[reportFunctionMemberAccess]
vfn = register_current_context(vfn)
w = vfn.widget # type: ignore
if w is not None:
return w
# If widget is None, we're reading in a reactive context, other than the render context, throw a silent exception
if has_current_context():
req(False)
return None
return w
return None

def set_widget(*_: object):
raise RuntimeError("The widget attribute of a @render_widget function is read only.")

setattr(res.__class__, "widget", property(get_widget, set_widget))

def get_value(*_: object) -> object | None:
return res._value_fn.value # type: ignore
vfn = res._value_fn # pyright: ignore[reportFunctionMemberAccess]
vfn = register_current_context(vfn)
v = vfn.value # type: ignore
if v is not None:
return v
if has_current_context():
req(False)
return None

def set_value(*_: object):
raise RuntimeError("The value attribute of a @render_widget function is read only.")

setattr(res.__class__, "value", property(get_value, set_value))

# Define these attributes directly on the user function so they're defined, even
# if that function hasn't been called yet. (we don't want to raise an exception in that case)
fn.widget = None # type: ignore
fn.value = None # type: ignore

return res


def invalidate_contexts(fn: UserValueFn):
ctxs = getattr(fn, "_shinywidgets_contexts", set[Context]())
for ctx in ctxs:
# TODO: at what point should we be removing contexts?
ctx.invalidate()


# If the widget/value is read in a reactive context, then we'll need to invalidate
# that context when the widget's value changes
def register_current_context(fn: UserValueFn):
if not has_current_context():
return fn
ctxs = getattr(fn, "_shinywidgets_contexts", set[Context]())
ctxs.add(get_current_context())
fn._shinywidgets_contexts = ctxs # type: ignore
return fn


def reactive_read(widget: Widget, names: Union[str, Sequence[str]]) -> Any:
reactive_depend(widget, names)
if isinstance(names, str):
Expand All @@ -282,7 +329,7 @@ def reactive_depend(
"""

try:
ctx = reactive.get_current_context() # pyright: ignore[reportPrivateImportUsage]
ctx = get_current_context()
except RuntimeError:
raise RuntimeError("reactive_read() must be called within a reactive context")

Expand Down Expand Up @@ -381,6 +428,14 @@ def set_layout_defaults(widget: Widget) -> Tuple[Widget, bool]:

return (widget, fill)


def has_current_context() -> bool:
try:
get_current_context()
return True
except RuntimeError:
return False

# similar to base::system.file()
def package_dir(package: str) -> str:
with tempfile.TemporaryDirectory():
Expand Down

0 comments on commit 8976df6

Please sign in to comment.