Skip to content

Commit

Permalink
Add upload file button (#168)
Browse files Browse the repository at this point in the history
* Add upload file button

* Fix typing

* Add file upload progressbar

* Add ack messages for uploadfilebutton

* Race condition, duplicate handling fixes, minor message tweaks

* nit types etc

* Formatting

* Add UploadButton.tsx, comments

* remove unused imports

* remove unnecessary dataclass serialization logic

* Add icon to example

---------

Co-authored-by: Brent Yi <[email protected]>
  • Loading branch information
jkulhanek and brentyi authored Mar 21, 2024
1 parent 2ec40b9 commit 96fb516
Show file tree
Hide file tree
Showing 14 changed files with 1,277 additions and 672 deletions.
9 changes: 9 additions & 0 deletions examples/02_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ def main() -> None:
initial_value=3,
marks=((0, "0"), (5, "5"), (7, "7"), 10),
)
gui_upload_button = server.add_gui_upload_button(
"Upload", icon=viser.Icon.UPLOAD
)

@gui_upload_button.on_upload
def _(_) -> None:
"""Callback for when a file is uploaded."""
file = gui_upload_button.value
print(file.name, len(file.content), "bytes")

# Pre-generate a point cloud to send.
point_positions = onp.random.uniform(low=-1.0, high=1.0, size=(5000, 3))
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ disable_error_code="var-annotated" # Common source of mypy + numpy false positi
exclude = ["./docs/**/*", "./examples/assets/**/*", "./src/viser/client/.nodeenv", "./build"]

