Skip to content

Commit

Permalink
feat: allow sending websocket messages from the main event thread. (#953
Browse files Browse the repository at this point in the history
)

This makes #841 simpler.
Instead of giving an error, we do our best, but warn the user it is
better to send from a different thread.
  • Loading branch information
maartenbreddels authored Dec 24, 2024
1 parent db4435e commit cce60d2
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion solara/server/starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def __init__(self, ws: starlette.websockets.WebSocket, portal: Optional[anyio.fr
# we store a strong reference
self.tasks: Set[asyncio.Task] = set()
self.event_loop = asyncio.get_event_loop()
self._thread_id = threading.get_ident()
if settings.main.experimental_performance:
self.task = asyncio.ensure_future(self.process_messages_task())

Expand Down Expand Up @@ -164,7 +165,19 @@ def send_text(self, data: str) -> None:
if settings.main.experimental_performance:
self.to_send.append(data)
else:
self.portal.call(self._send_text_exc, data)
if self._thread_id == threading.get_ident():
warnings.warn("""You are triggering a websocket send from the event loop thread.
Support for this is experimental, and to avoid this message, make sure you trigger updates
that trigger this from a different thread, e.g.:
from anyio import to_thread
await to_thread.run_sync(my_update)
""")
task = self.event_loop.create_task(self._send_text_exc(data))
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
else:
self.portal.call(self._send_text_exc, data)

def send_bytes(self, data: bytes) -> None:
if self.portal is None:
Expand All @@ -175,6 +188,18 @@ def send_bytes(self, data: bytes) -> None:
if settings.main.experimental_performance:
self.to_send.append(data)
else:
if self._thread_id == threading.get_ident():
warnings.warn("""You are triggering a websocket send from the event loop thread.
Support for this is experimental, and to avoid this message, make sure you trigger updates
that trigger this from a different thread, e.g.:
from anyio import to_thread
await to_thread.run_sync(my_update)
""")
task = self.event_loop.create_task(self._send_bytes_exc(data))
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)

self.portal.call(self._send_bytes_exc, data)

async def receive(self):
Expand Down

0 comments on commit cce60d2

Please sign in to comment.