Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add upload file button #168

Merged
merged 13 commits into from
Mar 21, 2024
9 changes: 9 additions & 0 deletions examples/02_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ def main() -> None:
initial_value=3,
marks=((0, "0"), (5, "5"), (7, "7"), 10),
)
gui_upload_button = server.add_gui_upload_button(
"Upload", icon=viser.Icon.UPLOAD
)

@gui_upload_button.on_upload
def _(_) -> None:
"""Callback for when a file is uploaded."""
file = gui_upload_button.value
print(file.name, len(file.content), "bytes")

# Pre-generate a point cloud to send.
point_positions = onp.random.uniform(low=-1.0, high=1.0, size=(5000, 3))
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ disable_error_code="var-annotated" # Common source of mypy + numpy false positi
exclude = ["./docs/**/*", "./examples/assets/**/*", "./src/viser/client/.nodeenv", "./build"]

[tool.ruff]
select = [
lint.select = [
"E", # pycodestyle errors.
"F", # Pyflakes rules.
"PLC", # Pylint convention warnings.
Expand All @@ -78,7 +78,7 @@ select = [
"PLW", # Pylint warnings.
"I", # Import sorting.
]
ignore = [
lint.ignore = [
"E741", # Ambiguous variable name. (l, O, or I)
"E501", # Line too long.
"E721", # Do not compare types, use `isinstance()`.
Expand Down
236 changes: 227 additions & 9 deletions src/viser/_gui_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import abc
import dataclasses
import functools
import threading
import time
import warnings
Expand All @@ -18,12 +19,21 @@
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
overload,
)

import numpy as onp
from typing_extensions import Literal, LiteralString
from typing_extensions import (
Literal,
LiteralString,
TypedDict,
get_args,
get_origin,
get_type_hints,
)

from . import _messages
from ._gui_handles import (
Expand All @@ -37,14 +47,17 @@
GuiMarkdownHandle,
GuiModalHandle,
GuiTabGroupHandle,
GuiUploadButtonHandle,
SupportsRemoveProtocol,
UploadedFile,
_GuiHandleState,
_GuiInputHandle,
_make_unique_id,
)
from ._icons import base64_from_icon
from ._icons_enum import IconName
from ._message_api import MessageApi, cast_vector
from ._messages import FileTransferPartAck

if TYPE_CHECKING:
from .infra import ClientId
Expand Down Expand Up @@ -106,6 +119,64 @@ def _apply_default_order(order: Optional[float]) -> float:
return _global_order_counter


@functools.lru_cache(maxsize=None)
def get_type_hints_cached(cls: Type[Any]) -> Dict[str, Any]:
return get_type_hints(cls) # type: ignore


def cast_value(tp, value):
"""Cast a value to a type, or raise a TypeError if it cannot be cast."""
origin = get_origin(tp)

if (origin is tuple or tp is tuple) and isinstance(value, list):
return cast_value(tp, tuple(value))

if origin is Literal:
for val in get_args(tp):
try:
value_casted = cast_value(type(val), value)
if val == value_casted:
return value_casted
except ValueError:
pass
except TypeError:
pass
raise TypeError(f"Value {value} is not in {get_args(tp)}")

if origin is Union:
for t in get_args(tp):
try:
return cast_value(t, value)
except ValueError:
pass
except TypeError:
pass
raise TypeError(f"Value {value} is not in {tp}")

if tp in {int, float, bool, str}:
return tp(value)

if dataclasses.is_dataclass(tp):
return tp(
**{k: cast_value(v, value[k]) for k, v in get_type_hints_cached(tp).items()}
)

if isinstance(value, tp):
return value

raise TypeError(f"Cannot cast value {value} to type {tp}")


class FileUploadState(TypedDict):
filename: str
mime_type: str
part_count: int
parts: Dict[int, bytes]
total_bytes: int
transferred_bytes: int
lock: threading.Lock


class GuiApi(abc.ABC):
_target_container_from_thread_id: Dict[int, str] = {}
"""ID of container to put GUI elements into."""
Expand All @@ -117,9 +188,18 @@ def __init__(self) -> None:
self._container_handle_from_id: Dict[str, GuiContainerProtocol] = {
"root": _RootGuiContainer({})
}
self._current_file_upload_states: Dict[str, FileUploadState] = {}

self._get_api()._message_handler.register_handler(
_messages.GuiUpdateMessage, self._handle_gui_updates
)
self._get_api()._message_handler.register_handler(
_messages.FileTransferStart, self._handle_file_transfer_start
)
self._get_api()._message_handler.register_handler(
_messages.FileTransferPart,
self._handle_file_transfer_part,
)

def _handle_gui_updates(
self, client_id: ClientId, message: _messages.GuiUpdateMessage
Expand All @@ -139,14 +219,7 @@ def _handle_gui_updates(
# Do some type casting. This is brittle, but necessary when we
# expect floats but the Javascript side gives us integers.
if prop_name == "value":
if handle_state.typ is tuple:
assert len(prop_value) == len(handle_state.value)
prop_value = tuple(
type(handle_state.value[i])(prop_value[i])
for i in range(len(prop_value))
)
else:
prop_value = handle_state.typ(prop_value)
prop_value = cast_value(handle_state.typ, prop_value)

# Update handle property.
if current_value != prop_value:
Expand Down Expand Up @@ -179,6 +252,83 @@ def _handle_gui_updates(
if handle_state.sync_cb is not None:
handle_state.sync_cb(client_id, updates_cast)

def _handle_file_transfer_start(
self, client_id: ClientId, message: _messages.FileTransferStart
) -> None:
if message.source_component_id not in self._gui_handle_from_id:
brentyi marked this conversation as resolved.
Show resolved Hide resolved
return
self._current_file_upload_states[message.transfer_uuid] = {
"filename": message.filename,
"mime_type": message.mime_type,
"part_count": message.part_count,
"parts": {},
"total_bytes": message.size_bytes,
"transferred_bytes": 0,
"lock": threading.Lock(),
}

def _handle_file_transfer_part(
self, client_id: ClientId, message: _messages.FileTransferPart
) -> None:
if message.transfer_uuid not in self._current_file_upload_states:
return
assert message.source_component_id in self._gui_handle_from_id

state = self._current_file_upload_states[message.transfer_uuid]
state["parts"][message.part] = message.content
total_bytes = state["total_bytes"]

with state["lock"]:
state["transferred_bytes"] += len(message.content)

# Send ack to the server.
self._get_api()._queue(
FileTransferPartAck(
source_component_id=message.source_component_id,
transfer_uuid=message.transfer_uuid,
transferred_bytes=state["transferred_bytes"],
total_bytes=total_bytes,
)
)

if state["transferred_bytes"] < total_bytes:
return

# Finish the upload.
assert state["transferred_bytes"] == total_bytes
state = self._current_file_upload_states.pop(message.transfer_uuid)

handle = self._gui_handle_from_id.get(message.source_component_id, None)
if handle is None:
return

handle_state = handle._impl

value = UploadedFile(
name=state["filename"],
content=b"".join(state["parts"][i] for i in range(state["part_count"])),
)

# Update state.
with self._get_api()._atomic_lock:
handle_state.value = value
handle_state.update_timestamp = time.time()

# Trigger callbacks.
for cb in handle_state.update_cb:
from ._viser import ClientHandle, ViserServer

# Get the handle of the client that triggered this event.
api = self._get_api()
if isinstance(api, ClientHandle):
client = api
elif isinstance(api, ViserServer):
client = api.get_clients()[client_id]
else:
assert False

cb(GuiEvent(client, client_id, handle))

def _get_container_id(self) -> str:
"""Get container ID associated with the current thread."""
return self._target_container_from_thread_id.get(threading.get_ident(), "root")
Expand Down Expand Up @@ -413,6 +563,74 @@ def add_gui_button(
)._impl
)

def add_gui_upload_button(
self,
label: str,
disabled: bool = False,
visible: bool = True,
hint: Optional[str] = None,
color: Optional[
Literal[
"dark",
"gray",
"red",
"pink",
"grape",
"violet",
"indigo",
"blue",
"cyan",
"green",
"lime",
"yellow",
"orange",
"teal",
]
] = None,
icon: Optional[IconName] = None,
mime_type: str = "*/*",
order: Optional[float] = None,
) -> GuiUploadButtonHandle:
"""Add a button to the GUI. The value of this input is set to `True` every time
it is clicked; to detect clicks, we can manually set it back to `False`.

Args:
label: Label to display on the button.
visible: Whether the button is visible.
disabled: Whether the button is disabled.
hint: Optional hint to display on hover.
color: Optional color to use for the button.
icon: Optional icon to display on the button.
mime_type: Optional MIME type to filter the files that can be uploaded.
order: Optional ordering, smallest values will be displayed first.

Returns:
A handle that can be used to interact with the GUI element.
"""

# Re-wrap the GUI handle with a button interface.
id = _make_unique_id()
order = _apply_default_order(order)
return GuiUploadButtonHandle(
self._create_gui_input(
value=UploadedFile("", b""),
message=_messages.GuiAddUploadButtonMessage(
value=None,
disabled=disabled,
visible=visible,
order=order,
id=id,
label=label,
container_id=self._get_container_id(),
hint=hint,
color=color,
mime_type=mime_type,
icon_base64=None if icon is None else base64_from_icon(icon),
),
is_button=True,
)._impl
)

# The TLiteralString overload tells pyright to resolve the value type to a Literal
# whenever possible.
#
Expand Down
25 changes: 25 additions & 0 deletions src/viser/_gui_handles.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,31 @@ def on_click(
return func


@dataclasses.dataclass
class UploadedFile:
"""Result of a file upload."""

name: str
"""Name of the file."""
content: bytes
"""Contents of the file."""


@dataclasses.dataclass
class GuiUploadButtonHandle(_GuiInputHandle[UploadedFile]):
"""Handle for an upload file button in our visualizer.

The `.value` attribute will be updated with the contents of uploaded files.
"""

def on_upload(
self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], None]
) -> Callable[[GuiEvent[TGuiHandle]], None]:
"""Attach a function to call when a button is pressed. Happens in a thread."""
self._impl.update_cb.append(func)
return func


@dataclasses.dataclass
class GuiButtonGroupHandle(_GuiInputHandle[StringType], Generic[StringType]):
"""Handle for a button group input in our visualizer.
Expand Down
Loading
Loading