[tool.ruff]
select = [
lint.select = [
"E", # pycodestyle errors.
"F", # Pyflakes rules.
"PLC", # Pylint convention warnings.
Expand All @@ -78,7 +78,7 @@ select = [
"PLW", # Pylint warnings.
"I", # Import sorting.
]
ignore = [
lint.ignore = [
"E741", # Ambiguous variable name. (l, O, or I)
"E501", # Line too long.
"E721", # Do not compare types, use `isinstance()`.
Expand Down
236 changes: 227 additions & 9 deletions src/viser/_gui_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import abc
import dataclasses
import functools
import threading
import time
import warnings
Expand All @@ -18,12 +19,21 @@
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
overload,
)

import numpy as onp
from typing_extensions import Literal, LiteralString
from typing_extensions import (
Literal,
LiteralString,
TypedDict,
get_args,
get_origin,
get_type_hints,
)

from . import _messages
from ._gui_handles import (
Expand All @@ -37,14 +47,17 @@
GuiMarkdownHandle,
GuiModalHandle,
GuiTabGroupHandle,
GuiUploadButtonHandle,
SupportsRemoveProtocol,
UploadedFile,
_GuiHandleState,
_GuiInputHandle,
_make_unique_id,
)
from ._icons import base64_from_icon
from ._icons_enum import IconName
from ._message_api import MessageApi, cast_vector
from ._messages import FileTransferPartAck

if TYPE_CHECKING:
from .infra import ClientId
Expand Down Expand Up @@ -106,6 +119,64 @@ def _apply_default_order(order: Optional[float]) -> float:
return _global_order_counter


@functools.lru_cache(maxsize=None)
def get_type_hints_cached(cls: Type[Any]) -> Dict[str, Any]:
return get_type_hints(cls) # type: ignore


def cast_value(tp, value):
"""Cast a value to a type, or raise a TypeError if it cannot be cast."""
origin = get_origin(tp)

if (origin is tuple or tp is tuple) and isinstance(value, list):
return cast_value(tp, tuple(value))

if origin is Literal:
for val in get_args(tp):
try:
value_casted = cast_value(type(val), value)
if val == value_casted:
return value_casted
except ValueError:
pass
except TypeError:
pass
raise TypeError(f"Value {value} is not in {get_args(tp)}")

if origin is Union:
for t in get_args(tp):
try:
return cast_value(t, value)
except ValueError:
pass
except TypeError:
pass
raise TypeError(f"Value {value} is not in {tp}")

if tp in {int, float, bool, str}:
return tp(value)

if dataclasses.is_dataclass(tp):
return tp(
**{k: cast_value(v, value[k]) for k, v in get_type_hints_cached(tp).items()}
)

if isinstance(value, tp):
return value

raise TypeError(f"Cannot cast value {value} to type {tp}")


class FileUploadState(TypedDict):
filename: str
mime_type: str
part_count: int
parts: Dict[int, bytes]
total_bytes: int
transferred_bytes: int
lock: threading.Lock


class GuiApi(abc.ABC):
_target_container_from_thread_id: Dict[int, str] = {}
"""ID of container to put GUI elements into."""
Expand All @@ -117,9 +188,18 @@ def __init__(self) -> None:
self._container_handle_from_id: Dict[str, GuiContainerProtocol] = {
"root": _RootGuiContainer({})
}
self._current_file_upload_states: Dict[str, FileUploadState] = {}

self._get_api()._message_handler.register_handler(
_messages.GuiUpdateMessage, self._handle_gui_updates
)
self._get_api()._message_handler.register_handler(
_messages.FileTransferStart, self._handle_file_transfer_start
)
self._get_api()._message_handler.register_handler(
_messages.FileTransferPart,
self._handle_file_transfer_part,
)

def _handle_gui_updates(
self, client_id: ClientId, message: _messages.GuiUpdateMessage
Expand All @@ -139,14 +219,7 @@ def _handle_gui_updates(
# Do some type casting. This is brittle, but necessary when we
# expect floats but the Javascript side gives us integers.
if prop_name == "value":
if handle_state.typ is tuple:
assert len(prop_value) == len(handle_state.value)
prop_value = tuple(
type(handle_state.value[i])(prop_value[i])
for i in range(len(prop_value))
)
else:
prop_value = handle_state.typ(prop_value)
prop_value = cast_value(handle_state.typ, prop_value)

# Update handle property.
if current_value != prop_value:
Expand Down Expand Up @@ -179,6 +252,83 @@ def _handle_gui_updates(
if handle_state.sync_cb is not None:
handle_state.sync_cb(client_id, updates_cast)

def _handle_file_transfer_start(
self, client_id: ClientId, message: _messages.FileTransferStart
) -> None:
if message.source_component_id not in self._gui_handle_from_id:
return
self._current_file_upload_states[message.transfer_uuid] = {
"filename": message.filename,
"mime_type": message.mime_type,
"part_count": message.part_count,
"parts": {},
"total_bytes": message.size_bytes,
"transferred_bytes": 0,
"lock": threading.Lock(),
}

def _handle_file_transfer_part(
self, client_id: ClientId, message: _messages.FileTransferPart
) -> None:
if message.transfer_uuid not in self._current_file_upload_states:
return
assert message.source_component_id in self._gui_handle_from_id

state = self._current_file_upload_states[message.transfer_uuid]
state["parts"][message.part] = message.content
total_bytes = state["total_bytes"]

with state["lock"]:
state["transferred_bytes"] += len(message.content)

# Send ack to the server.
self._get_api()._queue(
FileTransferPartAck(
source_component_id=message.source_component_id,
transfer_uuid=message.transfer_uuid,
transferred_bytes=state["transferred_bytes"],
total_bytes=total_bytes,
)
)

if state["transferred_bytes"] < total_bytes:
return

# Finish the upload.
assert state["transferred_bytes"] == total_bytes
state = self._current_file_upload_states.pop(message.transfer_uuid)

handle = self._gui_handle_from_id.get(message.source_component_id, None)
if handle is None:
return

handle_state = handle._impl

value = UploadedFile(
name=state["filename"],
content=b"".join(state["parts"][i] for i in range(state["part_count"])),
)

# Update state.
with self._get_api()._atomic_lock:
handle_state.value = value
handle_state.update_timestamp = time.time()

# Trigger callbacks.
for cb in handle_state.update_cb:
from ._viser import ClientHandle, ViserServer

# Get the handle of the client that triggered this event.
api = self._get_api()
if isinstance(api, ClientHandle):
client = api
elif isinstance(api, ViserServer):
client = api.get_clients()[client_id]
else:
assert False

cb(GuiEvent(client, client_id, handle))

def _get_container_id(self) -> str:
"""Get container ID associated with the current thread."""
return self._target_container_from_thread_id.get(threading.get_ident(), "root")
Expand Down Expand Up @@ -413,6 +563,74 @@ def add_gui_button(
)._impl
)

def add_gui_upload_button(
self,
label: str,
disabled: bool = False,
visible: bool = True,
hint: Optional[str] = None,
color: Optional[
Literal[
"dark",
"gray",
"red",
"pink",
"grape",
"violet",
"indigo",
"blue",
"cyan",
"green",
"lime",
"yellow",
"orange",
"teal",
]
] = None,
icon: Optional[IconName] = None,
mime_type: str = "*/*",
order: Optional[float] = None,
) -> GuiUploadButtonHandle:
"""Add a button to the GUI. The value of this input is set to `True` every time
it is clicked; to detect clicks, we can manually set it back to `False`.
Args:
label: Label to display on the button.
visible: Whether the button is visible.
disabled: Whether the button is disabled.
hint: Optional hint to display on hover.
color: Optional color to use for the button.
icon: Optional icon to display on the button.
mime_type: Optional MIME type to filter the files that can be uploaded.
order: Optional ordering, smallest values will be displayed first.
Returns:
A handle that can be used to interact with the GUI element.
"""

# Re-wrap the GUI handle with a button interface.
id = _make_unique_id()
order = _apply_default_order(order)
return GuiUploadButtonHandle(
self._create_gui_input(
value=UploadedFile("", b""),
message=_messages.GuiAddUploadButtonMessage(
value=None,
disabled=disabled,
visible=visible,
order=order,
id=id,
label=label,
container_id=self._get_container_id(),
hint=hint,
color=color,
mime_type=mime_type,
icon_base64=None if icon is None else base64_from_icon(icon),
),
is_button=True,
)._impl
)

# The TLiteralString overload tells pyright to resolve the value type to a Literal
# whenever possible.
#
Expand Down
25 changes: 25 additions & 0 deletions src/viser/_gui_handles.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,31 @@ def on_click(
return func


@dataclasses.dataclass
class UploadedFile:
"""Result of a file upload."""

name: str
"""Name of the file."""
content: bytes
"""Contents of the file."""


@dataclasses.dataclass
class GuiUploadButtonHandle(_GuiInputHandle[UploadedFile]):
"""Handle for an upload file button in our visualizer.
The `.value` attribute will be updated with the contents of uploaded files.
"""

def on_upload(
self: TGuiHandle, func: Callable[[GuiEvent[TGuiHandle]], None]
) -> Callable[[GuiEvent[TGuiHandle]], None]:
"""Attach a function to call when a button is pressed. Happens in a thread."""
self._impl.update_cb.append(func)
return func


@dataclasses.dataclass
class GuiButtonGroupHandle(_GuiInputHandle[StringType], Generic[StringType]):
"""Handle for a button group input in our visualizer.
Expand Down
Loading

0 comments on commit 96fb516

Please sign in to comment.