Skip to content

Commit

Permalink
feat: fix and add tests, examples, update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
miri-bar committed Jun 16, 2024
1 parent 54b8c0d commit 23ecd60
Show file tree
Hide file tree
Showing 11 changed files with 107 additions and 62 deletions.
44 changes: 32 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,40 +35,60 @@ poetry add ai21-tokenizer

### Tokenizer Creation

### Jamba-Instruct Tokenizer

```python
from ai21_tokenizer import Tokenizer
from ai21_tokenizer import Tokenizer, PreTrainedTokenizers

tokenizer = Tokenizer.get_tokenizer()
tokenizer = Tokenizer.get_tokenizer(PreTrainedTokenizers.JAMBA_INSTRUCT_TOKENIZER)
# Your code here
```

Another way would be to use our Jurassic model directly:
Another way would be to use our Jamba model directly:

```python
from ai21_tokenizer import JurassicTokenizer
from ai21_tokenizer import JambaInstructTokenizer

model_path = "<Path to your vocabs file. This is usually a binary file that end with .model>"
config = {} # "dictionary object of your config.json file"
tokenizer = JurassicTokenizer(model_path=model_path, config=config)
model_path = "<Path to your vocabs file>"
tokenizer = JambaInstructTokenizer(model_path=model_path)
# Your code here
```

#### Async usage

```python
from ai21_tokenizer import Tokenizer, PreTrainedTokenizers

tokenizer = Tokenizer.get_async_tokenizer(PreTrainedTokenizers.JAMBA_INSTRUCT_TOKENIZER)
# Your code here
```

### Async usage
### J2 Tokenizer

```python
from ai21_tokenizer import Tokenizer

tokenizer = Tokenizer.get_tokenizer(is_async=True)
tokenizer = Tokenizer.get_tokenizer()
# Your code here
```

Direct usage of async Jurassic model:
Another way would be to use our Jurassic model directly:

```python
from ai21_tokenizer import AsyncJurassicTokenizer
from ai21_tokenizer import JurassicTokenizer

model_path = "<Path to your vocabs file. This is usually a binary file that end with .model>"
config = {} # "dictionary object of your config.json file"
tokenizer = AsyncJurassicTokenizer(model_path=model_path, config=config)
tokenizer = JurassicTokenizer(model_path=model_path, config=config)
```

#### Async usage

```python
from ai21_tokenizer import Tokenizer

tokenizer = Tokenizer.get_async_tokenizer()
# Your code here
```

