Skip to content

Commit

Permalink
test: Added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
asafgardin committed Jan 2, 2024
1 parent e565a62 commit 57a25f7
Showing 1 changed file with 75 additions and 1 deletion.
76 changes: 75 additions & 1 deletion tests/test_jurassic_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import json
from pathlib import Path
from typing import Union, List
from typing import Union, List, BinaryIO, Optional

import pytest

from ai21_tokenizer.jurassic_tokenizer import JurassicTokenizer
from ai21_tokenizer.utils import PathLike

_LOCAL_RESOURCES_PATH = Path(__file__).parents[1] / "ai21_tokenizer" / "resources" / "j2-tokenizer"


def test_tokenizer_encode_decode(tokenizer: JurassicTokenizer):
Expand Down Expand Up @@ -87,3 +90,74 @@ def test_tokenizer__convert_tokens_to_ids(
actual_ids = tokenizer.convert_tokens_to_ids(tokens)

assert actual_ids == expected_ids


def test_tokenizer__from_file_handle():
text = "Hello world!"
model_config = {
"vocab_size": 262144,
"pad_id": 0,
"bos_id": 1,
"eos_id": 2,
"unk_id": 3,
"add_dummy_prefix": False,
"newline_piece": "<|newline|>",
"number_mode": "right_keep",
"space_mode": "left",
}

with (_LOCAL_RESOURCES_PATH / "j2-tokenizer.model").open("rb") as tokenizer_file:
tokenizer = JurassicTokenizer.from_file_handle(model_file_handle=tokenizer_file, config=model_config)

encoded = tokenizer.encode(text)
decoded = tokenizer.decode(encoded)

assert decoded == text


def test_tokenizer__from_file_path():
text = "Hello world!"
model_config = {
"vocab_size": 262144,
"pad_id": 0,
"bos_id": 1,
"eos_id": 2,
"unk_id": 3,
"add_dummy_prefix": False,
"newline_piece": "<|newline|>",
"number_mode": "right_keep",
"space_mode": "left",
}

tokenizer = JurassicTokenizer.from_file_path(
model_path=(_LOCAL_RESOURCES_PATH / "j2-tokenizer.model"), config=model_config
)

encoded = tokenizer.encode(text)
decoded = tokenizer.decode(encoded)

assert decoded == text


@pytest.mark.parametrize(
ids=[
"when_model_path_and_file_handle_are_none__should_raise_value_error",
"when_model_path_and_file_handle_are_not_none__should_raise_value_error",
],
argnames=["model_path", "model_file_handle", "expected_error_message"],
argvalues=[
(None, None, "Must provide exactly one of model_path or model_file_handle. Got none."),
(
Path("some_path"),
"some_file_handle",
"Must provide exactly one of model_path or model_file_handle. Got both.",
),
],
)
def test_tokenizer__(
model_path: Optional[PathLike], model_file_handle: Optional[BinaryIO], expected_error_message: str
):
with pytest.raises(ValueError) as error:
JurassicTokenizer(model_file_handle=model_file_handle, model_path=model_path, config={})

assert error.value.args[0] == expected_error_message

0 comments on commit 57a25f7

Please sign in to comment.