Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ui.Chat() now correctly handles new ollama.chat() return value introduced in ollama 0.4 #1787

Merged
merged 8 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to Shiny for Python will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [UNRELEASED]

### Bug fixes

* `ui.Chat()` now correctly handles new `ollama.chat()` return value introduced in `ollama` v0.4. (#1787)

## [1.2.1] - 2024-11-14

### Bug fixes
Expand Down
2 changes: 1 addition & 1 deletion shiny/templates/chat/hello-providers/ollama/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async def _():
# Create a response message stream
# Assumes you've run `ollama run llama3` to start the server
response = ollama.chat(
model="llama3",
model="llama3.2",
messages=messages,
stream=True,
)
Expand Down
9 changes: 0 additions & 9 deletions shiny/ui/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,15 +914,6 @@ def _get_token_count(
if self._tokenizer is None:
self._tokenizer = get_default_tokenizer()

if self._tokenizer is None:
raise ValueError(
"A tokenizer is required to impose `token_limits` on messages. "
"To get a generic default tokenizer, install the `tokenizers` "
"package (`pip install tokenizers`). "
"To get a more precise token count, provide a specific tokenizer "
"to the `Chat` constructor."
)

encoded = self._tokenizer.encode(content)
if isinstance(encoded, TokenizersEncoding):
return len(encoded.ids)
Expand Down
16 changes: 12 additions & 4 deletions shiny/ui/_chat_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,19 @@ def normalize_chunk(self, chunk: "dict[str, Any]") -> ChatMessage:
return super().normalize_chunk(msg)

def can_normalize(self, message: Any) -> bool:
if not isinstance(message, dict):
return False
if "message" not in message:
try:
from ollama import ChatResponse

# Ollama<0.4 used TypedDict (now it uses pydantic)
# https://github.com/ollama/ollama-python/pull/276
if isinstance(ChatResponse, dict):
return "message" in message and super().can_normalize(
message["message"]
)
else:
return isinstance(message, ChatResponse)
Comment on lines +239 to +244
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is some of the context here that you have a message normalizer for Pydantic models?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you necessarily need to leave a comment, but it'd be helpful for my understanding of the code if you just quickly explained how this fixes the problem (beyond the simple explanation that before it was a dict and now it isn't, that part I get).

Copy link
Collaborator Author

@cpsievert cpsievert Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DictNormalizer works for either case since ollama defines __getitem__() on the pydantic model. I suppose that is a weird/subtle thing that requires extra context, and it'd be nice to take advantage of stronger pydantic typing, but I opted for the minimal change (especially if we're going to support older versions)

except Exception:
return False
return super().can_normalize(message["message"])

def can_normalize_chunk(self, chunk: Any) -> bool:
return self.can_normalize(chunk)
Expand Down
21 changes: 16 additions & 5 deletions shiny/ui/_chat_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,23 @@ def encode(
TokenEncoding = Union[TiktokenEncoding, TokenizersTokenizer]


def get_default_tokenizer() -> TokenizersTokenizer | None:
def get_default_tokenizer() -> TokenizersTokenizer:
try:
from tokenizers import Tokenizer

return Tokenizer.from_pretrained("bert-base-cased") # type: ignore
except Exception:
pass

return None
except ImportError:
raise ImportError(
"Failed to download a default tokenizer. "
"A tokenizer is required to impose `token_limits` on `chat.messages()`. "
"To get a generic default tokenizer, install the `tokenizers` "
"package (`pip install tokenizers`). "
)
except Exception as e:
raise RuntimeError(
"Failed to download a default tokenizer. "
"A tokenizer is required to impose `token_limits` on `chat.messages()`. "
"Try manually downloading a tokenizer using "
"`tokenizers.Tokenizer.from_pretrained()` and passing it to `ui.Chat()`."
f"Error: {e}"
) from e
18 changes: 15 additions & 3 deletions tests/pytest/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,20 @@ def test_openai_normalization():
assert msg == {"content": "Hello ", "role": "assistant"}


def test_ollama_normalization():
from ollama import ChatResponse
from ollama import Message as OllamaMessage

# Mock return object from ollama.chat()
msg = ChatResponse(
message=OllamaMessage(content="Hello world!", role="assistant"),
)

msg_dict = {"content": "Hello world!", "role": "assistant"}
assert normalize_message(msg) == msg_dict
assert normalize_message_chunk(msg) == msg_dict


# ------------------------------------------------------------------------------------
# Unit tests for as_provider_message()
#
Expand Down Expand Up @@ -462,9 +476,7 @@ def test_as_ollama_message():
import ollama
from ollama import Message as OllamaMessage

assert "typing.Sequence[ollama._types.Message]" in str(
ollama.chat.__annotations__["messages"]
)
assert "ollama._types.Message" in str(ollama.chat.__annotations__["messages"])

from shiny.ui._chat_provider_types import as_ollama_message

Expand Down
Loading