### Functions
Expand Down
8 changes: 6 additions & 2 deletions ai21_tokenizer/jamba_instruct_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,12 @@ def __init__(
"""
self._model_path = model_path
self._cache_dir = cache_dir or _DEFAULT_MODEL_CACHE_DIR
# BaseJambaInstructTokenizer.__init__(self, model_path=model_path, cache_dir=cache_dir)

async def __aenter__(self):
await self._init_tokenizer()
return self

def __aexit__(self, exc_type, exc_val, exc_tb):
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass

async def encode(self, text: str, **kwargs) -> List[int]:
Expand All @@ -110,6 +109,11 @@ async def convert_ids_to_tokens(self, token_ids: Union[int, List[int]], **kwargs

@property
def vocab_size(self) -> int:
if not self._tokenizer:
raise ValueError(
"Tokenizer not properly initialized. Please do not initialize the tokenizer directly. Use "
"Tokenizer.get_async_tokenizer instead."
)
return self._tokenizer.get_vocab_size()

async def _init_tokenizer(self):
Expand Down
5 changes: 5 additions & 0 deletions ai21_tokenizer/jurassic_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):

@property
def vocab_size(self) -> int:
if not self._sp:
raise ValueError(
"Tokenizer not properly initialized. Please do not initialize the tokenizer directly. Use "
"Tokenizer.get_async_tokenizer instead."
)
return self._sp.vocab_size()

async def encode(self, text: str, **kwargs) -> List[int]:
Expand Down
18 changes: 18 additions & 0 deletions examples/async_jamba_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import asyncio

from ai21_tokenizer import Tokenizer, PreTrainedTokenizers


async def main():
tokenizer = await Tokenizer.get_async_tokenizer(PreTrainedTokenizers.JAMBA_INSTRUCT_TOKENIZER)

example_sentence = "This sentence should be encoded and then decoded. Hurray!!!!"
encoded = await tokenizer.encode(example_sentence)
decoded = await tokenizer.decode(encoded)

assert decoded == example_sentence
print("Example sentence: " + example_sentence)
print("Encoded and decoded: " + decoded)


asyncio.run(main())
4 changes: 2 additions & 2 deletions examples/async_jurassic_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from pathlib import Path

from ai21_tokenizer import AsyncJurassicTokenizer
from ai21_tokenizer import Tokenizer, PreTrainedTokenizers
from ai21_tokenizer.utils import load_json

resource_path = Path(__file__).parent.parent / "ai21_tokenizer" / "resources"
Expand All @@ -12,7 +12,7 @@


async def main():
tokenizer = AsyncJurassicTokenizer(model_path=model_path, config=config)
tokenizer = await Tokenizer.get_async_tokenizer(PreTrainedTokenizers.J2_TOKENIZER)

example_sentence = "This sentence should be encoded and then decoded. Hurray!!!!"
encoded = await tokenizer.encode(example_sentence)
Expand Down
11 changes: 11 additions & 0 deletions examples/jamba_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from ai21_tokenizer import JambaInstructTokenizer

model_path = "ai21labs/Jamba-v0.1"

tokenizer = JambaInstructTokenizer(model_path=model_path)

example_sentence = "This sentence should be encoded and then decoded. Hurray!!!!"
encoded = tokenizer.encode(example_sentence)
decoded = tokenizer.decode(encoded)

assert decoded == example_sentence
2 changes: 1 addition & 1 deletion examples/use_tokenizer_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


async def main():
tokenizer = Tokenizer.get_tokenizer(is_async=True)
tokenizer = await Tokenizer.get_async_tokenizer()
example_sentence = "This sentence should be encoded and then decoded. Hurray!!"
encoded = await tokenizer.encode(example_sentence)
decoded = await tokenizer.decode(encoded)
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ newline_sequence = "\n"





[tool.poetry.group.test.dependencies]
coverage = "^7.1.0"
pytest = "7.4.4"
Expand Down
9 changes: 4 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def tokenizer() -> JurassicTokenizer:
raise ValueError("JurassicTokenizer not found")


@pytest.fixture(scope="session")
def async_tokenizer() -> AsyncJurassicTokenizer:
jurassic_tokenizer = Tokenizer.get_tokenizer(tokenizer_name=PreTrainedTokenizers.J2_TOKENIZER, is_async=True)
@pytest.fixture()
async def async_tokenizer() -> AsyncJurassicTokenizer:
jurassic_tokenizer = await Tokenizer.get_async_tokenizer(tokenizer_name=PreTrainedTokenizers.J2_TOKENIZER)

if isinstance(jurassic_tokenizer, AsyncJurassicTokenizer):
return jurassic_tokenizer
Expand All @@ -49,8 +49,7 @@ def jamba_instruct_tokenizer() -> JambaInstructTokenizer:
raise ValueError("JambaInstructTokenizer not found")


@pytest.mark.asyncio
@pytest.fixture()
@pytest.fixture
async def async_jamba_instruct_tokenizer() -> AsyncJambaInstructTokenizer:
jamba_tokenizer = await Tokenizer.get_async_tokenizer(PreTrainedTokenizers.JAMBA_INSTRUCT_TOKENIZER)

Expand Down
52 changes: 12 additions & 40 deletions tests/test_jamba_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path
from typing import List, Union
from unittest.mock import patch, AsyncMock
from unittest.mock import patch

import pytest
from ai21_tokenizer import JambaInstructTokenizer, AsyncJambaInstructTokenizer
Expand Down Expand Up @@ -179,46 +179,18 @@ async def test_async_tokenizer_encode_caches_tokenizer__should_have_tokenizer_in


@pytest.mark.asyncio
@patch("ai21_tokenizer.jamba_instruct_tokenizer._load_from_cache", new_callable=AsyncMock)
async def test_async_tokenizer_when_cache_dir_exists__should_load_from_cache(
async def test_async_tokenizer_initialized_directly_and_uses_vocab_size__should_raise_error(
tmp_path: Path,
mock_async_jamba_instruct_tokenizer: AsyncJambaInstructTokenizer,
):
# Creating tokenizer once from repo
assert not (tmp_path / "tokenizer.json").exists()
tokenizer = AsyncJambaInstructTokenizer(JAMBA_TOKENIZER_HF_PATH, tmp_path)
_ = await tokenizer.encode("Hello world!")

assert (tmp_path / "tokenizer.json").exists()
with pytest.raises(ValueError):
tokenizer = AsyncJambaInstructTokenizer(model_path=JAMBA_TOKENIZER_HF_PATH, cache_dir=tmp_path)
_ = tokenizer.vocab_size

tokenizer2 = AsyncJambaInstructTokenizer(JAMBA_TOKENIZER_HF_PATH, tmp_path)
assert (tmp_path / "tokenizer.json").exists()

_ = await tokenizer2.encode("Hello world!")

# Assert that _load_from_cache was called once
mock_async_jamba_instruct_tokenizer._load_from_cache.assert_called_once()


# @pytest.mark.asyncio
# async def test_async_tokenizer__when_cache_dir_not_exists__should_save_tokenizer_in_cache_dir(tmp_path: Path):
# assert not (tmp_path / "tokenizer.json").exists()
# AsyncJambaInstructTokenizer(JAMABA_TOKENIZER_HF_PATH, tmp_path)
#
# assert (tmp_path / "tokenizer.json").exists()


# @pytest.mark.asyncio
# async def test_async_tokenizer__when_cache_dir_exists__should_load_from_cache(tmp_path: Path):
# # Creating tokenizer once from repo
# assert not (tmp_path / "tokenizer.json").exists()
# AsyncJambaInstructTokenizer(JAMABA_TOKENIZER_HF_PATH, tmp_path)
#
# # Creating tokenizer again to load from cache
# with patch.object(
# AsyncJambaInstructTokenizer, AsyncJambaInstructTokenizer._load_from_cache.__name__
# ) as mock_load_from_cache:
# AsyncJambaInstructTokenizer(JAMABA_TOKENIZER_HF_PATH, tmp_path)
#
# # assert load_from_cache was called
# mock_load_from_cache.assert_called_once()
@pytest.mark.asyncio
async def test_async_tokenizer_initialized_with_manager_and_uses_vocab_size__should_not_raise_error(
tmp_path: Path,
):
tokenizer = AsyncJambaInstructTokenizer(model_path=JAMBA_TOKENIZER_HF_PATH, cache_dir=tmp_path)
async with tokenizer:
assert tokenizer.vocab_size > 0
14 changes: 14 additions & 0 deletions tests/test_jurassic_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,17 @@ async def test_async_init__when_model_path_is_a_file__should_support_backwards_c
decoded = await async_tokenizer.decode(encoded)

assert decoded == TEXT


@pytest.mark.asyncio
async def test_async_tokenizer_initialized_directly_and_uses_vocab_size__should_raise_error():
with pytest.raises(ValueError):
tokenizer = AsyncJurassicTokenizer(model_path=_LOCAL_RESOURCES_PATH / "j2-tokenizer.model")
_ = tokenizer.vocab_size


@pytest.mark.asyncio
async def test_async_tokenizer_initialized_with_manager_and_uses_vocab_size__should_not_raise_error():
tokenizer = AsyncJurassicTokenizer(model_path=_LOCAL_RESOURCES_PATH / "j2-tokenizer.model")
async with tokenizer:
assert tokenizer.vocab_size > 0

0 comments on commit 23ecd60

Please sign in to comment.