Skip to content

Commit

Permalink
Merge branch 'main' into dependabot/pip/pytest-7.4.4
Browse files Browse the repository at this point in the history
  • Loading branch information
asafgardin authored Jan 2, 2024
2 parents c9001d9 + dcb73a7 commit 0c8ae53
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 4 deletions.
28 changes: 25 additions & 3 deletions ai21_tokenizer/jurassic_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import re
from dataclasses import dataclass
from typing import List, Union, Optional, Dict, Any, Tuple
from typing import List, Union, Optional, Dict, Any, Tuple, BinaryIO

import sentencepiece as spm

Expand All @@ -19,11 +19,16 @@ class SpaceSymbol:
class JurassicTokenizer(BaseTokenizer):
def __init__(
self,
model_path: PathLike,
model_path: Optional[PathLike] = None,
model_file_handle: Optional[BinaryIO] = None,
config: Optional[Dict[str, Any]] = None,
):
self._validate_init(model_path=model_path, model_file_handle=model_file_handle)

model_proto = load_binary(model_path) if model_path else model_file_handle.read()

# noinspection PyArgumentList
self._sp = spm.SentencePieceProcessor(model_proto=load_binary(model_path))
self._sp = spm.SentencePieceProcessor(model_proto=model_proto)
config = config or {}

self.pad_id = config.get("pad_id")
Expand Down Expand Up @@ -52,6 +57,13 @@ def __init__(
self._space_mode = config.get("space_mode")
self._space_tokens = self._map_space_tokens()

def _validate_init(self, model_path: Optional[PathLike], model_file_handle: Optional[BinaryIO]) -> None:
if model_path is None and model_file_handle is None:
raise ValueError("Must provide exactly one of model_path or model_file_handle. Got none.")

if model_path is not None and model_file_handle is not None:
raise ValueError("Must provide exactly one of model_path or model_file_handle. Got both.")

def _map_space_tokens(self) -> List[SpaceSymbol]:
res = []
for count in range(32, 0, -1):
Expand Down Expand Up @@ -226,3 +238,13 @@ def convert_ids_to_tokens(self, token_ids: Union[int, List[int]], **kwargs) -> U
return self._id_to_token(token_ids)

return [self._id_to_token(token_id) for token_id in token_ids]

@classmethod
def from_file_handle(
cls, model_file_handle: BinaryIO, config: Optional[Dict[str, Any]] = None
) -> JurassicTokenizer:
return cls(model_file_handle=model_file_handle, config=config)

@classmethod
def from_file_path(cls, model_path: PathLike, config: Optional[Dict[str, Any]] = None) -> JurassicTokenizer:
return cls(model_path=model_path, config=config)
76 changes: 75 additions & 1 deletion tests/test_jurassic_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import json
from pathlib import Path
from typing import Union, List
from typing import Union, List, BinaryIO, Optional

import pytest

from ai21_tokenizer.jurassic_tokenizer import JurassicTokenizer
from ai21_tokenizer.utils import PathLike

_LOCAL_RESOURCES_PATH = Path(__file__).parents[1] / "ai21_tokenizer" / "resources" / "j2-tokenizer"


def test_tokenizer_encode_decode(tokenizer: JurassicTokenizer):
Expand Down Expand Up @@ -87,3 +90,74 @@ def test_tokenizer__convert_tokens_to_ids(
actual_ids = tokenizer.convert_tokens_to_ids(tokens)

assert actual_ids == expected_ids


def test_tokenizer__from_file_handle():
text = "Hello world!"
model_config = {
"vocab_size": 262144,
"pad_id": 0,
"bos_id": 1,
"eos_id": 2,
"unk_id": 3,
"add_dummy_prefix": False,
"newline_piece": "<|newline|>",
"number_mode": "right_keep",
"space_mode": "left",
}

with (_LOCAL_RESOURCES_PATH / "j2-tokenizer.model").open("rb") as tokenizer_file:
tokenizer = JurassicTokenizer.from_file_handle(model_file_handle=tokenizer_file, config=model_config)

encoded = tokenizer.encode(text)
decoded = tokenizer.decode(encoded)

assert decoded == text


def test_tokenizer__from_file_path():
text = "Hello world!"
model_config = {
"vocab_size": 262144,
"pad_id": 0,
"bos_id": 1,
"eos_id": 2,
"unk_id": 3,
"add_dummy_prefix": False,
"newline_piece": "<|newline|>",
"number_mode": "right_keep",
"space_mode": "left",
}

tokenizer = JurassicTokenizer.from_file_path(
model_path=(_LOCAL_RESOURCES_PATH / "j2-tokenizer.model"), config=model_config
)

encoded = tokenizer.encode(text)
decoded = tokenizer.decode(encoded)

assert decoded == text


@pytest.mark.parametrize(
ids=[
"when_model_path_and_file_handle_are_none__should_raise_value_error",
"when_model_path_and_file_handle_are_not_none__should_raise_value_error",
],
argnames=["model_path", "model_file_handle", "expected_error_message"],
argvalues=[
(None, None, "Must provide exactly one of model_path or model_file_handle. Got none."),
(
Path("some_path"),
"some_file_handle",
"Must provide exactly one of model_path or model_file_handle. Got both.",
),
],
)
def test_tokenizer__(
model_path: Optional[PathLike], model_file_handle: Optional[BinaryIO], expected_error_message: str
):
with pytest.raises(ValueError) as error:
JurassicTokenizer(model_file_handle=model_file_handle, model_path=model_path, config={})

assert error.value.args[0] == expected_error_message

0 comments on commit 0c8ae53

Please sign in to comment.