Skip to content

Commit

Permalink
feat: Add async tokenizer, add detokenize method (#144)
Browse files Browse the repository at this point in the history
* feat: add detokenize method, add async tokenizer

* chore: update pyproject and poetry.lock

* fix: fix tokenizer name in examples and readme, add example
  • Loading branch information
miri-bar authored Jun 19, 2024
1 parent 5b0eadb commit f2d06fc
Show file tree
Hide file tree
Showing 10 changed files with 223 additions and 19 deletions.
16 changes: 14 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -379,14 +379,26 @@ By using the `count_tokens` method, you can estimate the billing for a given req
```python
from ai21.tokenizers import get_tokenizer

tokenizer = get_tokenizer(name="jamba-instruct-tokenizer")
tokenizer = get_tokenizer(name="jamba-tokenizer")
total_tokens = tokenizer.count_tokens(text="some text") # returns int
print(total_tokens)
```

### Async Usage

```python
from ai21.tokenizers import get_async_tokenizer

## Your async function code
#...
tokenizer = await get_async_tokenizer(name="jamba-tokenizer")
total_tokens = await tokenizer.count_tokens(text="some text") # returns int
print(total_tokens)
```

Available tokenizers are:

- `jamba-instruct-tokenizer`
- `jamba-tokenizer`
- `j2-tokenizer`

For more information on AI21 Tokenizers, see the [documentation](https://github.com/AI21Labs/ai21-tokenizer).
Expand Down
4 changes: 2 additions & 2 deletions ai21/tokenizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .ai21_tokenizer import AI21Tokenizer
from .factory import get_tokenizer
from .factory import get_tokenizer, get_async_tokenizer

__all__ = ["AI21Tokenizer", "get_tokenizer"]
__all__ = ["AI21Tokenizer", "get_tokenizer", "get_async_tokenizer"]
31 changes: 30 additions & 1 deletion ai21/tokenizers/ai21_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Any

from ai21_tokenizer import BaseTokenizer
from ai21_tokenizer import BaseTokenizer, AsyncBaseTokenizer


class AI21Tokenizer:
Expand All @@ -20,3 +20,32 @@ def tokenize(self, text: str, **kwargs: Any) -> List[str]:
encoded_text = self._tokenizer.encode(text, **kwargs)

return self._tokenizer.convert_ids_to_tokens(encoded_text, **kwargs)

def detokenize(self, tokens: List[str], **kwargs: Any) -> str:
token_ids = self._tokenizer.convert_tokens_to_ids(tokens)

return self._tokenizer.decode(token_ids, **kwargs)


class AsyncAI21Tokenizer:
"""
A class that wraps an async tokenizer and provides additional functionality.
"""

def __init__(self, tokenizer: AsyncBaseTokenizer):
self._tokenizer = tokenizer

async def count_tokens(self, text: str) -> int:
encoded_text = await self._tokenizer.encode(text)

return len(encoded_text)

async def tokenize(self, text: str, **kwargs: Any) -> List[str]:
encoded_text = await self._tokenizer.encode(text, **kwargs)

return await self._tokenizer.convert_ids_to_tokens(encoded_text, **kwargs)

async def detokenize(self, tokens: List[str], **kwargs: Any) -> str:
token_ids = await self._tokenizer.convert_tokens_to_ids(tokens)

return await self._tokenizer.decode(token_ids, **kwargs)
17 changes: 16 additions & 1 deletion ai21/tokenizers/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from ai21_tokenizer import Tokenizer, PreTrainedTokenizers

from ai21.tokenizers.ai21_tokenizer import AI21Tokenizer
from ai21.tokenizers.ai21_tokenizer import AI21Tokenizer, AsyncAI21Tokenizer

_cached_tokenizers: Dict[str, AI21Tokenizer] = {}
_cached_async_tokenizers: Dict[str, AsyncAI21Tokenizer] = {}


def get_tokenizer(name: str = PreTrainedTokenizers.J2_TOKENIZER) -> AI21Tokenizer:
Expand All @@ -19,3 +20,17 @@ def get_tokenizer(name: str = PreTrainedTokenizers.J2_TOKENIZER) -> AI21Tokenize
_cached_tokenizers[name] = AI21Tokenizer(Tokenizer.get_tokenizer(name))

return _cached_tokenizers[name]


async def get_async_tokenizer(name: str = PreTrainedTokenizers.J2_TOKENIZER) -> AsyncAI21Tokenizer:
"""
Get the async tokenizer instance.
If the tokenizer instance is not cached, it will be created using the Tokenizer.get_tokenizer() method.
"""
global _cached_async_tokenizers

if _cached_async_tokenizers.get(name) is None:
_cached_async_tokenizers[name] = AsyncAI21Tokenizer(await Tokenizer.get_async_tokenizer(name))

return _cached_async_tokenizers[name]
42 changes: 42 additions & 0 deletions examples/studio/async_tokenization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from ai21.tokenizers import get_async_tokenizer
import asyncio

prompt = (
"The following is a conversation between a user of an eCommerce store and a user operation"
" associate called Max. Max is very kind and keen to help."
" The following are important points about the business policies:\n- "
"Delivery takes up to 5 days\n- There is no return option\n\nUser gender:"
" Male.\n\nConversation:\nUser: Hi, had a question\nMax: "
"Hi there, happy to help!\nUser: Is there no way to return a product?"
" I got your blue T-Shirt size small but it doesn't fit.\n"
"Max: I'm sorry to hear that. Unfortunately we don't have a return policy. \n"
"User: That's a shame. \nMax: Is there anything else i can do for you?\n\n"
"##\n\nThe following is a conversation between a user of an eCommerce store and a user operation"
" associate called Max. Max is very kind and keen to help. The following are important points about"
" the business policies:\n- Delivery takes up to 5 days\n- There is no return option\n\n"
'User gender: Female.\n\nConversation:\nUser: Hi, I was wondering when you\'ll have the "Blue & White" '
"t-shirt back in stock?\nMax: Hi, happy to assist! We currently don't have it in stock. Do you want me"
" to send you an email once we do?\nUser: Yes!\nMax: Awesome. What's your email?\nUser: [email protected]\n"
"Max: Great. I'll send you an email as soon as we get it.\n\n##\n\nThe following is a conversation between"
" a user of an eCommerce store and a user operation associate called Max. Max is very kind and keen to help."
" The following are important points about the business policies:\n- Delivery takes up to 5 days\n"
"- There is no return option\n\nUser gender: Female.\n\nConversation:\nUser: Hi, how much time does it"
" take for the product to reach me?\nMax: Hi, happy to assist! It usually takes 5 working"
" days to reach you.\nUser: Got it! thanks. Is there a way to shorten that delivery time if i pay extra?\n"
"Max: I'm sorry, no.\nUser: Got it. How do i know if the White Crisp t-shirt will fit my size?\n"
"Max: The size charts are available on the website.\nUser: Can you tell me what will fit a young women.\n"
"Max: Sure. Can you share her exact size?\n\n##\n\nThe following is a conversation between a user of an"
" eCommerce store and a user operation associate called Max. Max is very kind and keen to help. The following"
" are important points about the business policies:\n- Delivery takes up to 5 days\n"
"- There is no return option\n\nUser gender: Female.\n\nConversation:\n"
"User: Hi, I have a question for you"
)


async def main():
tokenizer = await get_async_tokenizer(name="jamba-tokenizer")
response = await tokenizer.count_tokens(prompt)
print(response)


asyncio.run(main())
2 changes: 1 addition & 1 deletion examples/studio/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@
"- There is no return option\n\nUser gender: Female.\n\nConversation:\n"
"User: Hi, I have a question for you"
)
tokenizer = get_tokenizer(name="jamba-instruct-tokenizer")
tokenizer = get_tokenizer(name="jamba-tokenizer")
response = tokenizer.count_tokens(prompt)
print(response)
17 changes: 9 additions & 8 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ packages = [

[tool.poetry.dependencies]
python = "^3.8"
ai21-tokenizer = ">=0.9.1,<1.0.0"
ai21-tokenizer = ">=0.11.0,<1.0.0"
boto3 = { version = "^1.28.82", optional = true }
dataclasses-json = "^0.6.3"
typing-extensions = "^4.9.0"
Expand Down
25 changes: 22 additions & 3 deletions tests/unittests/tokenizers/test_ai21_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class TestAI21Tokenizer:
argnames=["tokenizer_name", "expected_tokens"],
argvalues=[
("j2-tokenizer", 8),
("jamba-instruct-tokenizer", 9),
("jamba-tokenizer", 9),
],
)
def test__count_tokens__should_return_number_of_tokens(self, tokenizer_name: str, expected_tokens: int):
Expand All @@ -32,7 +32,7 @@ def test__count_tokens__should_return_number_of_tokens(self, tokenizer_name: str
argvalues=[
("j2-tokenizer", ["▁Text", "▁to", "▁Token", "ize", "▁-", "▁Hello", "▁world", "!"]),
(
"jamba-instruct-tokenizer",
"jamba-tokenizer",
["<|startoftext|>", "Text", "▁to", "▁Token", "ize", "▁-", "▁Hello", "▁world", "!"],
),
],
Expand All @@ -44,6 +44,25 @@ def test__tokenize__should_return_list_of_tokens(self, tokenizer_name: str, expe

assert actual_tokens == expected_tokens

@pytest.mark.parametrize(
ids=[
"when_j2_tokenizer",
"when_jamba_instruct_tokenizer",
],
argnames=["tokenizer_name"],
argvalues=[
("j2-tokenizer",),
("jamba-tokenizer",),
],
)
def test__detokenize__should_return_list_of_tokens(self, tokenizer_name: str):
tokenizer = get_tokenizer(tokenizer_name)
original_text = "Text to Tokenize - Hello world!"
actual_tokens = tokenizer.tokenize(original_text)
detokenized_text = tokenizer.detokenize(actual_tokens)

assert original_text == detokenized_text

def test__tokenizer__should_be_singleton__when_called_twice(self):
tokenizer1 = get_tokenizer()
tokenizer2 = get_tokenizer()
Expand All @@ -52,7 +71,7 @@ def test__tokenizer__should_be_singleton__when_called_twice(self):

def test__get_tokenizer__when_called_with_different_tokenizer_name__should_return_different_tokenizer(self):
tokenizer1 = get_tokenizer("j2-tokenizer")
tokenizer2 = get_tokenizer("jamba-instruct-tokenizer")
tokenizer2 = get_tokenizer("jamba-tokenizer")

assert tokenizer1._tokenizer is not tokenizer2._tokenizer

Expand Down
86 changes: 86 additions & 0 deletions tests/unittests/tokenizers/test_async_ai21_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import List

import pytest
from ai21.tokenizers.factory import get_async_tokenizer


class TestAsyncAI21Tokenizer:
@pytest.mark.asyncio
@pytest.mark.parametrize(
ids=[
"when_j2_tokenizer",
"when_jamba_instruct_tokenizer",
],
argnames=["tokenizer_name", "expected_tokens"],
argvalues=[
("j2-tokenizer", 8),
("jamba-tokenizer", 9),
],
)
async def test__count_tokens__should_return_number_of_tokens(self, tokenizer_name: str, expected_tokens: int):
tokenizer = await get_async_tokenizer(tokenizer_name)

actual_number_of_tokens = await tokenizer.count_tokens("Text to Tokenize - Hello world!")

assert actual_number_of_tokens == expected_tokens

@pytest.mark.asyncio
@pytest.mark.parametrize(
ids=[
"when_j2_tokenizer",
"when_jamba_instruct_tokenizer",
],
argnames=["tokenizer_name", "expected_tokens"],
argvalues=[
("j2-tokenizer", ["▁Text", "▁to", "▁Token", "ize", "▁-", "▁Hello", "▁world", "!"]),
(
"jamba-tokenizer",
["<|startoftext|>", "Text", "▁to", "▁Token", "ize", "▁-", "▁Hello", "▁world", "!"],
),
],
)
async def test__tokenize__should_return_list_of_tokens(self, tokenizer_name: str, expected_tokens: List[str]):
tokenizer = await get_async_tokenizer(tokenizer_name)

actual_tokens = await tokenizer.tokenize("Text to Tokenize - Hello world!")

assert actual_tokens == expected_tokens

@pytest.mark.asyncio
@pytest.mark.parametrize(
ids=[
"when_j2_tokenizer",
"when_jamba_instruct_tokenizer",
],
argnames=["tokenizer_name"],
argvalues=[
("j2-tokenizer",),
("jamba-tokenizer",),
],
)
async def test__detokenize__should_return_list_of_tokens(self, tokenizer_name: str):
tokenizer = await get_async_tokenizer(tokenizer_name)
original_text = "Text to Tokenize - Hello world!"
actual_tokens = await tokenizer.tokenize(original_text)
detokenized_text = await tokenizer.detokenize(actual_tokens)

assert original_text == detokenized_text

@pytest.mark.asyncio
async def test__tokenizer__should_be_singleton__when_called_twice(self):
tokenizer1 = await get_async_tokenizer()
tokenizer2 = await get_async_tokenizer()

assert tokenizer1 is tokenizer2

@pytest.mark.asyncio
async def test__get_tokenizer__when_called_with_different_tokenizer_name__should_return_different_tokenizer(self):
tokenizer1 = await get_async_tokenizer("j2-tokenizer")
tokenizer2 = await get_async_tokenizer("jamba-tokenizer")

assert tokenizer1._tokenizer is not tokenizer2._tokenizer

@pytest.mark.asyncio
async def test__get_tokenizer__when_tokenizer_name_not_supported__should_raise_error(self):
with pytest.raises(ValueError):
await get_async_tokenizer("some-tokenizer")

0 comments on commit f2d06fc

Please sign in to comment.