Skip to content

Commit

Permalink
ui.Chat() now correctly handles new ollama.chat() return value in…
Browse files Browse the repository at this point in the history
…troduced in ollama 0.4 (#1787)
  • Loading branch information
cpsievert authored Nov 26, 2024
1 parent ba97d6d commit 46d8ab8
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 22 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to Shiny for Python will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [UNRELEASED]

### Bug fixes

* `ui.Chat()` now correctly handles new `ollama.chat()` return value introduced in `ollama` v0.4. (#1787)

## [1.2.1] - 2024-11-14

### Bug fixes
Expand Down
2 changes: 1 addition & 1 deletion shiny/templates/chat/hello-providers/ollama/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def _():
# Create a response message stream
# Assumes you've run `ollama run llama3` to start the server
response = ollama.chat(
model="llama3",
model="llama3.2",
messages=messages,
stream=True,
)
Expand Down
9 changes: 0 additions & 9 deletions shiny/ui/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,15 +914,6 @@ def _get_token_count(
if self._tokenizer is None:
self._tokenizer = get_default_tokenizer()

if self._tokenizer is None:
raise ValueError(
"A tokenizer is required to impose `token_limits` on messages. "
"To get a generic default tokenizer, install the `tokenizers` "
"package (`pip install tokenizers`). "
"To get a more precise token count, provide a specific tokenizer "
"to the `Chat` constructor."
)

encoded = self._tokenizer.encode(content)
if isinstance(encoded, TokenizersEncoding):
return len(encoded.ids)
Expand Down
16 changes: 12 additions & 4 deletions shiny/ui/_chat_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,19 @@ def normalize_chunk(self, chunk: "dict[str, Any]") -> ChatMessage:
return super().normalize_chunk(msg)

def can_normalize(self, message: Any) -> bool:
if not isinstance(message, dict):
return False
if "message" not in message:
try:
from ollama import ChatResponse

# Ollama<0.4 used TypedDict (now it uses pydantic)
# https://github.com/ollama/ollama-python/pull/276
if isinstance(ChatResponse, dict):
return "message" in message and super().can_normalize(
message["message"]
)
else:
return isinstance(message, ChatResponse)
except Exception:
return False
return super().can_normalize(message["message"])

def can_normalize_chunk(self, chunk: Any) -> bool:
return self.can_normalize(chunk)
Expand Down
21 changes: 16 additions & 5 deletions shiny/ui/_chat_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,23 @@ def encode(
TokenEncoding = Union[TiktokenEncoding, TokenizersTokenizer]


def get_default_tokenizer() -> TokenizersTokenizer | None:
def get_default_tokenizer() -> TokenizersTokenizer:
try:
from tokenizers import Tokenizer

return Tokenizer.from_pretrained("bert-base-cased") # type: ignore
except Exception:
pass

return None
except ImportError:
raise ImportError(
"Failed to download a default tokenizer. "
"A tokenizer is required to impose `token_limits` on `chat.messages()`. "
"To get a generic default tokenizer, install the `tokenizers` "
"package (`pip install tokenizers`). "
)
except Exception as e:
raise RuntimeError(
"Failed to download a default tokenizer. "
"A tokenizer is required to impose `token_limits` on `chat.messages()`. "
"Try manually downloading a tokenizer using "
"`tokenizers.Tokenizer.from_pretrained()` and passing it to `ui.Chat()`."
f"Error: {e}"
) from e
18 changes: 15 additions & 3 deletions tests/pytest/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,20 @@ def test_openai_normalization():
assert msg == {"content": "Hello ", "role": "assistant"}


def test_ollama_normalization():
from ollama import ChatResponse
from ollama import Message as OllamaMessage

# Mock return object from ollama.chat()
msg = ChatResponse(
message=OllamaMessage(content="Hello world!", role="assistant"),
)

msg_dict = {"content": "Hello world!", "role": "assistant"}
assert normalize_message(msg) == msg_dict
assert normalize_message_chunk(msg) == msg_dict


# ------------------------------------------------------------------------------------
# Unit tests for as_provider_message()
#
Expand Down Expand Up @@ -462,9 +476,7 @@ def test_as_ollama_message():
import ollama
from ollama import Message as OllamaMessage

assert "typing.Sequence[ollama._types.Message]" in str(
ollama.chat.__annotations__["messages"]
)
assert "ollama._types.Message" in str(ollama.chat.__annotations__["messages"])

from shiny.ui._chat_provider_types import as_ollama_message

Expand Down

0 comments on commit 46d8ab8

Please sign in to comment.